Example #20: nth Fibonacci number
Q1 What is wrong with the following recursive approach to compute the nth Fibonacci number?
A1 It “hangs” for the larger values of “n” like 50.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
object RecursionVsIteration extends App { //1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144 //find nth fibonacci number def fibNaive(n: Long): Long = { if (n <= 1) n else fibNaive(n - 1) + fibNaive(n - 2) } println(fibNaive(8)) // 8th number is 21 println(fibNaive(12)) // 12th number is 144 println(fibNaive(50)) //....hangs } |
Q2 How about using an iterative approach as shown below?
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
object RecursionVsIteration extends App { //1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233 def fibLoop(n: Long) : Long = { if (n <= 1) return n; //first two fib numbers var a = 1L; //prev number var b = 1L; //current number for(i <- 1L to (n-2L)) { var tmp = b; b = a + b // next b a = tmp // b becomes next a } b } println(fibLoop(8)) // 8th number is 21 println(fibLoop(12)) // 12th number is 144 println(fibLoop(50)) // 50th number is 12586269025 } |
A2 The iterative approach uses mutable variables as shown above like “a” and “b”. Functional programming, by its very nature, encourages you to write thread-safe code. Immutability is key to writing thread-safe code. Immutable code can be easily parallelized. Recursion promotes immutability, but in this scenario normal recursion is not efficient.
Q3 How can we then use recursion, which promotes immutability to fix the above issue with the code hanging?
A3 “Tail recursion” to the rescue.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
object RecursionVsIteration extends App { //1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233 def fibTail(n: Long): Long = { if (n <= 1) return n; else fibRecursion(1,1, n-2) } def fibRecursion(a: Long, b: Long, n: Long): Long = { if (n == 0) b else fibRecursion(b, a + b, n-1) } println(fibTail(8)) // 8th number is 21 println(fibTail(12)) // 12th number is 144 println(fibTail(50)) // 50th number 12586269025 } |
Example #21: Higher Order Function
Here is a very simple Scala code without using any unit testing frameworks to test the above Fibonacci number code.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import com.ex1.RecursionVsIteration._; object RecursionVsIterationTest extends App { def test(): Unit = { assert(1L == fibTail(1)) assert(1L == fibTail(2)) assert(21L == fibTail(8)) assert(144L == fibTail(12)) assert(12586269025L == fibTail(50)) } test() } |
Q4 What if you want to test all three functions that demonstrated naive, looping, and tail recursion without having to copy the contents of the “test” method 3 times?
A4 “Higher Order” functions to the rescue. A higher order function is a function that takes another function as an input parameter or returns a function. In the following example, it makes the code reusable for different functions like “fibNaive”, “fibTail”, and “fibLoop” by modifying the “test()” function to take a function as an input parameter. “Long => Long” is a function.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import com.ex1.RecursionVsIteration._; object RecursionVsIterationTest extends App { def test(f: Long => Long): Unit = { assert(1L == f(1)) assert(1L == f(2)) assert(21L == f(8)) assert(144L == f(12)) assert(12586269025L == f(50)) } test({fibNaive}) test({fibTail}) test({fibLoop}) } |
Q5 Is there anything else not right with the above Fibonacci code?
A5 #1. The data type “Long” will start to overflow for larger values of “n” like the 95th Fibonacci number. So, favor using a BigInt data type as opposed to a Long.
#2. Use nested functions as shown below.
#3. Use pattern matching instead of if/else.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import scala.annotation.tailrec object RecursionVsIteration extends App { val zero = BigInt(0) def fibTail(n: Long): BigInt = { if(n <= 1) return 1 @tailrec def fibRecursion(a: BigInt, b: BigInt, n: BigInt): BigInt = { n match { case `zero` => b //wrap with "`" to match against a variable case _ => fibRecursion(b, a + b, n - 1) } } fibRecursion(1, 1, n - 2) } println(fibTail(95)) // 95th number 31940434634990099905 } |
Note: The variable “zero
” is wrapped with a single quote in pattern matching. The @tailrec annotation checks if the annotated function contains a tail recursion, and if it doesn’t, gives a compilation error.
Scala functions like foldLeft, foldRight, scanLeft, scanRight, etc are accumulating functions that make use of recursion.