Tail call functions on Rust I
Ivan Molina Rebolledo- 4 min read
Motivation
Recursive functions are often a great way of writing declarative programs. This allows for greater expressiveness on the making process. However, they come with a catch: slowness and stack overflow, which is not great. The so-called tail call functions have a specific form that can be transform into normal jump instructions using "Tail Call Optimization". However, this is not exactly an article about TCO or recursive functions, but rather one about a way of describing tail recursive algorithms in Rust (which lacks proper TCO).
The Closure-Enum way
I found this algorithm that @10maurycy10 wrote. It only uses a sum type and a helper function.
enum TCO<A, B> {
Rec(A),
Ret(B),
}
fn tco<A, B>(f: fn(A) -> TCO<A, B>) -> impl Fn(A) -> B {
move |p: A| {
let mut c: TCO<A, B> = TCO::Rec(p);
loop {
match c {
TCO::Rec(i) => c = f(i),
TCO::Ret(b) => return b,
}
}
}
}
We have a type \(\textrm{TCO<A,B>}\) with two constructors: \(\textrm{Rec<A>}\) and \(\textrm{Ret<B>}\). The first case, \(\textrm{Rec<A>}\), is used to store the function parameter, while \(\textrm{Ret<B>}\) serves as the actual return value of the function. The biggest downside of this method is that we need to explicitly specify the return and recursion points.
The function \(\textrm{tco}\) serves as the actual code that makes the loop. So, the interesting part is that the function that we want to behave as a proper tail call is not actually going to be recursive at all. All the places where we are supposed to have recursive function calls are going to be just calls to the \(\textrm{Rec<A>}\) constructor. It is just fake recursion.
As we can see in the code, the function loops until \(c\) is no longer a tuple with the function parameters. Every time we have a \(\textrm{Rec<A>}\), we evaluate the function with newer values defined by our function.
Example I: Fibonacci as closure
Is easy to write recursive closures in this way. We define the tail recursive Fibonacci as follows:
fn main() {
let fib = tco(|s: (u64, u64, u64)| {
match s {
(0, i, _) => TCO::Ret(i),
(1, _, j) => TCO::Ret(j),
(2, i, j) => TCO::Ret(i + j),
(c, i, j) => TCO::Rec((c - 1, j.clone(), i + j))
}
});
assert_eq!(21, fib((8, 0, 1)));
}
This works as expected.
Example II: Fibonacci as function
fn fib(c: u64, i: u64, j: u64) -> u64 {
tco(|s: (u64, u64, u64)| {
match s {
(0, i, _) => TCO::Ret(i),
(1, _, j) => TCO::Ret(j),
(2, i, j) => TCO::Ret(i + j),
(c, i, j) => TCO::Rec((c - 1, j.clone(), i + j))
}
})((c, i, j))
}
fn main() {
assert_eq!(21, fib(8, 0, 1));
}
And this also works as expected. However is a little more verbose than using closures, but it can be used in situations where closures are not an option.
Analysing assembly output
So I decided to test the second version with the following:
fn main() {
let args: Vec<String> = env::args().collect();
let i = args[0].parse::<u64>().unwrap();
let val = fib((8u64, 0u64, 1u64));
let val2 = fib((21u64, 0u64, 1u64));
let val3 = fib((i, 0u64, 1u64));
println!("Hello, world!, {} {} {}", val, val3, val2);
}
There is a difference here. Instead of having individual arguments (as in the example), I'm just passing the tuple as-is to the closure. This is probably not going to have an impact I think, but I could be wrong.
Let us take a look into the assembly:
Ltmp14:
testb $1, %al
jne LBB10_47
movq $21, -232(%rbp)
movq $10946, -88(%rbp)
movl $1, %ecx
xorl %eax, %eax
testq %rdx, %rdx
je LBB10_54
.p2align 4, 0x90
LBB10_50:
cmpq $1, %rdx
je LBB10_51
cmpq $2, %rdx
je LBB10_53
decq %rdx
movq %rax, %rsi
addq %rcx, %rsi
movq %rcx, %rax
movq %rsi, %rcx
testq %rdx, %rdx
jne LBB10_50
jmp LBB10_54
.p2align 4, 0x90
This is the interesting part of all the file.
The first thing we can notice is that the first two fib invocations are already computed.
movq $21, -232(%rbp)
movq $10946, -88(%rbp)
So Rust is really optimizing this part. Which is great.
But we also have the \(\textrm{LBB10\_50}\) procedure. This is the actual algorithm for Fibonacci. And, as you can see, it looks just as imperative as an iterative approach. These are the steps:
- Is the counter at 1?
- Return if yes.
- Is the counter at 2?
- Return if yes.
- Decrement the counter by 1.
- Move the first accumulator to a temporal variable (\(\textrm{let}\ \mathrm{tmp} = i\))
- Add the second accumulator to the temporal variable (\(\mathrm{tmp} += j\))
- Move the second accumulator to the position of the first (\(i = j\))
- Move the temporal variable to the position of the second acumulator (\(j = \mathrm{tmp}\))
- Is the counter at 0?
- Loop to the start if no.
We can easily write this as a for loop.
Is this really effective?
Yes and no.
It works, it is great. But maybe I need to remember you that this test was made using just 8 bytes values. Chances are that if you are calculating Fibonacci, you will be using much larger number, and thus a bigint library.
And I did try it with a bigint library. It is slower than the normal iterative approach (but I was using a function for copying data from the std, so maybe this comparison was kind of unfair). However, the performance is really close (even if not the same).
Future improving
I have some cool ideas about thisAnyways, thanks for reading.