Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store the module data in an arc to simplify a lot of things #65

Merged
merged 6 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
rust: [1.76.0, stable, beta, nightly]
rust: [1.80.0, stable, beta, nightly]

# Test with no features, default features ("") and all features.
# Ordered fewest features to most features.
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ homepage = "https://www.nlnetlabs.nl/projects/routing/rotonda/"
keywords = ["routing", "bgp"]
categories = ["network-programming"]
license = "BSD-3-Clause"
rust-version = "1.76"
rust-version = "1.80"

[dependencies]
# ariadne is set to git because of some unpublished contributions made on
Expand Down
4 changes: 4 additions & 0 deletions examples/simple.roto
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ filter-map main(bla: Bla) {
}
}
}

filter-map just_reject(x: u32) {
apply { reject }
}
33 changes: 25 additions & 8 deletions examples/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,33 @@ fn main() -> Result<(), roto::RotoReport> {
.inspect_err(|e| eprintln!("{e}"))
.unwrap();

let func2 = compiled
.get_function::<(u32,), Verdict<(), ()>>("just_reject")
.inspect_err(|e| eprintln!("{e}"))
.unwrap();

// We should now be able to drop this safely, because each func has an Arc
// to the data it references.
drop(compiled);

for y in 0..20 {
let mut bla = Bla { _x: 1, y, _z: 1 };
let res = func.call(&mut bla as *mut _);

let expected = if y > 10 {
Verdict::Accept(y * 2)
} else {
Verdict::Reject(())
};
println!("main({y}) = {res:?} (expected: {expected:?})");

let func = func.clone();
std::thread::spawn(move || {
let res = func.call(&mut bla as *mut _);
let expected = if y > 10 {
Verdict::Accept(y * 2)
} else {
Verdict::Reject(())
};
println!("main({y}) = {res:?} (expected: {expected:?})");
})
.join()
.unwrap();

let res = func2.call(y);
println!("{res:?}");
}
Ok(())
}
12 changes: 6 additions & 6 deletions src/codegen/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use inetnum::asn::Asn;
use string_interner::{backend::StringBackend, StringInterner};

