In this continuing series of my back to basics, I've been talking about recursion, and various strategies around using it effectively. This includes covering the basic types of recursion, whether it be linear, binary, and tail. Then I took it a step further with topics on list recursion and memoization techniques. This is an ongoing part of my back to basics series in which I hope is a refresher for many who don't use these things on a daily basis.
Let's catch up to where we are today:
This time, we're going to talk about recursion and Continuation Passing Style (CPS). We'll include both samples in F# as well as C# for these examples.
Continuation Passing Style
Earlier in the series, I talked about several different ways to approach recursion. Today we're going to bring CPS into the terminology. Let's first discuss the first word in the title, Continuation. Simply put, a continuation represents the remainder of a computation given a point in the computation. You could almost think that a continuation is GOTO with a parameter that is the value of the function that transferred the control. It is unlike a function call because, while it is possible to return to the original computation it is not always expected or necessary. Yet, many people get tripped up by this definition.
Writing a function using CPS takes an explicit continuation function argument which is meant to receive the result of the computation that was performed in the function. When doing this, the caller is expected to give the function to be invoked during the end of the function. So, what this means is instead of returning a value from a function, the value is passed to the code that will continue the computation. Using this programmatic style gives us a few things, which include intermediate values, order of argument processing and tail recursion. Why is this important? Besides making you the obvious hit at any party, and an absolute dev maven, it's important if you want to understand another fun topic of Monads as well as asynchronous calls. But, that's for another day. I'll go into a scenario below where it's absolutely useful in regards to recursion. But, in the mean time, let's look at how CPS differs from direct calls.
Let's first show a quick example of CPS in action using F#:
let square_func x = x * x
(* val square_func : int -> int *)
let square_cps x cont = cont(x * x)
(* val square_cps : int -> (int -> 'a) -> 'a *)
let result = square_cps 4 (fun x -> x)
What I did in the above syntax is take a standard squaring function, and applied a continuation which handled the result of the computation. Then in the calling function, I gave it an identity continuation which is simply taking the input and returning it. This is all well and good, but let's try to apply this more towards our story at hand with recursion.
CPS and Recursion
In this series, I showed how you could take a normal recursive function and move it from linear to tail recursive. This time, we'll take that same approach with going towards CPS. In the past I've been guilty of showing the standard Fibonacci sequence as well as factorial, so let's switch it up and go towards our ImmutableList, or just a List<'a> for those in the F# world. The immutable list is part of my Functional C# library on MSDN Code Gallery.
Let's first go with our standard linear recursion and what that looks like when calculating the length of our list.
let rec length = function
|  -> 0
| _ :: t -> 1 + length t
public static int Length<T>(this IImmutableList<T> list)
if (list.IsEmpty) return 0;
return 1 + Length(list.Tail);
As you notice, our last calculation wasn't a tail call, instead we were adding 1 to the result of the recursive call. Now, instead of this, which could stack overflow on rather large data sets, let's optimize it for the tail call. This usually includes an inner function to do this with an accumulator involved.
let length_tail lst =
let rec length_acc l acc =
match l with
|  -> acc
| _::t -> length_aux t (1 + acc)
length_acc lst 0
public static int LengthTail<T>(this IImmutableList<T> list)
Func<IImmutableList<T>, int, int> length_acc = null;
length_acc = (l, acc) =>
l.IsEmpty ? acc : length_acc(l.Tail, 1 + acc);
return length_acc(list, 0);
Now these functions are optimized by the tail call. I created an inner function which takes the list and an accumulator and recurses over itself by adding 1 to the accumulator. But, as I've stated before, the C# compiler doesn't optimize for the tail call. This has been a known issue for some time now, and I don't think it's on the highest priority to fix. After all, the JITer does do a simple optimization with a tail call on x64 only. If you're running on an x86, you'll still get stack overflows. More reason to look at F# when doing recursive algorithms.
But, let's take this a step further. What would it take, to take this above function and turn it into CPS? Well, to think about turning something like that function, you have to think backwards just a little bit. You'll see what I mean.
let length_cps lst =
let rec length_cont l cont =
match l with
|  -> cont 0
| _::t -> length_cont t (fun x -> cont(1 + x))
length_cont lst (fun x -> x)
public static int LengthCont<T>(this IImmutableList<T> list)
Func<IImmutableList<T>, Func<int, int>, int> length_cont = null;
length_cont = (l, cont) =>
l.IsEmpty ? cont(0) : length_cont(l.Tail, x => cont(1 + x));
return length_cont(list, x => x);
As you can see from above, it really turned our function inside out. It's enough to make someone's head hurt sometimes. Let's give one last example on how it can turn a function inside out, and yes, I lied about not bringing our standard Fibonacci sequence in here as well:
let fibonacci_cps n =
let rec fibonacci_cont a cont =
if a <= 2 then cont 1
fibonacci_cont (a - 2) (fun x ->
fibonacci_cont (a - 1) (fun y ->
cont(x + y)))
fibonacci_cont n (fun x -> x)
static int FibonacciCont(int n)
Func<int, Func<int, int>, int> fibonacci_cont = null;
(a, cont) => a <= 2 ? cont(1)
: fibonacci_cont(a - 2, x => fibonacci_cont(a - 1, y => cont(x + y)));
return fibonacci_cont(n, x => x);
This by itself is quite interesting. But, the problem is still that C# isn't optimized for the tail call, so I see no benefit of writing it this way. Also, CPS has a tendency to be an order of magnitude slower than standard tail calls and even tail calls for that matter. With regards to F# though, since it's optimized for the tail call, produces well formed and fast code. But, there are areas where you should care with regards to using CPS versus standard tail calls. One such scenario is for parsing unbalanced trees.
Parsing Unbalanced Trees
When dealing with tree structures, it's a lot harder to get right when using tail recursion as opposed to very structured data. Therefore, CPS comes to the rescue here. For these examples, I'm only going to use F# because C# doesn't have the support I need when accomplishing this. First, let's define a tree structure that we're going to parse and a binary recursive function that calculates the size of our tree.
type Tree<'a> =
| Node of 'a * Tree<'a> * Tree<'a>
| Leaf of 'a
let rec tree_size = function
| Leaf _ -> 1
| Node(_, left, right) -> tree_size left + tree_size right
The problem with this function is that when I start getting large unbalanced trees, we are at risk of a stack overflow. So, let's try to move it towards using tail calls in order to fix the issue.
let tree_size_tail tree =
let rec size_acc tree acc =
match tree with
| Leaf _ -> 1 + acc
| Node(_, left, right) ->
let acc = size_acc left acc
size_acc right acc
size_acc tree 0
As you can see, the way this code as written is that the left side is not actually tail recursive at all, and instead only on the right side. This might be acceptable if the tree were balanced to the right, but if the tree is skewed to the left, then we'll have a problem. This is where CPS can help to solve this issue. Let's try now to apply our knowledge from above on how to modify our functions to use CPS with an accumulator, instead.
let tree_size_cont tree =
let rec size_acc tree acc cont =
match tree with
| Leaf _ -> cont (1 + acc)
| Node(_, left, right) ->
size_acc left acc (fun left_size ->
size_acc right left_size cont)
size_acc tree 0 (fun x -> x)
What we were able to accomplish is the following:
- Create an inner function which uses an accumulator, our tree and a continuation function as input.
- Return 1 plus the accumulator if we've reached the leaf of the tree
- Else, we're going to call the function to get the left tree size recursively until we reach the leaf. We create a continuation to get the right tree size.
- Finally, we call the right tree size while passing in the accumulator and our continuation.
Using this approach, we don't face the the threat of stack overflows, but as well, we've used only two short lived continuations to compute this item. This is a pretty advanced technique in most languages and isn't used too much. But once you understand how they work and their uses, it's very powerful.
Wrapping it Up
I hope this brief introduction to CPS with regards to recursion was interesting. If you'd like to know more about continuations in general, Wes Dyer, of the Volta team has a pretty good explanation and the various uses of CPS here. From here, it's best to move onto monads (aka computation expressions in F#), especially with regards to the asynchronous processing. As always, I make my code samples available through the MSDN Code Gallery in the Functional C# library.