Haskell adventure: Project Euler #191

I thought I'd try to be like one of the cool Haskell kids and do one of those literate programming blog posts. That means you can copy and paste this whole blog post into a .lhs file and it will actually compile. This is about how I solved Project Euler problem 191 in Haskell. So if you're trying to work through the problems on your own, don't read this yet!

Let's get to it. Already read the problem statement? Cool.

> import Control.Monad ( guard )
> import Data.List ( groupBy
>                  , isInfixOf
>                  , sort 
>                  )

One of the most obvious things you might want to do, especially since the problem statement used the word "string," would be to represent the O's, L's, and A's as Chars. But I think it's nicer to use a distinct type like this:

> data Day = O | L | A
>   deriving ( Eq, Show )

As you'll see, it will also be helpful to have a list of the three possible "day" values:

> days :: [Day]
> days = [O, L, A]

Then we can represent a student's attendance record as a list of those values. (I'm calling it Record1 here because I'm later came up with a better representation, as you'll see.)

> type Record1 = [Day]

How do we know whether a particular record is prize-winning or not? This is more or less just a translation of what the problem tells us.

> prize1 :: Record1 -> Bool
> prize1 record = (not $ [A,A,A] `isInfixOf` record) && notLateTwice record
> notLateTwice :: Record1 -> Bool
> notLateTwice record = case filter (== L) record of
>   (L:L:_) -> False
>   _       -> True

Now we can attack the actual question: After n days, how many prize strings are possible? For n=0, there is only one, namely, an empty record:

> prizes1 :: Int -> [Record1]
> prizes1 0 = [[]]

On the nth day, take all the prize-winning strings from the (n-1)th day, and for each one, tack on an O, L, and A. Now you have three times as many strings, and you can check each of them to see if they're prize-winning. (We don't have to check the non-prize-winning strings from day n-1, because once you've lost the prize, there's no way to get it back.)

> prizes1 n = do
>   prevPrizeString <- prizes1 (n - 1)
>   nextDay <- days
>   let newString = nextDay:prevPrizeString
>   guard $ prize1 newString
>   return newString

The code after "do" gets evaluated several times, once for each possible combination of a string from prizes1 (n - 1) and a day from [O,L,A]. In another language, you might write this as a double "for" loop, or possibly a list comprehension. I probably could have used a Haskell list comprehension instead but I think this is nicer. Anyway, then we add the 'nextDay' onto the prize string from the (n - 1)th day, and check whether the result is still a prize string. If it is, we "return" it which means it will end up in the list of prize strings for day "n" and if not, the "guard" function ensures it will not be returned.

You may notice that I stuck nextDay onto the front of the list instead of the end. That's just because it's faster and cleaner than writing "++ [nextDay]" although it turns out to be useful later too, as you'll see. You can check that "length $ prizes1 4" is 43 which is a good sign we probably haven't messed up too badly yet. And then "length $ prizes1 30" should be the answer. I fired up ghci and typed it in, and ... nothing. The CPU cranked away but after several seconds, it hadn't come up with anything. The rule of thumb for Project Euler is that your code should run in a minute or less. But I had a sneaking suspicion that there was a solution for this problem that would run almost instantly. So let's optimize!

One thing to notice is, we don't really care about absences in the distant past. You only lose the prize if you're absent three consecutive times. And we'll never end up with a string like OAAAO because once you hit the third A, you've already lost your prize and we stop keeping track of you at all. So the function we pass to "guard" can just look for A's at the beginning of the string (remember, more recent days are at the beginning, not the end), rather than using "isInfixOf" to look for an "AAA" sequence anywhere in the string.

> checkRecord :: Record1 -> Bool
> checkRecord (A:A:A:_) = False
> checkRecord (L:ds) = L `notElem` ds
> checkRecord _ = True

If today is your third consecutive absence (A:A:A:_), you don't get a prize. If you were late today, you can still get a prize, but only if you were never late in the past. In all other cases, if you haven't already lost your prize, then you're still eligible for it. Now in the definition for "prizes1," we can just replace "guard $ prize newString" with "guard $ checkRecord newString" and it should be a bit faster. It was still well short of "instant" so I kept looking for better approaches.

