Recursion

Recursion is a general problem solving technique that means breaking a problem down into smaller versions of itself. Presumably there is some “smallest” or “simplest” version of the problem, that does not need to be broken down, where a solution can be found directly. If all “bigger” instances of the problem can (eventually) be broken down into the smallest instance, then logically any instance of the problem can be solved this way.

A simple example is computing the sum of the numbers from 1 to n:

$$\sum_{i=1}^n i = 1 + 2 + \cdots + n$$

A loop-based solution is fairly easy to build:

int sum(int n) {
    int s = 0;
    for(int i = 1; i <= n; ++i)
        s += i;

    return s;
}

To construct a recursive solution, we must think about how to break the problem of finding sum(n) into some sum(n') where \(n’ < n\). Looking at it we have

$$\mathtt{sum}(n) = 0 + 1 + 2 + \cdots + (n-1) + n$$

What is \(0 + 1 + 2 + \cdots + (n-1)\)? In fact,

$$0 + 1 + 2 + \cdots + (n-1) = \mathtt{sum}(n-1)$$

so we can substitute this into our original definition to get

$$\mathtt{sum}(n) = \mathtt{sum}(n-1) + n$$

We still need to specify some “smallest” version of sum that does not rely on itself. Since we specified that the sum starts at 1, we can say

$$\mathtt{sum}(n) = 0 \qquad\text{if}\; n < 1$$

Translating this into C++ we get

int sum(int n) {
    if(n < 1)
        return 0;
    else
        return n + sum(n-1);
}

This definition of sum is identical, in behavior, to the loop-based one.

Tracing recursive functions

It can be difficult to visualize what a recursive function is doing, so to help with this we can draw the recursion tree of a given function call. Suppose we want to trace through the function call sum(3). First, we work down through the recursion, like so:

DIAGRAM

Once we reach the simple case, we work our way back up, until we get back to the original function call:

DIAGRAM

Terminology

In a recursive function, several components are important:

When you look at a recursive function, identifying the base case, recursive case, and parameter of recursion can help you figure out what the function is doing. When you go to write a recursive function, identifying these three factors will give you a blueprint for the structure of the function.

If, when we trace through a recursive function, it comes out as just a straight line we call that a linearly recursive function; this means that each call to the function executes at most one recursive call. A function that executes more than one recursive call is nonlinearly recursive.

Some linearly recursive functions have the property that there is no work to be done on the way back up the trace; once we get to the bottom we can just return the result directly back to the first call. A function like this is called tail recursive. Tail recursive functions can be transformed into a single loop (some programming languages will do this automatically; C++ unfortunately does not, although the compiler may do it in certain circumstances).

Efficiency of recursive functions

Sometimes a recursive function can hid a gross inefficiency. Consider the Fibonacci sequence. This sequence is formed by starting with 1,1 and then adding the previous two numbers to get the next number. Mathematically, it is defined as

$$F[0] = F[1] = 1$$ $$F[n] = F[n-1] + F[n-2]\qquad n > 2$$

Translating this into C++ gives us this function

int fib(int n) {
    if(n == 0 || n == 1)
        return 1;
    else
        return fib(n-1) + fib(n-2);
}

But if we trace through fib(4) something terrible happens

TRACE

The time complexity of a recursive function is determined by the number of function calls; i.e., by the size of the tree I just drew. For this version of fib, the tree size, and hence the complexity, grows as \(O(2^n)\).

The trouble with this implementation is that we are going in the wrong direction: trying to compute fib(n) instead of starting from fib(1) and working our way up. We can create an \(O(n)\) recursive implementation, but it’s no longer quite so simple:

int fib(int n, int p1 = 1, int p2 = 1) {
  if(n == 0)
      return p2;
  else if(n == 1)
      return p1;
  else
      return fib(n-1, p1 + p2, p1); 
}

p1 and p2 store the previous two values in the sequence. When we call fib(n), they are initialized to 1,1, but thereafter we set p1 to the sum (i.e., it becomes the next entry in the sequence) and we shift the old value of p1 down into p2. If we trace through this implementation, we’ll see that it’s now a straight line, rather than a tree:

TRACE

It’s easy to accidentally inject exponential time behavior into a recursive function, even one that isn’t naturally exponential. Consider the integer power function:

$$\texttt{pow}(f,p) = f^p\qquad p \ge 0$$

We used a scheme wherein if \(p\) was even, we split the work in half by squaring \(f^{p/2}\). Recursively, this looks like this:

float pow(float f, int p) {
    if(p == 0)
        return 1;
    else if(p == 1)
        return f;
    else if(p % 2 == 0)
        return pow(f, p/2) * pow(f, p/2);
    else
        return f * pow(f, p-1);
}

Remember that the whole point of this was to turn pow from a \(O(p)\) algorithm into a \(O(\log p)\) algorithm. But if we trace this out

TRACE

It’s \(O(2^p)\)! The culprit is the line pow(f, p/2) * pow(f, p/2); by computing the same value twice we are forcing the trace to branch when its not necessary. Instead, we should save the return value of pow(f, p/2) and then square that:

float pow(float f, int p) {
    if(p == 0)
        return 1;
    else if(p == 1)
        return f;
    else if(p % 2 == 0) {
        float f2 = pow(f, p/2);
        return f2 * f2;
    }
    else
        return f * pow(f, p-1);
}

Tracing through this, we find the \(O(\log p)\) behavior we were expecting.

Recursive data and recursive functions

There are actually two kinds of “recursion” available to us: recursive functions, and recursive data types. Technically, the two are linked together; we can only write a recursive function if we have (or can fake) some kind of recursive data, and if we have a recursive data type, then it will require recursive functions to process.

In the above, we “faked” a recursive natural number type using int. The recursive structure of this pseudo-data type can be described like this:

Every natural number (\(n \ge 0\)) is either

Using this definition we can break down any number into its “1+” form:

$$4 = 1 + (1 + (1 + (1 + (0))))$$

The two cases — base and recursive — cover the two possible “formations” of a number: either 0 (i.e., \(\lt 1\)) or \(\ge 1\). In the recursive case, we invert the 1 + by doing n - 1, giving us the next smaller number. If \(n > 1\) then there is always a smaller number we can get to by doing \(n - 1\).

Given this definition, any recursive function on natural numbers can be defined purely in terms of two cases:

All of our natural number functions will have two cases, one for zero and one for the 1+ case. The only difference is that we will usually write the second as n-1, because we start with n. (I.e., if \(n = n’+1\) then \(n’ = n - 1\).)

As an example, suppose that the normal arithmetic operators weren’t defined for us: + and * did not exist. Can we write them ourselves, using only the recursive definition?

Look at \(a + b\):

Implemented as C++, this looks like this:

int plus(int a, int b) {
    if(a == 0)
        return b;
    else
        return 1 + plus(a-1,b);
}

Trace through plus(2,3).

What about multiplication? Once again, we consider what happens in the two cases in \(a * b\):

This gives us

int mult(int a, int b) {
    if(a == 0)
        return 0;
    else 
        return plus(b, mult(a-1, b));
}

(You might wonder why we don’t have a case for \(1 * b = b\); the answer is that it isn’t necessary: this breaks down into \(1 * b = b + 0 * b\) recursively, which gives the correct answer.)

In normal arithmetic, we know that \(a + 0 = a\) but this is not obvious from the definition we gave above. Can we show that the similar property holds for our plus? Yes, fairly easily:

Another way to think of this is that for any given value of \(a\), you could break it down into a series of steps showing that \(a + 0 = a\) because \((a-1) + 0 = a-1\) because … because \(0 = 0\) at the very end. There would be exactly as many steps as the magnitued of \(a\). This is known as proof by induction, showing that some property holds “for all the things” by showing how to break big things down into smaller ones.

Another property of + is associativity: \(a+(b+c) = (a+b)+c\). Does this property hold for our plus? Here we aim to show that

plus(a,plus(b, c)) = plus(plus(a,b),c)

Once again, we have two cases, based on the possible values for a:

Complexity and inductive proofs

The technique of inductive proof is also useful for verifying a guess about the complexity (big-O) of an algorithm. Take, for example, the \(O(2^n)\) pow function above. Suppose we want to prove that it will really take \(O(2^n)\) steps. We divide this into two cases:

Recursive Strings

Most of our exercises are going to be done on a type of recursive string. A recursive string is defined as either

We call the action of building a string by putting a character onto the front of an existing string “consing a string”.

Internally, we implement recursive strings as a wrapper around string:

#include <string>

class rstring {
  public:
    rstring(std::string s = "") {
        source = s;
    }

    operator bool() { return !source.empty(); }
    operator std::string() { return source; }

    char    front() { return source.front(); }
    rstring rest()  { return rstring(source.substr(1,0)); }

    friend rstring cons(char c, rstring rest);

  private:
    std::string source;
};

inline rstring cons(char c, rstring rest) {
  return rstring(c + rest.source);
}

(You can save the above as rstring.hpp if you want to try the examples.)

We can check a recursive string for empty-ness by just using it as a bool:

rstring s = ...;
if(s) 
    // s is not empty

If a rstring is non-empty, then we can extract its .front() and .rest(). We can cons together a new rstring from a character and a given rstring.

String operations

Let’s implement some standard string operations on rstrings.

Length

How do we find the length of an rstring? Again, there are only two situations to consider:

Starting out, we have

int length(rstring r) {
    if(!r)
        return 0;
    else
        return ...;
}

Recursively, we will define the length of a string s to be \(1 +\) the length of its rest, giving us

int length(rstring r) {
    if(!r)
        return 0;
    else
        return 1 + length(r.rest());
}

We can verify that this works by tracing through it: length("cat").

TRACE

You can think of length as just being a process that converts empty into 0 and cons into 1+. (Note, also, that we’re only using 1+ here, not the general plus.)

String concatenation

String concatenation is the process of putting two strings together, end to end. For example the concatenation of "Hello" and "World" is "HelloWorld". How do we do string concatenation recursively?

Suppose we the cons consisting of H and "ello" and we want to concatenate this with "World". The only primitive operation we have available to us is putting a character on the front of a string. If we already had the concatenation of "ello" and "World", "elloWorld" then we could just cons 'H' on the front of it and we’d be done. But this is exactly what the (recursive) concatenation gives us.

Thus we have

rstring concat(rstring a, rstring b) {
    if(!a)
        return b;
    else
        return cons(a.front(), concat(a.rest(), b));
}

Tracing through concat("cat", "dog") gives us

TRACE

We would expect that length(s1) + length(s2) == length(concat(s1,s2)). That is, that concat is to strings what + is to numbers. We can actually show that this is true, by using the definitions of +, length and concat together:

Push-back

A cons is like a push_front; can we do a push_back? This is just a concat of a one-character rstring:

rstring cons_back(rstring s, char c) {
    return concat(s, cons(c, ""));
}

I’ve named this cons_back to emphasize that, unlike push_back, it constructs a new string, rather than modifying it in place.

Character indexing

Can we look up a character at a particular index, akin to .at(n) or [n]? This will be recursive on both the string and the number.

char at(rstring s, int n) {
    if(!s)
        // Error!
    else {
        if(n == 0)
            return s.front();
        else
            return at(s.rest(), n - 1);
    }
}

Reverse

We can use push_back to recursively construct the reversal of a string:

rstring reverse(rstring s) {
    if(!s)
        return "";
    else
        return push_back(reverse(s.rest()), s.front());
}

We would expect that if we reverse a string, its length stays the same; that is, that

$$\texttt{length}(s) = \texttt{length}(\texttt{reverse}(s))$$

Can we prove this? In order to do so, we first need to prove a lemma, a kind of sub-proof, which says

$$\texttt{length}(\texttt{cons_back}(s,c)) = 1 + \texttt{length}(s)$$

But because cons_back is defined in terms of concat, and we earlier proved that length(concat(s1,s2)) = length(s1) + length(s2), all we have to do is prove that length(cons(c, "")) = 1. By the definition of length:

length(cons(c,"")) =
1 + length("") = 
1 + 0 = 
1

and we’re done. We’ll call this the ConsBack1 lemma.

With that in place, we want to show that

$$\texttt{length}(s) = \texttt{length}(\texttt{reverse}(s))$$

Extracting a prefix

A prefix of a string is some portion at the beginning. For example, "He" is a prefix of "Hello". We want to extract a prefix of a given length l (this is a first step to extracting a substring).

rstring prefix(rstring s, int l) {
    if(l == 0 || !s)
        return "";
    else
        return cons(s.front(), prefix(s.rest(), l-1));
}

Extracting a substring

Extracting a substring involves doing two things:

We do the first part recursively, and then second part using prefix.

rstring substring(rstring r, int s, int l) {
    if(r == 0)
        return prefix(r,l);
    else
        return substring(r.rest(), s-1, l);
}

Removing a prefix

Removing a prefix is the first step to erasing a substring.

rstring remove_prefix(rstring s, int l) {
    if(l == 0)
        return s;
    else if(!s)
        return "";
    else
        return remove_prefix(s.rest(), l-1);
}

Erasing a substring

rstring erase(rstring r, int s, int l) {
    if(s == 0)
        return remove_prefix(r,l);
    else 
        return cons(r.front(), erase(r.rest(), s-1, l));
}

Note that in the recursive case we put the first character back on to the erase-d substring, because if \(s > 0\) then we aren’t yet to the portion of the string we want to erase.

Comparing two strings for equality

There are two bases cases here:

In the recursive case, we compare the .front()s of each string; if they are different then the strings must be different, if they are the same, then we compare (recursively) the .rest() of the strings.

bool equal(rstring a, rstring b) {
    if(!a || !b)
        return false;
    else if(a.front() != b.front())
        return false;
    else
        return equal(a.rest(), b.rest());
}

Searching for a substring

Suppose we want to reimplement the standard string find method, which returns the location of a substring (or a value bigger than the length of the string, if it was not found). This is a two-step process: matching the substring against the beginning of the target string (i.e., as a prefix) and then using that to match anywhere in the string.

bool prefix_match(rstring src, rstring ptn) {
    if(ptn.empty())
        return true;
    else
        return src.front() == ptn.front() &&
               prefix_match(src.rest(), ptn.rest());
}

int find(rstring src, rstring ptn) {
    if(src.empty())
        return MAX_LENGTH; // max string length
    else {
        if(prefix_match(src,ptn))
            return 0;
        else
            return 1 + find(src.rest(), ptn);
    }
}

Finding the factorial

The factorial of n is defined as the product \(n (n-1) (n-2) \ldots (2) (1)\), i.e., the product of all integers up to and including n. It’s easy to write as a loop:

int factorial_iter(int n) {
    int r = 1; // Factorial of 0 is 1
    for(int i = 1; i <= n; ++i)
        r *= i;
    return r;
}

To write this, or any other algorithm, recursively, we have to ask two questions:

For the factorial, the base case is what happens when \(n = 0\): the loop doesn’t run at all, and 1 is returned. So we can start our recursive version with

int factorial_rec(int n) {
    if(n == 0)
        return 1;
    else
        ...
}

To construct the recursive case, we need to look at what happens when n > 0. In particular, how can we break \(n!\) down into some \(n’ !, n’ < n\)? The most common case is \(n’ = n - 1\).

One way to look at this is to assume that we already have the value of \((n-1)!\), and we want to get \(n!\) from it. That is, assume that factorial_rec(n - 1) will work and give us the right answer; we just need to construct the factorial of n from it. How can we do this? \(n! = n (n-1)!\). So we write our recursive case like this:

int fact(int n) {
    if(n == 0)
        return 1;
    else
        return n * fact(n - 1);
}

Let’s take a minute to walk through the process of computing factorial_rec(3):

Inductive proof

How do we show that a function does what it is supposed to do? We could test it, running it thousands or millions of times and verifying that its output is what we expect, but this requires us to come up with an independent way to define what the function does (e.g., a different way of computing the factorial), which might itself be incorrect, and furthermore, repeated testing can only ever give us a statistical confidence that our algorithm is correct. If we want to be sure, then we need a logical, or mathematical proof that it is correct. For recursive functions, this often takes the form of proof by induction. An inductive proof is kind of the mathematical equivalent to a recursive function. Like a recursive function it has base case(s) (one base case, in fact, for every base case in the function), and the base cases are usually easy. It also has inductive case(s) (one for each recursive case in the function), which are somewhat more tricky, but allow us to do something like recursion.

Consider the example above. We want to prove that fact(n) =\(n!\), where the definition of \(n! = n(n-1)(n-2)\ldots(2)(1), 0! = 1\).

Proof by induction on n (whatever variable we do the recursion on, we say we are doing “proof by induction” on that variable):

Like recursion, the heart of an inductive proof is the act of applying the proof itself as an assumption about “smaller” values (\(n’ < n\)). Technically, there are two kinds of inductive proofs:

The integer exponent calculation

Remember when we worked out the runtime complexity of our “optimized” \(O(\log n)\) function for finding a \(b^n\)? We can write a recursive version of that as well. Once again, we have to ask

This gives us the following definition:

float powi(float b, int n) {
    if(n == 0)
        return 1;
    else if(n % 2 == 0) { 
        // Even
        float fp = powi(b, n / 2);
        return fp * fp;
    }
    else if(n % 2 == 1) // Odd
        return f * powi(b, n - 1);
}

This has the same complexity as the loop-based version, and is arguably simpler.

In this case, if we want to prove that \(\mathtt{powi}(b,n) = b^n\) we’ll need strong induction, because one of the recursive cases shrinks the input by something other than just -1.

Proof that \(\mathtt{powi}(b,n) = b^n\) by strong induction on \(n\):

Mutual recursion

Mutual recursion is when we define several recursive functions in terms of each other. For example, consider the following definition of even and odd:

We can then define two functions (predicates) that recursively refer to each other:

bool is_even(int n) {
    if(n == 0)
        return true;
    else if(n == 1)
        return false;
    else
        return is_odd(n - 1);
}

bool is_odd(int n) {
    if(n == 0)
        return false;
    else if(n == 1)
        return true;
    else
        return is_even(n - 1);
}

If we track out the processing of determining is_even(4), we’ll see that it bounces back and forth between is_even and is_odd.

We did a binary search iteratively, but we can do it recursively as well:

This looks like

template<typename T>
int binary_search(const vector<T>& data, 
                  int low = 0, 
                  int high = data.size()-1,
                  const T& target) {

    if(low > high)
        return -1;

    int mid = low + (high - low) / 2; // Why did I do this?

    if(data.at(mid) == target)
        return mid;
    else if(data.at(mid) < target) // Search right
        return binary_search(data, mid+1, high, target);
    else if(data.at(mid) > target) // Search left
        return binary_search(data, low, mid-1, target);
}

Other examples: Counting the number of copies in a vector. For any vector-style recursion, we need to keep track of our “starting place” within the vector. This is because we can’t make the vector itself smaller, so we have to put a marker into it showing where we are starting. We can do this in two ways, with an int start parameter, or by using iterators.

template<typename T>
int count(vector<T> data, int start, T target) {
    if(start == data.size())
        return 0;
    else
        return (data.at(start) == target) +
               count(data, start + 1, target);
}

With iterators:

template<typename T, typename It>
int count(It start, It finish, T target) {
    if(start == finish)
        return 0;
    else
        return (*start == target) + 
               count(start + 1, finish, target);
}

Iterators are kind of like pointers.