use crate::interop::{SharedI32Buffer, SharedStats, SharedU8Buffer}; use libloading::Library; use std::env; use std::ffi::{CStr, CString}; use std::fs; use std::os::raw::{c_char, c_int, c_uint}; use std::path::{Path, PathBuf}; type HsComputeStats = unsafe extern "C" fn(c_int, c_int, *mut SharedStats) -> c_int; type HsSumSlice = unsafe extern "C" fn(*const c_int, usize) -> c_int; type HsChecksumBytes = unsafe extern "C" fn(*const u8, usize) -> c_uint; type HsMakeMessage = unsafe extern "C" fn(*const c_char, c_int, c_int) -> *mut c_char; type HsMakeSequence = unsafe extern "C" fn(c_int, usize, *mut SharedI32Buffer) -> c_int; type HsMakeBytePattern = unsafe extern "C" fn(u8, usize, *mut SharedU8Buffer) -> c_int; type HsFreeString = unsafe extern "C" fn(*mut c_char); type HsFreeI32Buffer = unsafe extern "C" fn(*mut c_int, usize, usize); type HsFreeU8Buffer = unsafe extern "C" fn(*mut u8, usize, usize); type HsInit = unsafe extern "C" fn(*mut c_int, *mut *mut *mut c_char); type HsExit = unsafe extern "C" fn(); #[derive(Clone, Debug, PartialEq, Eq)] pub struct DemoArgs { pub name: String, pub left: i32, pub right: i32, pub library_path: Option, } pub fn run_haskell_demo(args: &DemoArgs) -> Result { let library_path = resolve_library_path(args.library_path.as_deref())?; let library = unsafe { Library::new(&library_path) } .map_err(|error| format!("failed to load {}: {error}", library_path.display()))?; let runtime = HaskellRuntime::start()?; let output = load_and_run(&library_path, &library, args); drop(runtime); output } struct HaskellRuntime { _rts_library: Library, hs_exit: HsExit, } impl HaskellRuntime { fn start() -> Result { let rts_library_path = resolve_rts_library_path()?; let rts_library = unsafe { Library::new(&rts_library_path) }.map_err(|error| { format!( "failed to load GHC RTS library {}: {error}", rts_library_path.display() ) })?; let hs_init: HsInit = unsafe { *rts_library .get(b"hs_init\0") .map_err(|error| format!("failed to load hs_init: {error}"))? }; let hs_exit: HsExit = unsafe { *rts_library .get(b"hs_exit\0") .map_err(|error| format!("failed to load hs_exit: {error}"))? }; let mut argc: c_int = 1; let program_name = CString::new("integrations-hs-runtime") .map_err(|_| "failed to create runtime program name".to_string())?; let mut argv = vec![program_name.as_ptr() as *mut c_char, std::ptr::null_mut()]; let mut argv_ptr = argv.as_mut_ptr(); unsafe { hs_init(&mut argc, &mut argv_ptr); } Ok(Self { _rts_library: rts_library, hs_exit, }) } } impl Drop for HaskellRuntime { fn drop(&mut self) { unsafe { (self.hs_exit)(); } } } fn load_and_run(library_path: &Path, library: &Library, args: &DemoArgs) -> Result { let compute_stats: HsComputeStats = unsafe { *library .get(b"hs_compute_stats\0") .map_err(|error| format!("failed to load hs_compute_stats: {error}"))? }; let sum_slice: HsSumSlice = unsafe { *library .get(b"hs_sum_slice\0") .map_err(|error| format!("failed to load hs_sum_slice: {error}"))? }; let checksum_bytes: HsChecksumBytes = unsafe { *library .get(b"hs_checksum_bytes\0") .map_err(|error| format!("failed to load hs_checksum_bytes: {error}"))? }; let make_message: HsMakeMessage = unsafe { *library .get(b"hs_make_message\0") .map_err(|error| format!("failed to load hs_make_message: {error}"))? }; let make_sequence: HsMakeSequence = unsafe { *library .get(b"hs_make_sequence\0") .map_err(|error| format!("failed to load hs_make_sequence: {error}"))? }; let make_byte_pattern: HsMakeBytePattern = unsafe { *library .get(b"hs_make_byte_pattern\0") .map_err(|error| format!("failed to load hs_make_byte_pattern: {error}"))? }; let free_string: HsFreeString = unsafe { *library .get(b"hs_free_string\0") .map_err(|error| format!("failed to load hs_free_string: {error}"))? }; let free_i32_buffer: HsFreeI32Buffer = unsafe { *library .get(b"hs_free_i32_buffer\0") .map_err(|error| format!("failed to load hs_free_i32_buffer: {error}"))? }; let free_u8_buffer: HsFreeU8Buffer = unsafe { *library .get(b"hs_free_u8_buffer\0") .map_err(|error| format!("failed to load hs_free_u8_buffer: {error}"))? }; let mut stats = SharedStats::default(); let status = unsafe { compute_stats(args.left, args.right, &mut stats) }; if status != 0 { return Err(format!("hs_compute_stats returned status {status}")); } let name = CString::new(args.name.replace('\0', "?")) .map_err(|_| "failed to prepare demo name".to_string())?; let message_ptr = unsafe { make_message(name.as_ptr(), args.left, args.right) }; if message_ptr.is_null() { return Err("hs_make_message returned a null pointer".to_string()); } let message = unsafe { CStr::from_ptr(message_ptr) } .to_string_lossy() .into_owned(); unsafe { free_string(message_ptr); } let sample_values: [c_int; 4] = [2, 4, 6, 8]; let slice_sum = unsafe { sum_slice(sample_values.as_ptr(), sample_values.len()) }; let sample_bytes: [u8; 4] = [72, 0, 105, 255]; let byte_checksum = unsafe { checksum_bytes(sample_bytes.as_ptr(), sample_bytes.len()) }; let mut sequence_buffer = SharedI32Buffer::default(); let sequence_status = unsafe { make_sequence(args.left, 5, &mut sequence_buffer) }; if sequence_status != 0 { return Err(format!("hs_make_sequence returned status {sequence_status}")); } let sequence_values = read_i32_buffer(&sequence_buffer)?; unsafe { free_i32_buffer( sequence_buffer.ptr, sequence_buffer.len, sequence_buffer.cap, ); } let byte_seed = normalize_byte_seed(args.left); let mut byte_buffer = SharedU8Buffer::default(); let byte_status = unsafe { make_byte_pattern(byte_seed, 6, &mut byte_buffer) }; if byte_status != 0 { return Err(format!("hs_make_byte_pattern returned status {byte_status}")); } let returned_bytes = read_u8_buffer(&byte_buffer)?; unsafe { free_u8_buffer(byte_buffer.ptr, byte_buffer.len, byte_buffer.cap); } Ok(format!( "Rust -> Haskell demo\nLibrary: {}\nInputs: name={}, left={}, right={}\nStats from Haskell: total={}, product={}, gap={}\nMessage from Haskell: {}\nSlice sent to Haskell: {:?}\nHaskell summed slice to: {}\nVector returned from Haskell: {:?}\nByte slice sent to Haskell: {:?}\nHaskell checksummed bytes to: {}\nByte buffer returned from Haskell: {:?}", library_path.display(), args.name, args.left, args.right, stats.total, stats.product, stats.gap, message, sample_values, slice_sum, sequence_values, sample_bytes, byte_checksum, returned_bytes, )) } fn read_i32_buffer(buffer: &SharedI32Buffer) -> Result, String> { if buffer.ptr.is_null() { if buffer.len == 0 { return Ok(Vec::new()); } return Err("received a null buffer pointer with a non-zero length".to_string()); } let values = unsafe { std::slice::from_raw_parts(buffer.ptr, buffer.len) }; Ok(values.to_vec()) } fn read_u8_buffer(buffer: &SharedU8Buffer) -> Result, String> { if buffer.ptr.is_null() { if buffer.len == 0 { return Ok(Vec::new()); } return Err("received a null byte buffer pointer with a non-zero length".to_string()); } let values = unsafe { std::slice::from_raw_parts(buffer.ptr, buffer.len) }; Ok(values.to_vec()) } fn normalize_byte_seed(value: i32) -> u8 { value.rem_euclid(256) as u8 } fn resolve_library_path(explicit_path: Option<&str>) -> Result { if let Some(path) = explicit_path { return Ok(PathBuf::from(path)); } if let Ok(path) = env::var("HASKELL_FOREIGN_LIB") { return Ok(PathBuf::from(path)); } let dist_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("haskell") .join("dist-newstyle"); let mut matches = Vec::new(); collect_matching_libraries(&dist_dir, &mut matches)?; matches.sort(); matches.into_iter().next().ok_or_else(|| { "could not find the Haskell foreign library under haskell/dist-newstyle; run `make haskell-build` first or pass an explicit path".to_string() }) } fn resolve_rts_library_path() -> Result { if let Ok(path) = env::var("GHC_RTS_LIB") { return Ok(PathBuf::from(path)); } let libdir = if let Ok(path) = env::var("GHC_LIBDIR") { PathBuf::from(path) } else { let output = std::process::Command::new("ghc") .arg("--print-libdir") .output() .map_err(|error| format!("failed to run `ghc --print-libdir`: {error}"))?; if !output.status.success() { return Err("`ghc --print-libdir` did not exit successfully".to_string()); } PathBuf::from(String::from_utf8_lossy(&output.stdout).trim()) }; let mut matches = Vec::new(); collect_matching_rts_libraries(&libdir, &mut matches)?; matches.sort_by_key(|path| rts_priority(path)); matches .into_iter() .next() .ok_or_else(|| format!("could not find a GHC RTS library under {}", libdir.display())) } fn collect_matching_libraries(root: &Path, matches: &mut Vec) -> Result<(), String> { if !root.exists() { return Ok(()); } let entries = fs::read_dir(root) .map_err(|error| format!("failed to read {}: {error}", root.display()))?; for entry in entries { let entry = entry.map_err(|error| format!("failed to inspect {}: {error}", root.display()))?; let path = entry.path(); if path.is_dir() { collect_matching_libraries(&path, matches)?; continue; } if is_haskell_foreign_library(&path) { matches.push(path); } } Ok(()) } fn collect_matching_rts_libraries(root: &Path, matches: &mut Vec) -> Result<(), String> { if !root.exists() { return Ok(()); } let entries = fs::read_dir(root) .map_err(|error| format!("failed to read {}: {error}", root.display()))?; for entry in entries { let entry = entry.map_err(|error| format!("failed to inspect {}: {error}", root.display()))?; let path = entry.path(); if path.is_dir() { collect_matching_rts_libraries(&path, matches)?; continue; } if is_rts_library(&path) { matches.push(path); } } Ok(()) } fn is_haskell_foreign_library(path: &Path) -> bool { let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else { return false; }; let is_library = file_name.ends_with(".so") || file_name.ends_with(".dylib") || file_name.ends_with(".dll"); is_library && file_name.contains("interop_hs") } fn is_rts_library(path: &Path) -> bool { let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else { return false; }; path.extension().and_then(|ext| ext.to_str()) == Some("so") && file_name.starts_with("libHSrts-") } fn rts_priority(path: &Path) -> (u8, String) { let file_name = path .file_name() .and_then(|value| value.to_str()) .unwrap_or_default() .to_string(); let rank = if file_name.contains("_debug") { 3 } else if file_name.contains("_thr") { 1 } else { 0 }; (rank, file_name) }