If you’ve spent any time with Spark, you already know which line of code in your job is going to time out. It’s the join. It’s almost always the join. This post is the version of “what to do about that” I wish someone had handed me on day one.
Why joins are the hard part
Spark is a distributed engine. Your data lives on many machines, partitioned by some hash of some column. A WHERE filter is easy, because each partition can answer it on its own. A groupBy().sum() is harder but still tractable. A join is the painful one, because to join row A from one dataframe with row B from another, both rows have to physically end up on the same machine.
The mechanism for “physically end up on the same machine” is called a shuffle. The shuffle is the slowest, most expensive thing Spark does. It writes intermediate data to disk on every executor, sends it across the network, and reads it back on the other side. Every join you write that doesn’t fit one of the special cases below triggers one. If your job is slow, the answer is almost always “you have a join doing a shuffle of half a terabyte.”
The default: sort-merge join
Out of the box, Spark joins two dataframes with a sort-merge join:
- Hash both sides on the join key.
- Shuffle the rows so all matching keys land on the same executor.
- Sort each side locally by the key.
- Walk them in lockstep and emit matches.
This is fine when both sides are large and roughly the same size. It is the default for a reason. But it has a fixed minimum cost — the full shuffle of both sides — and that cost dominates everything else.
from pyspark.sql import SparkSession, functions as F
spark = SparkSession.builder.appName("joins").getOrCreate()
orders = spark.read.parquet("s3://bucket/orders/") # 800M rows
products = spark.read.parquet("s3://bucket/products/") # 2k rows
# This will work, but it shuffles 800M rows of `orders`
# across the cluster for absolutely no reason.
joined = orders.join(products, on="product_id", how="left")
That’s the most common Spark mistake in the wild: a sort-merge join where one of the sides is small enough to fit in memory.
Broadcast joins: the easy win
If one side of the join is small — and “small” means “fits comfortably in memory on every executor” — you can avoid the shuffle entirely by broadcasting that side. Instead of moving the big dataframe around, Spark sends a copy of the small one to every executor. Each executor then joins locally, no shuffle, no sort. It’s faster by an order of magnitude when it applies.
from pyspark.sql.functions import broadcast
# 2k rows of `products` get serialized and sent to every executor.
# `orders` stays where it is. No shuffle.
joined = orders.join(broadcast(products), on="product_id", how="left")
Spark will broadcast small tables for you automatically when it can confidently estimate their size — controlled by spark.sql.autoBroadcastJoinThreshold (default 10 MB). The catch is that the auto-broadcast only fires when the optimizer can read the table’s size from statistics. Read from a CSV with no statistics? Read after a chain of filters that confused the optimizer? It silently falls back to a shuffle and you wait twenty minutes wondering why.
The fix is to be explicit. If you know one side is small, wrap it in broadcast(). The hint costs you nothing if Spark would have broadcast it anyway, and it saves you from the optimizer guessing wrong.
The other catch: there’s a hard upper bound on what you can broadcast. The default is 8 GB per task; in practice you should be nervous about anything above a few hundred MB. Broadcasting a 4 GB table to a 200-executor cluster means sending 800 GB across the network — at which point the shuffle was actually cheaper.
Skew: when one key has all the rows
The other failure mode is data skew. Spark assumes a roughly even distribution of join keys across partitions. When that’s wrong — when 80% of your orders have customer_id = NULL, or when one VIP customer has 50 million orders and everyone else has 50 — one executor ends up doing all the work while the rest sit idle. You’ll see it in the Spark UI as a stage where 199 tasks finish in 30 seconds and one task runs for 45 minutes. That one task is the skewed key.
The first thing to check is whether you can just filter the offending key out. NULL join keys are almost always a bug:
# If NULL join keys aren't meaningful, drop them before the join.
orders_clean = orders.filter(F.col("customer_id").isNotNull())
joined = orders_clean.join(customers, on="customer_id")
If the skew is from a real, legitimately-popular key, the standard trick is salting: add a random suffix to the hot key on one side, explode the other side to match, then join on the combined key. It distributes the work across many partitions at the cost of multiplying the small side.
# Pick a salt range. Bigger = more parallelism, more memory on the small side.
SALT_BUCKETS = 16
# Salt the big (skewed) side: each row gets a random bucket id.
orders_salted = orders.withColumn(
"salt",
(F.rand() * SALT_BUCKETS).cast("int"),
)
# Explode the small side: every row is duplicated SALT_BUCKETS times,
# once for each possible salt value.
customers_exploded = customers.withColumn(
"salt",
F.explode(F.array([F.lit(i) for i in range(SALT_BUCKETS)])),
)
joined = orders_salted.join(
customers_exploded,
on=["customer_id", "salt"],
)
It feels ugly the first time you write it. It is ugly. It also takes a job that was running for an hour and brings it down to ten minutes, which is the only thing the on-call rotation cares about.
Spark 3.0+ also ships Adaptive Query Execution (AQE), which can detect skew at runtime and split the offending partitions automatically. Turn it on:
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
AQE handles the easy skew cases for free. For the gnarly ones — pre-aggregations, multi-stage pipelines, joins on derived columns — you’ll still need salting. But AQE on by default is a no-brainer in 2026 and I’m slightly amazed it isn’t.
A few smaller habits that pay off
- Push filters before joins, always. Joining a billion-row table and then filtering to last week’s data is the same answer as filtering first and then joining a 7-million-row table. The second one runs in a fraction of the time. Spark’s optimizer will sometimes do this for you. Don’t rely on it — write the filter first.
- Project only the columns you need. A
SELECT *before a join carries every column through the shuffle, including the 400-character JSON blob you’re not even using.df.select("id", "amount", "ts")before joining is free performance. - Avoid
.toPandas()on intermediate results. It collects the entire dataframe to the driver. People do this to “check” the output and accidentally OOM the driver on a multi-terabyte job. - Read the Spark UI. I know, it’s ugly. I know, it has six tabs you don’t understand. But the Stages tab will tell you, in five seconds, which task in your job is the slow one and how much data it’s shuffling. It’s the single most useful debugging tool Spark gives you, and most people never open it.
- Cache strategically, not reflexively. Caching a dataframe that’s only used once is pure overhead. Caching one that’s used in three downstream joins is gold. Be deliberate.
The mental model
The whole game with PySpark performance is: avoid moving data. Filters move zero data. Projections move zero data. Broadcasts move a tiny amount of data once. Sort-merge joins move all the data. If you’re slow, find the shuffle and ask whether it has to be there. Most of the time, it doesn’t.