diff --git a/.gitattributes b/.gitattributes index c5c14f3..5ed19ec 100644 --- a/.gitattributes +++ b/.gitattributes @@ -51,3 +51,5 @@ *.out filter=lfs diff=lfs merge=lfs -text *.a filter=lfs diff=lfs merge=lfs -text *.o filter=lfs diff=lfs merge=lfs -text + +Makefile linguist-vendored diff --git a/.gitignore b/.gitignore index 71ea34a..2f4d7ba 100644 --- a/.gitignore +++ b/.gitignore @@ -72,6 +72,9 @@ poetry.lock .cargo-ok cobertura.xml tarpaulin-report.html +/haskell/dist-newstyle/ +/haskell/.ghc.environment.* +/.cabal/ # Comment out the next line if you want to checkin your lock file for Cargo Cargo.lock diff --git a/Cargo.toml b/Cargo.toml index cd7c249..5a058af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,18 +1,21 @@ [package] name = "integrations" version = "0.1.0" -edition = "2021" +edition = "2024" license = "MIT OR Apache-2.0" publish = false +build = "rust/build.rs" [lib] path = "rust/lib.rs" +crate-type = ["rlib", "staticlib"] [[bin]] name = "integrations" path = "rust/main.rs" [dependencies] -ctor = "0.2" +ctor = "0.6.3" +libloading = "0.9.0" tracing = "0.1" tracing-subscriber = "0.3" diff --git a/Makefile b/Makefile index 0c6c7c8..7145ce9 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ DEBUG_PROJ := 0 RUST_BACKTRACE := 1 ASSET_DIR := assets TEST_DATA_DIR := tests/testdata -SHELL := /bin/bash +SHELL := bash # Default target .DEFAULT_GOAL := help @@ -17,6 +17,18 @@ help: ## Show help messages for all available targets @grep -E '^[a-zA-Z_-]+:.*## .*$$' Makefile | \ awk 'BEGIN {FS = ":.*## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' +.PHONY: haskell-build +haskell-build: ## Build the Haskell interop project and the Rust static library it links against + @$(MAKE) -C haskell build + +.PHONY: haskell-run +haskell-run: ## Run the Haskell -> Rust demo + @$(MAKE) -C haskell run + +.PHONY: rust-calls-haskell +rust-calls-haskell: haskell-build ## Run the Rust -> Haskell demo + @cargo run -- rust-calls-haskell + .PHONY: format format: ## Format Rust files @echo "Formatting Rust files..." diff --git a/README.md b/README.md index e9a8e44..2cadf4d 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,12 @@ --- -To be added. +Integrating components written in different languages (like Haskell/Rust/C). + +Current demo work lives in `haskell/` and `rust/`. + +- `haskell/` contains a small Cabal project with a Haskell executable that calls into Rust and a foreign library that Rust can call back into. +- `rust/` contains the Rust C ABI exports plus a CLI path for Rust calling the Haskell foreign library. ### License diff --git a/haskell/Makefile b/haskell/Makefile new file mode 100644 index 0000000..74e7e9f --- /dev/null +++ b/haskell/Makefile @@ -0,0 +1,26 @@ +SHELL := bash +CABAL_DIR := $(abspath ../.cabal) +CABAL_ENV := CABAL_DIR=$(CABAL_DIR) XDG_STATE_HOME=$(CABAL_DIR)/state XDG_CACHE_HOME=$(CABAL_DIR)/cache XDG_CONFIG_HOME=$(CABAL_DIR)/config + +.DEFAULT_GOAL := help + +.PHONY: help +help: ## Show available Haskell interop commands + @grep -E '^[a-zA-Z_-]+:.*## .*$$' Makefile | \ + awk 'BEGIN {FS = ":.*## "}; {printf "\033[36m%-24s\033[0m %s\n", $$1, $$2}' + +.PHONY: rust-lib +rust-lib: ## Build the Rust static library used by the Haskell executable + @cargo build --manifest-path ../Cargo.toml --lib + +.PHONY: build +build: rust-lib ## Build the Haskell project, including the foreign library for Rust + @$(CABAL_ENV) cabal build --project-file=cabal.project all + +.PHONY: run +run: rust-lib ## Run the Haskell -> Rust demo executable + @$(CABAL_ENV) cabal run --project-file=cabal.project haskell-calls-rust -- Ada 7 5 + +.PHONY: test +test: ## Run the pure Haskell tests + @$(CABAL_ENV) cabal test --project-file=cabal.project diff --git a/haskell/README.md b/haskell/README.md new file mode 100644 index 0000000..0bcbd02 --- /dev/null +++ b/haskell/README.md @@ -0,0 +1,61 @@ +# Haskell Interop Demo + +This directory contains a small Haskell project for the repository's Rust/Haskell integration work. + +The project demonstrates both directions: + +- Haskell calling into Rust through a C ABI exposed by the Rust crate. +- Rust calling back into Haskell through a Cabal `foreign-library`, with explicit GHC RTS initialization. + +## Layout + +- `src/Interop/Shared.hs` - pure shared logic plus the C-compatible struct layout used at the boundary. +- `src/Interop/Exports.hs` - Haskell functions exported to C for Rust to call. +- `app/RustClient.hs` - Haskell imports for the Rust C ABI. +- `app/Main.hs` - Haskell executable that demonstrates the Haskell -> Rust path. +- `test/Spec.hs` - small pure tests that avoid crossing the FFI boundary. + +The code keeps the FFI surface small on purpose. The boundary uses: + +- integers +- a fixed C-shaped struct +- owned C strings with an explicit free function on each side + +That is enough to demonstrate the main challenges without pulling in code generation or large bindings. + +## Build And Run + +From the repository root: + +```sh +make haskell-build +``` + +Run the Haskell executable that calls Rust: + +```sh +make haskell-run +``` + +Run the Rust executable that calls the Haskell foreign library: + +```sh +make rust-calls-haskell +``` + +You can also run the commands manually: + +```sh +cargo build --lib +cabal build --project-file=haskell/cabal.project all +cabal run --project-file=haskell/cabal.project haskell-calls-rust -- Ada 7 5 +cargo run -- rust-calls-haskell Ada 7 5 +``` + +## What This Demonstrates + +- The boundary must stay C-shaped. Rich Rust and Haskell types do not cross directly. +- Strings need explicit ownership rules. Each side exports its own free function. +- Struct layout must be mirrored carefully on both sides. +- Rust calling Haskell is the harder direction because it must initialize and shut down the GHC runtime correctly. +- Build order is part of the design. Haskell links against the Rust static library, and Rust loads the Haskell foreign library after Cabal builds it. diff --git a/haskell/app/Main.hs b/haskell/app/Main.hs new file mode 100644 index 0000000..eaa8e7d --- /dev/null +++ b/haskell/app/Main.hs @@ -0,0 +1,48 @@ +module Main (main) where + +import Interop.Shared (Summary (..)) +import RustClient (callRustMessage, callRustSummary) +import System.Environment (getArgs) +import Text.Read (readMaybe) + +main :: IO () +main = do + args <- getArgs + let (name, left, right) = parseArgs args + summary <- callRustSummary left right + message <- callRustMessage name left right + + putStrLn "Haskell -> Rust demo" + putStrLn $ "Inputs: name=" ++ name ++ ", left=" ++ show left ++ ", right=" ++ show right + putStrLn $ "Stats from Rust: " ++ renderSummary summary + putStrLn $ "Message from Rust: " ++ message + +parseArgs :: [String] -> (String, Int, Int) +parseArgs args = + let name = case args of + value : _ -> value + [] -> "Ada" + left = maybe 7 id (pickNumber 1 args) + right = maybe 5 id (pickNumber 2 args) + in (name, left, right) + +pickNumber :: Int -> [String] -> Maybe Int +pickNumber index values = do + value <- safeIndex index values + readMaybe value + +safeIndex :: Int -> [a] -> Maybe a +safeIndex index values + | index < 0 = Nothing + | otherwise = case drop index values of + value : _ -> Just value + [] -> Nothing + +renderSummary :: Summary -> String +renderSummary summary = + "total=" + ++ show (total summary) + ++ ", product=" + ++ show (combinedProduct summary) + ++ ", gap=" + ++ show (gap summary) diff --git a/haskell/app/RustClient.hs b/haskell/app/RustClient.hs new file mode 100644 index 0000000..ac2a09a --- /dev/null +++ b/haskell/app/RustClient.hs @@ -0,0 +1,41 @@ +module RustClient ( + callRustMessage, + callRustSummary, +) where + +import Control.Exception (bracket) +import Foreign.C.String (CString, peekCString, withCString) +import Foreign.C.Types (CInt (..)) +import Foreign.Marshal.Alloc (alloca) +import Foreign.Ptr (Ptr, nullPtr) +import Foreign.Storable (peek) +import Interop.Shared (Summary, SharedStats, summaryFromSharedStats) + +foreign import ccall unsafe "rust_compute_stats" + rustComputeStats :: CInt -> CInt -> Ptr SharedStats -> IO CInt + +foreign import ccall unsafe "rust_make_message" + rustMakeMessage :: CString -> CInt -> CInt -> IO CString + +foreign import ccall unsafe "rust_free_string" + rustFreeString :: CString -> IO () + +callRustSummary :: Int -> Int -> IO Summary +callRustSummary left right = + alloca $ \outStats -> do + status <- rustComputeStats (fromIntegral left) (fromIntegral right) outStats + if status /= 0 + then fail ("rustComputeStats returned status " ++ show status) + else summaryFromSharedStats <$> peek outStats + +callRustMessage :: String -> Int -> Int -> IO String +callRustMessage name left right = + withCString name $ \namePtr -> do + messagePtr <- rustMakeMessage namePtr (fromIntegral left) (fromIntegral right) + if messagePtr == nullPtr + then fail "rustMakeMessage returned a null pointer" + else + bracket + (pure messagePtr) + rustFreeString + peekCString diff --git a/haskell/cabal.project b/haskell/cabal.project new file mode 100644 index 0000000..e6fdbad --- /dev/null +++ b/haskell/cabal.project @@ -0,0 +1 @@ +packages: . diff --git a/haskell/interop-demo.cabal b/haskell/interop-demo.cabal new file mode 100644 index 0000000..837af8d --- /dev/null +++ b/haskell/interop-demo.cabal @@ -0,0 +1,69 @@ +cabal-version: 3.8 +name: haskell-interop-demo +version: 0.1.0.0 +license: MIT +build-type: Simple + +common common + default-language: GHC2021 + ghc-options: + -Wall + -Wcompat + -Widentities + -Wincomplete-record-updates + -Wincomplete-uni-patterns + -Wmissing-export-lists + -Wmissing-home-modules + -Wpartial-fields + -Wredundant-constraints + build-depends: + base >= 4.18 && < 5 + +library + import: common + hs-source-dirs: src + exposed-modules: + Interop.Exports + Interop.Shared + +foreign-library interop_hs + import: common + type: native-shared + options: standalone + ghc-options: + -dynamic + hs-source-dirs: src + other-modules: + Interop.Exports + Interop.Shared + +executable haskell-calls-rust + import: common + main-is: Main.hs + hs-source-dirs: + app + src + other-modules: + RustClient + Interop.Shared + ghc-options: + -threaded + -rtsopts + -with-rtsopts=-N + extra-lib-dirs: + ../target/debug + ../target/release + extra-libraries: + integrations + +test-suite haskell-interop-demo-test + import: common + type: exitcode-stdio-1.0 + main-is: Spec.hs + hs-source-dirs: + test + src + other-modules: + Interop.Shared + build-depends: + haskell-interop-demo diff --git a/haskell/src/Interop/Exports.hs b/haskell/src/Interop/Exports.hs new file mode 100644 index 0000000..2a72da6 --- /dev/null +++ b/haskell/src/Interop/Exports.hs @@ -0,0 +1,38 @@ +module Interop.Exports ( + hsComputeStats, + hsFreeString, + hsMakeMessage, +) where + +import Foreign.C.String (CString, newCString, peekCString) +import Foreign.C.Types (CInt (..)) +import Foreign.Marshal.Alloc (free) +import Foreign.Ptr (Ptr, nullPtr) +import Foreign.Storable (poke) +import Interop.Shared (SharedStats, calculateSummary, formatHaskellMessage, summaryToSharedStats) + +hsComputeStats :: CInt -> CInt -> Ptr SharedStats -> IO CInt +hsComputeStats left right outStats + | outStats == nullPtr = pure 1 + | otherwise = do + poke outStats $ + summaryToSharedStats $ + calculateSummary (fromIntegral left) (fromIntegral right) + pure 0 + +hsMakeMessage :: CString -> CInt -> CInt -> IO CString +hsMakeMessage namePtr left right + | namePtr == nullPtr = newCString "Haskell received a null name pointer" + | otherwise = do + name <- peekCString namePtr + let summary = calculateSummary (fromIntegral left) (fromIntegral right) + newCString (formatHaskellMessage name summary) + +hsFreeString :: CString -> IO () +hsFreeString ptr + | ptr == nullPtr = pure () + | otherwise = free ptr + +foreign export ccall "hs_compute_stats" hsComputeStats :: CInt -> CInt -> Ptr SharedStats -> IO CInt +foreign export ccall "hs_make_message" hsMakeMessage :: CString -> CInt -> CInt -> IO CString +foreign export ccall "hs_free_string" hsFreeString :: CString -> IO () diff --git a/haskell/src/Interop/Shared.hs b/haskell/src/Interop/Shared.hs new file mode 100644 index 0000000..869bd22 --- /dev/null +++ b/haskell/src/Interop/Shared.hs @@ -0,0 +1,82 @@ +module Interop.Shared ( + SharedStats (..), + Summary (..), + calculateSummary, + formatHaskellMessage, + summaryFromSharedStats, + summaryToSharedStats, +) where + +import Foreign.C.Types (CInt) +import Foreign.Storable (Storable (..), peekByteOff, pokeByteOff) + +data Summary = Summary + { total :: Int + , combinedProduct :: Int + , gap :: Int + } + deriving (Eq, Show) + +calculateSummary :: Int -> Int -> Summary +calculateSummary left right = + Summary + { total = left + right + , combinedProduct = left * right + , gap = abs (left - right) + } + +formatHaskellMessage :: String -> Summary -> String +formatHaskellMessage name summary = + "Haskell handled " + ++ name + ++ ": total=" + ++ show (total summary) + ++ ", product=" + ++ show (combinedProduct summary) + ++ ", gap=" + ++ show (gap summary) + +data SharedStats = SharedStats + { sharedTotal :: CInt + , sharedProduct :: CInt + , sharedGap :: CInt + } + deriving (Eq, Show) + +instance Storable SharedStats where + sizeOf _ = fieldSize * 3 + where + fieldSize = sizeOf (undefined :: CInt) + + alignment _ = alignment (undefined :: CInt) + + peek ptr = + SharedStats + <$> peekByteOff ptr 0 + <*> peekByteOff ptr fieldSize + <*> peekByteOff ptr (fieldSize * 2) + where + fieldSize = sizeOf (undefined :: CInt) + + poke ptr value = do + pokeByteOff ptr 0 (sharedTotal value) + pokeByteOff ptr fieldSize (sharedProduct value) + pokeByteOff ptr (fieldSize * 2) (sharedGap value) + where + fieldSize = sizeOf (undefined :: CInt) + +summaryToSharedStats :: Summary -> SharedStats +summaryToSharedStats summary = + SharedStats + { sharedTotal = fromIntegral (total summary) + , sharedProduct = fromIntegral (combinedProduct summary) + , sharedGap = fromIntegral (gap summary) + } + +summaryFromSharedStats :: SharedStats -> Summary +summaryFromSharedStats stats = + Summary + { total = fromIntegral (sharedTotal stats) + , combinedProduct = fromIntegral (sharedProduct stats) + , gap = fromIntegral (sharedGap stats) + } diff --git a/haskell/test/Spec.hs b/haskell/test/Spec.hs new file mode 100644 index 0000000..d7ee01d --- /dev/null +++ b/haskell/test/Spec.hs @@ -0,0 +1,35 @@ +module Main (main) where + +import Interop.Shared (calculateSummary, combinedProduct, formatHaskellMessage, gap, total) +import System.Exit (exitFailure) + +main :: IO () +main = do + assertEqual "summary total" 12 (totalValue 7 5) + assertEqual "summary product" 35 (productValue 7 5) + assertEqual "summary gap" 2 (gapValue 7 5) + assertEqual + "message rendering" + "Haskell handled Ada: total=12, product=35, gap=2" + (formatHaskellMessage "Ada" (calculateSummary 7 5)) + +totalValue :: Int -> Int -> Int +totalValue left right = total (calculateSummary left right) + +productValue :: Int -> Int -> Int +productValue left right = combinedProduct (calculateSummary left right) + +gapValue :: Int -> Int -> Int +gapValue left right = gap (calculateSummary left right) + +assertEqual :: (Eq a, Show a) => String -> a -> a -> IO () +assertEqual label expected actual + | expected == actual = pure () + | otherwise = do + putStrLn $ + label + ++ " expected " + ++ show expected + ++ " but got " + ++ show actual + exitFailure diff --git a/rust/build.rs b/rust/build.rs new file mode 100644 index 0000000..021a2b6 --- /dev/null +++ b/rust/build.rs @@ -0,0 +1,262 @@ +use std::collections::{BTreeSet, HashSet}; +use std::env; +use std::fs; +use std::path::{Path, PathBuf}; +use std::process::Command; + +fn main() { + println!("cargo:rerun-if-env-changed=GHC_LIBDIR"); + println!("cargo:rerun-if-env-changed=GHC_RTS_LIB"); + println!("cargo:rerun-if-changed=rust/build.rs"); + + let libdir = env::var("GHC_LIBDIR").unwrap_or_else(|_| ghc_print_libdir()); + let explicit_rts = env::var("GHC_RTS_LIB").ok().map(PathBuf::from); + let rts_path = explicit_rts.unwrap_or_else(|| find_rts_library(Path::new(&libdir))); + let rts_dir = rts_path + .parent() + .unwrap_or_else(|| Path::new(&libdir)) + .to_path_buf(); + let rts_name = rts_path + .file_stem() + .and_then(|stem| stem.to_str()) + .map(strip_library_prefix) + .unwrap_or_else(|| panic!("failed to resolve GHC RTS library name from {}", rts_path.display())); + + let mut search_dirs = BTreeSet::new(); + let mut haskell_libs = Vec::new(); + let mut seen_haskell_libs = HashSet::new(); + let mut native_libs = BTreeSet::new(); + + search_dirs.insert(rts_dir); + seen_haskell_libs.insert(rts_name.clone()); + haskell_libs.push(rts_name); + + for package in ["base", "ghc-prim", "ghc-bignum"] { + let info = ghc_pkg_describe(package); + for dir in &info.dynamic_library_dirs { + search_dirs.insert(dir.clone()); + } + for library in info.hs_libraries { + let resolved = resolve_dynamic_hs_library(&library, &info.dynamic_library_dirs); + if seen_haskell_libs.insert(resolved.clone()) { + haskell_libs.push(resolved); + } + } + + for library in info.extra_libraries { + native_libs.insert(library); + } + } + + let rts_info = ghc_pkg_describe("rts"); + for dir in rts_info.dynamic_library_dirs { + search_dirs.insert(dir); + } + for library in rts_info.extra_libraries { + native_libs.insert(library); + } + + for dir in search_dirs { + println!("cargo:rustc-link-search=native={}", dir.display()); + println!("cargo:rustc-link-arg=-Wl,-rpath,{}", dir.display()); + } + + println!("cargo:rustc-link-arg=-Wl,--no-as-needed"); + for library in haskell_libs { + println!("cargo:rustc-link-lib=dylib={library}"); + } + println!("cargo:rustc-link-arg=-Wl,--as-needed"); + + for library in native_libs { + println!("cargo:rustc-link-lib=dylib={library}"); + } +} + +#[derive(Default)] +struct PackageInfo { + hs_libraries: Vec, + extra_libraries: Vec, + dynamic_library_dirs: Vec, +} + +fn ghc_print_libdir() -> String { + let output = Command::new("ghc") + .arg("--print-libdir") + .output() + .unwrap_or_else(|error| panic!("failed to run `ghc --print-libdir`: {error}")); + if !output.status.success() { + panic!("`ghc --print-libdir` did not exit successfully"); + } + + String::from_utf8_lossy(&output.stdout).trim().to_string() +} + +fn ghc_pkg_describe(package: &str) -> PackageInfo { + let output = Command::new("ghc-pkg") + .args(["describe", package]) + .output() + .unwrap_or_else(|error| panic!("failed to run `ghc-pkg describe {package}`: {error}")); + if !output.status.success() { + panic!("`ghc-pkg describe {package}` did not exit successfully"); + } + + let description = String::from_utf8_lossy(&output.stdout); + parse_package_description(&description) +} + +fn parse_package_description(description: &str) -> PackageInfo { + let mut info = PackageInfo::default(); + let mut current_field = String::new(); + let mut pkgroot = String::new(); + + for raw_line in description.lines() { + let line = raw_line.trim_end(); + if line.is_empty() { + continue; + } + + if raw_line.starts_with(' ') || raw_line.starts_with('\t') { + push_field_values(¤t_field, line.trim(), &mut pkgroot, &mut info); + continue; + } + + if let Some((field, rest)) = line.split_once(':') { + current_field = field.trim().to_string(); + push_field_values(¤t_field, rest.trim(), &mut pkgroot, &mut info); + } + } + + if !pkgroot.is_empty() { + for dir in &mut info.dynamic_library_dirs { + let resolved = dir + .display() + .to_string() + .replace("${pkgroot}", &pkgroot); + *dir = PathBuf::from(resolved); + } + } + + info +} + +fn push_field_values( + field: &str, + values: &str, + pkgroot: &mut String, + info: &mut PackageInfo, +) { + if values.is_empty() { + return; + } + + match field { + "pkgroot" => { + *pkgroot = values.trim_matches('"').to_string(); + } + "hs-libraries" => { + for value in values.split_whitespace() { + info.hs_libraries + .push(strip_library_prefix(value.trim_matches('"'))); + } + } + "extra-libraries" => { + for value in values.split_whitespace() { + info.extra_libraries.push(value.trim_matches('"').to_string()); + } + } + "dynamic-library-dirs" => { + for value in values.split_whitespace() { + info.dynamic_library_dirs + .push(PathBuf::from(value.trim_matches('"'))); + } + } + _ => {} + } +} + +fn resolve_dynamic_hs_library(library: &str, search_dirs: &[PathBuf]) -> String { + for dir in search_dirs { + let Ok(entries) = fs::read_dir(dir) else { + continue; + }; + + for entry in entries.flatten() { + let path = entry.path(); + let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else { + continue; + }; + + let exact_name = format!("lib{library}.so"); + let versioned_prefix = format!("lib{library}-"); + if (file_name == exact_name + || (file_name.starts_with(&versioned_prefix) && file_name.ends_with(".so"))) + && path.is_file() + { + if let Some(stem) = path.file_stem().and_then(|value| value.to_str()) { + return strip_library_prefix(stem); + } + } + } + } + + library.to_string() +} + +fn find_rts_library(libdir: &Path) -> PathBuf { + let mut candidates = Vec::new(); + walk_for_rts(libdir, &mut candidates); + candidates.sort_by_key(|path| rts_priority(path)); + + candidates + .into_iter() + .next() + .unwrap_or_else(|| panic!("failed to locate a GHC RTS library under {}", libdir.display())) +} + +fn walk_for_rts(root: &Path, candidates: &mut Vec) { + let Ok(entries) = fs::read_dir(root) else { + return; + }; + + for entry in entries.flatten() { + let path = entry.path(); + if path.is_dir() { + walk_for_rts(&path, candidates); + continue; + } + + if is_threaded_rts_library(&path) { + candidates.push(path); + } + } +} + +fn is_threaded_rts_library(path: &Path) -> bool { + let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else { + return false; + }; + + path.extension().and_then(|ext| ext.to_str()) == Some("so") && file_name.starts_with("libHSrts-") +} + +fn rts_priority(path: &Path) -> (u8, String) { + let file_name = path + .file_name() + .and_then(|value| value.to_str()) + .unwrap_or_default() + .to_string(); + + let rank = if file_name.contains("_debug") { + 3 + } else if file_name.contains("_thr") { + 1 + } else { + 0 + }; + + (rank, file_name) +} + +fn strip_library_prefix(stem: &str) -> String { + stem.strip_prefix("lib").unwrap_or(stem).to_string() +} diff --git a/rust/cli.rs b/rust/cli.rs index 2290390..9f29ba2 100644 --- a/rust/cli.rs +++ b/rust/cli.rs @@ -1,32 +1,137 @@ +use crate::haskell::{run_haskell_demo, DemoArgs}; use std::ffi::OsString; use tracing::error; pub fn run(args: impl IntoIterator) -> Result<(), i32> { - let _args: Vec = args.into_iter().collect(); - if _args.len() < 2 { - error!("Expecting at least 2 arguments"); - return Err(1); + let args: Vec = args + .into_iter() + .map(|arg| arg.to_string_lossy().into_owned()) + .collect(); + + match parse_command(&args) { + Ok(Command::Help) => { + print_usage(args.first().map(String::as_str).unwrap_or("integrations")); + Ok(()) + } + Ok(Command::RustCallsHaskell(demo_args)) => match run_haskell_demo(&demo_args) { + Ok(output) => { + println!("{output}"); + Ok(()) + } + Err(message) => { + error!("{message}"); + Err(1) + } + }, + Err(message) => { + error!("{message}"); + print_usage(args.first().map(String::as_str).unwrap_or("integrations")); + Err(1) + } } - Ok(()) } -// Unit tests +#[derive(Debug, PartialEq, Eq)] +enum Command { + Help, + RustCallsHaskell(DemoArgs), +} + +fn parse_command(args: &[String]) -> Result { + match args { + [_program] => Ok(Command::Help), + [_program, command] if command == "help" || command == "--help" || command == "-h" => { + Ok(Command::Help) + } + [_program, command, rest @ ..] if command == "rust-calls-haskell" => { + parse_demo_args(rest).map(Command::RustCallsHaskell) + } + [_program, command, ..] => Err(format!("unknown command: {command}")), + [] => Ok(Command::Help), + } +} + +fn parse_demo_args(args: &[String]) -> Result { + let name = args + .first() + .cloned() + .unwrap_or_else(|| "Ada".to_string()); + let left = parse_i32_arg(args.get(1), "left operand", 7)?; + let right = parse_i32_arg(args.get(2), "right operand", 5)?; + let library_path = args.get(3).cloned(); + + Ok(DemoArgs { + name, + left, + right, + library_path, + }) +} + +fn parse_i32_arg(raw: Option<&String>, label: &str, default: i32) -> Result { + match raw { + Some(value) => value + .parse::() + .map_err(|_| format!("invalid {label}: {value}")), + None => Ok(default), + } +} + +fn print_usage(program: &str) { + println!("Usage:"); + println!(" {program} rust-calls-haskell [name] [left] [right] [haskell-lib-path]"); + println!(); + println!("Examples:"); + println!(" {program} rust-calls-haskell"); + println!(" {program} rust-calls-haskell Grace 8 3"); +} + #[cfg(test)] mod tests { use super::*; - use std::ffi::OsString; #[test] - fn test_run_with_valid_args() { - let args = vec![OsString::from("arg1"), OsString::from("arg2")]; - let result = run(args); - assert!(result.is_ok()); + fn parse_demo_args_uses_defaults() { + let args: Vec = Vec::new(); + let parsed = parse_demo_args(&args).expect("defaults should parse"); + + assert_eq!( + parsed, + DemoArgs { + name: "Ada".to_string(), + left: 7, + right: 5, + library_path: None, + } + ); } #[test] - fn test_run_with_invalid_args() { - let args = vec![OsString::from("invalid_arg")]; - let result = run(args); - assert!(result.is_err()); + fn parse_demo_args_accepts_explicit_values() { + let args = vec![ + "Linus".to_string(), + "9".to_string(), + "4".to_string(), + "haskell/dist-newstyle/demo/libinterop_hs.so".to_string(), + ]; + let parsed = parse_demo_args(&args).expect("explicit args should parse"); + + assert_eq!( + parsed, + DemoArgs { + name: "Linus".to_string(), + left: 9, + right: 4, + library_path: Some("haskell/dist-newstyle/demo/libinterop_hs.so".to_string()), + } + ); + } + + #[test] + fn parse_command_rejects_unknown_commands() { + let args = vec!["integrations".to_string(), "wat".to_string()]; + let parsed = parse_command(&args); + + assert!(parsed.is_err()); } } diff --git a/rust/haskell.rs b/rust/haskell.rs new file mode 100644 index 0000000..6b5db01 --- /dev/null +++ b/rust/haskell.rs @@ -0,0 +1,169 @@ +use crate::interop::SharedStats; +use libloading::Library; +use std::env; +use std::ffi::{CStr, CString}; +use std::fs; +use std::os::raw::{c_char, c_int}; +use std::path::{Path, PathBuf}; + +unsafe extern "C" { + fn hs_init(argc: *mut c_int, argv: *mut *mut *mut c_char); + fn hs_exit(); +} + +type HsComputeStats = unsafe extern "C" fn(c_int, c_int, *mut SharedStats) -> c_int; +type HsMakeMessage = unsafe extern "C" fn(*const c_char, c_int, c_int) -> *mut c_char; +type HsFreeString = unsafe extern "C" fn(*mut c_char); + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct DemoArgs { + pub name: String, + pub left: i32, + pub right: i32, + pub library_path: Option, +} + +pub fn run_haskell_demo(args: &DemoArgs) -> Result { + let library_path = resolve_library_path(args.library_path.as_deref())?; + let runtime = HaskellRuntime::start()?; + + let output = load_and_run(&library_path, args); + + drop(runtime); + output +} + +struct HaskellRuntime; + +impl HaskellRuntime { + fn start() -> Result { + let mut argc: c_int = 1; + let program_name = CString::new("integrations-hs-runtime") + .map_err(|_| "failed to create runtime program name".to_string())?; + let mut argv = vec![program_name.as_ptr() as *mut c_char, std::ptr::null_mut()]; + let mut argv_ptr = argv.as_mut_ptr(); + + unsafe { + hs_init(&mut argc, &mut argv_ptr); + } + + Ok(Self) + } +} + +impl Drop for HaskellRuntime { + fn drop(&mut self) { + unsafe { + hs_exit(); + } + } +} + +fn load_and_run(library_path: &Path, args: &DemoArgs) -> Result { + let library = unsafe { Library::new(library_path) } + .map_err(|error| format!("failed to load {}: {error}", library_path.display()))?; + + let compute_stats: HsComputeStats = unsafe { + *library + .get(b"hs_compute_stats\0") + .map_err(|error| format!("failed to load hs_compute_stats: {error}"))? + }; + let make_message: HsMakeMessage = unsafe { + *library + .get(b"hs_make_message\0") + .map_err(|error| format!("failed to load hs_make_message: {error}"))? + }; + let free_string: HsFreeString = unsafe { + *library + .get(b"hs_free_string\0") + .map_err(|error| format!("failed to load hs_free_string: {error}"))? + }; + + let mut stats = SharedStats::default(); + let status = unsafe { compute_stats(args.left, args.right, &mut stats) }; + if status != 0 { + return Err(format!("hs_compute_stats returned status {status}")); + } + + let name = CString::new(args.name.replace('\0', "?")) + .map_err(|_| "failed to prepare demo name".to_string())?; + let message_ptr = unsafe { make_message(name.as_ptr(), args.left, args.right) }; + if message_ptr.is_null() { + return Err("hs_make_message returned a null pointer".to_string()); + } + + let message = unsafe { CStr::from_ptr(message_ptr) } + .to_string_lossy() + .into_owned(); + unsafe { + free_string(message_ptr); + } + + Ok(format!( + "Rust -> Haskell demo\nLibrary: {}\nInputs: name={}, left={}, right={}\nStats from Haskell: total={}, product={}, gap={}\nMessage from Haskell: {}", + library_path.display(), + args.name, + args.left, + args.right, + stats.total, + stats.product, + stats.gap, + message, + )) +} + +fn resolve_library_path(explicit_path: Option<&str>) -> Result { + if let Some(path) = explicit_path { + return Ok(PathBuf::from(path)); + } + + if let Ok(path) = env::var("HASKELL_FOREIGN_LIB") { + return Ok(PathBuf::from(path)); + } + + let dist_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("haskell") + .join("dist-newstyle"); + let mut matches = Vec::new(); + collect_matching_libraries(&dist_dir, &mut matches)?; + matches.sort(); + + matches.into_iter().next().ok_or_else(|| { + "could not find the Haskell foreign library under haskell/dist-newstyle; run `make haskell-build` first or pass an explicit path".to_string() + }) +} + +fn collect_matching_libraries(root: &Path, matches: &mut Vec) -> Result<(), String> { + if !root.exists() { + return Ok(()); + } + + let entries = fs::read_dir(root) + .map_err(|error| format!("failed to read {}: {error}", root.display()))?; + for entry in entries { + let entry = entry.map_err(|error| format!("failed to inspect {}: {error}", root.display()))?; + let path = entry.path(); + if path.is_dir() { + collect_matching_libraries(&path, matches)?; + continue; + } + + if is_haskell_foreign_library(&path) { + matches.push(path); + } + } + + Ok(()) +} + +fn is_haskell_foreign_library(path: &Path) -> bool { + let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else { + return false; + }; + + let is_library = file_name.ends_with(".so") + || file_name.ends_with(".dylib") + || file_name.ends_with(".dll"); + + is_library && file_name.contains("interop_hs") +} diff --git a/rust/interop.rs b/rust/interop.rs new file mode 100644 index 0000000..5afc230 --- /dev/null +++ b/rust/interop.rs @@ -0,0 +1,98 @@ +use std::ffi::{CStr, CString}; +use std::os::raw::{c_char, c_int}; + +#[repr(C)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub struct SharedStats { + pub total: c_int, + pub product: c_int, + pub gap: c_int, +} + +pub fn compute_stats(left: c_int, right: c_int) -> SharedStats { + SharedStats { + total: left.saturating_add(right), + product: left.saturating_mul(right), + gap: left.saturating_sub(right).abs(), + } +} + +pub fn make_rust_message(name: &str, left: c_int, right: c_int) -> String { + let stats = compute_stats(left, right); + format!( + "Rust handled {name}: total={}, product={}, gap={}", + stats.total, stats.product, stats.gap + ) +} + +#[no_mangle] +pub unsafe extern "C" fn rust_compute_stats( + left: c_int, + right: c_int, + out_stats: *mut SharedStats, +) -> c_int { + if out_stats.is_null() { + return 1; + } + + out_stats.write(compute_stats(left, right)); + 0 +} + +#[no_mangle] +pub unsafe extern "C" fn rust_make_message( + name: *const c_char, + left: c_int, + right: c_int, +) -> *mut c_char { + if name.is_null() { + return string_into_raw("Rust received a null name pointer".to_string()); + } + + let name = CStr::from_ptr(name).to_string_lossy(); + string_into_raw(make_rust_message(name.as_ref(), left, right)) +} + +#[no_mangle] +pub unsafe extern "C" fn rust_free_string(ptr: *mut c_char) { + if ptr.is_null() { + return; + } + + drop(CString::from_raw(ptr)); +} + +fn string_into_raw(message: String) -> *mut c_char { + let sanitized = message.replace('\0', "?"); + match CString::new(sanitized) { + Ok(c_string) => c_string.into_raw(), + Err(_) => std::ptr::null_mut(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn compute_stats_matches_expected_values() { + assert_eq!( + compute_stats(9, 4), + SharedStats { + total: 13, + product: 36, + gap: 5, + } + ); + } + + #[test] + fn message_contains_name_and_values() { + let message = make_rust_message("Ada", 7, 5); + + assert!(message.contains("Ada")); + assert!(message.contains("total=12")); + assert!(message.contains("product=35")); + assert!(message.contains("gap=2")); + } +} diff --git a/rust/lib.rs b/rust/lib.rs index 2b8e049..d0bab95 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -1,2 +1,4 @@ pub mod cli; +pub mod haskell; +pub mod interop; pub mod logging;