diff --git a/Cargo.toml b/Cargo.toml index 5a058af..dbf82e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,6 @@ version = "0.1.0" edition = "2024" license = "MIT OR Apache-2.0" publish = false -build = "rust/build.rs" [lib] path = "rust/lib.rs" diff --git a/README.md b/README.md index 33a6528..e32ce3e 100644 --- a/README.md +++ b/README.md @@ -23,8 +23,20 @@ The boundary is deliberately C-shaped: - integers - a shared struct layout +- borrowed slices as `ptr + len` +- owned buffers as `ptr + len + cap` - owned C strings with explicit free functions on each side +The current examples cover: + +- scalar values and shared structs +- borrowed `i32` slices +- owned `i32` buffers returned across the boundary +- borrowed raw byte buffers +- owned raw byte buffers returned across the boundary + +The byte examples are separate from C strings on purpose. They demonstrate handling embedded zero bytes safely. + ## Why The Build Uses Both Static And Shared Libraries This demo uses different library styles for the two directions because that keeps each path simpler: @@ -70,6 +82,8 @@ CABAL_DIR=$PWD/../.cabal XDG_STATE_HOME=$PWD/../.cabal/state XDG_CACHE_HOME=$PWD - the ABI boundary must stay simple and explicit - Rust and Haskell do not share ownership rules automatically - struct layout must match on both sides +- borrowed data and owned returned buffers need different FFI patterns +- raw bytes are not the same thing as C strings - Rust calling Haskell is the harder direction because it must initialize the GHC runtime correctly - build tooling is part of the integration problem, not just an implementation detail diff --git a/haskell/README.md b/haskell/README.md index 0bcbd02..f869e06 100644 --- a/haskell/README.md +++ b/haskell/README.md @@ -19,10 +19,20 @@ The code keeps the FFI surface small on purpose. The boundary uses: - integers - a fixed C-shaped struct +- borrowed slices as `ptr + len` +- owned returned buffers as `ptr + len + cap` - owned C strings with an explicit free function on each side That is enough to demonstrate the main challenges without pulling in code generation or large bindings. +The current examples include: + +- scalar stats and message passing +- borrowed `i32` slices sent across the boundary +- owned `i32` buffers returned across the boundary +- borrowed raw byte buffers with embedded zero bytes +- owned raw byte buffers returned across the boundary + ## Build And Run From the repository root: @@ -56,6 +66,8 @@ cargo run -- rust-calls-haskell Ada 7 5 - The boundary must stay C-shaped. Rich Rust and Haskell types do not cross directly. - Strings need explicit ownership rules. Each side exports its own free function. +- Borrowed slices and owned returned buffers are different patterns and must be modeled differently. +- Raw byte buffers should be treated separately from C strings. - Struct layout must be mirrored carefully on both sides. - Rust calling Haskell is the harder direction because it must initialize and shut down the GHC runtime correctly. - Build order is part of the design. Haskell links against the Rust static library, and Rust loads the Haskell foreign library after Cabal builds it. diff --git a/haskell/app/Main.hs b/haskell/app/Main.hs index eaa8e7d..3f660a3 100644 --- a/haskell/app/Main.hs +++ b/haskell/app/Main.hs @@ -1,7 +1,14 @@ module Main (main) where import Interop.Shared (Summary (..)) -import RustClient (callRustMessage, callRustSummary) +import RustClient ( + callRustByteChecksum, + callRustBytePattern, + callRustMessage, + callRustSequence, + callRustSliceSum, + callRustSummary, + ) import System.Environment (getArgs) import Text.Read (readMaybe) @@ -11,11 +18,23 @@ main = do let (name, left, right) = parseArgs args summary <- callRustSummary left right message <- callRustMessage name left right + let sliceValues = [2, 4, 6, 8] + sliceSum <- callRustSliceSum sliceValues + sequenceValues <- callRustSequence left 5 + let byteValues = [72, 0, 105, 255] + byteChecksum <- callRustByteChecksum byteValues + bytePattern <- callRustBytePattern left 6 putStrLn "Haskell -> Rust demo" putStrLn $ "Inputs: name=" ++ name ++ ", left=" ++ show left ++ ", right=" ++ show right putStrLn $ "Stats from Rust: " ++ renderSummary summary putStrLn $ "Message from Rust: " ++ message + putStrLn $ "Slice sent to Rust: " ++ show sliceValues + putStrLn $ "Rust summed slice to: " ++ show sliceSum + putStrLn $ "Vector returned from Rust: " ++ show sequenceValues + putStrLn $ "Byte slice sent to Rust: " ++ show byteValues + putStrLn $ "Rust checksummed bytes to: " ++ show byteChecksum + putStrLn $ "Byte buffer returned from Rust: " ++ show bytePattern parseArgs :: [String] -> (String, Int, Int) parseArgs args = diff --git a/haskell/app/RustClient.hs b/haskell/app/RustClient.hs index ac2a09a..d2de5f7 100644 --- a/haskell/app/RustClient.hs +++ b/haskell/app/RustClient.hs @@ -1,25 +1,48 @@ module RustClient ( + callRustByteChecksum, + callRustBytePattern, callRustMessage, + callRustSequence, + callRustSliceSum, callRustSummary, ) where import Control.Exception (bracket) import Foreign.C.String (CString, peekCString, withCString) -import Foreign.C.Types (CInt (..)) +import Foreign.C.Types (CInt (..), CSize (..), CUChar (..), CUInt (..)) import Foreign.Marshal.Alloc (alloca) +import Foreign.Marshal.Array (peekArray, withArrayLen) import Foreign.Ptr (Ptr, nullPtr) import Foreign.Storable (peek) -import Interop.Shared (Summary, SharedStats, summaryFromSharedStats) +import Interop.Shared (SharedI32Buffer (..), SharedU8Buffer (..), Summary, SharedStats, summaryFromSharedStats) foreign import ccall unsafe "rust_compute_stats" rustComputeStats :: CInt -> CInt -> Ptr SharedStats -> IO CInt +foreign import ccall unsafe "rust_sum_slice" + rustSumSlice :: Ptr CInt -> CSize -> IO CInt + +foreign import ccall unsafe "rust_checksum_bytes" + rustChecksumBytes :: Ptr CUChar -> CSize -> IO CUInt + foreign import ccall unsafe "rust_make_message" rustMakeMessage :: CString -> CInt -> CInt -> IO CString +foreign import ccall unsafe "rust_make_sequence" + rustMakeSequence :: CInt -> CSize -> Ptr SharedI32Buffer -> IO CInt + +foreign import ccall unsafe "rust_make_byte_pattern" + rustMakeBytePattern :: CUChar -> CSize -> Ptr SharedU8Buffer -> IO CInt + foreign import ccall unsafe "rust_free_string" rustFreeString :: CString -> IO () +foreign import ccall unsafe "rust_free_i32_buffer" + rustFreeI32Buffer :: Ptr CInt -> CSize -> CSize -> IO () + +foreign import ccall unsafe "rust_free_u8_buffer" + rustFreeU8Buffer :: Ptr CUChar -> CSize -> CSize -> IO () + callRustSummary :: Int -> Int -> IO Summary callRustSummary left right = alloca $ \outStats -> do @@ -28,6 +51,16 @@ callRustSummary left right = then fail ("rustComputeStats returned status " ++ show status) else summaryFromSharedStats <$> peek outStats +callRustSliceSum :: [Int] -> IO Int +callRustSliceSum values = + withArrayLen (map fromIntegral values) $ \valueCount valuesPtr -> + fmap fromIntegral (rustSumSlice valuesPtr (fromIntegral valueCount)) + +callRustByteChecksum :: [Int] -> IO Int +callRustByteChecksum values = + withArrayLen (map (fromIntegral . (`mod` 256)) values) $ \valueCount valuesPtr -> + fmap fromIntegral (rustChecksumBytes valuesPtr (fromIntegral valueCount)) + callRustMessage :: String -> Int -> Int -> IO String callRustMessage name left right = withCString name $ \namePtr -> do @@ -39,3 +72,33 @@ callRustMessage name left right = (pure messagePtr) rustFreeString peekCString + +callRustSequence :: Int -> Int -> IO [Int] +callRustSequence start count = + alloca $ \outBuffer -> do + status <- rustMakeSequence (fromIntegral start) (fromIntegral count) outBuffer + if status /= 0 + then fail ("rustMakeSequence returned status " ++ show status) + else do + buffer <- peek outBuffer + values <- + if sharedBufferPtr buffer == nullPtr + then pure [] + else map fromIntegral <$> peekArray (fromIntegral (sharedBufferLen buffer)) (sharedBufferPtr buffer) + rustFreeI32Buffer (sharedBufferPtr buffer) (sharedBufferLen buffer) (sharedBufferCap buffer) + pure values + +callRustBytePattern :: Int -> Int -> IO [Int] +callRustBytePattern seed count = + alloca $ \outBuffer -> do + status <- rustMakeBytePattern (fromIntegral (seed `mod` 256)) (fromIntegral count) outBuffer + if status /= 0 + then fail ("rustMakeBytePattern returned status " ++ show status) + else do + buffer <- peek outBuffer + values <- + if sharedByteBufferPtr buffer == nullPtr + then pure [] + else map fromIntegral <$> peekArray (fromIntegral (sharedByteBufferLen buffer)) (sharedByteBufferPtr buffer) + rustFreeU8Buffer (sharedByteBufferPtr buffer) (sharedByteBufferLen buffer) (sharedByteBufferCap buffer) + pure values diff --git a/haskell/src/Interop/Exports.hs b/haskell/src/Interop/Exports.hs index 2a72da6..a8746ed 100644 --- a/haskell/src/Interop/Exports.hs +++ b/haskell/src/Interop/Exports.hs @@ -1,15 +1,29 @@ module Interop.Exports ( + hsChecksumBytes, hsComputeStats, + hsFreeI32Buffer, + hsFreeU8Buffer, hsFreeString, + hsMakeBytePattern, hsMakeMessage, + hsMakeSequence, + hsSumSlice, ) where import Foreign.C.String (CString, newCString, peekCString) -import Foreign.C.Types (CInt (..)) +import Foreign.C.Types (CInt (..), CSize (..), CUChar (..), CUInt (..)) import Foreign.Marshal.Alloc (free) +import Foreign.Marshal.Array (newArray, peekArray) import Foreign.Ptr (Ptr, nullPtr) import Foreign.Storable (poke) -import Interop.Shared (SharedStats, calculateSummary, formatHaskellMessage, summaryToSharedStats) +import Interop.Shared ( + SharedI32Buffer (..), + SharedStats, + SharedU8Buffer (..), + calculateSummary, + formatHaskellMessage, + summaryToSharedStats, + ) hsComputeStats :: CInt -> CInt -> Ptr SharedStats -> IO CInt hsComputeStats left right outStats @@ -20,6 +34,22 @@ hsComputeStats left right outStats calculateSummary (fromIntegral left) (fromIntegral right) pure 0 +hsSumSlice :: Ptr CInt -> CSize -> IO CInt +hsSumSlice valuesPtr valueCount + | valuesPtr == nullPtr = pure 0 + | otherwise = do + values <- peekArray (fromIntegral valueCount) valuesPtr + pure (sum values) + +hsChecksumBytes :: Ptr CUChar -> CSize -> IO CUInt +hsChecksumBytes bytesPtr byteCount + | bytesPtr == nullPtr = pure 0 + | otherwise = do + values <- peekArray (fromIntegral byteCount) bytesPtr + pure $ + fromIntegral $ + sum (map (fromIntegral :: CUChar -> Int) values) + hsMakeMessage :: CString -> CInt -> CInt -> IO CString hsMakeMessage namePtr left right | namePtr == nullPtr = newCString "Haskell received a null name pointer" @@ -28,11 +58,70 @@ hsMakeMessage namePtr left right let summary = calculateSummary (fromIntegral left) (fromIntegral right) newCString (formatHaskellMessage name summary) +hsMakeSequence :: CInt -> CSize -> Ptr SharedI32Buffer -> IO CInt +hsMakeSequence start count outBuffer + | outBuffer == nullPtr = pure 1 + | count == 0 = do + poke outBuffer SharedI32Buffer{sharedBufferPtr = nullPtr, sharedBufferLen = 0, sharedBufferCap = 0} + pure 0 + | otherwise = do + let values = + [ start + fromIntegral offset + | offset <- [0 .. (fromIntegral count :: Int) - 1] + ] + valuesPtr <- newArray values + poke + outBuffer + SharedI32Buffer + { sharedBufferPtr = valuesPtr + , sharedBufferLen = count + , sharedBufferCap = count + } + pure 0 + +hsMakeBytePattern :: CUChar -> CSize -> Ptr SharedU8Buffer -> IO CInt +hsMakeBytePattern seed count outBuffer + | outBuffer == nullPtr = pure 1 + | count == 0 = do + poke outBuffer SharedU8Buffer{sharedByteBufferPtr = nullPtr, sharedByteBufferLen = 0, sharedByteBufferCap = 0} + pure 0 + | otherwise = do + let seedValue = fromIntegral seed :: Int + let values = + [ fromIntegral ((seedValue + offset) `mod` 256) :: CUChar + | offset <- [0 .. (fromIntegral count :: Int) - 1] + ] + valuesPtr <- newArray values + poke + outBuffer + SharedU8Buffer + { sharedByteBufferPtr = valuesPtr + , sharedByteBufferLen = count + , sharedByteBufferCap = count + } + pure 0 + hsFreeString :: CString -> IO () hsFreeString ptr | ptr == nullPtr = pure () | otherwise = free ptr +hsFreeI32Buffer :: Ptr CInt -> CSize -> CSize -> IO () +hsFreeI32Buffer ptr _ _ + | ptr == nullPtr = pure () + | otherwise = free ptr + +hsFreeU8Buffer :: Ptr CUChar -> CSize -> CSize -> IO () +hsFreeU8Buffer ptr _ _ + | ptr == nullPtr = pure () + | otherwise = free ptr + foreign export ccall "hs_compute_stats" hsComputeStats :: CInt -> CInt -> Ptr SharedStats -> IO CInt +foreign export ccall "hs_sum_slice" hsSumSlice :: Ptr CInt -> CSize -> IO CInt +foreign export ccall "hs_checksum_bytes" hsChecksumBytes :: Ptr CUChar -> CSize -> IO CUInt foreign export ccall "hs_make_message" hsMakeMessage :: CString -> CInt -> CInt -> IO CString +foreign export ccall "hs_make_sequence" hsMakeSequence :: CInt -> CSize -> Ptr SharedI32Buffer -> IO CInt +foreign export ccall "hs_make_byte_pattern" hsMakeBytePattern :: CUChar -> CSize -> Ptr SharedU8Buffer -> IO CInt foreign export ccall "hs_free_string" hsFreeString :: CString -> IO () +foreign export ccall "hs_free_i32_buffer" hsFreeI32Buffer :: Ptr CInt -> CSize -> CSize -> IO () +foreign export ccall "hs_free_u8_buffer" hsFreeU8Buffer :: Ptr CUChar -> CSize -> CSize -> IO () diff --git a/haskell/src/Interop/Shared.hs b/haskell/src/Interop/Shared.hs index 869bd22..f1b5ba5 100644 --- a/haskell/src/Interop/Shared.hs +++ b/haskell/src/Interop/Shared.hs @@ -1,5 +1,7 @@ module Interop.Shared ( + SharedI32Buffer (..), SharedStats (..), + SharedU8Buffer (..), Summary (..), calculateSummary, formatHaskellMessage, @@ -7,7 +9,8 @@ module Interop.Shared ( summaryToSharedStats, ) where -import Foreign.C.Types (CInt) +import Foreign.C.Types (CInt, CSize, CUChar) +import Foreign.Ptr (Ptr) import Foreign.Storable (Storable (..), peekByteOff, pokeByteOff) data Summary = Summary @@ -43,6 +46,20 @@ data SharedStats = SharedStats } deriving (Eq, Show) +data SharedI32Buffer = SharedI32Buffer + { sharedBufferPtr :: Ptr CInt + , sharedBufferLen :: CSize + , sharedBufferCap :: CSize + } + deriving (Eq, Show) + +data SharedU8Buffer = SharedU8Buffer + { sharedByteBufferPtr :: Ptr CUChar + , sharedByteBufferLen :: CSize + , sharedByteBufferCap :: CSize + } + deriving (Eq, Show) + instance Storable SharedStats where sizeOf _ = fieldSize * 3 where @@ -65,6 +82,56 @@ instance Storable SharedStats where where fieldSize = sizeOf (undefined :: CInt) +instance Storable SharedI32Buffer where + sizeOf _ = ptrSize + (fieldSize * 2) + where + ptrSize = sizeOf (undefined :: Ptr CInt) + fieldSize = sizeOf (undefined :: CSize) + + alignment _ = alignment (undefined :: Ptr CInt) + + peek ptr = + SharedI32Buffer + <$> peekByteOff ptr 0 + <*> peekByteOff ptr ptrSize + <*> peekByteOff ptr (ptrSize + fieldSize) + where + ptrSize = sizeOf (undefined :: Ptr CInt) + fieldSize = sizeOf (undefined :: CSize) + + poke ptr value = do + pokeByteOff ptr 0 (sharedBufferPtr value) + pokeByteOff ptr ptrSize (sharedBufferLen value) + pokeByteOff ptr (ptrSize + fieldSize) (sharedBufferCap value) + where + ptrSize = sizeOf (undefined :: Ptr CInt) + fieldSize = sizeOf (undefined :: CSize) + +instance Storable SharedU8Buffer where + sizeOf _ = ptrSize + (fieldSize * 2) + where + ptrSize = sizeOf (undefined :: Ptr CUChar) + fieldSize = sizeOf (undefined :: CSize) + + alignment _ = alignment (undefined :: Ptr CUChar) + + peek ptr = + SharedU8Buffer + <$> peekByteOff ptr 0 + <*> peekByteOff ptr ptrSize + <*> peekByteOff ptr (ptrSize + fieldSize) + where + ptrSize = sizeOf (undefined :: Ptr CUChar) + fieldSize = sizeOf (undefined :: CSize) + + poke ptr value = do + pokeByteOff ptr 0 (sharedByteBufferPtr value) + pokeByteOff ptr ptrSize (sharedByteBufferLen value) + pokeByteOff ptr (ptrSize + fieldSize) (sharedByteBufferCap value) + where + ptrSize = sizeOf (undefined :: Ptr CUChar) + fieldSize = sizeOf (undefined :: CSize) + summaryToSharedStats :: Summary -> SharedStats summaryToSharedStats summary = SharedStats diff --git a/rust-toolchain.toml b/rust-toolchain.toml index cb83ddb..db8c8a6 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.83.0" +channel = "1.92.0" components = ["rustfmt", "clippy", "rust-analyzer"] diff --git a/rust/build.rs b/rust/build.rs deleted file mode 100644 index 021a2b6..0000000 --- a/rust/build.rs +++ /dev/null @@ -1,262 +0,0 @@ -use std::collections::{BTreeSet, HashSet}; -use std::env; -use std::fs; -use std::path::{Path, PathBuf}; -use std::process::Command; - -fn main() { - println!("cargo:rerun-if-env-changed=GHC_LIBDIR"); - println!("cargo:rerun-if-env-changed=GHC_RTS_LIB"); - println!("cargo:rerun-if-changed=rust/build.rs"); - - let libdir = env::var("GHC_LIBDIR").unwrap_or_else(|_| ghc_print_libdir()); - let explicit_rts = env::var("GHC_RTS_LIB").ok().map(PathBuf::from); - let rts_path = explicit_rts.unwrap_or_else(|| find_rts_library(Path::new(&libdir))); - let rts_dir = rts_path - .parent() - .unwrap_or_else(|| Path::new(&libdir)) - .to_path_buf(); - let rts_name = rts_path - .file_stem() - .and_then(|stem| stem.to_str()) - .map(strip_library_prefix) - .unwrap_or_else(|| panic!("failed to resolve GHC RTS library name from {}", rts_path.display())); - - let mut search_dirs = BTreeSet::new(); - let mut haskell_libs = Vec::new(); - let mut seen_haskell_libs = HashSet::new(); - let mut native_libs = BTreeSet::new(); - - search_dirs.insert(rts_dir); - seen_haskell_libs.insert(rts_name.clone()); - haskell_libs.push(rts_name); - - for package in ["base", "ghc-prim", "ghc-bignum"] { - let info = ghc_pkg_describe(package); - for dir in &info.dynamic_library_dirs { - search_dirs.insert(dir.clone()); - } - for library in info.hs_libraries { - let resolved = resolve_dynamic_hs_library(&library, &info.dynamic_library_dirs); - if seen_haskell_libs.insert(resolved.clone()) { - haskell_libs.push(resolved); - } - } - - for library in info.extra_libraries { - native_libs.insert(library); - } - } - - let rts_info = ghc_pkg_describe("rts"); - for dir in rts_info.dynamic_library_dirs { - search_dirs.insert(dir); - } - for library in rts_info.extra_libraries { - native_libs.insert(library); - } - - for dir in search_dirs { - println!("cargo:rustc-link-search=native={}", dir.display()); - println!("cargo:rustc-link-arg=-Wl,-rpath,{}", dir.display()); - } - - println!("cargo:rustc-link-arg=-Wl,--no-as-needed"); - for library in haskell_libs { - println!("cargo:rustc-link-lib=dylib={library}"); - } - println!("cargo:rustc-link-arg=-Wl,--as-needed"); - - for library in native_libs { - println!("cargo:rustc-link-lib=dylib={library}"); - } -} - -#[derive(Default)] -struct PackageInfo { - hs_libraries: Vec, - extra_libraries: Vec, - dynamic_library_dirs: Vec, -} - -fn ghc_print_libdir() -> String { - let output = Command::new("ghc") - .arg("--print-libdir") - .output() - .unwrap_or_else(|error| panic!("failed to run `ghc --print-libdir`: {error}")); - if !output.status.success() { - panic!("`ghc --print-libdir` did not exit successfully"); - } - - String::from_utf8_lossy(&output.stdout).trim().to_string() -} - -fn ghc_pkg_describe(package: &str) -> PackageInfo { - let output = Command::new("ghc-pkg") - .args(["describe", package]) - .output() - .unwrap_or_else(|error| panic!("failed to run `ghc-pkg describe {package}`: {error}")); - if !output.status.success() { - panic!("`ghc-pkg describe {package}` did not exit successfully"); - } - - let description = String::from_utf8_lossy(&output.stdout); - parse_package_description(&description) -} - -fn parse_package_description(description: &str) -> PackageInfo { - let mut info = PackageInfo::default(); - let mut current_field = String::new(); - let mut pkgroot = String::new(); - - for raw_line in description.lines() { - let line = raw_line.trim_end(); - if line.is_empty() { - continue; - } - - if raw_line.starts_with(' ') || raw_line.starts_with('\t') { - push_field_values(¤t_field, line.trim(), &mut pkgroot, &mut info); - continue; - } - - if let Some((field, rest)) = line.split_once(':') { - current_field = field.trim().to_string(); - push_field_values(¤t_field, rest.trim(), &mut pkgroot, &mut info); - } - } - - if !pkgroot.is_empty() { - for dir in &mut info.dynamic_library_dirs { - let resolved = dir - .display() - .to_string() - .replace("${pkgroot}", &pkgroot); - *dir = PathBuf::from(resolved); - } - } - - info -} - -fn push_field_values( - field: &str, - values: &str, - pkgroot: &mut String, - info: &mut PackageInfo, -) { - if values.is_empty() { - return; - } - - match field { - "pkgroot" => { - *pkgroot = values.trim_matches('"').to_string(); - } - "hs-libraries" => { - for value in values.split_whitespace() { - info.hs_libraries - .push(strip_library_prefix(value.trim_matches('"'))); - } - } - "extra-libraries" => { - for value in values.split_whitespace() { - info.extra_libraries.push(value.trim_matches('"').to_string()); - } - } - "dynamic-library-dirs" => { - for value in values.split_whitespace() { - info.dynamic_library_dirs - .push(PathBuf::from(value.trim_matches('"'))); - } - } - _ => {} - } -} - -fn resolve_dynamic_hs_library(library: &str, search_dirs: &[PathBuf]) -> String { - for dir in search_dirs { - let Ok(entries) = fs::read_dir(dir) else { - continue; - }; - - for entry in entries.flatten() { - let path = entry.path(); - let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else { - continue; - }; - - let exact_name = format!("lib{library}.so"); - let versioned_prefix = format!("lib{library}-"); - if (file_name == exact_name - || (file_name.starts_with(&versioned_prefix) && file_name.ends_with(".so"))) - && path.is_file() - { - if let Some(stem) = path.file_stem().and_then(|value| value.to_str()) { - return strip_library_prefix(stem); - } - } - } - } - - library.to_string() -} - -fn find_rts_library(libdir: &Path) -> PathBuf { - let mut candidates = Vec::new(); - walk_for_rts(libdir, &mut candidates); - candidates.sort_by_key(|path| rts_priority(path)); - - candidates - .into_iter() - .next() - .unwrap_or_else(|| panic!("failed to locate a GHC RTS library under {}", libdir.display())) -} - -fn walk_for_rts(root: &Path, candidates: &mut Vec) { - let Ok(entries) = fs::read_dir(root) else { - return; - }; - - for entry in entries.flatten() { - let path = entry.path(); - if path.is_dir() { - walk_for_rts(&path, candidates); - continue; - } - - if is_threaded_rts_library(&path) { - candidates.push(path); - } - } -} - -fn is_threaded_rts_library(path: &Path) -> bool { - let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else { - return false; - }; - - path.extension().and_then(|ext| ext.to_str()) == Some("so") && file_name.starts_with("libHSrts-") -} - -fn rts_priority(path: &Path) -> (u8, String) { - let file_name = path - .file_name() - .and_then(|value| value.to_str()) - .unwrap_or_default() - .to_string(); - - let rank = if file_name.contains("_debug") { - 3 - } else if file_name.contains("_thr") { - 1 - } else { - 0 - }; - - (rank, file_name) -} - -fn strip_library_prefix(stem: &str) -> String { - stem.strip_prefix("lib").unwrap_or(stem).to_string() -} diff --git a/rust/haskell.rs b/rust/haskell.rs index 6b5db01..ab3e7f4 100644 --- a/rust/haskell.rs +++ b/rust/haskell.rs @@ -1,19 +1,22 @@ -use crate::interop::SharedStats; +use crate::interop::{SharedI32Buffer, SharedStats, SharedU8Buffer}; use libloading::Library; use std::env; use std::ffi::{CStr, CString}; use std::fs; -use std::os::raw::{c_char, c_int}; +use std::os::raw::{c_char, c_int, c_uint}; use std::path::{Path, PathBuf}; -unsafe extern "C" { - fn hs_init(argc: *mut c_int, argv: *mut *mut *mut c_char); - fn hs_exit(); -} - type HsComputeStats = unsafe extern "C" fn(c_int, c_int, *mut SharedStats) -> c_int; +type HsSumSlice = unsafe extern "C" fn(*const c_int, usize) -> c_int; +type HsChecksumBytes = unsafe extern "C" fn(*const u8, usize) -> c_uint; type HsMakeMessage = unsafe extern "C" fn(*const c_char, c_int, c_int) -> *mut c_char; +type HsMakeSequence = unsafe extern "C" fn(c_int, usize, *mut SharedI32Buffer) -> c_int; +type HsMakeBytePattern = unsafe extern "C" fn(u8, usize, *mut SharedU8Buffer) -> c_int; type HsFreeString = unsafe extern "C" fn(*mut c_char); +type HsFreeI32Buffer = unsafe extern "C" fn(*mut c_int, usize, usize); +type HsFreeU8Buffer = unsafe extern "C" fn(*mut u8, usize, usize); +type HsInit = unsafe extern "C" fn(*mut c_int, *mut *mut *mut c_char); +type HsExit = unsafe extern "C" fn(); #[derive(Clone, Debug, PartialEq, Eq)] pub struct DemoArgs { @@ -25,18 +28,40 @@ pub struct DemoArgs { pub fn run_haskell_demo(args: &DemoArgs) -> Result { let library_path = resolve_library_path(args.library_path.as_deref())?; + let library = unsafe { Library::new(&library_path) } + .map_err(|error| format!("failed to load {}: {error}", library_path.display()))?; let runtime = HaskellRuntime::start()?; - let output = load_and_run(&library_path, args); + let output = load_and_run(&library_path, &library, args); drop(runtime); output } -struct HaskellRuntime; +struct HaskellRuntime { + _rts_library: Library, + hs_exit: HsExit, +} impl HaskellRuntime { fn start() -> Result { + let rts_library_path = resolve_rts_library_path()?; + let rts_library = unsafe { Library::new(&rts_library_path) }.map_err(|error| { + format!( + "failed to load GHC RTS library {}: {error}", + rts_library_path.display() + ) + })?; + let hs_init: HsInit = unsafe { + *rts_library + .get(b"hs_init\0") + .map_err(|error| format!("failed to load hs_init: {error}"))? + }; + let hs_exit: HsExit = unsafe { + *rts_library + .get(b"hs_exit\0") + .map_err(|error| format!("failed to load hs_exit: {error}"))? + }; let mut argc: c_int = 1; let program_name = CString::new("integrations-hs-runtime") .map_err(|_| "failed to create runtime program name".to_string())?; @@ -47,37 +72,67 @@ impl HaskellRuntime { hs_init(&mut argc, &mut argv_ptr); } - Ok(Self) + Ok(Self { + _rts_library: rts_library, + hs_exit, + }) } } impl Drop for HaskellRuntime { fn drop(&mut self) { unsafe { - hs_exit(); + (self.hs_exit)(); } } } -fn load_and_run(library_path: &Path, args: &DemoArgs) -> Result { - let library = unsafe { Library::new(library_path) } - .map_err(|error| format!("failed to load {}: {error}", library_path.display()))?; - +fn load_and_run(library_path: &Path, library: &Library, args: &DemoArgs) -> Result { let compute_stats: HsComputeStats = unsafe { *library .get(b"hs_compute_stats\0") .map_err(|error| format!("failed to load hs_compute_stats: {error}"))? }; + let sum_slice: HsSumSlice = unsafe { + *library + .get(b"hs_sum_slice\0") + .map_err(|error| format!("failed to load hs_sum_slice: {error}"))? + }; + let checksum_bytes: HsChecksumBytes = unsafe { + *library + .get(b"hs_checksum_bytes\0") + .map_err(|error| format!("failed to load hs_checksum_bytes: {error}"))? + }; let make_message: HsMakeMessage = unsafe { *library .get(b"hs_make_message\0") .map_err(|error| format!("failed to load hs_make_message: {error}"))? }; + let make_sequence: HsMakeSequence = unsafe { + *library + .get(b"hs_make_sequence\0") + .map_err(|error| format!("failed to load hs_make_sequence: {error}"))? + }; + let make_byte_pattern: HsMakeBytePattern = unsafe { + *library + .get(b"hs_make_byte_pattern\0") + .map_err(|error| format!("failed to load hs_make_byte_pattern: {error}"))? + }; let free_string: HsFreeString = unsafe { *library .get(b"hs_free_string\0") .map_err(|error| format!("failed to load hs_free_string: {error}"))? }; + let free_i32_buffer: HsFreeI32Buffer = unsafe { + *library + .get(b"hs_free_i32_buffer\0") + .map_err(|error| format!("failed to load hs_free_i32_buffer: {error}"))? + }; + let free_u8_buffer: HsFreeU8Buffer = unsafe { + *library + .get(b"hs_free_u8_buffer\0") + .map_err(|error| format!("failed to load hs_free_u8_buffer: {error}"))? + }; let mut stats = SharedStats::default(); let status = unsafe { compute_stats(args.left, args.right, &mut stats) }; @@ -99,8 +154,40 @@ fn load_and_run(library_path: &Path, args: &DemoArgs) -> Result free_string(message_ptr); } + let sample_values: [c_int; 4] = [2, 4, 6, 8]; + let slice_sum = unsafe { sum_slice(sample_values.as_ptr(), sample_values.len()) }; + let sample_bytes: [u8; 4] = [72, 0, 105, 255]; + let byte_checksum = unsafe { checksum_bytes(sample_bytes.as_ptr(), sample_bytes.len()) }; + + let mut sequence_buffer = SharedI32Buffer::default(); + let sequence_status = unsafe { make_sequence(args.left, 5, &mut sequence_buffer) }; + if sequence_status != 0 { + return Err(format!("hs_make_sequence returned status {sequence_status}")); + } + + let sequence_values = read_i32_buffer(&sequence_buffer)?; + unsafe { + free_i32_buffer( + sequence_buffer.ptr, + sequence_buffer.len, + sequence_buffer.cap, + ); + } + + let byte_seed = normalize_byte_seed(args.left); + let mut byte_buffer = SharedU8Buffer::default(); + let byte_status = unsafe { make_byte_pattern(byte_seed, 6, &mut byte_buffer) }; + if byte_status != 0 { + return Err(format!("hs_make_byte_pattern returned status {byte_status}")); + } + + let returned_bytes = read_u8_buffer(&byte_buffer)?; + unsafe { + free_u8_buffer(byte_buffer.ptr, byte_buffer.len, byte_buffer.cap); + } + Ok(format!( - "Rust -> Haskell demo\nLibrary: {}\nInputs: name={}, left={}, right={}\nStats from Haskell: total={}, product={}, gap={}\nMessage from Haskell: {}", + "Rust -> Haskell demo\nLibrary: {}\nInputs: name={}, left={}, right={}\nStats from Haskell: total={}, product={}, gap={}\nMessage from Haskell: {}\nSlice sent to Haskell: {:?}\nHaskell summed slice to: {}\nVector returned from Haskell: {:?}\nByte slice sent to Haskell: {:?}\nHaskell checksummed bytes to: {}\nByte buffer returned from Haskell: {:?}", library_path.display(), args.name, args.left, @@ -109,9 +196,43 @@ fn load_and_run(library_path: &Path, args: &DemoArgs) -> Result stats.product, stats.gap, message, + sample_values, + slice_sum, + sequence_values, + sample_bytes, + byte_checksum, + returned_bytes, )) } +fn read_i32_buffer(buffer: &SharedI32Buffer) -> Result, String> { + if buffer.ptr.is_null() { + if buffer.len == 0 { + return Ok(Vec::new()); + } + return Err("received a null buffer pointer with a non-zero length".to_string()); + } + + let values = unsafe { std::slice::from_raw_parts(buffer.ptr, buffer.len) }; + Ok(values.to_vec()) +} + +fn read_u8_buffer(buffer: &SharedU8Buffer) -> Result, String> { + if buffer.ptr.is_null() { + if buffer.len == 0 { + return Ok(Vec::new()); + } + return Err("received a null byte buffer pointer with a non-zero length".to_string()); + } + + let values = unsafe { std::slice::from_raw_parts(buffer.ptr, buffer.len) }; + Ok(values.to_vec()) +} + +fn normalize_byte_seed(value: i32) -> u8 { + value.rem_euclid(256) as u8 +} + fn resolve_library_path(explicit_path: Option<&str>) -> Result { if let Some(path) = explicit_path { return Ok(PathBuf::from(path)); @@ -133,6 +254,33 @@ fn resolve_library_path(explicit_path: Option<&str>) -> Result }) } +fn resolve_rts_library_path() -> Result { + if let Ok(path) = env::var("GHC_RTS_LIB") { + return Ok(PathBuf::from(path)); + } + + let libdir = if let Ok(path) = env::var("GHC_LIBDIR") { + PathBuf::from(path) + } else { + let output = std::process::Command::new("ghc") + .arg("--print-libdir") + .output() + .map_err(|error| format!("failed to run `ghc --print-libdir`: {error}"))?; + if !output.status.success() { + return Err("`ghc --print-libdir` did not exit successfully".to_string()); + } + PathBuf::from(String::from_utf8_lossy(&output.stdout).trim()) + }; + + let mut matches = Vec::new(); + collect_matching_rts_libraries(&libdir, &mut matches)?; + matches.sort_by_key(|path| rts_priority(path)); + matches + .into_iter() + .next() + .ok_or_else(|| format!("could not find a GHC RTS library under {}", libdir.display())) +} + fn collect_matching_libraries(root: &Path, matches: &mut Vec) -> Result<(), String> { if !root.exists() { return Ok(()); @@ -156,6 +304,29 @@ fn collect_matching_libraries(root: &Path, matches: &mut Vec) -> Result Ok(()) } +fn collect_matching_rts_libraries(root: &Path, matches: &mut Vec) -> Result<(), String> { + if !root.exists() { + return Ok(()); + } + + let entries = fs::read_dir(root) + .map_err(|error| format!("failed to read {}: {error}", root.display()))?; + for entry in entries { + let entry = entry.map_err(|error| format!("failed to inspect {}: {error}", root.display()))?; + let path = entry.path(); + if path.is_dir() { + collect_matching_rts_libraries(&path, matches)?; + continue; + } + + if is_rts_library(&path) { + matches.push(path); + } + } + + Ok(()) +} + fn is_haskell_foreign_library(path: &Path) -> bool { let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else { return false; @@ -167,3 +338,30 @@ fn is_haskell_foreign_library(path: &Path) -> bool { is_library && file_name.contains("interop_hs") } + +fn is_rts_library(path: &Path) -> bool { + let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else { + return false; + }; + + path.extension().and_then(|ext| ext.to_str()) == Some("so") + && file_name.starts_with("libHSrts-") +} + +fn rts_priority(path: &Path) -> (u8, String) { + let file_name = path + .file_name() + .and_then(|value| value.to_str()) + .unwrap_or_default() + .to_string(); + + let rank = if file_name.contains("_debug") { + 3 + } else if file_name.contains("_thr") { + 1 + } else { + 0 + }; + + (rank, file_name) +} diff --git a/rust/interop.rs b/rust/interop.rs index 5afc230..3f013c4 100644 --- a/rust/interop.rs +++ b/rust/interop.rs @@ -1,5 +1,5 @@ use std::ffi::{CStr, CString}; -use std::os::raw::{c_char, c_int}; +use std::os::raw::{c_char, c_int, c_uint}; #[repr(C)] #[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] @@ -9,6 +9,22 @@ pub struct SharedStats { pub gap: c_int, } +#[repr(C)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub struct SharedI32Buffer { + pub ptr: *mut c_int, + pub len: usize, + pub cap: usize, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub struct SharedU8Buffer { + pub ptr: *mut u8, + pub len: usize, + pub cap: usize, +} + pub fn compute_stats(left: c_int, right: c_int) -> SharedStats { SharedStats { total: left.saturating_add(right), @@ -17,6 +33,19 @@ pub fn compute_stats(left: c_int, right: c_int) -> SharedStats { } } +pub fn sum_slice(values: &[c_int]) -> c_int { + values + .iter() + .copied() + .fold(0, |accumulator, value| accumulator.saturating_add(value)) +} + +pub fn checksum_bytes(values: &[u8]) -> c_uint { + values.iter().fold(0_u32, |accumulator, value| { + accumulator.saturating_add(u32::from(*value)) + }) +} + pub fn make_rust_message(name: &str, left: c_int, right: c_int) -> String { let stats = compute_stats(left, right); format!( @@ -25,7 +54,31 @@ pub fn make_rust_message(name: &str, left: c_int, right: c_int) -> String { ) } -#[no_mangle] +pub fn make_sequence_values(start: c_int, count: usize) -> Vec { + (0..count) + .map(|index| { + let offset = match c_int::try_from(index) { + Ok(value) => value, + Err(_) => c_int::MAX, + }; + start.saturating_add(offset) + }) + .collect() +} + +pub fn make_byte_pattern(seed: u8, count: usize) -> Vec { + (0..count) + .map(|index| { + let offset = match u8::try_from(index) { + Ok(value) => value, + Err(_) => u8::MAX, + }; + seed.wrapping_add(offset) + }) + .collect() +} + +#[unsafe(no_mangle)] pub unsafe extern "C" fn rust_compute_stats( left: c_int, right: c_int, @@ -35,11 +88,23 @@ pub unsafe extern "C" fn rust_compute_stats( return 1; } - out_stats.write(compute_stats(left, right)); + unsafe { + out_stats.write(compute_stats(left, right)); + } 0 } -#[no_mangle] +#[unsafe(no_mangle)] +pub unsafe extern "C" fn rust_sum_slice(ptr: *const c_int, len: usize) -> c_int { + if ptr.is_null() { + return 0; + } + + let values = unsafe { std::slice::from_raw_parts(ptr, len) }; + sum_slice(values) +} + +#[unsafe(no_mangle)] pub unsafe extern "C" fn rust_make_message( name: *const c_char, left: c_int, @@ -49,17 +114,99 @@ pub unsafe extern "C" fn rust_make_message( return string_into_raw("Rust received a null name pointer".to_string()); } - let name = CStr::from_ptr(name).to_string_lossy(); + let name = unsafe { CStr::from_ptr(name) }.to_string_lossy(); string_into_raw(make_rust_message(name.as_ref(), left, right)) } -#[no_mangle] +#[unsafe(no_mangle)] +pub unsafe extern "C" fn rust_checksum_bytes(ptr: *const u8, len: usize) -> c_uint { + if ptr.is_null() { + return 0; + } + + let values = unsafe { std::slice::from_raw_parts(ptr, len) }; + checksum_bytes(values) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn rust_make_sequence( + start: c_int, + count: usize, + out_buffer: *mut SharedI32Buffer, +) -> c_int { + if out_buffer.is_null() { + return 1; + } + + let mut values = make_sequence_values(start, count); + let buffer = SharedI32Buffer { + ptr: values.as_mut_ptr(), + len: values.len(), + cap: values.capacity(), + }; + + std::mem::forget(values); + unsafe { + out_buffer.write(buffer); + } + 0 +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn rust_make_byte_pattern( + seed: u8, + count: usize, + out_buffer: *mut SharedU8Buffer, +) -> c_int { + if out_buffer.is_null() { + return 1; + } + + let mut values = make_byte_pattern(seed, count); + let buffer = SharedU8Buffer { + ptr: values.as_mut_ptr(), + len: values.len(), + cap: values.capacity(), + }; + + std::mem::forget(values); + unsafe { + out_buffer.write(buffer); + } + 0 +} + +#[unsafe(no_mangle)] pub unsafe extern "C" fn rust_free_string(ptr: *mut c_char) { if ptr.is_null() { return; } - drop(CString::from_raw(ptr)); + unsafe { + drop(CString::from_raw(ptr)); + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn rust_free_i32_buffer(ptr: *mut c_int, len: usize, cap: usize) { + if ptr.is_null() { + return; + } + + unsafe { + drop(Vec::from_raw_parts(ptr, len, cap)); + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn rust_free_u8_buffer(ptr: *mut u8, len: usize, cap: usize) { + if ptr.is_null() { + return; + } + + unsafe { + drop(Vec::from_raw_parts(ptr, len, cap)); + } } fn string_into_raw(message: String) -> *mut c_char { @@ -95,4 +242,24 @@ mod tests { assert!(message.contains("product=35")); assert!(message.contains("gap=2")); } + + #[test] + fn sum_slice_handles_multiple_values() { + assert_eq!(sum_slice(&[2, 4, 6, 8]), 20); + } + + #[test] + fn make_sequence_values_builds_contiguous_output() { + assert_eq!(make_sequence_values(7, 5), vec![7, 8, 9, 10, 11]); + } + + #[test] + fn checksum_bytes_handles_embedded_zeroes() { + assert_eq!(checksum_bytes(&[72, 0, 105, 255]), 432); + } + + #[test] + fn make_byte_pattern_wraps_like_bytes() { + assert_eq!(make_byte_pattern(254, 4), vec![254, 255, 0, 1]); + } }