Wednesday, July 15, 2009

Improving performance in Jython

About two weeks ago I published a writeup about my findings about the performance of synchronization primitives in Jython from my presentation at JavaOne. During the presentation I said that these performance issues were something that I was going to work on, and improve. And indeed I did. I cannot take full credit for this, Jim Baker played a substantial part in this work as well. The end result is still something I'm very proud of since we managed to improve the performance of this benchmark as much as 50 times.

The benchmarks

The comparisons were performed based on the execution of this benchmark script invoked with:

  • JAVA_HOME=$JAVA_6_HOME jython synchbench.py
  • JAVA_HOME=$JAVA_6_HOME jython -J-server synchbench.py
# -*- coding: utf-8 -*-
from __future__ import with_statement, division

from java.lang.System import nanoTime
from java.util.concurrent import Executors, Callable
from java.util.concurrent.atomic import AtomicInteger

from functools import wraps
from threading import Lock

def adder(a, b):
    return a+b


count = 0
def counting_adder(a, b):
    global count
    count += 1 # NOT SYNCHRONIZED!
    return a+b


lock = Lock()
sync_count = 0
def synchronized_counting_adder(a, b):
    global sync_count
    with lock:
        sync_count += 1
    return a+b


atomic_count = AtomicInteger()
def atomic_counting_adder(a,b):
    atomic_count.incrementAndGet()
    return a+b


class Task(Callable):
    def __init__(self, func):
        self.call = func

def callit(function):
    @Task
    @wraps(function)
    def callable():
        timings = []
        for x in xrange(5):
            start = nanoTime()
            for x in xrange(10000):
                function(5,10)
            timings.append((nanoTime() - start)/1000000.0)
        return min(timings)
    return callable

def timeit(function):
    futures = []
    for i in xrange(40):
        futures.append(pool.submit(function))
    sum = 0
    for future in futures:
        sum += future.get()
    print sum

all = (adder,counting_adder,synchronized_counting_adder,atomic_counting_adder)
all = [callit(f) for f in all]

WARMUP = 20000
print "<WARMUP>"
for function in all:
    function.call()
for function in all:
    for x in xrange(WARMUP):
        function.call()
print "</WARMUP>"

pool = Executors.newFixedThreadPool(3)

for function in all:
    print
    print function.call.__name__
    timeit(function)
pool.shutdown()

glob = list(globals())
for name in glob:
    if name.endswith('count'):
        print name, globals()[name]

And the JRuby equivalent for comparison:

require 'java'
import java.lang.System
import java.util.concurrent.Executors
require 'thread'

def adder(a,b)
  a+b
end

class Counting
  def initialize
    @count = 0
  end
  def count
    @count
  end
  def adder(a,b)
    @count = @count + 1
    a+b
  end
end

class Synchronized
  def initialize
    @mutex = Mutex.new
    @count = 0
  end
  def count
    @count
  end
  def adder(a,b)
    @mutex.synchronize {
      @count = @count + 1
    }
    a + b
  end
end

counting = Counting.new
synchronized = Synchronized.new

puts "<WARMUP>"
10.times do
  10000.times do
    adder 5, 10
    counting.adder 5, 10
    synchronized.adder 5, 10
  end
end
puts "</WARMUP>"

class Body
  def initialize
    @pool = Executors.newFixedThreadPool(3)
  end
  def timeit(name)
    puts
    puts name
    result = []
    40.times do
      result << @pool.submit do
        times = []
        5.times do
          t = System.nanoTime
          10000.times do
            yield
          end
          times << (System.nanoTime - t) / 1000000.0
        end
        times.min
      end
    end
    result.each {|future| puts future.get()}
  end
  def done
    @pool.shutdown
  end
end

body = Body.new

body.timeit("adder") {adder 5, 10}
body.timeit("counting adder") {counting.adder 5, 10}
body.timeit("synchronized adder") {synchronized.adder 5, 10}

