368 lines
12 KiB
Rust
368 lines
12 KiB
Rust
use crate::interop::{SharedI32Buffer, SharedStats, SharedU8Buffer};
|
|
use libloading::Library;
|
|
use std::env;
|
|
use std::ffi::{CStr, CString};
|
|
use std::fs;
|
|
use std::os::raw::{c_char, c_int, c_uint};
|
|
use std::path::{Path, PathBuf};
|
|
|
|
type HsComputeStats = unsafe extern "C" fn(c_int, c_int, *mut SharedStats) -> c_int;
|
|
type HsSumSlice = unsafe extern "C" fn(*const c_int, usize) -> c_int;
|
|
type HsChecksumBytes = unsafe extern "C" fn(*const u8, usize) -> c_uint;
|
|
type HsMakeMessage = unsafe extern "C" fn(*const c_char, c_int, c_int) -> *mut c_char;
|
|
type HsMakeSequence = unsafe extern "C" fn(c_int, usize, *mut SharedI32Buffer) -> c_int;
|
|
type HsMakeBytePattern = unsafe extern "C" fn(u8, usize, *mut SharedU8Buffer) -> c_int;
|
|
type HsFreeString = unsafe extern "C" fn(*mut c_char);
|
|
type HsFreeI32Buffer = unsafe extern "C" fn(*mut c_int, usize, usize);
|
|
type HsFreeU8Buffer = unsafe extern "C" fn(*mut u8, usize, usize);
|
|
type HsInit = unsafe extern "C" fn(*mut c_int, *mut *mut *mut c_char);
|
|
type HsExit = unsafe extern "C" fn();
|
|
|
|
#[derive(Clone, Debug, PartialEq, Eq)]
|
|
pub struct DemoArgs {
|
|
pub name: String,
|
|
pub left: i32,
|
|
pub right: i32,
|
|
pub library_path: Option<String>,
|
|
}
|
|
|
|
pub fn run_haskell_demo(args: &DemoArgs) -> Result<String, String> {
|
|
let library_path = resolve_library_path(args.library_path.as_deref())?;
|
|
let library = unsafe { Library::new(&library_path) }
|
|
.map_err(|error| format!("failed to load {}: {error}", library_path.display()))?;
|
|
let runtime = HaskellRuntime::start()?;
|
|
|
|
let output = load_and_run(&library_path, &library, args);
|
|
|
|
drop(runtime);
|
|
output
|
|
}
|
|
|
|
struct HaskellRuntime {
|
|
_rts_library: Library,
|
|
hs_exit: HsExit,
|
|
}
|
|
|
|
impl HaskellRuntime {
|
|
fn start() -> Result<Self, String> {
|
|
let rts_library_path = resolve_rts_library_path()?;
|
|
let rts_library = unsafe { Library::new(&rts_library_path) }.map_err(|error| {
|
|
format!(
|
|
"failed to load GHC RTS library {}: {error}",
|
|
rts_library_path.display()
|
|
)
|
|
})?;
|
|
let hs_init: HsInit = unsafe {
|
|
*rts_library
|
|
.get(b"hs_init\0")
|
|
.map_err(|error| format!("failed to load hs_init: {error}"))?
|
|
};
|
|
let hs_exit: HsExit = unsafe {
|
|
*rts_library
|
|
.get(b"hs_exit\0")
|
|
.map_err(|error| format!("failed to load hs_exit: {error}"))?
|
|
};
|
|
let mut argc: c_int = 1;
|
|
let program_name = CString::new("integrations-hs-runtime")
|
|
.map_err(|_| "failed to create runtime program name".to_string())?;
|
|
let mut argv = vec![program_name.as_ptr() as *mut c_char, std::ptr::null_mut()];
|
|
let mut argv_ptr = argv.as_mut_ptr();
|
|
|
|
unsafe {
|
|
hs_init(&mut argc, &mut argv_ptr);
|
|
}
|
|
|
|
Ok(Self {
|
|
_rts_library: rts_library,
|
|
hs_exit,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl Drop for HaskellRuntime {
|
|
fn drop(&mut self) {
|
|
unsafe {
|
|
(self.hs_exit)();
|
|
}
|
|
}
|
|
}
|
|
|
|
fn load_and_run(library_path: &Path, library: &Library, args: &DemoArgs) -> Result<String, String> {
|
|
let compute_stats: HsComputeStats = unsafe {
|
|
*library
|
|
.get(b"hs_compute_stats\0")
|
|
.map_err(|error| format!("failed to load hs_compute_stats: {error}"))?
|
|
};
|
|
let sum_slice: HsSumSlice = unsafe {
|
|
*library
|
|
.get(b"hs_sum_slice\0")
|
|
.map_err(|error| format!("failed to load hs_sum_slice: {error}"))?
|
|
};
|
|
let checksum_bytes: HsChecksumBytes = unsafe {
|
|
*library
|
|
.get(b"hs_checksum_bytes\0")
|
|
.map_err(|error| format!("failed to load hs_checksum_bytes: {error}"))?
|
|
};
|
|
let make_message: HsMakeMessage = unsafe {
|
|
*library
|
|
.get(b"hs_make_message\0")
|
|
.map_err(|error| format!("failed to load hs_make_message: {error}"))?
|
|
};
|
|
let make_sequence: HsMakeSequence = unsafe {
|
|
*library
|
|
.get(b"hs_make_sequence\0")
|
|
.map_err(|error| format!("failed to load hs_make_sequence: {error}"))?
|
|
};
|
|
let make_byte_pattern: HsMakeBytePattern = unsafe {
|
|
*library
|
|
.get(b"hs_make_byte_pattern\0")
|
|
.map_err(|error| format!("failed to load hs_make_byte_pattern: {error}"))?
|
|
};
|
|
let free_string: HsFreeString = unsafe {
|
|
*library
|
|
.get(b"hs_free_string\0")
|
|
.map_err(|error| format!("failed to load hs_free_string: {error}"))?
|
|
};
|
|
let free_i32_buffer: HsFreeI32Buffer = unsafe {
|
|
*library
|
|
.get(b"hs_free_i32_buffer\0")
|
|
.map_err(|error| format!("failed to load hs_free_i32_buffer: {error}"))?
|
|
};
|
|
let free_u8_buffer: HsFreeU8Buffer = unsafe {
|
|
*library
|
|
.get(b"hs_free_u8_buffer\0")
|
|
.map_err(|error| format!("failed to load hs_free_u8_buffer: {error}"))?
|
|
};
|
|
|
|
let mut stats = SharedStats::default();
|
|
let status = unsafe { compute_stats(args.left, args.right, &mut stats) };
|
|
if status != 0 {
|
|
return Err(format!("hs_compute_stats returned status {status}"));
|
|
}
|
|
|
|
let name = CString::new(args.name.replace('\0', "?"))
|
|
.map_err(|_| "failed to prepare demo name".to_string())?;
|
|
let message_ptr = unsafe { make_message(name.as_ptr(), args.left, args.right) };
|
|
if message_ptr.is_null() {
|
|
return Err("hs_make_message returned a null pointer".to_string());
|
|
}
|
|
|
|
let message = unsafe { CStr::from_ptr(message_ptr) }
|
|
.to_string_lossy()
|
|
.into_owned();
|
|
unsafe {
|
|
free_string(message_ptr);
|
|
}
|
|
|
|
let sample_values: [c_int; 4] = [2, 4, 6, 8];
|
|
let slice_sum = unsafe { sum_slice(sample_values.as_ptr(), sample_values.len()) };
|
|
let sample_bytes: [u8; 4] = [72, 0, 105, 255];
|
|
let byte_checksum = unsafe { checksum_bytes(sample_bytes.as_ptr(), sample_bytes.len()) };
|
|
|
|
let mut sequence_buffer = SharedI32Buffer::default();
|
|
let sequence_status = unsafe { make_sequence(args.left, 5, &mut sequence_buffer) };
|
|
if sequence_status != 0 {
|
|
return Err(format!("hs_make_sequence returned status {sequence_status}"));
|
|
}
|
|
|
|
let sequence_values = read_i32_buffer(&sequence_buffer)?;
|
|
unsafe {
|
|
free_i32_buffer(
|
|
sequence_buffer.ptr,
|
|
sequence_buffer.len,
|
|
sequence_buffer.cap,
|
|
);
|
|
}
|
|
|
|
let byte_seed = normalize_byte_seed(args.left);
|
|
let mut byte_buffer = SharedU8Buffer::default();
|
|
let byte_status = unsafe { make_byte_pattern(byte_seed, 6, &mut byte_buffer) };
|
|
if byte_status != 0 {
|
|
return Err(format!("hs_make_byte_pattern returned status {byte_status}"));
|
|
}
|
|
|
|
let returned_bytes = read_u8_buffer(&byte_buffer)?;
|
|
unsafe {
|
|
free_u8_buffer(byte_buffer.ptr, byte_buffer.len, byte_buffer.cap);
|
|
}
|
|
|
|
Ok(format!(
|
|
"Rust -> Haskell demo\nLibrary: {}\nInputs: name={}, left={}, right={}\nStats from Haskell: total={}, product={}, gap={}\nMessage from Haskell: {}\nSlice sent to Haskell: {:?}\nHaskell summed slice to: {}\nVector returned from Haskell: {:?}\nByte slice sent to Haskell: {:?}\nHaskell checksummed bytes to: {}\nByte buffer returned from Haskell: {:?}",
|
|
library_path.display(),
|
|
args.name,
|
|
args.left,
|
|
args.right,
|
|
stats.total,
|
|
stats.product,
|
|
stats.gap,
|
|
message,
|
|
sample_values,
|
|
slice_sum,
|
|
sequence_values,
|
|
sample_bytes,
|
|
byte_checksum,
|
|
returned_bytes,
|
|
))
|
|
}
|
|
|
|
fn read_i32_buffer(buffer: &SharedI32Buffer) -> Result<Vec<c_int>, String> {
|
|
if buffer.ptr.is_null() {
|
|
if buffer.len == 0 {
|
|
return Ok(Vec::new());
|
|
}
|
|
return Err("received a null buffer pointer with a non-zero length".to_string());
|
|
}
|
|
|
|
let values = unsafe { std::slice::from_raw_parts(buffer.ptr, buffer.len) };
|
|
Ok(values.to_vec())
|
|
}
|
|
|
|
fn read_u8_buffer(buffer: &SharedU8Buffer) -> Result<Vec<u8>, String> {
|
|
if buffer.ptr.is_null() {
|
|
if buffer.len == 0 {
|
|
return Ok(Vec::new());
|
|
}
|
|
return Err("received a null byte buffer pointer with a non-zero length".to_string());
|
|
}
|
|
|
|
let values = unsafe { std::slice::from_raw_parts(buffer.ptr, buffer.len) };
|
|
Ok(values.to_vec())
|
|
}
|
|
|
|
fn normalize_byte_seed(value: i32) -> u8 {
|
|
value.rem_euclid(256) as u8
|
|
}
|
|
|
|
fn resolve_library_path(explicit_path: Option<&str>) -> Result<PathBuf, String> {
|
|
if let Some(path) = explicit_path {
|
|
return Ok(PathBuf::from(path));
|
|
}
|
|
|
|
if let Ok(path) = env::var("HASKELL_FOREIGN_LIB") {
|
|
return Ok(PathBuf::from(path));
|
|
}
|
|
|
|
let dist_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
|
|
.join("haskell")
|
|
.join("dist-newstyle");
|
|
let mut matches = Vec::new();
|
|
collect_matching_libraries(&dist_dir, &mut matches)?;
|
|
matches.sort();
|
|
|
|
matches.into_iter().next().ok_or_else(|| {
|
|
"could not find the Haskell foreign library under haskell/dist-newstyle; run `make haskell-build` first or pass an explicit path".to_string()
|
|
})
|
|
}
|
|
|
|
fn resolve_rts_library_path() -> Result<PathBuf, String> {
|
|
if let Ok(path) = env::var("GHC_RTS_LIB") {
|
|
return Ok(PathBuf::from(path));
|
|
}
|
|
|
|
let libdir = if let Ok(path) = env::var("GHC_LIBDIR") {
|
|
PathBuf::from(path)
|
|
} else {
|
|
let output = std::process::Command::new("ghc")
|
|
.arg("--print-libdir")
|
|
.output()
|
|
.map_err(|error| format!("failed to run `ghc --print-libdir`: {error}"))?;
|
|
if !output.status.success() {
|
|
return Err("`ghc --print-libdir` did not exit successfully".to_string());
|
|
}
|
|
PathBuf::from(String::from_utf8_lossy(&output.stdout).trim())
|
|
};
|
|
|
|
let mut matches = Vec::new();
|
|
collect_matching_rts_libraries(&libdir, &mut matches)?;
|
|
matches.sort_by_key(|path| rts_priority(path));
|
|
matches
|
|
.into_iter()
|
|
.next()
|
|
.ok_or_else(|| format!("could not find a GHC RTS library under {}", libdir.display()))
|
|
}
|
|
|
|
fn collect_matching_libraries(root: &Path, matches: &mut Vec<PathBuf>) -> Result<(), String> {
|
|
if !root.exists() {
|
|
return Ok(());
|
|
}
|
|
|
|
let entries = fs::read_dir(root)
|
|
.map_err(|error| format!("failed to read {}: {error}", root.display()))?;
|
|
for entry in entries {
|
|
let entry = entry.map_err(|error| format!("failed to inspect {}: {error}", root.display()))?;
|
|
let path = entry.path();
|
|
if path.is_dir() {
|
|
collect_matching_libraries(&path, matches)?;
|
|
continue;
|
|
}
|
|
|
|
if is_haskell_foreign_library(&path) {
|
|
matches.push(path);
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn collect_matching_rts_libraries(root: &Path, matches: &mut Vec<PathBuf>) -> Result<(), String> {
|
|
if !root.exists() {
|
|
return Ok(());
|
|
}
|
|
|
|
let entries = fs::read_dir(root)
|
|
.map_err(|error| format!("failed to read {}: {error}", root.display()))?;
|
|
for entry in entries {
|
|
let entry = entry.map_err(|error| format!("failed to inspect {}: {error}", root.display()))?;
|
|
let path = entry.path();
|
|
if path.is_dir() {
|
|
collect_matching_rts_libraries(&path, matches)?;
|
|
continue;
|
|
}
|
|
|
|
if is_rts_library(&path) {
|
|
matches.push(path);
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn is_haskell_foreign_library(path: &Path) -> bool {
|
|
let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else {
|
|
return false;
|
|
};
|
|
|
|
let is_library = file_name.ends_with(".so")
|
|
|| file_name.ends_with(".dylib")
|
|
|| file_name.ends_with(".dll");
|
|
|
|
is_library && file_name.contains("interop_hs")
|
|
}
|
|
|
|
fn is_rts_library(path: &Path) -> bool {
|
|
let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else {
|
|
return false;
|
|
};
|
|
|
|
path.extension().and_then(|ext| ext.to_str()) == Some("so")
|
|
&& file_name.starts_with("libHSrts-")
|
|
}
|
|
|
|
fn rts_priority(path: &Path) -> (u8, String) {
|
|
let file_name = path
|
|
.file_name()
|
|
.and_then(|value| value.to_str())
|
|
.unwrap_or_default()
|
|
.to_string();
|
|
|
|
let rank = if file_name.contains("_debug") {
|
|
3
|
|
} else if file_name.contains("_thr") {
|
|
1
|
|
} else {
|
|
0
|
|
};
|
|
|
|
(rank, file_name)
|
|
}
|