PySpark Pivoting¶
Today's post covers the following:
- Basic pivot operation
- Pivot with multiple aggregations
- Conditional pivoting
- Pivoting with specified column values
Basic Pivot Operation¶
from pyspark.sql import SparkSession
from pyspark.sql.functions import sum, avg, when
spark = SparkSession.builder.getOrCreate()
data = [("Alice", "HR", 50000),
("Bob", "IT", 60000),
("Cathy", "HR", 55000)]
df = spark.createDataFrame(data, ["name", "dept", "salary"])
df.show()
+-----+----+------+
| name|dept|salary|
+-----+----+------+
|Alice| HR| 50000|
| Bob| IT| 60000|
|Cathy| HR| 55000|
+-----+----+------+
# Pivot department to columns with salary summed per name
pivot_df = df.groupBy("name").pivot("dept").sum("salary")
pivot_df.show()
+-----+-----+-----+
| name| HR| IT|
+-----+-----+-----+
| Bob| NULL|60000|
|Alice|50000| NULL|
|Cathy|55000| NULL|
+-----+-----+-----+
Pivot with Multiple Aggregations¶
You can apply multiple aggregations to each pivoted column using agg()
# Transform rows into columns using the pivot() function, typically after grouping by one or more columns.
multi_agg_df = df.groupBy("name").pivot("dept").agg(
sum("salary").alias("total"),
avg("salary").alias("avg")
)
multi_agg_df.show()
+-----+--------+-------+--------+-------+
| name|HR_total| HR_avg|IT_total| IT_avg|
+-----+--------+-------+--------+-------+
| Bob| NULL| NULL| 60000|60000.0|
|Alice| 50000|50000.0| NULL| NULL|
|Cathy| 55000|55000.0| NULL| NULL|
+-----+--------+-------+--------+-------+
Conditional Pivoting¶
# Use expressions with when inside your aggregation to pivot with conditions.
conditional_df = df.groupBy("name")\
.pivot("dept")\
.agg(
sum(when(df.salary > 52000, df.salary)).alias("high_salary")
)
conditional_df.show()
+-----+-----+-----+
| name| HR| IT|
+-----+-----+-----+
| Bob| NULL|60000|
|Alice| NULL| NULL|
|Cathy|55000| NULL|
+-----+-----+-----+