Thursday 29 October 2009

Becoming really rich with Scala

Update: 31 March 2014: This article uses Scala 2.8
  • Updated version of this  code using scala 2.11 available in gtihub
  •   git clone https://github.com/azzoti/get-rich-with-scala.git
  • Its an eclipse scala ide project and a maven project
  • Can be run with:
  •   mvn scala:run -DmainClass=etf.analyzer.Program

Becoming really rich with C# was a great example of what's coming in the version of C# in Visual Studio 2010 and it makes it obvious that C# is leaving Java in the dust. Now that the Visual Studio 2010 beta 2 is available, you can download Luca's code and try it out.

For this post, I've translated the C# code into Scala while trying to preserve the C# original style. To do this I have added support code in order to match the C# style. This was easy and shows off Scala's extensibility.

The features/libraries or libraries added or used to do this are:
  • Scala-Time: A Java Joda Time library wrapper.
  • The "using" block from Martin Odersky's FOSDEM '09 talk
  • An EventHandler class for simulating C# Events
  • The Jetty HTTPClient from Eclipse that I wrapped to resemble the C# WebClient api.
  • Artithmetic operations for Option[Double]. (Option[Double] is Scala's equivalent of the C# nullable type double? In C# you can use double? variables in expresssions, with expressions returning null if any part of the expression is null. In Scala, you can't use Option[Double] in arithmetic expressions out of the box, but its very easy to add this ability in a small library.
  • The Scala code is written with the latest 2.8 pre version of Scala and uses one or two features from its latest standard library not present in the latest stable release.

While the Scala is slighly shorter than the C# code, it is supported by extra code or libraries that I have found or had to write. C# already has using blocks, Events and reasonable datetime management, a WebClient and Nullable double types that handle arithmentic operations sensibly.

For whatever reason, the Scala code runs much faster than the c# code, but there is a large amount of internet access involved and I suspect that the C# web client should be configured to use more threads. [Update: Luca just suggested I comment out the C# line ServicePointManager.DefaultConnectionLimit = 10; and this does indeed make the C# code much faster.]



Original C#Scala
See notes after the table

                                                                                
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using System.IO;

namespace ETFAnalyzer {

struct Event {
  internal Event(DateTime date, double price) { Date = date; Price = price; }
  internal readonly DateTime Date;
  internal readonly double Price;
}
class Summary {
  internal Summary(string ticker, string name, string assetClass,
          string assetSubClass, double? weekly, double? fourWeeks,
          double? threeMonths, double? sixMonths, double? oneYear,
          double? stdDev, double price, double? mav200) {
    Ticker = ticker;
    Name = name;
    AssetClass = assetClass;
    AssetSubClass = assetSubClass;
    // Abracadabra ...
    LRS = (fourWeeks + threeMonths + sixMonths + oneYear) / 4;
    Weekly = weekly;
    FourWeeks = fourWeeks;
    ThreeMonths = threeMonths;
    SixMonths = sixMonths;
    OneYear = oneYear;
    StdDev = stdDev;
    Mav200 = mav200;
    Price = price;
  }
  internal readonly string Ticker;
  internal readonly string Name;
  internal readonly string AssetClass;
  internal readonly string AssetSubClass;
  internal readonly double? LRS;
  internal readonly double? Weekly;
  internal readonly double? FourWeeks;
  internal readonly double? ThreeMonths;
  internal readonly double? SixMonths;
  internal readonly double? OneYear;
  internal readonly double? StdDev;
  internal readonly double? Mav200;
  internal double Price;

  internal static void Banner() {
    Console.Write("{0,-6}", "Ticker");
    Console.Write("{0,-50}", "Name");
    Console.Write("{0,-12}", "Asset Class");
    Console.Write("{0,4}", "RS");
    Console.Write("{0,4}", "1Wk");
    Console.Write("{0,4}", "4Wk");
    Console.Write("{0,4}", "3Ms");
    Console.Write("{0,4}", "6Ms");
    Console.Write("{0,4}", "1Yr");
    Console.Write("{0,6}", "Vol");
    Console.WriteLine("{0,2}", "Mv");
  }

  internal void Print() {

    Console.Write("{0,-6}", Ticker);
    Console.Write("{0,-50}", new String(Name.Take(48).ToArray()));
    Console.Write("{0,-12}", new String(AssetClass.Take(10).ToArray()));
    Console.Write("{0,4:N0}", LRS * 100);
    Console.Write("{0,4:N0}", Weekly * 100);
    Console.Write("{0,4:N0}", FourWeeks * 100);
    Console.Write("{0,4:N0}", ThreeMonths * 100);
    Console.Write("{0,4:N0}", SixMonths * 100);
    Console.Write("{0,4:N0}", OneYear * 100);
    Console.Write("{0,6:N0}", StdDev * 100);
    if (Price <= Mav200)
      Console.WriteLine("{0,2}", "X");
    else
      Console.WriteLine();
  }
}

class TimeSeries {
  internal readonly string Ticker;
  readonly DateTime _start;
  readonly Dictionary<DateTime, double> _adjDictionary;
  readonly string _name;
  readonly string _assetClass;
  readonly string _assetSubClass;

  internal TimeSeries(string ticker, string name, string assetClass, string assetSubClass, IEnumerable<event> events) {
    Ticker = ticker;
    _name = name;
    _assetClass = assetClass;
    _assetSubClass = assetSubClass;
    _start = events.Last().Date;
    _adjDictionary = events.ToDictionary(e => e.Date, e => e.Price);
  }

  bool GetPrice(DateTime when, out double price, out double shift) {
    // To nullify the effect of hours/min/sec/millisec being different from 0
    when = new DateTime(when.Year, when.Month, when.Day);
    var found = false;
    shift = 1;
    double aPrice = 0;
    while (when >= _start && !found) {
      if (_adjDictionary.TryGetValue(when, out aPrice)) {
        found = true;
      }
      when = when.AddDays(-1);
      shift -= 1;
    }
    price = aPrice;
    return found;
  }

  double? GetReturn(DateTime start, DateTime end) {
    var startPrice = 0.0;
    var endPrice = 0.0;
    var shift = 0.0;
    var foundEnd = GetPrice(end, out endPrice, out shift);
    var foundStart = GetPrice(start.AddDays(shift), out startPrice, out shift);
    if (!foundStart || !foundEnd)
      return null;
    else
      return endPrice / startPrice - 1;
  }

  internal double? LastWeekReturn() {
    return GetReturn(DateTime.Now.AddDays(-7), DateTime.Now);
  }
  internal double? Last4WeeksReturn() {
    return GetReturn(DateTime.Now.AddDays(-28), DateTime.Now);
  }
  internal double? Last3MonthsReturn() {
    return GetReturn(DateTime.Now.AddMonths(-3), DateTime.Now);
  }
  internal double? Last6MonthsReturn() {
    return GetReturn(DateTime.Now.AddMonths(-6), DateTime.Now);
  }
  internal double? LastYearReturn() {
    return GetReturn(DateTime.Now.AddYears(-1), DateTime.Now);
  }
  internal double? StdDev() {
    var now = DateTime.Now;
    now = new DateTime(now.Year, now.Month, now.Day);
    var limit = now.AddYears(-3);
    var rets = new List<double>();
    while (now >= _start.AddDays(12) && now >= limit) {
      var ret = GetReturn(now.AddDays(-7), now);
      rets.Add(ret.Value);
      now = now.AddDays(-7);
    }
    var mean = rets.Average();
    var variance = rets.Select(r => Math.Pow(r - mean, 2)).Sum();
    var weeklyStdDev = Math.Sqrt(variance / rets.Count);
    return weeklyStdDev * Math.Sqrt(40);
  }
  internal double? MAV200() {
    return _adjDictionary.ToList()
           .OrderByDescending(k => k.Key)
           .Take(200).Average(k => k.Value);
  }
  internal double TodayPrice() {
    var price = 0.0;
    var shift = 0.0;
    GetPrice(DateTime.Now, out price, out shift);
    return price;
  }
  internal Summary GetSummary() {
    return new Summary(Ticker, _name, _assetClass, _assetSubClass,
           LastWeekReturn(), Last4WeeksReturn(), Last3MonthsReturn(),
           Last6MonthsReturn(), LastYearReturn(), StdDev(), TodayPrice(), 
           MAV200());
  }
}

class Program {

  static string CreateUrl(string ticker, DateTime start, DateTime end)
  {
    return @"http://ichart.finance.yahoo.com/table.csv?s=" + ticker + 
      "&a="+(start.Month - 1).ToString()+"&b="+start.Day.ToString()+"&c="+start.Year.ToString() + 
      "&d="+(end.Month - 1).ToString()+"&e="+end.Day.ToString()+"&f="+end.Year.ToString() + 
      "&g=d&ignore=.csv";
  }

  static void Main(string[] args) {
    // If you rise this above 5 you tend to get frequent connection closing on my machine
    // I'm not sure if it is msft network or yahoo web service
    ServicePointManager.DefaultConnectionLimit = 10;

    var tickers =
      File.ReadAllLines("ETFTest.csv")
      .Skip(1)
      .Select(l => l.Split(new[] { ',' }))
      .Where(v => v[2] != "Leveraged")
      .Select(values => Tuple.Create(values[0], values[1], values[2], values[3]))
      .ToArray();

    var len = tickers.Length;

    var start = DateTime.Now.AddYears(-2);
    var end = DateTime.Now;
    var cevent = new CountdownEvent(len);
    var summaries = new Summary[len];
    
    for(var i = 0; i < len; i++)  {
      var t = tickers[i];
      var url = CreateUrl(t.Item1, start, end);
      using (var webClient = new WebClient()) {
        webClient.DownloadStringCompleted +=
                        new DownloadStringCompletedEventHandler(downloadStringCompleted);
        webClient.DownloadStringAsync(new Uri(url), Tuple.Create(t, cevent, summaries, i));
      }
    }

    cevent.Wait();
    Console.WriteLine("\n");

    var top15perc =
        summaries
        .Where(s => s.LRS.HasValue)
        .OrderByDescending(s => s.LRS)
        .Take((int)(len * 0.15));
    var bottom15perc =
        summaries
        .Where(s => s.LRS.HasValue)
        .OrderBy(s => s.LRS)
        .Take((int)(len * 0.15));

    Console.WriteLine();
    Summary.Banner();
    Console.WriteLine("TOP 15%");
    foreach(var s in top15perc)
      s.Print();

    Console.WriteLine();
    Console.WriteLine("Bottom 15%");
    foreach (var s in bottom15perc)
      s.Print();
      
  }

  static void downloadStringCompleted(object sender, DownloadStringCompletedEventArgs e) {
    var bigTuple = (Tuple<Tuple<string, string, string, string>, CountdownEvent, Summary[], int>)e.UserState;
    var tuple = bigTuple.Item1;
    var cevent = bigTuple.Item2;
    var summaries = bigTuple.Item3;
    var i = bigTuple.Item4;
    var ticker = tuple.Item1;
    var name = tuple.Item2;
    var asset = tuple.Item3;
    var subAsset = tuple.Item4;

    if (e.Error == null) {
      var adjustedPrices =
          e.Result
          .Split(new[] { '\n' })
          .Skip(1)
          .Select(l => l.Split(new[] { ',' }))
          .Where(l => l.Length == 7)
          .Select(v => new Event(DateTime.Parse(v[0]), Double.Parse(v[6])));

      var timeSeries = new TimeSeries(ticker, name, asset, subAsset, adjustedPrices);
      summaries[i] = timeSeries.GetSummary();
      cevent.Signal();
      Console.Write("{0} ", ticker);
    } else {
      Console.WriteLine("[{0} ERROR] ", ticker);
      summaries[i] = new Summary(ticker,name,"ERROR","ERROR",0,0,0,0,0,0,0,0); 
      cevent.Signal();
    }
  }
}
}

                                                                                
package etf.analyzer

import scala.io.Source
import org.scala_tools.time.Imports._
import org.scala_tools.option.math.Imports._  
import org.joda.time.Days
import org.scala_tools.using.Using
import org.scala_tools.web.WebClient
import org.scala_tools.web.WebClientConnections
import org.scala_tools.web.DownloadStringCompletedEventArgs 
import java.io.File
import java.util.concurrent.CountDownLatch 

case class Event (date : DateTime, price : Double) {}


case class Summary (
  ticker : String, name : String, assetClass : String,
  assetSubClass : String, weekly : Option[Double], 
  fourWeeks : Option[Double], threeMonths : Option[Double], 
  sixMonths : Option[Double], oneYear : Option[Double],
  stdDev : Double, price : Double, mav200 : Double 
) {



  // Abracadabra ...
  val LRS = (fourWeeks + threeMonths + sixMonths+ oneYear) / 4   














  




  def banner() = {
    printf("%-6s", "Ticker")
    printf("%-50s", "Name")
    printf("%-12s", "Asset Class")
    printf("%4s", "RS")
    printf("%4s", "1Wk")
    printf("%4s", "4Wk")
    printf("%4s", "3Ms")
    printf("%4s", "6Ms")
    printf("%4s", "1Yr")
    printf("%6s", "Vol")
    printf("%2s\n", "Mv")
  }
  
  def print() = {
  
    printf("%-6s", ticker);
    printf("%-50s", new String(name.toArray.take(48)))
    printf("%-12s", new String(assetClass.toArray.take(10)));
    printf("%4.0f", LRS * 100 getOrElse null)
    printf("%4.0f", weekly * 100 getOrElse null)
    printf("%4.0f", fourWeeks * 100 getOrElse null)
    printf("%4.0f", threeMonths * 100 getOrElse null)
    printf("%4.0f", sixMonths * 100 getOrElse null)
    printf("%4.0f", oneYear * 100 getOrElse null)
    printf("%6.0f", stdDev * 100);
    if (price <= mav200) {
      printf("%2s\n", "X");
    } else {
      println();
    }
  }  
}

case class TimeSeries (
    ticker : String, name : String, assetClass : String, 
    assetSubClass : String, private events : Iterable[Event]
) {
  
  private val _adjDictionary : Map[DateTime, Double] 
              = Map() ++ events.map(e => (e.date -> e.price))
  private val _start = events.last.date




  // Add the sum and average function to all Iterables[Double] used locally
  private implicit def iterableWithSumAndAverage(c: Iterable[Double]) = new { 
    def sum = c.foldLeft(0.0)(_ + _) 
    def average = sum / c.size
  }  
  
  def getPrice(whenp : DateTime) : Option[(Double,Int)] =  {
    var when = new DateTime(whenp.year.get,whenp.month.get,whenp.day.get,0,0,0,0)
    var found = false
    var shift = 1
    var aPrice = 0.0
    while (when >= _start && !found) {
        if (_adjDictionary.contains(when)) {
            aPrice = _adjDictionary(when)
            found = true
        }
        when = when - 1.days
        shift -= 1
    }
    // Either return the price and the shift or None if no price was found
    if (found) Some(aPrice,shift) else return None
  }  

  def getReturn(start: DateTime, end: DateTime) : Option[Double] = {
    for {
      (endPrice,daysBefore) <- -="" .takewhile="" 1.0="" 1.years="" 28.days="" 3.months="" 3.years="" 6.months="" 7.days="" d="" dates="Iterator.iterate(today)(_" datetime.now="" daysbefore.days="" def="" double="{" end="" endprice="" getprice="" last3monthsreturn="getReturn(DateTime.now" last4weeksreturn="getReturn(DateTime.now" last6monthsreturn="getReturn(DateTime.now" lastweekreturn="getReturn(DateTime.now" lastyearreturn="getReturn(DateTime.now" limit="today" private="" start="" startprice="" stddev="" today="DateTime.now" val="" yield=""> d >= (_start + 12.days) && d >= limit)
                .toList
    val rets = dates.map(d => getReturn(d - 7.days, d).get)
    val mean = rets.average
    val variance = rets.map(r => Math.pow(r - mean, 2)).average
    val weeklyStdDev = Math.sqrt(variance);
    return weeklyStdDev * Math.sqrt(40);
  }



  def mav200(): Double = {
    return _adjDictionary.toList
           .sortWith((elem1, elem2) => elem1._1 >= elem2._1)
           .take(200).map(keyValue => keyValue._2).average
  }
  def todayPrice() : Double = {
    getPrice(DateTime.now) match {
      case None => 0.0
      case Some((price,_)) => price 
    }
  }
  def getSummary() = 
    Summary(ticker, name, assetClass, assetSubClass, 
      lastWeekReturn, last4WeeksReturn, last3MonthsReturn, 
      last6MonthsReturn, lastYearReturn, stdDev, todayPrice, 
      mav200)
}

object Program extends Using {

  def createUrl(ticker: String, start: DateTime, end: DateTime) : String = {
    return """http://ichart.finance.yahoo.com/table.csv?s=""" + ticker +
      "&a="+(start.month.get-1)+ "&b=" + start.day.get + "&c=" + start.year.get +
      "&d="+(end.month.get  -1)+ "&e=" + end.day.get + "&f=" + end.year.get +
      "&g=d&ignore=.csv"
  } 


 
  def main(args : Array[String]) : Unit = {

    
    val tickers = 
     Source.fromFile(new File("ETFTest.csv")).getLines()
     .drop(1)
     .map(l => l.trim.split(','))
     .filter(v => v(2) != "Leveraged")
     .map(values => (values(0),values(1),values(2),if (values.length==4) values(3) else ""))
     .toSeq.toArray

    val len = tickers.length;

    val start = DateTime.now - 2.years
    val end = DateTime.now
    val cevent = new CountDownLatch(len)
    val summaries = new Array[Summary](len)
    
    using(new WebClientConnections(connectionsPerAddress = 10, threadPool=10)) {
      webClientConnections =>
      for (i <- .filter="" 0="" cevent.await="" cevent="" downloadstringcompleted="" end="" i="" len="" println="" s="" start="" summaries="" t="" top15perc="summaries" until="" url="" val="" webclient.downloadstringasync="" webclient.downloadstringcompleted="" webclient="webClientConnections.getWebClient"> s.LRS.isDefined)
      .sortWith((elem1, elem2) => elem1.LRS >= elem2.LRS)
      .take((len * 0.15).toInt)      
    val bottom15perc =
      summaries
      .filter(s => s.LRS.isDefined)
      .sortWith((elem1, elem2) => elem1.LRS <= elem2.LRS)
      .take((len * 0.15).toInt)

    println
    summaries(0).banner()
    println("TOP 15%")
    for (s <- .drop="" .map="" .split="" 15="" :="" adjustedprices="e.result" array="" asset="" bigtuple="" bottom15perc="" cevent="" countdownlatch="" datetimeformat.forpattern="" def="" downloadstringcompleted="" downloadstringcompletedeventargs="" e.error="=" e="" for="" i="" if="" int="" l="" n="" name="" null="" ottom="" parse="" parsedatetime="" println="" s.print="" s="" string="" subasset="" summaries="" ticker="" top15perc="" ummary="" val="" yyyy-mm-dd=""> l.split(','))
          .filter(l => l.length == 7)
          .map(v => Event(parse(v(0)),v(6).toDouble))

      val timeSeries = new TimeSeries(ticker, name, asset, subAsset, adjustedPrices);
      summaries(i) = timeSeries.getSummary();
      cevent.countDown() 
      printf("%s ", ticker)
    } else {
      printf("[%s ERROR] \n", ticker)
      summaries(i) = Summary(ticker,name,"ERROR","ERROR",Some(0),Some(0),Some(0),Some(0),Some(0),0,0,0) 
      cevent.countDown()
    }
  }  
}


Notes

TimeSeries getPrice method: Scala does not have output parameters on methods. It doesn't need them because the return type from a method can be a tuple and you can return as many values as you like. Also the method shown copies the C# style closely using loop variables. Another way of writing the same method in Scala making use of list functions is:
def getPrice(when : DateTime) : Option[(Double,Int)] =  {
  // Find the most recent day with a price starting from when, but don't go back past _start 
  val latestDayWithPrice 
      = Iterator.iterate(when)(_ - 1.days)
        .dropWhile (d=> !_adjDictionary.contains(d) && d >= _start )
        .next
  if (_adjDictionary.contains(latestDayWithPrice)) {
    val shift = Days.daysBetween(when,latestDayWithPrice).getDays()
    val aPrice = _adjDictionary(latestDayWithPrice) 
    Some((aPrice,shift))
  } else {
    None
  }
}  

TimeSeries getReturn method: The 2 calls to getPrice() return an Option[a price, days offset].
Dealing with Option[...] in a for expression is an easy way of dealing with the possibilty of either Option[...] being None. If either getPrice() call returns None, then the yield will return a None as well. Another perhaps simpler to understand getReturn implementation is:
def getReturn(start: DateTime, end: DateTime) : Option[Double] = {
  var endPriceDetails = getPrice(end)
  if (endPriceDetails == None) return None
  val (endPrice,daysBefore) = endPriceDetails.getOrElse(null)
  val startPriceDetails = getPrice(start + daysBefore.days)
  if (startPriceDetails == None) return None
  val (startPrice,_) = startPriceDetails.getOrElse(null)
  (endPrice / startPrice - 1.0)      
}

TimeSeries mav200 method: The scala version is slightly harder work than .net 3.5 LINQ OrderByDescending method with key selector syntax: .OrderByDescending(k => k.Key). The Scala version has to say *how* to do it. The LINQ version says *what* is required. The same is true for the C# use of the average function, which uses a field selector.

I'm not happy that I know whether the concurrent access to summaries array access in downloadStringCompleted is safe. It seems to work but I don't know if it is genuinely thread safe. I've just copied the C# code, which may have built-in thread safe array access.


Some of the features of Scala that are shown here
  • Easy Java library interop. See use of CountDownLatch, Days, File.
  • Good old fashioned casting if you really need it. See asInstanceOf.
  • No semicolons
  • No need to use () for a function declaration with no parameters or a call to ut. See TimeSeries.getSummary(). (brackets recommended if there are side effects)
  • Type declarations are unnecessary except in method parameters, but can be declared explicityly if it aids readability. See _adjDictionary.
  • Named and default parameters. See WebClientConnections
  • Much less boilerplate with "case" classes providing automatic constructors, fields, toString, equals, hashcode. See Event, Summary, TimeSeries
  • Joda time wrapper so you can say "today - 3.years"
  • Pattern matching assignment "val ((ticker,name),cevent,summaries,i) = bigTuple" in downloadStringCompleted
  • "using" block for automatic resource closing. In the main method, the "using(new WebClientConnections(" block will close down the WebClientConnections thread pool at the end of the block. This is very similar to the C# "using" code.
  • local "implicit" function definitions allowing you to effectively add methods to existing classes in a tightly controlled and scoped way. (see def iterableWithSumAndAverage)
  • Pattern matching, switch on steroids. See todayPrice().
  • Use of powerful list manipulation functions, such as Iterator.iterate, takeWhile to replace traditional state based loops. See iterate/dropWhile examples in stdDev() and in main(): drop, map, filter, sortWith, take. See the infamous foldLeft example at work in the sum function.

Tuesday 20 October 2009

Checking how to use blogger for side by side code comparison

I had to change the Blogger HTML template to do this.
.post {
padding-$startSide:0%;
padding-$endSide:0%;
}


                                                                                
object Primes {

def primes = {
  def sieve(is: Stream[Int]): Stream[Int] = is match { 
case p #:: xs => p #:: sieve(for (x <- xs if x % p > 0) yield x)
}
sieve(Stream from 2)
}

def main(args: Array[String]) {
primes take 100 foreach println
}
}


                                                                                
object Primes {

def primes = {
def sieve(is: Stream[Int]): Stream[Int] = is match { 
case p #:: xs => p #:: sieve(for (x <- xs if x % p > 0) yield x)
}
sieve(Stream from 2)
}

def main(args: Array[String]) {
primes take 100 foreach println
}
}

Wednesday 14 October 2009

Scala, lazy evaluation and the Sieve of Eratosthenes

I found a Scala algorithm for the Sieve of Eratosthenes here

Since that was written, the #:: object has been added to Stream and can be used instead of Stream.cons. Also as of a recent nightly build, #:: now object works in Stream pattern matching which means the Sieve of Eratosthenes can be even closer to the haskell implementation. (Presumably a good thing as it broadens your choices)


object Primes {

/* Haskell...
primes = sieve [2..]
sieve (p : xs) = p : sieve [x | x <− xs, x `mod` p > 0]
*/

def primes = {
def sieve(is: Stream[Int]): Stream[Int] = is match {
case p #:: xs => p #:: sieve(for (x <- xs if x % p > 0) yield x)
}
sieve(Stream from 2)
}

def main(args: Array[String]) {
primes take 100 foreach println
}
}


Nice!

Tuesday 13 October 2009

Scala Streams for iteration / Use streams to avoid TCO(?)


My last post Scala: C style iterate function for building lists got an amazing response from somebody called Johan. (Thanks Johan). It introduced me to the amazing power of lazy scala streams.

Here's what Johan said slightly edited (In Johan's comment, he used lazy_:: which is now called #:: in Scala 2.8):
   
[Your iterateFor] function is very similar to until (but with a reversed predicate) in Haskell: http://www.haskell.org/ghc/docs/latest/html/libraries/base/Prelude.html#v%3Auntil. Basically, it's iterate and takeWhile.

Unfortunately, the iterate functions in Scala's traversable collections are not polymorphic. (Maybe it's because it would only work for lazy collections.) But here's one for Stream:

