Add Rust/Haskell slice and buffer interop demos
This commit is contained in:
parent
317f1d0a0c
commit
728322ef0e
@ -4,7 +4,6 @@ version = "0.1.0"
|
||||
edition = "2024"
|
||||
license = "MIT OR Apache-2.0"
|
||||
publish = false
|
||||
build = "rust/build.rs"
|
||||
|
||||
[lib]
|
||||
path = "rust/lib.rs"
|
||||
|
||||
14
README.md
14
README.md
@ -23,8 +23,20 @@ The boundary is deliberately C-shaped:
|
||||
|
||||
- integers
|
||||
- a shared struct layout
|
||||
- borrowed slices as `ptr + len`
|
||||
- owned buffers as `ptr + len + cap`
|
||||
- owned C strings with explicit free functions on each side
|
||||
|
||||
The current examples cover:
|
||||
|
||||
- scalar values and shared structs
|
||||
- borrowed `i32` slices
|
||||
- owned `i32` buffers returned across the boundary
|
||||
- borrowed raw byte buffers
|
||||
- owned raw byte buffers returned across the boundary
|
||||
|
||||
The byte examples are separate from C strings on purpose. They demonstrate handling embedded zero bytes safely.
|
||||
|
||||
## Why The Build Uses Both Static And Shared Libraries
|
||||
|
||||
This demo uses different library styles for the two directions because that keeps each path simpler:
|
||||
@ -70,6 +82,8 @@ CABAL_DIR=$PWD/../.cabal XDG_STATE_HOME=$PWD/../.cabal/state XDG_CACHE_HOME=$PWD
|
||||
- the ABI boundary must stay simple and explicit
|
||||
- Rust and Haskell do not share ownership rules automatically
|
||||
- struct layout must match on both sides
|
||||
- borrowed data and owned returned buffers need different FFI patterns
|
||||
- raw bytes are not the same thing as C strings
|
||||
- Rust calling Haskell is the harder direction because it must initialize the GHC runtime correctly
|
||||
- build tooling is part of the integration problem, not just an implementation detail
|
||||
|
||||
|
||||
@ -19,10 +19,20 @@ The code keeps the FFI surface small on purpose. The boundary uses:
|
||||
|
||||
- integers
|
||||
- a fixed C-shaped struct
|
||||
- borrowed slices as `ptr + len`
|
||||
- owned returned buffers as `ptr + len + cap`
|
||||
- owned C strings with an explicit free function on each side
|
||||
|
||||
That is enough to demonstrate the main challenges without pulling in code generation or large bindings.
|
||||
|
||||
The current examples include:
|
||||
|
||||
- scalar stats and message passing
|
||||
- borrowed `i32` slices sent across the boundary
|
||||
- owned `i32` buffers returned across the boundary
|
||||
- borrowed raw byte buffers with embedded zero bytes
|
||||
- owned raw byte buffers returned across the boundary
|
||||
|
||||
## Build And Run
|
||||
|
||||
From the repository root:
|
||||
@ -56,6 +66,8 @@ cargo run -- rust-calls-haskell Ada 7 5
|
||||
|
||||
- The boundary must stay C-shaped. Rich Rust and Haskell types do not cross directly.
|
||||
- Strings need explicit ownership rules. Each side exports its own free function.
|
||||
- Borrowed slices and owned returned buffers are different patterns and must be modeled differently.
|
||||
- Raw byte buffers should be treated separately from C strings.
|
||||
- Struct layout must be mirrored carefully on both sides.
|
||||
- Rust calling Haskell is the harder direction because it must initialize and shut down the GHC runtime correctly.
|
||||
- Build order is part of the design. Haskell links against the Rust static library, and Rust loads the Haskell foreign library after Cabal builds it.
|
||||
|
||||
@ -1,7 +1,14 @@
|
||||
module Main (main) where
|
||||
|
||||
import Interop.Shared (Summary (..))
|
||||
import RustClient (callRustMessage, callRustSummary)
|
||||
import RustClient (
|
||||
callRustByteChecksum,
|
||||
callRustBytePattern,
|
||||
callRustMessage,
|
||||
callRustSequence,
|
||||
callRustSliceSum,
|
||||
callRustSummary,
|
||||
)
|
||||
import System.Environment (getArgs)
|
||||
import Text.Read (readMaybe)
|
||||
|
||||
@ -11,11 +18,23 @@ main = do
|
||||
let (name, left, right) = parseArgs args
|
||||
summary <- callRustSummary left right
|
||||
message <- callRustMessage name left right
|
||||
let sliceValues = [2, 4, 6, 8]
|
||||
sliceSum <- callRustSliceSum sliceValues
|
||||
sequenceValues <- callRustSequence left 5
|
||||
let byteValues = [72, 0, 105, 255]
|
||||
byteChecksum <- callRustByteChecksum byteValues
|
||||
bytePattern <- callRustBytePattern left 6
|
||||
|
||||
putStrLn "Haskell -> Rust demo"
|
||||
putStrLn $ "Inputs: name=" ++ name ++ ", left=" ++ show left ++ ", right=" ++ show right
|
||||
putStrLn $ "Stats from Rust: " ++ renderSummary summary
|
||||
putStrLn $ "Message from Rust: " ++ message
|
||||
putStrLn $ "Slice sent to Rust: " ++ show sliceValues
|
||||
putStrLn $ "Rust summed slice to: " ++ show sliceSum
|
||||
putStrLn $ "Vector returned from Rust: " ++ show sequenceValues
|
||||
putStrLn $ "Byte slice sent to Rust: " ++ show byteValues
|
||||
putStrLn $ "Rust checksummed bytes to: " ++ show byteChecksum
|
||||
putStrLn $ "Byte buffer returned from Rust: " ++ show bytePattern
|
||||
|
||||
parseArgs :: [String] -> (String, Int, Int)
|
||||
parseArgs args =
|
||||
|
||||
@ -1,25 +1,48 @@
|
||||
module RustClient (
|
||||
callRustByteChecksum,
|
||||
callRustBytePattern,
|
||||
callRustMessage,
|
||||
callRustSequence,
|
||||
callRustSliceSum,
|
||||
callRustSummary,
|
||||
) where
|
||||
|
||||
import Control.Exception (bracket)
|
||||
import Foreign.C.String (CString, peekCString, withCString)
|
||||
import Foreign.C.Types (CInt (..))
|
||||
import Foreign.C.Types (CInt (..), CSize (..), CUChar (..), CUInt (..))
|
||||
import Foreign.Marshal.Alloc (alloca)
|
||||
import Foreign.Marshal.Array (peekArray, withArrayLen)
|
||||
import Foreign.Ptr (Ptr, nullPtr)
|
||||
import Foreign.Storable (peek)
|
||||
import Interop.Shared (Summary, SharedStats, summaryFromSharedStats)
|
||||
import Interop.Shared (SharedI32Buffer (..), SharedU8Buffer (..), Summary, SharedStats, summaryFromSharedStats)
|
||||
|
||||
foreign import ccall unsafe "rust_compute_stats"
|
||||
rustComputeStats :: CInt -> CInt -> Ptr SharedStats -> IO CInt
|
||||
|
||||
foreign import ccall unsafe "rust_sum_slice"
|
||||
rustSumSlice :: Ptr CInt -> CSize -> IO CInt
|
||||
|
||||
foreign import ccall unsafe "rust_checksum_bytes"
|
||||
rustChecksumBytes :: Ptr CUChar -> CSize -> IO CUInt
|
||||
|
||||
foreign import ccall unsafe "rust_make_message"
|
||||
rustMakeMessage :: CString -> CInt -> CInt -> IO CString
|
||||
|
||||
foreign import ccall unsafe "rust_make_sequence"
|
||||
rustMakeSequence :: CInt -> CSize -> Ptr SharedI32Buffer -> IO CInt
|
||||
|
||||
foreign import ccall unsafe "rust_make_byte_pattern"
|
||||
rustMakeBytePattern :: CUChar -> CSize -> Ptr SharedU8Buffer -> IO CInt
|
||||
|
||||
foreign import ccall unsafe "rust_free_string"
|
||||
rustFreeString :: CString -> IO ()
|
||||
|
||||
foreign import ccall unsafe "rust_free_i32_buffer"
|
||||
rustFreeI32Buffer :: Ptr CInt -> CSize -> CSize -> IO ()
|
||||
|
||||
foreign import ccall unsafe "rust_free_u8_buffer"
|
||||
rustFreeU8Buffer :: Ptr CUChar -> CSize -> CSize -> IO ()
|
||||
|
||||
callRustSummary :: Int -> Int -> IO Summary
|
||||
callRustSummary left right =
|
||||
alloca $ \outStats -> do
|
||||
@ -28,6 +51,16 @@ callRustSummary left right =
|
||||
then fail ("rustComputeStats returned status " ++ show status)
|
||||
else summaryFromSharedStats <$> peek outStats
|
||||
|
||||
callRustSliceSum :: [Int] -> IO Int
|
||||
callRustSliceSum values =
|
||||
withArrayLen (map fromIntegral values) $ \valueCount valuesPtr ->
|
||||
fmap fromIntegral (rustSumSlice valuesPtr (fromIntegral valueCount))
|
||||
|
||||
callRustByteChecksum :: [Int] -> IO Int
|
||||
callRustByteChecksum values =
|
||||
withArrayLen (map (fromIntegral . (`mod` 256)) values) $ \valueCount valuesPtr ->
|
||||
fmap fromIntegral (rustChecksumBytes valuesPtr (fromIntegral valueCount))
|
||||
|
||||
callRustMessage :: String -> Int -> Int -> IO String
|
||||
callRustMessage name left right =
|
||||
withCString name $ \namePtr -> do
|
||||
@ -39,3 +72,33 @@ callRustMessage name left right =
|
||||
(pure messagePtr)
|
||||
rustFreeString
|
||||
peekCString
|
||||
|
||||
callRustSequence :: Int -> Int -> IO [Int]
|
||||
callRustSequence start count =
|
||||
alloca $ \outBuffer -> do
|
||||
status <- rustMakeSequence (fromIntegral start) (fromIntegral count) outBuffer
|
||||
if status /= 0
|
||||
then fail ("rustMakeSequence returned status " ++ show status)
|
||||
else do
|
||||
buffer <- peek outBuffer
|
||||
values <-
|
||||
if sharedBufferPtr buffer == nullPtr
|
||||
then pure []
|
||||
else map fromIntegral <$> peekArray (fromIntegral (sharedBufferLen buffer)) (sharedBufferPtr buffer)
|
||||
rustFreeI32Buffer (sharedBufferPtr buffer) (sharedBufferLen buffer) (sharedBufferCap buffer)
|
||||
pure values
|
||||
|
||||
callRustBytePattern :: Int -> Int -> IO [Int]
|
||||
callRustBytePattern seed count =
|
||||
alloca $ \outBuffer -> do
|
||||
status <- rustMakeBytePattern (fromIntegral (seed `mod` 256)) (fromIntegral count) outBuffer
|
||||
if status /= 0
|
||||
then fail ("rustMakeBytePattern returned status " ++ show status)
|
||||
else do
|
||||
buffer <- peek outBuffer
|
||||
values <-
|
||||
if sharedByteBufferPtr buffer == nullPtr
|
||||
then pure []
|
||||
else map fromIntegral <$> peekArray (fromIntegral (sharedByteBufferLen buffer)) (sharedByteBufferPtr buffer)
|
||||
rustFreeU8Buffer (sharedByteBufferPtr buffer) (sharedByteBufferLen buffer) (sharedByteBufferCap buffer)
|
||||
pure values
|
||||
|
||||
@ -1,15 +1,29 @@
|
||||
module Interop.Exports (
|
||||
hsChecksumBytes,
|
||||
hsComputeStats,
|
||||
hsFreeI32Buffer,
|
||||
hsFreeU8Buffer,
|
||||
hsFreeString,
|
||||
hsMakeBytePattern,
|
||||
hsMakeMessage,
|
||||
hsMakeSequence,
|
||||
hsSumSlice,
|
||||
) where
|
||||
|
||||
import Foreign.C.String (CString, newCString, peekCString)
|
||||
import Foreign.C.Types (CInt (..))
|
||||
import Foreign.C.Types (CInt (..), CSize (..), CUChar (..), CUInt (..))
|
||||
import Foreign.Marshal.Alloc (free)
|
||||
import Foreign.Marshal.Array (newArray, peekArray)
|
||||
import Foreign.Ptr (Ptr, nullPtr)
|
||||
import Foreign.Storable (poke)
|
||||
import Interop.Shared (SharedStats, calculateSummary, formatHaskellMessage, summaryToSharedStats)
|
||||
import Interop.Shared (
|
||||
SharedI32Buffer (..),
|
||||
SharedStats,
|
||||
SharedU8Buffer (..),
|
||||
calculateSummary,
|
||||
formatHaskellMessage,
|
||||
summaryToSharedStats,
|
||||
)
|
||||
|
||||
hsComputeStats :: CInt -> CInt -> Ptr SharedStats -> IO CInt
|
||||
hsComputeStats left right outStats
|
||||
@ -20,6 +34,22 @@ hsComputeStats left right outStats
|
||||
calculateSummary (fromIntegral left) (fromIntegral right)
|
||||
pure 0
|
||||
|
||||
hsSumSlice :: Ptr CInt -> CSize -> IO CInt
|
||||
hsSumSlice valuesPtr valueCount
|
||||
| valuesPtr == nullPtr = pure 0
|
||||
| otherwise = do
|
||||
values <- peekArray (fromIntegral valueCount) valuesPtr
|
||||
pure (sum values)
|
||||
|
||||
hsChecksumBytes :: Ptr CUChar -> CSize -> IO CUInt
|
||||
hsChecksumBytes bytesPtr byteCount
|
||||
| bytesPtr == nullPtr = pure 0
|
||||
| otherwise = do
|
||||
values <- peekArray (fromIntegral byteCount) bytesPtr
|
||||
pure $
|
||||
fromIntegral $
|
||||
sum (map (fromIntegral :: CUChar -> Int) values)
|
||||
|
||||
hsMakeMessage :: CString -> CInt -> CInt -> IO CString
|
||||
hsMakeMessage namePtr left right
|
||||
| namePtr == nullPtr = newCString "Haskell received a null name pointer"
|
||||
@ -28,11 +58,70 @@ hsMakeMessage namePtr left right
|
||||
let summary = calculateSummary (fromIntegral left) (fromIntegral right)
|
||||
newCString (formatHaskellMessage name summary)
|
||||
|
||||
hsMakeSequence :: CInt -> CSize -> Ptr SharedI32Buffer -> IO CInt
|
||||
hsMakeSequence start count outBuffer
|
||||
| outBuffer == nullPtr = pure 1
|
||||
| count == 0 = do
|
||||
poke outBuffer SharedI32Buffer{sharedBufferPtr = nullPtr, sharedBufferLen = 0, sharedBufferCap = 0}
|
||||
pure 0
|
||||
| otherwise = do
|
||||
let values =
|
||||
[ start + fromIntegral offset
|
||||
| offset <- [0 .. (fromIntegral count :: Int) - 1]
|
||||
]
|
||||
valuesPtr <- newArray values
|
||||
poke
|
||||
outBuffer
|
||||
SharedI32Buffer
|
||||
{ sharedBufferPtr = valuesPtr
|
||||
, sharedBufferLen = count
|
||||
, sharedBufferCap = count
|
||||
}
|
||||
pure 0
|
||||
|
||||
hsMakeBytePattern :: CUChar -> CSize -> Ptr SharedU8Buffer -> IO CInt
|
||||
hsMakeBytePattern seed count outBuffer
|
||||
| outBuffer == nullPtr = pure 1
|
||||
| count == 0 = do
|
||||
poke outBuffer SharedU8Buffer{sharedByteBufferPtr = nullPtr, sharedByteBufferLen = 0, sharedByteBufferCap = 0}
|
||||
pure 0
|
||||
| otherwise = do
|
||||
let seedValue = fromIntegral seed :: Int
|
||||
let values =
|
||||
[ fromIntegral ((seedValue + offset) `mod` 256) :: CUChar
|
||||
| offset <- [0 .. (fromIntegral count :: Int) - 1]
|
||||
]
|
||||
valuesPtr <- newArray values
|
||||
poke
|
||||
outBuffer
|
||||
SharedU8Buffer
|
||||
{ sharedByteBufferPtr = valuesPtr
|
||||
, sharedByteBufferLen = count
|
||||
, sharedByteBufferCap = count
|
||||
}
|
||||
pure 0
|
||||
|
||||
hsFreeString :: CString -> IO ()
|
||||
hsFreeString ptr
|
||||
| ptr == nullPtr = pure ()
|
||||
| otherwise = free ptr
|
||||
|
||||
hsFreeI32Buffer :: Ptr CInt -> CSize -> CSize -> IO ()
|
||||
hsFreeI32Buffer ptr _ _
|
||||
| ptr == nullPtr = pure ()
|
||||
| otherwise = free ptr
|
||||
|
||||
hsFreeU8Buffer :: Ptr CUChar -> CSize -> CSize -> IO ()
|
||||
hsFreeU8Buffer ptr _ _
|
||||
| ptr == nullPtr = pure ()
|
||||
| otherwise = free ptr
|
||||
|
||||
foreign export ccall "hs_compute_stats" hsComputeStats :: CInt -> CInt -> Ptr SharedStats -> IO CInt
|
||||
foreign export ccall "hs_sum_slice" hsSumSlice :: Ptr CInt -> CSize -> IO CInt
|
||||
foreign export ccall "hs_checksum_bytes" hsChecksumBytes :: Ptr CUChar -> CSize -> IO CUInt
|
||||
foreign export ccall "hs_make_message" hsMakeMessage :: CString -> CInt -> CInt -> IO CString
|
||||
foreign export ccall "hs_make_sequence" hsMakeSequence :: CInt -> CSize -> Ptr SharedI32Buffer -> IO CInt
|
||||
foreign export ccall "hs_make_byte_pattern" hsMakeBytePattern :: CUChar -> CSize -> Ptr SharedU8Buffer -> IO CInt
|
||||
foreign export ccall "hs_free_string" hsFreeString :: CString -> IO ()
|
||||
foreign export ccall "hs_free_i32_buffer" hsFreeI32Buffer :: Ptr CInt -> CSize -> CSize -> IO ()
|
||||
foreign export ccall "hs_free_u8_buffer" hsFreeU8Buffer :: Ptr CUChar -> CSize -> CSize -> IO ()
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
module Interop.Shared (
|
||||
SharedI32Buffer (..),
|
||||
SharedStats (..),
|
||||
SharedU8Buffer (..),
|
||||
Summary (..),
|
||||
calculateSummary,
|
||||
formatHaskellMessage,
|
||||
@ -7,7 +9,8 @@ module Interop.Shared (
|
||||
summaryToSharedStats,
|
||||
) where
|
||||
|
||||
import Foreign.C.Types (CInt)
|
||||
import Foreign.C.Types (CInt, CSize, CUChar)
|
||||
import Foreign.Ptr (Ptr)
|
||||
import Foreign.Storable (Storable (..), peekByteOff, pokeByteOff)
|
||||
|
||||
data Summary = Summary
|
||||
@ -43,6 +46,20 @@ data SharedStats = SharedStats
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
data SharedI32Buffer = SharedI32Buffer
|
||||
{ sharedBufferPtr :: Ptr CInt
|
||||
, sharedBufferLen :: CSize
|
||||
, sharedBufferCap :: CSize
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
data SharedU8Buffer = SharedU8Buffer
|
||||
{ sharedByteBufferPtr :: Ptr CUChar
|
||||
, sharedByteBufferLen :: CSize
|
||||
, sharedByteBufferCap :: CSize
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance Storable SharedStats where
|
||||
sizeOf _ = fieldSize * 3
|
||||
where
|
||||
@ -65,6 +82,56 @@ instance Storable SharedStats where
|
||||
where
|
||||
fieldSize = sizeOf (undefined :: CInt)
|
||||
|
||||
instance Storable SharedI32Buffer where
|
||||
sizeOf _ = ptrSize + (fieldSize * 2)
|
||||
where
|
||||
ptrSize = sizeOf (undefined :: Ptr CInt)
|
||||
fieldSize = sizeOf (undefined :: CSize)
|
||||
|
||||
alignment _ = alignment (undefined :: Ptr CInt)
|
||||
|
||||
peek ptr =
|
||||
SharedI32Buffer
|
||||
<$> peekByteOff ptr 0
|
||||
<*> peekByteOff ptr ptrSize
|
||||
<*> peekByteOff ptr (ptrSize + fieldSize)
|
||||
where
|
||||
ptrSize = sizeOf (undefined :: Ptr CInt)
|
||||
fieldSize = sizeOf (undefined :: CSize)
|
||||
|
||||
poke ptr value = do
|
||||
pokeByteOff ptr 0 (sharedBufferPtr value)
|
||||
pokeByteOff ptr ptrSize (sharedBufferLen value)
|
||||
pokeByteOff ptr (ptrSize + fieldSize) (sharedBufferCap value)
|
||||
where
|
||||
ptrSize = sizeOf (undefined :: Ptr CInt)
|
||||
fieldSize = sizeOf (undefined :: CSize)
|
||||
|
||||
instance Storable SharedU8Buffer where
|
||||
sizeOf _ = ptrSize + (fieldSize * 2)
|
||||
where
|
||||
ptrSize = sizeOf (undefined :: Ptr CUChar)
|
||||
fieldSize = sizeOf (undefined :: CSize)
|
||||
|
||||
alignment _ = alignment (undefined :: Ptr CUChar)
|
||||
|
||||
peek ptr =
|
||||
SharedU8Buffer
|
||||
<$> peekByteOff ptr 0
|
||||
<*> peekByteOff ptr ptrSize
|
||||
<*> peekByteOff ptr (ptrSize + fieldSize)
|
||||
where
|
||||
ptrSize = sizeOf (undefined :: Ptr CUChar)
|
||||
fieldSize = sizeOf (undefined :: CSize)
|
||||
|
||||
poke ptr value = do
|
||||
pokeByteOff ptr 0 (sharedByteBufferPtr value)
|
||||
pokeByteOff ptr ptrSize (sharedByteBufferLen value)
|
||||
pokeByteOff ptr (ptrSize + fieldSize) (sharedByteBufferCap value)
|
||||
where
|
||||
ptrSize = sizeOf (undefined :: Ptr CUChar)
|
||||
fieldSize = sizeOf (undefined :: CSize)
|
||||
|
||||
summaryToSharedStats :: Summary -> SharedStats
|
||||
summaryToSharedStats summary =
|
||||
SharedStats
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
[toolchain]
|
||||
channel = "1.83.0"
|
||||
channel = "1.92.0"
|
||||
components = ["rustfmt", "clippy", "rust-analyzer"]
|
||||
|
||||
262
rust/build.rs
262
rust/build.rs
@ -1,262 +0,0 @@
|
||||
use std::collections::{BTreeSet, HashSet};
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
|
||||
fn main() {
|
||||
println!("cargo:rerun-if-env-changed=GHC_LIBDIR");
|
||||
println!("cargo:rerun-if-env-changed=GHC_RTS_LIB");
|
||||
println!("cargo:rerun-if-changed=rust/build.rs");
|
||||
|
||||
let libdir = env::var("GHC_LIBDIR").unwrap_or_else(|_| ghc_print_libdir());
|
||||
let explicit_rts = env::var("GHC_RTS_LIB").ok().map(PathBuf::from);
|
||||
let rts_path = explicit_rts.unwrap_or_else(|| find_rts_library(Path::new(&libdir)));
|
||||
let rts_dir = rts_path
|
||||
.parent()
|
||||
.unwrap_or_else(|| Path::new(&libdir))
|
||||
.to_path_buf();
|
||||
let rts_name = rts_path
|
||||
.file_stem()
|
||||
.and_then(|stem| stem.to_str())
|
||||
.map(strip_library_prefix)
|
||||
.unwrap_or_else(|| panic!("failed to resolve GHC RTS library name from {}", rts_path.display()));
|
||||
|
||||
let mut search_dirs = BTreeSet::new();
|
||||
let mut haskell_libs = Vec::new();
|
||||
let mut seen_haskell_libs = HashSet::new();
|
||||
let mut native_libs = BTreeSet::new();
|
||||
|
||||
search_dirs.insert(rts_dir);
|
||||
seen_haskell_libs.insert(rts_name.clone());
|
||||
haskell_libs.push(rts_name);
|
||||
|
||||
for package in ["base", "ghc-prim", "ghc-bignum"] {
|
||||
let info = ghc_pkg_describe(package);
|
||||
for dir in &info.dynamic_library_dirs {
|
||||
search_dirs.insert(dir.clone());
|
||||
}
|
||||
for library in info.hs_libraries {
|
||||
let resolved = resolve_dynamic_hs_library(&library, &info.dynamic_library_dirs);
|
||||
if seen_haskell_libs.insert(resolved.clone()) {
|
||||
haskell_libs.push(resolved);
|
||||
}
|
||||
}
|
||||
|
||||
for library in info.extra_libraries {
|
||||
native_libs.insert(library);
|
||||
}
|
||||
}
|
||||
|
||||
let rts_info = ghc_pkg_describe("rts");
|
||||
for dir in rts_info.dynamic_library_dirs {
|
||||
search_dirs.insert(dir);
|
||||
}
|
||||
for library in rts_info.extra_libraries {
|
||||
native_libs.insert(library);
|
||||
}
|
||||
|
||||
for dir in search_dirs {
|
||||
println!("cargo:rustc-link-search=native={}", dir.display());
|
||||
println!("cargo:rustc-link-arg=-Wl,-rpath,{}", dir.display());
|
||||
}
|
||||
|
||||
println!("cargo:rustc-link-arg=-Wl,--no-as-needed");
|
||||
for library in haskell_libs {
|
||||
println!("cargo:rustc-link-lib=dylib={library}");
|
||||
}
|
||||
println!("cargo:rustc-link-arg=-Wl,--as-needed");
|
||||
|
||||
for library in native_libs {
|
||||
println!("cargo:rustc-link-lib=dylib={library}");
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct PackageInfo {
|
||||
hs_libraries: Vec<String>,
|
||||
extra_libraries: Vec<String>,
|
||||
dynamic_library_dirs: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
fn ghc_print_libdir() -> String {
|
||||
let output = Command::new("ghc")
|
||||
.arg("--print-libdir")
|
||||
.output()
|
||||
.unwrap_or_else(|error| panic!("failed to run `ghc --print-libdir`: {error}"));
|
||||
if !output.status.success() {
|
||||
panic!("`ghc --print-libdir` did not exit successfully");
|
||||
}
|
||||
|
||||
String::from_utf8_lossy(&output.stdout).trim().to_string()
|
||||
}
|
||||
|
||||
fn ghc_pkg_describe(package: &str) -> PackageInfo {
|
||||
let output = Command::new("ghc-pkg")
|
||||
.args(["describe", package])
|
||||
.output()
|
||||
.unwrap_or_else(|error| panic!("failed to run `ghc-pkg describe {package}`: {error}"));
|
||||
if !output.status.success() {
|
||||
panic!("`ghc-pkg describe {package}` did not exit successfully");
|
||||
}
|
||||
|
||||
let description = String::from_utf8_lossy(&output.stdout);
|
||||
parse_package_description(&description)
|
||||
}
|
||||
|
||||
fn parse_package_description(description: &str) -> PackageInfo {
|
||||
let mut info = PackageInfo::default();
|
||||
let mut current_field = String::new();
|
||||
let mut pkgroot = String::new();
|
||||
|
||||
for raw_line in description.lines() {
|
||||
let line = raw_line.trim_end();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if raw_line.starts_with(' ') || raw_line.starts_with('\t') {
|
||||
push_field_values(¤t_field, line.trim(), &mut pkgroot, &mut info);
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some((field, rest)) = line.split_once(':') {
|
||||
current_field = field.trim().to_string();
|
||||
push_field_values(¤t_field, rest.trim(), &mut pkgroot, &mut info);
|
||||
}
|
||||
}
|
||||
|
||||
if !pkgroot.is_empty() {
|
||||
for dir in &mut info.dynamic_library_dirs {
|
||||
let resolved = dir
|
||||
.display()
|
||||
.to_string()
|
||||
.replace("${pkgroot}", &pkgroot);
|
||||
*dir = PathBuf::from(resolved);
|
||||
}
|
||||
}
|
||||
|
||||
info
|
||||
}
|
||||
|
||||
fn push_field_values(
|
||||
field: &str,
|
||||
values: &str,
|
||||
pkgroot: &mut String,
|
||||
info: &mut PackageInfo,
|
||||
) {
|
||||
if values.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
match field {
|
||||
"pkgroot" => {
|
||||
*pkgroot = values.trim_matches('"').to_string();
|
||||
}
|
||||
"hs-libraries" => {
|
||||
for value in values.split_whitespace() {
|
||||
info.hs_libraries
|
||||
.push(strip_library_prefix(value.trim_matches('"')));
|
||||
}
|
||||
}
|
||||
"extra-libraries" => {
|
||||
for value in values.split_whitespace() {
|
||||
info.extra_libraries.push(value.trim_matches('"').to_string());
|
||||
}
|
||||
}
|
||||
"dynamic-library-dirs" => {
|
||||
for value in values.split_whitespace() {
|
||||
info.dynamic_library_dirs
|
||||
.push(PathBuf::from(value.trim_matches('"')));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_dynamic_hs_library(library: &str, search_dirs: &[PathBuf]) -> String {
|
||||
for dir in search_dirs {
|
||||
let Ok(entries) = fs::read_dir(dir) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let exact_name = format!("lib{library}.so");
|
||||
let versioned_prefix = format!("lib{library}-");
|
||||
if (file_name == exact_name
|
||||
|| (file_name.starts_with(&versioned_prefix) && file_name.ends_with(".so")))
|
||||
&& path.is_file()
|
||||
{
|
||||
if let Some(stem) = path.file_stem().and_then(|value| value.to_str()) {
|
||||
return strip_library_prefix(stem);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
library.to_string()
|
||||
}
|
||||
|
||||
fn find_rts_library(libdir: &Path) -> PathBuf {
|
||||
let mut candidates = Vec::new();
|
||||
walk_for_rts(libdir, &mut candidates);
|
||||
candidates.sort_by_key(|path| rts_priority(path));
|
||||
|
||||
candidates
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap_or_else(|| panic!("failed to locate a GHC RTS library under {}", libdir.display()))
|
||||
}
|
||||
|
||||
fn walk_for_rts(root: &Path, candidates: &mut Vec<PathBuf>) {
|
||||
let Ok(entries) = fs::read_dir(root) else {
|
||||
return;
|
||||
};
|
||||
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
walk_for_rts(&path, candidates);
|
||||
continue;
|
||||
}
|
||||
|
||||
if is_threaded_rts_library(&path) {
|
||||
candidates.push(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_threaded_rts_library(path: &Path) -> bool {
|
||||
let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
path.extension().and_then(|ext| ext.to_str()) == Some("so") && file_name.starts_with("libHSrts-")
|
||||
}
|
||||
|
||||
fn rts_priority(path: &Path) -> (u8, String) {
|
||||
let file_name = path
|
||||
.file_name()
|
||||
.and_then(|value| value.to_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
|
||||
let rank = if file_name.contains("_debug") {
|
||||
3
|
||||
} else if file_name.contains("_thr") {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
(rank, file_name)
|
||||
}
|
||||
|
||||
fn strip_library_prefix(stem: &str) -> String {
|
||||
stem.strip_prefix("lib").unwrap_or(stem).to_string()
|
||||
}
|
||||
230
rust/haskell.rs
230
rust/haskell.rs
@ -1,19 +1,22 @@
|
||||
use crate::interop::SharedStats;
|
||||
use crate::interop::{SharedI32Buffer, SharedStats, SharedU8Buffer};
|
||||
use libloading::Library;
|
||||
use std::env;
|
||||
use std::ffi::{CStr, CString};
|
||||
use std::fs;
|
||||
use std::os::raw::{c_char, c_int};
|
||||
use std::os::raw::{c_char, c_int, c_uint};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn hs_init(argc: *mut c_int, argv: *mut *mut *mut c_char);
|
||||
fn hs_exit();
|
||||
}
|
||||
|
||||
type HsComputeStats = unsafe extern "C" fn(c_int, c_int, *mut SharedStats) -> c_int;
|
||||
type HsSumSlice = unsafe extern "C" fn(*const c_int, usize) -> c_int;
|
||||
type HsChecksumBytes = unsafe extern "C" fn(*const u8, usize) -> c_uint;
|
||||
type HsMakeMessage = unsafe extern "C" fn(*const c_char, c_int, c_int) -> *mut c_char;
|
||||
type HsMakeSequence = unsafe extern "C" fn(c_int, usize, *mut SharedI32Buffer) -> c_int;
|
||||
type HsMakeBytePattern = unsafe extern "C" fn(u8, usize, *mut SharedU8Buffer) -> c_int;
|
||||
type HsFreeString = unsafe extern "C" fn(*mut c_char);
|
||||
type HsFreeI32Buffer = unsafe extern "C" fn(*mut c_int, usize, usize);
|
||||
type HsFreeU8Buffer = unsafe extern "C" fn(*mut u8, usize, usize);
|
||||
type HsInit = unsafe extern "C" fn(*mut c_int, *mut *mut *mut c_char);
|
||||
type HsExit = unsafe extern "C" fn();
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct DemoArgs {
|
||||
@ -25,18 +28,40 @@ pub struct DemoArgs {
|
||||
|
||||
pub fn run_haskell_demo(args: &DemoArgs) -> Result<String, String> {
|
||||
let library_path = resolve_library_path(args.library_path.as_deref())?;
|
||||
let library = unsafe { Library::new(&library_path) }
|
||||
.map_err(|error| format!("failed to load {}: {error}", library_path.display()))?;
|
||||
let runtime = HaskellRuntime::start()?;
|
||||
|
||||
let output = load_and_run(&library_path, args);
|
||||
let output = load_and_run(&library_path, &library, args);
|
||||
|
||||
drop(runtime);
|
||||
output
|
||||
}
|
||||
|
||||
struct HaskellRuntime;
|
||||
struct HaskellRuntime {
|
||||
_rts_library: Library,
|
||||
hs_exit: HsExit,
|
||||
}
|
||||
|
||||
impl HaskellRuntime {
|
||||
fn start() -> Result<Self, String> {
|
||||
let rts_library_path = resolve_rts_library_path()?;
|
||||
let rts_library = unsafe { Library::new(&rts_library_path) }.map_err(|error| {
|
||||
format!(
|
||||
"failed to load GHC RTS library {}: {error}",
|
||||
rts_library_path.display()
|
||||
)
|
||||
})?;
|
||||
let hs_init: HsInit = unsafe {
|
||||
*rts_library
|
||||
.get(b"hs_init\0")
|
||||
.map_err(|error| format!("failed to load hs_init: {error}"))?
|
||||
};
|
||||
let hs_exit: HsExit = unsafe {
|
||||
*rts_library
|
||||
.get(b"hs_exit\0")
|
||||
.map_err(|error| format!("failed to load hs_exit: {error}"))?
|
||||
};
|
||||
let mut argc: c_int = 1;
|
||||
let program_name = CString::new("integrations-hs-runtime")
|
||||
.map_err(|_| "failed to create runtime program name".to_string())?;
|
||||
@ -47,37 +72,67 @@ impl HaskellRuntime {
|
||||
hs_init(&mut argc, &mut argv_ptr);
|
||||
}
|
||||
|
||||
Ok(Self)
|
||||
Ok(Self {
|
||||
_rts_library: rts_library,
|
||||
hs_exit,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for HaskellRuntime {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
hs_exit();
|
||||
(self.hs_exit)();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_and_run(library_path: &Path, args: &DemoArgs) -> Result<String, String> {
|
||||
let library = unsafe { Library::new(library_path) }
|
||||
.map_err(|error| format!("failed to load {}: {error}", library_path.display()))?;
|
||||
|
||||
fn load_and_run(library_path: &Path, library: &Library, args: &DemoArgs) -> Result<String, String> {
|
||||
let compute_stats: HsComputeStats = unsafe {
|
||||
*library
|
||||
.get(b"hs_compute_stats\0")
|
||||
.map_err(|error| format!("failed to load hs_compute_stats: {error}"))?
|
||||
};
|
||||
let sum_slice: HsSumSlice = unsafe {
|
||||
*library
|
||||
.get(b"hs_sum_slice\0")
|
||||
.map_err(|error| format!("failed to load hs_sum_slice: {error}"))?
|
||||
};
|
||||
let checksum_bytes: HsChecksumBytes = unsafe {
|
||||
*library
|
||||
.get(b"hs_checksum_bytes\0")
|
||||
.map_err(|error| format!("failed to load hs_checksum_bytes: {error}"))?
|
||||
};
|
||||
let make_message: HsMakeMessage = unsafe {
|
||||
*library
|
||||
.get(b"hs_make_message\0")
|
||||
.map_err(|error| format!("failed to load hs_make_message: {error}"))?
|
||||
};
|
||||
let make_sequence: HsMakeSequence = unsafe {
|
||||
*library
|
||||
.get(b"hs_make_sequence\0")
|
||||
.map_err(|error| format!("failed to load hs_make_sequence: {error}"))?
|
||||
};
|
||||
let make_byte_pattern: HsMakeBytePattern = unsafe {
|
||||
*library
|
||||
.get(b"hs_make_byte_pattern\0")
|
||||
.map_err(|error| format!("failed to load hs_make_byte_pattern: {error}"))?
|
||||
};
|
||||
let free_string: HsFreeString = unsafe {
|
||||
*library
|
||||
.get(b"hs_free_string\0")
|
||||
.map_err(|error| format!("failed to load hs_free_string: {error}"))?
|
||||
};
|
||||
let free_i32_buffer: HsFreeI32Buffer = unsafe {
|
||||
*library
|
||||
.get(b"hs_free_i32_buffer\0")
|
||||
.map_err(|error| format!("failed to load hs_free_i32_buffer: {error}"))?
|
||||
};
|
||||
let free_u8_buffer: HsFreeU8Buffer = unsafe {
|
||||
*library
|
||||
.get(b"hs_free_u8_buffer\0")
|
||||
.map_err(|error| format!("failed to load hs_free_u8_buffer: {error}"))?
|
||||
};
|
||||
|
||||
let mut stats = SharedStats::default();
|
||||
let status = unsafe { compute_stats(args.left, args.right, &mut stats) };
|
||||
@ -99,8 +154,40 @@ fn load_and_run(library_path: &Path, args: &DemoArgs) -> Result<String, String>
|
||||
free_string(message_ptr);
|
||||
}
|
||||
|
||||
let sample_values: [c_int; 4] = [2, 4, 6, 8];
|
||||
let slice_sum = unsafe { sum_slice(sample_values.as_ptr(), sample_values.len()) };
|
||||
let sample_bytes: [u8; 4] = [72, 0, 105, 255];
|
||||
let byte_checksum = unsafe { checksum_bytes(sample_bytes.as_ptr(), sample_bytes.len()) };
|
||||
|
||||
let mut sequence_buffer = SharedI32Buffer::default();
|
||||
let sequence_status = unsafe { make_sequence(args.left, 5, &mut sequence_buffer) };
|
||||
if sequence_status != 0 {
|
||||
return Err(format!("hs_make_sequence returned status {sequence_status}"));
|
||||
}
|
||||
|
||||
let sequence_values = read_i32_buffer(&sequence_buffer)?;
|
||||
unsafe {
|
||||
free_i32_buffer(
|
||||
sequence_buffer.ptr,
|
||||
sequence_buffer.len,
|
||||
sequence_buffer.cap,
|
||||
);
|
||||
}
|
||||
|
||||
let byte_seed = normalize_byte_seed(args.left);
|
||||
let mut byte_buffer = SharedU8Buffer::default();
|
||||
let byte_status = unsafe { make_byte_pattern(byte_seed, 6, &mut byte_buffer) };
|
||||
if byte_status != 0 {
|
||||
return Err(format!("hs_make_byte_pattern returned status {byte_status}"));
|
||||
}
|
||||
|
||||
let returned_bytes = read_u8_buffer(&byte_buffer)?;
|
||||
unsafe {
|
||||
free_u8_buffer(byte_buffer.ptr, byte_buffer.len, byte_buffer.cap);
|
||||
}
|
||||
|
||||
Ok(format!(
|
||||
"Rust -> Haskell demo\nLibrary: {}\nInputs: name={}, left={}, right={}\nStats from Haskell: total={}, product={}, gap={}\nMessage from Haskell: {}",
|
||||
"Rust -> Haskell demo\nLibrary: {}\nInputs: name={}, left={}, right={}\nStats from Haskell: total={}, product={}, gap={}\nMessage from Haskell: {}\nSlice sent to Haskell: {:?}\nHaskell summed slice to: {}\nVector returned from Haskell: {:?}\nByte slice sent to Haskell: {:?}\nHaskell checksummed bytes to: {}\nByte buffer returned from Haskell: {:?}",
|
||||
library_path.display(),
|
||||
args.name,
|
||||
args.left,
|
||||
@ -109,9 +196,43 @@ fn load_and_run(library_path: &Path, args: &DemoArgs) -> Result<String, String>
|
||||
stats.product,
|
||||
stats.gap,
|
||||
message,
|
||||
sample_values,
|
||||
slice_sum,
|
||||
sequence_values,
|
||||
sample_bytes,
|
||||
byte_checksum,
|
||||
returned_bytes,
|
||||
))
|
||||
}
|
||||
|
||||
fn read_i32_buffer(buffer: &SharedI32Buffer) -> Result<Vec<c_int>, String> {
|
||||
if buffer.ptr.is_null() {
|
||||
if buffer.len == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
return Err("received a null buffer pointer with a non-zero length".to_string());
|
||||
}
|
||||
|
||||
let values = unsafe { std::slice::from_raw_parts(buffer.ptr, buffer.len) };
|
||||
Ok(values.to_vec())
|
||||
}
|
||||
|
||||
fn read_u8_buffer(buffer: &SharedU8Buffer) -> Result<Vec<u8>, String> {
|
||||
if buffer.ptr.is_null() {
|
||||
if buffer.len == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
return Err("received a null byte buffer pointer with a non-zero length".to_string());
|
||||
}
|
||||
|
||||
let values = unsafe { std::slice::from_raw_parts(buffer.ptr, buffer.len) };
|
||||
Ok(values.to_vec())
|
||||
}
|
||||
|
||||
fn normalize_byte_seed(value: i32) -> u8 {
|
||||
value.rem_euclid(256) as u8
|
||||
}
|
||||
|
||||
fn resolve_library_path(explicit_path: Option<&str>) -> Result<PathBuf, String> {
|
||||
if let Some(path) = explicit_path {
|
||||
return Ok(PathBuf::from(path));
|
||||
@ -133,6 +254,33 @@ fn resolve_library_path(explicit_path: Option<&str>) -> Result<PathBuf, String>
|
||||
})
|
||||
}
|
||||
|
||||
fn resolve_rts_library_path() -> Result<PathBuf, String> {
|
||||
if let Ok(path) = env::var("GHC_RTS_LIB") {
|
||||
return Ok(PathBuf::from(path));
|
||||
}
|
||||
|
||||
let libdir = if let Ok(path) = env::var("GHC_LIBDIR") {
|
||||
PathBuf::from(path)
|
||||
} else {
|
||||
let output = std::process::Command::new("ghc")
|
||||
.arg("--print-libdir")
|
||||
.output()
|
||||
.map_err(|error| format!("failed to run `ghc --print-libdir`: {error}"))?;
|
||||
if !output.status.success() {
|
||||
return Err("`ghc --print-libdir` did not exit successfully".to_string());
|
||||
}
|
||||
PathBuf::from(String::from_utf8_lossy(&output.stdout).trim())
|
||||
};
|
||||
|
||||
let mut matches = Vec::new();
|
||||
collect_matching_rts_libraries(&libdir, &mut matches)?;
|
||||
matches.sort_by_key(|path| rts_priority(path));
|
||||
matches
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| format!("could not find a GHC RTS library under {}", libdir.display()))
|
||||
}
|
||||
|
||||
fn collect_matching_libraries(root: &Path, matches: &mut Vec<PathBuf>) -> Result<(), String> {
|
||||
if !root.exists() {
|
||||
return Ok(());
|
||||
@ -156,6 +304,29 @@ fn collect_matching_libraries(root: &Path, matches: &mut Vec<PathBuf>) -> Result
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_matching_rts_libraries(root: &Path, matches: &mut Vec<PathBuf>) -> Result<(), String> {
|
||||
if !root.exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let entries = fs::read_dir(root)
|
||||
.map_err(|error| format!("failed to read {}: {error}", root.display()))?;
|
||||
for entry in entries {
|
||||
let entry = entry.map_err(|error| format!("failed to inspect {}: {error}", root.display()))?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
collect_matching_rts_libraries(&path, matches)?;
|
||||
continue;
|
||||
}
|
||||
|
||||
if is_rts_library(&path) {
|
||||
matches.push(path);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_haskell_foreign_library(path: &Path) -> bool {
|
||||
let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else {
|
||||
return false;
|
||||
@ -167,3 +338,30 @@ fn is_haskell_foreign_library(path: &Path) -> bool {
|
||||
|
||||
is_library && file_name.contains("interop_hs")
|
||||
}
|
||||
|
||||
fn is_rts_library(path: &Path) -> bool {
|
||||
let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
path.extension().and_then(|ext| ext.to_str()) == Some("so")
|
||||
&& file_name.starts_with("libHSrts-")
|
||||
}
|
||||
|
||||
fn rts_priority(path: &Path) -> (u8, String) {
|
||||
let file_name = path
|
||||
.file_name()
|
||||
.and_then(|value| value.to_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
|
||||
let rank = if file_name.contains("_debug") {
|
||||
3
|
||||
} else if file_name.contains("_thr") {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
(rank, file_name)
|
||||
}
|
||||
|
||||
177
rust/interop.rs
177
rust/interop.rs
@ -1,5 +1,5 @@
|
||||
use std::ffi::{CStr, CString};
|
||||
use std::os::raw::{c_char, c_int};
|
||||
use std::os::raw::{c_char, c_int, c_uint};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
|
||||
@ -9,6 +9,22 @@ pub struct SharedStats {
|
||||
pub gap: c_int,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
|
||||
pub struct SharedI32Buffer {
|
||||
pub ptr: *mut c_int,
|
||||
pub len: usize,
|
||||
pub cap: usize,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
|
||||
pub struct SharedU8Buffer {
|
||||
pub ptr: *mut u8,
|
||||
pub len: usize,
|
||||
pub cap: usize,
|
||||
}
|
||||
|
||||
pub fn compute_stats(left: c_int, right: c_int) -> SharedStats {
|
||||
SharedStats {
|
||||
total: left.saturating_add(right),
|
||||
@ -17,6 +33,19 @@ pub fn compute_stats(left: c_int, right: c_int) -> SharedStats {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sum_slice(values: &[c_int]) -> c_int {
|
||||
values
|
||||
.iter()
|
||||
.copied()
|
||||
.fold(0, |accumulator, value| accumulator.saturating_add(value))
|
||||
}
|
||||
|
||||
pub fn checksum_bytes(values: &[u8]) -> c_uint {
|
||||
values.iter().fold(0_u32, |accumulator, value| {
|
||||
accumulator.saturating_add(u32::from(*value))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn make_rust_message(name: &str, left: c_int, right: c_int) -> String {
|
||||
let stats = compute_stats(left, right);
|
||||
format!(
|
||||
@ -25,7 +54,31 @@ pub fn make_rust_message(name: &str, left: c_int, right: c_int) -> String {
|
||||
)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub fn make_sequence_values(start: c_int, count: usize) -> Vec<c_int> {
|
||||
(0..count)
|
||||
.map(|index| {
|
||||
let offset = match c_int::try_from(index) {
|
||||
Ok(value) => value,
|
||||
Err(_) => c_int::MAX,
|
||||
};
|
||||
start.saturating_add(offset)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn make_byte_pattern(seed: u8, count: usize) -> Vec<u8> {
|
||||
(0..count)
|
||||
.map(|index| {
|
||||
let offset = match u8::try_from(index) {
|
||||
Ok(value) => value,
|
||||
Err(_) => u8::MAX,
|
||||
};
|
||||
seed.wrapping_add(offset)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn rust_compute_stats(
|
||||
left: c_int,
|
||||
right: c_int,
|
||||
@ -35,11 +88,23 @@ pub unsafe extern "C" fn rust_compute_stats(
|
||||
return 1;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
out_stats.write(compute_stats(left, right));
|
||||
}
|
||||
0
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn rust_sum_slice(ptr: *const c_int, len: usize) -> c_int {
|
||||
if ptr.is_null() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let values = unsafe { std::slice::from_raw_parts(ptr, len) };
|
||||
sum_slice(values)
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn rust_make_message(
|
||||
name: *const c_char,
|
||||
left: c_int,
|
||||
@ -49,17 +114,99 @@ pub unsafe extern "C" fn rust_make_message(
|
||||
return string_into_raw("Rust received a null name pointer".to_string());
|
||||
}
|
||||
|
||||
let name = CStr::from_ptr(name).to_string_lossy();
|
||||
let name = unsafe { CStr::from_ptr(name) }.to_string_lossy();
|
||||
string_into_raw(make_rust_message(name.as_ref(), left, right))
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn rust_checksum_bytes(ptr: *const u8, len: usize) -> c_uint {
|
||||
if ptr.is_null() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let values = unsafe { std::slice::from_raw_parts(ptr, len) };
|
||||
checksum_bytes(values)
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn rust_make_sequence(
|
||||
start: c_int,
|
||||
count: usize,
|
||||
out_buffer: *mut SharedI32Buffer,
|
||||
) -> c_int {
|
||||
if out_buffer.is_null() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
let mut values = make_sequence_values(start, count);
|
||||
let buffer = SharedI32Buffer {
|
||||
ptr: values.as_mut_ptr(),
|
||||
len: values.len(),
|
||||
cap: values.capacity(),
|
||||
};
|
||||
|
||||
std::mem::forget(values);
|
||||
unsafe {
|
||||
out_buffer.write(buffer);
|
||||
}
|
||||
0
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn rust_make_byte_pattern(
|
||||
seed: u8,
|
||||
count: usize,
|
||||
out_buffer: *mut SharedU8Buffer,
|
||||
) -> c_int {
|
||||
if out_buffer.is_null() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
let mut values = make_byte_pattern(seed, count);
|
||||
let buffer = SharedU8Buffer {
|
||||
ptr: values.as_mut_ptr(),
|
||||
len: values.len(),
|
||||
cap: values.capacity(),
|
||||
};
|
||||
|
||||
std::mem::forget(values);
|
||||
unsafe {
|
||||
out_buffer.write(buffer);
|
||||
}
|
||||
0
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn rust_free_string(ptr: *mut c_char) {
|
||||
if ptr.is_null() {
|
||||
return;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
drop(CString::from_raw(ptr));
|
||||
}
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn rust_free_i32_buffer(ptr: *mut c_int, len: usize, cap: usize) {
|
||||
if ptr.is_null() {
|
||||
return;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
drop(Vec::from_raw_parts(ptr, len, cap));
|
||||
}
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn rust_free_u8_buffer(ptr: *mut u8, len: usize, cap: usize) {
|
||||
if ptr.is_null() {
|
||||
return;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
drop(Vec::from_raw_parts(ptr, len, cap));
|
||||
}
|
||||
}
|
||||
|
||||
fn string_into_raw(message: String) -> *mut c_char {
|
||||
@ -95,4 +242,24 @@ mod tests {
|
||||
assert!(message.contains("product=35"));
|
||||
assert!(message.contains("gap=2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sum_slice_handles_multiple_values() {
|
||||
assert_eq!(sum_slice(&[2, 4, 6, 8]), 20);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn make_sequence_values_builds_contiguous_output() {
|
||||
assert_eq!(make_sequence_values(7, 5), vec![7, 8, 9, 10, 11]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn checksum_bytes_handles_embedded_zeroes() {
|
||||
assert_eq!(checksum_bytes(&[72, 0, 105, 255]), 432);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn make_byte_pattern_wraps_like_bytes() {
|
||||
assert_eq!(make_byte_pattern(254, 4), vec![254, 255, 0, 1]);
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user