Skip to content

PySpark Pivoting

Today's post covers the following:

  • Basic pivot operation
  • Pivot with multiple aggregations
  • Conditional pivoting
  • Pivoting with specified column values

Basic Pivot Operation

from pyspark.sql import SparkSession
from pyspark.sql.functions import sum, avg, when

spark = SparkSession.builder.getOrCreate()

data = [("Alice", "HR", 50000), 
        ("Bob", "IT", 60000), 
        ("Cathy", "HR", 55000)]

df = spark.createDataFrame(data, ["name", "dept", "salary"])
df.show()
+-----+----+------+
| name|dept|salary|
+-----+----+------+
|Alice|  HR| 50000|
|  Bob|  IT| 60000|
|Cathy|  HR| 55000|
+-----+----+------+
# Pivot department to columns with salary summed per name
pivot_df = df.groupBy("name").pivot("dept").sum("salary")
pivot_df.show()
+-----+-----+-----+
| name|   HR|   IT|
+-----+-----+-----+
|  Bob| NULL|60000|
|Alice|50000| NULL|
|Cathy|55000| NULL|
+-----+-----+-----+

Pivot with Multiple Aggregations

You can apply multiple aggregations to each pivoted column using agg()

# Transform rows into columns using the pivot() function, typically after grouping by one or more columns. 
multi_agg_df = df.groupBy("name").pivot("dept").agg(
    sum("salary").alias("total"),
    avg("salary").alias("avg")
)
multi_agg_df.show()
+-----+--------+-------+--------+-------+
| name|HR_total| HR_avg|IT_total| IT_avg|
+-----+--------+-------+--------+-------+
|  Bob|    NULL|   NULL|   60000|60000.0|
|Alice|   50000|50000.0|    NULL|   NULL|
|Cathy|   55000|55000.0|    NULL|   NULL|
+-----+--------+-------+--------+-------+

Conditional Pivoting

# Use expressions with when inside your aggregation to pivot with conditions.
conditional_df = df.groupBy("name")\
                    .pivot("dept")\
                    .agg(
    sum(when(df.salary > 52000, df.salary)).alias("high_salary")
)
conditional_df.show()
+-----+-----+-----+
| name|   HR|   IT|
+-----+-----+-----+
|  Bob| NULL|60000|
|Alice| NULL| NULL|
|Cathy|55000| NULL|
+-----+-----+-----+

Pivoting with Explicit Values

# To control which values are used as output columns, explicitly provide them to pivot()
df.groupBy("name")\
  .pivot("dept", ["HR", "IT","Eng"]).sum("salary").show()
+-----+-----+-----+----+
| name|   HR|   IT| Eng|
+-----+-----+-----+----+
|  Bob| NULL|60000|NULL|
|Alice|50000| NULL|NULL|
|Cathy|55000| NULL|NULL|
+-----+-----+-----+----+