body.done

Where we started

A week ago the performance of this Jython benchmark was bad. Compared to the equivalent code in JRuby, Jython required over 10 times as much time to complete.

When I analyzed the code that Jython and JRuby generated and executed, I came to the conclusion that the reason Jython performed so badly was that the call path from the running code to the actual lock/unlock instructions introduced too much overhead for the JVM to have any chance at analyzing and optimizing the lock. I published this analysis in my writeup on the problem. It would of course be possible to lower this overhead by importing and utilizing the pure Java classes for synchronization instead of using the Jython threading module, but we like how the with-statement reads for synchronization:

with lock:
    counter += 1

Getting better

Based on my analysis of the how the with-statement compiles and the way that this introduces overhead I worked out the following redesign of the with-statement context manager interaction that would allow us to get closer to the metal, while remaining compatible with PEP 434:

  • When entering the with-block we transform the object that constitutes the context manager to a ContextManager-object.
  • If the object that constitutes the context manager implements the ContextManager interface it is simply returned. This is where context managers written in Java get their huge benefit by getting really close to the metal.
  • Otherwise a default implementation of the ContextManager is returned. This object is created by retrieving the __exit__ method and invoking the __enter__ method of the context manager object.
  • The compiled code of the with-statement then only invokes the __enter__ and __exit__ methods of the returned ContextManager object.
  • This has the added benefit that even for context managers written in pure Python the ContextManager could be optimized and cached when we implement call site caching.

This specification was easily implemented by Jim and then he could rewrite the threading module in Java to let the lock implementation take direct benefit of the rewritten with-statement and thereby get the actual code really close to the locking and unlocking. The result were instantaneous and beyond expectation:

Not only did we improve performance, but we passed the performance of the JRuby equivalent! Even using the client compiler, with no warm up we perform almost two times better than JRuby. Turn on the server compiler and let the JIT warm up and perform all it's compilation and we end up with a speedup of slightly more than 50 times.

A disclaimer is appropriate here. With the first benchmark (before this was optimized) I didn't have time to wait for a full warmup. This because of the fact that the benchmark was so incredibly slow at that point and the fact that I was doing the benchmarks quite late before the presentation and didn't have time to leave it running over the night. Instead I turned down the compilation threshold of the Hotspot server compiler and ran just a few warmup iterations. It is possible that the JVM could have optimized the previous code slightly better given (a lot) more time. The actual speedup might be closer to the speedup from the first code to the new code using the client compiler and no warmup. But this is still a speedup of almost 20 times, which is still something I'm very proud of. There is also the possibility that I didn't run/implement the JRuby version in the best possible way, meaning that there might be ways of making the JRuby version run faster that I don't know about. The new figures are still very nice, much nicer than the old ones for sure.

The current state of performance of Jython synchronization primitives

It is also interesting to compare how the current implementation compares to the other versions in Jython that I included in my presentation:

Without synchronization the code runs about three times as fast as with synchronization, but the counter does not return the correct result here due to race conditions. It's interesting from the point of view of analyzing the overhead added by synchronization but not for an actual implementation. Two times overhead is quite good in my opinion. What is more interesting to see is that the fastest version from the presentation, the one using AtomicInteger, is now suffering from the overhead of reflection required for the method invocations compared to the synchronized version. In a system with more hardware threads (commonly referred to as "cores") the implementation based on AtomicInteger could still be faster though.

Where do we proceed from here?

Now that we have proven that it was possible to get a nice speedup from this redesign of the code paths the next step is to provide the same kind of optimizations for code written in pure Python. Providing a better version of contextlib.contextmanager that exploits these faster code paths should be the easiest way to improve context managers written in Python. Then there are of course a wide range of other areas in Python where performance could be improved through the same kind of thorough analysis. I don't know at this point what we will focus on next, but you can look forward to many more performance improvements in Jython in the time to come.

No comments: