SIMD Support for the vector library

26 Mar 2012

Now that the ICFP deadline is past, I’ve returned to working on adding SIMD support to GHC and associated libraries. The short-term goal is to be able to leverage SIMD instructions from the vector and—hopefully transparently—from Data Parallel Haskell. Back in November I added a fairly complete set of primops to GHC that generate SSE instructions when using the LLVM back-end. If you’re interested in the original design for SIMD support in GHC and the state of implementation, you can read about it here. It’s a bit of a hack and requires the LLVM back-end, but it does work.

Primops are necessary, but what we’d really like a higher-level interface to these low-level instructions. This post is a very short introduction to my recent efforts in that direction. All the code I describe is public—you can find directions for getting up and running with the GHC simd branch on github.

Because this is Haskell, we’ll start by introducing a data type for a SIMD vector that is indexed by the type of the scalar values it contains. The term “vector” is already overloaded, so, at Simon Peyton Jones’ suggestion, we call a SIMD vector containing scalars of type a a Multi a. Because we want to choose a different primitive representation for each different a, Multi is a type family (actually an associated type family). Along with Multi, we define a type class MultiPrim that allows us to treat primitive operations on Multi’s in a uniform way, just as the Prim type class defined by the primitive library allows for scalars. Here’s the first part of the definition of the MultiPrim type class and the Multi associated family. You can see that it defines functions for replicating scalars across Multi’s, folding a function over the scalar elements of a Multi, reading a Multi out of a ByteArray#, etc. Right now there are instance definitions for Multi Float, Multi Double, Multi Int32, Multi Int64, and Multi Int. This type class and the rest of the code I’ll be showing are actually part of the simd branch of the vector library that I’ve Put on github. You can go look there for further details, like the Num instances defined for the Multi’s.

class (Prim a, Prim (Multi a)) => MultiPrim a where
    data Multi a

    -- | The number of elements of type @a@ in a @Multi a@.
    multiplicity :: Multi a -> Int

    -- | Replicate a scalar across a @Multi a@.
    multireplicate :: a -> Multi a

    -- | Map a function over the elements of a @Multi a@.
    multimap :: (a -> a) -> Multi a -> Multi a

    -- | Fold a function over the elements of a @Multi a@.
    multifold :: (b -> a -> b) -> b -> Multi a -> b

    -- | Read a multi-value from the array. The offset is in elements of type
    -- @a@ rather than in elements of type @Multi a@.
    indexByteArrayAsMulti# :: ByteArray# -> Int# -> Multi a

Now that we have the Multi type, we would like to use it operate over Vector’s—that is, vector types from the vector library. A Vector has scalar elements, so for us to be able to use SIMD operations on these scalars we need to know something about the representation the Vector uses, namely that it lays out scalars contiguously in memory. The PackedVector type class lets us express this constraint in Haskell’s type system, and I won’t say anything more about it here, but instances are defined for the appropriate vector types in the Data.Vector.Unboxed and Data.Vector.Storable modules.

Of course the next step is to define appropriate versions of our old friends, map, zip, and fold, that will let us exploit SIMD operations. Here they are.

mmap :: (PackedVector v a, PackedVector v b)
     => (a -> b)
     -> (Multi a -> Multi b)
     -> v a
     -> v b
mzipWith :: (PackedVector v a, PackedVector v b, PackedVector v c)
         => (a -> b -> c)
	 -> (Multi a -> Multi b -> Multi c)
	 -> v a
	 -> v b
	 -> v c
mfoldl' :: PackedVector v b
        => (a -> b -> a)
	-> (a -> Multi b -> a)
	-> a
	-> v b
	-> a

If you’re familiar with the vector library, you may know it uses stream fusion to generate very efficient code—many operations are typically compiled to tight loops similar to what one would get from a C compiler. Stream fusion works by re-expressing high-level operations like map, zip, and fold in terms of “step” functions. Each step function takes some state and an element and produces either some new state and a new element, just some new state, or a value that says it is done processing elements. To support computing over vectors using SIMD operations, I have added a new “stream” variant so that step functions can receive not just scalar elements, but Multi elements. That is, at every step, the stream consumer could be handed either a scalar or a Multi and must be prepared for either case. mmap, mzipWith, and mfoldl are almost exactly like their scalar-only counterparts, but they each take an extra function argument for handling Multi’s.

Let’s see if this stuff actually works by starting off with something easy— summing up all the elements in a vector. The following code uses the new vector library primitives multifold and U.mfoldl to exploit SIMD instructions.

import qualified Data.Vector.Unboxed as U

import Data.Primitive.Multi

multisum :: U.Vector Float -> Float
multisum v =
    multifold (+) s ms
  where
    s  :: Float
    ms :: Multi Float
    (s, ms) = U.mfoldl' plus1 plusm (0, 0) v
    plusm (x, mx) my = (x, mx + my)
    plus1 (x, mx) y  = (x + y, mx)

