def factorial(n: Int): Int = if (n == 0) 1 else n * factorial(n - 1)
factorial(3)
defined function factorial res0_1: Int = 6
Thus far in our programming, we have no way to repeat. A natural way to repeat is using recursive functions. Let us consider defining a Scala function that computes factorial. Recall from discrete mathematics that factorial, written \(n!\), corresponds to the number of permutations of \(n\) elements and is defined as follows: \[\begin{array}{rcl} n! & \stackrel{\text{\tiny def}}{=}& n \cdot (n - 1) \cdot \;\cdots\; \cdot 1 \\ 0! & \stackrel{\text{\tiny def}}{=}& 1 \end{array}\] From the definition above, we see that factorial satisfies the following equation for \(n \geq 1\): \[ n! = n \cdot (n - 1)! \] Based on this equation, we can define a Scala function to compute factorial as follows:
def factorial(n: Int): Int = if (n == 0) 1 else n * factorial(n - 1)
factorial(3)
defined function factorial res0_1: Int = 6
Let us write out some steps of evaluating factorial(3)
:
factorial(3) |
\(\longrightarrow^{\ast}\) | if (3 == 0) 1 else 3 * factorial(3 - 1) |
\(\longrightarrow^{\ast}\) | 3 * factorial(2) |
|
\(\longrightarrow^{\ast}\) | 3 * 2 * factorial(1) |
|
\(\longrightarrow^{\ast}\) | 3 * 2 * 1 * factorial(0) |
|
\(\longrightarrow^{\ast}\) | 3 * 2 * 1 * (if (0 == 0) 1 else 1 * factorial(0 - 1)) |
|
\(\longrightarrow^{\ast}\) | 3 * 2 * 1 * 1 |
|
\(\longrightarrow^{\ast}\) | 6 |
where the sequence above is shorthand for expressing that each successive pair of expressions is related by the multi-step evaluation relation \(\longrightarrow^{\ast}\) written between them.
Observe that the variable factorial
needs to be in scope in the function body (i.e., the expression after =
) to enable the recursive definition. To define a recursive function, the return type : Int
has to be given for factorial
to be in scope in the function body. (Why? To enable static type checking.)
def factorial(n: Int) = if (n == 0) 1 else n * factorial(n - 1)
cmd1.sc:1: recursive method factorial needs result type
def factorial(n: Int) = if (n == 0) 1 else n * factorial(n - 1)
^Compilation Failed
:
Compilation Failed
Induction is important proof technique for reasoning about recursively-defined objects that you might recall from a discrete mathematics course. Here, we consider basic proofs of properties of recursive Scala functions.
The simplest form of induction is what we call mathematical induction, that is, induction over natural numbers. Intuitively, to prove a property \(P\) over all natural numbers (i.e., \(\forall n\in\mathbb{N}.P(n)\)), we consider two cases: (a) we prove the property holds for \(0\) (i.e., \(P(0)\)), which is called the base case; and (b) we prove that the property holds for \(n+1\) assuming it holds for an \(n \geq 0\) (i.e., \(\forall n\in\mathbb{N}.(P(n) \implies P(n + 1))\)), which is called the inductive case.
As an example, let us prove that our Scala function factorial
computes the mathematical definition of factorial \(n!\). To state this property precisely, we need a way to relate mathematical numbers with Scala values. To do so, we use the notation \(\llcorner n \lrcorner\) to mean the Scala integer value corresponding to the mathematical number \(n\) (i.e., \(\llcorner n \lrcorner\) : Int
as long as \(n\) is representable as an Int
).
Theorem 8.1 For all integers \(n\) such that \(n \geq 0\), \[\texttt{factorial(}\llcorner n \lrcorner\texttt{)} \longrightarrow^{\ast} \llcorner n! \lrcorner \;.\]
Proof. By mathematical induction on \(n\).
Case \(n = 0\): Note that \(\llcorner 0 \lrcorner = \texttt{0}\). Taking a few steps of evaluation, we have that \[\texttt{factorial(0)} \longrightarrow^{\ast}\texttt{1} \;.\] Then, the Scala value can also be written as \(\llcorner 0! \lrcorner\) because mathematically \(0! = 1\).
Case \(n = n' + 1\) for some \(n' \geq 0\): The induction hypothesis is as follows: \[\texttt{factorial(}\llcorner n' \lrcorner\texttt{)} \longrightarrow^{\ast}\llcorner n'! \lrcorner \;.\]
Let us evaluate \(\texttt{factorial(}\llcorner n \lrcorner\texttt{)}\) a few steps, and we have the following: \[ \texttt{factorial(}\llcorner n \lrcorner\texttt{)} \longrightarrow^{\ast} \llcorner n \lrcorner \;\texttt{*}\; \texttt{factorial(}\llcorner n - 1 \lrcorner\texttt{)} \] because we know that \(n \neq 0\).
Applying the induction hypothesis (observing that \(n - 1 = n'\)), we have that \[ \llcorner n \lrcorner\;\texttt{*}\;\texttt{factorial(}\llcorner n' \lrcorner\texttt{)} \longrightarrow^{\ast} \llcorner n \lrcorner \;\texttt{*}\; \llcorner n'! \lrcorner \] By further evaluation, we have that \[ \llcorner n \lrcorner \;\texttt{*}\; \llcorner n'! \lrcorner \longrightarrow \llcorner n \cdot n'! \lrcorner \;.\] Note that \(n \cdot n'! n = n \cdot (n - 1)! = n!\), which completes this case.
In the above, we are actually using an abstract notion of evaluation where Scala integer values are unbounded. In implementation, Scala integers are in fact 32-bit signed two’s complement integers that we have ignored in our evaluation relation. It is often convenient to use abstract models of evaluation to essentially separate concerns. Here, we use an abstract model of evaluation to ignore overflow.
There is another style of writing recursive functions using pattern matching that looks somewhat closer to structure of an inductive proof. For example, we can write an implementation of factorial equivalent to as follows:
defined function factorial res1_1: Int = 6
The \(\texttt{match}\) expression has the following form:
\[\begin{array}{l} e\; \texttt{match} \; \texttt{\{} \\ \quad \texttt{case} \; \mathit{pattern}_1 \; \texttt{=>} \; e_1 \\ \quad \ldots \\ \quad \texttt{case} \; \mathit{pattern}_n \; \texttt{=>} \; e_n \\ \texttt{\}} \end{array}\]
and evaluates by comparing the value of expression \(e\) against the patterns given by the \(\texttt{case}\)s. Patterns are tried in sequence from \(\mathit{pattern}_1\) to \(\mathit{pattern}_n\). Evaluation continues with the corresponding expression for the first pattern that matches. Again, we will revisit pattern matching in detail in ?sec-data-structures-and-pattern-matching. For the moment, simply recognize that patterns in general bind names (like seen previously in Section 4.4.4). In Listing 8.1, we use the “wildcard” pattern _
to match anything that is non-zero.
The definitions of factorial
given aboe and implicitly assume that they are called with non-negative integer values. Consider evaluating factorial(-2)
:
factorial(-2) |
\(\longrightarrow^{\ast}\) | -2 * factorial(-3) |
\(\longrightarrow^{\ast}\) | -2 * -3 * factorial(-4) |
|
\(\longrightarrow^{\ast}\) | -2 * -3 * -4 * factorial(-5) |
|
\(\longrightarrow^{\ast}\) | -2 * -3 * -4 * -5 * factorial(-5) |
|
\(\longrightarrow^{\ast}\) | \(\ldots\) |
We see that we have non-termination with infinite recursion. In implementation, we recurse until the run-time yields a stack overflow error.
Following principles of good design, we should at least document in a comment the requirement on the input parameter n
that it should be non-negative. In Scala, we do something a bit better in that we can specify such preconditions in code:
def factorial(n: Int): Int = {
require(n >= 0)
match {
n case 0 => 1
case _ => n * factorial(n - 1)
}
}
factorial(-2)
java.lang.IllegalArgumentException: requirement failed scala.Predef$.require(Predef.scala:325) ammonite.$sess.cmd2$Helper.factorial(cmd2.sc:2) ammonite.$sess.cmd2$Helper.<init>(cmd2.sc:8) ammonite.$sess.cmd2$.<clinit>(cmd2.sc:7)
If this version of factorial
is called with a negative integer, it will result in a run-time exception. The require
function does nothing if its argument evaluates to true
and otherwise throws an exception if its argument evaluates to false
.
For factorial
, it is clear that the require
will never fail in any recursive call. We really only need to check the initial n
from the initiating call to factorial
. One way we can do this is to use a helper function that actually performs the recursive computation:
def factorial(n: Int): Int = {
require(n >= 0)
def f(n: Int): Int = n match {
case 0 => 1
case _ => n * f(n - 1)
}
f(n)
}
factorial(3)
defined function factorial res3_1: Int = 6
Here, the f
function is local to the factorial
function. The f
does not do any checking on its argument, but the require
check in factorial
will ensure that f
always terminates.
Examining the evaluation of the various versions of factorial
in this section, we observe that they all behave similarly: (1) the recursion builds up an expression consisting of a sequence of multiplication *
operations, and then (2) the multiplication operations are evaluated to yield the result. In a typical run-time system, step (1) grows the call stack of activation records with recursive calls recording pending evaluation (i.e., the *
operation), and each individual *
operation in step (2) is executed while unwinding the call stack on return. Our abstract notation for evaluation does not represent a call stack explicitly, but we can see the corresponding behavior in the growing “pending” expression.
Not all recursive functions require a call stack of activation records. In particular, when there’s nothing left to do on return, there is no “pending computation” to record. This kind of recursive function is called tail recursive. A tail recursive version of the factorial function is given below in .
def factorial(n: Int): Int = {
require(n >= 0)
def loop(acc: Int, n: Int): Int = n match {
case 0 => acc
case _ => loop(acc * n, n - 1)
}
loop(1, n)
}
factorial(3)
defined function factorial res4_1: Int = 6
Let us write out some steps of evaluating \(\texttt{factorial(3)}\) for this version:
factorial(3) |
\(\longrightarrow^{\ast}\) | loop(1, 3) |
\(\longrightarrow^{\ast}\) | loop(1 * 3, 2) |
|
\(\longrightarrow^{\ast}\) | loop(3 * 2, 1) |
|
\(\longrightarrow^{\ast}\) | loop(6 * 1, 0) |
|
\(\longrightarrow^{\ast}\) | 6 |
Observe that the acc
variable serves to accumulate the result. When we reach the base case (i.e., 0
), then we simply return the accumulator variable acc
. Notice that there is no expression gets built up during the course of the recursion. When the last call to loop
returns, we have the final result. It is an important optimization for compilers to recognize tail recursion and avoid building a call stack unnecessarily.
A tail-recursive function corresponds closely to a loop (e.g., a while
loop) but does not require mutation. For example, consider the following imperative version of factorial
:
def factorial(n: Int): Int = {
require(n >= 0)
println(s"factorial(n = $n)")
var acc = 1
var i = n
while (i != 0) {
println(s"acc -> $acc, i -> $i")
= acc * i
acc = i - 1
i }
println(s"acc -> $acc, i -> $i")
acc}
factorial(3)
factorial(n = 3)
acc -> 1, i -> 3
acc -> 3, i -> 2
acc -> 6, i -> 1
acc -> 6, i -> 0
defined function factorial res5_1: Int = 6
Conceptually, each iteration of the while
loop corresponds to a call to loop
. The value of acc
and i
in each iteration of the while
loop correspond to the values bound to acc
and n
on each tail-recursive call to loop
. We see this by comparing the instrumentation to print the values of acc
and i
on each loop iteration and the values of acc
and n
in each tail-recursive call.
def factorial(n: Int): Int = {
require(n >= 0)
println(s"factorial(n = $n)")
def loop(acc: Int, n: Int): Int = {
println(s"-->* loop(acc = $acc, n = $n)")
match {
n case 0 => acc
case _ => loop(acc * n, n - 1)
}
}
val r = loop(1, n)
println(s"-->* $r")
r}
factorial(3)
factorial(n = 3)
-->* loop(acc = 1, n = 3)
-->* loop(acc = 3, n = 2)
-->* loop(acc = 6, n = 1)
-->* loop(acc = 6, n = 0)
-->* 6
defined function factorial res6_1: Int = 6
Exercise 8.1 A very similar example to factorial
is to define the exponentiation function exp
that computes \(x^n\) for \(n \geq 0\).
def exp(x: Int, n: Int): Int = {
require(n >= 0)
???
}
assert(exp(2,4) == 16)
scala.NotImplementedError: an implementation is missing scala.Predef$.$qmark$qmark$qmark(Predef.scala:345) ammonite.$sess.cmd7$Helper.exp(cmd7.sc:3) ammonite.$sess.cmd7$Helper.<init>(cmd7.sc:5) ammonite.$sess.cmd7$.<clinit>(cmd7.sc:7)
Let us consider the fibonacci
function that computes the \(n^\text{\textrm{th}}\) Fibonacci number:
def fibonacci(n: Int): Int = {
require(n >= 0)
match {
n case 0 | 1 => 1
case _ => fibonacci(n - 1) + fibonacci(n - 2)
}
}
defined function fibonacci
The fibonacci
function is more interesting than factorial
because it makes two recursive calls. Is it terminating on all input \(n\)? Yes, we can reason by induction just like with factorial
.
Is it tail recursive? Most definitely not, as each recursive call awaits the result of the other recursive call to then apply +
on the results. This is potentially problematic because each call requires an allocation of a stack frame.
For fibonacci(
\(n\) )
, how many recursive calls are made? Let’s consider an instrumented version that records the stack depth
of and the count
on total calls to f
:
def fibonacci(n: Int): Int = {
require(n >= 0)
println(s"factorial($n)")
def f(n: Int, depth: Int, count: Int): (Int, Int) = {
val r = n match {
case 0 | 1 => (1, count)
case _ => {
val (b, countb) = f(n - 1, depth + 1, count + 1)
val (a, counta) = f(n - 2, depth + 1, countb + 1)
(a + b, counta)
}
}
println(s"${" " * depth}- f(n = $n, depth = $depth, count = $count) = $r")
r}
val (r, _) = f(n, 0, 1)
r}
fibonacci(0)
fibonacci(1)
fibonacci(2)
fibonacci(3)
fibonacci(4)
fibonacci(5)
factorial(0)
- f(n = 0, depth = 0, count = 1) = (1,1)
factorial(1)
- f(n = 1, depth = 0, count = 1) = (1,1)
factorial(2)
- f(n = 1, depth = 1, count = 2) = (1,2)
- f(n = 0, depth = 1, count = 3) = (1,3)
- f(n = 2, depth = 0, count = 1) = (2,3)
factorial(3)
- f(n = 1, depth = 2, count = 3) = (1,3)
- f(n = 0, depth = 2, count = 4) = (1,4)
- f(n = 2, depth = 1, count = 2) = (2,4)
- f(n = 1, depth = 1, count = 5) = (1,5)
- f(n = 3, depth = 0, count = 1) = (3,5)
factorial(4)
- f(n = 1, depth = 3, count = 4) = (1,4)
- f(n = 0, depth = 3, count = 5) = (1,5)
- f(n = 2, depth = 2, count = 3) = (2,5)
- f(n = 1, depth = 2, count = 6) = (1,6)
- f(n = 3, depth = 1, count = 2) = (3,6)
- f(n = 1, depth = 2, count = 8) = (1,8)
- f(n = 0, depth = 2, count = 9) = (1,9)
- f(n = 2, depth = 1, count = 7) = (2,9)
- f(n = 4, depth = 0, count = 1) = (5,9)
factorial(5)
- f(n = 1, depth = 4, count = 5) = (1,5)
- f(n = 0, depth = 4, count = 6) = (1,6)
- f(n = 2, depth = 3, count = 4) = (2,6)
- f(n = 1, depth = 3, count = 7) = (1,7)
- f(n = 3, depth = 2, count = 3) = (3,7)
- f(n = 1, depth = 3, count = 9) = (1,9)
- f(n = 0, depth = 3, count = 10) = (1,10)
- f(n = 2, depth = 2, count = 8) = (2,10)
- f(n = 4, depth = 1, count = 2) = (5,10)
- f(n = 1, depth = 3, count = 13) = (1,13)
- f(n = 0, depth = 3, count = 14) = (1,14)
- f(n = 2, depth = 2, count = 12) = (2,14)
- f(n = 1, depth = 2, count = 15) = (1,15)
- f(n = 3, depth = 1, count = 11) = (3,15)
- f(n = 5, depth = 0, count = 1) = (8,15)
defined function fibonacci res9_1: Int = 1 res9_2: Int = 1 res9_3: Int = 2 res9_4: Int = 3 res9_5: Int = 5 res9_6: Int = 8
Unfortunately, the growth of the number of recursive calls is exponential in \(n\). Thus, to compute fibonacci(40)
requires us to make more than a billion calls.
However, we can see from above that there is a wasted work in repeatedly computing smaller Fibonacci numbers. We can define a tail-recursive version of the fibonacci
function by computing the \(n^\text{\textrm{th}}\) Fibonacci number “bottom up” starting from the \(0^\text{\textrm{th}}\), \(1^\text{\textrm{st}}\), \(2^\text{\textrm{nd}}\), \(3^\text{\textrm{rd}}\), \(\ldots\) (using what you might remember from other classes as dynamic programming).
Exercise 8.2 Give a tail-recursive definition fib
of the Fibonacci function:
def fib(n: Int): Int = {
require(n >= 0)
???
}
defined function fib
See that you can compute much larger Fibonacci numbers using your linear-time tail-recursive implementation fib
compared to the direct recursive fibonacci
.