Skip to content

HowTo's with Spark DataFrames

Some examples on how to do analysis using spark DataFrames

Based on information I got from the web, some example cites:

For the pyspark sql module functions look here: - https://spark.apache.org/docs/2.1.0/api/python/pyspark.sql.html

!which python
/cvmfs/sft.cern.ch/lcg/views/LCG_95apython3_nxcals/x86_64-centos7-gcc7-opt/bin/python
from cern.nxcals.api.extraction.data.builders import *
from pyspark.sql.functions import rank, col, count
from pyspark.sql.window import Window
from pyspark.sql.types import StructType
from pyspark.sql.types import StructField
from pyspark.sql.types import DoubleType
from pyspark.sql.types import ArrayType
from pyspark.sql.types import IntegerType
from pyspark.sql.types import StringType

import pyspark.sql.functions as func
import time
import numpy as np
import pandas as pd 
import os
from matplotlib import pyplot as plt

# sc - Spark Context
# spark - Spark Session in Memory
spark
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

Create a dataFrame

columns=["id", "dt"]
data = [ (int(i), float(i*100)) for i in np.random.uniform(1, 10, 50)]
sparkdf = spark.createDataFrame(data).toDF(*columns)
sparkdf.show()
+---+------------------+
| id|                dt|
+---+------------------+
|  9| 941.8761129079987|
|  3|390.16172459144695|
|  2| 242.1360392045945|
|  2|216.62042402269836|
|  7|  774.128580355669|
|  6| 637.5278973819181|
|  9| 977.6018401145111|
|  3|388.24187511910503|
|  9| 965.4869912727394|
|  8| 820.6383035808919|
|  9| 992.4769742044103|
|  2|237.41710242214492|
|  7|   791.88156236064|
|  3|352.29105395671434|
|  3|360.15990622238814|
|  2| 223.3502382932214|
|  3| 362.1403963446267|
|  3|399.03128048047455|
|  4|474.08328044389805|
|  8|  885.475638549277|
+---+------------------+
only showing top 20 rows

Group and aggregate data

A typical function is to group the data based on the informaiton in a column and apply some functions to the groupped data. I use here the pyspark Window module which in combination with other buid-in spark functions offers a powerful way to do the data analysis.

Group and select

  • In this example I group the data by "id", order by "dt" and add a "count" column
w = Window.partitionBy("id").orderBy("dt")
sparkdf.select("id","dt",count("dt").over(w).alias("count")).show()
+---+------------------+-----+
| id|                dt|count|
+---+------------------+-----+
|  7| 709.5887233317628|    1|
|  7| 739.6305448935279|    2|
|  7|   759.92889335773|    3|
|  7|  774.128580355669|    4|
|  7|   791.88156236064|    5|
|  7| 794.2388948556547|    6|
|  7| 799.6724743949059|    7|
|  6| 607.3996650878397|    1|
|  6| 637.5278973819181|    2|
|  9| 904.2611870491163|    1|
|  9|  922.228736608961|    2|
|  9| 941.8761129079987|    3|
|  9| 965.4869912727394|    4|
|  9| 972.3761922794508|    5|
|  9| 977.6018401145111|    6|
|  9| 992.4769742044103|    7|
|  9| 997.7288430877943|    8|
|  5|503.51719225583287|    1|
|  5| 518.5626967253762|    2|
|  5| 562.0211027479502|    3|
+---+------------------+-----+
only showing top 20 rows
  • the order can be in descending order
w = Window.partitionBy("id").orderBy(sparkdf["dt"].desc())
sparkdf.select("id","dt",count("dt").over(w).alias("count")).show()
+---+------------------+-----+
| id|                dt|count|
+---+------------------+-----+
|  7| 799.6724743949059|    1|
|  7| 794.2388948556547|    2|
|  7|   791.88156236064|    3|
|  7|  774.128580355669|    4|
|  7|   759.92889335773|    5|
|  7| 739.6305448935279|    6|
|  7| 709.5887233317628|    7|
|  6| 637.5278973819181|    1|
|  6| 607.3996650878397|    2|
|  9| 997.7288430877943|    1|
|  9| 992.4769742044103|    2|
|  9| 977.6018401145111|    3|
|  9| 972.3761922794508|    4|
|  9| 965.4869912727394|    5|
|  9| 941.8761129079987|    6|
|  9|  922.228736608961|    7|
|  9| 904.2611870491163|    8|
|  5| 562.0211027479502|    1|
|  5| 518.5626967253762|    2|
|  5|503.51719225583287|    3|
+---+------------------+-----+
only showing top 20 rows
  • select the first element in the group
