Since Java 7, the JDK includes a set of classes implementing a fork-join pattern. This is an approach decomposing work into multiple tasks that can be executed in parallel. Java provides ForkJoinPool and ForkJoinTask as the two primary classes implementing the approach. This post will cover an example of using these, by converting a Mergesort implementation from a recursive implementation to one using fork-join.

Mergesort is a classic divide-and-conquer approach to sorting. The data to be sorted are split into two halves, and those halves each are split into two until the data can no longer be split. Then the resulting arrays are merged and sorted.

This can be visualized using this diagram from Wikipedia’s Mergesort page.

A simple recursive implementation in Java is

    public static void mergeSort(int[] array, int size) {
        if (size > 1) {
            int mid = size / 2;
            int[] left = new int[mid];
            int[] right = new int[array.length - mid];
 
            arraycopy(array, 0, left, 0, left.length);
            arraycopy(array, mid, right, 0, array.length - mid);
 
            mergeSort(left, left.length);
            mergeSort(right, right.length);
 
            merge(array, left, right);
        }
    }

As you can see the input array is split in half, with each half (left and right ) then recursively mergesorted. The recursion stops when the array size is 1, and the left and right parts are merged into the original array in sorted order. The merge() is implemented like this

    private static void merge(final int[] dest, final int[] left, final int[] right) {
        int destIdx = 0, leftIdx = 0, rightIdx = 0;
        while (leftIdx < left.length && rightIdx < right.length) {
            if (left[leftIdx] < right[rightIdx]) {
                dest[destIdx++] = left[leftIdx++];
            }
            else {
                dest[destIdx++] = right[rightIdx++];
            }
        }
 
        for (int i = leftIdx; i < left.length; i++) {
            dest[destIdx++] = left[i];
        }
 
        for (int i = rightIdx; i < right.length; i++) {
            dest[destIdx++] = right[i];
        }
    }

Disclaimer: I do not in any way claim that this is the best way to implement the Mergesort or the merge part of the algorithm! But it is functionally correct.

So there you have an example of a divide-and-conquer recursive approach to implementing Mergesort. From this it should be fairly obvious where we’d plug in the fork-join … just use it instead of recursion.

Java ForkJoinPool and ForkJoinTask

You can take a look at the generic tutorial from Oracle here.

The general recommended approach when using the fork-join framework is to apply it to a problem only after some threshold, since the overhead of the fork-join framework would be too costly for simpler cases. For example, with the Mergesort, sorting smaller arrays would be faster just doing a direct implementation as shown above. So in this example I have (arbitrarily) set the cutoff at an array size of 1000.

    public static void forkJoinMergesort(int[] array, int size) {
        if (size < 1000) {
            mergeSort(array, size);
        }
        else {
            ForkJoinPool.commonPool().invoke(new ForkJoinMerge(array, size));
        }
 
    }

If the array size is smaller than 1000 then just use the recursive implementation. Otherwise, use the fork-join framework.

The ForkJoinPool is similar to other thread pools in Java. You can create your own pools, or you can use a shared common pool, via the convenience method ForkJoinPool.commonPool(). ForkJoinPool has several different ways to submit tasks. You can review the Javadoc. The basic approach is to call invoke() passing an implementation of ForkJoinTask. You could directly extend ForkJoinTask, but the JDK provides two implementations that should suffice for many use cases. These are RecursiveTask and RecursiveAction, which return a result or not, respectively --- essentially ForkJoinTask<V> and ForkJoinTask<Void>.

Since the Mergesort is an “in-place” sort there is no return value, so here we use RecursiveAction in the implementation. I have created a class called ForkJoinMerge that extends RecursiveAction.

    private static class ForkJoinMerge extends RecursiveAction {
 
        private final int[] array;
        private final int size;
 
        public ForkJoinMerge(final int[] array, final int size) {
            this.array = array;
            this.size = size;
        }

With the ForkJoinTask you implement the compute() method to do your work, similar how with regular threading you implement Runnable.run() or Callable.call().

The compute() method of my ForkJoinMerge does exactly the same thing as the recursive Mergesort but creates new instances of ForkJoinMerge for the left and right halves. Then it uses the same merge() implementation shown earlier.

        @Override
        protected void compute() {
            if (size > 1) {
                int mid = size / 2;
                int[] left = new int[mid];
                int[] right = new int[array.length - mid];
 
                arraycopy(array, 0, left, 0, left.length);
                arraycopy(array, mid, right, 0, array.length - mid);
 
                var leftFork = new ForkJoinMerge(left, left.length);
                var rightFork = new ForkJoinMerge(right, right.length);
 
                leftFork.fork();
                rightFork.fork();
 
                leftFork.join();
                rightFork.join();
 
                merge(array, left, right);
            }
        }

Once the new ForkJoinMerge instances have been created, their fork() method is called. This causes the tasks to run in the same ForkJoinPool as the current task. As with typical thread execution, to await the completion of the thread, you call join(). Finally, as with the original algorithm, you take the left and right arrays and merge them.

Hopefully this gives a good basic introduction to the Java fork-join framework. You can look at more detailed and perhaps useful uses of it by looking at the JDK itself --- ForkJoinTask is used in the Java stream operator implementations for example. Or check out the Sorter class in DualPivotQuicksort. The Sorter class is a ForkJoinTask and DualPivotQuicksort is used if you call Arrays.parallelSort().