Add Rust/Haskell slice and buffer interop demos

This commit is contained in:
Hassan Abedi 2026-03-27 09:20:31 +01:00
parent 317f1d0a0c
commit 728322ef0e
11 changed files with 659 additions and 293 deletions

View File

@ -4,7 +4,6 @@ version = "0.1.0"
edition = "2024"
license = "MIT OR Apache-2.0"
publish = false
build = "rust/build.rs"
[lib]
path = "rust/lib.rs"

View File

@ -23,8 +23,20 @@ The boundary is deliberately C-shaped:
- integers
- a shared struct layout
- borrowed slices as `ptr + len`
- owned buffers as `ptr + len + cap`
- owned C strings with explicit free functions on each side
The current examples cover:
- scalar values and shared structs
- borrowed `i32` slices
- owned `i32` buffers returned across the boundary
- borrowed raw byte buffers
- owned raw byte buffers returned across the boundary
The byte examples are separate from C strings on purpose. They demonstrate handling embedded zero bytes safely.
## Why The Build Uses Both Static And Shared Libraries
This demo uses different library styles for the two directions because that keeps each path simpler:
@ -70,6 +82,8 @@ CABAL_DIR=$PWD/../.cabal XDG_STATE_HOME=$PWD/../.cabal/state XDG_CACHE_HOME=$PWD
- the ABI boundary must stay simple and explicit
- Rust and Haskell do not share ownership rules automatically
- struct layout must match on both sides
- borrowed data and owned returned buffers need different FFI patterns
- raw bytes are not the same thing as C strings
- Rust calling Haskell is the harder direction because it must initialize the GHC runtime correctly
- build tooling is part of the integration problem, not just an implementation detail

View File

@ -19,10 +19,20 @@ The code keeps the FFI surface small on purpose. The boundary uses:
- integers
- a fixed C-shaped struct
- borrowed slices as `ptr + len`
- owned returned buffers as `ptr + len + cap`
- owned C strings with an explicit free function on each side
That is enough to demonstrate the main challenges without pulling in code generation or large bindings.
The current examples include:
- scalar stats and message passing
- borrowed `i32` slices sent across the boundary
- owned `i32` buffers returned across the boundary
- borrowed raw byte buffers with embedded zero bytes
- owned raw byte buffers returned across the boundary
## Build And Run
From the repository root:
@ -56,6 +66,8 @@ cargo run -- rust-calls-haskell Ada 7 5
- The boundary must stay C-shaped. Rich Rust and Haskell types do not cross directly.
- Strings need explicit ownership rules. Each side exports its own free function.
- Borrowed slices and owned returned buffers are different patterns and must be modeled differently.
- Raw byte buffers should be treated separately from C strings.
- Struct layout must be mirrored carefully on both sides.
- Rust calling Haskell is the harder direction because it must initialize and shut down the GHC runtime correctly.
- Build order is part of the design. Haskell links against the Rust static library, and Rust loads the Haskell foreign library after Cabal builds it.

View File

@ -1,7 +1,14 @@
module Main (main) where
import Interop.Shared (Summary (..))
import RustClient (callRustMessage, callRustSummary)
import RustClient (
callRustByteChecksum,
callRustBytePattern,
callRustMessage,
callRustSequence,
callRustSliceSum,
callRustSummary,
)
import System.Environment (getArgs)
import Text.Read (readMaybe)
@ -11,11 +18,23 @@ main = do
let (name, left, right) = parseArgs args
summary <- callRustSummary left right
message <- callRustMessage name left right
let sliceValues = [2, 4, 6, 8]
sliceSum <- callRustSliceSum sliceValues
sequenceValues <- callRustSequence left 5
let byteValues = [72, 0, 105, 255]
byteChecksum <- callRustByteChecksum byteValues
bytePattern <- callRustBytePattern left 6
putStrLn "Haskell -> Rust demo"
putStrLn $ "Inputs: name=" ++ name ++ ", left=" ++ show left ++ ", right=" ++ show right
putStrLn $ "Stats from Rust: " ++ renderSummary summary
putStrLn $ "Message from Rust: " ++ message
putStrLn $ "Slice sent to Rust: " ++ show sliceValues
putStrLn $ "Rust summed slice to: " ++ show sliceSum
putStrLn $ "Vector returned from Rust: " ++ show sequenceValues
putStrLn $ "Byte slice sent to Rust: " ++ show byteValues
putStrLn $ "Rust checksummed bytes to: " ++ show byteChecksum
putStrLn $ "Byte buffer returned from Rust: " ++ show bytePattern
parseArgs :: [String] -> (String, Int, Int)
parseArgs args =

View File

