Half of every real Spark job ends with the same shape of question: “for each X, what’s the total / average / count of Y?”. Revenue per country. Orders per customer. Login attempts per hour. Click-through per campaign. The answer always involves groupBy and agg, and once you know the pattern you’ll write fifty of these a week without thinking.
This lesson is the catalog. We’ll cover the everyday aggregate functions, the single-pass idiom that scans your data once instead of N times, and the one detail that sets aggregation apart from everything we’ve done so far: it’s a wide transformation. Spark has to shuffle. We’ll set up the why now and unpack the cost properly in lesson 21.
Setup
A small DataFrame we can poke at:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = (SparkSession.builder
.appName("Aggregations101")
.master("local[*]")
.config("spark.sql.shuffle.partitions", "8")
.getOrCreate())
orders = spark.createDataFrame(
[
(1001, 1, 59.00, "NL", "2026-03-05"),
(1002, 1, 29.00, "NL", "2026-03-18"),
(1003, 2, 149.00, "IT", "2026-02-15"),
(1004, 2, 89.50, "IT", "2026-03-22"),
(1005, 3, 199.00, "DE", "2026-03-10"),
(1006, 4, 42.42, "RO", "2026-03-28"),
(1007, 1, 12.00, "NL", "2026-03-30"),
(1008, 2, 75.00, "IT", "2026-03-31"),
],
"OrderId INT, CustomerId INT, Total DOUBLE, Country STRING, OrderDate STRING",
)
Eight rows is enough to demonstrate everything; the patterns scale to eight billion identically.
groupBy returns a GroupedData, not a DataFrame
This is the gotcha that trips everyone the first time:
g = orders.groupBy("Country")
print(type(g))
# <class 'pyspark.sql.group.GroupedData'>
GroupedData is an intermediate object. It’s the “I’ve decided how to group, but I haven’t told you what to compute yet” stage. You can’t call .show() on it, you can’t .write it, you can’t do anything DataFrame-shaped. You have to feed it into one of these:
.count()— shorthand, returns a DataFrame with one column calledcount..sum("col"),.avg("col"),.min("col"),.max("col"),.mean("col")— single-aggregate shortcuts..agg(...)— the full power, multiple aggregates at once with renames and expressions.
The shortcuts are convenient for one-off lookups:
orders.groupBy("Country").count().show()
# +-------+-----+
# |Country|count|
# +-------+-----+
# | NL| 3|
# | IT| 3|
# | DE| 1|
# | RO| 1|
# +-------+-----+
orders.groupBy("Country").sum("Total").show()
# +-------+----------+
# |Country|sum(Total)|
# +-------+----------+
# | NL| 100.0|
# | IT| 313.5|
# | DE| 199.0|
# | RO| 42.42|
# +-------+----------+
For anything beyond a single number, drop the shortcuts and use .agg(...).
The single-pass agg pattern
Here’s the thing nobody tells you: every separate aggregation is a separate scan. If you write three queries to compute three KPIs, Spark reads the data three times. Computers love sequential reads but they don’t love them that much.
Pass everything into one agg(...) call:
orders.groupBy("Country").agg(
F.count("*").alias("orders"),
F.sum("Total").alias("revenue"),
F.avg("Total").alias("aov"), # average order value
F.min("Total").alias("smallest_order"),
F.max("Total").alias("biggest_order"),
F.countDistinct("CustomerId").alias("unique_customers"),
).show()
# +-------+------+-------+-----+--------------+-------------+----------------+
# |Country|orders|revenue| aov|smallest_order|biggest_order|unique_customers|
# +-------+------+-------+-----+--------------+-------------+----------------+
# | NL| 3| 100.0|33.33| 12.0| 59.0| 1|
# | IT| 3| 313.5|104.5| 75.0| 149.0| 1|
# | DE| 1| 199.0|199.0| 199.0| 199.0| 1|
# | RO| 1| 42.42|42.42| 42.42| 42.42| 1|
# +-------+------+-------+-----+--------------+-------------+----------------+
One scan, six aggregates, every metric a business person would ask for. This is the shape you want to write by reflex.
The .alias(...) calls are not optional in spirit — without them you get column names like sum(Total) and avg(Total) that are painful to reference downstream and ugly in dashboards. Always alias.
The catalog
The aggregate functions you’ll actually use, ranked by how often I personally type them:
F.count("*") # row count, never null-skipping
F.count("col") # row count, NULLs excluded (lesson 13)
F.sum("col") # total
F.avg("col") / F.mean("col") # arithmetic mean (aliases for each other)
F.min("col") / F.max("col") # smallest / largest
F.countDistinct("col") # unique value count — exact, expensive
F.approx_count_distinct("col") # HyperLogLog estimate, much cheaper
F.stddev("col") / F.variance("col") # sample standard deviation / variance
F.stddev_pop("col") # population variant if that's what you need
F.collect_list("col") # gather all values (with duplicates) into an array
F.collect_set("col") # gather unique values into an array
F.first("col") / F.last("col") # first/last value in each group (order-dependent — be careful)
Two of those are worth a closer look.
approx_count_distinct is the one that sounds scary and is actually a gift. Counting distinct values exactly across a billion rows means Spark has to keep track of every value it’s seen — memory-expensive and slow. The approximate version uses HyperLogLog and gives you ~2% relative error for orders of magnitude less work. For dashboards and “active users” KPIs, that’s almost always fine:
orders.groupBy("Country").agg(
F.countDistinct("CustomerId").alias("exact_unique"),
F.approx_count_distinct("CustomerId").alias("approx_unique"),
).show()
collect_list and collect_set let you fold a group into an array — useful when you want one row per group with a list of all the related values:
orders.groupBy("CustomerId").agg(
F.count("*").alias("orders"),
F.collect_list("OrderId").alias("order_ids"),
F.collect_set("Country").alias("countries_ordered_from"),
).show(truncate=False)
# +----------+------+----------------+----------------------+
# |CustomerId|orders|order_ids |countries_ordered_from|
# +----------+------+----------------+----------------------+
# |1 |3 |[1001, 1002, 1007]|[NL] |
# |2 |3 |[1003, 1004, 1008]|[IT] |
# |3 |1 |[1005] |[DE] |
# |4 |1 |[1006] |[RO] |
# +----------+------+----------------+----------------------+
Watch out for collect_list on big groups — if a single group has millions of rows, you’re building a million-element array on a single executor. That’s a memory bomb. For analytics, prefer aggregates that produce scalars.
Why aggregation is a wide transformation
Until now, every transformation we’ve written has been narrow: each output partition depends on exactly one input partition. select, filter, withColumn — all narrow. Spark just streams rows through whatever executors already hold them.
groupBy breaks that. To compute the sum for Country = 'NL', every NL row has to end up on the same executor. They don’t start there — they’re scattered across partitions, wherever the data happened to land when it was read. Spark has to shuffle: hash each row by Country, send it across the network to the appropriate executor, then aggregate.
That network step is the most expensive thing Spark does. It’s why the rule of thumb “filter before you aggregate” exists — every row you where away is a row that doesn’t have to ride the network. It’s why partitioning your data well at write time matters. It’s why a 100GB join can take ten minutes while a 100GB filter takes thirty seconds.
We’re going to spend an entire lesson on shuffle in lesson 25, with .explain() plans and the Spark UI. For now, just register: groupBy is the moment your job leaves “stream this through” territory and enters “redistribute the universe” territory. Be intentional about it.
NULLs and count — the gotcha worth memorizing
A subtle point that catches everyone at least once. count("*") and count("col") are not the same:
count("*")counts rows. Always.count("col")counts rows wherecolis not NULL.
from pyspark.sql import Row
with_nulls = spark.createDataFrame([
Row(country="NL", customer=1),
Row(country="NL", customer=None),
Row(country="IT", customer=2),
Row(country="IT", customer=None),
])
with_nulls.groupBy("country").agg(
F.count("*").alias("rows"),
F.count("customer").alias("non_null_customers"),
).show()
# +-------+----+------------------+
# |country|rows|non_null_customers|
# +-------+----+------------------+
# | NL| 2| 1|
# | IT| 2| 1|
# +-------+----+------------------+
This is exactly the SQL behaviour and is genuinely useful — count("col") doubles as a “how many rows actually had a value here?” check. Just remember the difference when you’re computing percentages: dividing count("col") by count("*") gives you a fill rate, not a row count.
The same NULL-skipping applies to sum, avg, min, max — they all silently ignore NULL inputs. avg of [10, NULL, 30] is 20, not 13.33. Usually that’s what you want. When it isn’t, replace NULLs with zero (or whatever the right default is) using F.coalesce(col, F.lit(0)) before aggregating.
SQL-style subtotals: rollup and cube
One last trick. Sometimes you want grouped totals and a grand total in the same result. SQL has WITH ROLLUP and WITH CUBE; PySpark exposes them as .rollup(...) and .cube(...).
orders.rollup("Country").agg(
F.sum("Total").alias("revenue"),
F.count("*").alias("orders"),
).orderBy("Country").show()
# +-------+-------+------+
# |Country|revenue|orders|
# +-------+-------+------+
# | NULL| 654.92| 8| ← grand total
# | DE| 199.00| 1|
# | IT| 313.50| 3|
# | NL| 100.00| 3|
# | RO| 42.42| 1|
# +-------+-------+------+
rollup("Country", "CustomerId") would give you per-country-per-customer totals, plus per-country subtotals, plus a grand total. cube gives you every combination of subtotals (every column × every other). Useful for pivot-style reports where finance wants “by country, by year, by both, and the bottom line” all in one query.
F.grouping("col") returns 1 if a row’s NULL came from rollup/cube and 0 if it’s a real NULL — handy when you want to label the totals row instead of leaving a blank.
Filtering aggregates: agg plus where afterwards
Spark doesn’t have SQL’s HAVING clause as a separate operator — you don’t need one. After agg, you have a normal DataFrame. Filter it with where like anything else:
# "Countries with revenue above 100 EUR and at least 2 orders"
(orders
.groupBy("Country")
.agg(
F.sum("Total").alias("revenue"),
F.count("*").alias("orders"),
)
.where((F.col("revenue") > 100) & (F.col("orders") >= 2))
.show())
This reads naturally and the optimizer is smart enough to push the filter where it can. The only thing to remember: filters that reference aggregated columns have to come after agg. Filters on the raw columns belong before — they’re cheaper because every row you eliminate is a row that doesn’t have to ride the shuffle.
# Pre-aggregation filter: cheap, narrow transformation
(orders
.where(F.col("OrderDate") >= "2026-03-01")
.groupBy("Country")
.agg(F.sum("Total").alias("revenue"))
.where(F.col("revenue") > 100) # post-aggregation filter
.show())
The order matters for performance. Filter rows first, then group, then filter groups. Same logic as SQL’s WHERE vs HAVING.
A note on window functions
Sometimes you don’t want to collapse groups, you want to annotate rows with group-level numbers. “For each order, what fraction of its country’s total revenue does it represent?” That’s a window function: F.sum("Total").over(Window.partitionBy("Country")). Same shuffle cost as a groupBy, different output shape — every input row stays, gets a new column.
We’ll cover windows in detail in lesson 38. For now, when you see groupBy, think “collapse”; when you see over, think “annotate.”
Run this on your own machine
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = (SparkSession.builder
.appName("Aggregations101")
.master("local[*]")
.getOrCreate())
orders = spark.createDataFrame(
[
(1001, 1, 59.00, "NL"), (1002, 1, 29.00, "NL"),
(1003, 2, 149.00, "IT"), (1004, 2, 89.50, "IT"),
(1005, 3, 199.00, "DE"), (1006, 4, 42.42, "RO"),
(1007, 1, 12.00, "NL"), (1008, 2, 75.00, "IT"),
],
"OrderId INT, CustomerId INT, Total DOUBLE, Country STRING",
)
# 1. Single-pass aggregation, the shape you want to memorize
orders.groupBy("Country").agg(
F.count("*").alias("orders"),
F.sum("Total").alias("revenue"),
F.avg("Total").alias("aov"),
F.min("Total").alias("smallest"),
F.max("Total").alias("biggest"),
F.countDistinct("CustomerId").alias("unique_customers"),
).orderBy(F.col("revenue").desc()).show()
# 2. Per-customer fold with collect_set
orders.groupBy("CustomerId").agg(
F.count("*").alias("orders"),
F.sum("Total").alias("ltv"),
F.collect_set("Country").alias("countries"),
).show(truncate=False)
# 3. Rollup with a grand total
orders.rollup("Country").agg(
F.sum("Total").alias("revenue"),
F.count("*").alias("orders"),
).orderBy(F.col("Country").asc_nulls_first()).show()
# 4. Approximate vs exact distinct
orders.agg(
F.countDistinct("CustomerId").alias("exact"),
F.approx_count_distinct("CustomerId").alias("approx"),
).show()
Run each query. Notice how query 1 returns one row per country with six metrics, query 2 returns one row per customer with arrays inside, and query 3 has the extra “everything” row at the top. That’s the whole expressive range of groupBy + agg.
Next lesson: sorting at scale. orderBy, sort, the cost of a global sort, and the sortWithinPartitions trick that saves you when you don’t actually need a global order.