We’ll compare it with five other versions. “Scalar” and “Scalar (C)” are plain old scalar versions written in Haskell and C, respectively. “Manual” and “Manual (C)” are hand-written Haskell and C versions, respectively. The Haskell version explicitly iterates over the vector instead of using a fold. The vector version is the code we just saw, and the multivector version is based on a library I wrote to test out fusion when I first added SSE support to GHC. It implements a small subset of the vector library API. Here we go [fn:bench: All timings were done on a laptop with a 2.70GHz Intel® Core™ i7-2620M CPU with frequency scaling disabled. 100 trials were performed at each data point. C code was compiled by GCC at -O3, and llc and opt were invoked with -O3.]

Timings for the sum function

Timings for the sum function

Not bad. The following table gives the timings for vectors with 2^24 elements. In this case, Haskell is as fast as C. This isn’t too surprising, as we’ve seen before that Haskell can be as fast as C.

Variant Time (ms)
Scalar 19.7 $\pm$ 0.2
Scalar (C) 19.7 $\pm$ 0.4
Manual 4.62 $\pm$ 0.03
Manual (C) 4.58 $\pm$ 0.02
vector 4.62 $\pm$ 0.02
multivector 4.62 $\pm$ 0.02

Summing vectors of size $2^{24}$.

Of course, summing up the elements in a vector isn’t so hard. The great thing about the vector library is that you can write high-level Haskell code and, through the magic of fusion, you end up with a tight inner loop that looks like what you might have gotten out of a C compiler had you chosen to write in C. Let’s try s slightly more difficult computation that will require fusion—dot product.

Computing the dot product efficiently requires fusing two loops to perform a combined addition and multiplication. Here is the scalar version in Haskell

import qualified Data.Vector.Unboxed as U

dotp :: U.Vector Float -> U.Vector Float -> Float
dotp v w =
    U.sum $ U.zipWith (*) v w

And here is our first cut at a SIMD version.

import qualified Data.Vector.Unboxed as U

import Data.Primitive.Multi

multidotp :: U.Vector Float -> U.Vector Float -> Float
multidotp v w =
    multifold (+) s ms
  where
    s  :: Float
    ms :: Multi Float
    (s, ms)          = U.mfoldl' plus1 plusm (0, 0) $ U.mzipWith (*) (*) v w
    plusm (x, mx) my = (x,     mx + my)
    plus1 (x, mx) y  = (x + y, mx)

Let’s look at performance once more. Again, “Manual” is a Haskell version that manually iterates over the vector once and fuses the addition and multiplication, the idea being that this is what we would hope to get out of GHC after fusion, inlining, constructor specialization, etc.

Timings for the dotp function

Timings for the dotp function

For reference, here are the timings for the case with $n = 2^{24}$ again.

Variant Time (ms)
Scalar 16.98 $\pm$ 0.08
Scalar (C) 16.63 $\pm$ 0.09
Manual 8.87 $\pm$ 0.03
Manual (C) 8.64 $\pm$ 0.02
vector 13.03 $\pm$ 0.07
multivector 9.5 $\pm$ 0.1

Calculating the dot product of vectors of size $2^{24}$.

Not so hot. Although our hand-written Haskell implementation (“Manual” in the plot and table) is competitive with C, the vector version is not. Interestingly, the “multivector” version /is/ competitive. What could be going on?