@ -1,25 +1,48 @@
module RustClient (
callRustByteChecksum,
callRustBytePattern,
callRustMessage,
callRustSequence,
callRustSliceSum,
callRustSummary,
) where
import Control.Exception (bracket)
import Foreign.C.String (CString, peekCString, withCString)
import Foreign.C.Types (CInt (..))
import Foreign.C.Types (CInt (..), CSize (..), CUChar (..), CUInt (..))
import Foreign.Marshal.Alloc (alloca)
import Foreign.Marshal.Array (peekArray, withArrayLen)
import Foreign.Ptr (Ptr, nullPtr)
import Foreign.Storable (peek)
import Interop.Shared (Summary, SharedStats, summaryFromSharedStats)
import Interop.Shared (SharedI32Buffer (..), SharedU8Buffer (..), Summary, SharedStats, summaryFromSharedStats)
foreign import ccall unsafe "rust_compute_stats"
rustComputeStats :: CInt -> CInt -> Ptr SharedStats -> IO CInt
foreign import ccall unsafe "rust_sum_slice"
rustSumSlice :: Ptr CInt -> CSize -> IO CInt
foreign import ccall unsafe "rust_checksum_bytes"
rustChecksumBytes :: Ptr CUChar -> CSize -> IO CUInt
foreign import ccall unsafe "rust_make_message"
rustMakeMessage :: CString -> CInt -> CInt -> IO CString
foreign import ccall unsafe "rust_make_sequence"
rustMakeSequence :: CInt -> CSize -> Ptr SharedI32Buffer -> IO CInt
foreign import ccall unsafe "rust_make_byte_pattern"
rustMakeBytePattern :: CUChar -> CSize -> Ptr SharedU8Buffer -> IO CInt
foreign import ccall unsafe "rust_free_string"
rustFreeString :: CString -> IO ()
foreign import ccall unsafe "rust_free_i32_buffer"
rustFreeI32Buffer :: Ptr CInt -> CSize -> CSize -> IO ()
foreign import ccall unsafe "rust_free_u8_buffer"
rustFreeU8Buffer :: Ptr CUChar -> CSize -> CSize -> IO ()
callRustSummary :: Int -> Int -> IO Summary
callRustSummary left right =
alloca $ \outStats -> do
@ -28,6 +51,16 @@ callRustSummary left right =
then fail ("rustComputeStats returned status " ++ show status)
else summaryFromSharedStats <$> peek outStats
callRustSliceSum :: [Int] -> IO Int
callRustSliceSum values =
withArrayLen (map fromIntegral values) $ \valueCount valuesPtr ->
fmap fromIntegral (rustSumSlice valuesPtr (fromIntegral valueCount))
callRustByteChecksum :: [Int] -> IO Int
callRustByteChecksum values =
withArrayLen (map (fromIntegral . (`mod` 256)) values) $ \valueCount valuesPtr ->
fmap fromIntegral (rustChecksumBytes valuesPtr (fromIntegral valueCount))
callRustMessage :: String -> Int -> Int -> IO String
callRustMessage name left right =
withCString name $ \namePtr -> do
@ -39,3 +72,33 @@ callRustMessage name left right =
(pure messagePtr)
rustFreeString
peekCString
callRustSequence :: Int -> Int -> IO [Int]
callRustSequence start count =
alloca $ \outBuffer -> do
status <- rustMakeSequence (fromIntegral start) (fromIntegral count) outBuffer
if status /= 0
then fail ("rustMakeSequence returned status " ++ show status)
else do
buffer <- peek outBuffer
values <-
if sharedBufferPtr buffer == nullPtr
then pure []
else map fromIntegral <$> peekArray (fromIntegral (sharedBufferLen buffer)) (sharedBufferPtr buffer)
rustFreeI32Buffer (sharedBufferPtr buffer) (sharedBufferLen buffer) (sharedBufferCap buffer)
pure values
callRustBytePattern :: Int -> Int -> IO [Int]
callRustBytePattern seed count =
alloca $ \outBuffer -> do
status <- rustMakeBytePattern (fromIntegral (seed `mod` 256)) (fromIntegral count) outBuffer
if status /= 0
then fail ("rustMakeBytePattern returned status " ++ show status)
else do
buffer <- peek outBuffer
values <-
if sharedByteBufferPtr buffer == nullPtr
then pure []
else map fromIntegral <$> peekArray (fromIntegral (sharedByteBufferLen buffer)) (sharedByteBufferPtr buffer)
rustFreeU8Buffer (sharedByteBufferPtr buffer) (sharedByteBufferLen buffer) (sharedByteBufferCap buffer)
pure values

View File

