Join PySpark che non fanno saltare il cluster

Perché le join sono la principale fonte di dolore in Spark, cosa fa davvero lo shuffle, e i trucchi del broadcast e del salting che trasformano un job da 40 minuti in uno da 4.

Se hai passato un po’ di tempo con Spark, sai già quale riga di codice del tuo job andrà in timeout. È la join. È quasi sempre la join. Questo post è la versione di “cosa farci” che avrei voluto qualcuno mi mettesse in mano il primo giorno.

Perché le join sono la parte difficile

Spark è un motore distribuito. I tuoi dati vivono su tante macchine, partizionati per qualche hash di qualche colonna. Un filtro WHERE è facile, perché ogni partizione può rispondere da sola. Un groupBy().sum() è più difficile ma ancora gestibile. Una join è quella dolorosa, perché per fare il join della riga A da un dataframe con la riga B da un altro, le due righe devono finire fisicamente sulla stessa macchina.

Il meccanismo per “finire fisicamente sulla stessa macchina” si chiama shuffle. Lo shuffle è la cosa più lenta e costosa che fa Spark. Scrive dati intermedi su disco su ogni executor, li manda attraverso la rete e li rilegge dall’altra parte. Ogni join che scrivi che non rientra in uno dei casi speciali qui sotto ne fa scattare uno. Se il tuo job è lento, la risposta è quasi sempre “hai una join che sta facendo lo shuffle di mezzo terabyte.”

Il default: sort-merge join

Out of the box, Spark fa join tra due dataframe con un sort-merge join:

  1. Hash su entrambi i lati sulla join key.
  2. Shuffle delle righe in modo che tutte le chiavi corrispondenti finiscano sullo stesso executor.
  3. Sort locale di ogni lato per chiave.
  4. Cammina insieme e emetti i match.

Va bene quando entrambi i lati sono grandi e all’incirca della stessa dimensione. È il default per un motivo. Ma ha un costo minimo fisso — lo shuffle completo di entrambi i lati — e quel costo domina tutto il resto.

from pyspark.sql import SparkSession, functions as F

spark = SparkSession.builder.appName("joins").getOrCreate()

orders   = spark.read.parquet("s3://bucket/orders/")     # 800M righe
products = spark.read.parquet("s3://bucket/products/")   #   2k righe

# Funziona, ma fa lo shuffle di 800M righe di `orders`
# attraverso il cluster senza alcun motivo.
joined = orders.join(products, on="product_id", how="left")

Quello è l’errore più comune in Spark che si vede in giro: un sort-merge join in cui uno dei due lati è abbastanza piccolo da starsene comodamente in memoria.

Broadcast join: la vittoria facile

Se un lato della join è piccolo — e “piccolo” significa “ci sta comodamente in memoria su ogni executor” — puoi evitare lo shuffle del tutto broadcastando quel lato. Invece di spostare il dataframe grande, Spark manda una copia di quello piccolo a ogni executor. Ogni executor poi fa la join localmente, niente shuffle, niente sort. È più veloce di un ordine di grandezza quando si applica.

from pyspark.sql.functions import broadcast

# 2k righe di `products` vengono serializzate e mandate a ogni executor.
# `orders` resta dov'è. Niente shuffle.
joined = orders.join(broadcast(products), on="product_id", how="left")

Spark farà il broadcast di tabelle piccole automaticamente quando riesce a stimarne la dimensione con sicurezza — controllato da spark.sql.autoBroadcastJoinThreshold (default 10 MB). Il problema è che l’auto-broadcast scatta solo quando l’optimizer riesce a leggere la dimensione della tabella dalle statistiche. Letto da un CSV senza statistiche? Letto dopo una catena di filtri che ha confuso l’optimizer? Cade silenziosamente in shuffle e tu aspetti venti minuti chiedendoti perché.

La soluzione è essere espliciti. Se sai che un lato è piccolo, avvolgilo in broadcast(). L’hint non ti costa niente se Spark l’avrebbe broadcastato comunque, e ti salva dall’optimizer che indovina male.

L’altro problema: c’è un limite massimo a quello che puoi broadcastare. Il default è 8 GB per task; in pratica dovresti essere nervoso per qualsiasi cosa sopra qualche centinaio di MB. Broadcastare una tabella da 4 GB su un cluster da 200 executor significa mandare 800 GB attraverso la rete — a quel punto lo shuffle era effettivamente più economico.

Skew: quando una chiave ha tutte le righe

