mathr / blog / #

Fixing replicateM's space leak

In my previous post I had a problem with some simple Haskell code that exploded in memory. This morning I worked out why, and how to fix it, using the technique of difference lists. The problematic code was:

import Control.Monad (replicateM)

shapes bias p =
  [ s
  | m <- [0 .. p]
  , s <- replicateM (fromInteger m + 2) [1 .. p]
  , p == sum (zipWith (*) (map (bias +) s) (tail s))
  ]

main = mapM_ (print . length . shapes 0) [1..10]

To see how memory explodes, here's my session:

$ ghc -O2 replicateM.hs -rtsopts
...
$ timeout 60 ./replicateM +RTS -h
1
3
5
10
14
27
37
$ hp2pretty replicateM.hp

and here is the SVG graph output by hp2pretty:

heap profile graph showing a memory leak

The memory is going up dramatically, and while the program ran for a whole minute wall-clock, subtracting the garbage collection time gives less than 17 seconds of useful work.

I had a hunch that the problem is replicateM (the rest of the code is so simple that it almost obviously can't be anything else). I'm using for lists, where it does Cartesian product (output all possible combinations of one item from each of cnt0 copies of a list f). Let's look at the definition in the source code:

replicateM cnt0 f =
    loop cnt0
  where
    loop cnt
        | cnt <= 0  = pure []
        | otherwise = liftA2 (:) f (loop (cnt - 1))
-- base-4.16.1.0:Control.Monad.replicateM

To intuitively understand liftA2, one can desugar it to do-notation, which changes the type unless {-# LANGUAGE ApplicativeDo #-} is enabled:

replicateM cnt0 f =
    loop cnt0
  where
    loop cnt
        | cnt <= 0  = pure []
        | otherwise = do
            x <- f
            xs <- loop (cnt - 1)
            pure (x : xs)

but to see where the leak is coming from we can evaluate the critical part of the original version, specifically for the [String] type:

> result = words "hello world how are you"
> mapM_ putStrLn $ liftA2 (:) "¿¡" result
¿hello
¿world
¿how
¿are
¿you
¡hello
¡world
¡how
¡are
¡you

Immediately one can see that it needs all of result in order to print it once with a prefix, but then it needs it all again to print it with the second prefix. So because loop (cnt - 1) is passed as an argument to liftA2, its value will be shared in the same way as the result list of strings was. And sharing in this way forces the value to be kept in memory, and can only start to be freed when the last prefix has started being emitted.

The problem in replicateM is compounded, because the same happens in each recursive call to loop, although the retained values get smaller deeper down so its not such a huge deal. The problem is at the top of the recursion, where it has to store a list of length ((length f)^(n-1)), which can get pretty large.

So that's the issue, how to fix it? Difference lists traditionally work by turning appends like ((x ++ y) ++ z) ++ w, which is expensive because (++) cost is the length of the left hand side, and left nesting means the left hand side gets longer and longer, into something like (((prepend x . prepend y) . prepend z) . prepend w) empty which is much more efficient because the cost of function composition (.) is constant and prepend costs the length of its argument. In this example the cost of the first is X + X + Y + X + Y + Z = 3 X + 2 Y + Z, while the second is X + Y + Z + W (and the + W could be avoided by using it at the end instead of prepending w to empty).

Here's what I came up with, I'm not 100% sure why it works better, and I'm not sure if it is even correct for anything apart from m = [], and my first attempt output items reversed, and maybe it has similarities to foldl' vs foldr in terms of accumulator as well as using difference lists, but anyway:

{-# LANGUAGE ApplicativeDo #-}
replicateM' :: Applicative m => Int -> m a -> m [a]
replicateM' n ls = go n (pure id)
  where
    go n fs
      | n <= 0 = do
          f <- fs
          pure (f [])
      | otherwise = go (n - 1) gs
          where
            gs = do
              f <- fs
              l <- ls
              pure (f . (l:))

Here's the heap profile graph for the same program, just using this replicateM' to replace replicateM:

heap profile graph showing a memory leak

What a difference! Tiny constant memory, as it should be, and almost none of the time is spent garbage collecting.