Optimising Haskell: joy of profiling

My Advent of Code 2024 day 12 solution was the slowest of the whole event, so my first target for optimising. Going in, all I knew is that it took nearly a minute to find the right answer, despite the underlying algorithm (union-find) being theoretically fast. There are a number of things I could have done to speed up the program, but I had no idea which of them would be effective.

The original code, as written during the event itself, is in MainUFSlow.hs.

Enter the profiler to find where the program was spending its time.

cabal run advent12 --enable-profiling -- +RTS -N -p -s -hT

That produced a profile file that showed which functions took what time.

COST CENTRE          MODULE                           SRC                                     no.     entries  %time %alloc   %time %alloc

 main                Main                             advent12/MainUFSlow.hs:(77,1)-(95,28)   759           0    0.0    0.0   100.0  100.0
  findFenceLength    Main                             advent12/MainUFSlow.hs:(137,1)-(138,41) 778       19600    7.2    0.0    41.6   46.6
   meets             Main                             advent12/MainUFSlow.hs:(66,3)-(67,71)   795   384160000   15.6    2.6    34.4   46.6
    pos              Main                             advent12/MainUFSlow.hs:11:20-22         796   768320000    0.9    0.0     0.9    0.0
    neighbours       Main                             advent12/MainUFSlow.hs:143:1-54         797   384160000   13.5   29.9    17.8   44.0
     neighboursH     Main                             advent12/MainUFSlow.hs:141:1-47         798   384160000    4.3   14.1     4.3   14.1
     neighboursV     Main                             advent12/MainUFSlow.hs:142:1-47         799   384121080    0.0    0.0     0.0    0.0
    plant            Main                             advent12/MainUFSlow.hs:11:37-41         800      155680    0.0    0.0     0.0    0.0
  findRegions        Main                             advent12/MainUFSlow.hs:119:1-41         769           1    0.0    0.0    49.5   52.8
   merge             Main                             advent12/MainUFSlow.hs:65:10-22         771           0    0.0    0.0    49.5   52.8
    merge            Main                             advent12/MainUFSlow.hs:49:3-44          772           1    0.0    0.0    49.5   52.8
     mergeItem       Main                             advent12/MainUFSlow.hs:65:10-22         781           0    0.0    0.0    49.5   52.8
      mergeItem      Main                             advent12/MainUFSlow.hs:(52,3)-(53,45)   782       19600   14.2    6.2    49.5   52.8
       meets         Main                             advent12/MainUFSlow.hs:(66,3)-(67,71)   783   384160000   17.0    2.6    35.1   46.6
        pos          Main                             advent12/MainUFSlow.hs:11:20-22         784   768320000    0.9    0.0     0.9    0.0
        neighbours   Main                             advent12/MainUFSlow.hs:143:1-54         785   384160000   13.6   29.9    17.2   44.0
         neighboursH Main                             advent12/MainUFSlow.hs:141:1-47         786   384160000    3.5   14.1     3.5   14.1
         neighboursV Main                             advent12/MainUFSlow.hs:142:1-47         787   384121080    0.0    0.0     0.0    0.0

Finding fences

The first thing I noticed was that the program spent about 40% of its time in the initial call to findFenceLength, that found the length of fence around each plot in the map.

findFenceLength :: Region -> Plot -> Plot
findFenceLength region plot = plot { fenceLength = 4 - (length nbrs) }
  where nbrs = filter (meets plot) region

Most of that time was spent in neighbours in meets, which finds the plots that touch this one and have the same plant type. This function was called on the initial field, which contained all the plots in the input. This meant that each invocation of filter had to trawl through every plot in the input, even though most of them were far from the current plot and therefore would be discarded.

A better way was to do the same calculation, but after finding the regions. That also allowed me to simplify the Plot type, as it didn't need to keep the fenceLength there any more.

data Plot = Plot { pos :: Position, plant :: Char }
  deriving (Show, Eq, Ord)

part1 :: UFind Plot -> Int
part1 regions = sum $ fmap fenceCost $ distinctSets regions

perimeter :: Region -> Int
perimeter region = sum $ fmap (fenceLength region) region

fenceLength :: Region -> Plot -> Int
fenceLength region plot = 4 - (length nbrs)
  where nbrs = filter (meets plot) region

Making just that change reduced the runtime from 59 seconds to 32 seconds, a saving of 46%. This version of the program is present as MainNoFence.hs.

Version     Runtime     Fraction of original  
UFSlow 59.6 1.00
NoFence 32.0 0.54

Better meeting

Profiling of this version showed it was now spending about 83% of the runtime finding the initial regions for the map. This was the place to make further speedups. The profiling results showed that most of the time was spent in meets inside the mergeItem function. Making meets faster looked like the best option to try.

