Saturday, 10 October 2009

C style for loops in Scala with break and continue

(If you don't want to read about how I ended up hand rolling C style loops in scala then go straight to the end result here.)

Becoming really rich with C# is a great example of what's coming in the version of C# in Visual Studio 2010 and it makes it obvious that C# is leaving Java in the dust. That why I'm looking at scala.

As an excercise, I translated the c# version of the stdDev function into Scala, copying the imperative non-functional style in the original c# code

Imperative version of stddev
  // 1 for 1 translation of imperative c# version
def stdDevImperative(): Double = {
var now = Dates.now()
val limit = now - 3.years
var rets = List[Double]()
while (now >= _start + 12.days && now >= limit) {
val ret = getReturn(now - 7.days, now);
val retd = ret.get // might throw Exception (same as c# code)
rets = rets ::: retd :: Nil
now = now - 7.days;
}
val mean = average(rets)
val variance = sum(rets.map(r => Math.pow(r - mean, 2))) / rets.length
val weeklyStdDev = Math.sqrt(variance)
return weeklyStdDev * Math.sqrt(40)
}


Couple of notes here. I'm using the joda time scala wrapper (import org.scala_tools.time.Imports._). The sum() and average() functions should be obvious. getReturn(now - 7.days, now) returns an Option[Double] of the increase in price in the last seven days as per the c# version.





If you are learning Scala, you can stop right here! The imperative version shown above is absolutely fine. What follows here is an intellectual exercise in using functional programming techniques to reorganize the algorithm so that it does not use vars. Many people would see this as a pointless exercise.


It also turns out that there is a very simple solution to this exercise which is described in the follow up section at the end of the post.





I wanted to convert my stdDevImperative to a functional style, getting rid of all those icky vars. So this was my first go:

Simple Recursive version of stddev


// simplistic recursive version of imperative c# version
def stdDevSimpleRecursive(): Double = {
var now = Dates.now()
val limit = now - 3.years
def getListOfReturns(d : DateTime) : List[Double] = {
if (d >= _start + 12.days && d >= limit)
getReturn(d - 7.days, d).get :: getListOfReturns(d - 7.days) // Not TCO
else
Nil
}
val rets = getListOfReturns(now)
val mean = average(rets)
val variance = sum(rets.map(r => Math.pow(r - mean, 2))) /rets.length
val weeklyStdDev = Math.sqrt(variance);
return weeklyStdDev * Math.sqrt(40);
}


No more vars, but its longer and the algorithm is less obvious than the imperative version. (At least for me). Also the compiler cannot perform "tail call optimisation". And so it munches stack.

Next step: allow the compiler to use "tail call optimisation".

TCO Recursive version of stddev


// recursive version allowing TCO and insert result at head of list
def stdDevTCORecursive(): Double = {
var now = Dates.now()
val limit = now - 3.years
def getListOfReturns(results: List[Double], d : DateTime) : List[Double] = {
if (d >= _start + 12.days && d >= limit) {
val newResults = getReturn(d - 7.days, d).get :: results // optimised to add to head not tail
getListOfReturns(newResults, d - 7.days) // TCO possible
} else
results
}
val rets = getListOfReturns(List(),now).reverse // reverse unnecessary for algorithm, but matches c# ordering
val mean = average(rets)
val variance = sum(rets.map(r => Math.pow(r - mean, 2))) /rets.length
val weeklyStdDev = Math.sqrt(variance);
return weeklyStdDev * Math.sqrt(40);
}

Well it works, but its getting uglier and the actual algorithm is being obscured. (Yeah I know, premature optimisation is the root of all evil. Yada. But TCO is pretty basic).

So I got to thinking if there was a better algoirthm and couldn't think of one. All the neat algorithms use a starting list/sequence and I couldn't see one. We are going back in time 7 days at a time, stoping when we get to the first date (_start) of the available price history or we have collected 3 years worth of returns. Hell, I liked the imperative algorithm. Maybe that will change with as my brain adapts to functional techniques. I though what I need is a way of factoring out the recursion into a "thing".

Enter the unfold function described nicely by David Pollak. I though that this might be the solution. So here goes:

Unfold version of stddev


// This unfold function should probably be built into a scala standard library somewhere
def unfold[T, R](init: T)(f: T => Option[(R, T)]): List[R] = f(init) match {
case None => Nil
case Some(r, v) => r :: unfold(v)(f)
}

def stdDevUnfold(): Double = {
val now = Dates.now()
val limit = now - 3.years
def getReturnOnDayAndNextDay(someDay: DateTime) : Option[(Double, DateTime)] = {
if (someDay < _start || someDay < limit) return None
val ret = getReturn(someDay - 7.days, someDay)
Some((ret.get, someDay - 7.days)) // might throw Exception (same as c# code)
}
val rets = unfold(now)(getReturnOnDayAndNextDay)
val mean = average(rets)
val variance = sum(rets.map(r => Math.pow(r - mean, 2)))
val weeklyStdDev = Math.sqrt(variance / rets.length);
return weeklyStdDev * Math.sqrt(40);
}


So the recursion is now out of my stdDev function but the code is even more obscure. The problem is that the unfold parameter function f() has two jobs to do. (1) manage a loop counter, (in this case the current date) and (2) return a value to be collected in the returned results.

So I liked the idea of unfold because it hides the recursion, but not the multiple jobs that the function has to do. The obvious answer is to have more than one function argument: One function to get the next value for the iteration or terminate and another function that works out the result. Hang on a minute, that first function is doing two jobs: getting the next value and working out when to terminate. How about three functions? Hey wait while we're there...
This is what I see in my head as the clearest way of defining the algorithm in a sort of java/pseudocode


// Java like pseudocode
ArrayList a = new ArrayList();
for (DateTime d = today(); d > firstDateForWhichPriceIsAvaialble && d > today - 3.years; d = d - 7.days) {
a.append(getReturn(d, d-7.days));
}


Final version of stddev with forloop


The challenge was to come up with something completely stateless of course. After many iterations (ha ha) this is what I have ended up with as my new stdDev method:



def stdDevForLoop(): Double = {
val today = Dates.now()
val limit = today - 3.years
val rets =
forloop(today)(d => d >= _start && d >= limit)(d => d - 7.days) {
day =>
val ret = getReturn(day - 7.days, day)
ret.get // might throw Exception (same as c# code)
}
val mean = average(rets)
val variance = sum(rets.map(r => Math.pow(r - mean, 2))) / rets.length
val weeklyStdDev = Math.sqrt(variance);
return weeklyStdDev * Math.sqrt(40);
}

I was pretty happy with the result, which, to my non-functional brain, expresses the algorithm (slightly) more clearly than the original imperative algorithm but without using any vars. The recursion is neatly tucked away in the forloop method. The definition of forloop is shown at the end.

The icing on the cake: break and continue


I realised that c style break and continue would be relatively easy to implement copying the techniques used in breakable. As ever the devil is in the detail: as the forloop body yields a value that is collected in the list of results, I added a breakWith 'keyword' too so that you can return a last value when you break. So the following becomes possible:


def stdDevForLoopBreakAndContinue(): Double = {
val today = Dates.now()
val limit = today - 3.years
val rets = forloop(today)(d => d >= _start && d >= limit)(d => d - 7.days) {
day =>
val ret = getReturn(day - 7.days, day)
if (day == blackFriday1981) continue
if (day == blackWednesday1973) breakWith(0d) // or just plain break, if no final value required
ret.get // might throw Exception (same as c# code)
}
val mean = average(rets)
val variance = sum(rets.map(r => Math.pow(r - mean, 2))) / rets.length
val weeklyStdDev = Math.sqrt(variance);
return weeklyStdDev * Math.sqrt(40);
}

What's missing?

  • Multiple loop counters. Of course you can do this by using a Tuple as the loop counter containing as many 'embedded' loop counters as you need, but this gets Lispy with the brackets. Maybe there is a nicer way.
Here's the code for forloop:



trait ForLoops
{


def forloop[T, R]
(init: T)
(cond: T => Boolean)
(next: T => T)
(body: T => R)
: List[R] =
{
def _forloop(currentResults: List[R])(loopcounter: T) : List[R] = {
if (cond(loopcounter)) {
// Get the newResults by appending the result of the body() function on the current loopcounter
val newResults = try {
// normal case: body() returns something based on loopcounter
val bodyResult = body(loopcounter)

bodyResult match {
// if the body has no return value (Unit) then don't bother appending to the newResults
case _:Unit => currentResults
// if the body has a return value then append to the newResults
case _ => bodyResult :: currentResults
}

} catch {
case ex: ContinueException => currentResults // don't add to the currentResults
case ex: BreakException => return currentResults // exit loop so return currentResults
case BreakWithException(v) => return v.asInstanceOf[R] :: currentResults
}
// This can be tail call optimized
// Recursive call to loop with the next value of the loop counter
_forloop(newResults)(next(loopcounter))
} else {
// loop has terminated so return results
currentResults
}
}
// Seed the _forloop currentResults with an empty list and the first value (init) of the loopcounter
_forloop(List())(init).reverse
}

// Copied break and continue technique from "breakable" in Scala 2.8
def continue = throw continueException
def break = throw breakException
def breakWith(v:Any) = throw new BreakWithException(v)

private class ContinueException extends RuntimeException
private val continueException = new ContinueException

private class BreakException extends RuntimeException
private val breakException = new BreakException

private case class BreakWithException(v:Any) extends RuntimeException {
override def fillInStackTrace():Throwable = {
// this is a performance optimisation. the actual stack trace is not required.
return null;
}
}

}


Follow up: An implementation using Iterator.iterate:

In the next couple of posts (in the future) I look at various alternatives to the forloop function that I have described above and got some very helpful comments from various people. The simplest (and built-in) solution that I found was using the new Iterator.iterate method in Scala 2.8pre along with takeWhile and map. Also, a very neat alternative to Iterator.iterate is to use define an iterate method that uses lazy Stream.



def stdDevIteratorIterate(): Double = {
val today = Dates.now()
val limit = today - 3.years
val dates = Iterator.iterate(today)(_ - 7.days) takeWhile (d => d >= _start && d >= limit) toList
val rets = dates.map(d => getReturn(d - 7.days, d).get)
val mean = rets.average
val variance = rets.map(r => Math.pow(r - mean, 2)).average
val weeklyStdDev = Math.sqrt(variance);
return weeklyStdDev * Math.sqrt(40);
}
Post a Comment