The first things that jumps to mind is that fusion might not be kicking in: I could’ve screwed up the implementation of the SIMD-enabled combinators! To check this hypothesis, let’s look at the GHC core generated for the main loop in multidotp (this is the loop that iterates over elements SIMD-vector-wise):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    letrec {
      $s$wmfoldlM_loopm_s4ri [Occ=LoopBreaker]
        :: GHC.Prim.Int#
           -> GHC.Prim.Int#
           -> GHC.Prim.~#
                *
                Data.Primitive.Multi.FloatX4.FloatX4
                (Data.Primitive.Multi.Multi GHC.Types.Float)
           -> GHC.Prim.FloatX4#
           -> GHC.Prim.Float#
           -> (# GHC.Types.Float,
                 Data.Primitive.Multi.Multi GHC.Types.Float #)
      [LclId, Arity=5, Str=DmdType LLLLL]
      $s$wmfoldlM_loopm_s4ri =
        \ (sc_s4nR :: GHC.Prim.Int#)
          (sc1_s4nS :: GHC.Prim.Int#)
          (sg_s4nT
             :: GHC.Prim.~#
                  *
                  Data.Primitive.Multi.FloatX4.FloatX4
                  (Data.Primitive.Multi.Multi GHC.Types.Float))
          (sc2_s4nU :: GHC.Prim.FloatX4#)
          (sc3_s4nV :: GHC.Prim.Float#) ->
          case GHC.Prim.>=# sc1_s4nS ipv7_aHm of _ {
            GHC.Types.False ->
              case GHC.Prim.indexFloatArrayAsFloatX4#
                     ipv2_s4kn (GHC.Prim.+# ipv_s4kl sc1_s4nS)
              of wild_a4j3 { __DEFAULT ->
              case GHC.Prim.>=# sc_s4nR ipv6_XI3 of _ {
                GHC.Types.False ->
                  case GHC.Prim.indexFloatArrayAsFloatX4#
                         ipv5_s4l7 (GHC.Prim.+# ipv3_s4l5 sc_s4nR)
                  of wild3_X4jF { __DEFAULT ->
                  $s$wmfoldlM_loopm_s4ri
                    (GHC.Prim.+# sc_s4nR 4)
                    (GHC.Prim.+# sc1_s4nS 4)
                    @~ (Sym (Data.Primitive.Multi.NTCo:R:MultiFloat) ; Sym
                                                                         (Data.Primitive.Multi.TFCo:R:MultiFloat))
                    (GHC.Prim.plusFloatX4#
                       sc2_s4nU (GHC.Prim.timesFloatX4# wild_a4j3 wild3_X4jF)) 
                    sc3_s4nV
                  };
                GHC.Types.True -> ...
              }
              };
            GHC.Types.True -> ...
          };

Core for main loop of multidotp

We can see that the two loops have been fused. I won’t show the core for the other Haskell implementations, but I’ll note that it looks pretty much the same except for one thing: multidotp is carrying around /two/ pieces of state during the fold it performs, a scalar Float and a Multi Float. That shouldn’t make a difference though—these guys should just live in two separate registers. There’s only one reasonable thing left to do: look at some assembly.

Just so we have an idea of what we want to see, let’s examine the inner loop of the C version first:

.L3:
	movaps	(%rdi,%rax), %xmm0
	mulps	(%rdx,%rax), %xmm0
	addq	$16, %rax
	cmpq	%r8, %rax
	addps	%xmm0, %xmm1
	jne	.L3

Inner loop of Manual (C).

Cool. Our array pointers live in rdi and rdx, our index in rax, and the array bounds in r8. Now on to the “manual” Haskell version.

.LBB5_3:                                # %n5oi
                                        # =>This Inner Loop Header: Depth=1
	movups	(%rcx), %xmm2
	movups	(%rdx), %xmm1
	mulps	%xmm2, %xmm1
	addps	%xmm1, %xmm0
	addq	$16, %rcx
	addq	$16, %rdx
	addq	$4, %r14
	cmpq	%r14, %rax
	jg	.LBB5_3

Inner loop of Manual.

Still pretty good. This time our array pointers live in rcx and rdx, our index in r14, and our bounds in rax. Note that the index is now measured in float’s instead of bytes. How about the “multivector” version?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
.LBB1_2:                                # %n3JW.i
                                        #   in Loop: Header=BB1_1 Depth=1
	cmpq	%rax, %r8
	jle	.LBB1_5
# BB#3:                                 # %n3K9.i
                                        #   in Loop: Header=BB1_1 Depth=1
	movq	8(%rcx), %rdx
	addq	%rax, %rdx
	movq	16(%rcx), %rdi
	movups	16(%rdi,%rdx,4), %xmm2
	movups	(%rbx), %xmm1
	mulps	%xmm2, %xmm1
	addps	%xmm1, %xmm0
	movups	%xmm0, -56(%rcx)        # multivector spill
	addq	$16, %rbx
	addq	$4, %rax
.LBB1_1:                                # %tailrecurse.i
                                        # =>This Inner Loop Header: Depth=1
	cmpq	%rax, %r9
	jg	.LBB1_2

Inner loop of multivector.

There is definitely more junk here. Still, not horrible except for line 14 where we spill the result to the stack. Apparently the spill doesn’t cost us much [fn:: On an AMD machine I have access to this spill does incur a penalty.]. Now the “vector” version that had performance issues.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
.LBB4_2:                                # %n4H3
                                        #   in Loop: Header=BB4_1 Depth=1
	cmpq	%r14, 43(%rbx)
	jle	.LBB4_5
# BB#3:                                 # %n4Hw
                                        #   in Loop: Header=BB4_1 Depth=1
	movq	35(%rbx), %rdx
	addq	%r14, %rdx
	movq	3(%rbx), %rcx
	movq	11(%rbx), %rdi
	movups	16(%rdi,%rdx,4), %xmm0
	movq	27(%rbx), %rdx
	addq	%rsi, %rdx
	movups	16(%rcx,%rdx,4), %xmm1
	mulps	%xmm0, %xmm1
	movups	(%rbp), %xmm0           # vector load
	addps	%xmm1, %xmm0
	movups	%xmm0, (%rbp)           # vector spill
	addq	$4, %r14
	addq	$4, %rsi
.LBB4_1:                                # %tailrecurse
                                        # =>This Inner Loop Header: Depth=1
	cmpq	%rsi, %rax
	jg	.LBB4_2

Inner loop of vector.

Ah-hah, there’s our likely culprit: our accumulator is loaded from the stack in line 16 and spilled back in line 17. Yuck! It looks like carrying around that extra bit of state really cost us. I’m not sure why LLVM didn’t spill the Float portion of the state to the stack temporarily so that it could use the register for the main loop, but it seems likely that it is related to the GHC calling convention used by the LLVM back-end.

I’m disappointed that we weren’t able to get C-competitive performance from our high-level Haskell code, especially since it seems so tantalizingly close. At least there is hope that with some prodding we can convince LLVM to keep our accumulating parameter in a register.

comments powered by Disqus