PySpark, from the ground up Lesson 29 / 60

Salting: the standard fix when one key dominates

How to break up a hot key by adding a synthetic random suffix, the worked example, and the cost of the trick.

The previous lesson left you staring at a Spark UI where one task was running for thirty minutes while the other 199 finished in twelve seconds. The bottleneck was a single hot key — user_id = 1, or country = 'US', or whatever else dominates your data. Throwing more cluster at it doesn’t help, because the work is on one task, on one core. We need to change the shape of the key.

That’s salting.

The idea in one sentence

Take the hot key, glue a small random number onto it, and now what used to be one partition becomes N partitions. The dominant key is split across multiple tasks, the stage finishes in roughly 1/N of the time, and you’re done.

The catch — and there’s always a catch — is that the other side of the join now has to be replicated, because every salted variant of the key has to find its match. We’ll see exactly how that works below, including the smarter version where you only salt the rows that need it.

The four-step pattern

Salting is mechanical once you’ve seen it. Four steps:

Step 1: on the skewed (fact) side, add a salt column. Pick a salt range N — typically 4, 8, or 16. For each row on the skewed side, generate a random integer salt = floor(rand() * N). The effective key is now (original_key, salt) instead of just original_key. The hot key, which used to hash to one partition, now hashes to up to N partitions.

Step 2: on the other (dimension) side, replicate each row N times. For every row in the dimension table, emit N copies — one for each possible salt value 0, 1, ..., N-1. The effective key on this side is also (original_key, salt), but it covers every possible salt for every key.

Step 3: join on (original_key, salt). Every fact row finds its single matching dim row. The join now distributes across N partitions for the hot key, evenly.

Step 4: drop the salt column afterward. It served its purpose during the shuffle. The output rows are the same rows you would have gotten from a regular join — there are no duplicates, because each fact row matched exactly one of the N replicated dim rows.

That’s it. Steps 2 and 4 are where most beginners get tripped up. Replicate the dim side, but join on the salt, so each fact row matches exactly one replicated dim row. Don’t aggregate before salting unless you mean to. Don’t forget to remove the salt before counting things.

Worked example

Let’s do this against a realistic-shaped dataset. A fact table of transactions where 60% of rows are country = 'US', plus a small dimension table of country metadata.

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = (SparkSession.builder
         .appName("Salting")
         .master("local[*]")
         .config("spark.sql.shuffle.partitions", "200")
         .config("spark.sql.adaptive.enabled", "false")  # disable AQE so we see raw skew
         .getOrCreate())

# Fact table: 1M rows, 60% are US, the rest spread across 9 other countries
us_rows = spark.range(0, 600_000).select(
    F.lit("US").alias("country"),
    F.col("id").alias("txn_id"),
    (F.rand() * 1000).alias("amount"),
)

others = ["IT", "FR", "DE", "ES", "UK", "JP", "BR", "IN", "CA"]
other_rows = spark.range(0, 400_000).select(
    F.element_at(F.array(*[F.lit(c) for c in others]),
                 (F.col("id") % 9 + 1).cast("int")).alias("country"),
    F.col("id").alias("txn_id"),
    (F.rand() * 1000).alias("amount"),
)

facts = us_rows.unionByName(other_rows)

# Dim table: one row per country
dim = spark.createDataFrame(
    [(c, c + " full name", c + "-region") for c in ["US"] + others],
    "country STRING, country_name STRING, region STRING",
)

A vanilla join on country:

joined = facts.join(dim, on="country", how="inner")
joined.write.mode("overwrite").parquet("/tmp/skew-vanilla")

If you watch the Spark UI, the shuffle stage shows the textbook pattern from lesson 28: 199 quick tasks, one slow task on the partition holding all 600k US rows. On a real cluster with bigger numbers, that one task might run minutes while the rest finish in seconds.

Now the salted version. The smart variant only salts US — the hot key — not every country, because salting non-skewed keys just multiplies dim rows for no benefit.

N = 8  # salt range

# Step 1: salt the fact side, but ONLY for the hot key
facts_salted = facts.withColumn(
    "salt",
    F.when(F.col("country") == "US",
           (F.rand() * N).cast("int"))
     .otherwise(F.lit(0))
)

# Step 2: replicate the dim side, but only for the hot key
us_dim = dim.filter(F.col("country") == "US")
others_dim = dim.filter(F.col("country") != "US")

# Build [0, 1, ..., N-1]
salts = spark.range(0, N).withColumnRenamed("id", "salt").withColumn(
    "salt", F.col("salt").cast("int")
)

us_dim_salted = us_dim.crossJoin(salts)        # N copies of the US row, one per salt
others_dim_salted = others_dim.withColumn("salt", F.lit(0))

dim_salted = us_dim_salted.unionByName(others_dim_salted)

