PySpark, dalle fondamenta Lezione 16 / 60

Aggregazioni 101: groupBy, agg e il catalogo delle funzioni di sintesi

groupBy + agg, le funzioni di aggregazione di base, le aggregazioni multi-colonna in un'unica passata, e perché agg è una wide transformation.

Metà di ogni job Spark reale finisce con la stessa forma di domanda: “per ogni X, qual è il totale / la media / il conteggio di Y?”. Ricavi per Paese. Ordini per cliente. Tentativi di login per ora. Click-through per campagna. La risposta coinvolge sempre groupBy e agg, e una volta che conosci il pattern ne scriverai cinquanta a settimana senza pensarci.

Questa lezione è il catalogo. Vedremo le funzioni di aggregazione quotidiane, l’idioma single-pass che scansiona i dati una volta sola invece di N volte, e l’unico dettaglio che distingue l’aggregazione da tutto ciò che abbiamo fatto finora: è una wide transformation. Spark deve fare shuffle. Imposteremo il perché adesso e ne sviscereremo il costo come si deve nella lezione 21.

Setup

Un piccolo DataFrame con cui giocare:

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = (SparkSession.builder
         .appName("Aggregations101")
         .master("local[*]")
         .config("spark.sql.shuffle.partitions", "8")
         .getOrCreate())

orders = spark.createDataFrame(
    [
        (1001, 1, 59.00,  "NL", "2026-03-05"),
        (1002, 1, 29.00,  "NL", "2026-03-18"),
        (1003, 2, 149.00, "IT", "2026-02-15"),
        (1004, 2, 89.50,  "IT", "2026-03-22"),
        (1005, 3, 199.00, "DE", "2026-03-10"),
        (1006, 4, 42.42,  "RO", "2026-03-28"),
        (1007, 1, 12.00,  "NL", "2026-03-30"),
        (1008, 2, 75.00,  "IT", "2026-03-31"),
    ],
    "OrderId INT, CustomerId INT, Total DOUBLE, Country STRING, OrderDate STRING",
)

Otto righe bastano per dimostrare tutto; gli stessi pattern scalano a otto miliardi in modo identico.

groupBy restituisce un GroupedData, non un DataFrame

È il tranello che frega tutti la prima volta:

g = orders.groupBy("Country")
print(type(g))
# <class 'pyspark.sql.group.GroupedData'>

GroupedData è un oggetto intermedio. È lo stadio “ho deciso come raggruppare, ma non ti ho ancora detto cosa calcolare”. Non puoi chiamarci sopra .show(), non puoi .write , non puoi fare nulla che abbia la forma di un DataFrame. Devi darlo in pasto a uno di questi:

  • .count() — scorciatoia, restituisce un DataFrame con una colonna chiamata count.
  • .sum("col"), .avg("col"), .min("col"), .max("col"), .mean("col") — scorciatoie a singola aggregazione.
  • .agg(...) — la potenza completa, più aggregazioni in una volta con rinomine ed espressioni.

Le scorciatoie sono comode per lookup veloci una tantum:

orders.groupBy("Country").count().show()
# +-------+-----+
# |Country|count|
# +-------+-----+
# |     NL|    3|
# |     IT|    3|
# |     DE|    1|
# |     RO|    1|
# +-------+-----+

orders.groupBy("Country").sum("Total").show()
# +-------+----------+
# |Country|sum(Total)|
# +-------+----------+
# |     NL|     100.0|
# |     IT|     313.5|
# |     DE|     199.0|
# |     RO|     42.42|
# +-------+----------+

Per qualsiasi cosa al di là di un singolo numero, lascia perdere le scorciatoie e usa .agg(...).

Il pattern single-pass di agg

Ecco la cosa che nessuno ti dice: ogni aggregazione separata è una scansione separata. Se scrivi tre query per calcolare tre KPI, Spark legge i dati tre volte. I computer adorano le letture sequenziali ma non così tanto.

Passa tutto a un’unica chiamata agg(...):