def iterate[A](step: A => A)(seed: A): Stream[A] = seed #:: iterate(step)(step(seed))

Now it's easy to write iterateFor:

def iterateFor[A](predicate: A => Boolean)(step: A => A)(seed: A): Stream[A] = iterate(step)(seed) takeWhile predicate

or, all in one go:

def until[A](...) = if (!predicate(seed)) seed #:: until(predicate)(step)(step(seed)) else Stream.empty



Now the symbol #:: is a special lazy "cons" operator for Streams (used just like the List :: operator) which only gets evaluated on demand. The #:: operator can be used to define infinite lists as Johan has done with the iterate function.

Now before I get on to trying this stuff out, the way Johan defined the functions, you need to specify the type of [A] when you call them, for example iterate[Int].... But if you move the seed parameter from the last parameter to the first then you can get away with not defining A explicitly. So here are the functions reorganized to make them easier to use:


def iterate[A]
(seed: A)(step: A => A)
: Stream[A]
= seed #:: iterate(step(seed))(step)


def iterateFor[A] (seed: A)(predicate: A => Boolean)(step: A => A)
: Stream[A]
= iterate(seed)(step) takeWhile predicate

def until[A]
(seed: A)(predicate: A => Boolean)(step: A => A)
: Stream[A]
= if (!predicate(seed)) seed #:: until(step(seed))(predicate)(step) else Stream.empty



