integrations/rust/interop.rs
2026-03-27 09:20:31 +01:00

266 lines
6.0 KiB
Rust

use std::ffi::{CStr, CString};
use std::os::raw::{c_char, c_int, c_uint};
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct SharedStats {
pub total: c_int,
pub product: c_int,
pub gap: c_int,
}
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct SharedI32Buffer {
pub ptr: *mut c_int,
pub len: usize,
pub cap: usize,
}
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct SharedU8Buffer {
pub ptr: *mut u8,
pub len: usize,
pub cap: usize,
}
pub fn compute_stats(left: c_int, right: c_int) -> SharedStats {
SharedStats {
total: left.saturating_add(right),
product: left.saturating_mul(right),
gap: left.saturating_sub(right).abs(),
}
}
pub fn sum_slice(values: &[c_int]) -> c_int {
values
.iter()
.copied()
.fold(0, |accumulator, value| accumulator.saturating_add(value))
}
pub fn checksum_bytes(values: &[u8]) -> c_uint {
values.iter().fold(0_u32, |accumulator, value| {
accumulator.saturating_add(u32::from(*value))
})
}
pub fn make_rust_message(name: &str, left: c_int, right: c_int) -> String {
let stats = compute_stats(left, right);
format!(
"Rust handled {name}: total={}, product={}, gap={}",
stats.total, stats.product, stats.gap
)
}
pub fn make_sequence_values(start: c_int, count: usize) -> Vec<c_int> {
(0..count)
.map(|index| {
let offset = match c_int::try_from(index) {
Ok(value) => value,
Err(_) => c_int::MAX,
};
start.saturating_add(offset)
})
.collect()
}
pub fn make_byte_pattern(seed: u8, count: usize) -> Vec<u8> {
(0..count)
.map(|index| {
let offset = match u8::try_from(index) {
Ok(value) => value,
Err(_) => u8::MAX,
};
seed.wrapping_add(offset)
})
.collect()
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_compute_stats(
left: c_int,
right: c_int,
out_stats: *mut SharedStats,
) -> c_int {
if out_stats.is_null() {
return 1;
}
unsafe {
out_stats.write(compute_stats(left, right));
}
0
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_sum_slice(ptr: *const c_int, len: usize) -> c_int {
if ptr.is_null() {
return 0;
}
let values = unsafe { std::slice::from_raw_parts(ptr, len) };
sum_slice(values)
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_make_message(
name: *const c_char,
left: c_int,
right: c_int,
) -> *mut c_char {
if name.is_null() {
return string_into_raw("Rust received a null name pointer".to_string());
}
let name = unsafe { CStr::from_ptr(name) }.to_string_lossy();
string_into_raw(make_rust_message(name.as_ref(), left, right))
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_checksum_bytes(ptr: *const u8, len: usize) -> c_uint {
if ptr.is_null() {
return 0;
}
let values = unsafe { std::slice::from_raw_parts(ptr, len) };
checksum_bytes(values)
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_make_sequence(
start: c_int,
count: usize,
out_buffer: *mut SharedI32Buffer,
) -> c_int {
if out_buffer.is_null() {
return 1;
}
let mut values = make_sequence_values(start, count);
let buffer = SharedI32Buffer {
ptr: values.as_mut_ptr(),
len: values.len(),
cap: values.capacity(),
};
std::mem::forget(values);
unsafe {
out_buffer.write(buffer);
}
0
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_make_byte_pattern(
seed: u8,
count: usize,
out_buffer: *mut SharedU8Buffer,
) -> c_int {
if out_buffer.is_null() {
return 1;
}
let mut values = make_byte_pattern(seed, count);
let buffer = SharedU8Buffer {
ptr: values.as_mut_ptr(),
len: values.len(),
cap: values.capacity(),
};
std::mem::forget(values);
unsafe {
out_buffer.write(buffer);
}
0
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_free_string(ptr: *mut c_char) {
if ptr.is_null() {
return;
}
unsafe {
drop(CString::from_raw(ptr));
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_free_i32_buffer(ptr: *mut c_int, len: usize, cap: usize) {
if ptr.is_null() {
return;
}
unsafe {
drop(Vec::from_raw_parts(ptr, len, cap));
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn rust_free_u8_buffer(ptr: *mut u8, len: usize, cap: usize) {
if ptr.is_null() {
return;
}
unsafe {
drop(Vec::from_raw_parts(ptr, len, cap));
}
}
fn string_into_raw(message: String) -> *mut c_char {
let sanitized = message.replace('\0', "?");
match CString::new(sanitized) {
Ok(c_string) => c_string.into_raw(),
Err(_) => std::ptr::null_mut(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compute_stats_matches_expected_values() {
assert_eq!(
compute_stats(9, 4),
SharedStats {
total: 13,
product: 36,
gap: 5,
}
);
}
#[test]
fn message_contains_name_and_values() {
let message = make_rust_message("Ada", 7, 5);
assert!(message.contains("Ada"));
assert!(message.contains("total=12"));
assert!(message.contains("product=35"));
assert!(message.contains("gap=2"));
}
#[test]
fn sum_slice_handles_multiple_values() {
assert_eq!(sum_slice(&[2, 4, 6, 8]), 20);
}
#[test]
fn make_sequence_values_builds_contiguous_output() {
assert_eq!(make_sequence_values(7, 5), vec![7, 8, 9, 10, 11]);
}
#[test]
fn checksum_bytes_handles_embedded_zeroes() {
assert_eq!(checksum_bytes(&[72, 0, 105, 255]), 432);
}
#[test]
fn make_byte_pattern_wraps_like_bytes() {
assert_eq!(make_byte_pattern(254, 4), vec![254, 255, 0, 1]);
}
}