integrations/haskell/app/RustClient.hs

105 lines
4.3 KiB
Haskell
Raw Normal View History

module RustClient (
callRustByteChecksum,
callRustBytePattern,
callRustMessage,
callRustSequence,
callRustSliceSum,
callRustSummary,
) where
import Control.Exception (bracket)
import Foreign.C.String (CString, peekCString, withCString)
import Foreign.C.Types (CInt (..), CSize (..), CUChar (..), CUInt (..))
import Foreign.Marshal.Alloc (alloca)
import Foreign.Marshal.Array (peekArray, withArrayLen)
import Foreign.Ptr (Ptr, nullPtr)
import Foreign.Storable (peek)
import Interop.Shared (SharedI32Buffer (..), SharedU8Buffer (..), Summary, SharedStats, summaryFromSharedStats)
foreign import ccall unsafe "rust_compute_stats"
rustComputeStats :: CInt -> CInt -> Ptr SharedStats -> IO CInt
foreign import ccall unsafe "rust_sum_slice"
rustSumSlice :: Ptr CInt -> CSize -> IO CInt
foreign import ccall unsafe "rust_checksum_bytes"
rustChecksumBytes :: Ptr CUChar -> CSize -> IO CUInt
foreign import ccall unsafe "rust_make_message"
rustMakeMessage :: CString -> CInt -> CInt -> IO CString
foreign import ccall unsafe "rust_make_sequence"
rustMakeSequence :: CInt -> CSize -> Ptr SharedI32Buffer -> IO CInt
foreign import ccall unsafe "rust_make_byte_pattern"
rustMakeBytePattern :: CUChar -> CSize -> Ptr SharedU8Buffer -> IO CInt
foreign import ccall unsafe "rust_free_string"
rustFreeString :: CString -> IO ()
foreign import ccall unsafe "rust_free_i32_buffer"
rustFreeI32Buffer :: Ptr CInt -> CSize -> CSize -> IO ()
foreign import ccall unsafe "rust_free_u8_buffer"
rustFreeU8Buffer :: Ptr CUChar -> CSize -> CSize -> IO ()
callRustSummary :: Int -> Int -> IO Summary
callRustSummary left right =
alloca $ \outStats -> do
status <- rustComputeStats (fromIntegral left) (fromIntegral right) outStats
if status /= 0
then fail ("rustComputeStats returned status " ++ show status)
else summaryFromSharedStats <$> peek outStats
callRustSliceSum :: [Int] -> IO Int
callRustSliceSum values =
withArrayLen (map fromIntegral values) $ \valueCount valuesPtr ->
fmap fromIntegral (rustSumSlice valuesPtr (fromIntegral valueCount))
callRustByteChecksum :: [Int] -> IO Int
callRustByteChecksum values =
withArrayLen (map (fromIntegral . (`mod` 256)) values) $ \valueCount valuesPtr ->
fmap fromIntegral (rustChecksumBytes valuesPtr (fromIntegral valueCount))
callRustMessage :: String -> Int -> Int -> IO String
callRustMessage name left right =
withCString name $ \namePtr -> do
messagePtr <- rustMakeMessage namePtr (fromIntegral left) (fromIntegral right)
if messagePtr == nullPtr
then fail "rustMakeMessage returned a null pointer"
else
bracket
(pure messagePtr)
rustFreeString
peekCString
callRustSequence :: Int -> Int -> IO [Int]
callRustSequence start count =
alloca $ \outBuffer -> do
status <- rustMakeSequence (fromIntegral start) (fromIntegral count) outBuffer
if status /= 0
then fail ("rustMakeSequence returned status " ++ show status)
else do
buffer <- peek outBuffer
values <-
if sharedBufferPtr buffer == nullPtr
then pure []
else map fromIntegral <$> peekArray (fromIntegral (sharedBufferLen buffer)) (sharedBufferPtr buffer)
rustFreeI32Buffer (sharedBufferPtr buffer) (sharedBufferLen buffer) (sharedBufferCap buffer)
pure values
callRustBytePattern :: Int -> Int -> IO [Int]
callRustBytePattern seed count =
alloca $ \outBuffer -> do
status <- rustMakeBytePattern (fromIntegral (seed `mod` 256)) (fromIntegral count) outBuffer
if status /= 0
then fail ("rustMakeBytePattern returned status " ++ show status)
else do
buffer <- peek outBuffer
values <-
if sharedByteBufferPtr buffer == nullPtr
then pure []
else map fromIntegral <$> peekArray (fromIntegral (sharedByteBufferLen buffer)) (sharedByteBufferPtr buffer)
rustFreeU8Buffer (sharedByteBufferPtr buffer) (sharedByteBufferLen buffer) (sharedByteBufferCap buffer)
pure values