So what can I do with all this ****? Well first of all, iterate defines an infinite list, that is evaluated as needed so the following loop will iterate forever printing as it goes,

for (i <- iterate(0)(_ + 1)) {
println(i)
}

as will

iterate(0)(_ + 1).foreach(println)

Note that I said printing as it goes.

If you defined iterator using List instead of Stream replacing #:: with :: , then the loops above would get a StackOverflowError after munching stack on each iteration, but also, nothing would be printed, as the entire (infinite) List is calculated before anything is printed.

So what use is the iterate function using Streams? Well you can reinvent a convoluted for loop, by taking the first n of the infinite stream.

for (i <- iterate(0)(_ + 1) take 10 ) {
println(i)
}


Big deal, but more importantly its a building block: Johan went on to redefine my original iterateFor function (see the last post) using iterator and takeWhile.

def iterateFor[A] (seed: A)(predicate: A => Boolean)(step: A => A)
: Stream[A]
= iterate(seed)(step) takeWhile predicate


How simple is that!

As he also pointed out iterateFor can be defined without iterate like this:

def iterateFor[A](seed: A)(predicate: A => Boolean)(step: A => A) : Stream[A]
= if (predicate(seed)) seed #:: iterateFor(step(seed))(predicate)(step) else Stream.empty

except that he called that version "until" to match the haskell function. I don't like this name for two reasons in scala: (1) "until" is already used in "for (i <- 1 until 10)" and (2) by moving the seed parameter to the first parameter (as I explained above), until(0)(_ >= 10)(_ + 1) does not read well.