orders.groupBy("Country").agg(
    F.count("*").alias("orders"),
    F.sum("Total").alias("revenue"),
    F.avg("Total").alias("aov"),                # average order value
    F.min("Total").alias("smallest_order"),
    F.max("Total").alias("biggest_order"),
    F.countDistinct("CustomerId").alias("unique_customers"),
).show()
# +-------+------+-------+-----+--------------+-------------+----------------+
# |Country|orders|revenue|  aov|smallest_order|biggest_order|unique_customers|
# +-------+------+-------+-----+--------------+-------------+----------------+
# |     NL|     3|  100.0|33.33|          12.0|         59.0|               1|
# |     IT|     3|  313.5|104.5|          75.0|        149.0|               1|
# |     DE|     1|  199.0|199.0|         199.0|        199.0|               1|
# |     RO|     1|  42.42|42.42|         42.42|        42.42|               1|
# +-------+------+-------+-----+--------------+-------------+----------------+

Una scansione, sei aggregazioni, ogni metrica che una persona di business potrebbe chiederti. Questa è la forma che vuoi scrivere per riflesso.

Le chiamate .alias(...) non sono opzionali nello spirito: senza, ti ritrovi nomi di colonna come sum(Total) e avg(Total) che sono dolorosi da referenziare a valle e brutti nelle dashboard. Metti sempre l’alias.

Le funzioni di aggregazione che userai davvero, classificate per quanto spesso le digito personalmente:

F.count("*")                     # row count, never null-skipping
F.count("col")                   # row count, NULLs excluded (lesson 13)
F.sum("col")                     # total
F.avg("col") / F.mean("col")     # arithmetic mean (aliases for each other)
F.min("col") / F.max("col")      # smallest / largest
F.countDistinct("col")           # unique value count -- exact, expensive
F.approx_count_distinct("col")   # HyperLogLog estimate, much cheaper
F.stddev("col") / F.variance("col")   # sample standard deviation / variance
F.stddev_pop("col")              # population variant if that's what you need
F.collect_list("col")            # gather all values (with duplicates) into an array
F.collect_set("col")             # gather unique values into an array
F.first("col") / F.last("col")   # first/last value in each group (order-dependent -- be careful)

Due di queste meritano uno sguardo più attento.

approx_count_distinct è quella che sembra spaventosa ed è invece un regalo. Contare valori distinti in modo esatto su un miliardo di righe significa che Spark deve tener traccia di ogni valore visto: costoso in memoria e lento. La versione approssimata usa HyperLogLog e ti dà un errore relativo di ~2% per ordini di grandezza in meno di lavoro. Per le dashboard e i KPI tipo “utenti attivi” va quasi sempre bene:

orders.groupBy("Country").agg(
    F.countDistinct("CustomerId").alias("exact_unique"),
    F.approx_count_distinct("CustomerId").alias("approx_unique"),
).show()

collect_list e collect_set ti permettono di ripiegare un gruppo dentro un array: utile quando vuoi una riga per gruppo con una lista di tutti i valori correlati:

orders.groupBy("CustomerId").agg(
    F.count("*").alias("orders"),
    F.collect_list("OrderId").alias("order_ids"),
    F.collect_set("Country").alias("countries_ordered_from"),
).show(truncate=False)
# +----------+------+----------------+----------------------+
# |CustomerId|orders|order_ids       |countries_ordered_from|
# +----------+------+----------------+----------------------+
# |1         |3     |[1001, 1002, 1007]|[NL]                |
# |2         |3     |[1003, 1004, 1008]|[IT]                |
# |3         |1     |[1005]          |[DE]                  |
# |4         |1     |[1006]          |[RO]                  |
# +----------+------+----------------+----------------------+

Attenzione con collect_list su gruppi grossi: se un singolo gruppo ha milioni di righe, stai costruendo un array da un milione di elementi su un singolo executor. È una bomba di memoria. Per l’analytics, preferisci aggregazioni che producono scalari.

Perchè l’aggregazione è una wide transformation

