Hydronitrogen Tech Blog

Hamel Ajay Kothari writes about computers and stuff.



Apache Spark Shuffles Explained In Depth


I originally intended this to be a much longer post about memory in Spark, but I figured it would be useful to just talk about Shuffles generally so that I could brush over it in the Memory discussion and just make it a bit more digestible. Shuffles are one of the most memory/network intensive parts of most Spark jobs so it's important to understand when they occur and what's going on when you're trying to improve performance.

What is a shuffle?

Let's start with the most simple case of a non-shuffling RDD. You have an RDD of a text file and you apply a map, followed by a write to disk. For example:

rdd.flatMap {line => line.split(' ') }.saveAsTextFile(outputPath)

The memory usage of this is actually fairly straightforward. In the case of a text file, it streams the file in, line by line applies your flatMap operation and immediately writes (possibly buffering) it out to disk. Therefore you're not really doing anything terribly memory intensive and Spark doesn't really do anything fancy in terms of coordination. If you have two files on disk that are read in by this application, it will read in the two files (partitions) in parallel and write them out in parallel. No coordination is required and this is what is known as an "embarrassingly parallel" process.

Now we bring a shuffle into the picture. Let's say you create a word count application which consists of a flatMap on that text file, followed by a reduceByKey(). For example:

rdd.flatMap { line => line.split(' ') }.map((_, 1)).reduceByKey((a, b) => a + b).collect()

