module Main where import Data.Char (showLitChar) import Data.Function (on) import Data.List (intersect, partition, sortBy, uncons) import Data.Set (Set) import Data.Set qualified as S import System.Environment (getArgs) main :: IO () main = do input <- maybe (error "Missing argument to input file") (readFile . fst) =<< uncons <$> getArgs putStrLn $ unwords ["Part 1:", show $ part1 1000 input] putStrLn $ unwords ["Part 2:", show $ part2 input] newtype Pos = Pos {unPos :: (Int, Int, Int)} deriving (Eq, Ord) instance Show Pos where showsPrec _ (Pos (x, y, z)) = showParen True $ shows x . showLitChar ',' . shows y . showLitChar ',' . shows z instance Read Pos where readsPrec _ s = do (x, ',' : rest) <- reads s (y, ',' : rest') <- reads rest (z, rest'') <- reads rest' pure (Pos (x, y, z), rest'') dist :: Pos -> Pos -> Double dist (Pos (x1, y1, z1)) (Pos (x2, y2, z2)) = sum . map ((^ (2 :: Integer)) . fromIntegral) $ [x2 - x1, y2 - y1, z2 - z1] distances :: [Pos] -> [(Double, (Pos, Pos))] distances bs = S.toList . S.map (\(a, b) -> (dist a b, (a, b))) $ pairs bs S.empty where pairs :: [Pos] -> Set (Pos, Pos) -> Set (Pos, Pos) pairs [] ps = ps pairs (a : as) ps = pairs as $ foldl' (\ps' b -> if a /= b then S.insert (max a b, min a b) ps' else ps') ps bs -- | Idea: Combine all positions to a cartesian product and -- calculate the distances between them. Pick the n shortest distances -- and connect them. part1 :: Int -> String -> Int part1 n = go [] . take n . sortBy (compare `on` fst) . distances . map (read @Pos) . lines where go :: [(Set Pos, Int)] -> [(Double, (Pos, Pos))] -> Int go circuits [] = product . take 3 . sortBy (flip compare) $ map (S.size . fst) $ circuits go circuits ((d, (a, b)) : rest) = go (connect circuits a b (round d)) rest -- check if the fuse boxes are already part of a circuit. -- -- There are a few possible scenarios: -- \* both fuse boxes are part of the same circuit => nothing happens. -- \* both fuse boxes are part of different circuits => merge circuits. -- \* one fuse box is part of a circuit, the other is not => add the unconnected fuse box to the circuit. -- \* none of the fuse boxes are part of any circuits => create a new circuit with the new boxes. connect :: [(Set Pos, Int)] -> Pos -> Pos -> Int -> [(Set Pos, Int)] connect [] a b d = [(S.fromList [a, b], d)] connect circuits a b d = case (partition (\p -> S.member a (fst p)) circuits, partition (\p -> S.member b (fst p)) circuits) of -- a & b is part of a circuit (([(c, l)], rest), ([(c', l')], rest')) | c == c' -> circuits | otherwise -> (S.union c c', l + l' + d) : intersect rest rest' -- no fuse box is part of any circuits (([], _), ([], _)) -> (S.fromList [a, b], d) : circuits -- a is part of a circuit (([(c, l)], rest), ([], _)) -> (S.insert b c, l + d) : rest -- b is part of a circuit (([], _), ([(c, l)], rest)) -> (S.insert a c, l + d) : rest _ -> error "Not implemented" dupl :: a -> (a, a) dupl a = (a, a) (***) :: (a -> b) -> (c -> d) -> (a, c) -> (b, d) (***) f1 f2 (a, b) = (f1 a, f2 b) -- | Idea: Combine all positions to a cartesian product and calculate -- the distances between them. Sort them in increasing distances and -- connect them until there are no more separate circuits. Keep track -- of unconnected fuse boxes as well. part2 :: String -> Int part2 = uncurry (go []) . (***) (sortBy (compare `on` fst) . distances) (S.fromList) . dupl . map (read @Pos) . lines where go :: [(Set Pos)] -> [(Double, (Pos, Pos))] -> Set Pos -> Int go _ [] _ = error "unreachable" go circuits ((_, (a, b)) : pairs) unconnected = case connect (circuits, unconnected) a b of (circuits', unconnected') | S.null unconnected' -> x a * x b | otherwise -> go circuits' pairs unconnected' x :: Pos -> Int x (Pos (x', _, _)) = x' -- check if the fuse boxes are already part of a circuit. -- -- There are a few possible scenarios: -- \* both fuse boxes are part of the same circuit => nothing happens. -- \* both fuse boxes are part of different circuits => merge circuits. -- \* one fuse box is part of a circuit, the other is not => add the unconnected fuse box to the circuit. -- \* none of the fuse boxes are part of any circuits => create a new circuit with the new boxes. connect :: ([(Set Pos)], Set Pos) -> Pos -> Pos -> ([(Set Pos)], Set Pos) connect (circuits, unconnected) a b = case (partition (\c -> S.member a c) circuits, partition (\c -> S.member b c) circuits) of -- a & b is part of a circuit (([c], rest), ([c'], rest')) | c == c' -> (circuits, unconnected) | otherwise -> (S.union c c' : intersect rest rest', unconnected) -- no fuse box is part of any circuits (([], _), ([], _)) -> (S.fromList [a, b] : circuits, S.delete b $ S.delete a unconnected) -- a is part of a circuit (([c], rest), ([], _)) -> (S.insert b c : rest, S.delete b unconnected) -- b is part of a circuit (([], _), ([c], rest)) -> (S.insert a c : rest, S.delete a unconnected) _ -> error "Not implemented" testInput :: String testInput = unlines [ "162,817,812", "57,618,57", "906,360,560", "592,479,940", "352,342,300", "466,668,158", "542,29,236", "431,825,988", "739,650,466", "52,470,668", "216,146,977", "819,987,18", "117,168,530", "805,96,715", "346,949,466", "970,615,88", "941,993,340", "862,61,35", "984,92,344", "425,690,689" ]