sparkdf.select("*", rank().over(w).alias('rank')).show()
+---+------------------+----+
| id|                dt|rank|
+---+------------------+----+
|  7| 799.6724743949059|   1|
|  7| 794.2388948556547|   2|
|  7|   791.88156236064|   3|
|  7|  774.128580355669|   4|
|  7|   759.92889335773|   5|
|  7| 739.6305448935279|   6|
|  7| 709.5887233317628|   7|
|  6| 637.5278973819181|   1|
|  6| 607.3996650878397|   2|
|  9| 997.7288430877943|   1|
|  9| 992.4769742044103|   2|
|  9| 977.6018401145111|   3|
|  9| 972.3761922794508|   4|
|  9| 965.4869912727394|   5|
|  9| 941.8761129079987|   6|
|  9|  922.228736608961|   7|
|  9| 904.2611870491163|   8|
|  5| 562.0211027479502|   1|
|  5| 518.5626967253762|   2|
|  5|503.51719225583287|   3|
+---+------------------+----+
only showing top 20 rows
sparkdf.select("*", rank().over(w).alias('rank')).filter(col('rank') == 1).show()
+---+------------------+----+
| id|                dt|rank|
+---+------------------+----+
|  7| 799.6724743949059|   1|
|  6| 637.5278973819181|   1|
|  9| 997.7288430877943|   1|
|  5| 562.0211027479502|   1|
|  1| 183.3569183212222|   1|
|  3|399.03128048047455|   1|
|  8|  885.475638549277|   1|
|  2| 274.5040163491951|   1|
|  4|474.08328044389805|   1|
+---+------------------+----+
  • select the first element in the group (alternative method)
_tmp = sparkdf.select("id","dt",count("dt").over(w).alias("count"))
_tmp.filter(_tmp["count"] == 1).show()
+---+------------------+-----+
| id|                dt|count|
+---+------------------+-----+
|  7| 799.6724743949059|    1|
|  6| 637.5278973819181|    1|
|  9| 997.7288430877943|    1|
|  5| 562.0211027479502|    1|
|  1| 183.3569183212222|    1|
|  3|399.03128048047455|    1|
|  8|  885.475638549277|    1|
|  2| 274.5040163491951|    1|
|  4|474.08328044389805|    1|
+---+------------------+-----+

Select (Filter) data

Select rows based on values in columns

_tmp.filter((_tmp["count"] == 1) | (_tmp["count"] == 3)).show()
+---+------------------+-----+
| id|                dt|count|
+---+------------------+-----+
|  7| 799.6724743949059|    1|
|  7|   791.88156236064|    3|
|  6| 637.5278973819181|    1|
|  9| 997.7288430877943|    1|
|  9| 977.6018401145111|    3|
|  5| 562.0211027479502|    1|
|  5|503.51719225583287|    3|
|  1| 183.3569183212222|    1|
|  1|163.95075537903253|    3|
|  3|399.03128048047455|    1|
|  3|388.24187511910503|    3|
|  8|  885.475638549277|    1|
|  8| 820.6383035808919|    3|
|  2| 274.5040163491951|    1|
|  2|237.41710242214492|    3|
|  4|474.08328044389805|    1|
|  4| 418.3061092636971|    3|
+---+------------------+-----+
  • use a filter dataframe. This opens the possibility to use any function to generate the selection data
filterdata = [1,2,3]
filter_df = spark.createDataFrame(filterdata, IntegerType())
_tmp.join(filter_df, _tmp['count'] == filter_df['value']).show()
+---+------------------+-----+-----+
| id|                dt|count|value|
+---+------------------+-----+-----+
|  1| 183.3569183212222|    1|    1|
|  2| 274.5040163491951|    1|    1|
|  7| 799.6724743949059|    1|    1|
|  5| 562.0211027479502|    1|    1|
|  6| 637.5278973819181|    1|    1|
|  3|399.03128048047455|    1|    1|
|  9| 997.7288430877943|    1|    1|
|  8|  885.475638549277|    1|    1|
|  4|474.08328044389805|    1|    1|
|  1|163.95075537903253|    3|    3|
|  2|237.41710242214492|    3|    3|
|  9| 977.6018401145111|    3|    3|
|  7|   791.88156236064|    3|    3|
|  8| 820.6383035808919|    3|    3|
|  5|503.51719225583287|    3|    3|
|  3|388.24187511910503|    3|    3|
|  4| 418.3061092636971|    3|    3|
|  1| 165.0773077929639|    2|    2|
|  2| 242.1360392045945|    2|    2|
|  9| 992.4769742044103|    2|    2|
+---+------------------+-----+-----+
only showing top 20 rows

Group and Aggregate

Apply library or user functions to data