L’altro modo in cui le cose vanno male è lo skew dei dati. Spark assume una distribuzione più o meno uniforme delle join key tra le partizioni. Quando è sbagliato — quando l’80% dei tuoi orders ha customer_id = NULL, o quando un cliente VIP ha 50 milioni di ordini e tutti gli altri ne hanno 50 — un executor finisce per fare tutto il lavoro mentre il resto sta a girarsi i pollici. Lo vedi nella Spark UI come uno stage in cui 199 task finiscono in 30 secondi e un task gira per 45 minuti. Quel task è la chiave skewed.

La prima cosa da controllare è se puoi semplicemente filtrare via la chiave incriminata. Le join key NULL sono quasi sempre un bug:

# Se le join key NULL non sono significative, scartale prima della join.
orders_clean = orders.filter(F.col("customer_id").isNotNull())
joined = orders_clean.join(customers, on="customer_id")

Se lo skew viene da una chiave reale e legittimamente popolare, il trucco standard è il salting: aggiungi un suffisso casuale alla chiave calda da un lato, fai esplodere l’altro lato per farlo combaciare, poi fai la join sulla chiave combinata. Distribuisce il lavoro su molte partizioni al costo di moltiplicare il lato piccolo.

# Scegli un range di salt. Più grande = più parallelismo, più memoria sul lato piccolo.
SALT_BUCKETS = 16

# Salta il lato grande (skewed): ogni riga riceve un id di bucket casuale.
orders_salted = orders.withColumn(
    "salt",
    (F.rand() * SALT_BUCKETS).cast("int"),
)

# Esplodi il lato piccolo: ogni riga è duplicata SALT_BUCKETS volte,
# una per ciascun valore di salt possibile.
customers_exploded = customers.withColumn(
    "salt",
    F.explode(F.array([F.lit(i) for i in range(SALT_BUCKETS)])),
)

joined = orders_salted.join(
    customers_exploded,
    on=["customer_id", "salt"],
)

Sembra brutto la prima volta che lo scrivi. È brutto. Prende anche un job che girava da un’ora e lo porta a dieci minuti, che è l’unica cosa a cui tiene il turno di reperibilità.

Spark 3.0+ include anche Adaptive Query Execution (AQE), che può rilevare lo skew a runtime e splittare automaticamente le partizioni incriminate. Accendilo:

spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")

AQE gestisce gratis i casi di skew facili. Per quelli rognosi — pre-aggregazioni, pipeline a più stage, join su colonne derivate — ti servirà ancora il salting. Ma AQE acceso di default è una cosa ovvia nel 2026 e sono un po’ stupito che non lo sia.

Qualche piccola abitudine che paga

  • Spingi i filtri prima delle join, sempre. Fare la join di una tabella da un miliardo di righe e poi filtrare ai dati dell’ultima settimana dà lo stesso risultato di filtrare prima e fare la join su una tabella da 7 milioni di righe. La seconda gira in una frazione del tempo. L’optimizer di Spark a volte lo fa per te. Non contarci — scrivi il filtro per primo.
  • Proietta solo le colonne che ti servono. Una SELECT * prima di una join porta ogni colonna attraverso lo shuffle, incluso il blob JSON da 400 caratteri che non stai nemmeno usando. df.select("id", "amount", "ts") prima della join è performance gratuita.
  • Evita .toPandas() su risultati intermedi. Raccoglie l’intero dataframe sul driver. La gente lo fa per “controllare” l’output e accidentalmente manda in OOM il driver su un job da diversi terabyte.
  • Leggi la Spark UI. Lo so, è brutta. Lo so, ha sei tab che non capisci. Ma il tab Stages ti dirà, in cinque secondi, quale task del tuo job è quello lento e quanti dati sta facendo shuffle. È il singolo strumento di debug più utile che Spark ti dà, e la maggior parte della gente non lo apre mai.
  • Cache strategica, non riflessiva. Cachare un dataframe usato una sola volta è puro overhead. Cachare uno usato in tre join downstream è oro. Sii deliberato.

Il modello mentale

Tutto il gioco con la performance di PySpark è: evita di muovere i dati. I filtri muovono zero dati. Le proiezioni muovono zero dati. I broadcast muovono una piccola quantità di dati una volta sola. I sort-merge join muovono tutti i dati. Se sei lento, trova lo shuffle e chiediti se deve davvero stare lì. La maggior parte delle volte, no.