@ -1,15 +1,29 @@
module Interop.Exports (
hsChecksumBytes,
hsComputeStats,
hsFreeI32Buffer,
hsFreeU8Buffer,
hsFreeString,
hsMakeBytePattern,
hsMakeMessage,
hsMakeSequence,
hsSumSlice,
) where
import Foreign.C.String (CString, newCString, peekCString)
import Foreign.C.Types (CInt (..))
import Foreign.C.Types (CInt (..), CSize (..), CUChar (..), CUInt (..))
import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (newArray, peekArray)
import Foreign.Ptr (Ptr, nullPtr)
import Foreign.Storable (poke)
import Interop.Shared (SharedStats, calculateSummary, formatHaskellMessage, summaryToSharedStats)
import Interop.Shared (
SharedI32Buffer (..),
SharedStats,
SharedU8Buffer (..),
calculateSummary,
formatHaskellMessage,
summaryToSharedStats,
)
hsComputeStats :: CInt -> CInt -> Ptr SharedStats -> IO CInt
hsComputeStats left right outStats
@ -20,6 +34,22 @@ hsComputeStats left right outStats
calculateSummary (fromIntegral left) (fromIntegral right)
pure 0
hsSumSlice :: Ptr CInt -> CSize -> IO CInt
hsSumSlice valuesPtr valueCount
| valuesPtr == nullPtr = pure 0
| otherwise = do
values <- peekArray (fromIntegral valueCount) valuesPtr
pure (sum values)
hsChecksumBytes :: Ptr CUChar -> CSize -> IO CUInt
hsChecksumBytes bytesPtr byteCount
| bytesPtr == nullPtr = pure 0
| otherwise = do
values <- peekArray (fromIntegral byteCount) bytesPtr
pure $
fromIntegral $
sum (map (fromIntegral :: CUChar -> Int) values)
hsMakeMessage :: CString -> CInt -> CInt -> IO CString
hsMakeMessage namePtr left right
| namePtr == nullPtr = newCString "Haskell received a null name pointer"
@ -28,11 +58,70 @@ hsMakeMessage namePtr left right
let summary = calculateSummary (fromIntegral left) (fromIntegral right)
newCString (formatHaskellMessage name summary)
hsMakeSequence :: CInt -> CSize -> Ptr SharedI32Buffer -> IO CInt
hsMakeSequence start count outBuffer
| outBuffer == nullPtr = pure 1
| count == 0 = do
poke outBuffer SharedI32Buffer{sharedBufferPtr = nullPtr, sharedBufferLen = 0, sharedBufferCap = 0}
pure 0
| otherwise = do
let values =
[ start + fromIntegral offset
| offset <- [0 .. (fromIntegral count :: Int) - 1]
]
valuesPtr <- newArray values
poke
outBuffer
SharedI32Buffer
{ sharedBufferPtr = valuesPtr
, sharedBufferLen = count
, sharedBufferCap = count
}
pure 0
hsMakeBytePattern :: CUChar -> CSize -> Ptr SharedU8Buffer -> IO CInt
hsMakeBytePattern seed count outBuffer
| outBuffer == nullPtr = pure 1
| count == 0 = do
poke outBuffer SharedU8Buffer{sharedByteBufferPtr = nullPtr, sharedByteBufferLen = 0, sharedByteBufferCap = 0}
pure 0
| otherwise = do
let seedValue = fromIntegral seed :: Int
let values =
[ fromIntegral ((seedValue + offset) `mod` 256) :: CUChar
| offset <- [0 .. (fromIntegral count :: Int) - 1]
]
valuesPtr <- newArray values
poke
outBuffer
SharedU8Buffer
{ sharedByteBufferPtr = valuesPtr
, sharedByteBufferLen = count
, sharedByteBufferCap = count
}
pure 0
hsFreeString :: CString -> IO ()
hsFreeString ptr
| ptr == nullPtr = pure ()
| otherwise = free ptr
hsFreeI32Buffer :: Ptr CInt -> CSize -> CSize -> IO ()
hsFreeI32Buffer ptr _ _
| ptr == nullPtr = pure ()
| otherwise = free ptr
hsFreeU8Buffer :: Ptr CUChar -> CSize -> CSize -> IO ()
hsFreeU8Buffer ptr _ _
| ptr == nullPtr = pure ()
| otherwise = free ptr
foreign export ccall "hs_compute_stats" hsComputeStats :: CInt -> CInt -> Ptr SharedStats -> IO CInt
foreign export ccall "hs_sum_slice" hsSumSlice :: Ptr CInt -> CSize -> IO CInt
foreign export ccall "hs_checksum_bytes" hsChecksumBytes :: Ptr CUChar -> CSize -> IO CUInt
foreign export ccall "hs_make_message" hsMakeMessage :: CString -> CInt -> CInt -> IO CString
foreign export ccall "hs_make_sequence" hsMakeSequence :: CInt -> CSize -> Ptr SharedI32Buffer -> IO CInt
foreign export ccall "hs_make_byte_pattern" hsMakeBytePattern :: CUChar -> CSize -> Ptr SharedU8Buffer -> IO CInt
foreign export ccall "hs_free_string" hsFreeString :: CString -> IO ()
foreign export ccall "hs_free_i32_buffer" hsFreeI32Buffer :: Ptr CInt -> CSize -> CSize -> IO ()
foreign export ccall "hs_free_u8_buffer" hsFreeU8Buffer :: Ptr CUChar -> CSize -> CSize -> IO ()

View File

