weak-refs/main.hs
2023-04-07 21:48:29 -04:00

336 lines
10 KiB
Haskell

{-# LANGUAGE MagicHash #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE ScopedTypeVariables #-}
import System.Mem
import System.Mem.Weak
import Data.IORef
import GHC.IORef
import GHC.STRef
import GHC.IO
import GHC.Weak
import GHC.Prim
import Control.Monad.Primitive
import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as IntMap
import Data.Set (Set)
import qualified Data.Set as Set
import System.IO.Unsafe
import Control.Monad
import Control.Concurrent
import Data.Foldable
main :: IO ()
main = do
testDualWeak
testWeakChain
testGraphChain
testGraphX
testGraphO
-- Demonstrate that a Weak's key keeps its value alive, even if the Weak is dead
testWeakChain :: IO ()
testWeakChain = do
a <- newIORef ()
b <- newIORef ()
_ <- mkWeakWithIORefKey a b
w <- mkWeakWithIORefKey b "b is still alive"
performGCUntilFinalizersQuiesce
Just "b is still alive" <- deRefWeak w
touch a
testGraphChain :: IO ()
testGraphChain = do
do
(nai, neo) <- buildGraphChain
performGCUntilFinalizersQuiesce
["a", "b", "c", "d", "e"] <- toList <$> findForwardVals nai
["a", "b", "c", "d", "e"] <- toList <$> findBackwardVals neo
touch (nai, neo)
do
(nai, _) <- buildGraphChain
performGCUntilFinalizersQuiesce
[] <- toList <$> findForwardVals nai
touch nai
do
(_, neo) <- buildGraphChain
performGCUntilFinalizersQuiesce
[] <- toList <$> findBackwardVals neo
touch neo
buildGraphChain :: IO (NodeInput String, NodeOutput String)
buildGraphChain = do
(nai, nao) <- newNode "a"
(nbi, nbo) <- newNode "b"
(nci, nco) <- newNode "c"
(ndi, ndo) <- newNode "d"
(nei, neo) <- newNode "e"
link nao nbi
link nbo nci
link nco ndi
link ndo nei
pure (nai, neo)
buildGraphX :: IO (NodeInput String, NodeInput String, NodeOutput String, NodeOutput String)
buildGraphX = do
(nai, nao) <- newNode "a"
(nbi, nbo) <- newNode "b"
(nci, nco) <- newNode "c"
(ndi, ndo) <- newNode "d"
(nei, neo) <- newNode "e"
link nao nci
link nbo nci
link nco ndi
link nco nei
pure (nai, nbi, ndo, neo)
testGraphX :: IO ()
testGraphX = do
do
(nai, nbi, ndo, neo) <- buildGraphX
performGCUntilFinalizersQuiesce
["a", "c", "d", "e"] <- toList <$> findForwardVals nai
["b", "c", "d", "e"] <- toList <$> findForwardVals nbi
["a", "b", "c", "d"] <- toList <$> findBackwardVals ndo
["a", "b", "c", "e"] <- toList <$> findBackwardVals neo
touch (nai, nbi, ndo, neo)
do
(nai, nbi, ndo, neo) <- buildGraphX
performGCUntilFinalizersQuiesce
["a", "c", "d"] <- toList <$> findForwardVals nai
["b", "c", "d"] <- toList <$> findForwardVals nbi
touch (nai, nbi, ndo)
do
(nai, nbi, ndo, neo) <- buildGraphX
performGCUntilFinalizersQuiesce
[] <- toList <$> findForwardVals nai
[] <- toList <$> findForwardVals nbi
touch (nai, nbi)
testGraphO :: IO ()
testGraphO = do
do
(nai, nao) <- buildGraphO
performGCUntilFinalizersQuiesce
["a", "b", "c"] <- toList <$> findBackward nao
["a", "b", "c"] <- toList <$> findForward nai
touch (nai, nao)
do
(_, nao) <- buildGraphO
performGCUntilFinalizersQuiesce
[] <- toList <$> findBackward nao
touch nao
do
(nai, _) <- buildGraphO
performGCUntilFinalizersQuiesce
[] <- toList <$> findForward nai
touch nai
buildGraphO :: IO (NodeInput String, NodeOutput String)
buildGraphO = do
(nai, nao) <- newNode "a"
(nbi, nbo) <- newNode "b"
(nci, nco) <- newNode "c"
link nao nbi
link nbo nci
link nco nai
pure (nai, nao)
data NodeInput a = NodeInput
{ _nodeInput_key :: !Int
, _nodeInput_contents :: !(Weak (IORef (Maybe a)))
, _nodeInput_forwardLinks :: !(IORef (IntMap (NodeInput a)))
, _nodeInput_backLinks :: !(Weak (IORef (IntMap (NodeOutput a))))
}
data NodeOutput a = NodeOutput
{ _nodeOutput_key :: !Int
, _nodeOutput_contents :: !(Weak (IORef (Maybe a)))
, _nodeOutput_forwardLinks :: !(Weak (IORef (IntMap (NodeInput a))))
, _nodeOutput_backLinks :: !(IORef (IntMap (NodeOutput a)))
}
{-# NOINLINE globalNodeIdRef #-}
globalNodeIdRef :: IORef Int
globalNodeIdRef = unsafePerformIO $ newIORef 1
newNodeId :: IO Int
newNodeId = atomicModifyIORef' globalNodeIdRef $ \n -> (succ n, n)
newNode :: a -> IO (NodeInput a, NodeOutput a)
newNode v = do
nodeId <- newNodeId
backLinks <- newIORef mempty
forwardLinks <- newIORef mempty
contentsRef <- newIORef $ Just v
w <- mkWeakWithIORefKey backLinks contentsRef
wBack <- mkWeakWithIORefKey backLinks backLinks
wForward <- mkWeakWithIORefKeyWithFinalizer forwardLinks forwardLinks $ do
writeIORef finalizerDidRunRef True
deRefWeak w >>= \case
Nothing -> pure ()
Just contentsRef' -> do
writeIORef contentsRef' Nothing
deRefWeak wBack >>= \case
Nothing -> pure ()
Just backLinks' -> do
writeIORef backLinks' mempty
--TODO: Clear out all the nodes forward of us
pure
( NodeInput
{ _nodeInput_key = nodeId
, _nodeInput_contents = w
, _nodeInput_backLinks = wBack
, _nodeInput_forwardLinks = forwardLinks
}
, NodeOutput
{ _nodeOutput_key = nodeId
, _nodeOutput_contents = w
, _nodeOutput_backLinks = backLinks
, _nodeOutput_forwardLinks = wForward
}
)
getNodeContents :: Weak (IORef (Maybe a)) -> IO (Maybe a)
getNodeContents w = do
deRefWeak w >>= \case
Nothing -> pure Nothing
Just r -> readIORef r
link :: NodeOutput a -> NodeInput a -> IO ()
link aOut bIn = do
deRefWeak (_nodeInput_backLinks bIn) >>= \case
Nothing -> pure ()
Just backLinks -> deRefWeak (_nodeOutput_forwardLinks aOut) >>= \case
Nothing -> pure ()
Just forwardLinks -> do
atomicModifyIORef' forwardLinks $ \m -> (IntMap.insert (_nodeInput_key bIn) bIn m, ())
atomicModifyIORef' backLinks $ \m -> (IntMap.insert (_nodeOutput_key aOut) aOut m, ())
findForwardVals :: Ord a => NodeInput a -> IO (Set a)
findForwardVals i = Set.fromList . IntMap.elems <$> findForward i
findForward :: forall a. NodeInput a -> IO (IntMap a)
findForward i = go mempty (IntMap.singleton (_nodeInput_key i) i)
where
go :: IntMap a -> IntMap (NodeInput a) -> IO (IntMap a)
go found toSearch = case IntMap.minViewWithKey toSearch of
Nothing -> pure found
Just ((k, thisIn), restToSearch) -> do
getNodeContents (_nodeInput_contents thisIn) >>= \case
Nothing -> go found restToSearch
Just this -> do
newLinks <- readIORef (_nodeInput_forwardLinks thisIn)
go (IntMap.insert k this found) (IntMap.union restToSearch (newLinks `IntMap.difference` found))
findBackwardVals :: Ord a => NodeOutput a -> IO (Set a)
findBackwardVals i = Set.fromList . IntMap.elems <$> findBackward i
findBackward :: forall a. NodeOutput a -> IO (IntMap a)
findBackward o = go mempty (IntMap.singleton (_nodeOutput_key o) o)
where
go :: IntMap a -> IntMap (NodeOutput a) -> IO (IntMap a)
go found toSearch = case IntMap.minViewWithKey toSearch of
Nothing -> pure found
Just ((k, thisOut), restToSearch) -> do
getNodeContents (_nodeOutput_contents thisOut) >>= \case
Nothing -> go found restToSearch
Just this -> do
newLinks <- readIORef (_nodeOutput_backLinks thisOut)
go (IntMap.insert k this found) (IntMap.union restToSearch (newLinks `IntMap.difference` found))
testDualWeak :: IO ()
testDualWeak = do
do
target <- newIORef ()
r <- mkWeakWithIORefKey target ()
(w, a, b) <- newDualWeak target
performGC
threadDelay 1000000
Just () <- deRefWeak r
Just _ <- getDualWeak w
touch (a, b)
performGCUntilFinalizersQuiesce
do
target <- newIORef ()
r <- mkWeakWithIORefKey target ()
(w, a, _) <- newDualWeak target
performGC
threadDelay 1000000
Nothing <- deRefWeak r
Nothing <- getDualWeak w
touch a
performGCUntilFinalizersQuiesce
do
target <- newIORef ()
r <- mkWeakWithIORefKey target ()
(w, _, b) <- newDualWeak target
performGC
threadDelay 1000000
performGC
Nothing <- deRefWeak r
Nothing <- getDualWeak w
touch b
performGCUntilFinalizersQuiesce
performGCUntilFinalizersQuiesce :: IO ()
performGCUntilFinalizersQuiesce = do
writeIORef finalizerDidRunRef False
performGC
fence <- createFinalizerFence -- Based on my shaky memory, I think finalizers are processed in a queue, and therefore this fence will run after all the other finalizers in the GC we just finished. However, it may run at ANY point in the subsequent GC, so we can't rely on it to know that *that* GC had no finalizers run.
performGC
waitForFinalizerFence fence
readIORef finalizerDidRunRef >>= \case
True -> performGCUntilFinalizersQuiesce
False -> pure ()
{-# NOINLINE finalizerDidRunRef #-}
finalizerDidRunRef :: IORef Bool
finalizerDidRunRef = unsafePerformIO $ newIORef False
newtype FinalizerFence = FinalizerFence (MVar ())
createFinalizerFence :: IO FinalizerFence
createFinalizerFence = do
r <- newIORef ()
v <- newEmptyMVar
_ <- mkWeakWithIORefKeyWithFinalizer r () $ putMVar v ()
pure $ FinalizerFence v
waitForFinalizerFence :: FinalizerFence -> IO ()
waitForFinalizerFence (FinalizerFence v) = takeMVar v
newtype DualWeak a = DualWeak (Weak (Weak (IORef (Maybe a))))
newtype Ticket = Ticket (IORef ())
-- | Creates a weak reference to `a` which only remains alive if *both* tickets are alive
newDualWeak :: a -> IO (DualWeak a, Ticket, Ticket)
newDualWeak v = do
vRef <- newIORef $ Just v
tInner <- newIORef ()
wInner <- mkWeakWithIORefKey tInner vRef
tOuter <- newIORef ()
wOuter <- mkWeakWithIORefKeyWithFinalizer tOuter wInner $ do
writeIORef finalizerDidRunRef True
deRefWeak wInner >>= \case
Nothing -> pure ()
Just vRef' -> do
writeIORef vRef' Nothing
pure (DualWeak wOuter, Ticket tOuter, Ticket tInner)
getDualWeak :: DualWeak a -> IO (Maybe a)
getDualWeak (DualWeak wOuter) = do
deRefWeak wOuter >>= \case
Nothing -> pure Nothing
Just wInner -> deRefWeak wInner >>= \case
Nothing -> pure Nothing
Just vRef -> readIORef vRef
mkWeakWithIORefKey :: IORef a -> b -> IO (Weak b)
mkWeakWithIORefKey (IORef (STRef r#)) v = IO $ \s ->
case mkWeakNoFinalizer# r# v s of (# s1, w #) -> (# s1, Weak w #)
mkWeakWithIORefKeyWithFinalizer :: IORef a -> b -> IO () -> IO (Weak b)
mkWeakWithIORefKeyWithFinalizer (IORef (STRef r#)) v (IO f) = IO $ \s ->
case mkWeak# r# v f s of (# s1, w #) -> (# s1, Weak w #)