integrations/rust/haskell.rs
2026-03-27 09:20:31 +01:00

368 lines
12 KiB
Rust

use crate::interop::{SharedI32Buffer, SharedStats, SharedU8Buffer};
use libloading::Library;
use std::env;
use std::ffi::{CStr, CString};
use std::fs;
use std::os::raw::{c_char, c_int, c_uint};
use std::path::{Path, PathBuf};
type HsComputeStats = unsafe extern "C" fn(c_int, c_int, *mut SharedStats) -> c_int;
type HsSumSlice = unsafe extern "C" fn(*const c_int, usize) -> c_int;
type HsChecksumBytes = unsafe extern "C" fn(*const u8, usize) -> c_uint;
type HsMakeMessage = unsafe extern "C" fn(*const c_char, c_int, c_int) -> *mut c_char;
type HsMakeSequence = unsafe extern "C" fn(c_int, usize, *mut SharedI32Buffer) -> c_int;
type HsMakeBytePattern = unsafe extern "C" fn(u8, usize, *mut SharedU8Buffer) -> c_int;
type HsFreeString = unsafe extern "C" fn(*mut c_char);
type HsFreeI32Buffer = unsafe extern "C" fn(*mut c_int, usize, usize);
type HsFreeU8Buffer = unsafe extern "C" fn(*mut u8, usize, usize);
type HsInit = unsafe extern "C" fn(*mut c_int, *mut *mut *mut c_char);
type HsExit = unsafe extern "C" fn();
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct DemoArgs {
pub name: String,
pub left: i32,
pub right: i32,
pub library_path: Option<String>,
}
pub fn run_haskell_demo(args: &DemoArgs) -> Result<String, String> {
let library_path = resolve_library_path(args.library_path.as_deref())?;
let library = unsafe { Library::new(&library_path) }
.map_err(|error| format!("failed to load {}: {error}", library_path.display()))?;
let runtime = HaskellRuntime::start()?;
let output = load_and_run(&library_path, &library, args);
drop(runtime);
output
}
struct HaskellRuntime {
_rts_library: Library,
hs_exit: HsExit,
}
impl HaskellRuntime {
fn start() -> Result<Self, String> {
let rts_library_path = resolve_rts_library_path()?;
let rts_library = unsafe { Library::new(&rts_library_path) }.map_err(|error| {
format!(
"failed to load GHC RTS library {}: {error}",
rts_library_path.display()
)
})?;
let hs_init: HsInit = unsafe {
*rts_library
.get(b"hs_init\0")
.map_err(|error| format!("failed to load hs_init: {error}"))?
};
let hs_exit: HsExit = unsafe {
*rts_library
.get(b"hs_exit\0")
.map_err(|error| format!("failed to load hs_exit: {error}"))?
};
let mut argc: c_int = 1;
let program_name = CString::new("integrations-hs-runtime")
.map_err(|_| "failed to create runtime program name".to_string())?;
let mut argv = vec![program_name.as_ptr() as *mut c_char, std::ptr::null_mut()];
let mut argv_ptr = argv.as_mut_ptr();
unsafe {
hs_init(&mut argc, &mut argv_ptr);
}
Ok(Self {
_rts_library: rts_library,
hs_exit,
})
}
}
impl Drop for HaskellRuntime {
fn drop(&mut self) {
unsafe {
(self.hs_exit)();
}
}
}
fn load_and_run(library_path: &Path, library: &Library, args: &DemoArgs) -> Result<String, String> {
let compute_stats: HsComputeStats = unsafe {
*library
.get(b"hs_compute_stats\0")
.map_err(|error| format!("failed to load hs_compute_stats: {error}"))?
};
let sum_slice: HsSumSlice = unsafe {
*library
.get(b"hs_sum_slice\0")
.map_err(|error| format!("failed to load hs_sum_slice: {error}"))?
};
let checksum_bytes: HsChecksumBytes = unsafe {
*library
.get(b"hs_checksum_bytes\0")
.map_err(|error| format!("failed to load hs_checksum_bytes: {error}"))?
};
let make_message: HsMakeMessage = unsafe {
*library
.get(b"hs_make_message\0")
.map_err(|error| format!("failed to load hs_make_message: {error}"))?
};
let make_sequence: HsMakeSequence = unsafe {
*library
.get(b"hs_make_sequence\0")
.map_err(|error| format!("failed to load hs_make_sequence: {error}"))?
};
let make_byte_pattern: HsMakeBytePattern = unsafe {
*library
.get(b"hs_make_byte_pattern\0")
.map_err(|error| format!("failed to load hs_make_byte_pattern: {error}"))?
};
let free_string: HsFreeString = unsafe {
*library
.get(b"hs_free_string\0")
.map_err(|error| format!("failed to load hs_free_string: {error}"))?
};
let free_i32_buffer: HsFreeI32Buffer = unsafe {
*library
.get(b"hs_free_i32_buffer\0")
.map_err(|error| format!("failed to load hs_free_i32_buffer: {error}"))?
};
let free_u8_buffer: HsFreeU8Buffer = unsafe {
*library
.get(b"hs_free_u8_buffer\0")
.map_err(|error| format!("failed to load hs_free_u8_buffer: {error}"))?
};
let mut stats = SharedStats::default();
let status = unsafe { compute_stats(args.left, args.right, &mut stats) };
if status != 0 {
return Err(format!("hs_compute_stats returned status {status}"));
}
let name = CString::new(args.name.replace('\0', "?"))
.map_err(|_| "failed to prepare demo name".to_string())?;
let message_ptr = unsafe { make_message(name.as_ptr(), args.left, args.right) };
if message_ptr.is_null() {
return Err("hs_make_message returned a null pointer".to_string());
}
let message = unsafe { CStr::from_ptr(message_ptr) }
.to_string_lossy()
.into_owned();
unsafe {
free_string(message_ptr);
}
let sample_values: [c_int; 4] = [2, 4, 6, 8];
let slice_sum = unsafe { sum_slice(sample_values.as_ptr(), sample_values.len()) };
let sample_bytes: [u8; 4] = [72, 0, 105, 255];
let byte_checksum = unsafe { checksum_bytes(sample_bytes.as_ptr(), sample_bytes.len()) };
let mut sequence_buffer = SharedI32Buffer::default();
let sequence_status = unsafe { make_sequence(args.left, 5, &mut sequence_buffer) };
if sequence_status != 0 {
return Err(format!("hs_make_sequence returned status {sequence_status}"));
}
let sequence_values = read_i32_buffer(&sequence_buffer)?;
unsafe {
free_i32_buffer(
sequence_buffer.ptr,
sequence_buffer.len,
sequence_buffer.cap,
);
}
let byte_seed = normalize_byte_seed(args.left);
let mut byte_buffer = SharedU8Buffer::default();
let byte_status = unsafe { make_byte_pattern(byte_seed, 6, &mut byte_buffer) };
if byte_status != 0 {
return Err(format!("hs_make_byte_pattern returned status {byte_status}"));
}
let returned_bytes = read_u8_buffer(&byte_buffer)?;
unsafe {
free_u8_buffer(byte_buffer.ptr, byte_buffer.len, byte_buffer.cap);
}
Ok(format!(
"Rust -> Haskell demo\nLibrary: {}\nInputs: name={}, left={}, right={}\nStats from Haskell: total={}, product={}, gap={}\nMessage from Haskell: {}\nSlice sent to Haskell: {:?}\nHaskell summed slice to: {}\nVector returned from Haskell: {:?}\nByte slice sent to Haskell: {:?}\nHaskell checksummed bytes to: {}\nByte buffer returned from Haskell: {:?}",
library_path.display(),
args.name,
args.left,
args.right,
stats.total,
stats.product,
stats.gap,
message,
sample_values,
slice_sum,
sequence_values,
sample_bytes,
byte_checksum,
returned_bytes,
))
}
fn read_i32_buffer(buffer: &SharedI32Buffer) -> Result<Vec<c_int>, String> {
if buffer.ptr.is_null() {
if buffer.len == 0 {
return Ok(Vec::new());
}
return Err("received a null buffer pointer with a non-zero length".to_string());
}
let values = unsafe { std::slice::from_raw_parts(buffer.ptr, buffer.len) };
Ok(values.to_vec())
}
fn read_u8_buffer(buffer: &SharedU8Buffer) -> Result<Vec<u8>, String> {
if buffer.ptr.is_null() {
if buffer.len == 0 {
return Ok(Vec::new());
}
return Err("received a null byte buffer pointer with a non-zero length".to_string());
}
let values = unsafe { std::slice::from_raw_parts(buffer.ptr, buffer.len) };
Ok(values.to_vec())
}
fn normalize_byte_seed(value: i32) -> u8 {
value.rem_euclid(256) as u8
}
fn resolve_library_path(explicit_path: Option<&str>) -> Result<PathBuf, String> {
if let Some(path) = explicit_path {
return Ok(PathBuf::from(path));
}
if let Ok(path) = env::var("HASKELL_FOREIGN_LIB") {
return Ok(PathBuf::from(path));
}
let dist_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("haskell")
.join("dist-newstyle");
let mut matches = Vec::new();
collect_matching_libraries(&dist_dir, &mut matches)?;
matches.sort();
matches.into_iter().next().ok_or_else(|| {
"could not find the Haskell foreign library under haskell/dist-newstyle; run `make haskell-build` first or pass an explicit path".to_string()
})
}
fn resolve_rts_library_path() -> Result<PathBuf, String> {
if let Ok(path) = env::var("GHC_RTS_LIB") {
return Ok(PathBuf::from(path));
}
let libdir = if let Ok(path) = env::var("GHC_LIBDIR") {
PathBuf::from(path)
} else {
let output = std::process::Command::new("ghc")
.arg("--print-libdir")
.output()
.map_err(|error| format!("failed to run `ghc --print-libdir`: {error}"))?;
if !output.status.success() {
return Err("`ghc --print-libdir` did not exit successfully".to_string());
}
PathBuf::from(String::from_utf8_lossy(&output.stdout).trim())
};
let mut matches = Vec::new();
collect_matching_rts_libraries(&libdir, &mut matches)?;
matches.sort_by_key(|path| rts_priority(path));
matches
.into_iter()
.next()
.ok_or_else(|| format!("could not find a GHC RTS library under {}", libdir.display()))
}
fn collect_matching_libraries(root: &Path, matches: &mut Vec<PathBuf>) -> Result<(), String> {
if !root.exists() {
return Ok(());
}
let entries = fs::read_dir(root)
.map_err(|error| format!("failed to read {}: {error}", root.display()))?;
for entry in entries {
let entry = entry.map_err(|error| format!("failed to inspect {}: {error}", root.display()))?;
let path = entry.path();
if path.is_dir() {
collect_matching_libraries(&path, matches)?;
continue;
}
if is_haskell_foreign_library(&path) {
matches.push(path);
}
}
Ok(())
}
fn collect_matching_rts_libraries(root: &Path, matches: &mut Vec<PathBuf>) -> Result<(), String> {
if !root.exists() {
return Ok(());
}
let entries = fs::read_dir(root)
.map_err(|error| format!("failed to read {}: {error}", root.display()))?;
for entry in entries {
let entry = entry.map_err(|error| format!("failed to inspect {}: {error}", root.display()))?;
let path = entry.path();
if path.is_dir() {
collect_matching_rts_libraries(&path, matches)?;
continue;
}
if is_rts_library(&path) {
matches.push(path);
}
}
Ok(())
}
fn is_haskell_foreign_library(path: &Path) -> bool {
let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else {
return false;
};
let is_library = file_name.ends_with(".so")
|| file_name.ends_with(".dylib")
|| file_name.ends_with(".dll");
is_library && file_name.contains("interop_hs")
}
fn is_rts_library(path: &Path) -> bool {
let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else {
return false;
};
path.extension().and_then(|ext| ext.to_str()) == Some("so")
&& file_name.starts_with("libHSrts-")
}
fn rts_priority(path: &Path) -> (u8, String) {
let file_name = path
.file_name()
.and_then(|value| value.to_str())
.unwrap_or_default()
.to_string();
let rank = if file_name.contains("_debug") {
3
} else if file_name.contains("_thr") {
1
} else {
0
};
(rank, file_name)
}