# Step 3: join on (country, salt)
joined_salted = facts_salted.join(dim_salted, on=["country", "salt"], how="inner")

# Step 4: drop the salt column
joined_salted = joined_salted.drop("salt")

joined_salted.write.mode("overwrite").parquet("/tmp/skew-salted")

The US rows on the fact side now spread evenly across 8 salt buckets. The US row on the dim side appears 8 times — once for each salt value — so every fact row matches exactly one dim row. Other countries are untouched: salt is 0 on both sides, and they join as before.

In the Spark UI, the salted version looks completely different. Where the vanilla version had one 30-second task and 199 fast ones, the salted version has 8 tasks each handling about 75k US rows in roughly 4 seconds, in parallel with the rest. The stage’s wall clock drops from “limited by the slow task” to “limited by the median task” — exactly what you want.

The cost

Salting isn’t free. Three things to keep in mind:

Replicated dim rows are real rows. If N = 8 and the hot key has, say, 50 dim-side rows (some keys do — think about hot composite keys), you’ve turned 50 rows into 400. That extra data has to be shuffled. For tiny dim sides like the country table above, the cost is rounding error. For bigger dim sides — a per-product attribute table where the “hot SKU” has 200 attribute variants — the multiplication can add up. Check.

Pick N as small as you can get away with. A common mistake is picking N = 100 because “more spread is better.” It’s not. The fact-side hot key only had 600k rows; spreading them across 8 partitions gives ~75k per partition, which is plenty of parallelism. Going to 100 makes each partition tiny, but you’ve also multiplied the dim side’s hot rows by 100. For most workloads, N between 4 and 16 is the sweet spot. Start at 8, measure, adjust.

Only salt the hot keys. The example above is the smart pattern: a WHEN ... OTHERWISE that salts US and leaves everything else alone. The naive version — salting every fact row and replicating every dim row — works, but it pays the replication cost on the entire dim table for no benefit. If your skew is concentrated on one or two known keys, single them out. If you don’t know which keys are hot, run the diagnostic from lesson 28 first.

When NOT to salt

Three cases where salting is the wrong tool:

Broadcast join works. If the dim side fits in memory (lesson 27), broadcast it and you’re done. No shuffle, no skew, no salt. The salted approach only matters when the dim side is too big to broadcast and the fact side is skewed.

AQE handles it for you. Spark 3.x’s Adaptive Query Execution (lesson 59) has skew-join support. With spark.sql.adaptive.enabled = true and spark.sql.adaptive.skewJoin.enabled = true, Spark detects skewed partitions at runtime and splits them automatically — no code changes. AQE only handles sort-merge joins and only kicks in past a configurable size threshold, so it doesn’t replace salting for every case, but on Spark 3.4+ it solves a lot of skew before you even know you have skew. Always check whether AQE is on before reaching for salt.

The “skew” is actually mild. If the top key is 3x the median, you don’t have a skew problem — you have a slightly uneven workload. Increasing spark.sql.shuffle.partitions (lesson 32) is a cheaper fix. Salting is for the cases where one key has 100x the volume of the rest, where the math says nothing else will work.

Salting for group-by, not just join

The same trick works for aggregations. If groupBy("user_id").sum("amount") has a hot user, you can:

  1. Add a salt column on the input.
  2. Group by (user_id, salt) — partial aggregation, distributed across N partitions.
  3. Group by user_id again on the partial results — small, cheap, no skew.
salted = events.withColumn("salt", (F.rand() * 8).cast("int"))

partial = (salted.groupBy("user_id", "salt")
                 .agg(F.sum("amount").alias("partial_sum")))

final = (partial.groupBy("user_id")
                .agg(F.sum("partial_sum").alias("total")))

The partial aggregation does the heavy lifting in parallel; the final aggregation only sees N rows per user, which fits in a single task with room to spare. Same pattern as combiner / reduce in MapReduce — distribute the work, then collapse.

What’s next

Lesson 30 ties off the joins-and-shuffles module by walking through how to read the physical plan of a join and predict its runtime before pressing go — broadcasts, sort-merge, shuffle hash, and how AQE rewrites the plan when it sees a problem. After that, Module 6 (lessons 31 and 32) zooms out from “fixing the slow stage” to “designing partitions intentionally in the first place,” because most skew problems are partition problems wearing a costume.

Two things worth committing to memory from this lesson:

  • The salting recipe: salt the fact side, replicate the dim side, join, drop the salt. Four steps. Don’t skip any.
  • Pick small N, only salt hot keys, and check whether AQE is already doing this for you before writing a line of code.

Salting is the kind of trick that feels clever the first time and routine the tenth. Keep the snippet above somewhere; the next time a stage hangs on one task, you’ll know exactly what to do.


References: Apache Spark documentation on join strategies and Adaptive Query Execution; Databricks engineering blog posts on skew remediation patterns. Retrieved 2026-05-01.

Search