For clarity, this splits each line by whitespace, turns each word into a row and each row into a tuple of (word, 1) which then gets aggregated to result in (word, # of occurrences for that word).

In this case, you're actually going to begin by performing the same streaming, parallel map operation that takes place in the previous example, but things get a little more complex after that. In order to count all of the given words which may exist across an entire dataset (with a bunch of partitions) in Spark, each partition must aggregate all the counts of words within that partition, but then it must also sum across parititions. The process of moving the data from partition to partition in order to aggregate, join, match up, or spread out in some other way, is known as shuffling. The aggregation/reduction that takes place before data is moved across partitions is known as a map-side shuffle.

The following operations are examples of shuffle inducing operations for RDDs:

  • groupBy/subtractByKey/foldByKey/aggregateByKey/reduceByKey
  • cogroup
  • any of the join transformations
  • distinct

If you think about how you would go about performing those operations, it should intuitively make sense why shuffles are occurring. Just ask yourself: does this operation depend on one or multiple rows? If it depends on multiple rows, assume that you know nothing about the contents of a partition: do you need to look across partitions to match up rows in any way? if the answer to those are both yes, then you likely need a shuffle.

Performance Analysis

Let's look at the simplified worst case memory/time complexity of that WordCount just described. Let's say we have N total words with K distinct words, which are evenly distributed (ie. N/K occurrences of each word). In the worst case, we'll say that there is 1 of each key on each executor and there are K executors. Assuming the words hash evenly, each word will end up on a different executor so you'll have to pay the cost of moving K partial aggregates to each executor. Then you will have to finish the aggregate on each of those executors across the cluster. This means that each of the K executors will have to perform K-1 fetches (one to each other node). Reasoning about time complexity is kind of hard here because the K fetches are done serially on each executor, but there are K of them going on at once. This means O(K) fetches are being done at any given moment across the whole cluster and O(K\^2) are done in total. It's also worth noting that the executors that are doing fetching are also serving up files to other executors, and processing the data that they are fetching, all within the same JVM. So as you add more executors, you can expect the shuffle to have a sort of diminishing payoff with regards to parallelism.


Diagram representing the shuffle of a word count.

The above example actually pretty much explains the behavior of the hash based shuffle manager. As its name suggests, keys are hashed and then each executor is responsible for a certain bucket of keys. This is probably the easiest to intuitively understand and reason about performance for.

Until around Spark 1.2 or so, this was also the default manager. In Spark 1.1, they added the Sort based shuffle manager and in Spark 1.2 they made that manager the default. At the bottom of this page we link to some more reading from Cloudera on the Sort based shuffle. If you're interested in how that works and why it's the default, I would suggest you read on there. It follows a similar behavior as the hash based shuffle manager but is more efficient with regards to memory.

Tuning Shuffles

Shuffles are tuned by a few parameters: The shuffle manager (which determines which keys go to which partitions) and the number/size of partitions (these are sort of interchangable) and the number/size of executors processing your data. When you increase the number of partitions, you inherently reduce (or in certain cases, keep the size the same) the size of each partition.

In most cases, you won't actually change the shuffle manager as the two have similar performance but the sort based shuffle manager has been deemed the more optimal one. This means that one of the biggest knobs you can turn when it comes to shuffle performance is the number of partitions. In practice this is done using the "repartition" or "coalesce" functions.

A Series of Examples

One Executor, One Core, One Partition: Looking at the extreme case: if you start with data within 1 partition and you have one executor with one core provisioned for spark. No shuffle will occur in this case because it's all within the one executor. This is great because the shuffle cost is now 0, but your program may be adversely affected in other ways such as OOMs if you can't fit the whole partition into memory or long GC times if it's close to filling up memory. Also, your map tasks will receive no parallelism because only one thread can operate on a partition at a time.

One Executor, Four Cores, Four Partitions: Let's now alter the above example, still keeping one executor, but increasing the number of cores to 4 and number of partitions of data to 4. We've achieved higher parallelism because now our data is split into 4 chunks which can be processed on the map side in parallel. We still haven't sacrificed shuffle performance because it's all within the same JVM so no shuffle needs to take place. But, there is a more subtle challenge here. Now that you have 4 threads processing data at the same time, you may be increasing your potential for an OOM or GC because each of these need to keep their working set in memory.

Two Executors, Two Cores, Four Partitions: So let's make another change. Let's drop down to 2 cores for the sake of keeping things interesting and add another executor. Now each thread has twice as much memory available to it, but you still have two levels of parallelism within each executor. On top of that, when shuffles take place now you'll actually have network traffic to deal with and you'll have the burden of having to fetch and serve files on each executor. If these four partitions are rather small (or have relatively few keys), it's possible that the amount of data shuffled across the wire is low and this isn't a huge burden.

Skewed Keys: An even worse situation can arise in the previous example where you have partitions that instead of having relatively few keys, have a ton. Or, in the most extreme situation, you have really skewed keys where for example, one partition has thousands or millions of distinct keys while others have only hundreds. In this case, the one with many keys will be responsible for aggregrating much more data on the map side, causing memory pressure and potentially OOMs. On top of that, that one executor which processed the big partition will be responsible for serving a disproportionate amount of data relative to the other executors and will slow down because of that. Because everyone else needs to fetch rows from that executor given the shuffle, the performance of the other nodes is intimately tied to the performance of the slow node. What's the best way to combat that issue? Perhaps we could try shuffling the data and repartitioning to a larger number of partitions to ensure that there are fewer distinct keys per partition to ensure the pressure is more distributed.

These are a few of the ways you may think about changing the parameters to tune shuffles for your use cases. It's really important to keep in mind that these are essentially a set of heuristics and because each use case is different, they're best paired with empirical measurements as you make changes to ensure that performance is sufficiently improving. Because there are so many different variables that could be changed (including your workflow, which we didn't explore above) the theory alone isn't enough to make a decision on the best approach.

Other Notes

It's also worth noting the fact that more shuffles aren't always bad. Shuffling allows you to do things that previously were impossible due to memory constraints, etc. For example, increasing the number of shuffles that take place can prevent OOMs or large GC times by decreasing the size of individual tasks which could be operating on a ton of data. The Spark Tuning Guide itself specifically calls this out, pointing out that:

"Spark’s shuffle operations (sortByKey, groupByKey, reduceByKey, join, etc) build a hash table within each task to perform the grouping, which can often be large. The simplest fix here is to increase the level of parallelism, so that each task’s input set is smaller"

What they're essentially suggesting is increasing the number of shuffles (via the number of partitions) but decreasing the size of these partitions/shuffles.

Sometimes, it's possible to avoid shuffling altogether, for example: in the case of a join between a large table (eg: 10MM rows) and a small table (eg: 100 rows). In a situation like this, you would be shuffling the keys of both the small table and the large table to make sure they're colocated across partitions, but moving all of this data could be more expensive than just putting a copy of the small table everywhere. You can actually facilitate that by collecting your smaller table, creating a map using the keys, broadcasting it out using "sc.broadcast" and then performing your join in a map function. In this way you have a linear time join with no shuffles because it's just the amount of time it takes to apply a map over the large dataset.

At this point, for general application tuning/debugging, you can probably stop here. You should have a sufficient understanding of what's going on under the hood when your job runs to understand where data is being moved around and at what scale. If you're curious about the infrastructure that facilitates shuffling and how it's done, keep reading.

How does this work under the hood?

There are two key parts that make shuffles work in Spark, marking when in the RDD a shuffle needs to take place (planning) and during execution actually facilitating the shuffle.

In the RDD Graph

For the above operations usually something called a ShuffleRDD (this is not the case for subtractByKey/cogroup, but we'll get to that) is created which contains an explicit ShuffleDependency on the parent stage and the corresponding aggregate function

  1. When collect is called it's submitted into the DAG scheduler which breaks up the RDD by its dependencies.
  2. When the dependency involves a shuffle (eg. above operations) it creates a separate stage. The way it knows if the dependency involves a shuffle is through the RDD's getDependency call, which can return a ShuffleDependency.
    1. By default non-shuffle operations create a OneToOneDependency which specifies that the parent RDD and child RDD partitions correspond one-to-one to one another (ie. no shuffling involved)
    2. For shuffle inducing operations, a ShuffleDependency is provided. This dependency tracks what the key/value types are and how the keys are combined/ordered for this shuffle.

What this means about the execution style is that each stage is processed independently and the results for that stage the put into the MapOutputTracker/ShuffleManager until the next downstream stage which depends on it requires it. When that downstream stage is executed, it passes in the ShuffleDependency object which identifies the set of blocks in the ShuffleManager and it gets an Iterator[(K, V)] of information from the block manager.

Execution

In practice the way this works, when your RDD is running is that all map stages will be executed over reach row up until a shuffle related action is hit, that row will be "grouped" (we'll explain in a second, how the grouping occurs) for the shuffle, and then it will move on to the next row. What this means is that if you have a bunch of map operations (maps/filters/flatMap/sample/etc) followed by a shuffle operation followed by another map operation, for example:

rdd.flatMap(_.split(' ')).map((_, 1)).reduceByKey((x, y) => x + y).filter(_._1 > 100)

The above operation will apply the flatMap over the first row, then immediately map the resulting functions and "group" the keys for the shuffle and then proceed over to the next row to be processed. Once all the rows have been processed and grouped the next stage, which requires the keys to be grouped, will begin. Each task within this stage will be in charge of handling some subset of the groups from our previous stage but it can act with confidence that it has all the information necessary to that group within the partition that it processes.

In the case of a reduceByKey, we would start streaming that partition and keep track of the reduced value by keeping a running sum of each key. In our above example, each partition may contain multiple keys, but one key will never be spread across multiple partitions. That means that after processing all of the groups for that given partition, it can proceed and apply any remaining map functions over all of the resulting reduced rows. If there were more shuffle operations, that same process would be repeated until the job has completed.

So now the only remaining question is how does a reduce task get a hold on one of the grouped partitions which may live across multiple executors and how do we ensure that every key is only within one partition. Well, after the initial map stages complete, depending on your shuflle manager, each row is either hashed by the key or sorted and put into a file on disk, on the machine that it was sourced from. Then that executors lets something called the ShuffleManager know that it currently has a block of data corresponding to the given key. The ShuffleManager keeps track of all keys/locations and once all of the map side work is done. The next stage starts, and the executors each reach out to the shuffle manager to figure out where the blocks for each of their keys live. Once they know where those blocks live, each executor will reach out to the corresponding executor to fetch the data and pull it down to be processed locally. To enable this, all the executors run a Netty server which can serve blocks that are requested from that specific executor.

So to recap, it proceeds as follows:

  1. Map operations pipelined and executed
  2. Map side shuffle operations performed (map side reduce, etc)
  3. Map side blocks written to disk and tracked within the ShuffleManager/BlockManager
  4. Reduce stage begins to fetch keys and blocks from the ShuffleManager/BlockManager
  5. Reduce side aggregate takes place.
  6. Next set of map/shuffle stages takes place, repeating from step #1.

How does this change in Spark SQL?

So far all we've talked about is how shuffles are performed with RDD code. There's a pretty good chance that you're using DataFrames to interact with Spark and under the hood. It still performs shuffles in a pretty similar way. The only differences are how it keeps track of shuffles within the Query Plan and the information it keeps on what kind of shuffles are required.

The process of building your DataFrame and executing it involves a QueryPlanner/Optimizer which can drastically change the order/type of operations that are being done in order to more efficiently compute your results. I wrote a piece on Query Planning here, but one significant fact is that when you build a DataFrame, before it is executed, Spark optimizes that plan by potentially moving operations around and then figuring out where to put the required shuffles via a  process called an Exchange in Spark SQL.

Spark tracks the Distribution of the data at each stage in the PhysicalPlan in order to ensure that operations are getting all the data they need. Each operator in Spark specifies both the distribution it expects from its children (where they source their data from) and the distribution that the operator results in. This is pretty much analogous to the ShuffleDependency in RDDs but it keeps track of additional SQL data such as the expressions that were shuffled on and potentially more details about the distribution.

More Reading


Comments


Powered by Pelican, Python, Markdown and tons of other helpful stuff.