PySpark, dalle fondamenta Lezione 23 / 60

Caching e persistence: storage level, quando ognuno ha senso

df.cache() e df.persist(): cosa fanno davvero, gli storage level offerti da Spark e i pattern tipici in cui il caching ripaga.

La lezione 22 ha coperto come Spark divide il job in stage e task. Una cosa che il modello DAG implica ma non rende ovvia: ogni action ricomputa l’intero grafo dalla sorgente. Se chiami .count() e poi .write.parquet() sullo stesso DataFrame, Spark legge la sorgente due volte, fa girare tutti i filtri due volte, fa girare tutti i join due volte. Spark è lazy di default e dimentica di default.

Il caching è la leva che tiri quando hai intenzione di usare un DataFrame più di una volta. Dice a Spark “la prossima volta che lo computi, tienitelo da parte; passerò di nuovo”. Usato bene, taglia di ordini di grandezza il tempo di iterazione su job a output multipli. Usato male, spreca memoria e rallenta le cose. La lezione di oggi è la versione “usato bene”. Quella di domani (la 24) è la versione “usato male”.

.cache() vs .persist()

I due metodi su ogni DataFrame:

df.cache()                                    # forma abbreviata
df.persist()                                  # come .cache() col livello di default
df.persist(StorageLevel.MEMORY_AND_DISK)      # esplicito, identico al precedente
df.persist(StorageLevel.DISK_ONLY)            # storage level diverso

.cache() è esattamente equivalente a .persist(StorageLevel.MEMORY_AND_DISK). Tutto qui. Niente magia, nessuna differenza. Usa .cache() per il caso comune; ricorri a .persist(...) quando ti serve uno storage level non di default.

Mina storica che vale la pena conoscere: in Spark 1.x, .cache() aveva default MEMORY_ONLY. Se il tuo DataFrame non ci stava in memoria, dei pezzi venivano semplicemente droppati e Spark li ricomputava on demand, in silenzio e lentamente. Spark 2.0 ha cambiato il default in MEMORY_AND_DISK, che spilla l’eccedenza su disco locale invece di buttarla. Molto più amichevole. Se mai leggi vecchi blog post su Spark che mettono in guardia da “DataFrame cachati che spariscono”, quello è il mondo che descrivono. Non ci viviamo più.

Gli storage level

StorageLevel vive in pyspark.storagelevel. I sette livelli che incontrerai:

  • MEMORY_ONLY: memorizza oggetti JVM deserializzati nell’heap dell’executor. Accesso veloce, footprint di memoria più grande. Se non ci sta, le partizioni vengono droppate e ricomputate al volo. Oggi raramente la scelta giusta.
  • MEMORY_AND_DISK (default per .cache()): prova prima la memoria; spilla il resto su disco locale. La scelta “senza sorprese”. Usa questa per quasi tutto il caching.
  • MEMORY_ONLY_SER: memorizza byte serializzati (Kryo o Java). Più piccolo in memoria di MEMORY_ONLY, ma ogni lettura paga un costo di deserializzazione. Caso di nicchia.
  • MEMORY_AND_DISK_SER: come sopra, ma spilla anche su disco i byte serializzati. Si usa quando la pressione di memoria è severa e puoi pagare la tassa di deserializzazione in cambio di meno RAM.
  • DISK_ONLY: dritto su disco locale, senza copia in memoria. Utile per DataFrame genuinamente enormi che non stanno nella memoria del cluster ma costano molto da ricomputare (per esempio un join pesante che riutilizzerai).
  • OFF_HEAP: memorizza fuori dall’heap della JVM. Originariamente per l’integrazione con Tachyon/Alluxio. Raro nel PySpark moderno.
  • Le varianti _2: ogni livello di cui sopra ha una variante _2 (MEMORY_AND_DISK_2, DISK_ONLY_2, ecc.) che replica ogni partizione cachata su due executor. Utile se hai una sessione interattiva long-running e la morte di un executor sarebbe molto costosa da recuperare.

Se non hai una ragione forte per scegliere altro, MEMORY_AND_DISK è la risposta. È quello che ti dà .cache(). Vai avanti.

Il caching è esso stesso lazy

Ecco il dettaglio che frega tutti la prima volta:

df = spark.read.parquet("big_input/")
df.cache()           # marca df per il caching... ma per ora NON fa nulla
df.filter(...).count()   # ORA df viene letto E cachato

Chiamare .cache() non innesca la computazione. Registra solo un flag: “la prossima volta che mi computi, tieniti il risultato”. La materializzazione vera avviene quando un’action gira contro il DataFrame.

Questo ha una conseguenza sottile. Se la tua “prossima action” tocca solo una parte del DataFrame, solo quella parte potrebbe finire cachata. Per esempio:

