module Main where import Data.Char (readLitChar, showLitChar) import Data.Function (on) import Data.Ix (inRange) import Data.List (foldl', nub, sort, sortBy) import Data.Map (Map) import qualified Data.Map as M import Data.Set (Set) import qualified Data.Set as S import Debug.Trace import System.Environment (getArgs) main :: IO () main = do input <- readFile . head =<< getArgs putStrLn $ unwords ["Part 1:", show $ part1 1000 input] putStrLn $ unwords ["Part 2:", show $ part2 input] newtype Pos = Pos {unPos :: (Int, Int, Int)} deriving (Eq, Ord) instance Show Pos where showsPrec _ (Pos (x, y, z)) = showParen True $ shows x . showLitChar ',' . shows y . showLitChar ',' . shows z instance Read Pos where readsPrec _ s = do (x, ',' : rest) <- reads s (y, ',' : rest') <- reads rest (z, rest'') <- reads rest' pure (Pos (x, y, z), rest'') dist :: Pos -> Pos -> Double dist (Pos (x1, y1, z1)) (Pos (x2, y2, z2)) = sum . map ((^ 2) . fromIntegral) $ [x2 - x1, y2 - y1, z2 - z1] -- | Original idea: Combine all positions to a cartesian product and -- calculate the distances between them. Pick the n shortest distances -- and connect them. -- -- Problem: The problem space explodes and grinds to a halt. part1 :: Int -> String -> Int part1 n = go [] 0 . nub . sortBy (compare `on` fst) . traceShowId . (\ps -> combine ps ps) . map (read @Pos) . lines where combine :: [Pos] -> [Pos] -> [(Double, (Pos, Pos))] combine [] _ = [] combine (a : as) bs = trace ("combine: " ++ show (length as) ++ " positions left...") $ map (\a' -> (dist a a', (a, a'))) as ++ combine as bs go :: [(Set Pos, Int)] -> Int -> [(Double, (Pos, Pos))] -> Int go circuits _ [] = product . take 3 . sortBy invCompare $ map (S.size . fst) circuits go circuits connections ((d, (a, b)) : rest) | connections <= n = trace ("part1: " ++ show (n - connections) ++ " junctions left...") $ go (connect circuits a b (round d)) (connections + 1) rest | otherwise = product . take 3 . sortBy invCompare $ map (S.size . fst) circuits connect :: [(Set Pos, Int)] -> Pos -> Pos -> Int -> [(Set Pos, Int)] connect [] a b d = [(S.fromList [a, b], d)] connect ((c, l) : circuits) a b d | S.member a c && S.member b c = (c, l) : circuits | S.member a c = (S.insert b c, l + d) : circuits | S.member b c = (S.insert a c, l + d) : circuits | otherwise = (c, l) : connect circuits a b d invCompare :: (Ord a) => a -> a -> Ordering invCompare a b = b `compare` a part2 :: String -> Int part2 = error "Not implemented" testInput :: String testInput = unlines [ "162,817,812", "57,618,57", "906,360,560", "592,479,940", "352,342,300", "466,668,158", "542,29,236", "431,825,988", "739,650,466", "52,470,668", "216,146,977", "819,987,18", "117,168,530", "805,96,715", "346,949,466", "970,615,88", "941,993,340", "862,61,35", "984,92,344", "425,690,689" ]