24  Higher-Order Functions

\(\newcommand{\TirName}[1]{\text{#1}} \newcommand{\inferrule}[3][]{ \let\and\qquad \begin{array}{@{}l@{}} \TirName{#1} \\ \displaystyle \frac{#2}{#3} \end{array} } \newcommand{\infer}[3][]{\inferrule[#1]{#2}{#3}} \)

Returning to programming principles, recall that in many languages like Scala, functions are first-class. What this means is that functions are values — they may be passed as arguments or returned as returned values from other functions. Functions that take function arguments are called higher-order functions.

24.1 Currying

Recall that we can write down function literals and bind them to variables:

(n: Int) => n + 1
((n: Int) => n + 1)(41)

val incr: Int => Int = { n => n + 1 }
incr(41)
res0_0: Int => Int = ammonite.$sess.cmd0$Helper$$Lambda$1811/0x0000000800a39840@edb2f86
res0_1: Int = 42
incr: Int => Int = ammonite.$sess.cmd0$Helper$$Lambda$1813/0x0000000800a3b040@3c5a87fa
res0_3: Int = 42

We have seen that with first-class functions (and lexical scoping), we do not need tuples or other data structures to have multi-parameter functions. In particular, a function that returns another function behaves like a multi-parameter function. This is called currying.

val plus: Int => Int => Int = { x => y => x + y }
plus(3)(4)
plus: Int => Int => Int = ammonite.$sess.cmd1$Helper$$Lambda$1925/0x0000000800a96040@7a746d0d
res1_1: Int = 7

Since currying is a common thing to do, Scala has some syntactic sugar for it:

def plus(x: Int)(y: Int): Int = x + y
plus(3)(4)
defined function plus
res2_1: Int = 7

One reason to use currying is to enable partial application. For example, we can define incr using plus:

val incr: Int => Int = plus(1)
incr(41)
incr: Int => Int = ammonite.$sess.cmd3$Helper$$Lambda$1942/0x0000000800a9b840@570dea5c
res3_1: Int = 42

Sometimes partial application is simply for defining new functions in terms of others in a compact manner. Other times, partial application enables some non-trivial partial computation.

This is a silly example to illustrate the latter, defining a function addToFactorial that computes the n \(!\) and then returns a function to add some number to that:

def addToFactorial(n: Int): Int => Int = {
  def factorial(n: Int): Int = n match {
    case 0 => 1
    case _ => n * factorial(n - 1)
  }
  val nth = factorial(n)
  m => nth + m
}
defined function addToFactorial

We can compute \(10!\) once and then reuse it with the function tenFactorialPlus:

val tenFactorialPlus = addToFactorial(10)
tenFactorialPlus(47)
tenFactorialPlus(59)
tenFactorialPlus: Int => Int = ammonite.$sess.cmd4$Helper$$Lambda$1970/0x0000000800ab2040@78256aa6
res5_1: Int = 3628847
res5_2: Int = 3628859

24.2 Collections and Callbacks

We have seen standard data types like lists, options, maps, and sets that are often called collections, as they are generic in the values they collect together (see Section 6.1).

What is common to collection libraries is that the client of the library must have some way to work with the elements managed by the collection. Because the client decides the element type, the library implements higher-order functions that take a callback function argument to tell the library “what to do with the elements.” For example, we have already seen one higher-order function foreach in the Scala standard library that enables the client to perform a side-effect for each element of a list:

List(1, 2, 3).foreach(println)
1
2
3

Note that Scala standard library chooses to define foreach as a method on objects of type List[A].

In the following, we describe some standard higher-order functions on collections. Our intent is to discuss the fundamental higher-order programming patterns. While the examples are drawn from the Scala standard library, the patterns reoccur in many other contexts and languages. We also do not intend to describe the application programming interface (API) exhaustively, see the API documentation for that or other sources for library-specific tutorials.

24.2.1 Map

Recall that we use lists directly by pattern matching and recursion. For example, we can define functions to increment or double each integer in a given List[Int] or to get the length of each string in a given List[String]:

def increment(l: List[Int]): List[Int] = l match {
  case Nil => Nil
  case h :: t => (h + 1) :: increment(t)
}
increment(List(1, 2, 3))
defined function increment
res7_1: List[Int] = List(2, 3, 4)
def double(l: List[Int]): List[Int] = l match {
  case Nil => Nil
  case h :: t => (h * 2) :: double(t)
}
double(List(1, 2, 3))
defined function double
res8_1: List[Int] = List(2, 4, 6)
def eachLength(l: List[String]): List[Int] = l match {
  case Nil => Nil
  case h :: t => h.length :: eachLength(t)
}
eachLength(List("Neo", "Trinity", "Morpheus"))
defined function eachLength
res9_1: List[Int] = List(3, 7, 8)

We see that transformation pattern is very common: we want to map each element from the input list to the corresponding element in the output list. We can implement this pattern generically given a callback function argument f: A => B that tells us how to map an A to a B:

def map[A, B](l: List[A])(f: A => B): List[B] = l match {
  case Nil => Nil
  case h :: t => f(h) :: map(t)(f)
}
defined function map

And we can then define the increment, double, and eachLength as a clients of the map function:

def increment(l: List[Int]): List[Int] = map(l) { h => h + 1 }
increment(List(1, 2, 3))
defined function increment
res11_1: List[Int] = List(2, 3, 4)
def double(l: List[Int]): List[Int] = map(l) { h => h * 2 }
double(List(1, 2, 3))
defined function double
res12_1: List[Int] = List(2, 4, 6)
def eachLength(l: List[String]): List[Int] = map(l) { h => h.length }
eachLength(List("Neo", "Trinity", "Morpheus"))
defined function eachLength
res13_1: List[Int] = List(3, 7, 8)

We have abstracted all of the common boilerplate code into the definition of map and have just what makes increment, double, and eachLength differ as the callback argument.

As noted above, the Scala standard library chooses to define these higher-order functions as methods on the List[A] data type, so we use the built-in version of map as follows:

List(1, 2, 3).map(i => i * 3)
res14: List[Int] = List(3, 6, 9)

Note that it is idiomatic in Scala to use the binary operator form for map:

List(1, 2, 3) map { i => i * 3 }
List(1, 2, 3) map { _ * 3 }
res15_0: List[Int] = List(3, 6, 9)
res15_1: List[Int] = List(3, 6, 9)

The binary operator form yields chains reminiscent of Unix pipes:

List(1, 2, 3) map { i => i * 3 } map { i => i + 1 }
List(1, 2, 3) map { _ * 3 } map { _ + 1 }
res16_0: List[Int] = List(4, 7, 10)
res16_1: List[Int] = List(4, 7, 10)

The chain of method call form is what modern Java (and other object-oriented languages) libraries call fluent interfaces:

List(1, 2, 3)
  .map(i => i * 3)
  .map(i => i + 1)

List(1, 2, 3)
  .map(_ * 3)
  .map(_ + 1)
res17_0: List[Int] = List(4, 7, 10)
res17_1: List[Int] = List(4, 7, 10)

Comprehensions

Scala has a loop-like form called a comprehension that translates to a map call:

for (i <- List(1, 2, 3)) yield i * 3
List(1, 2, 3) map { i => i * 3 }
res18_0: List[Int] = List(3, 6, 9)
res18_1: List[Int] = List(3, 6, 9)

A comprehension draws from set-comprehensions in mathematics:

\[ \{ i \cdot 3 \mid i \in \{ 1, 2, 3 \} \} \]

And Python has similar syntax for list-comprehensions:

Python
[i * 3 for i in [1, 2, 3]]

Comprehensions with constraints in mathematics and Python are also common:

\[ \{ i \cdot 3 \mid i \in \{ 1, 2, 3 \} \mathrel{\text{s.t.}} i \bmod 2 = 1 \} \]

Python
[i * 3 for i in [1, 2, 3] if i % 2 == 1]

and is supported in Scala:

for (i <- List(1, 2, 3) if i % 2 == 1) yield i * 3
res19: List[Int] = List(3, 9)

A constraint corresponds to first applying a filter:

List(1, 2, 3) filter { i => i % 2 == 1 } map { i => i * 3 }
res20: List[Int] = List(3, 9)

Because filtering and then mapping is common, Scala implements an optimization to record the filter to apply during the map.

List(1, 2, 3) filter { i => i % 2 == 1 }
List(1, 2, 3) filter { i => i % 2 == 1 } map { i => i }

List(1, 2, 3) withFilter { i => i % 2 == 1 }
List(1, 2, 3) withFilter { i => i % 2 == 1 } map { i => i}
res21_0: List[Int] = List(1, 3)
res21_1: List[Int] = List(1, 3)
res21_2: collection.WithFilter[Int, List[_]] = scala.collection.IterableOps$WithFilter@5181b20d
res21_3: List[Int] = List(1, 3)

The for-if-yield comprehension translates to a call of withFilter and then map:

for (i <- List(1, 2, 3) if i % 2 == 1) yield i * 3
List(1, 2, 3) withFilter { i => i % 2 == 1 } map { i => i}
res22_0: List[Int] = List(3, 9)
res22_1: List[Int] = List(1, 3)

Pattern Matching on the Formal Parameter

While using map, we often want to pattern match in the parameter of the callback. For example,

List(None, Some(3), Some(4), None) map { iopt => iopt match {
  case None => 0
  case Some(i) => i + 1
} }
res23: List[Int] = List(0, 4, 5, 0)

We can drop the the match part to get the same behavior:

List(None, Some(3), Some(4), None) map {
  case None => 0
  case Some(i) => i + 1
}
res24: List[Int] = List(0, 4, 5, 0)

In actuality, the version without match the Scala syntax for defining “partial functions,” which is a more specific version of “functions.”

24.2.2 FlatMap

A slight generalization of map and filter together is called flatMap. Compare and contrast the type and implementations of map and flatMap:

def map[A, B](l: List[A])(f: A => B): List[B] = l match {
  case Nil => Nil
  case h :: t => f(h) :: map(t)(f)
}

def flatMap[A, B](l: List[A])(f: A => List[B]): List[B] = l match {
  case Nil => Nil
  case h :: t => f(h) ::: flatMap(t)(f)
}
defined function map
defined function flatMap

A flatMap takes a callback function argument f: A => List[B], allowing us to define, for example, duplicate:

def duplicate[A](l: List[A]) = l flatMap { a => List(a, a) }
duplicate(List(1, 2, 3))
defined function duplicate
res26_1: List[Int] = List(1, 1, 2, 2, 3, 3)

The flatMap method takes its name from being a combination of map and flatten:

val mapped = List(1, 2, 3) map { a => List(a, a) }
val flattened = mapped.flatten
mapped: List[List[Int]] = List(List(1, 1), List(2, 2), List(3, 3))
flattened: List[Int] = List(1, 1, 2, 2, 3, 3)

While a direct implementation of map and filter is more efficient, we can see that flatMap is a generalization by defining map and filter using flatMap:

Exercise 24.1 Define map in terms of flatMap.

def map[A, B](l: List[A])(f: A => B): List[B] = ???
defined function map

Exercise 24.2 Define filter in terms of flatMap.

def filter[A](l: List[A])(f: A => Boolean): List[A] = ???
defined function filter

24.2.3 FoldRight

The map and flatMap offer transformations that stay within in the List type constructor. Let us look at examples addList and multList that summarize lists defined by direct recursion:

def addList(l: List[Int]): Int = l match {
  case Nil => 0
  case h :: t => h + addList(t) 
}
addList(List(1, 2, 3, 4))
defined function addList
res30_1: Int = 10
def multList(l: List[Int]): Int = l match {
  case Nil => 1
  case h :: t => h * multList(t)
}
multList(List(1, 2, 3, 4))
defined function multList
res31_1: Int = 24

We recognize this summarization pattern: we use a binary operator to fold the recursively accumulation with the current element:

def foldRight[A, B](l: List[A])(z: B)(bop: (A, B) => B): B = l match {
  case Nil => z
  case h :: t => bop(h, foldRight(t)(z)(bop))
}
defined function foldRight

And we can then define the addList and multList as a clients of the foldRight function:

def addList(l: List[Int]): Int = foldRight(l)(0) { (h, acc) => h + acc }
addList(List(1, 2, 3, 4))

def multList(l: List[Int]): Int = foldRight(l)(1) { (h, acc) => h * acc }
multList(List(1, 2, 3, 4))
defined function addList
res33_1: Int = 10
defined function multList
res33_3: Int = 24

Like map, foldRight is defined as a method on List[A] in Scala:

List(1, 2, 3, 4).foldRight(0) { (h, acc) => h + acc }
List(1, 2, 3, 4).foldRight(1) { (h, acc) => h * acc }
res34_0: Int = 10
res34_1: Int = 24

Catamorphism

Take a closer look at the foldRight implementation:

def foldRight[A, B](l: List[A])(z: B)(bop: (A, B) => B): B = l match {
  case Nil => z
  case h :: t => bop(h, foldRight(t)(z)(bop))
}
defined function foldRight

and we see that it abstracts exactly structural recursion over the inductive data type List[A] where the z parameter corresponds to Nil constructor and the bop parameter to the :: constructor. This pattern called a catamorphism can be translated into any inductive data type that abstracts the structural recursion with a parameter for each constructor. We say that foldRight is the catamorphism for List.

It is good practice to structural recursive functions using foldRight:

Exercise 24.3 Define map in terms of foldRight

def map[A,B](l: List[A])(f: A => B): List[B] = ???
defined function map

Exercise 24.4 Define append: (List[A], List[A]) => List[A] that appends together two lists into one list (i.e., returns l1 follows by l2) in terms of foldRight:

def append[A](l1: List[A], l2: List[A]): List[A] = ???
defined function append

24.2.4 Other Folds and Reduce

With lists, we have another common pattern: tail-recursive iteration. This pattern is abstracted with the foldLeft function:

def foldLeft[A, B](l: List[A])(z: B)(bop: (B, A) => B): B = {
  def loop(acc: B, l: List[A]): B = l match {
    case Nil => acc
    case h :: t => loop(bop(acc, h), t)
  }
  loop(z, l)
}
defined function foldLeft

Because multiplication is associative, we can also use the tail-recursive foldLeft to multiply the elements of an integer list:

List(1, 2, 3, 4).foldLeft(1) { (acc, h) => acc * h }
res39: Int = 24

The mnemonic for foldRight versus foldLeft is that foldRight accumulates from the right of the list, while foldLeft accumulates from the left.

A good exercise is to write tail-recursive iteration lists functions using foldLeft.

Exercise 24.5 Define reverse of a list in terms of foldLeft.

def reverse[A](l: List[A]): List[A] = ???
defined function reverse

Reduce

When the order does not matter because the binary operator associative, using fold method allows the library to do whatever is most efficient:

List(1, 2, 3, 4).fold(1)(_ * _)
res41: Int = 24

A further special case of using an associative operator binary on a non-empty list is reduce:

List(1, 2, 3, 4).reduce(_ * _)
res42: Int = 24

that picks an element as the starting accumulator.

24.3 Abstract Data Types

We have seen that Map and Set data types are unlike List are abstract data types where we cannot get at the underlying representation. They prevent the client from direct access to underlying balanced search tree representation to be able to maintain the balance and search invariants, allowing for efficient key-based lookup.

At the same time, higher-order functions enables them to present the same collection view as lists with map, flatMap, foldRight, and foldLeft:

val m = Map(2 -> List("two", "dos", "二"), 10 -> List("ten", "diez", "十"))
m: Map[Int, List[String]] = Map(
  2 -> List("two", "dos", "\u4e8c"),
  10 -> List("ten", "diez", "\u5341")
)
m map { case k -> words => k -> words.head }
res44: Map[Int, String] = Map(2 -> "two", 10 -> "ten")
m.foldRight(Nil: List[String]) {
  case (_ -> words, acc) => words.head :: acc
}
res45: List[String] = List("two", "ten")

Parallel and Distributed

This decoupling of the concrete representation from the higher-order accessor view is incredibly powerful. For example, the same client code using map and reduce on sequential collections can be re-used by loading a parallel collections library:

Scala Parallel Collections Library

Run the following cell to load the Scala Parallel Collections library.

scala.collection.parallel.CollectionConverters._
import $ivy.`org.scala-lang.modules::scala-parallel-collections:1.0.4`, scala.collection.parallel.CollectionConverters._
import $ivy.$                                                         , scala.collection.parallel.CollectionConverters._
val par0to9999 = (0 to 9999).toList.par
val sum = par0to9999.map(_ + 1).reduce(_ + _)
assert(sum == 50005000)
par0to9999: collection.parallel.immutable.ParSeq[Int] = ParVector(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, ...
sum: Int = 50005000

This same idea underlies big-data applications where the library takes care of scheduling distributed jobs with client code that also works in the small locally in memory.