Fino ad ora, ogni trasformazione che abbiamo scritto è stata narrow: ogni partizione di output dipende esattamente da una partizione di input. select, filter, withColumn: tutte narrow. Spark si limita a far fluire le righe attraverso gli executor che già le contengono.

groupBy rompe questa cosa. Per calcolare la somma per Country = 'NL', ogni riga NL deve finire sullo stesso executor. Non partono li’: sono sparpagliate tra le partizioni, dovunque siano atterrati i dati al momento della lettura. Spark deve fare shuffle: applica l’hash a ciascuna riga su Country, la spedisce attraverso la rete all’executor giusto, poi aggrega.

Quel passaggio di rete è la cosa più costosa che fa Spark. È per questo che esiste la regola “filtra prima di aggregare”: ogni riga che escludi con where è una riga che non deve viaggiare in rete. È per questo che partizionare bene i dati al momento della scrittura conta. È per questo che una join da 100GB può richiedere dieci minuti mentre un filtro da 100GB ne richiede trenta secondi.

Dedicheremo un’intera lezione allo shuffle nella lezione 25, con i piani di .explain() e la Spark UI. Per ora, registra solo questo: groupBy è il momento in cui il tuo job esce dal territorio “stream questi dati attraverso” ed entra in quello del “ridistribuisci l’universo”. Sii intenzionale al riguardo.

NULL e count: il tranello da memorizzare

Un punto sottile che frega tutti almeno una volta. count("*") e count("col") non sono la stessa cosa:

  • count("*") conta le righe. Sempre.
  • count("col") conta le righe in cui col è non NULL.
from pyspark.sql import Row

with_nulls = spark.createDataFrame([
    Row(country="NL", customer=1),
    Row(country="NL", customer=None),
    Row(country="IT", customer=2),
    Row(country="IT", customer=None),
])

with_nulls.groupBy("country").agg(
    F.count("*").alias("rows"),
    F.count("customer").alias("non_null_customers"),
).show()
# +-------+----+------------------+
# |country|rows|non_null_customers|
# +-------+----+------------------+
# |     NL|   2|                 1|
# |     IT|   2|                 1|
# +-------+----+------------------+

È esattamente il comportamento di SQL ed è genuinamente utile: count("col") funziona anche come check del tipo “quante righe avevano davvero un valore qui?”. Ricordati solo della differenza quando calcoli percentuali: dividere count("col") per count("*") ti dà un fill rate, non un conteggio righe.

Lo stesso null-skipping si applica a sum, avg, min, max: ignorano tutti silenziosamente gli input NULL. La avg di [10, NULL, 30] è 20, non 13.33. Di solito è quello che vuoi. Quando non lo è, sostituisci i NULL con zero (o qualunque sia il default giusto) usando F.coalesce(col, F.lit(0)) prima di aggregare.

Subtotal in stile SQL: rollup e cube

Un ultimo trucco. A volte vuoi totali raggruppati e un totale generale nello stesso risultato. SQL ha WITH ROLLUP e WITH CUBE; PySpark li espone come .rollup(...) e .cube(...).

orders.rollup("Country").agg(
    F.sum("Total").alias("revenue"),
    F.count("*").alias("orders"),
).orderBy("Country").show()
# +-------+-------+------+
# |Country|revenue|orders|
# +-------+-------+------+
# |   NULL|  654.92|    8|   <- grand total
# |     DE|  199.00|    1|
# |     IT|  313.50|    3|
# |     NL|  100.00|    3|
# |     RO|   42.42|    1|
# +-------+-------+------+

rollup("Country", "CustomerId") ti darebbe i totali per-Paese-per-cliente, più i subtotali per Paese, più un totale generale. cube ti dà ogni combinazione di subtotali (ogni colonna x ogni altra). Utile per i report in stile pivot dove finance vuole “per Paese, per anno, per entrambi, e la riga finale” tutto in un’unica query.

F.grouping("col") restituisce 1 se il NULL di una riga viene da rollup/cube e 0 se è un NULL reale: comodo quando vuoi etichettare la riga dei totali invece di lasciare un buco.

Filtrare le aggregazioni: agg più where dopo