A Reminder About What This Is All For!

iterateFor is not for counting from 1 to 10! Use the normal Scala for (i <- 1 to 10) for that!

This is what I originally wanted it for:

val dates = iterateFor(today)(d => d >= _start && d >= limit)(d => d - 7.days)


The elephants in the room

1. Stack? These lazy stream versions of iterateFor DO NOT munch stack even though they do not look like they can be TCO'd. Amazing. I don't understand how it works, but it does. Maybe they munch alot more heap . Who knows.

2. Performance. I have not studied this in detail, but from what I have found with some simple (and probably flawed) tests, there is NO performance penalty. The use of streams in this way seems to be at least as fast as the TCOd recursive algorithms and the imperative while algorithm.




Monday 12 October 2009

Scala: C style iterate function for building lists

In my previous post, I used a forloop function for creating a List from a sort of c-like for loop

val rets = forloop(today)(d => d >= _start && d >= limit)(d => d - 7.days) {
day =>
val ret = getReturn(day - 7.days, day)
ret.get // might throw Exception (same as c# code)
}


I liked the fact that I could use break and continue semantics in the loop although I didn't need it in the algorithm.

But perhaps foreach is/was trying to do too much, effectively (a) producing a list of dates (etc) and then (b) for each date in the List calculating another value (a Double) to return in a list. Of course (b) is just the map function. So here's another function that I've called iterateFor which only does (a). Using it to replace the code above gives:


