HowTo's with Spark DataFrames
Some examples on how to do analysis using spark DataFrames
Based on information I got from the web, some example cites:
For the pyspark sql module functions look here: - https://spark.apache.org/docs/2.1.0/api/python/pyspark.sql.html
!which python
/cvmfs/sft.cern.ch/lcg/views/LCG_95apython3_nxcals/x86_64-centos7-gcc7-opt/bin/python
from cern.nxcals.api.extraction.data.builders import *
from pyspark.sql.functions import rank, col, count
from pyspark.sql.window import Window
from pyspark.sql.types import StructType
from pyspark.sql.types import StructField
from pyspark.sql.types import DoubleType
from pyspark.sql.types import ArrayType
from pyspark.sql.types import IntegerType
from pyspark.sql.types import StringType
import pyspark.sql.functions as func
import time
import numpy as np
import pandas as pd
import os
from matplotlib import pyplot as plt
# sc - Spark Context
# spark - Spark Session in Memory
spark
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
Create a dataFrame
columns=["id", "dt"]
data = [ (int(i), float(i*100)) for i in np.random.uniform(1, 10, 50)]
sparkdf = spark.createDataFrame(data).toDF(*columns)
sparkdf.show()
+---+------------------+
| id| dt|
+---+------------------+
| 9| 941.8761129079987|
| 3|390.16172459144695|
| 2| 242.1360392045945|
| 2|216.62042402269836|
| 7| 774.128580355669|
| 6| 637.5278973819181|
| 9| 977.6018401145111|
| 3|388.24187511910503|
| 9| 965.4869912727394|
| 8| 820.6383035808919|
| 9| 992.4769742044103|
| 2|237.41710242214492|
| 7| 791.88156236064|
| 3|352.29105395671434|
| 3|360.15990622238814|
| 2| 223.3502382932214|
| 3| 362.1403963446267|
| 3|399.03128048047455|
| 4|474.08328044389805|
| 8| 885.475638549277|
+---+------------------+
only showing top 20 rows
Group and aggregate data
A typical function is to group the data based on the informaiton in a column and apply some functions to the groupped data. I use here the pyspark Window module which in combination with other buid-in spark functions offers a powerful way to do the data analysis.
Group and select
- In this example I group the data by "id", order by "dt" and add a "count" column
w = Window.partitionBy("id").orderBy("dt")
sparkdf.select("id","dt",count("dt").over(w).alias("count")).show()
+---+------------------+-----+
| id| dt|count|
+---+------------------+-----+
| 7| 709.5887233317628| 1|
| 7| 739.6305448935279| 2|
| 7| 759.92889335773| 3|
| 7| 774.128580355669| 4|
| 7| 791.88156236064| 5|
| 7| 794.2388948556547| 6|
| 7| 799.6724743949059| 7|
| 6| 607.3996650878397| 1|
| 6| 637.5278973819181| 2|
| 9| 904.2611870491163| 1|
| 9| 922.228736608961| 2|
| 9| 941.8761129079987| 3|
| 9| 965.4869912727394| 4|
| 9| 972.3761922794508| 5|
| 9| 977.6018401145111| 6|
| 9| 992.4769742044103| 7|
| 9| 997.7288430877943| 8|
| 5|503.51719225583287| 1|
| 5| 518.5626967253762| 2|
| 5| 562.0211027479502| 3|
+---+------------------+-----+
only showing top 20 rows
- the order can be in descending order
w = Window.partitionBy("id").orderBy(sparkdf["dt"].desc())
sparkdf.select("id","dt",count("dt").over(w).alias("count")).show()
+---+------------------+-----+
| id| dt|count|
+---+------------------+-----+
| 7| 799.6724743949059| 1|
| 7| 794.2388948556547| 2|
| 7| 791.88156236064| 3|
| 7| 774.128580355669| 4|
| 7| 759.92889335773| 5|
| 7| 739.6305448935279| 6|
| 7| 709.5887233317628| 7|
| 6| 637.5278973819181| 1|
| 6| 607.3996650878397| 2|
| 9| 997.7288430877943| 1|
| 9| 992.4769742044103| 2|
| 9| 977.6018401145111| 3|
| 9| 972.3761922794508| 4|
| 9| 965.4869912727394| 5|
| 9| 941.8761129079987| 6|
| 9| 922.228736608961| 7|
| 9| 904.2611870491163| 8|
| 5| 562.0211027479502| 1|
| 5| 518.5626967253762| 2|
| 5|503.51719225583287| 3|
+---+------------------+-----+
only showing top 20 rows
- select the first element in the group
sparkdf.select("*", rank().over(w).alias('rank')).show()
+---+------------------+----+
| id| dt|rank|
+---+------------------+----+
| 7| 799.6724743949059| 1|
| 7| 794.2388948556547| 2|
| 7| 791.88156236064| 3|
| 7| 774.128580355669| 4|
| 7| 759.92889335773| 5|
| 7| 739.6305448935279| 6|
| 7| 709.5887233317628| 7|
| 6| 637.5278973819181| 1|
| 6| 607.3996650878397| 2|
| 9| 997.7288430877943| 1|
| 9| 992.4769742044103| 2|
| 9| 977.6018401145111| 3|
| 9| 972.3761922794508| 4|
| 9| 965.4869912727394| 5|
| 9| 941.8761129079987| 6|
| 9| 922.228736608961| 7|
| 9| 904.2611870491163| 8|
| 5| 562.0211027479502| 1|
| 5| 518.5626967253762| 2|
| 5|503.51719225583287| 3|
+---+------------------+----+
only showing top 20 rows
sparkdf.select("*", rank().over(w).alias('rank')).filter(col('rank') == 1).show()
+---+------------------+----+
| id| dt|rank|
+---+------------------+----+
| 7| 799.6724743949059| 1|
| 6| 637.5278973819181| 1|
| 9| 997.7288430877943| 1|
| 5| 562.0211027479502| 1|
| 1| 183.3569183212222| 1|
| 3|399.03128048047455| 1|
| 8| 885.475638549277| 1|
| 2| 274.5040163491951| 1|
| 4|474.08328044389805| 1|
+---+------------------+----+
- select the first element in the group (alternative method)
_tmp = sparkdf.select("id","dt",count("dt").over(w).alias("count"))
_tmp.filter(_tmp["count"] == 1).show()
+---+------------------+-----+
| id| dt|count|
+---+------------------+-----+
| 7| 799.6724743949059| 1|
| 6| 637.5278973819181| 1|
| 9| 997.7288430877943| 1|
| 5| 562.0211027479502| 1|
| 1| 183.3569183212222| 1|
| 3|399.03128048047455| 1|
| 8| 885.475638549277| 1|
| 2| 274.5040163491951| 1|
| 4|474.08328044389805| 1|
+---+------------------+-----+
Select (Filter) data
Select rows based on values in columns
_tmp.filter((_tmp["count"] == 1) | (_tmp["count"] == 3)).show()
+---+------------------+-----+
| id| dt|count|
+---+------------------+-----+
| 7| 799.6724743949059| 1|
| 7| 791.88156236064| 3|
| 6| 637.5278973819181| 1|
| 9| 997.7288430877943| 1|
| 9| 977.6018401145111| 3|
| 5| 562.0211027479502| 1|
| 5|503.51719225583287| 3|
| 1| 183.3569183212222| 1|
| 1|163.95075537903253| 3|
| 3|399.03128048047455| 1|
| 3|388.24187511910503| 3|
| 8| 885.475638549277| 1|
| 8| 820.6383035808919| 3|
| 2| 274.5040163491951| 1|
| 2|237.41710242214492| 3|
| 4|474.08328044389805| 1|
| 4| 418.3061092636971| 3|
+---+------------------+-----+
- use a filter dataframe. This opens the possibility to use any function to generate the selection data
filterdata = [1,2,3]
filter_df = spark.createDataFrame(filterdata, IntegerType())
_tmp.join(filter_df, _tmp['count'] == filter_df['value']).show()
+---+------------------+-----+-----+
| id| dt|count|value|
+---+------------------+-----+-----+
| 1| 183.3569183212222| 1| 1|
| 2| 274.5040163491951| 1| 1|
| 7| 799.6724743949059| 1| 1|
| 5| 562.0211027479502| 1| 1|
| 6| 637.5278973819181| 1| 1|
| 3|399.03128048047455| 1| 1|
| 9| 997.7288430877943| 1| 1|
| 8| 885.475638549277| 1| 1|
| 4|474.08328044389805| 1| 1|
| 1|163.95075537903253| 3| 3|
| 2|237.41710242214492| 3| 3|
| 9| 977.6018401145111| 3| 3|
| 7| 791.88156236064| 3| 3|
| 8| 820.6383035808919| 3| 3|
| 5|503.51719225583287| 3| 3|
| 3|388.24187511910503| 3| 3|
| 4| 418.3061092636971| 3| 3|
| 1| 165.0773077929639| 2| 2|
| 2| 242.1360392045945| 2| 2|
| 9| 992.4769742044103| 2| 2|
+---+------------------+-----+-----+
only showing top 20 rows
Group and Aggregate
Apply library or user functions to data
sparkdf.groupby('id').agg({'dt':'mean'}).show()
+---+------------------+
| id| avg(dt)|
+---+------------------+
| 7| 767.0099533642699|
| 6| 622.463781234879|
| 9| 959.2546096906227|
| 5| 528.0336639097197|
| 1|157.36084603583848|
| 3|354.68385362150997|
| 8| 849.4268890773079|
| 2| 230.7501911178383|
| 4| 446.2081817399182|
+---+------------------+
# -- import pre-defined functions of pyspark. For a complete list visit the pyspark.sql module
# in
from pyspark.sql.functions import mean, avg, stddev, stddev_pop, stddev_samp, collect_list
sparkdf.groupBy("id").agg(
count(col("dt")).alias('count'),
avg(col("dt")).alias('avg'),
stddev(col("dt")).alias("stdev"),
stddev_pop(col("dt")).alias("stdev_pop"),
stddev_samp(col("dt")).alias("stdev_samp")
).show()
+---+-----+------------------+------------------+------------------+------------------+
| id|count| avg| stdev| stdev_pop| stdev_samp|
+---+-----+------------------+------------------+------------------+------------------+
| 7| 7| 767.0099533642699| 33.10211069956878|30.646599430556815| 33.10211069956878|
| 6| 2| 622.463781234879| 21.30387736030639|15.064116147039215| 21.30387736030639|
| 9| 8| 959.2546096906228| 33.44256371568842| 31.28265388986594| 33.44256371568842|
| 5| 3| 528.0336639097197|30.380113565827955|24.805258854694554|30.380113565827955|
| 1| 4|157.36084603583848|28.302257885932015|24.510474313675587|28.302257885932015|
| 3| 12|354.68385362150997|35.306444778914816|33.803347309835274|35.306444778914816|
| 8| 3| 849.4268890773079| 33.02275711028052|26.962968273350803| 33.02275711028052|
| 2| 8|230.75019111783837| 21.51951205456684| 20.12966030968522| 21.51951205456684|
| 4| 3| 446.2081817399182|27.888595373461367|22.770942769308004|27.888595373461367|
+---+-----+------------------+------------------+------------------+------------------+
Apply a user function
In the first example we apply a simple function that does a calculation on the input data. In the second example we use additional arguments to use local data.
Important: the return value of the functions must be of well defiend type following the pySpark Types. Otherwise we get an error
qsrt_udf = func.udf(lambda x: float(np.sqrt(x*x)), DoubleType())
sparkdf.withColumn('tsval', qsrt_udf(col('dt'))).show()
+---+------------------+------------------+
| id| dt| tsval|
+---+------------------+------------------+
| 9| 941.8761129079987| 941.8761129079987|
| 3|390.16172459144695|390.16172459144695|
| 2| 242.1360392045945| 242.1360392045945|
| 2|216.62042402269836|216.62042402269836|
| 7| 774.128580355669| 774.128580355669|
| 6| 637.5278973819181| 637.5278973819181|
| 9| 977.6018401145111| 977.6018401145111|
| 3|388.24187511910503|388.24187511910503|
| 9| 965.4869912727394| 965.4869912727394|
| 8| 820.6383035808919| 820.6383035808919|
| 9| 992.4769742044103| 992.4769742044103|
| 2|237.41710242214492|237.41710242214492|
| 7| 791.88156236064| 791.88156236064|
| 3|352.29105395671434|352.29105395671434|
| 3|360.15990622238814|360.15990622238814|
| 2| 223.3502382932214| 223.3502382932214|
| 3| 362.1403963446267| 362.1403963446267|
| 3|399.03128048047455|399.03128048047455|
| 4|474.08328044389805|474.08328044389805|
| 8| 885.475638549277| 885.475638549277|
+---+------------------+------------------+
only showing top 20 rows
aa = 'ABCDEFGHIJ'
def mystr(ix, aa):
return str(aa[ix-1:ix])
my_udf = func.udf(lambda x: mystr(x, aa), StringType())
sparkdf.withColumn('myudf', my_udf(col('id'))).show()
+---+------------------+-----+
| id| dt|myudf|
+---+------------------+-----+
| 9| 941.8761129079987| I|
| 3|390.16172459144695| C|
| 2| 242.1360392045945| B|
| 2|216.62042402269836| B|
| 7| 774.128580355669| G|
| 6| 637.5278973819181| F|
| 9| 977.6018401145111| I|
| 3|388.24187511910503| C|
| 9| 965.4869912727394| I|
| 8| 820.6383035808919| H|
| 9| 992.4769742044103| I|
| 2|237.41710242214492| B|
| 7| 791.88156236064| G|
| 3|352.29105395671434| C|
| 3|360.15990622238814| C|
| 2| 223.3502382932214| B|
| 3| 362.1403963446267| C|
| 3|399.03128048047455| C|
| 4|474.08328044389805| D|
| 8| 885.475638549277| H|
+---+------------------+-----+
only showing top 20 rows
_aux = sparkdf.groupBy("id").agg(
count(col("dt")).alias('count'),
avg(col("dt")).alias('avg'),
stddev(col("dt")).alias("stdev"),
stddev_pop(col("dt")).alias("stdev_pop"),
stddev_samp(col("dt")).alias("stdev_samp"),
collect_list(col("dt")).alias("elements")
)
_aux.select('id','count','elements').show()
+---+-----+--------------------+
| id|count| elements|
+---+-----+--------------------+
| 7| 7|[759.92889335773,...|
| 6| 2|[637.527897381918...|
| 9| 8|[977.601840114511...|
| 5| 3|[503.517192255832...|
| 1| 4|[163.950755379032...|
| 3| 12|[388.241875119105...|
| 8| 3|[885.475638549277...|
| 2| 8|[223.350238293221...|
| 4| 3|[446.235155512159...|
+---+-----+--------------------+
_aux.printSchema()
root
|-- id: long (nullable = true)
|-- count: long (nullable = false)
|-- avg: double (nullable = true)
|-- stdev: double (nullable = true)
|-- stdev_pop: double (nullable = true)
|-- stdev_samp: double (nullable = true)
|-- elements: array (nullable = true)
| |-- element: double (containsNull = true)
Add a column to the dataframe
from pyspark.sql.functions import lit
_aux.withColumn("newcol", lit(0)).show()
+---+-----+------------------+------------------+------------------+------------------+--------------------+------+
| id|count| avg| stdev| stdev_pop| stdev_samp| elements|newcol|
+---+-----+------------------+------------------+------------------+------------------+--------------------+------+
| 7| 7| 767.0099533642702|33.102110699568776|30.646599430556808|33.102110699568776|[759.92889335773,...| 0|
| 6| 2| 622.463781234879| 21.30387736030639|15.064116147039215| 21.30387736030639|[637.527897381918...| 0|
| 9| 8| 959.2546096906228|33.442563715688415|31.282653889865937|33.442563715688415|[977.601840114511...| 0|
| 5| 3| 528.0336639097197|30.380113565827955|24.805258854694554|30.380113565827955|[503.517192255832...| 0|
| 1| 4|157.36084603583848|28.302257885932015|24.510474313675587|28.302257885932015|[163.950755379032...| 0|
| 3| 12|354.68385362150997| 35.30644477891481| 33.80334730983527| 35.30644477891481|[390.161724591446...| 0|
| 8| 3| 849.4268890773079|33.022757110280516|26.962968273350796|33.022757110280516|[885.475638549277...| 0|
| 2| 8|230.75019111783834|21.519512054566842|20.129660309685224|21.519512054566842|[237.417102422144...| 0|
| 4| 3| 446.2081817399182|27.888595373461367|22.770942769308004|27.888595373461367|[446.235155512159...| 0|
+---+-----+------------------+------------------+------------------+------------------+--------------------+------+
_aux.withColumn("strcol", lit('dummy')).show()
+---+-----+------------------+------------------+------------------+------------------+--------------------+------+
| id|count| avg| stdev| stdev_pop| stdev_samp| elements|strcol|
+---+-----+------------------+------------------+------------------+------------------+--------------------+------+
| 7| 7| 767.0099533642699| 33.10211069956878|30.646599430556815| 33.10211069956878|[791.88156236064,...| dummy|
| 6| 2| 622.463781234879| 21.30387736030639|15.064116147039215| 21.30387736030639|[637.527897381918...| dummy|
| 9| 8| 959.2546096906228|33.442563715688415|31.282653889865937|33.442563715688415|[977.601840114511...| dummy|
| 5| 3| 528.0336639097197|30.380113565827955|24.805258854694554|30.380113565827955|[503.517192255832...| dummy|
| 1| 4|157.36084603583848|28.302257885932015|24.510474313675587|28.302257885932015|[163.950755379032...| dummy|
| 3| 12| 354.68385362151| 35.30644477891481| 33.80334730983527| 35.30644477891481|[390.161724591446...| dummy|
| 8| 3| 849.4268890773079| 33.02275711028052|26.962968273350803| 33.02275711028052|[820.638303580891...| dummy|
| 2| 8|230.75019111783834|21.519512054566853|20.129660309685235|21.519512054566853|[242.136039204594...| dummy|
| 4| 3| 446.2081817399182|27.888595373461367|22.770942769308004|27.888595373461367|[474.083280443898...| dummy|
+---+-----+------------------+------------------+------------------+------------------+--------------------+------+