df = spark.read.parquet("orders/").cache()

# Questa action computa solo le prime righe
df.show()

# Questa action ha bisogno del resto, e potrebbe innescare una rilettura
df.count()

.show() materializza solo qualche partizione (quel che basta a riempire il display). Il resto del DataFrame non è ancora cachato. Quando gira .count(), Spark legge le partizioni mancanti dalla sorgente.

L’idioma canonico per “forza la cache a popolarsi tutta adesso”:

df.cache()
df.count()    # action sull'intero DataFrame; popola la cache da capo a fondo

Lo vedrai in continuazione nel codice di produzione. Non è laziness della laziness: è un pattern deliberato di “scaldare la cache prima che parta il loop”.

Quando il caching ripaga

Tre pattern tipici in cui il caching è la mossa giusta.

1. Action multiple sullo stesso DataFrame

prepared = (
    spark.read.parquet("orders/")
        .filter(col("year") == 2026)
        .join(broadcast(customers), "customer_id")
        .withColumn("revenue_eur", col("total") * col("eur_rate"))
)
prepared.cache()

prepared.write.parquet("out/by_country/", partitionBy="country")
prepared.write.parquet("out/by_status/", partitionBy="status")
prepared.groupBy("country").sum("revenue_eur").write.parquet("out/summary/")

Senza .cache(), quella grossa pipeline read-filter-join-withColumn gira tre volte. Con la cache, una. Su un dataset reale è la differenza tra un job da 9 minuti e uno da 3.

2. Algoritmi iterativi

Training di ML, traversal di grafi, qualunque cosa cicli sugli stessi dati:

training_set = featurize(raw).cache()
training_set.count()    # scalda la cache

for epoch in range(20):
    model = model.update(training_set)

Ogni iterazione legge training_set una volta. Senza la cache, ogni iterazione ricomputa featurize(raw) da zero. Venti epoche = venti ricomputazioni inutili.

3. Esplorazione interattiva su notebook

Stai smanettando su un DataFrame complesso, lanci 15 query diverse contro di lui. Cachalo una volta in cima al notebook, poi esplora in libertà. Ricordati solo di chiamare .unpersist() (o di riavviare il kernel) quando passi ad altro, altrimenti resta nella memoria del cluster per sempre.

Unpersist: la mossa di produzione

L’eviction LRU prima o poi butterà fuori i DataFrame cachati ormai stantii, ma “prima o poi” può voler dire “dopo che hanno sprecato memoria per tutto il job”. Nel codice di produzione, libera la cache esplicitamente:

prepared = expensive_pipeline().cache()
prepared.count()

# ... usa prepared in 3 write diverse ...

prepared.unpersist()   # libera la memoria; abbiamo finito

Particolarmente importante dentro applicazioni Spark long-running (notebook, job di structured streaming, batch schedulati che condividono una SparkSession tra task). Fai unpersist di quel che non serve più. La memoria è una risorsa; trattala come tale.

Scegliere un non-default storage level apposta

Il default, MEMORY_AND_DISK, è quasi sempre quello giusto. Ecco i pochi casi in cui ricorrere ad altro è giustificato.

DISK_ONLY quando:

  • Il tuo DataFrame è molto più grande della memoria totale degli executor e lo riutilizzerai pesantemente. Cachare in memoria farebbe solo thrashing; disk-only lo evita e i dati vengono letti sequenzialmente quando servono.
  • Vuoi liberare memoria per la computazione vera (join, aggregazioni) e ti basta un artefatto on-disk stabile. Spesso è segno che faresti meglio a scrivere su Parquet e rileggerlo, ma DISK_ONLY è un’alternativa veloce per il lavoro interattivo.

MEMORY_AND_DISK_SER quando:

  • Il tuo dataset è largo (molte colonne) e gli oggetti deserializzati pieni usano molto heap. La forma serializzata è tipicamente 2-5 volte più piccola. Paghi la deserializzazione su ogni lettura in cambio di farne stare di più in memoria.
  • Stai vedendo un GC time alto nella Spark UI e la cache è il sospettato.

Varianti di replica *_2 quando:

  • Ricostruire la cache è costoso (pensa: hai cachato il risultato di un join da 30 minuti), e il cluster è preemptible / spot instance dove gli executor muoiono. Il costo di storage 2x ti compra resilienza al fallimento di un singolo executor.

Le altre (MEMORY_ONLY, MEMORY_ONLY_SER, OFF_HEAP) sono raramente la risposta giusta in un deploy Spark moderno. Se sei tentato di usare MEMORY_ONLY perché l’hai letto in un tutorial, resisti: il default post-Spark-1.x è giustamente MEMORY_AND_DISK e dovresti lasciarglielo fare.