Writing "checkRecord" was a step in the right direction, but we were still keeping track of lots of information we didn't actually care about. All that really matters is a student's current absence streak, and total number of times being late. So let's just store those, and not the actual sequences:

> data Record2 = Record2 { consecutiveAbsences :: Int, lates :: Int }
>   deriving ( Eq, Ord, Show )

You could also just use a tuple ((Int, Int)) but this way there's no risk of forgetting which field is which. Plus, this syntax is called "record syntax" so it's only appropriate to use it for our "Record" type, right? Right. Now that records aren't just lists, we can't tack on the next O, L, or A with the (:) operator -- we have to actually keep track of what those two Ints should be. That's what the (#) function does. (Why did I choose "#"? No particular reason, I just picked a character.) Anyway, here it is:

> (#) :: Record2 -> Day -> Record2
> r # O = r { consecutiveAbsences = 0 }
> r # L = r { consecutiveAbsences = 0, lates = lates r + 1 }
> r # A = r { consecutiveAbsences = consecutiveAbsences r + 1 }

If you're absent, we increase "consecutiveAbsences" by 1. If not, we reset it to 0. And if you're late, we increase "lates" by 1. Now it's really easy to check whether a particular record is prize-winning:

> prize2 :: Record2 -> Bool
> prize2 r = consecutiveAbsences r < 3 && lates r < 2

And we can do more or less the same thing we did before:

> prizes2 :: Int -> [Record2]
> prizes2 0 = [Record2 0 0]
> prizes2 n = do
>   r <- prizes2 (n - 1)
>   d <- days
>   let r' = r # d
>   guard (prize2 r')
>   return r'

This should be a bit faster, I think, at least in theory. But I was still convinced the "right" solution was instantaneous, and this one definitely wasn't. The problem is, we're still dealing with a number of records on the order of 330, and we really don't need to. If you look at the new Record2 type, you realize that there are only a few distinct records we ever care about: consecutiveAbsences only goes up to 3 and lates only goes up to 2 so there are only 3*2=6 possible records we'll ever care about. So instead of keeping a huge list containing several copies of identical records, we could just keep a list of the six possible records we actually care about, paired with a number indicating how many times that record should appear in the list:

> prizes3 :: Int -> [(Record2, Integer)]
> prizes3 0 = [(Record2 0 0, 1)]
> prizes3 numDays = reduce $ do
>   (record, count) <- prizes3 (numDays - 1)
>   day <- days
>   let record' = record # day
>   guard (prize2 record')
>   return (record', count)

If we leave out the "reduce" this will be the same as "prizes2", except that every record will be paired with a "1" which is kind of useless. The "reduce" function takes all the Record 0 0's and puts them together, then takes all the Record 0 1's and puts them together, and so on, each possible record value being paired with its total count. There are at least a couple ways to do this, but what I did was this:

> reduce :: [(Record2, Integer)] -> [(Record2, Integer)]
> reduce = map f . group . sort where
>   group = groupBy (\(r,_) (s,_) -> r == s)
>   f list@((r,_):_) = (r, sum $ map snd list)

Remember that with the (.) function, it's often easier to read right to left. So the reduce function takes a list of (Record2, Integer) pairs, sorts it, then calls "group" on that sorted list, then maps the function "f" over the result of that. The "group" function groups all the identical records together, returning a list of lists. Then the "f" function reduces each list into a single (Record2, Integer) pair. To get the total number of prize strings after n days, we can't just use "length" anymore; we need to sum the counts from all the pairs:

> prizeCount :: Int -> Integer
> prizeCount = sum . map snd . prizes3

Again, you can check that prizeCount 4 is 43 (really nice of the Project Euler people to give you that sanity check, isn't it?) and then

> answer :: Integer
> answer = prizeCount 30

and it runs instantly! If you didn't know much Haskell before. I hope you learned something from this post, or at least enjoyed kind of half-following along. If you did, maybe you can point out something I did wrong, or a more elegant way to accomplish one of these steps. Either way, leave a comment and let me know what you think!

1 comment:

Noah said...

Cool stuff!

Since there's only six records we care about, and the total number
of days doesn't change them (only their frequency), we can look at this as a state machine.
Each record is a state, and a single day transitions us between states.

It's easy to enumerate the available transitions for each record:

> transitions :: Map Record2 [Record2]
> transitions = fromList
> [ (r00, [ r00, r01, r10 ] )
> , (r01, [ r01, r11 ] )
> , (r10, [ r00, r01, r20 ] )
> , (r11, [ r01, r21 ] )
> , (r20, [ r00, r01 ] )
> , (r21, [ r01 ] )
> ]
> where r00 = Record2 0 0
> r01 = Record2 0 1
> r10 = Record2 1 0
> r11 = Record2 1 1
> r20 = Record2 2 0
> r21 = Record2 2 1

We can also reverse this, and enumerate the states that can transition to a given state:

> rtransitions :: Map Record2 [Record2]
> rtransitions = fromList
> [ (r00, [ r00, r10, r20 ] )
> , (r01, [ r00, r01, r10, r11, r20, r21 ] )
> , (r10, [ r00 ] )
> , (r11, [ r01 ] )
> , (r20, [ r10 ] )
> , (r21, [ r11 ] )
> ]
> where r00 = Record2 0 0
> r01 = Record2 0 1
> r10 = Record2 1 0
> r11 = Record2 1 1
> r20 = Record2 2 0
> r21 = Record2 2 1

What we really need to keep track of isn't individual states, but the freqencies of the various states.
So why don't we just do that?

> data Dist = Dist Int Int Int Int Int Int deriving (Eq, Ord, Show)

Now we can model the effect that a single day has on the distribution of states, and use
that to calculate the distribution of states for all possible days:

> step :: Dist -> Dist
> step (Dist r00 r01 r10 r11 r20 r21) =
> Dist (r00+r10+r20) (r00+r01+r10+r11+r20+r21) r00 r01 r10 r11

> prizeDists :: [Dist]
> prizeDists = iterate step $ Dist 1 0 0 0 0 0

Counting how many prize strings exist then is just a matter of summing the state frequencies:
> count :: Dist -> Int
> count (Dist r00 r01 r10 r11 r20 r21) = r00+r01+r10+r11+r20+r21
> prizeCounts :: [Int]
> prizeCounts = map count prizeDists

As a check, we drop into ghci:

ghci> take 10 prizeDists
[Dist 1 0 0 0 0 0,Dist 1 1 1 0 0 0,Dist 2 3 1 1 1 0,Dist 4 8 2 3 1 1,Dist 7 19 4 8 2 3,Dist 13 43 7 19 4 8,Dist 24 94 13 43 7 19,Dist 44 200 24 94 13 43,Dist 81 418 44 200 24 94,Dist 149 861 81 418 44 200]
ghci> take 10 prizeCounts

We could stop here, but then again, we could have stopped any time :). Is there a pattern here we can infer? A possible closed form? prizeCounts has a lot in it, but maybe there's some patterns in the other parts of the distribution, so let's try focusing on them.

> r00,r01,r10,r11,r20,r21 :: Dist -> Int
> r00 (Dist i _ _ _ _ _) = i
> r01 (Dist _ i _ _ _ _) = i
> r10 (Dist _ _ i _ _ _) = i
> r11 (Dist _ _ _ i _ _) = i
> r20 (Dist _ _ _ _ i _) = i
> r21 (Dist _ _ _ _ _ i) = i

ghci> take 10 $ map r00 prizeDists

I'll admit, the pattern didn't jump out at me, but after a while I figured out that 1+1+2=4, 1+2+4=7, 2+4+7=13, 4+7+13=24,... that these were
members of the sequence r00_{n+3} = r00_{n+2} + r00_{n+1} + r00_{n}. Wikipedia calls these the tribonacci numbers.

ghci> take 10 $ map r01 prizeDists

The best I could come up with for these is that r01_{n+2} = r00_{n+2} + r01_{n+2} + r01_{n+1} + r01_{n}.

r10 and r20 are just offsets of r00; likewise r11 and r21 are just offsets of r01.

The next step would be to figure out the closed form, but since I'm not a professor, I'll just say I don't know one, rather than leaving it as an exercise to the reader.