Compare commits

...

8 Commits

Author SHA1 Message Date
George Thomas
c1d30b0e61 refactor 2026-04-13 15:50:03 +01:00
George Thomas
ea5f816425 basic enum example 2026-04-13 15:48:41 +01:00
George Thomas
5c3aa088f7 Add maybe/optional example 2026-04-13 13:04:06 +01:00
George Thomas
6203e7570f Add vector/slice example 2026-04-13 13:04:06 +01:00
George Thomas
9e8bccdafd Pass all non-primitive types by (immutable) reference 2026-04-13 13:04:06 +01:00
George Thomas
8cb0bf9c5e Deduplicate base dependency 2026-04-13 13:04:06 +01:00
George Thomas
2ff3ad93e5 Simplify build script
Seeing as we now no longer need to modify generated header files.
2026-04-13 13:04:06 +01:00
George Thomas
8a5bf61d6b wip file dependency reload stuff 2026-04-13 13:04:06 +01:00
5 changed files with 102 additions and 26 deletions

View File

@ -1,13 +1,20 @@
module Main (main) where module Main (main) where
import Data.Vector.Storable qualified as V
import GarnetRs.Wrapped import GarnetRs.Wrapped
import System.IO
main :: IO () main :: IO ()
main = do main = do
hello "Haskell" hello "Haskell"
helloStruct T{a = True, b = 42} helloStruct T{a = True, b = 42}
helloStruct T{a = False, b = maxBound} helloStruct T{a = False, b = maxBound}
helloEnum E1
helloEnum E3
helloShape $ Circle 3.14 helloShape $ Circle 3.14
helloShape $ Rectangle 10.0 5.0 helloShape $ Rectangle 10.0 5.0
putStrLn $ "3 + 4 = " <> show (add 3 4) putStrLn $ "3 + 4 = " <> show (add 3 4)
putStrLn $ "Tree sum: " <> show (sumTree (Fork (Fork (Leaf 1) (Fork (Leaf 2) (Leaf 3))) (Leaf 4))) putStrLn $ "Tree sum: " <> show (sumTree (Fork (Fork (Leaf 1) (Fork (Leaf 2) (Leaf 3))) (Leaf 4)))
putStrLn $ "Slice sum: " <> show (sumSlice $ V.fromList [0 .. 5])
putStrLn "Nothing." >> printOptional Nothing
putStr "Something: " >> hFlush stdout >> printOptional (Just 67)

View File

@ -6,6 +6,13 @@ author: Patrick Aldis
maintainer: maintainer:
george.thomas@obsidian.systems george.thomas@obsidian.systems
patrick.aldis@obsidian.systems patrick.aldis@obsidian.systems
-- aha, nice, this does fix recompilation checking
extra-source-files:
rust/target/debug/garnet_rs.h
-- that could be problematic given file is gitignored? unfortunately this doesn't work
-- tbf, I haven't even looked up the docs, just saw the autocompletion
-- extra-tmp-files:
-- rust/target/debug/garnet_rs.h
common common common common
default-language: GHC2024 default-language: GHC2024
@ -28,11 +35,13 @@ common common
-Wall -Wall
-fdefer-type-errors -fdefer-type-errors
build-depends: build-depends:
base,
bytestring, bytestring,
extra, extra,
mtl, mtl,
process, process,
text, text,
vector,
library library
import: common import: common
@ -41,9 +50,11 @@ library
GarnetRs.Wrapped GarnetRs.Wrapped
hs-source-dirs: lib hs-source-dirs: lib
include-dirs: rust/target/debug include-dirs: rust/target/debug
-- HLS gives up entirely when the header is malformed if we do this
-- and anyway, I don't think it gives us dependency tracking like `extra-source-files` does
-- includes: garnet_rs.h
extra-bundled-libraries: Cgarnet_rs extra-bundled-libraries: Cgarnet_rs
build-depends: build-depends:
base,
hs-bindgen, hs-bindgen,
hs-bindgen-runtime, hs-bindgen-runtime,
template-haskell, template-haskell,
@ -57,5 +68,4 @@ executable garnet
-rtsopts -rtsopts
-with-rtsopts=-N -with-rtsopts=-N
build-depends: build-depends:
base,
garnet, garnet,

View File