use crate::{
runtime::ty::{Reflect, TypeDescription, TypeRegistry},
runtime::ty::{
Reflect, TypeDescription, TypeRegistry, GLOBAL_TYPE_REGISTRY,
},
typechecker::{
info::TypeInfo,
types::{type_to_string, Primitive, Type},
Expand Down Expand Up @@ -56,13 +58,13 @@ impl Display for FunctionRetrievalError {
}

pub fn check_roto_type_reflect<T: Reflect>(
registry: &mut TypeRegistry,
type_info: &mut TypeInfo,
identifiers: &StringInterner<StringBackend>,
roto_ty: &Type,
) -> Result<(), TypeMismatch> {
let mut registry = GLOBAL_TYPE_REGISTRY.lock().unwrap();
let rust_ty = registry.resolve::<T>().type_id;
check_roto_type(registry, type_info, identifiers, rust_ty, roto_ty)
check_roto_type(&registry, type_info, identifiers, rust_ty, roto_ty)
}

#[allow(non_snake_case)]
Expand Down Expand Up @@ -170,7 +172,6 @@ pub fn return_type_by_ref(registry: &TypeRegistry, rust_ty: TypeId) -> bool {

pub trait RotoParams {
fn check(
registry: &mut TypeRegistry,
type_info: &mut TypeInfo,
identifiers: &StringInterner<StringBackend>,
ty: &[Type],
Expand Down Expand Up @@ -199,7 +200,6 @@ macro_rules! params {
$($t: Reflect,)*
{
fn check(
registry: &mut TypeRegistry,
type_info: &mut TypeInfo,
identifiers: &StringInterner<StringBackend>,
ty: &[Type]
Expand All @@ -216,7 +216,7 @@ macro_rules! params {
let mut i = 0;
$(
i += 1;
check_roto_type_reflect::<$t>(registry, type_info, identifiers, $t)
check_roto_type_reflect::<$t>(type_info, identifiers, $t)
.map_err(|e| FunctionRetrievalError::TypeMismatch(format!("argument {i}"), e))?;
)*
Ok(())
Expand Down
82 changes: 61 additions & 21 deletions src/codegen/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! Machine code generation via cranelift

use std::{
any::TypeId, collections::HashMap, marker::PhantomData, num::NonZeroU8,
sync::Arc,
any::TypeId, collections::HashMap, marker::PhantomData,
mem::ManuallyDrop, num::NonZeroU8, sync::Arc,
};

use crate::{
Expand All @@ -13,7 +13,7 @@ use crate::{
value::IrType,
IrFunction,
},
runtime::ty::{Reflect, TypeRegistry},
runtime::ty::{Reflect, GLOBAL_TYPE_REGISTRY},
typechecker::{info::TypeInfo, scope::ScopeRef, types},
IrValue,
};
Expand Down Expand Up @@ -45,27 +45,68 @@ pub mod check;
#[cfg(test)]
mod tests;

#[derive(Clone)]
pub struct ModuleData(Arc<ManuallyDrop<JITModule>>);

impl From<JITModule> for ModuleData {
fn from(value: JITModule) -> Self {
#[allow(clippy::arc_with_non_send_sync)]
Self(Arc::new(ManuallyDrop::new(value)))
}
}

impl Drop for ModuleData {
fn drop(&mut self) {
// get_mut returns None if we are not the last Arc, so the JITModule
// shouldn't be dropped yet.
let Some(module) = Arc::get_mut(&mut self.0) else {
return;
};

// SAFETY: We only give out functions that hold a ModuleData and
// therefore an Arc to this module. By `get_mut`, we know that we are
// the last Arc to this memory and hence it is safe to free its
// memory. New Arcs cannot have been created in the meantime because
// that requires access to the last Arc, which we know that we have.
unsafe {
let inner = ManuallyDrop::take(module);
inner.free_memory();
};
}
}

unsafe impl Send for ModuleData {}
unsafe impl Sync for ModuleData {}

/// A compiled, ready-to-run Roto module
pub struct Module {
/// The set of public functions and their signatures.
functions: HashMap<String, FunctionInfo>,

/// The inner cranelift module
inner: JITModule,
inner: ModuleData,

/// Info from the typechecker for checking types against Rust types
type_info: TypeInfo,
}

pub struct TypedFunc<'module, Params, Return> {
#[derive(Clone)]
pub struct TypedFunc<Params, Return> {
func: *const u8,
return_by_ref: bool,
_ty: PhantomData<&'module (Params, Return)>,

// The module holds the data for this function, that's why we need
// to ensure that it doesn't get dropped. This field is ESSENTIAL
// for the safety of calling this function. Without it, the data that
// the `func` pointer points to might have been dropped.
_module: ModuleData,
_ty: PhantomData<(Params, Return)>,
}

impl<'module, Params: RotoParams, Return: Reflect>
TypedFunc<'module, Params, Return>
{
unsafe impl<Params, Return> Send for TypedFunc<Params, Return> {}
unsafe impl<Params, Return> Sync for TypedFunc<Params, Return> {}

impl<Params: RotoParams, Return: Reflect> TypedFunc<Params, Return> {
pub fn call_tuple(&self, params: Params) -> Return {
unsafe {
Params::invoke::<Return>(self.func, params, self.return_by_ref)
Expand All @@ -75,17 +116,18 @@ impl<'module, Params: RotoParams, Return: Reflect>

macro_rules! call_impl {
($($ty:ident),*) => {
impl<'module, $($ty,)* Return: Reflect> TypedFunc<'module, ($($ty,)*), Return>
impl<$($ty,)* Return: Reflect> TypedFunc<($($ty,)*), Return>
where
($($ty,)*): RotoParams,
{
#[allow(non_snake_case)]
#[allow(clippy::too_many_arguments)]
pub fn call(&self, $($ty: $ty,)*) -> Return {
self.call_tuple(($($ty,)*))
}

#[allow(non_snake_case)]
pub fn as_func(self) -> impl Fn($($ty,)*) -> Return + 'module {
pub fn into_func(self) -> impl Fn($($ty,)*) -> Return {
move |$($ty,)*| self.call($($ty,)*)
}
}
Expand Down Expand Up @@ -342,7 +384,7 @@ impl ModuleBuilder<'_> {
self.inner.finalize_definitions().unwrap();
Module {
functions: self.functions,
inner: self.inner,
inner: self.inner.into(),
type_info: self.type_info,
}
}
Expand Down Expand Up @@ -749,13 +791,11 @@ impl<'a, 'c> FuncGen<'a, 'c> {
}

impl Module {
pub fn get_function<'module, Params: RotoParams, Return: Reflect>(
&'module mut self,
type_registry: &mut TypeRegistry,
pub fn get_function<Params: RotoParams, Return: Reflect>(
&mut self,
identifiers: &StringInterner<StringBackend>,
name: &str,
) -> Result<TypedFunc<'module, Params, Return>, FunctionRetrievalError>
{
) -> Result<TypedFunc<Params, Return>, FunctionRetrievalError> {
let function_info = self.functions.get(name).ok_or_else(|| {
FunctionRetrievalError::DoesNotExist {
name: name.to_string(),
Expand All @@ -767,14 +807,12 @@ impl Module {
let id = function_info.id;

Params::check(
type_registry,
&mut self.type_info,
identifiers,
&sig.parameter_types,
)?;

check_roto_type_reflect::<Return>(
type_registry,
&mut self.type_info,
identifiers,
&sig.return_type,
Expand All @@ -786,13 +824,15 @@ impl Module {
)
})?;

let registry = GLOBAL_TYPE_REGISTRY.lock().unwrap();
let return_by_ref =
return_type_by_ref(type_registry, TypeId::of::<Return>());
return_type_by_ref(&registry, TypeId::of::<Return>());

let func_ptr = self.inner.get_finalized_function(id);
let func_ptr = self.inner.0.get_finalized_function(id);
Ok(TypedFunc {
func: func_ptr,
return_by_ref,
_module: self.inner.clone(),
_ty: PhantomData,
})
}
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ fn record_with_fields_flipped() {
let f = p
.get_function::<(i32,), Verdict<(), ()>>("main")
.expect("No function found (or mismatched types)")
.as_func();
.into_func();

for x in 0..100 {
let expected = if x == 20 {
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub(crate) mod typechecker;
pub(crate) mod pipeline;
pub(crate) mod runtime;

pub use codegen::TypedFunc;
pub use lower::eval::Memory;
pub use lower::value::IrValue;
pub use pipeline::*;
Expand Down
4 changes: 2 additions & 2 deletions src/lower/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,8 @@ fn call_runtime_function(
func: &RuntimeFunction,
args: Vec<IrValue>,
) -> Option<IrValue> {
assert_eq!(func.description.parameter_types.len(), args.len());
(func.description.wrapped)(args)
assert_eq!(func.description.parameter_types().len(), args.len());
(func.description.wrapped())(args)
}

fn eval_operand<'a>(
Expand Down
4 changes: 2 additions & 2 deletions src/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ impl<'r> Lowerer<'r> {

let ir_func = IrFunction {
name: ident.node,
ptr: runtime_func.description.pointer,
ptr: runtime_func.description.pointer(),
params,
ret: if self.type_info.size_of(&ret) > 0 {
Some(self.lower_type(&ret))
Expand Down Expand Up @@ -610,7 +610,7 @@ impl<'r> Lowerer<'r> {

let ir_func = IrFunction {
name: m.node,
ptr: runtime_func.description.pointer,
ptr: runtime_func.description.pointer(),
params,
ret: if self.type_info.size_of(&ret) > 0 {
Some(self.lower_type(&ret))
Expand Down
Loading
Loading