Come verificare che la cache stia funzionando

Due posti dove guardare:

1. Spark UI -> tab Storage. Elenca ogni DataFrame cachato, con dimensione in memoria, dimensione su disco, conteggio di repliche e storage level. Se il tuo DataFrame “cachato” non compare qui, non c’è ancora stata un’action che lo abbia materializzato.

2. Explain plan. Cerca InMemoryTableScan (o InMemoryRelation) nel physical plan:

prepared.cache()
prepared.count()        # innesca il caching
prepared.explain()

Il piano includerà InMemoryTableScan [...] vicino alla cima, a indicare che Spark sta servendo dalla cache invece di rileggere la sorgente. Se non lo vedi, la cache non viene usata (potrebbe essere una trasformazione non deterministica, oppure stai operando su un oggetto DataFrame diverso da quello che hai cachato: sì, succede).

Un bug sottile comune

df = spark.read.parquet("orders/")
df.cache()
df.count()

# Più avanti nel codice...
df = df.filter(col("year") == 2026)   # riassegna df

df.write.parquet("out/")              # NON viene servito dalla cache

Il caching è legato a un oggetto DataFrame, non a un nome. Nel momento in cui riassegni df, il nuovo DataFrame non è cachato. L’originale (ancora in memoria) è ora un orfano che nessuno referenzia. Resterà lì finché l’eviction LRU non lo butta fuori.

Il fix:

raw = spark.read.parquet("orders/").cache()
raw.count()

filtered = raw.filter(col("year") == 2026)   # nome diverso; raw è ancora cachato
filtered.write.parquet("out/")

Usa nomi distinti per DataFrame distinti. Non oscurare la variabile che hai cachato.

Cache vs checkpoint: una breve digressione

Concetto adiacente che a volte viene confuso col caching: il checkpointing.

spark.sparkContext.setCheckpointDir("/tmp/spark-checkpoints")

df = expensive_pipeline()
df = df.checkpoint()    # eager di default

Il checkpoint materializza il DataFrame su storage affidabile (HDFS, S3) e tronca il lineage: il piano del DataFrame risultante parte dal checkpoint, non dalla sorgente originale. Il caching mantiene il lineage; memoizza solo il risultato. Se una partizione cachata viene persa, Spark la ricomputa dal lineage. Se una partizione checkpointata viene persa, Spark la rilegge dallo storage durevole.

Quando usare checkpoint invece di cache:

  • Algoritmi iterativi in cui il lineage cresce così tanto che persino pianificare una nuova iterazione diventa lento (alcuni algoritmi su grafi).
  • Job di streaming che hanno bisogno di un punto di recovery stabile.

Per l’ETL batch quotidiano, cache è quel che vuoi. Rivisiteremo il checkpoint nel modulo streaming molto più avanti. Lo menziono qui solo perché tu non ci ricorra per sbaglio; risolve un problema diverso.

Provalo sulla tua macchina

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, rand
from pyspark.storagelevel import StorageLevel
import time

spark = (
    SparkSession.builder
        .appName("cache-demo")
        .config("spark.sql.shuffle.partitions", "8")
        .getOrCreate()
)

# Costruisci qualcosa deliberatamente costoso: un join + compute pesante
big = (
    spark.range(0, 5_000_000, numPartitions=32)
        .withColumn("k", (col("id") % 1000).cast("int"))
        .withColumn("v", rand() * 1000)
)
small = spark.range(0, 1000).withColumnRenamed("id", "k")

joined = (
    big.join(small, "k")
       .withColumn("v2", col("v") * 1.21)
       .withColumn("v3", col("v") + col("v2"))
)

# === Senza cache ===
t0 = time.time()
joined.count()
joined.groupBy("k").sum("v3").count()
print(f"Senza cache: {time.time() - t0:.2f}s")

# === Con cache ===
joined.cache()
joined.count()    # scalda la cache

t0 = time.time()
joined.count()
joined.groupBy("k").sum("v3").count()
print(f"Con cache:    {time.time() - t0:.2f}s")

# Guarda la tab Storage mentre lo script è in pausa
input("Premi Invio per liberare la cache e uscire... ")
joined.unpersist()
spark.stop()

Dovresti vedere il run cachato terminare in una piccola frazione del tempo del run a freddo. Apri la tab Storage durante la pausa e osserva la dimensione e lo storage level del DataFrame cachato.

Quello è il lato produttivo del caching. La prossima lezione, il post esistente che teniamo in questo slot, ma rilavorato, copre il lato oscuro: quando il caching peggiora le cose. Pressione di memoria, churn da eviction, lo script di dieci righe che diventa più veloce quando rimuovi la riga .cache(). Leggila subito dopo; è il contrappeso necessario a quella di oggi.

Cerca