val dates = iterateFor(today)(d => d >= _start && d >= limit)(d => d - 7.days)
val rets = dates.map(d => getReturn(d - 7.days, d).get)


I was hoping to find something already built it to the scala libraries that would do what iterateFor does . As Daniel pointed out in the comments there is Iterator.iterate:

def iterate [T](start : T)(f : (T) => T) : Iterator[T]


and if you use takeWhile on the Iterator then the implementation of iterateFor is simple:


def iterateFor[T]
(init: T)
(cond: T => Boolean)
(next: T => T)
: List[T] =
{
Iterator.iterate(init)(next) takeWhile cond toList
}


Or it can be implemented as a recursive function like this:


def iterateFor[T]
(init: T)
(cond: T => Boolean)
(next: T => T)
: List[T] =
{
def _iterateFor(currentResults: List[T])(loopcounter: T) : List[T] = {
if (cond(loopcounter)) {
val newResults = loopcounter :: currentResults
_iterateFor(newResults)(next(loopcounter))
} else {
currentResults
}
}
_iterateFor(List())(init).reverse
}


I suspect that there are plenty of little while loop type algorithms that use vars that could be replaced by iterateFor (or forloop).

Saturday 10 October 2009

C style for loops in Scala with break and continue

(If you don't want to read about how I ended up hand rolling C style loops in scala then go straight to the end result here.)

Becoming really rich with C# is a great example of what's coming in the version of C# in Visual Studio 2010 and it makes it obvious that C# is leaving Java in the dust. That why I'm looking at scala.

As an excercise, I translated the c# version of the stdDev function into Scala, copying the imperative non-functional style in the original c# code

Imperative version of stddev
  // 1 for 1 translation of imperative c# version
def stdDevImperative(): Double = {
var now = Dates.now()
val limit = now - 3.years
var rets = List[Double]()
while (now >= _start + 12.days && now >= limit) {
val ret = getReturn(now - 7.days, now);
val retd = ret.get // might throw Exception (same as c# code)
rets = rets ::: retd :: Nil
now = now - 7.days;
}
val mean = average(rets)
val variance = sum(rets.map(r => Math.pow(r - mean, 2))) / rets.length
val weeklyStdDev = Math.sqrt(variance)
return weeklyStdDev * Math.sqrt(40)
}


Couple of notes here. I'm using the joda time scala wrapper (import org.scala_tools.time.Imports._). The sum() and average() functions should be obvious. getReturn(now - 7.days, now) returns an Option[Double] of the increase in price in the last seven days as per the c# version.





If you are learning Scala, you can stop right here! The imperative version shown above is absolutely fine. What follows here is an intellectual exercise in using functional programming techniques to reorganize the algorithm so that it does not use vars. Many people would see this as a pointless exercise.


It also turns out that there is a very simple solution to this exercise which is described in the follow up section at the end of the post.





I wanted to convert my stdDevImperative to a functional style, getting rid of all those icky vars. So this was my first go:

Simple Recursive version of stddev


// simplistic recursive version of imperative c# version
def stdDevSimpleRecursive(): Double = {
var now = Dates.now()
val limit = now - 3.years
def getListOfReturns(d : DateTime) : List[Double] = {
if (d >= _start + 12.days && d >= limit)
getReturn(d - 7.days, d).get :: getListOfReturns(d - 7.days) // Not TCO
else
Nil
}
val rets = getListOfReturns(now)
val mean = average(rets)
val variance = sum(rets.map(r => Math.pow(r - mean, 2))) /rets.length
val weeklyStdDev = Math.sqrt(variance);
return weeklyStdDev * Math.sqrt(40);
}


No more vars, but its longer and the algorithm is less obvious than the imperative version. (At least for me). Also the compiler cannot perform "tail call optimisation". And so it munches stack.

Next step: allow the compiler to use "tail call optimisation".

TCO Recursive version of stddev


// recursive version allowing TCO and insert result at head of list
def stdDevTCORecursive(): Double = {
var now = Dates.now()
val limit = now - 3.years
def getListOfReturns(results: List[Double], d : DateTime) : List[Double] = {
if (d >= _start + 12.days && d >= limit) {
val newResults = getReturn(d - 7.days, d).get :: results // optimised to add to head not tail
getListOfReturns(newResults, d - 7.days) // TCO possible
} else
results
}
val rets = getListOfReturns(List(),now).reverse // reverse unnecessary for algorithm, but matches c# ordering
val mean = average(rets)
val variance = sum(rets.map(r => Math.pow(r - mean, 2))) /rets.length
val weeklyStdDev = Math.sqrt(variance);
return weeklyStdDev * Math.sqrt(40);
}

Well it works, but its getting uglier and the actual algorithm is being obscured. (Yeah I know, premature optimisation is the root of all evil. Yada. But TCO is pretty basic).

So I got to thinking if there was a better algoirthm and couldn't think of one. All the neat algorithms use a starting list/sequence and I couldn't see one. We are going back in time 7 days at a time, stoping when we get to the first date (_start) of the available price history or we have collected 3 years worth of returns. Hell, I liked the imperative algorithm. Maybe that will change with as my brain adapts to functional techniques. I though what I need is a way of factoring out the recursion into a "thing".

Enter the unfold function described nicely by David Pollak. I though that this might be the solution. So here goes:

Unfold version of stddev


// This unfold function should probably be built into a scala standard library somewhere
def unfold[T, R](init: T)(f: T => Option[(R, T)]): List[R] = f(init) match {
case None => Nil
case Some(r, v) => r :: unfold(v)(f)
}

def stdDevUnfold(): Double = {
val now = Dates.now()
val limit = now - 3.years
def getReturnOnDayAndNextDay(someDay: DateTime) : Option[(Double, DateTime)] = {
if (someDay < _start || someDay < limit) return None
val ret = getReturn(someDay - 7.days, someDay)
Some((ret.get, someDay - 7.days)) // might throw Exception (same as c# code)
}
val rets = unfold(now)(getReturnOnDayAndNextDay)
val mean = average(rets)
val variance = sum(rets.map(r => Math.pow(r - mean, 2)))
val weeklyStdDev = Math.sqrt(variance / rets.length);
return weeklyStdDev * Math.sqrt(40);
}


So the recursion is now out of my stdDev function but the code is even more obscure. The problem is that the unfold parameter function f() has two jobs to do. (1) manage a loop counter, (in this case the current date) and (2) return a value to be collected in the returned results.

So I liked the idea of unfold because it hides the recursion, but not the multiple jobs that the function has to do. The obvious answer is to have more than one function argument: One function to get the next value for the iteration or terminate and another function that works out the result. Hang on a minute, that first function is doing two jobs: getting the next value and working out when to terminate. How about three functions? Hey wait while we're there...
This is what I see in my head as the clearest way of defining the algorithm in a sort of java/pseudocode


// Java like pseudocode
ArrayList a = new ArrayList();
for (DateTime d = today(); d > firstDateForWhichPriceIsAvaialble && d > today - 3.years; d = d - 7.days) {
a.append(getReturn(d, d-7.days));
}


Final version of stddev with forloop


The challenge was to come up with something completely stateless of course. After many iterations (ha ha) this is what I have ended up with as my new stdDev method:



def stdDevForLoop(): Double = {
val today = Dates.now()
val limit = today - 3.years
val rets =
forloop(today)(d => d >= _start && d >= limit)(d => d - 7.days) {
day =>
val ret = getReturn(day - 7.days, day)
ret.get // might throw Exception (same as c# code)
}
val mean = average(rets)
val variance = sum(rets.map(r => Math.pow(r - mean, 2))) / rets.length
val weeklyStdDev = Math.sqrt(variance);
return weeklyStdDev * Math.sqrt(40);
}

I was pretty happy with the result, which, to my non-functional brain, expresses the algorithm (slightly) more clearly than the original imperative algorithm but without using any vars. The recursion is neatly tucked away in the forloop method. The definition of forloop is shown at the end.

The icing on the cake: break and continue


I realised that c style break and continue would be relatively easy to implement copying the techniques used in breakable. As ever the devil is in the detail: as the forloop body yields a value that is collected in the list of results, I added a breakWith 'keyword' too so that you can return a last value when you break. So the following becomes possible:


def stdDevForLoopBreakAndContinue(): Double = {
val today = Dates.now()
val limit = today - 3.years
val rets = forloop(today)(d => d >= _start && d >= limit)(d => d - 7.days) {
day =>
val ret = getReturn(day - 7.days, day)
if (day == blackFriday1981) continue
if (day == blackWednesday1973) breakWith(0d) // or just plain break, if no final value required
ret.get // might throw Exception (same as c# code)
}
val mean = average(rets)
val variance = sum(rets.map(r => Math.pow(r - mean, 2))) / rets.length
val weeklyStdDev = Math.sqrt(variance);
return weeklyStdDev * Math.sqrt(40);
}

What's missing?

  • Multiple loop counters. Of course you can do this by using a Tuple as the loop counter containing as many 'embedded' loop counters as you need, but this gets Lispy with the brackets. Maybe there is a nicer way.
Here's the code for forloop:



trait ForLoops
{


def forloop[T, R]
(init: T)
(cond: T => Boolean)
(next: T => T)
(body: T => R)
: List[R] =
{
def _forloop(currentResults: List[R])(loopcounter: T) : List[R] = {
if (cond(loopcounter)) {
// Get the newResults by appending the result of the body() function on the current loopcounter
val newResults = try {
// normal case: body() returns something based on loopcounter
val bodyResult = body(loopcounter)

bodyResult match {
// if the body has no return value (Unit) then don't bother appending to the newResults
case _:Unit => currentResults
// if the body has a return value then append to the newResults
case _ => bodyResult :: currentResults
}

} catch {
case ex: ContinueException => currentResults // don't add to the currentResults
case ex: BreakException => return currentResults // exit loop so return currentResults
case BreakWithException(v) => return v.asInstanceOf[R] :: currentResults
}
// This can be tail call optimized
// Recursive call to loop with the next value of the loop counter
_forloop(newResults)(next(loopcounter))
} else {
// loop has terminated so return results
currentResults
}
}
// Seed the _forloop currentResults with an empty list and the first value (init) of the loopcounter
_forloop(List())(init).reverse
}

// Copied break and continue technique from "breakable" in Scala 2.8
def continue = throw continueException
def break = throw breakException
def breakWith(v:Any) = throw new BreakWithException(v)

private class ContinueException extends RuntimeException
private val continueException = new ContinueException

private class BreakException extends RuntimeException
private val breakException = new BreakException

private case class BreakWithException(v:Any) extends RuntimeException {
override def fillInStackTrace():Throwable = {
// this is a performance optimisation. the actual stack trace is not required.
return null;
}
}

}


Follow up: An implementation using Iterator.iterate:

In the next couple of posts (in the future) I look at various alternatives to the forloop function that I have described above and got some very helpful comments from various people. The simplest (and built-in) solution that I found was using the new Iterator.iterate method in Scala 2.8pre along with takeWhile and map. Also, a very neat alternative to Iterator.iterate is to use define an iterate method that uses lazy Stream.



def stdDevIteratorIterate(): Double = {
val today = Dates.now()
val limit = today - 3.years
val dates = Iterator.iterate(today)(_ - 7.days) takeWhile (d => d >= _start && d >= limit) toList
val rets = dates.map(d => getReturn(d - 7.days, d).get)
val mean = rets.average
val variance = rets.map(r => Math.pow(r - mean, 2)).average
val weeklyStdDev = Math.sqrt(variance);
return weeklyStdDev * Math.sqrt(40);
}

Friday 9 October 2009

Using syntax version of SyntaxHighlighter with Blogger

The latest version 2.1.364 of the amazing SyntaxHighlighter from Alex Gorbatchev needs a bit of TLC to get working on Blogger:

Below is what I have right before the </head> tag in the Edit HTML tab of my Blogger layout.


<script src='http://alexgorbatchev.com/pub/sh/2.1.364/scripts/shCore.js' type='text/javascript'/>
<script src='http://alexgorbatchev.com/pub/sh/2.1.364/scripts/shBrushCSharp.js' type='text/javascript'/>
<script src='http://alexgorbatchev.com/pub/sh/2.1.364/scripts/shBrushJava.js' type='text/javascript'/>
<script src='http://alexgorbatchev.com/pub/sh/2.1.364/scripts/shBrushJScript.js' type='text/javascript'/>
<script src='http://alexgorbatchev.com/pub/sh/2.1.364/scripts/shBrushPlain.js' type='text/javascript'/>
<script src='http://alexgorbatchev.com/pub/sh/2.1.364/scripts/shBrushScala.js' type='text/javascript'/>
<script src='http://alexgorbatchev.com/pub/sh/2.1.364/scripts/shBrushXml.js' type='text/javascript'/>
<link href='http://alexgorbatchev.com/pub/sh/2.1.364/styles/shCore.css' rel='stylesheet' type='text/css'/>
<link href='http://alexgorbatchev.com/pub/sh/2.1.364/styles/shThemeDefault.css' id='shTheme' rel='stylesheet' type='text/css'/>
<style type='text/css'>
.syntaxhighlighter .line {
font-size: 76% !important;
}
</style>
<script type='text/javascript'>
SyntaxHighlighter.config.clipboardSwf = 'http://alexgorbatchev.com/pub/sh/2.1.364/scripts/clipboard.swf';
SyntaxHighlighter.all();
SyntaxHighlighter.config.bloggerMode=true;
SyntaxHighlighter.defaults['font-size'] = '50%';
</script>


Thats all you need to make it work.

The ugly section

<style type='text/css'>
.syntaxhighlighter .line {
font-size: 76% !important;
}
</style>

works around a bug in this version of SyntaxHighlighter described here. This allows SyntaxHighlighter.defaults['font-size'] = '50%'; to actually have an effect. If you don't do this, for some reason the code is shown at a ridiculously large size.

So now I can type exactly what you see here into Blogger's Compose editor

<pre class="brush: scala">
def factorialInt(i: Int): Int = {
def fact(i: Int)(accumulator: Int): Int = i match {
case 1 => accumulator
case _ => fact(i - 1)(i * accumulator)
}
fact(i)(1)
}
</pre>


which then appears as follows


def factorialInt(i: Int): Int = {
def fact(i: Int)(accumulator: Int): Int = i match {
case 1 => accumulator
case _ => fact(i - 1)(i * accumulator)
}
fact(i)(1)
}


There is a newer syntax as of 2.1.364 which uses CDATA section for your code which allows you to put special charcters such as < and > characters directly in the code without them being encoded as &lt; or &gt; which is possibly useful for Scala. Although blogger does this for you (automatic encoding of special characters) so its not too important for blogger posts.

So if you type exactly what you see here into Blogger's Compose editor (but the "pre" syntax above is simpler when using blogger




then you get this