Skip to content

PySpark Aggregations

Todays post covers the following:

  • Aggregations without grouping
  • Aggregations with grouping
  • Filtering after grouping

Aggregations without grouping

You can refer to columns using any of these notations: df.age , df['age'], col('age') Basic Filtering

#Count rows
df.count()

#Count Distinct Values in a column
df.select(countDistinct("Department")).show()

#Sum
df.select(sum("Salary")).show()

#Multiple Aggregations
df.select(min("Salary"), max("Salary")).show()

Aggregations with Grouping

#Group by a single column
df.groupBy("Department").sum("Salary").show()

#GroupBy with Multiple Columns
df.groupBy("Department", "Employee").sum("Salary").show()

#Group by with multiple aggregations
df.groupBy("Department").agg(
                              count("Employee").alias("Employee_Count"),
                              avg("Salary").alias("Average_Salary"),
                              max("Salary").alias("Max_Salary")
)

Filtration after aggregation

Like in SQL filtration after grouping data (having)

#Filter after aggregation
df.groupBy("Department").agg(sum("Salary").alias("Total_Salary")).filter("Total_Salary > 8000").show()

Filtration after aggregation

Commonly used aggregation

Function Description Example
count() Counts rows in a group. groupBy("Department").count()
sum() Sums values in a group. groupBy("Department").sum("Salary")
avg() / mean() Calculates average values. groupBy("Department").avg("Salary")
min() Finds the minimum value. groupBy("Department").min("Salary")
max() Finds the maximum value. groupBy("Department").max("Salary")
countDistinct() Counts distinct values in a group. countDistinct("Employee")
collect_list() Collects all values into a list. collect_list("Employee")
collect_set() Collects unique values into a set. collect_set("Employee")