sparkdf.groupby('id').agg({'dt':'mean'}).show()
+---+------------------+
| id|           avg(dt)|
+---+------------------+
|  7| 767.0099533642699|
|  6|  622.463781234879|
|  9| 959.2546096906227|
|  5| 528.0336639097197|
|  1|157.36084603583848|
|  3|354.68385362150997|
|  8| 849.4268890773079|
|  2| 230.7501911178383|
|  4| 446.2081817399182|
+---+------------------+
# -- import pre-defined functions of pyspark. For a complete list visit the pyspark.sql module
# in 

from pyspark.sql.functions import mean, avg, stddev, stddev_pop, stddev_samp, collect_list
sparkdf.groupBy("id").agg(
        count(col("dt")).alias('count'),
        avg(col("dt")).alias('avg'),
        stddev(col("dt")).alias("stdev"),
        stddev_pop(col("dt")).alias("stdev_pop"),
        stddev_samp(col("dt")).alias("stdev_samp")
        ).show()
+---+-----+------------------+------------------+------------------+------------------+
| id|count|               avg|             stdev|         stdev_pop|        stdev_samp|
+---+-----+------------------+------------------+------------------+------------------+
|  7|    7| 767.0099533642699| 33.10211069956878|30.646599430556815| 33.10211069956878|
|  6|    2|  622.463781234879| 21.30387736030639|15.064116147039215| 21.30387736030639|
|  9|    8| 959.2546096906228| 33.44256371568842| 31.28265388986594| 33.44256371568842|
|  5|    3| 528.0336639097197|30.380113565827955|24.805258854694554|30.380113565827955|
|  1|    4|157.36084603583848|28.302257885932015|24.510474313675587|28.302257885932015|
|  3|   12|354.68385362150997|35.306444778914816|33.803347309835274|35.306444778914816|
|  8|    3| 849.4268890773079| 33.02275711028052|26.962968273350803| 33.02275711028052|
|  2|    8|230.75019111783837| 21.51951205456684| 20.12966030968522| 21.51951205456684|
|  4|    3| 446.2081817399182|27.888595373461367|22.770942769308004|27.888595373461367|
+---+-----+------------------+------------------+------------------+------------------+

Apply a user function

In the first example we apply a simple function that does a calculation on the input data. In the second example we use additional arguments to use local data.

Important: the return value of the functions must be of well defiend type following the pySpark Types. Otherwise we get an error

qsrt_udf = func.udf(lambda x: float(np.sqrt(x*x)), DoubleType())
sparkdf.withColumn('tsval', qsrt_udf(col('dt'))).show()
+---+------------------+------------------+
| id|                dt|             tsval|
+---+------------------+------------------+
|  9| 941.8761129079987| 941.8761129079987|
|  3|390.16172459144695|390.16172459144695|
|  2| 242.1360392045945| 242.1360392045945|
|  2|216.62042402269836|216.62042402269836|
|  7|  774.128580355669|  774.128580355669|
|  6| 637.5278973819181| 637.5278973819181|
|  9| 977.6018401145111| 977.6018401145111|
|  3|388.24187511910503|388.24187511910503|
|  9| 965.4869912727394| 965.4869912727394|
|  8| 820.6383035808919| 820.6383035808919|
|  9| 992.4769742044103| 992.4769742044103|
|  2|237.41710242214492|237.41710242214492|
|  7|   791.88156236064|   791.88156236064|
|  3|352.29105395671434|352.29105395671434|
|  3|360.15990622238814|360.15990622238814|
|  2| 223.3502382932214| 223.3502382932214|
|  3| 362.1403963446267| 362.1403963446267|
|  3|399.03128048047455|399.03128048047455|
|  4|474.08328044389805|474.08328044389805|
|  8|  885.475638549277|  885.475638549277|
+---+------------------+------------------+
only showing top 20 rows
aa = 'ABCDEFGHIJ'
def mystr(ix, aa):
    return str(aa[ix-1:ix])
my_udf = func.udf(lambda x: mystr(x, aa), StringType())
sparkdf.withColumn('myudf', my_udf(col('id'))).show()
+---+------------------+-----+
| id|                dt|myudf|
+---+------------------+-----+
|  9| 941.8761129079987|    I|
|  3|390.16172459144695|    C|
|  2| 242.1360392045945|    B|
|  2|216.62042402269836|    B|
|  7|  774.128580355669|    G|
|  6| 637.5278973819181|    F|
|  9| 977.6018401145111|    I|
|  3|388.24187511910503|    C|
|  9| 965.4869912727394|    I|
|  8| 820.6383035808919|    H|
|  9| 992.4769742044103|    I|
|  2|237.41710242214492|    B|
|  7|   791.88156236064|    G|
|  3|352.29105395671434|    C|
|  3|360.15990622238814|    C|
|  2| 223.3502382932214|    B|
|  3| 362.1403963446267|    C|
|  3|399.03128048047455|    C|
|  4|474.08328044389805|    D|
|  8|  885.475638549277|    H|
+---+------------------+-----+
only showing top 20 rows
_aux = sparkdf.groupBy("id").agg(
        count(col("dt")).alias('count'),
        avg(col("dt")).alias('avg'),
        stddev(col("dt")).alias("stdev"),
        stddev_pop(col("dt")).alias("stdev_pop"),
        stddev_samp(col("dt")).alias("stdev_samp"),
        collect_list(col("dt")).alias("elements")
        )
