module RustClient ( callRustByteChecksum, callRustBytePattern, callRustMessage, callRustSequence, callRustSliceSum, callRustSummary, ) where import Control.Exception (bracket) import Foreign.C.String (CString, peekCString, withCString) import Foreign.C.Types (CInt (..), CSize (..), CUChar (..), CUInt (..)) import Foreign.Marshal.Alloc (alloca) import Foreign.Marshal.Array (peekArray, withArrayLen) import Foreign.Ptr (Ptr, nullPtr) import Foreign.Storable (peek) import Interop.Shared (SharedI32Buffer (..), SharedU8Buffer (..), Summary, SharedStats, summaryFromSharedStats) foreign import ccall unsafe "rust_compute_stats" rustComputeStats :: CInt -> CInt -> Ptr SharedStats -> IO CInt foreign import ccall unsafe "rust_sum_slice" rustSumSlice :: Ptr CInt -> CSize -> IO CInt foreign import ccall unsafe "rust_checksum_bytes" rustChecksumBytes :: Ptr CUChar -> CSize -> IO CUInt foreign import ccall unsafe "rust_make_message" rustMakeMessage :: CString -> CInt -> CInt -> IO CString foreign import ccall unsafe "rust_make_sequence" rustMakeSequence :: CInt -> CSize -> Ptr SharedI32Buffer -> IO CInt foreign import ccall unsafe "rust_make_byte_pattern" rustMakeBytePattern :: CUChar -> CSize -> Ptr SharedU8Buffer -> IO CInt foreign import ccall unsafe "rust_free_string" rustFreeString :: CString -> IO () foreign import ccall unsafe "rust_free_i32_buffer" rustFreeI32Buffer :: Ptr CInt -> CSize -> CSize -> IO () foreign import ccall unsafe "rust_free_u8_buffer" rustFreeU8Buffer :: Ptr CUChar -> CSize -> CSize -> IO () callRustSummary :: Int -> Int -> IO Summary callRustSummary left right = alloca $ \outStats -> do status <- rustComputeStats (fromIntegral left) (fromIntegral right) outStats if status /= 0 then fail ("rustComputeStats returned status " ++ show status) else summaryFromSharedStats <$> peek outStats callRustSliceSum :: [Int] -> IO Int callRustSliceSum values = withArrayLen (map fromIntegral values) $ \valueCount valuesPtr -> fmap fromIntegral (rustSumSlice valuesPtr (fromIntegral valueCount)) callRustByteChecksum :: [Int] -> IO Int callRustByteChecksum values = withArrayLen (map (fromIntegral . (`mod` 256)) values) $ \valueCount valuesPtr -> fmap fromIntegral (rustChecksumBytes valuesPtr (fromIntegral valueCount)) callRustMessage :: String -> Int -> Int -> IO String callRustMessage name left right = withCString name $ \namePtr -> do messagePtr <- rustMakeMessage namePtr (fromIntegral left) (fromIntegral right) if messagePtr == nullPtr then fail "rustMakeMessage returned a null pointer" else bracket (pure messagePtr) rustFreeString peekCString callRustSequence :: Int -> Int -> IO [Int] callRustSequence start count = alloca $ \outBuffer -> do status <- rustMakeSequence (fromIntegral start) (fromIntegral count) outBuffer if status /= 0 then fail ("rustMakeSequence returned status " ++ show status) else do buffer <- peek outBuffer values <- if sharedBufferPtr buffer == nullPtr then pure [] else map fromIntegral <$> peekArray (fromIntegral (sharedBufferLen buffer)) (sharedBufferPtr buffer) rustFreeI32Buffer (sharedBufferPtr buffer) (sharedBufferLen buffer) (sharedBufferCap buffer) pure values callRustBytePattern :: Int -> Int -> IO [Int] callRustBytePattern seed count = alloca $ \outBuffer -> do status <- rustMakeBytePattern (fromIntegral (seed `mod` 256)) (fromIntegral count) outBuffer if status /= 0 then fail ("rustMakeBytePattern returned status " ++ show status) else do buffer <- peek outBuffer values <- if sharedByteBufferPtr buffer == nullPtr then pure [] else map fromIntegral <$> peekArray (fromIntegral (sharedByteBufferLen buffer)) (sharedByteBufferPtr buffer) rustFreeU8Buffer (sharedByteBufferPtr buffer) (sharedByteBufferLen buffer) (sharedByteBufferCap buffer) pure values