@ -1,20 +1,32 @@
{-# LANGUAGE PatternSynonyms #-}
-- TODO automate this sort of high level wrapper boilerplate -- TODO automate this sort of high level wrapper boilerplate
-- or look at upstream plans: https://github.com/well-typed/hs-bindgen/issues?q=state%3Aopen%20label%3A%22highlevel%22 -- or look at upstream plans: https://github.com/well-typed/hs-bindgen/issues?q=state%3Aopen%20label%3A%22highlevel%22
module GarnetRs.Wrapped ( module GarnetRs.Wrapped (
T (..), T (..),
Raw.E (..),
-- TODO hmm, we don't really want to have to list all of these...
-- is there an option to make them not be patterns at all?
pattern Raw.E1,
pattern Raw.E2,
pattern Raw.E3,
Shape (..), Shape (..),
BTree (..), BTree (..),
hello, hello,
helloStruct, helloStruct,
helloEnum,
helloShape, helloShape,
add, add,
sumTree, sumTree,
sumSlice,
printOptional,
) where ) where
import Control.Monad.Cont import Control.Monad.Cont
import Control.Monad.Trans import Control.Monad.Trans
import Data.ByteString import Data.ByteString
import Data.Function import Data.Function
import Data.Vector.Storable qualified as V
import Data.Word import Data.Word
import Foreign import Foreign
import Foreign.C import Foreign.C
@ -54,16 +66,27 @@ withBTree =
Raw.Fork_Body (unsafeFromPtr lPtr) (unsafeFromPtr rPtr) Raw.Fork_Body (unsafeFromPtr lPtr) (unsafeFromPtr rPtr)
hello :: ByteString -> IO () hello :: ByteString -> IO ()
hello s = useAsCString s $ Raw.hello . unsafeFromPtr hello = flip useAsCString $ Raw.hello . unsafeFromPtr
helloStruct :: T -> IO () helloStruct :: T -> IO ()
helloStruct = Raw.hello_struct . convertT helloStruct = flip with (Raw.hello_struct . unsafeFromPtr) . convertT
helloEnum :: Raw.E -> IO ()
helloEnum = flip with (Raw.hello_enum . unsafeFromPtr)
helloShape :: Shape -> IO () helloShape :: Shape -> IO ()
helloShape = Raw.hello_shape . convertShape helloShape = flip with (Raw.hello_shape . unsafeFromPtr) . convertShape
add :: Int64 -> Int64 -> Int64 add :: Int64 -> Int64 -> Int64
add = Raw.add add = Raw.add
sumTree :: BTree Int64 -> Int64 sumTree :: BTree Int64 -> Int64
sumTree = unsafePerformIO . flip withBTree Raw.sum_tree sumTree = unsafePerformIO . flip withBTree (flip with $ Raw.sum_tree . unsafeFromPtr)
sumSlice :: V.Vector Int64 -> Int64
sumSlice v = unsafePerformIO $ V.unsafeWith v \p -> Raw.sum_slice (unsafeFromPtr p) (fromIntegral $ V.length v)
printOptional :: Maybe Int8 -> IO ()
printOptional = \case
Nothing -> Raw.print_optional (unsafeFromPtr nullPtr)
Just t -> with t (Raw.print_optional . unsafeFromPtr)

View File

@ -1,27 +1,36 @@
use std::env; use std::env;
use std::fs;
use std::path::PathBuf; use std::path::PathBuf;
fn main() { fn main() {
// this doesn't make _much_ difference really, since this is our only Rust source file
// but it seems it's probably better than not having it
// what we really want is to tell Rust to only regenerate the header file if the Rust code actually compiles
// but we don't have that flexibility
// and it's an issue because cbindgen tries to be fault-tolerant in some ways that don't even seem to make sense
//
// e.g. mis-spell "Option" as "Option" and you get
// void print_optional(Optio<const int8_t*> x);
// instead of
// void print_optional(const int8_t *x);
// and that's only an issue because in HLS TH dependent-file watching gives up after an error
// i.e. once the containing splice has thrown an exception once, the containing file needs a manual edit to kick it
// and that's really not helped by Rust Analyzer mostly only showing diagnostics on save
// P.S. strings to stdout?! what a terrible API
// don't get me started in the discoverability of actually then using the terminal for debugging:
// println!("cargo::warning={:?}", env::var("OUT_DIR"));
println!("cargo::rerun-if-changed=lib.rs");
let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let profile = env::var("PROFILE").unwrap(); let profile = env::var("PROFILE").unwrap();
cbindgen::Builder::new()
let bindings = cbindgen::Builder::new()
.with_crate(&crate_dir) .with_crate(&crate_dir)
.with_language(cbindgen::Language::C) .with_language(cbindgen::Language::C)
.with_style(cbindgen::Style::Tag) .with_style(cbindgen::Style::Tag)
.generate() .generate()
.expect("Unable to generate bindings"); .expect("Unable to generate bindings")
.write_to_file(
let mut buf = Vec::new();
bindings.write(&mut buf);
let header = String::from_utf8(buf).unwrap();
fs::write(
PathBuf::from(&crate_dir) PathBuf::from(&crate_dir)
.join("target") .join("target")
.join(&profile) .join(&profile)
.join("garnet_rs.h"), .join("garnet_rs.h"),
header, );
)
.unwrap();
} }

View File

@ -3,6 +3,7 @@
use std::{ use std::{
ffi::{CStr, c_char}, ffi::{CStr, c_char},
ops::Add, ops::Add,
slice,
}; };
fn say_hello(name: &str) { fn say_hello(name: &str) {
@ -22,10 +23,23 @@ struct T {
} }
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
extern "C" fn hello_struct(t: T) -> () { extern "C" fn hello_struct(t: &T) -> () {
say_hello(&format!("{:?}", t)) say_hello(&format!("{:?}", t))
} }
#[repr(C)]
#[derive(Debug)]
enum E {
E1,
E2,
E3,
}
#[unsafe(no_mangle)]
extern "C" fn hello_enum(e: &E) {
say_hello(&format!("{:?}", e))
}
#[repr(C)] #[repr(C)]
#[derive(Debug)] #[derive(Debug)]
enum Shape { enum Shape {
@ -34,7 +48,7 @@ enum Shape {
} }
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
extern "C" fn hello_shape(s: Shape) -> () { extern "C" fn hello_shape(s: &Shape) -> () {
say_hello(&format!("{:?}", s)) say_hello(&format!("{:?}", s))
} }
@ -78,6 +92,19 @@ enum BTreeC {
} }
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
extern "C" fn sum_tree(t: BTreeC) -> i64 { extern "C" fn sum_tree(t: &BTreeC) -> i64 {
(unsafe { std::mem::transmute::<_, &BTree<i64>>(&t) }).sum() (unsafe { std::mem::transmute::<_, &BTree<i64>>(t) }).sum()
}
#[unsafe(no_mangle)]
extern "C" fn sum_slice(v: *const i64, s: usize) -> i64 {
unsafe { slice::from_raw_parts(v, s) }.iter().sum()
}
#[unsafe(no_mangle)]
extern "C" fn print_optional(x: Option<&i8>) -> () {
match x {
Some(x) => println!("{}", x / 2),
None => {}
}
} }