_aux.select('id','count','elements').show()
+---+-----+--------------------+
| id|count|            elements|
+---+-----+--------------------+
|  7|    7|[759.92889335773,...|
|  6|    2|[637.527897381918...|
|  9|    8|[977.601840114511...|
|  5|    3|[503.517192255832...|
|  1|    4|[163.950755379032...|
|  3|   12|[388.241875119105...|
|  8|    3|[885.475638549277...|
|  2|    8|[223.350238293221...|
|  4|    3|[446.235155512159...|
+---+-----+--------------------+
_aux.printSchema()
root
 |-- id: long (nullable = true)
 |-- count: long (nullable = false)
 |-- avg: double (nullable = true)
 |-- stdev: double (nullable = true)
 |-- stdev_pop: double (nullable = true)
 |-- stdev_samp: double (nullable = true)
 |-- elements: array (nullable = true)
 |    |-- element: double (containsNull = true)

Add a column to the dataframe

from pyspark.sql.functions import lit

_aux.withColumn("newcol", lit(0)).show()
+---+-----+------------------+------------------+------------------+------------------+--------------------+------+
| id|count|               avg|             stdev|         stdev_pop|        stdev_samp|            elements|newcol|
+---+-----+------------------+------------------+------------------+------------------+--------------------+------+
|  7|    7| 767.0099533642702|33.102110699568776|30.646599430556808|33.102110699568776|[759.92889335773,...|     0|
|  6|    2|  622.463781234879| 21.30387736030639|15.064116147039215| 21.30387736030639|[637.527897381918...|     0|
|  9|    8| 959.2546096906228|33.442563715688415|31.282653889865937|33.442563715688415|[977.601840114511...|     0|
|  5|    3| 528.0336639097197|30.380113565827955|24.805258854694554|30.380113565827955|[503.517192255832...|     0|
|  1|    4|157.36084603583848|28.302257885932015|24.510474313675587|28.302257885932015|[163.950755379032...|     0|
|  3|   12|354.68385362150997| 35.30644477891481| 33.80334730983527| 35.30644477891481|[390.161724591446...|     0|
|  8|    3| 849.4268890773079|33.022757110280516|26.962968273350796|33.022757110280516|[885.475638549277...|     0|
|  2|    8|230.75019111783834|21.519512054566842|20.129660309685224|21.519512054566842|[237.417102422144...|     0|
|  4|    3| 446.2081817399182|27.888595373461367|22.770942769308004|27.888595373461367|[446.235155512159...|     0|
+---+-----+------------------+------------------+------------------+------------------+--------------------+------+
_aux.withColumn("strcol", lit('dummy')).show()
+---+-----+------------------+------------------+------------------+------------------+--------------------+------+
| id|count|               avg|             stdev|         stdev_pop|        stdev_samp|            elements|strcol|
+---+-----+------------------+------------------+------------------+------------------+--------------------+------+
|  7|    7| 767.0099533642699| 33.10211069956878|30.646599430556815| 33.10211069956878|[791.88156236064,...| dummy|
|  6|    2|  622.463781234879| 21.30387736030639|15.064116147039215| 21.30387736030639|[637.527897381918...| dummy|
|  9|    8| 959.2546096906228|33.442563715688415|31.282653889865937|33.442563715688415|[977.601840114511...| dummy|
|  5|    3| 528.0336639097197|30.380113565827955|24.805258854694554|30.380113565827955|[503.517192255832...| dummy|
|  1|    4|157.36084603583848|28.302257885932015|24.510474313675587|28.302257885932015|[163.950755379032...| dummy|
|  3|   12|   354.68385362151| 35.30644477891481| 33.80334730983527| 35.30644477891481|[390.161724591446...| dummy|
|  8|    3| 849.4268890773079| 33.02275711028052|26.962968273350803| 33.02275711028052|[820.638303580891...| dummy|
|  2|    8|230.75019111783834|21.519512054566853|20.129660309685235|21.519512054566853|[242.136039204594...| dummy|
|  4|    3| 446.2081817399182|27.888595373461367|22.770942769308004|27.888595373461367|[474.083280443898...| dummy|
+---+-----+------------------+------------------+------------------+------------------+--------------------+------+