March 16, 2021

Increase PySpark's JDBC parallelism through predicates

Increase the parallelism of loading data through JDBC with Spark using predicates on non-numeric columns.

Increase PySpark's JDBC parallelism through predicates

By default Spark only uses 1 partition to read data through a JDBC connection. This can be increased by using the options numPartitions, lowerBound, upperBound and column, but the caveat is that column has to be of a numerical type, and thus, so have lowerBound and upperBound, which doesn't work for date columns. The latter is typical for event-based data.

It was supposedly fixed in 2.4.0, but in 3.x (and maybe earlier already), Spark has reverted  the jdbc() method back to accepting only integers. Although, to be fair, regardless of that the ticket says, the Scala API docs for 2.4.0 also only specify Long types. Presumably, because there is another way to partition JDBC data at the source: By using predicates. So if you stumbled on the following exception, or you were just searching for an example using predicates, read on!", line 625, in jdbc
TypeError: int() argument must be a string, a bytes-like object or a number, not 'datetime.datetime'

The predicates parameter in one of the jdbc() signatures takes a list of strings that can be used in where clauses. So by creating a list of none overlapping intervals, every partition is filled with a chunk of the table as defined by an interval.

 input_df = \
    .option("driver", DRIVER_CLASS_NAME) \

The magic is in the _generate_predicates() method, which uses a start and end timestamp and divides it in as many chunks as desired number of partitions. The resulting intervals are then used to format a string and this fills the list of predicates. The first example below is maybe a bit dense, although it is advisable to work this way with Spark in a professional data engineering setting (e.g. configuration objects).

def _generate_predicates(config):
    predicate_template = "'{}' <= {} and {} < '{}'"
    delta_t = (config.load_until - config.load_since) / config.num_chunks
    predicates = []
    for i in range(0, config.num_chunks):
        interval_start = (config.load_since + i * delta_t).isoformat()
        interval_end = (config.load_since + (i + 1) * delta_t).isoformat()

    return predicates

Notice how the template uses single quotes to enclose the date values. This next snippet will make the date logic more clear:

# or defaultParallelism from your Spark context
num_chunks = 10

# the step size
delta_t = (load_until - load_since) / num_chunks

# the template
tpl = "'{}' <= {} and {} < '{}'"

# the column to which to apply the partitioning
p_col = "my-column"

# then build the list
predicates = []
for i in range(0, num_chunks):
    interval_start = (load_since + i * delta_t)
    interval_end = (load_since + (i + 1) * delta_t)

With this simple piece of code it is possible to increase the parallelism of your data ingestion via JDBC using predicates. Do watch out that you don't bring your source to its knees with a big Spark cluster.