@ -1,5 +1,7 @@
module Interop.Shared (
SharedI32Buffer (..),
SharedStats (..),
SharedU8Buffer (..),
Summary (..),
calculateSummary,
formatHaskellMessage,
@ -7,7 +9,8 @@ module Interop.Shared (
summaryToSharedStats,
) where
import Foreign.C.Types (CInt)
import Foreign.C.Types (CInt, CSize, CUChar)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable (..), peekByteOff, pokeByteOff)
data Summary = Summary
@ -43,6 +46,20 @@ data SharedStats = SharedStats
}
deriving (Eq, Show)
data SharedI32Buffer = SharedI32Buffer
{ sharedBufferPtr :: Ptr CInt
, sharedBufferLen :: CSize
, sharedBufferCap :: CSize
}
deriving (Eq, Show)
data SharedU8Buffer = SharedU8Buffer
{ sharedByteBufferPtr :: Ptr CUChar
, sharedByteBufferLen :: CSize
, sharedByteBufferCap :: CSize
}
deriving (Eq, Show)
instance Storable SharedStats where
sizeOf _ = fieldSize * 3
where
@ -65,6 +82,56 @@ instance Storable SharedStats where
where
fieldSize = sizeOf (undefined :: CInt)
instance Storable SharedI32Buffer where
sizeOf _ = ptrSize + (fieldSize * 2)
where
ptrSize = sizeOf (undefined :: Ptr CInt)
fieldSize = sizeOf (undefined :: CSize)
alignment _ = alignment (undefined :: Ptr CInt)
peek ptr =
SharedI32Buffer
<$> peekByteOff ptr 0
<*> peekByteOff ptr ptrSize
<*> peekByteOff ptr (ptrSize + fieldSize)
where
ptrSize = sizeOf (undefined :: Ptr CInt)
fieldSize = sizeOf (undefined :: CSize)
poke ptr value = do
pokeByteOff ptr 0 (sharedBufferPtr value)
pokeByteOff ptr ptrSize (sharedBufferLen value)
pokeByteOff ptr (ptrSize + fieldSize) (sharedBufferCap value)
where
ptrSize = sizeOf (undefined :: Ptr CInt)
fieldSize = sizeOf (undefined :: CSize)
instance Storable SharedU8Buffer where
sizeOf _ = ptrSize + (fieldSize * 2)
where
ptrSize = sizeOf (undefined :: Ptr CUChar)
fieldSize = sizeOf (undefined :: CSize)
alignment _ = alignment (undefined :: Ptr CUChar)
peek ptr =
SharedU8Buffer
<$> peekByteOff ptr 0
<*> peekByteOff ptr ptrSize
<*> peekByteOff ptr (ptrSize + fieldSize)
where
ptrSize = sizeOf (undefined :: Ptr CUChar)
fieldSize = sizeOf (undefined :: CSize)
poke ptr value = do
pokeByteOff ptr 0 (sharedByteBufferPtr value)
pokeByteOff ptr ptrSize (sharedByteBufferLen value)
pokeByteOff ptr (ptrSize + fieldSize) (sharedByteBufferCap value)
where
ptrSize = sizeOf (undefined :: Ptr CUChar)
fieldSize = sizeOf (undefined :: CSize)
summaryToSharedStats :: Summary -> SharedStats
summaryToSharedStats summary =
SharedStats

View File

@ -1,3 +1,3 @@
[toolchain]
channel = "1.83.0"
channel = "1.92.0"
components = ["rustfmt", "clippy", "rust-analyzer"]

View File

