Thursday, 29 October 2009

Becoming really rich with Scala

Update: 31 March 2014: This article uses Scala 2.8
  • Updated version of this  code using scala 2.11 available in gtihub
  •   git clone https://github.com/azzoti/get-rich-with-scala.git
  • Its an eclipse scala ide project and a maven project
  • Can be run with:
  •   mvn scala:run -DmainClass=etf.analyzer.Program

Becoming really rich with C# was a great example of what's coming in the version of C# in Visual Studio 2010 and it makes it obvious that C# is leaving Java in the dust. Now that the Visual Studio 2010 beta 2 is available, you can download Luca's code and try it out.

For this post, I've translated the C# code into Scala while trying to preserve the C# original style. To do this I have added support code in order to match the C# style. This was easy and shows off Scala's extensibility.

The features/libraries or libraries added or used to do this are:
  • Scala-Time: A Java Joda Time library wrapper.
  • The "using" block from Martin Odersky's FOSDEM '09 talk
  • An EventHandler class for simulating C# Events
  • The Jetty HTTPClient from Eclipse that I wrapped to resemble the C# WebClient api.
  • Artithmetic operations for Option[Double]. (Option[Double] is Scala's equivalent of the C# nullable type double? In C# you can use double? variables in expresssions, with expressions returning null if any part of the expression is null. In Scala, you can't use Option[Double] in arithmetic expressions out of the box, but its very easy to add this ability in a small library.
  • The Scala code is written with the latest 2.8 pre version of Scala and uses one or two features from its latest standard library not present in the latest stable release.

While the Scala is slighly shorter than the C# code, it is supported by extra code or libraries that I have found or had to write. C# already has using blocks, Events and reasonable datetime management, a WebClient and Nullable double types that handle arithmentic operations sensibly.

For whatever reason, the Scala code runs much faster than the c# code, but there is a large amount of internet access involved and I suspect that the C# web client should be configured to use more threads. [Update: Luca just suggested I comment out the C# line ServicePointManager.DefaultConnectionLimit = 10; and this does indeed make the C# code much faster.]



Original C#Scala
See notes after the table

                                                                                
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using System.IO;

namespace ETFAnalyzer {

struct Event {
  internal Event(DateTime date, double price) { Date = date; Price = price; }
  internal readonly DateTime Date;
  internal readonly double Price;
}
class Summary {
  internal Summary(string ticker, string name, string assetClass,
          string assetSubClass, double? weekly, double? fourWeeks,
          double? threeMonths, double? sixMonths, double? oneYear,
          double? stdDev, double price, double? mav200) {
    Ticker = ticker;
    Name = name;
    AssetClass = assetClass;
    AssetSubClass = assetSubClass;
    // Abracadabra ...
    LRS = (fourWeeks + threeMonths + sixMonths + oneYear) / 4;
    Weekly = weekly;
    FourWeeks = fourWeeks;
    ThreeMonths = threeMonths;
    SixMonths = sixMonths;
    OneYear = oneYear;
    StdDev = stdDev;
    Mav200 = mav200;
    Price = price;
  }
  internal readonly string Ticker;
  internal readonly string Name;
  internal readonly string AssetClass;
  internal readonly string AssetSubClass;
  internal readonly double? LRS;
  internal readonly double? Weekly;
  internal readonly double? FourWeeks;
  internal readonly double? ThreeMonths;
  internal readonly double? SixMonths;
  internal readonly double? OneYear;
  internal readonly double? StdDev;
  internal readonly double? Mav200;
  internal double Price;

  internal static void Banner() {
    Console.Write("{0,-6}", "Ticker");
    Console.Write("{0,-50}", "Name");
    Console.Write("{0,-12}", "Asset Class");
    Console.Write("{0,4}", "RS");
    Console.Write("{0,4}", "1Wk");
    Console.Write("{0,4}", "4Wk");
    Console.Write("{0,4}", "3Ms");
    Console.Write("{0,4}", "6Ms");
    Console.Write("{0,4}", "1Yr");
    Console.Write("{0,6}", "Vol");
    Console.WriteLine("{0,2}", "Mv");
  }

  internal void Print() {

    Console.Write("{0,-6}", Ticker);
    Console.Write("{0,-50}", new String(Name.Take(48).ToArray()));
    Console.Write("{0,-12}", new String(AssetClass.Take(10).ToArray()));
    Console.Write("{0,4:N0}", LRS * 100);
    Console.Write("{0,4:N0}", Weekly * 100);
    Console.Write("{0,4:N0}", FourWeeks * 100);
    Console.Write("{0,4:N0}", ThreeMonths * 100);
    Console.Write("{0,4:N0}", SixMonths * 100);
    Console.Write("{0,4:N0}", OneYear * 100);
    Console.Write("{0,6:N0}", StdDev * 100);
    if (Price <= Mav200)
      Console.WriteLine("{0,2}", "X");
    else
      Console.WriteLine();
  }
}

class TimeSeries {
  internal readonly string Ticker;
  readonly DateTime _start;
  readonly Dictionary<DateTime, double> _adjDictionary;
  readonly string _name;
  readonly string _assetClass;
  readonly string _assetSubClass;

  internal TimeSeries(string ticker, string name, string assetClass, string assetSubClass, IEnumerable<event> events) {
    Ticker = ticker;
    _name = name;
    _assetClass = assetClass;
    _assetSubClass = assetSubClass;
    _start = events.Last().Date;
    _adjDictionary = events.ToDictionary(e => e.Date, e => e.Price);
  }

  bool GetPrice(DateTime when, out double price, out double shift) {
    // To nullify the effect of hours/min/sec/millisec being different from 0
    when = new DateTime(when.Year, when.Month, when.Day);
    var found = false;
    shift = 1;
    double aPrice = 0;
    while (when >= _start && !found) {
      if (_adjDictionary.TryGetValue(when, out aPrice)) {
        found = true;
      }
      when = when.AddDays(-1);
      shift -= 1;
    }
    price = aPrice;
    return found;
  }

  double? GetReturn(DateTime start, DateTime end) {
    var startPrice = 0.0;
    var endPrice = 0.0;
    var shift = 0.0;
    var foundEnd = GetPrice(end, out endPrice, out shift);
    var foundStart = GetPrice(start.AddDays(shift), out startPrice, out shift);
    if (!foundStart || !foundEnd)
      return null;
    else
      return endPrice / startPrice - 1;
  }

  internal double? LastWeekReturn() {
    return GetReturn(DateTime.Now.AddDays(-7), DateTime.Now);
  }
  internal double? Last4WeeksReturn() {
    return GetReturn(DateTime.Now.AddDays(-28), DateTime.Now);
  }
  internal double? Last3MonthsReturn() {
    return GetReturn(DateTime.Now.AddMonths(-3), DateTime.Now);
  }
  internal double? Last6MonthsReturn() {
    return GetReturn(DateTime.Now.AddMonths(-6), DateTime.Now);
  }
  internal double? LastYearReturn() {
    return GetReturn(DateTime.Now.AddYears(-1), DateTime.Now);
  }
  internal double? StdDev() {
    var now = DateTime.Now;
    now = new DateTime(now.Year, now.Month, now.Day);
    var limit = now.AddYears(-3);
    var rets = new List<double>();
    while (now >= _start.AddDays(12) && now >= limit) {
      var ret = GetReturn(now.AddDays(-7), now);
      rets.Add(ret.Value);
      now = now.AddDays(-7);
    }
    var mean = rets.Average();
    var variance = rets.Select(r => Math.Pow(r - mean, 2)).Sum();
    var weeklyStdDev = Math.Sqrt(variance / rets.Count);
    return weeklyStdDev * Math.Sqrt(40);
  }
  internal double? MAV200() {
    return _adjDictionary.ToList()
           .OrderByDescending(k => k.Key)
           .Take(200).Average(k => k.Value);
  }
  internal double TodayPrice() {
    var price = 0.0;
    var shift = 0.0;
    GetPrice(DateTime.Now, out price, out shift);
    return price;
  }
  internal Summary GetSummary() {
    return new Summary(Ticker, _name, _assetClass, _assetSubClass,
           LastWeekReturn(), Last4WeeksReturn(), Last3MonthsReturn(),
           Last6MonthsReturn(), LastYearReturn(), StdDev(), TodayPrice(), 
           MAV200());
  }
}

class Program {

  static string CreateUrl(string ticker, DateTime start, DateTime end)
  {
    return @"http://ichart.finance.yahoo.com/table.csv?s=" + ticker + 
      "&a="+(start.Month - 1).ToString()+"&b="+start.Day.ToString()+"&c="+start.Year.ToString() + 
      "&d="+(end.Month - 1).ToString()+"&e="+end.Day.ToString()+"&f="+end.Year.ToString() + 
      "&g=d&ignore=.csv";
  }

  static void Main(string[] args) {
    // If you rise this above 5 you tend to get frequent connection closing on my machine
    // I'm not sure if it is msft network or yahoo web service
    ServicePointManager.DefaultConnectionLimit = 10;

    var tickers =
      File.ReadAllLines("ETFTest.csv")
      .Skip(1)
      .Select(l => l.Split(new[] { ',' }))
      .Where(v => v[2] != "Leveraged")
      .Select(values => Tuple.Create(values[0], values[1], values[2], values[3]))
      .ToArray();

    var len = tickers.Length;

    var start = DateTime.Now.AddYears(-2);
    var end = DateTime.Now;
    var cevent = new CountdownEvent(len);
    var summaries = new Summary[len];
    
    for(var i = 0; i < len; i++)  {
      var t = tickers[i];
      var url = CreateUrl(t.Item1, start, end);
      using (var webClient = new WebClient()) {
        webClient.DownloadStringCompleted +=
                        new DownloadStringCompletedEventHandler(downloadStringCompleted);
        webClient.DownloadStringAsync(new Uri(url), Tuple.Create(t, cevent, summaries, i));
      }
    }

    cevent.Wait();
    Console.WriteLine("\n");

    var top15perc =
        summaries
        .Where(s => s.LRS.HasValue)
        .OrderByDescending(s => s.LRS)
        .Take((int)(len * 0.15));
    var bottom15perc =
        summaries
        .Where(s => s.LRS.HasValue)
        .OrderBy(s => s.LRS)
        .Take((int)(len * 0.15));

    Console.WriteLine();
    Summary.Banner();
    Console.WriteLine("TOP 15%");
    foreach(var s in top15perc)
      s.Print();

    Console.WriteLine();
    Console.WriteLine("Bottom 15%");
    foreach (var s in bottom15perc)
      s.Print();
      
  }

  static void downloadStringCompleted(object sender, DownloadStringCompletedEventArgs e) {
    var bigTuple = (Tuple<Tuple<string, string, string, string>, CountdownEvent, Summary[], int>)e.UserState;
    var tuple = bigTuple.Item1;
    var cevent = bigTuple.Item2;
    var summaries = bigTuple.Item3;
    var i = bigTuple.Item4;
    var ticker = tuple.Item1;
    var name = tuple.Item2;
    var asset = tuple.Item3;
    var subAsset = tuple.Item4;

    if (e.Error == null) {
      var adjustedPrices =
          e.Result
          .Split(new[] { '\n' })
          .Skip(1)
          .Select(l => l.Split(new[] { ',' }))
          .Where(l => l.Length == 7)
          .Select(v => new Event(DateTime.Parse(v[0]), Double.Parse(v[6])));

      var timeSeries = new TimeSeries(ticker, name, asset, subAsset, adjustedPrices);
      summaries[i] = timeSeries.GetSummary();
      cevent.Signal();
      Console.Write("{0} ", ticker);
    } else {
      Console.WriteLine("[{0} ERROR] ", ticker);
      summaries[i] = new Summary(ticker,name,"ERROR","ERROR",0,0,0,0,0,0,0,0); 
      cevent.Signal();
    }
  }
}
}

                                                                                
package etf.analyzer

import scala.io.Source
import org.scala_tools.time.Imports._
import org.scala_tools.option.math.Imports._  
import org.joda.time.Days
import org.scala_tools.using.Using
import org.scala_tools.web.WebClient
import org.scala_tools.web.WebClientConnections
import org.scala_tools.web.DownloadStringCompletedEventArgs 
import java.io.File
import java.util.concurrent.CountDownLatch 

case class Event (date : DateTime, price : Double) {}


case class Summary (
  ticker : String, name : String, assetClass : String,
  assetSubClass : String, weekly : Option[Double], 
  fourWeeks : Option[Double], threeMonths : Option[Double], 
  sixMonths : Option[Double], oneYear : Option[Double],
  stdDev : Double, price : Double, mav200 : Double 
) {



  // Abracadabra ...
  val LRS = (fourWeeks + threeMonths + sixMonths+ oneYear) / 4   














  




  def banner() = {
    printf("%-6s", "Ticker")
    printf("%-50s", "Name")
    printf("%-12s", "Asset Class")
    printf("%4s", "RS")
    printf("%4s", "1Wk")
    printf("%4s", "4Wk")
    printf("%4s", "3Ms")
    printf("%4s", "6Ms")
    printf("%4s", "1Yr")
    printf("%6s", "Vol")
    printf("%2s\n", "Mv")
  }
  
  def print() = {
  
    printf("%-6s", ticker);
    printf("%-50s", new String(name.toArray.take(48)))
    printf("%-12s", new String(assetClass.toArray.take(10)));
    printf("%4.0f", LRS * 100 getOrElse null)
    printf("%4.0f", weekly * 100 getOrElse null)
    printf("%4.0f", fourWeeks * 100 getOrElse null)
    printf("%4.0f", threeMonths * 100 getOrElse null)
    printf("%4.0f", sixMonths * 100 getOrElse null)
    printf("%4.0f", oneYear * 100 getOrElse null)
    printf("%6.0f", stdDev * 100);
    if (price <= mav200) {
      printf("%2s\n", "X");
    } else {
      println();
    }
  }  
}

case class TimeSeries (
    ticker : String, name : String, assetClass : String, 
    assetSubClass : String, private events : Iterable[Event]
) {
  
  private val _adjDictionary : Map[DateTime, Double] 
              = Map() ++ events.map(e => (e.date -> e.price))
  private val _start = events.last.date




  // Add the sum and average function to all Iterables[Double] used locally
  private implicit def iterableWithSumAndAverage(c: Iterable[Double]) = new { 
    def sum = c.foldLeft(0.0)(_ + _) 
    def average = sum / c.size
  }  
  
  def getPrice(whenp : DateTime) : Option[(Double,Int)] =  {
    var when = new DateTime(whenp.year.get,whenp.month.get,whenp.day.get,0,0,0,0)
    var found = false
    var shift = 1
    var aPrice = 0.0
    while (when >= _start && !found) {
        if (_adjDictionary.contains(when)) {
            aPrice = _adjDictionary(when)
            found = true
        }
        when = when - 1.days
        shift -= 1
    }
    // Either return the price and the shift or None if no price was found
    if (found) Some(aPrice,shift) else return None
  }  

  def getReturn(start: DateTime, end: DateTime) : Option[Double] = {
    for {
      (endPrice,daysBefore) <- -="" .takewhile="" 1.0="" 1.years="" 28.days="" 3.months="" 3.years="" 6.months="" 7.days="" d="" dates="Iterator.iterate(today)(_" datetime.now="" daysbefore.days="" def="" double="{" end="" endprice="" getprice="" last3monthsreturn="getReturn(DateTime.now" last4weeksreturn="getReturn(DateTime.now" last6monthsreturn="getReturn(DateTime.now" lastweekreturn="getReturn(DateTime.now" lastyearreturn="getReturn(DateTime.now" limit="today" private="" start="" startprice="" stddev="" today="DateTime.now" val="" yield=""> d >= (_start + 12.days) && d >= limit)
                .toList
    val rets = dates.map(d => getReturn(d - 7.days, d).get)
    val mean = rets.average
    val variance = rets.map(r => Math.pow(r - mean, 2)).average
    val weeklyStdDev = Math.sqrt(variance);
    return weeklyStdDev * Math.sqrt(40);
  }



  def mav200(): Double = {
    return _adjDictionary.toList
           .sortWith((elem1, elem2) => elem1._1 >= elem2._1)
           .take(200).map(keyValue => keyValue._2).average
  }
  def todayPrice() : Double = {
    getPrice(DateTime.now) match {
      case None => 0.0
      case Some((price,_)) => price 
    }
  }
  def getSummary() = 
    Summary(ticker, name, assetClass, assetSubClass, 
      lastWeekReturn, last4WeeksReturn, last3MonthsReturn, 
      last6MonthsReturn, lastYearReturn, stdDev, todayPrice, 
      mav200)
}

object Program extends Using {

  def createUrl(ticker: String, start: DateTime, end: DateTime) : String = {
    return """http://ichart.finance.yahoo.com/table.csv?s=""" + ticker +
      "&a="+(start.month.get-1)+ "&b=" + start.day.get + "&c=" + start.year.get +
      "&d="+(end.month.get  -1)+ "&e=" + end.day.get + "&f=" + end.year.get +
      "&g=d&ignore=.csv"
  } 


 
  def main(args : Array[String]) : Unit = {

    
    val tickers = 
     Source.fromFile(new File("ETFTest.csv")).getLines()
     .drop(1)
     .map(l => l.trim.split(','))
     .filter(v => v(2) != "Leveraged")
     .map(values => (values(0),values(1),values(2),if (values.length==4) values(3) else ""))
     .toSeq.toArray

    val len = tickers.length;

    val start = DateTime.now - 2.years
    val end = DateTime.now
    val cevent = new CountDownLatch(len)
    val summaries = new Array[Summary](len)
    
    using(new WebClientConnections(connectionsPerAddress = 10, threadPool=10)) {
      webClientConnections =>
      for (i <- .filter="" 0="" cevent.await="" cevent="" downloadstringcompleted="" end="" i="" len="" println="" s="" start="" summaries="" t="" top15perc="summaries" until="" url="" val="" webclient.downloadstringasync="" webclient.downloadstringcompleted="" webclient="webClientConnections.getWebClient"> s.LRS.isDefined)
      .sortWith((elem1, elem2) => elem1.LRS >= elem2.LRS)
      .take((len * 0.15).toInt)      
    val bottom15perc =
      summaries
      .filter(s => s.LRS.isDefined)
      .sortWith((elem1, elem2) => elem1.LRS <= elem2.LRS)
      .take((len * 0.15).toInt)

    println
    summaries(0).banner()
    println("TOP 15%")
    for (s <- .drop="" .map="" .split="" 15="" :="" adjustedprices="e.result" array="" asset="" bigtuple="" bottom15perc="" cevent="" countdownlatch="" datetimeformat.forpattern="" def="" downloadstringcompleted="" downloadstringcompletedeventargs="" e.error="=" e="" for="" i="" if="" int="" l="" n="" name="" null="" ottom="" parse="" parsedatetime="" println="" s.print="" s="" string="" subasset="" summaries="" ticker="" top15perc="" ummary="" val="" yyyy-mm-dd=""> l.split(','))
          .filter(l => l.length == 7)
          .map(v => Event(parse(v(0)),v(6).toDouble))

      val timeSeries = new TimeSeries(ticker, name, asset, subAsset, adjustedPrices);
      summaries(i) = timeSeries.getSummary();
      cevent.countDown() 
      printf("%s ", ticker)
    } else {
      printf("[%s ERROR] \n", ticker)
      summaries(i) = Summary(ticker,name,"ERROR","ERROR",Some(0),Some(0),Some(0),Some(0),Some(0),0,0,0) 
      cevent.countDown()
    }
  }  
}


Notes

TimeSeries getPrice method: Scala does not have output parameters on methods. It doesn't need them because the return type from a method can be a tuple and you can return as many values as you like. Also the method shown copies the C# style closely using loop variables. Another way of writing the same method in Scala making use of list functions is:
def getPrice(when : DateTime) : Option[(Double,Int)] =  {
  // Find the most recent day with a price starting from when, but don't go back past _start 
  val latestDayWithPrice 
      = Iterator.iterate(when)(_ - 1.days)
        .dropWhile (d=> !_adjDictionary.contains(d) && d >= _start )
        .next
  if (_adjDictionary.contains(latestDayWithPrice)) {
    val shift = Days.daysBetween(when,latestDayWithPrice).getDays()
    val aPrice = _adjDictionary(latestDayWithPrice) 
    Some((aPrice,shift))
  } else {
    None
  }
}  

TimeSeries getReturn method: The 2 calls to getPrice() return an Option[a price, days offset].
Dealing with Option[...] in a for expression is an easy way of dealing with the possibilty of either Option[...] being None. If either getPrice() call returns None, then the yield will return a None as well. Another perhaps simpler to understand getReturn implementation is:
def getReturn(start: DateTime, end: DateTime) : Option[Double] = {
  var endPriceDetails = getPrice(end)
  if (endPriceDetails == None) return None
  val (endPrice,daysBefore) = endPriceDetails.getOrElse(null)
  val startPriceDetails = getPrice(start + daysBefore.days)
  if (startPriceDetails == None) return None
  val (startPrice,_) = startPriceDetails.getOrElse(null)
  (endPrice / startPrice - 1.0)      
}

TimeSeries mav200 method: The scala version is slightly harder work than .net 3.5 LINQ OrderByDescending method with key selector syntax: .OrderByDescending(k => k.Key). The Scala version has to say *how* to do it. The LINQ version says *what* is required. The same is true for the C# use of the average function, which uses a field selector.

I'm not happy that I know whether the concurrent access to summaries array access in downloadStringCompleted is safe. It seems to work but I don't know if it is genuinely thread safe. I've just copied the C# code, which may have built-in thread safe array access.


Some of the features of Scala that are shown here
  • Easy Java library interop. See use of CountDownLatch, Days, File.
  • Good old fashioned casting if you really need it. See asInstanceOf.
  • No semicolons
  • No need to use () for a function declaration with no parameters or a call to ut. See TimeSeries.getSummary(). (brackets recommended if there are side effects)
  • Type declarations are unnecessary except in method parameters, but can be declared explicityly if it aids readability. See _adjDictionary.
  • Named and default parameters. See WebClientConnections
  • Much less boilerplate with "case" classes providing automatic constructors, fields, toString, equals, hashcode. See Event, Summary, TimeSeries
  • Joda time wrapper so you can say "today - 3.years"
  • Pattern matching assignment "val ((ticker,name),cevent,summaries,i) = bigTuple" in downloadStringCompleted
  • "using" block for automatic resource closing. In the main method, the "using(new WebClientConnections(" block will close down the WebClientConnections thread pool at the end of the block. This is very similar to the C# "using" code.
  • local "implicit" function definitions allowing you to effectively add methods to existing classes in a tightly controlled and scoped way. (see def iterableWithSumAndAverage)
  • Pattern matching, switch on steroids. See todayPrice().
  • Use of powerful list manipulation functions, such as Iterator.iterate, takeWhile to replace traditional state based loops. See iterate/dropWhile examples in stdDev() and in main(): drop, map, filter, sortWith, take. See the infamous foldLeft example at work in the sum function.
Post a Comment