Spark non ha la clausola HAVING di SQL come operatore separato: non ne hai bisogno. Dopo agg hai un normale DataFrame. Filtralo con where come qualsiasi altra cosa:

# "Countries with revenue above 100 EUR and at least 2 orders"
(orders
 .groupBy("Country")
 .agg(
     F.sum("Total").alias("revenue"),
     F.count("*").alias("orders"),
 )
 .where((F.col("revenue") > 100) & (F.col("orders") >= 2))
 .show())

Si legge in modo naturale e l’optimizer è abbastanza intelligente da spingere il filtro dove può. L’unica cosa da ricordare: i filtri che fanno riferimento a colonne aggregate devono venire dopo agg. I filtri sulle colonne grezze stanno prima: sono più economici perché ogni riga eliminata è una riga che non deve cavalcare lo shuffle.

# Pre-aggregation filter: cheap, narrow transformation
(orders
 .where(F.col("OrderDate") >= "2026-03-01")
 .groupBy("Country")
 .agg(F.sum("Total").alias("revenue"))
 .where(F.col("revenue") > 100)            # post-aggregation filter
 .show())

L’ordine conta per la performance. Filtra le righe prima, poi raggruppa, poi filtra i gruppi. Stessa logica del WHERE vs HAVING di SQL.

Una nota sulle window function

A volte non vuoi collassare i gruppi, vuoi annotare le righe con numeri a livello di gruppo. “Per ogni ordine, che frazione del fatturato totale del suo Paese rappresenta?”. Quella è una window function: F.sum("Total").over(Window.partitionBy("Country")). Stesso costo di shuffle di un groupBy, forma di output diversa: ogni riga di input rimane, e prende una nuova colonna.

Tratteremo le window in dettaglio nella lezione 38. Per ora, quando vedi groupBy pensa “collassa”; quando vedi over pensa “annota”.

Esegui questo sulla tua macchina

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = (SparkSession.builder
         .appName("Aggregations101")
         .master("local[*]")
         .getOrCreate())

orders = spark.createDataFrame(
    [
        (1001, 1, 59.00,  "NL"), (1002, 1, 29.00,  "NL"),
        (1003, 2, 149.00, "IT"), (1004, 2, 89.50,  "IT"),
        (1005, 3, 199.00, "DE"), (1006, 4, 42.42,  "RO"),
        (1007, 1, 12.00,  "NL"), (1008, 2, 75.00,  "IT"),
    ],
    "OrderId INT, CustomerId INT, Total DOUBLE, Country STRING",
)

# 1. Single-pass aggregation, the shape you want to memorize
orders.groupBy("Country").agg(
    F.count("*").alias("orders"),
    F.sum("Total").alias("revenue"),
    F.avg("Total").alias("aov"),
    F.min("Total").alias("smallest"),
    F.max("Total").alias("biggest"),
    F.countDistinct("CustomerId").alias("unique_customers"),
).orderBy(F.col("revenue").desc()).show()

# 2. Per-customer fold with collect_set
orders.groupBy("CustomerId").agg(
    F.count("*").alias("orders"),
    F.sum("Total").alias("ltv"),
    F.collect_set("Country").alias("countries"),
).show(truncate=False)

# 3. Rollup with a grand total
orders.rollup("Country").agg(
    F.sum("Total").alias("revenue"),
    F.count("*").alias("orders"),
).orderBy(F.col("Country").asc_nulls_first()).show()

# 4. Approximate vs exact distinct
orders.agg(
    F.countDistinct("CustomerId").alias("exact"),
    F.approx_count_distinct("CustomerId").alias("approx"),
).show()

Esegui ogni query. Nota come la query 1 restituisce una riga per Paese con sei metriche, la query 2 una riga per cliente con array dentro, e la query 3 ha la riga “tutto” in più in cima. È tutto il range espressivo di groupBy + agg.

Prossima lezione: ordinamento alla scala. orderBy, sort, il costo di un sort globale, e il trucco sortWithinPartitions che ti salva quando in realtà non ti serve un ordine globale.

Cerca