@ -1,262 +0,0 @@
use std::collections::{BTreeSet, HashSet};
use std::env;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
fn main() {
println!("cargo:rerun-if-env-changed=GHC_LIBDIR");
println!("cargo:rerun-if-env-changed=GHC_RTS_LIB");
println!("cargo:rerun-if-changed=rust/build.rs");
let libdir = env::var("GHC_LIBDIR").unwrap_or_else(|_| ghc_print_libdir());
let explicit_rts = env::var("GHC_RTS_LIB").ok().map(PathBuf::from);
let rts_path = explicit_rts.unwrap_or_else(|| find_rts_library(Path::new(&libdir)));
let rts_dir = rts_path
.parent()
.unwrap_or_else(|| Path::new(&libdir))
.to_path_buf();
let rts_name = rts_path
.file_stem()
.and_then(|stem| stem.to_str())
.map(strip_library_prefix)
.unwrap_or_else(|| panic!("failed to resolve GHC RTS library name from {}", rts_path.display()));
let mut search_dirs = BTreeSet::new();
let mut haskell_libs = Vec::new();
let mut seen_haskell_libs = HashSet::new();
let mut native_libs = BTreeSet::new();
search_dirs.insert(rts_dir);
seen_haskell_libs.insert(rts_name.clone());
haskell_libs.push(rts_name);
for package in ["base", "ghc-prim", "ghc-bignum"] {
let info = ghc_pkg_describe(package);
for dir in &info.dynamic_library_dirs {
search_dirs.insert(dir.clone());
}
for library in info.hs_libraries {
let resolved = resolve_dynamic_hs_library(&library, &info.dynamic_library_dirs);
if seen_haskell_libs.insert(resolved.clone()) {
haskell_libs.push(resolved);
}
}
for library in info.extra_libraries {
native_libs.insert(library);
}
}
let rts_info = ghc_pkg_describe("rts");
for dir in rts_info.dynamic_library_dirs {
search_dirs.insert(dir);
}
for library in rts_info.extra_libraries {
native_libs.insert(library);
}
for dir in search_dirs {
println!("cargo:rustc-link-search=native={}", dir.display());
println!("cargo:rustc-link-arg=-Wl,-rpath,{}", dir.display());
}
println!("cargo:rustc-link-arg=-Wl,--no-as-needed");
for library in haskell_libs {
println!("cargo:rustc-link-lib=dylib={library}");
}
println!("cargo:rustc-link-arg=-Wl,--as-needed");
for library in native_libs {
println!("cargo:rustc-link-lib=dylib={library}");
}
}
#[derive(Default)]
struct PackageInfo {
hs_libraries: Vec<String>,
extra_libraries: Vec<String>,
dynamic_library_dirs: Vec<PathBuf>,
}
fn ghc_print_libdir() -> String {
let output = Command::new("ghc")
.arg("--print-libdir")
.output()
.unwrap_or_else(|error| panic!("failed to run `ghc --print-libdir`: {error}"));
if !output.status.success() {
panic!("`ghc --print-libdir` did not exit successfully");
}
String::from_utf8_lossy(&output.stdout).trim().to_string()
}
fn ghc_pkg_describe(package: &str) -> PackageInfo {
let output = Command::new("ghc-pkg")
.args(["describe", package])
.output()
.unwrap_or_else(|error| panic!("failed to run `ghc-pkg describe {package}`: {error}"));
if !output.status.success() {
panic!("`ghc-pkg describe {package}` did not exit successfully");
}
let description = String::from_utf8_lossy(&output.stdout);
parse_package_description(&description)
}
fn parse_package_description(description: &str) -> PackageInfo {
let mut info = PackageInfo::default();
let mut current_field = String::new();
let mut pkgroot = String::new();
for raw_line in description.lines() {
let line = raw_line.trim_end();
if line.is_empty() {
continue;
}
if raw_line.starts_with(' ') || raw_line.starts_with('\t') {
push_field_values(&current_field, line.trim(), &mut pkgroot, &mut info);
continue;
}
if let Some((field, rest)) = line.split_once(':') {
current_field = field.trim().to_string();
push_field_values(&current_field, rest.trim(), &mut pkgroot, &mut info);
}
}
if !pkgroot.is_empty() {
for dir in &mut info.dynamic_library_dirs {
let resolved = dir
.display()
.to_string()
.replace("${pkgroot}", &pkgroot);
*dir = PathBuf::from(resolved);
}
}
info
}
fn push_field_values(
field: &str,
values: &str,
pkgroot: &mut String,
info: &mut PackageInfo,
) {
if values.is_empty() {
return;
}
match field {
"pkgroot" => {
*pkgroot = values.trim_matches('"').to_string();
}
"hs-libraries" => {
for value in values.split_whitespace() {
info.hs_libraries
.push(strip_library_prefix(value.trim_matches('"')));
}
}
"extra-libraries" => {
for value in values.split_whitespace() {
info.extra_libraries.push(value.trim_matches('"').to_string());
}
}
"dynamic-library-dirs" => {
for value in values.split_whitespace() {
info.dynamic_library_dirs
.push(PathBuf::from(value.trim_matches('"')));
}
}
_ => {}
}
}
fn resolve_dynamic_hs_library(library: &str, search_dirs: &[PathBuf]) -> String {
for dir in search_dirs {
let Ok(entries) = fs::read_dir(dir) else {
continue;
};
for entry in entries.flatten() {
let path = entry.path();
let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else {
continue;
};
let exact_name = format!("lib{library}.so");
let versioned_prefix = format!("lib{library}-");
if (file_name == exact_name
|| (file_name.starts_with(&versioned_prefix) && file_name.ends_with(".so")))
&& path.is_file()
{
if let Some(stem) = path.file_stem().and_then(|value| value.to_str()) {
return strip_library_prefix(stem);
}
}
}
}
library.to_string()
}
fn find_rts_library(libdir: &Path) -> PathBuf {
let mut candidates = Vec::new();
walk_for_rts(libdir, &mut candidates);
candidates.sort_by_key(|path| rts_priority(path));
candidates
.into_iter()
.next()
.unwrap_or_else(|| panic!("failed to locate a GHC RTS library under {}", libdir.display()))
}
fn walk_for_rts(root: &Path, candidates: &mut Vec<PathBuf>) {
let Ok(entries) = fs::read_dir(root) else {
return;
};
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
walk_for_rts(&path, candidates);
continue;
}
if is_threaded_rts_library(&path) {
candidates.push(path);
}
}
}
fn is_threaded_rts_library(path: &Path) -> bool {
let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else {
return false;
};
path.extension().and_then(|ext| ext.to_str()) == Some("so") && file_name.starts_with("libHSrts-")
}
fn rts_priority(path: &Path) -> (u8, String) {
let file_name = path
.file_name()
.and_then(|value| value.to_str())
.unwrap_or_default()
.to_string();
let rank = if file_name.contains("_debug") {
3
} else if file_name.contains("_thr") {
1
} else {
0
};
(rank, file_name)
}
fn strip_library_prefix(stem: &str) -> String {
stem.strip_prefix("lib").unwrap_or(stem).to_string()
}

View File

