diff --git a/exe/Main.hs b/exe/Main.hs index e950324..00dfc50 100644 --- a/exe/Main.hs +++ b/exe/Main.hs @@ -1,5 +1,6 @@ module Main (main) where +import Data.Vector.Storable qualified as V import GarnetRs.Wrapped main :: IO () @@ -11,3 +12,4 @@ main = do helloShape $ Rectangle 10.0 5.0 putStrLn $ "3 + 4 = " <> show (add 3 4) putStrLn $ "Tree sum: " <> show (sumTree (Fork (Fork (Leaf 1) (Fork (Leaf 2) (Leaf 3))) (Leaf 4))) + putStrLn $ "Slice sum: " <> show (sumSlice $ V.fromList [0 .. 5]) diff --git a/garnet.cabal b/garnet.cabal index 2173eaf..981915b 100644 --- a/garnet.cabal +++ b/garnet.cabal @@ -34,6 +34,7 @@ common common mtl, process, text, + vector, library import: common diff --git a/lib/GarnetRs/Wrapped.hs b/lib/GarnetRs/Wrapped.hs index b6b7c9c..c3a68b4 100644 --- a/lib/GarnetRs/Wrapped.hs +++ b/lib/GarnetRs/Wrapped.hs @@ -9,12 +9,14 @@ module GarnetRs.Wrapped ( helloShape, add, sumTree, + sumSlice, ) where import Control.Monad.Cont import Control.Monad.Trans import Data.ByteString import Data.Function +import Data.Vector.Storable qualified as V import Data.Word import Foreign import Foreign.C @@ -67,3 +69,6 @@ add = Raw.add sumTree :: BTree Int64 -> Int64 sumTree = unsafePerformIO . flip withBTree (flip with $ Raw.sum_tree . unsafeFromPtr) + +sumSlice :: V.Vector Int64 -> Int64 +sumSlice v = unsafePerformIO $ V.unsafeWith v \p -> Raw.sum_slice (unsafeFromPtr p) (fromIntegral $ V.length v) diff --git a/rust/lib.rs b/rust/lib.rs index d80313d..167f24c 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -3,6 +3,7 @@ use std::{ ffi::{CStr, c_char}, ops::Add, + slice, }; fn say_hello(name: &str) { @@ -81,3 +82,8 @@ enum BTreeC { extern "C" fn sum_tree(t: &BTreeC) -> i64 { (unsafe { std::mem::transmute::<_, &BTree>(t) }).sum() } + +#[unsafe(no_mangle)] +extern "C" fn sum_slice(v: *const i64, s: usize) -> i64 { + unsafe { slice::from_raw_parts(v, s) }.iter().sum() +}