{-# LANGUAGE MagicHash #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE ScopedTypeVariables #-} import System.Mem import System.Mem.Weak import Data.IORef import GHC.IORef import GHC.STRef import GHC.IO import GHC.Weak import GHC.Prim import Control.Monad.Primitive import Data.IntMap.Strict (IntMap) import qualified Data.IntMap.Strict as IntMap import Data.Set (Set) import qualified Data.Set as Set import System.IO.Unsafe import Control.Monad import Control.Concurrent import Data.Foldable main :: IO () main = do testDualWeak testWeakChain testGraphChain testGraphX testGraphO -- Demonstrate that a Weak's key keeps its value alive, even if the Weak is dead testWeakChain :: IO () testWeakChain = do a <- newIORef () b <- newIORef () _ <- mkWeakWithIORefKey a b w <- mkWeakWithIORefKey b "b is still alive" performGCUntilFinalizersQuiesce Just "b is still alive" <- deRefWeak w touch a testGraphChain :: IO () testGraphChain = do do (nai, neo) <- buildGraphChain performGCUntilFinalizersQuiesce ["a", "b", "c", "d", "e"] <- toList <$> findForwardVals nai ["a", "b", "c", "d", "e"] <- toList <$> findBackwardVals neo touch (nai, neo) do (nai, _) <- buildGraphChain performGCUntilFinalizersQuiesce [] <- toList <$> findForwardVals nai touch nai do (_, neo) <- buildGraphChain performGCUntilFinalizersQuiesce [] <- toList <$> findBackwardVals neo touch neo buildGraphChain :: IO (NodeInput String, NodeOutput String) buildGraphChain = do (nai, nao) <- newNode "a" (nbi, nbo) <- newNode "b" (nci, nco) <- newNode "c" (ndi, ndo) <- newNode "d" (nei, neo) <- newNode "e" link nao nbi link nbo nci link nco ndi link ndo nei pure (nai, neo) buildGraphX :: IO (NodeInput String, NodeInput String, NodeOutput String, NodeOutput String) buildGraphX = do (nai, nao) <- newNode "a" (nbi, nbo) <- newNode "b" (nci, nco) <- newNode "c" (ndi, ndo) <- newNode "d" (nei, neo) <- newNode "e" link nao nci link nbo nci link nco ndi link nco nei pure (nai, nbi, ndo, neo) testGraphX :: IO () testGraphX = do do (nai, nbi, ndo, neo) <- buildGraphX performGCUntilFinalizersQuiesce ["a", "c", "d", "e"] <- toList <$> findForwardVals nai ["b", "c", "d", "e"] <- toList <$> findForwardVals nbi ["a", "b", "c", "d"] <- toList <$> findBackwardVals ndo ["a", "b", "c", "e"] <- toList <$> findBackwardVals neo touch (nai, nbi, ndo, neo) do (nai, nbi, ndo, neo) <- buildGraphX performGCUntilFinalizersQuiesce ["a", "c", "d"] <- toList <$> findForwardVals nai ["b", "c", "d"] <- toList <$> findForwardVals nbi touch (nai, nbi, ndo) do (nai, nbi, ndo, neo) <- buildGraphX performGCUntilFinalizersQuiesce [] <- toList <$> findForwardVals nai [] <- toList <$> findForwardVals nbi touch (nai, nbi) testGraphO :: IO () testGraphO = do do (nai, nao) <- buildGraphO performGCUntilFinalizersQuiesce ["a", "b", "c"] <- toList <$> findBackward nao ["a", "b", "c"] <- toList <$> findForward nai touch (nai, nao) do (_, nao) <- buildGraphO performGCUntilFinalizersQuiesce [] <- toList <$> findBackward nao touch nao do (nai, _) <- buildGraphO performGCUntilFinalizersQuiesce [] <- toList <$> findForward nai touch nai buildGraphO :: IO (NodeInput String, NodeOutput String) buildGraphO = do (nai, nao) <- newNode "a" (nbi, nbo) <- newNode "b" (nci, nco) <- newNode "c" link nao nbi link nbo nci link nco nai pure (nai, nao) data NodeInput a = NodeInput { _nodeInput_key :: !Int , _nodeInput_contents :: !(Weak (IORef (Maybe a))) , _nodeInput_forwardLinks :: !(IORef (IntMap (NodeInput a))) , _nodeInput_backLinks :: !(Weak (IORef (IntMap (NodeOutput a)))) } data NodeOutput a = NodeOutput { _nodeOutput_key :: !Int , _nodeOutput_contents :: !(Weak (IORef (Maybe a))) , _nodeOutput_forwardLinks :: !(Weak (IORef (IntMap (NodeInput a)))) , _nodeOutput_backLinks :: !(IORef (IntMap (NodeOutput a))) } {-# NOINLINE globalNodeIdRef #-} globalNodeIdRef :: IORef Int globalNodeIdRef = unsafePerformIO $ newIORef 1 newNodeId :: IO Int newNodeId = atomicModifyIORef' globalNodeIdRef $ \n -> (succ n, n) newNode :: a -> IO (NodeInput a, NodeOutput a) newNode v = do nodeId <- newNodeId backLinks <- newIORef mempty forwardLinks <- newIORef mempty contentsRef <- newIORef $ Just v w <- mkWeakWithIORefKey backLinks contentsRef wBack <- mkWeakWithIORefKey backLinks backLinks wForward <- mkWeakWithIORefKeyWithFinalizer forwardLinks forwardLinks $ do writeIORef finalizerDidRunRef True deRefWeak w >>= \case Nothing -> pure () Just contentsRef' -> do writeIORef contentsRef' Nothing deRefWeak wBack >>= \case Nothing -> pure () Just backLinks' -> do writeIORef backLinks' mempty --TODO: Clear out all the nodes forward of us pure ( NodeInput { _nodeInput_key = nodeId , _nodeInput_contents = w , _nodeInput_backLinks = wBack , _nodeInput_forwardLinks = forwardLinks } , NodeOutput { _nodeOutput_key = nodeId , _nodeOutput_contents = w , _nodeOutput_backLinks = backLinks , _nodeOutput_forwardLinks = wForward } ) getNodeContents :: Weak (IORef (Maybe a)) -> IO (Maybe a) getNodeContents w = do deRefWeak w >>= \case Nothing -> pure Nothing Just r -> readIORef r link :: NodeOutput a -> NodeInput a -> IO () link aOut bIn = do deRefWeak (_nodeInput_backLinks bIn) >>= \case Nothing -> pure () Just backLinks -> deRefWeak (_nodeOutput_forwardLinks aOut) >>= \case Nothing -> pure () Just forwardLinks -> do atomicModifyIORef' forwardLinks $ \m -> (IntMap.insert (_nodeInput_key bIn) bIn m, ()) atomicModifyIORef' backLinks $ \m -> (IntMap.insert (_nodeOutput_key aOut) aOut m, ()) findForwardVals :: Ord a => NodeInput a -> IO (Set a) findForwardVals i = Set.fromList . IntMap.elems <$> findForward i findForward :: forall a. NodeInput a -> IO (IntMap a) findForward i = go mempty (IntMap.singleton (_nodeInput_key i) i) where go :: IntMap a -> IntMap (NodeInput a) -> IO (IntMap a) go found toSearch = case IntMap.minViewWithKey toSearch of Nothing -> pure found Just ((k, thisIn), restToSearch) -> do getNodeContents (_nodeInput_contents thisIn) >>= \case Nothing -> go found restToSearch Just this -> do newLinks <- readIORef (_nodeInput_forwardLinks thisIn) go (IntMap.insert k this found) (IntMap.union restToSearch (newLinks `IntMap.difference` found)) findBackwardVals :: Ord a => NodeOutput a -> IO (Set a) findBackwardVals i = Set.fromList . IntMap.elems <$> findBackward i findBackward :: forall a. NodeOutput a -> IO (IntMap a) findBackward o = go mempty (IntMap.singleton (_nodeOutput_key o) o) where go :: IntMap a -> IntMap (NodeOutput a) -> IO (IntMap a) go found toSearch = case IntMap.minViewWithKey toSearch of Nothing -> pure found Just ((k, thisOut), restToSearch) -> do getNodeContents (_nodeOutput_contents thisOut) >>= \case Nothing -> go found restToSearch Just this -> do newLinks <- readIORef (_nodeOutput_backLinks thisOut) go (IntMap.insert k this found) (IntMap.union restToSearch (newLinks `IntMap.difference` found)) testDualWeak :: IO () testDualWeak = do do target <- newIORef () r <- mkWeakWithIORefKey target () (w, a, b) <- newDualWeak target performGC threadDelay 1000000 Just () <- deRefWeak r Just _ <- getDualWeak w touch (a, b) performGCUntilFinalizersQuiesce do target <- newIORef () r <- mkWeakWithIORefKey target () (w, a, _) <- newDualWeak target performGC threadDelay 1000000 Nothing <- deRefWeak r Nothing <- getDualWeak w touch a performGCUntilFinalizersQuiesce do target <- newIORef () r <- mkWeakWithIORefKey target () (w, _, b) <- newDualWeak target performGC threadDelay 1000000 performGC Nothing <- deRefWeak r Nothing <- getDualWeak w touch b performGCUntilFinalizersQuiesce performGCUntilFinalizersQuiesce :: IO () performGCUntilFinalizersQuiesce = do writeIORef finalizerDidRunRef False performGC fence <- createFinalizerFence -- Based on my shaky memory, I think finalizers are processed in a queue, and therefore this fence will run after all the other finalizers in the GC we just finished. However, it may run at ANY point in the subsequent GC, so we can't rely on it to know that *that* GC had no finalizers run. performGC waitForFinalizerFence fence readIORef finalizerDidRunRef >>= \case True -> performGCUntilFinalizersQuiesce False -> pure () {-# NOINLINE finalizerDidRunRef #-} finalizerDidRunRef :: IORef Bool finalizerDidRunRef = unsafePerformIO $ newIORef False newtype FinalizerFence = FinalizerFence (MVar ()) createFinalizerFence :: IO FinalizerFence createFinalizerFence = do r <- newIORef () v <- newEmptyMVar _ <- mkWeakWithIORefKeyWithFinalizer r () $ putMVar v () pure $ FinalizerFence v waitForFinalizerFence :: FinalizerFence -> IO () waitForFinalizerFence (FinalizerFence v) = takeMVar v newtype DualWeak a = DualWeak (Weak (Weak (IORef (Maybe a)))) newtype Ticket = Ticket (IORef ()) -- | Creates a weak reference to `a` which only remains alive if *both* tickets are alive newDualWeak :: a -> IO (DualWeak a, Ticket, Ticket) newDualWeak v = do vRef <- newIORef $ Just v tInner <- newIORef () wInner <- mkWeakWithIORefKey tInner vRef tOuter <- newIORef () wOuter <- mkWeakWithIORefKeyWithFinalizer tOuter wInner $ do writeIORef finalizerDidRunRef True deRefWeak wInner >>= \case Nothing -> pure () Just vRef' -> do writeIORef vRef' Nothing pure (DualWeak wOuter, Ticket tOuter, Ticket tInner) getDualWeak :: DualWeak a -> IO (Maybe a) getDualWeak (DualWeak wOuter) = do deRefWeak wOuter >>= \case Nothing -> pure Nothing Just wInner -> deRefWeak wInner >>= \case Nothing -> pure Nothing Just vRef -> readIORef vRef mkWeakWithIORefKey :: IORef a -> b -> IO (Weak b) mkWeakWithIORefKey (IORef (STRef r#)) v = IO $ \s -> case mkWeakNoFinalizer# r# v s of (# s1, w #) -> (# s1, Weak w #) mkWeakWithIORefKeyWithFinalizer :: IORef a -> b -> IO () -> IO (Weak b) mkWeakWithIORefKeyWithFinalizer (IORef (STRef r#)) v (IO f) = IO $ \s -> case mkWeak# r# v f s of (# s1, w #) -> (# s1, Weak w #)