@ -1,19 +1,22 @@
use crate::interop::SharedStats;
use crate::interop::{SharedI32Buffer, SharedStats, SharedU8Buffer};
use libloading::Library;
use std::env;
use std::ffi::{CStr, CString};
use std::fs;
use std::os::raw::{c_char, c_int};
use std::os::raw::{c_char, c_int, c_uint};
use std::path::{Path, PathBuf};
unsafe extern "C" {
fn hs_init(argc: *mut c_int, argv: *mut *mut *mut c_char);
fn hs_exit();
}
type HsComputeStats = unsafe extern "C" fn(c_int, c_int, *mut SharedStats) -> c_int;
type HsSumSlice = unsafe extern "C" fn(*const c_int, usize) -> c_int;
type HsChecksumBytes = unsafe extern "C" fn(*const u8, usize) -> c_uint;
type HsMakeMessage = unsafe extern "C" fn(*const c_char, c_int, c_int) -> *mut c_char;
type HsMakeSequence = unsafe extern "C" fn(c_int, usize, *mut SharedI32Buffer) -> c_int;
type HsMakeBytePattern = unsafe extern "C" fn(u8, usize, *mut SharedU8Buffer) -> c_int;
type HsFreeString = unsafe extern "C" fn(*mut c_char);
type HsFreeI32Buffer = unsafe extern "C" fn(*mut c_int, usize, usize);
type HsFreeU8Buffer = unsafe extern "C" fn(*mut u8, usize, usize);
type HsInit = unsafe extern "C" fn(*mut c_int, *mut *mut *mut c_char);
type HsExit = unsafe extern "C" fn();
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct DemoArgs {
@ -25,18 +28,40 @@ pub struct DemoArgs {
pub fn run_haskell_demo(args: &DemoArgs) -> Result<String, String> {
let library_path = resolve_library_path(args.library_path.as_deref())?;
let library = unsafe { Library::new(&library_path) }
.map_err(|error| format!("failed to load {}: {error}", library_path.display()))?;
let runtime = HaskellRuntime::start()?;
let output = load_and_run(&library_path, args);
let output = load_and_run(&library_path, &library, args);
drop(runtime);
output
}
struct HaskellRuntime;
struct HaskellRuntime {
_rts_library: Library,
hs_exit: HsExit,
}
impl HaskellRuntime {
fn start() -> Result<Self, String> {
let rts_library_path = resolve_rts_library_path()?;
let rts_library = unsafe { Library::new(&rts_library_path) }.map_err(|error| {
format!(
"failed to load GHC RTS library {}: {error}",
rts_library_path.display()
)
})?;
let hs_init: HsInit = unsafe {
*rts_library
.get(b"hs_init\0")
.map_err(|error| format!("failed to load hs_init: {error}"))?
};
let hs_exit: HsExit = unsafe {
*rts_library
.get(b"hs_exit\0")
.map_err(|error| format!("failed to load hs_exit: {error}"))?
};
let mut argc: c_int = 1;
let program_name = CString::new("integrations-hs-runtime")
.map_err(|_| "failed to create runtime program name".to_string())?;
@ -47,37 +72,67 @@ impl HaskellRuntime {
hs_init(&mut argc, &mut argv_ptr);
}
Ok(Self)
Ok(Self {
_rts_library: rts_library,
hs_exit,
})
}
}
impl Drop for HaskellRuntime {
fn drop(&mut self) {
unsafe {
hs_exit();
(self.hs_exit)();
}
}
}
fn load_and_run(library_path: &Path, args: &DemoArgs) -> Result<String, String> {
let library = unsafe { Library::new(library_path) }
.map_err(|error| format!("failed to load {}: {error}", library_path.display()))?;
fn load_and_run(library_path: &Path, library: &Library, args: &DemoArgs) -> Result<String, String> {
let compute_stats: HsComputeStats = unsafe {
*library
.get(b"hs_compute_stats\0")
.map_err(|error| format!("failed to load hs_compute_stats: {error}"))?
};
let sum_slice: HsSumSlice = unsafe {
*library
.get(b"hs_sum_slice\0")
.map_err(|error| format!("failed to load hs_sum_slice: {error}"))?
};
let checksum_bytes: HsChecksumBytes = unsafe {
*library
.get(b"hs_checksum_bytes\0")
.map_err(|error| format!("failed to load hs_checksum_bytes: {error}"))?
};
let make_message: HsMakeMessage = unsafe {
*library
.get(b"hs_make_message\0")
.map_err(|error| format!("failed to load hs_make_message: {error}"))?
};
let make_sequence: HsMakeSequence = unsafe {
*library
.get(b"hs_make_sequence\0")
.map_err(|error| format!("failed to load hs_make_sequence: {error}"))?
};
let make_byte_pattern: HsMakeBytePattern = unsafe {
*library
.get(b"hs_make_byte_pattern\0")
.map_err(|error| format!("failed to load hs_make_byte_pattern: {error}"))?
};
let free_string: HsFreeString = unsafe {
*library
.get(b"hs_free_string\0")
.map_err(|error| format!("failed to load hs_free_string: {error}"))?
};
let free_i32_buffer: HsFreeI32Buffer = unsafe {
*library
.get(b"hs_free_i32_buffer\0")
.map_err(|error| format!("failed to load hs_free_i32_buffer: {error}"))?
};
let free_u8_buffer: HsFreeU8Buffer = unsafe {
*library
.get(b"hs_free_u8_buffer\0")
.map_err(|error| format!("failed to load hs_free_u8_buffer: {error}"))?
};
let mut stats = SharedStats::default();
let status = unsafe { compute_stats(args.left, args.right, &mut stats) };
@ -99,8 +154,40 @@ fn load_and_run(library_path: &Path, args: &DemoArgs) -> Result<String, String>
free_string(message_ptr);
}
let sample_values: [c_int; 4] = [2, 4, 6, 8];
let slice_sum = unsafe { sum_slice(sample_values.as_ptr(), sample_values.len()) };
let sample_bytes: [u8; 4] = [72, 0, 105, 255];
let byte_checksum = unsafe { checksum_bytes(sample_bytes.as_ptr(), sample_bytes.len()) };
let mut sequence_buffer = SharedI32Buffer::default();
let sequence_status = unsafe { make_sequence(args.left, 5, &mut sequence_buffer) };
if sequence_status != 0 {
return Err(format!("hs_make_sequence returned status {sequence_status}"));
}
let sequence_values = read_i32_buffer(&sequence_buffer)?;
unsafe {
free_i32_buffer(
sequence_buffer.ptr,
sequence_buffer.len,
sequence_buffer.cap,
);
}
let byte_seed = normalize_byte_seed(args.left);
let mut byte_buffer = SharedU8Buffer::default();
let byte_status = unsafe { make_byte_pattern(byte_seed, 6, &mut byte_buffer) };
if byte_status != 0 {
return Err(format!("hs_make_byte_pattern returned status {byte_status}"));
}
let returned_bytes = read_u8_buffer(&byte_buffer)?;
unsafe {
free_u8_buffer(byte_buffer.ptr, byte_buffer.len, byte_buffer.cap);
}
Ok(format!(
"Rust -> Haskell demo\nLibrary: {}\nInputs: name={}, left={}, right={}\nStats from Haskell: total={}, product={}, gap={}\nMessage from Haskell: {}",
"Rust -> Haskell demo\nLibrary: {}\nInputs: name={}, left={}, right={}\nStats from Haskell: total={}, product={}, gap={}\nMessage from Haskell: {}\nSlice sent to Haskell: {:?}\nHaskell summed slice to: {}\nVector returned from Haskell: {:?}\nByte slice sent to Haskell: {:?}\nHaskell checksummed bytes to: {}\nByte buffer returned from Haskell: {:?}",
library_path.display(),
args.name,
args.left,
@ -109,9 +196,43 @@ fn load_and_run(library_path: &Path, args: &DemoArgs) -> Result<String, String>
stats.product,
stats.gap,
message,
sample_values,
slice_sum,
sequence_values,
sample_bytes,
byte_checksum,
returned_bytes,
))
}
fn read_i32_buffer(buffer: &SharedI32Buffer) -> Result<Vec<c_int>, String> {
if buffer.ptr.is_null() {
if buffer.len == 0 {
return Ok(Vec::new());
}
return Err("received a null buffer pointer with a non-zero length".to_string());
}
let values = unsafe { std::slice::from_raw_parts(buffer.ptr, buffer.len) };
Ok(values.to_vec())
}
fn read_u8_buffer(buffer: &SharedU8Buffer) -> Result<Vec<u8>, String> {
if buffer.ptr.is_null() {
if buffer.len == 0 {
return Ok(Vec::new());
}
return Err("received a null byte buffer pointer with a non-zero length".to_string());
}
let values = unsafe { std::slice::from_raw_parts(buffer.ptr, buffer.len) };
Ok(values.to_vec())
}
fn normalize_byte_seed(value: i32) -> u8 {
value.rem_euclid(256) as u8
}
fn resolve_library_path(explicit_path: Option<&str>) -> Result<PathBuf, String> {
if let Some(path) = explicit_path {
return Ok(PathBuf::from(path));
@ -133,6 +254,33 @@ fn resolve_library_path(explicit_path: Option<&str>) -> Result<PathBuf, String>
})
}
fn resolve_rts_library_path() -> Result<PathBuf, String> {
if let Ok(path) = env::var("GHC_RTS_LIB") {
return Ok(PathBuf::from(path));
}
let libdir = if let Ok(path) = env::var("GHC_LIBDIR") {
PathBuf::from(path)
} else {
let output = std::process::Command::new("ghc")
.arg("--print-libdir")
.output()
.map_err(|error| format!("failed to run `ghc --print-libdir`: {error}"))?;
if !output.status.success() {
return Err("`ghc --print-libdir` did not exit successfully".to_string());
}
PathBuf::from(String::from_utf8_lossy(&output.stdout).trim())
};
let mut matches = Vec::new();
collect_matching_rts_libraries(&libdir, &mut matches)?;
matches.sort_by_key(|path| rts_priority(path));
matches
.into_iter()
.next()
.ok_or_else(|| format!("could not find a GHC RTS library under {}", libdir.display()))
}
fn collect_matching_libraries(root: &Path, matches: &mut Vec<PathBuf>) -> Result<(), String> {
if !root.exists() {
return Ok(());
@ -156,6 +304,29 @@ fn collect_matching_libraries(root: &Path, matches: &mut Vec<PathBuf>) -> Result
Ok(())
}
fn collect_matching_rts_libraries(root: &Path, matches: &mut Vec<PathBuf>) -> Result<(), String> {
if !root.exists() {
return Ok(());
}
let entries = fs::read_dir(root)
.map_err(|error| format!("failed to read {}: {error}", root.display()))?;
for entry in entries {
let entry = entry.map_err(|error| format!("failed to inspect {}: {error}", root.display()))?;
let path = entry.path();
if path.is_dir() {
collect_matching_rts_libraries(&path, matches)?;
continue;
}
if is_rts_library(&path) {
matches.push(path);
}
}
Ok(())
}
fn is_haskell_foreign_library(path: &Path) -> bool {
let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else {
return false;
@ -167,3 +338,30 @@ fn is_haskell_foreign_library(path: &Path) -> bool {
is_library && file_name.contains("interop_hs")
}
fn is_rts_library(path: &Path) -> bool {
let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else {
return false;
};
path.extension().and_then(|ext| ext.to_str()) == Some("so")
&& file_name.starts_with("libHSrts-")
}
fn rts_priority(path: &Path) -> (u8, String) {
let file_name = path
.file_name()
.and_then(|value| value.to_str())
.unwrap_or_default()
.to_string();
let rank = if file_name.contains("_debug") {
3
} else if file_name.contains("_thr") {
1
} else {
0
};
(rank, file_name)
}

View File

@ -1,5 +1,5 @@
use std::ffi::{CStr, CString};
use std::os::raw::{c_char, c_int};
use std::os::raw::{c_char, c_int, c_uint};
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
@ -9,6 +9,22 @@ pub struct SharedStats {
pub gap: c_int,
}
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct SharedI32Buffer {
pub ptr: *mut c_int,
pub len: usize,
pub cap: usize,
}
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct SharedU8Buffer {
pub ptr: *mut u8,
pub len: usize,
pub cap: usize,
}
pub fn compute_stats(left: c_int, right: c_int) -> SharedStats {
SharedStats {
total: left.saturating_add(right),
@ -17,6 +33,19 @@ pub fn compute_stats(left: c_int, right: c_int) -> SharedStats {
}
}
pub fn sum_slice(values: &[c_int]) -> c_int {
values
.iter()
.copied()
.fold(0, |accumulator, value| accumulator.saturating_add(value))
}
pub fn checksum_bytes(values: &[u8]) -> c_uint {
values.iter().fold(0_u32, |accumulator, value| {
accumulator.saturating_add(u32::from(*value))
})
}
pub fn make_rust_message(name: &str, left: c_int, right: c_int) -> String {
let stats = compute_stats(left, right);
format!(
@ -25,7 +54,31 @@ pub fn make_rust_message(name: &str, left: c_int, right: c_int) -> String {
)
}
#[no_mangle]
pub fn make_sequence_values(start: c_int, count: usize) -> Vec<c_int> {
(0..count)
.map(|index| {
let offset = match c_int::try_from(index) {
Ok(value) => value,
Err(_) => c_int::MAX,
};
start.saturating_add(offset)
})
.collect()
}
pub fn make_byte_pattern(seed: u8, count: usize) -> Vec<u8> {
(0..count)
.map(|index| {
let offset = match u8::try_from(index) {
Ok(value) => value,
Err(_) => u8::MAX,
};
seed.wrapping_add(offset)
})
.collect()
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_compute_stats(
left: c_int,
right: c_int,
@ -35,11 +88,23 @@ pub unsafe extern "C" fn rust_compute_stats(
return 1;
}
out_stats.write(compute_stats(left, right));
unsafe {
out_stats.write(compute_stats(left, right));
}
0
}
#[no_mangle]
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_sum_slice(ptr: *const c_int, len: usize) -> c_int {
if ptr.is_null() {
return 0;
}
let values = unsafe { std::slice::from_raw_parts(ptr, len) };
sum_slice(values)
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_make_message(
name: *const c_char,
left: c_int,
@ -49,17 +114,99 @@ pub unsafe extern "C" fn rust_make_message(
return string_into_raw("Rust received a null name pointer".to_string());
}
let name = CStr::from_ptr(name).to_string_lossy();
let name = unsafe { CStr::from_ptr(name) }.to_string_lossy();
string_into_raw(make_rust_message(name.as_ref(), left, right))
}
#[no_mangle]
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_checksum_bytes(ptr: *const u8, len: usize) -> c_uint {
if ptr.is_null() {
return 0;
}
let values = unsafe { std::slice::from_raw_parts(ptr, len) };
checksum_bytes(values)
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_make_sequence(
start: c_int,
count: usize,
out_buffer: *mut SharedI32Buffer,
) -> c_int {
if out_buffer.is_null() {
return 1;
}
let mut values = make_sequence_values(start, count);
let buffer = SharedI32Buffer {
ptr: values.as_mut_ptr(),
len: values.len(),
cap: values.capacity(),
};
std::mem::forget(values);
unsafe {
out_buffer.write(buffer);
}
0
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_make_byte_pattern(
seed: u8,
count: usize,
out_buffer: *mut SharedU8Buffer,
) -> c_int {
if out_buffer.is_null() {
return 1;
}
let mut values = make_byte_pattern(seed, count);
let buffer = SharedU8Buffer {
ptr: values.as_mut_ptr(),
len: values.len(),
cap: values.capacity(),
};
std::mem::forget(values);
unsafe {
out_buffer.write(buffer);
}
0
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_free_string(ptr: *mut c_char) {
if ptr.is_null() {
return;
}
drop(CString::from_raw(ptr));
unsafe {
drop(CString::from_raw(ptr));
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_free_i32_buffer(ptr: *mut c_int, len: usize, cap: usize) {
if ptr.is_null() {
return;
}
unsafe {
drop(Vec::from_raw_parts(ptr, len, cap));
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_free_u8_buffer(ptr: *mut u8, len: usize, cap: usize) {
if ptr.is_null() {
return;
}
unsafe {
drop(Vec::from_raw_parts(ptr, len, cap));
}
}
fn string_into_raw(message: String) -> *mut c_char {
@ -95,4 +242,24 @@ mod tests {
assert!(message.contains("product=35"));
assert!(message.contains("gap=2"));
}
#[test]
fn sum_slice_handles_multiple_values() {
assert_eq!(sum_slice(&[2, 4, 6, 8]), 20);
}
#[test]
fn make_sequence_values_builds_contiguous_output() {
assert_eq!(make_sequence_values(7, 5), vec![7, 8, 9, 10, 11]);
}
#[test]
fn checksum_bytes_handles_embedded_zeroes() {
assert_eq!(checksum_bytes(&[72, 0, 105, 255]), 432);
}
#[test]
fn make_byte_pattern_wraps_like_bytes() {
assert_eq!(make_byte_pattern(254, 4), vec![254, 255, 0, 1]);
}
}