This is what it looked like.

class Ord a => Joinable a where
  ...
  merge :: UFind a -> UFind a
  merge uf = foldl' mergeItem uf $ M.keys uf
  
  mergeItem :: UFind a -> a -> UFind a
  mergeItem uf x = foldl' (\u y -> join u x y) uf nbrs
    where nbrs = filter (meets x) $ M.keys uf
    
instance Joinable Plot where
  meets plot1 plot2 = 
    plot1.pos `elem` neighbours plot2.pos && plot1.plant == plot2.plant

instance Joinable SideFragment where
  meets (SideFragment p1 T) (SideFragment p2 T) = p1 `elem` neighboursH p2
  meets (SideFragment p1 B) (SideFragment p2 B) = p1 `elem` neighboursH p2
  meets (SideFragment p1 L) (SideFragment p2 L) = p1 `elem` neighboursV p2
  meets (SideFragment p1 R) (SideFragment p2 R) = p1 `elem` neighboursV p2
  meets _ _ = False

I could see that there's a lot of repeated work being done in meets. The imperative-style description of the processing looks like this:

  • For each plot (x or plot1)
    • For each other plot (plot2)
      • For each neighbour of plot2
        • Does this new plot equal plot1?

That's a lot of neighbours being generated. A better approach would be to move the generation of neighbours to plot1, giving something like:

  • For each plot (x or plot1)
    • For each neighbour of plot1
      • For each other plot plot2
        • Is this one of the neighbours of plot1?

Moving the generation of neighbours up a level should save some time.

That gives a new definition of mergeItem, which itself uses a new function adjacents. That's all the change that's needed.

mergeItem :: UFind a -> a -> UFind a
  mergeItem uf x = foldl' (\u y -> join u x y) uf nbrs
    where nbrs = filter (`elem` (adjacents x)) $ M.keys uf

instance Joinable Plot where
  adjacents plot = fmap (\p -> plot { pos = p }) $ neighbours $ pos plot

instance Joinable SideFragment where
  adjacents (SideFragment p T) = fmap (\p' -> SideFragment p' T) $ neighboursH p
  adjacents (SideFragment p B) = fmap (\p' -> SideFragment p' B) $ neighboursH p
  adjacents (SideFragment p L) = fmap (\p' -> SideFragment p' L) $ neighboursV p
  adjacents (SideFragment p R) = fmap (\p' -> SideFragment p' R) $ neighboursV p

That produces a modest speedup, producing answers in about 23 seconds. That's about two thirds of the time of the "no fence" version and one third of the original time. This version of the program is present as MainAdjacent.hs.

Version     Runtime     Fraction of original  
UFSlow 59.6 1.00
NoFence 32.0 0.54
Adjacent 23.3 0.36

Using internals

That's better than we had, but not great. Profiling the MainAdjacent version isn't overly helpful. It shows that the program spends 74% of its time in mergeItem, but it's not clear how much of that is spent in the foldl' or filter functions within it. I can get that visibility by creating additional cost centres for these lines of code.

  mergeItem :: UFind a -> a -> UFind a
  mergeItem uf x = {-# SCC mergeItemFold #-} foldl' (\u y -> join u x y) uf nbrs
    where nbrs = {-# SCC mergeItemFilter #-} filter (`elem` (adjacents x)) $ M.keys uf

That shows that all the time in mergeItem is spent in the filter rather than the foldl'. But how to make the filter faster?

This call to filter is finding the plots (elements of uf that are adjacent to the current plot. A look in the documentation for Data.Map reveals the function restrictKeys. This limits a Map to have only the keys specified in the given Set. The implementation of Map and Set is similar, and uses the same tree-based representation of elements. A look at the source of restrictKeys indicates that it's using the detail of this representation to do the work. That should be faster. What happens if I use restrictKeys rather than filter?

  mergeItem :: UFind a -> a -> UFind a
  mergeItem uf x = foldl' (\u y -> join u x y) uf nbrs
    where nbrs = M.keys $ M.restrictKeys uf $ S.fromList $ adjacents x

That makes a huge difference. Just that one change drops the runtime from 23 seconds to less than 4 seconds. That's a factor of five speedup on the previous version, or a factor of fifteen on the original.

Version     Runtime     Fraction of original  
UFSlow 59.6 1.00
NoFence 32.0 0.54
Adjacent 23.3 0.36
Final 3.9 0.07

That's good enough for me.

Next steps?

The next step would have been converting the union-find structure from a Map to a mutable Array. That would allow me to build the union-find structure imperatively inside an ST monad, which would probably be faster. But that would have been a fair bit of change to the code, so I'm glad I didn't need it.

Code

You can get the code from my locally-hosted Git repo, or from Codeberg.