summaryrefslogtreecommitdiff
path: root/src/Data/Graph/Dynamic/Internal/Avl.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Graph/Dynamic/Internal/Avl.hs')
-rw-r--r--src/Data/Graph/Dynamic/Internal/Avl.hs483
1 files changed, 483 insertions, 0 deletions
diff --git a/src/Data/Graph/Dynamic/Internal/Avl.hs b/src/Data/Graph/Dynamic/Internal/Avl.hs
new file mode 100644
index 0000000..4f02445
--- /dev/null
+++ b/src/Data/Graph/Dynamic/Internal/Avl.hs
@@ -0,0 +1,483 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE RecordWildCards #-}
+{-# LANGUAGE TypeFamilies #-}
+module Data.Graph.Dynamic.Internal.Avl
+ ( Tree
+
+ , singleton
+ , append
+ , concat
+ , join
+ , split
+ , root
+ , connected
+ , label
+ , aggregate
+ , toList
+
+ -- * Debugging only
+ , freeze
+ , print
+ , assertInvariants
+ , assertSingleton
+ , assertRoot
+ ) where
+
+import Control.Monad (foldM, when)
+import Control.Monad.Primitive (PrimMonad (..))
+import qualified Data.Graph.Dynamic.Internal.Tree as Class
+import Data.List.NonEmpty (NonEmpty)
+import qualified Data.List.NonEmpty as NonEmpty
+import Data.Monoid ((<>))
+import Data.Primitive.MutVar (MutVar)
+import qualified Data.Primitive.MutVar as MutVar
+import qualified Data.Tree as Tree
+import Prelude hiding (concat, print)
+
+data Tree s a v = Tree
+ { tParent :: {-# UNPACK #-} !(MutVar s (Tree s a v))
+ , tLeft :: {-# UNPACK #-} !(MutVar s (Tree s a v))
+ , tRight :: {-# UNPACK #-} !(MutVar s (Tree s a v))
+ , tAggs :: {-# UNPACK #-} !(MutVar s (Aggs v))
+ , tLabel :: !a
+ , tValue :: !v
+ }
+
+instance Eq (Tree s a v) where
+ -- Reference equality through a MutVar.
+ t1 == t2 = tParent t1 == tParent t2
+
+data Aggs v = Aggs
+ { aHeight :: {-# UNPACK #-} !Int
+ , aAggregate :: !v
+ } deriving (Eq, Show)
+
+emptyAggs :: Monoid v => Aggs v
+emptyAggs = Aggs 0 mempty
+
+singletonAggs :: v -> Aggs v
+singletonAggs = Aggs 1
+
+joinAggs :: Monoid v => Aggs v -> v -> Aggs v -> Aggs v
+joinAggs (Aggs lh la) a (Aggs rh ra) =
+ Aggs (max lh rh + 1) (la <> a <> ra)
+
+singleton :: PrimMonad m => a -> v -> m (Tree (PrimState m) a v)
+singleton tLabel tValue = do
+ tParent <- MutVar.newMutVar undefined
+ tLeft <- MutVar.newMutVar undefined
+ tRight <- MutVar.newMutVar undefined
+ tAggs <- MutVar.newMutVar $ singletonAggs tValue
+ let tree = Tree {..}
+ MutVar.writeMutVar tParent tree
+ MutVar.writeMutVar tLeft tree
+ MutVar.writeMutVar tRight tree
+ return tree
+
+root :: PrimMonad m => Tree (PrimState m) a v -> m (Tree (PrimState m) a v)
+root tree@Tree {..} = do
+ parent <- MutVar.readMutVar tParent
+ if parent == tree then return tree else root parent
+
+concat
+ :: (PrimMonad m, Monoid v)
+ => NonEmpty (Tree (PrimState m) a v)
+ -> m (Tree (PrimState m) a v)
+concat (x0 NonEmpty.:| xs0) =
+ foldM append x0 xs0
+
+split
+ :: (PrimMonad m, Monoid v)
+ => Tree (PrimState m) a v
+ -> m (Maybe (Tree (PrimState m) a v), Maybe (Tree (PrimState m) a v))
+split x0 = do
+ (mbL, mbR, p, left) <- cut x0
+ if p == x0 then
+ return (mbL, mbR)
+ else do
+ upwards mbL mbR p left
+ where
+ upwards lacc0 racc0 x left0 = do
+ (mbL, mbR, p, left1) <- cut x
+ if left0 then do
+ racc1 <- join racc0 x mbR
+ if p == x then
+ return (lacc0, Just racc1)
+ else
+ upwards lacc0 (Just racc1) p left1
+ else do
+ lacc1 <- join mbL x lacc0
+ if p == x then
+ return (Just lacc1, racc0)
+ else
+ upwards (Just lacc1) racc0 p left1
+
+ cut x = do
+ p <- MutVar.readMutVar (tParent x)
+ pl <- MutVar.readMutVar (tLeft p)
+ l <- MutVar.readMutVar (tLeft x)
+ r <- MutVar.readMutVar (tRight x)
+ when (l /= x) $ removeParent l
+ when (r /= x) $ removeParent r
+ removeParent x
+ removeLeft x
+ removeRight x
+ updateAggs x
+ if pl == x then removeLeft p else removeRight p
+ return
+ ( if l == x then Nothing else Just l
+ , if r == x then Nothing else Just r
+ , p
+ , pl == x
+ )
+
+append
+ :: (PrimMonad m, Monoid v)
+ => Tree (PrimState m) a v
+ -> Tree (PrimState m) a v
+ -> m (Tree (PrimState m) a v)
+append l0 r0 = do
+ -- NOTE: there is a faster way to do this by just following the right spine
+ -- and joining along the way.
+ rm <- getRightMost l0
+ (mbL, mbR) <- split rm
+ case mbR of
+ Just _ -> error "append: invalid state"
+ _ -> assertSingleton rm
+ join mbL rm (Just r0)
+ where
+ getRightMost x = do
+ r <- MutVar.readMutVar (tRight x)
+ if r == x then return x else getRightMost r
+
+connected
+ :: (PrimMonad m, Monoid v)
+ => Tree (PrimState m) a v
+ -> Tree (PrimState m) a v
+ -> m Bool
+connected x y = do
+ xr <- root x
+ yr <- root y
+ return $ xr == yr
+
+label :: (PrimMonad m, Monoid v) => Tree (PrimState m) a v -> m a
+label = return . tLabel
+
+aggregate :: (PrimMonad m, Monoid v) => Tree (PrimState m) a v -> m v
+aggregate = fmap aAggregate . MutVar.readMutVar . tAggs
+
+-- | For debugging/testing.
+toList
+ :: PrimMonad m => Tree (PrimState m) a v -> m [a]
+toList = go []
+ where
+ go acc0 tree@Tree {..} = do
+ left <- MutVar.readMutVar tLeft
+ right <- MutVar.readMutVar tRight
+ acc1 <- if right == tree then return acc0 else go acc0 right
+ let acc2 = tLabel : acc1
+ if left == tree then return acc2 else go acc2 left
+
+join
+ :: (PrimMonad m, Monoid v)
+ => Maybe (Tree (PrimState m) a v)
+ -> Tree (PrimState m) a v -- Must be a singleton
+ -> Maybe (Tree (PrimState m) a v)
+ -> m (Tree (PrimState m) a v)
+join mbL c mbR = do
+ lh <- maybe (return 0) (fmap aHeight . MutVar.readMutVar . tAggs) mbL
+ rh <- maybe (return 0) (fmap aHeight . MutVar.readMutVar . tAggs) mbR
+ if | lh > rh + 1, Just l <- mbL ->
+ joinRight l c mbR
+ | rh > lh + 1, Just r <- mbR ->
+ joinLeft mbL c r
+ | otherwise -> do
+ case mbL of Just l -> setLeft c l; _ -> return ()
+ case mbR of Just r -> setRight c r; _ -> return ()
+ updateAggs c
+ return c
+
+joinLeft
+ :: (PrimMonad m, Monoid v)
+ => Maybe (Tree (PrimState m) a v)
+ -> Tree (PrimState m) a v -- Must be a singleton
+ -> Tree (PrimState m) a v
+ -> m (Tree (PrimState m) a v)
+joinLeft mbL c r = do
+ rl <- MutVar.readMutVar (tLeft r)
+ rla <- leftAggs r rl
+
+ rr <- MutVar.readMutVar (tRight r)
+ rra <- rightAggs r rr
+
+ la <- maybe (return emptyAggs) (MutVar.readMutVar . tAggs) mbL
+
+ if aHeight rla <= aHeight la + 1 then do
+ setLeft r c
+ when (rl /= r) $ setRight c rl
+ case mbL of Just l -> setLeft c l; _ -> return ()
+
+ let !ca = joinAggs rla (tValue c) la
+
+ -- Invalidity in the parent is fixed with two rotations
+ if aHeight rra + 1 < aHeight ca then do
+ rotateLeft c rl
+ rotateRight r rl
+
+ updateAggs c
+ updateAggs r
+ updateAggsToRoot rl
+ else do
+ -- One rotation
+ updateAggs c
+ updateAggs r
+ upLeft r
+ else
+ joinLeft mbL c rl
+
+upLeft
+ :: (PrimMonad m, Monoid v)
+ => Tree (PrimState m) a v
+ -> m (Tree (PrimState m) a v)
+upLeft l = do
+ p <- MutVar.readMutVar (tParent l)
+ if p == l then
+ return l
+ else do
+ r <- MutVar.readMutVar (tRight p)
+ ra <- rightAggs p r
+ la <- leftAggs p l
+ if aHeight ra + 1 < aHeight la then do
+ rotateRight p l
+ updateAggs p
+ updateAggsToRoot l
+ else do
+ updateAggs p -- Stuff below us might have changed.
+ upLeft p
+
+joinRight
+ :: (PrimMonad m, Monoid v)
+ => Tree (PrimState m) a v
+ -> Tree (PrimState m) a v -- Must be a singleton
+ -> Maybe (Tree (PrimState m) a v)
+ -> m (Tree (PrimState m) a v)
+joinRight l c mbR = do
+ lr <- MutVar.readMutVar (tRight l)
+ lra <- rightAggs l lr
+
+ ll <- MutVar.readMutVar (tLeft l)
+ lla <- leftAggs l ll
+
+ ra <- maybe (return emptyAggs) (MutVar.readMutVar . tAggs) mbR
+
+ if aHeight lra <= aHeight ra + 1 then do
+ setRight l c
+ when (lr /= l) $ setLeft c lr
+ case mbR of Just r -> setRight c r; _ -> return ()
+
+ let !ca = joinAggs lra (tValue c) ra
+
+ -- Invalidity in the parent is fixed with two rotations
+ if aHeight lla + 1 < aHeight ca then do
+ rotateRight c lr
+ rotateLeft l lr
+
+ -- Many of these are already computed...
+ updateAggs l
+ updateAggs c
+ updateAggsToRoot lr
+ else do
+ -- One rotation
+ updateAggs c
+ updateAggs l
+ upRight l
+ else
+ joinRight lr c mbR
+
+upRight
+ :: (PrimMonad m, Monoid v)
+ => Tree (PrimState m) a v
+ -> m (Tree (PrimState m) a v)
+upRight r = do
+ p <- MutVar.readMutVar (tParent r)
+ if p == r then
+ return p
+ else do
+ l <- MutVar.readMutVar (tLeft p)
+ la <- leftAggs p l
+ ra <- rightAggs p r
+ if aHeight la + 1 < aHeight ra then do
+ rotateLeft p r
+ updateAggs p
+ updateAggsToRoot r
+ else do
+ updateAggs p -- Stuff below us might have changed.
+ upRight p
+
+rotateLeft, rotateRight
+ :: PrimMonad m
+ => Tree (PrimState m) a v -- X's parent
+ -> Tree (PrimState m) a v -- X
+ -> m ()
+rotateLeft p x = do
+ b <- MutVar.readMutVar (tLeft x)
+ if b == x then removeRight p else setRight p b
+ gp <- MutVar.readMutVar (tParent p)
+ if gp == p then removeParent x else replace gp p x
+ setLeft x p
+rotateRight p x = do
+ b <- MutVar.readMutVar (tRight x)
+ if b == x then removeLeft p else setLeft p b
+ gp <- MutVar.readMutVar (tParent p)
+ if gp == p then removeParent x else replace gp p x
+ setRight x p
+
+setLeft, setRight
+ :: PrimMonad m
+ => Tree (PrimState m) a v -- Parent
+ -> Tree (PrimState m) a v -- New child
+ -> m ()
+setLeft p x = do
+ MutVar.writeMutVar (tParent x) p
+ MutVar.writeMutVar (tLeft p) x
+setRight p x = do
+ MutVar.writeMutVar (tParent x) p
+ MutVar.writeMutVar (tRight p) x
+
+removeParent, removeLeft, removeRight
+ :: PrimMonad m
+ => Tree (PrimState m) a v -- Parent
+ -> m ()
+removeParent x = MutVar.writeMutVar (tParent x) x
+removeLeft x = MutVar.writeMutVar (tLeft x) x
+removeRight x = MutVar.writeMutVar (tRight x) x
+
+leftAggs, rightAggs
+ :: (PrimMonad m, Monoid v)
+ => Tree (PrimState m) a v -- Parent
+ -> Tree (PrimState m) a v -- Left or right child
+ -> m (Aggs v)
+leftAggs p l =
+ if p == l then return emptyAggs else MutVar.readMutVar (tAggs l)
+rightAggs p r =
+ if p == r then return emptyAggs else MutVar.readMutVar (tAggs r)
+
+-- | Replace X by Y in the tree. X must have a parent.
+replace
+ :: PrimMonad m
+ => Tree (PrimState m) a v -- ^ X's parent
+ -> Tree (PrimState m) a v -- ^ X
+ -> Tree (PrimState m) a v -- ^ Y
+ -> m ()
+replace p x y = do
+ pl <- MutVar.readMutVar (tLeft p)
+ MutVar.writeMutVar (tParent y) p
+ if pl == x
+ then MutVar.writeMutVar (tLeft p) y
+ else MutVar.writeMutVar (tRight p) y
+
+-- | Recompute the aggregate and height of a node.
+updateAggs
+ :: (Monoid v, PrimMonad m)
+ => Tree (PrimState m) a v
+ -> m ()
+updateAggs t = do
+ l <- MutVar.readMutVar (tLeft t)
+ r <- MutVar.readMutVar (tRight t)
+ la <- leftAggs t l
+ ra <- rightAggs t r
+ let !agg = joinAggs la (tValue t) ra
+ MutVar.writeMutVar (tAggs t) agg
+
+-- | Recompute aggregate and height all the way to the root of the tree.
+updateAggsToRoot
+ :: (PrimMonad m, Monoid v)
+ => Tree (PrimState m) a v
+ -> m (Tree (PrimState m) a v)
+updateAggsToRoot x = do
+ updateAggs x
+ p <- MutVar.readMutVar (tParent x)
+ if p == x then return x else updateAggsToRoot p
+
+-- | For debugging/testing.
+freeze :: PrimMonad m => Tree (PrimState m) a v -> m (Tree.Tree a)
+freeze tree@Tree {..} = do
+ left <- MutVar.readMutVar tLeft
+ right <- MutVar.readMutVar tRight
+ children <- sequence $
+ [freeze left | left /= tree] ++
+ [freeze right | right /= tree]
+ return $ Tree.Node tLabel children
+
+print :: Show a => Tree (PrimState IO) a v -> IO ()
+print = go 0
+ where
+ go d t@Tree {..} = do
+ left <- MutVar.readMutVar tLeft
+ when (left /= t) $ go (d + 1) left
+
+ putStrLn $ replicate d ' ' ++ show tLabel
+
+ right <- MutVar.readMutVar tRight
+ when (right /= t) $ go (d + 1) right
+
+assertInvariants
+ :: (PrimMonad m, Monoid v, Eq v, Show v) => Tree (PrimState m) a v -> m ()
+assertInvariants t = do
+ _ <- computeAggs t t
+ return ()
+ where
+ -- TODO: Check average
+ computeAggs p x = do
+ p' <- MutVar.readMutVar (tParent x)
+ when (p /= p') $ fail "broken parent pointer"
+
+ l <- MutVar.readMutVar (tLeft x)
+ r <- MutVar.readMutVar (tRight x)
+ la <- if l == x then return emptyAggs else computeAggs x l
+ ra <- if r == x then return emptyAggs else computeAggs x r
+
+ let actualAggs = joinAggs la (tValue x) ra
+ storedAggs <- MutVar.readMutVar (tAggs x)
+
+ when (actualAggs /= storedAggs) $ fail $
+ "error in stored aggregates: " ++ show storedAggs ++
+ ", actual: " ++ show actualAggs
+
+ when (abs (aHeight la - aHeight ra) > 1) $ fail "inbalanced"
+ return actualAggs
+
+assertSingleton :: PrimMonad m => Tree (PrimState m) a v -> m ()
+assertSingleton x = do
+ l <- MutVar.readMutVar (tLeft x)
+ r <- MutVar.readMutVar (tRight x)
+ p <- MutVar.readMutVar (tParent x)
+ when (l /= x || r /= x || p /= x) $ fail "not a singleton"
+
+assertRoot :: PrimMonad m => Tree (PrimState m) a v -> m ()
+assertRoot x = do
+ p <- MutVar.readMutVar (tParent x)
+ when (p /= x) $ fail "not the root"
+
+data TreeGen s = TreeGen
+
+instance Class.Tree Tree where
+ type TreeGen Tree = TreeGen
+ newTreeGen _ = return TreeGen
+
+ singleton _ = singleton
+ append = append
+ split = split
+ connected = connected
+ root = root
+ label = label
+ aggregate = aggregate
+ toList = toList
+
+instance Class.TestTree Tree where
+ print = print
+ assertInvariants = assertInvariants
+ assertSingleton = assertSingleton
+ assertRoot = assertRoot