105 lines
4.3 KiB
Haskell
105 lines
4.3 KiB
Haskell
module RustClient (
|
|
callRustByteChecksum,
|
|
callRustBytePattern,
|
|
callRustMessage,
|
|
callRustSequence,
|
|
callRustSliceSum,
|
|
callRustSummary,
|
|
) where
|
|
|
|
import Control.Exception (bracket)
|
|
import Foreign.C.String (CString, peekCString, withCString)
|
|
import Foreign.C.Types (CInt (..), CSize (..), CUChar (..), CUInt (..))
|
|
import Foreign.Marshal.Alloc (alloca)
|
|
import Foreign.Marshal.Array (peekArray, withArrayLen)
|
|
import Foreign.Ptr (Ptr, nullPtr)
|
|
import Foreign.Storable (peek)
|
|
import Interop.Shared (SharedI32Buffer (..), SharedU8Buffer (..), Summary, SharedStats, summaryFromSharedStats)
|
|
|
|
foreign import ccall unsafe "rust_compute_stats"
|
|
rustComputeStats :: CInt -> CInt -> Ptr SharedStats -> IO CInt
|
|
|
|
foreign import ccall unsafe "rust_sum_slice"
|
|
rustSumSlice :: Ptr CInt -> CSize -> IO CInt
|
|
|
|
foreign import ccall unsafe "rust_checksum_bytes"
|
|
rustChecksumBytes :: Ptr CUChar -> CSize -> IO CUInt
|
|
|
|
foreign import ccall unsafe "rust_make_message"
|
|
rustMakeMessage :: CString -> CInt -> CInt -> IO CString
|
|
|
|
foreign import ccall unsafe "rust_make_sequence"
|
|
rustMakeSequence :: CInt -> CSize -> Ptr SharedI32Buffer -> IO CInt
|
|
|
|
foreign import ccall unsafe "rust_make_byte_pattern"
|
|
rustMakeBytePattern :: CUChar -> CSize -> Ptr SharedU8Buffer -> IO CInt
|
|
|
|
foreign import ccall unsafe "rust_free_string"
|
|
rustFreeString :: CString -> IO ()
|
|
|
|
foreign import ccall unsafe "rust_free_i32_buffer"
|
|
rustFreeI32Buffer :: Ptr CInt -> CSize -> CSize -> IO ()
|
|
|
|
foreign import ccall unsafe "rust_free_u8_buffer"
|
|
rustFreeU8Buffer :: Ptr CUChar -> CSize -> CSize -> IO ()
|
|
|
|
callRustSummary :: Int -> Int -> IO Summary
|
|
callRustSummary left right =
|
|
alloca $ \outStats -> do
|
|
status <- rustComputeStats (fromIntegral left) (fromIntegral right) outStats
|
|
if status /= 0
|
|
then fail ("rustComputeStats returned status " ++ show status)
|
|
else summaryFromSharedStats <$> peek outStats
|
|
|
|
callRustSliceSum :: [Int] -> IO Int
|
|
callRustSliceSum values =
|
|
withArrayLen (map fromIntegral values) $ \valueCount valuesPtr ->
|
|
fmap fromIntegral (rustSumSlice valuesPtr (fromIntegral valueCount))
|
|
|
|
callRustByteChecksum :: [Int] -> IO Int
|
|
callRustByteChecksum values =
|
|
withArrayLen (map (fromIntegral . (`mod` 256)) values) $ \valueCount valuesPtr ->
|
|
fmap fromIntegral (rustChecksumBytes valuesPtr (fromIntegral valueCount))
|
|
|
|
callRustMessage :: String -> Int -> Int -> IO String
|
|
callRustMessage name left right =
|
|
withCString name $ \namePtr -> do
|
|
messagePtr <- rustMakeMessage namePtr (fromIntegral left) (fromIntegral right)
|
|
if messagePtr == nullPtr
|
|
then fail "rustMakeMessage returned a null pointer"
|
|
else
|
|
bracket
|
|
(pure messagePtr)
|
|
rustFreeString
|
|
peekCString
|
|
|
|
callRustSequence :: Int -> Int -> IO [Int]
|
|
callRustSequence start count =
|
|
alloca $ \outBuffer -> do
|
|
status <- rustMakeSequence (fromIntegral start) (fromIntegral count) outBuffer
|
|
if status /= 0
|
|
then fail ("rustMakeSequence returned status " ++ show status)
|
|
else do
|
|
buffer <- peek outBuffer
|
|
values <-
|
|
if sharedBufferPtr buffer == nullPtr
|
|
then pure []
|
|
else map fromIntegral <$> peekArray (fromIntegral (sharedBufferLen buffer)) (sharedBufferPtr buffer)
|
|
rustFreeI32Buffer (sharedBufferPtr buffer) (sharedBufferLen buffer) (sharedBufferCap buffer)
|
|
pure values
|
|
|
|
callRustBytePattern :: Int -> Int -> IO [Int]
|
|
callRustBytePattern seed count =
|
|
alloca $ \outBuffer -> do
|
|
status <- rustMakeBytePattern (fromIntegral (seed `mod` 256)) (fromIntegral count) outBuffer
|
|
if status /= 0
|
|
then fail ("rustMakeBytePattern returned status " ++ show status)
|
|
else do
|
|
buffer <- peek outBuffer
|
|
values <-
|
|
if sharedByteBufferPtr buffer == nullPtr
|
|
then pure []
|
|
else map fromIntegral <$> peekArray (fromIntegral (sharedByteBufferLen buffer)) (sharedByteBufferPtr buffer)
|
|
rustFreeU8Buffer (sharedByteBufferPtr buffer) (sharedByteBufferLen buffer) (sharedByteBufferCap buffer)
|
|
pure values
|