diff --git a/Cargo.toml b/Cargo.toml index b9373bb0..ee2f38c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,11 +20,6 @@ categories = ["network-programming"] # is released ariadne = { version = "0.4.0", git = "https://github.com/zesterer/ariadne.git", rev = "4e5987cd55d954858da4a4130255eca1bf0bee5f" } clap = { version = "4.4.6", features = ["derive"] } -cranelift = "0.107.0" -cranelift-codegen = { version = "0.107.0", features = ["disas"] } -cranelift-jit = "0.107.0" -cranelift-module = "0.107.0" -cranelift-native = "0.107.0" env_logger = "0.10" log = "0.4" logos = "0.14.0" @@ -33,6 +28,18 @@ symbol_table = { version = "0.3.0", features = ["global"] } string-interner = "0.17.0" roto-macros = { path = "macros" } +[dependencies.cranelift] +version = "0.113.0" +features = ["frontend", "jit", "module", "native"] +git = "https://github.com/bytecodealliance/wasmtime.git" +rev = "1af294ea2d6c18c5a8fa9b4f272398b7c98e0c48" + +[dependencies.cranelift-codegen] +version = "0.113.0" +features = ["disas"] +git = "https://github.com/bytecodealliance/wasmtime.git" +rev = "1af294ea2d6c18c5a8fa9b4f272398b7c98e0c48" + [dev-dependencies] bytes = "1" routecore = { version = "0.4.0", features = ["bgp", "bmp"] } diff --git a/examples/simple.rs b/examples/simple.rs index cfe5b113..ce794264 100644 --- a/examples/simple.rs +++ b/examples/simple.rs @@ -1,5 +1,5 @@ use roto::{read_files, Runtime, Verdict}; -use roto_macros::roto_function; +use roto_macros::roto_method; struct Bla { _x: u16, @@ -7,18 +7,17 @@ struct Bla { _z: u32, } -#[roto_function] -fn get_y(bla: *const Bla) -> u32 { - unsafe { &*bla }.y -} - fn main() -> Result<(), roto::RotoReport> { env_logger::init(); let mut runtime = Runtime::basic().unwrap(); runtime.register_type::().unwrap(); - runtime.register_method::("y", get_y).unwrap(); + + #[roto_method(runtime, Bla, y)] + fn get_y(bla: *const Bla) -> u32 { + unsafe { &*bla }.y + } let mut compiled = read_files(["examples/simple.roto"])? .compile(runtime, usize::BITS / 8) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index f738803f..4782627d 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,30 +1,153 @@ use proc_macro::TokenStream; use quote::quote; -use syn::{parse_macro_input, ItemFn}; - -/// -/// -/// ```rust,no_run -/// fn foo(a1: A1, a2: A2) -> Ret { -/// /* ... */ -/// } -/// ``` -/// -/// +use syn::{parse_macro_input, Token}; + +struct Intermediate { + function: proc_macro2::TokenStream, + name: syn::Ident, + identifier: proc_macro2::TokenStream, +} + +struct FunctionArgs { + runtime_ident: syn::Ident, + name: Option, +} + +impl syn::parse::Parse for FunctionArgs { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let runtime_ident = input.parse()?; + + let mut name = None; + if input.peek(Token![,]) { + input.parse::()?; + if input.peek(syn::Ident) { + name = input.parse()?; + } + } + + Ok(Self { + runtime_ident, + name, + }) + } +} + +#[proc_macro_attribute] +pub fn roto_function(attr: TokenStream, item: TokenStream) -> TokenStream { + let item = parse_macro_input!(item as syn::ItemFn); + let Intermediate { + function, + identifier, + name: function_ident, + } = generate_function(item); + + let FunctionArgs { + runtime_ident, + name, + } = syn::parse(attr).unwrap(); + + let name = name.unwrap_or(function_ident); + + let expanded = quote! { + #function + + #runtime_ident.register_function(stringify!(#name), #identifier).unwrap(); + }; + + TokenStream::from(expanded) +} + +struct MethodArgs { + runtime_ident: syn::Ident, + ty: syn::Type, + name: Option, +} + +impl syn::parse::Parse for MethodArgs { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let runtime_ident = input.parse()?; + input.parse::()?; + let ty = input.parse()?; + + let mut name = None; + if input.peek(Token![,]) { + input.parse::()?; + if input.peek(syn::Ident) { + name = input.parse()?; + } + } + Ok(Self { + runtime_ident, + ty, + name, + }) + } +} + +#[proc_macro_attribute] +pub fn roto_method(attr: TokenStream, item: TokenStream) -> TokenStream { + let item = parse_macro_input!(item as syn::ItemFn); + let Intermediate { + function, + identifier, + name: function_name, + } = generate_function(item); + + let MethodArgs { + runtime_ident, + ty, + name, + } = parse_macro_input!(attr as MethodArgs); + + let name = name.unwrap_or(function_name); + + let expanded = quote! { + #function + + #runtime_ident.register_method::<#ty, _, _>(stringify!(#name), #identifier).unwrap(); + }; + + TokenStream::from(expanded) +} + #[proc_macro_attribute] -pub fn roto_function(_attr: TokenStream, item: TokenStream) -> TokenStream { - let input = parse_macro_input!(item as ItemFn); +pub fn roto_static_method( + attr: TokenStream, + item: TokenStream, +) -> TokenStream { + let item = parse_macro_input!(item as syn::ItemFn); + let Intermediate { + function, + identifier, + name: function_name, + } = generate_function(item); + + let MethodArgs { + runtime_ident, + ty, + name, + } = parse_macro_input!(attr as MethodArgs); - let ItemFn { + let name = name.unwrap_or(function_name); + + let expanded = quote! { + #function + + #runtime_ident.register_static_method::<#ty, _, _>(stringify!(#name), #identifier).unwrap(); + }; + + TokenStream::from(expanded) +} + +fn generate_function(item: syn::ItemFn) -> Intermediate { + let syn::ItemFn { attrs, vis, sig, block: _, - } = input.clone(); + } = item.clone(); assert!(sig.unsafety.is_none()); - assert!(sig.generics.params.is_empty()); - assert!(sig.generics.where_clause.is_none()); assert!(sig.variadic.is_none()); let ident = sig.ident; @@ -35,6 +158,7 @@ pub fn roto_function(_attr: TokenStream, item: TokenStream) -> TokenStream { pat }); + let generics = sig.generics; let inputs = sig.inputs.clone().into_iter(); let ret = match sig.output { syn::ReturnType::Default => quote!(()), @@ -53,21 +177,25 @@ pub fn roto_function(_attr: TokenStream, item: TokenStream) -> TokenStream { }) .collect(); - let arg_types = quote!(*mut #ret, #(#input_types,)*); + let underscored_types = input_types.iter().map(|_| quote!(_)); + let arg_types = quote!(_, #(#underscored_types,)*); - let expanded = quote! { - #[allow(non_upper_case_globals)] - #vis const #ident: extern "C" fn(#arg_types) = { - #(#attrs)* - extern "C" fn #ident ( out: *mut #ret, #(#inputs,)* ) { - #input + let function = quote! { + #(#attrs)* + #vis extern "C" fn #ident #generics ( out: *mut #ret, #(#inputs,)* ) { + #item - unsafe { *out = #ident(#(#args),*) }; - } + unsafe { *out = #ident(#(#args),*) }; + } + }; - #ident as extern "C" fn(#arg_types) - }; + let identifier = quote! { + #ident as extern "C" fn(#arg_types) }; - TokenStream::from(expanded) + Intermediate { + function, + name: ident, + identifier, + } } diff --git a/src/ast.rs b/src/ast.rs index 00a246b2..e7f9113b 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -252,7 +252,7 @@ pub enum Literal { #[allow(dead_code)] String(String), Asn(Asn), - IpAddress(IpAddress), + IpAddress(std::net::IpAddr), Integer(i64), Bool(bool), } @@ -313,9 +313,3 @@ impl std::fmt::Display for BinOp { ) } } - -#[derive(Clone, Debug)] -pub enum IpAddress { - Ipv4(std::net::Ipv4Addr), - Ipv6(std::net::Ipv6Addr), -} diff --git a/src/codegen/check.rs b/src/codegen/check.rs index 137aadd7..725d6a72 100644 --- a/src/codegen/check.rs +++ b/src/codegen/check.rs @@ -1,4 +1,4 @@ -use inetnum::asn::Asn; +use inetnum::{addr::Prefix, asn::Asn}; use crate::{ runtime::ty::{ @@ -9,7 +9,7 @@ use crate::{ types::{Primitive, Type}, }, }; -use std::{any::TypeId, fmt::Display, mem::MaybeUninit}; +use std::{any::TypeId, fmt::Display, mem::MaybeUninit, net::IpAddr}; #[derive(Debug)] pub enum FunctionRetrievalError { @@ -39,7 +39,7 @@ impl Display for FunctionRetrievalError { expected, got, } => { - writeln!(f, "The numer of arguments do not match")?; + writeln!(f, "The number of arguments do not match")?; writeln!(f, "The Roto function has {expected} arguments, but the Rust function has {got}.") } FunctionRetrievalError::TypeMismatch( @@ -84,6 +84,8 @@ fn check_roto_type( let I64: TypeId = TypeId::of::(); let UNIT: TypeId = TypeId::of::<()>(); let ASN: TypeId = TypeId::of::(); + let IPADDR: TypeId = TypeId::of::(); + let PREFIX: TypeId = TypeId::of::(); let Some(rust_ty) = registry.get(rust_ty) else { return Err(TypeMismatch { @@ -117,6 +119,8 @@ fn check_roto_type( x if x == I64 => Type::Primitive(Primitive::I64), x if x == UNIT => Type::Primitive(Primitive::Unit), x if x == ASN => Type::Primitive(Primitive::Asn), + x if x == IPADDR => Type::Primitive(Primitive::IpAddr), + x if x == PREFIX => Type::Primitive(Primitive::Prefix), _ => panic!(), }; if expected_roto == roto_ty { @@ -151,29 +155,55 @@ pub fn return_type_by_ref(registry: &TypeRegistry, rust_ty: TypeId) -> bool { #[allow(clippy::match_like_matches_macro)] match rust_ty.description { TypeDescription::Verdict(_, _) => true, - _ => false, + _ => todo!(), } } +/// Parameters of a Roto function +/// +/// This trait allows for checking the types against Roto types and converting +/// the values into values appropriate for Roto. +/// +/// The `invoke` method can (unsafely) invoke a pointer as if it were a function +/// with these parameters. +/// +/// This trait is implemented on tuples of various sizes. pub trait RotoParams { + /// This type but with [`Reflect::AsParam`] applied to each element. + type AsParams; + + /// Convert to `Self::AsParams`. + fn as_params(&mut self) -> Self::AsParams; + + /// Check whether these parameters match a parameter list from Roto. fn check( type_info: &mut TypeInfo, ty: &[Type], ) -> Result<(), FunctionRetrievalError>; + /// Call a function pointer as if it were a function with these parameters. + /// + /// This is _extremely_ unsafe, do not pass this arbitrary pointers and + /// always call `RotoParams::check` before calling this function. Don't + /// forget to also check the return type. + /// + /// A [`TypedFunc`](super::TypedFunc) is a safe abstraction around this + /// function. unsafe fn invoke( + self, func_ptr: *const u8, - params: Self, return_by_ref: bool, ) -> R; } +/// Little helper macro to create a unit macro_rules! unit { ($t:tt) => { () }; } +/// Implement the [`RotoParams`] trait for a tuple with some type parameters. macro_rules! params { ($($t:ident),*) => { #[allow(non_snake_case)] @@ -183,6 +213,13 @@ macro_rules! params { where $($t: Reflect,)* { + type AsParams = ($($t::AsParam,)*); + + fn as_params(&mut self) -> Self::AsParams { + let ($($t,)*) = self; + return ($($t.as_param(),)*); + } + fn check( type_info: &mut TypeInfo, ty: &[Type] @@ -205,17 +242,18 @@ macro_rules! params { Ok(()) } - unsafe fn invoke(func_ptr: *const u8, ($($t,)*): Self, return_by_ref: bool) -> R { + unsafe fn invoke(mut self, func_ptr: *const u8, return_by_ref: bool) -> R { + let ($($t,)*) = self.as_params(); if return_by_ref { let func_ptr = unsafe { - std::mem::transmute::<*const u8, fn(*mut R, $($t),*) -> ()>(func_ptr) + std::mem::transmute::<*const u8, fn(*mut R, $($t::AsParam),*) -> ()>(func_ptr) }; let mut ret = MaybeUninit::::uninit(); func_ptr(ret.as_mut_ptr(), $($t),*); unsafe { ret.assume_init() } } else { let func_ptr = unsafe { - std::mem::transmute::<*const u8, fn($($t),*) -> R>(func_ptr) + std::mem::transmute::<*const u8, fn($($t::AsParam),*) -> R>(func_ptr) }; func_ptr($($t),*) } diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs index dce57ebc..5414c831 100644 --- a/src/codegen/mod.rs +++ b/src/codegen/mod.rs @@ -2,7 +2,7 @@ use std::{ any::TypeId, collections::HashMap, marker::PhantomData, - mem::ManuallyDrop, num::NonZeroU8, sync::Arc, + mem::ManuallyDrop, sync::Arc, }; use crate::{ @@ -34,10 +34,8 @@ use cranelift::{ frontend::{ FuncInstBuilder, FunctionBuilder, FunctionBuilderContext, Switch, Variable, - }, + }, jit::{JITBuilder, JITModule}, module::{DataDescription, FuncId, Linkage, Module as _}, }; -use cranelift_jit::{JITBuilder, JITModule}; -use cranelift_module::{FuncId, Linkage, Module as _}; use log::info; pub mod check; @@ -117,9 +115,7 @@ unsafe impl Sync for TypedFunc {} impl TypedFunc { pub fn call_tuple(&self, params: Params) -> Return { - unsafe { - Params::invoke::(self.func, params, self.return_by_ref) - } + unsafe { params.invoke::(self.func, self.return_by_ref) } } } @@ -162,7 +158,7 @@ struct ModuleBuilder { functions: HashMap, /// External functions - runtime_functions: HashMap, + runtime_functions: HashMap, /// The inner cranelift module inner: JITModule, @@ -203,7 +199,7 @@ const MEMFLAGS: MemFlags = MemFlags::new().with_aligned(); pub fn codegen( ir: &[ir::Function], - runtime_functions: &HashMap, + runtime_functions: &HashMap, label_store: LabelStore, type_info: TypeInfo, ) -> Module { @@ -214,15 +210,15 @@ pub fn codegen( let mut settings = settings::builder(); settings.set("opt_level", "speed").unwrap(); let flags = settings::Flags::new(settings); - let isa = cranelift_native::builder().unwrap().finish(flags).unwrap(); + let isa = cranelift::native::builder().unwrap().finish(flags).unwrap(); let mut builder = JITBuilder::with_isa( isa.to_owned(), - cranelift_module::default_libcall_names(), + cranelift::module::default_libcall_names(), ); for (name, func) in runtime_functions { - builder.symbol(name, func.ptr); + builder.symbol(format!("runtime_function_{name}"), func.ptr); } let jit = JITModule::new(builder); @@ -237,7 +233,7 @@ pub fn codegen( type_info, }; - for (name, func) in runtime_functions { + for (roto_func_id, func) in runtime_functions { let mut sig = module.inner.make_signature(); for ty in &func.params { sig.params.push(AbiParam::new(module.cranelift_type(ty))); @@ -245,12 +241,14 @@ pub fn codegen( if let Some(ty) = &func.ret { sig.returns.push(AbiParam::new(module.cranelift_type(ty))); } - let Ok(func_id) = - module.inner.declare_function(name, Linkage::Import, &sig) - else { + let Ok(func_id) = module.inner.declare_function( + &format!("runtime_function_{roto_func_id}"), + Linkage::Import, + &sig, + ) else { panic!() }; - module.runtime_functions.insert(name.clone(), func_id); + module.runtime_functions.insert(*roto_func_id, func_id); } // Our functions might call each other, so we declare them before we @@ -397,7 +395,6 @@ impl ModuleBuilder { IrType::U16 | IrType::I16 => I16, IrType::U32 | IrType::I32 | IrType::Asn => I32, IrType::U64 | IrType::I64 => I64, - IrType::IpAddr => I32, IrType::Pointer | IrType::ExtPointer => self.isa.pointer_type(), IrType::ExtValue => todo!(), } @@ -531,7 +528,7 @@ impl<'c> FuncGen<'c> { } } ir::Instruction::CallRuntime { to, func, args } => { - let func_id = self.module.runtime_functions[&func.name]; + let func_id = self.module.runtime_functions[&func.id]; let func_ref = self .module .inner @@ -638,17 +635,76 @@ impl<'c> FuncGen<'c> { }; self.def(var, val) } + ir::Instruction::Extend { to, ty, from } => { + let ty = self.module.cranelift_type(ty); + let (from, _) = self.operand(from); + let val = self.ins().uextend(ty, from); + + let var = self.variable(to, ty); + self.def(var, val) + } ir::Instruction::Eq { .. } => todo!(), - ir::Instruction::Alloc { to, size } => { - let slot = self.builder.create_sized_stack_slot( - StackSlotData::new(StackSlotKind::ExplicitSlot, *size), - ); + ir::Instruction::Alloc { + to, + size, + align_shift, + } => { + let slot = + self.builder.create_sized_stack_slot(StackSlotData::new( + StackSlotKind::ExplicitSlot, + *size, + *align_shift, + )); let pointer_ty = self.module.isa.pointer_type(); let var = self.variable(to, pointer_ty); let p = self.ins().stack_addr(pointer_ty, slot, 0); self.def(var, p); } + ir::Instruction::Initialize { + to, + bytes, + align_shift, + } => { + let pointer_ty = self.module.isa.pointer_type(); + let slot = + self.builder.create_sized_stack_slot(StackSlotData::new( + StackSlotKind::ExplicitSlot, + bytes.len() as u32, + *align_shift, + )); + + let data_id = self + .module + .inner + .declare_anonymous_data(false, false) + .unwrap(); + let mut data_description = DataDescription::new(); + data_description.define(bytes.clone().into_boxed_slice()); + self.module + .inner + .define_data(data_id, &data_description) + .unwrap(); + let global_value = self + .module + .inner + .declare_data_in_func(data_id, self.builder.func); + let value = self.ins().global_value(pointer_ty, global_value); + + let var = self.variable(to, pointer_ty); + let p = self.ins().stack_addr(pointer_ty, slot, 0); + self.builder.emit_small_memory_copy( + self.module.isa.frontend_config(), + p, + value, + bytes.len() as u64, + 0, + 0, + true, + MEMFLAGS, + ); + self.def(var, p); + } ir::Instruction::Write { to, val } => { let (x, _) = self.operand(val); let (to, _) = self.operand(to); @@ -689,20 +745,18 @@ impl<'c> FuncGen<'c> { } => { let (left, _) = self.operand(left); let (right, _) = self.operand(right); + let (size, _) = self.operand(size); // We could pass more precise alignment to cranelift, but // values of 1 should just work. - let val = self.builder.emit_small_memory_compare( + let val = self.builder.call_memcmp( self.module.isa.frontend_config(), - IntCC::Equal, left, right, - *size as u64, - NonZeroU8::new(1).unwrap(), - NonZeroU8::new(1).unwrap(), - MEMFLAGS, + size, ); - let var = self.variable(to, I8); + + let var = self.variable(to, I32); self.def(var, val); } } diff --git a/src/codegen/tests.rs b/src/codegen/tests.rs index 7bf32367..6e4eae3d 100644 --- a/src/codegen/tests.rs +++ b/src/codegen/tests.rs @@ -1,5 +1,7 @@ -use inetnum::asn::Asn; -use roto_macros::roto_function; +use std::net::IpAddr; + +use inetnum::{addr::Prefix, asn::Asn}; +use roto_macros::{roto_function, roto_static_method}; use crate::{ pipeline::Compiled, runtime::tests::routecore_runtime, src, Files, @@ -543,20 +545,18 @@ fn int_var() { #[test] fn issue_52() { - let mut rt = Runtime::basic().unwrap(); - struct Foo { _x: i32, } - #[roto_function] + let mut rt = Runtime::basic().unwrap(); + rt.register_type::().unwrap(); + + #[roto_static_method(rt, Foo)] fn bar(_x: u32) -> u32 { 2 } - rt.register_type::().unwrap(); - rt.register_static_method::("bar", bar).unwrap(); - let s = src!( " filter-map main(foo: Foo) { @@ -666,6 +666,191 @@ fn multiply() { assert_eq!(res, Verdict::Accept(40)); } +#[test] +fn ip_output() { + let s = src!( + " + filter-map main() { + apply { accept 1.2.3.4 } + } + " + ); + + let mut p = compile(s); + let f = p + .get_function::<(), Verdict>("main") + .expect("No function found (or mismatched types)"); + + let ip = IpAddr::from([1, 2, 3, 4]); + let res = f.call(); + assert_eq!(res, Verdict::Accept(ip)); +} + +#[test] +fn ip_passthrough() { + let s = src!( + " + filter-map main(x: IpAddr) { + apply { accept x } + } + " + ); + + let mut p = compile(s); + let f = p + .get_function::<(IpAddr,), Verdict>("main") + .expect("No function found (or mismatched types)"); + + let ip = IpAddr::from([1, 2, 3, 4]); + let res = f.call(ip); + assert_eq!(res, Verdict::Accept(ip)); +} + +#[test] +fn ipv4_compare() { + let s = src!( + " + filter-map main(x: IpAddr) { + apply { + if x == 0.0.0.0 { + accept x + } else if x == 192.168.0.0 { + accept x + } else { + reject x + } + } + } + " + ); + + let mut p = compile(s); + let f = p + .get_function::<(IpAddr,), Verdict>("main") + .expect("No function found (or mismatched types)"); + + let ip = IpAddr::from([0, 0, 0, 0]); + let res = f.call(ip); + assert_eq!(res, Verdict::Accept(ip)); + let ip = IpAddr::from([192, 168, 0, 0]); + let res = f.call(ip); + assert_eq!(res, Verdict::Accept(ip)); + + let ip = IpAddr::from([1, 2, 3, 4]); + let res = f.call(ip); + assert_eq!(res, Verdict::Reject(ip)); + + let ip = IpAddr::from([0, 0, 0, 0, 0, 0, 0, 0]); + let res = f.call(ip); + assert_eq!(res, Verdict::Reject(ip)); +} + +#[test] +fn ipv6_compare() { + let s = src!( + " + filter-map main(x: IpAddr) { + apply { + if x == :: { + accept x + } else if x == 192.168.0.0 { + accept x + } else if x == ::1 { + accept x + } else { + reject x + } + } + } + " + ); + + let mut p = compile(s); + let f = p + .get_function::<(IpAddr,), Verdict>("main") + .expect("No function found (or mismatched types)"); + + let ip = IpAddr::from([0, 0, 0, 0, 0, 0, 0, 0]); + let res = f.call(ip); + assert_eq!(res, Verdict::Accept(ip)); + + let ip = IpAddr::from([0, 0, 0, 0, 0, 0, 0, 1]); + let res = f.call(ip); + assert_eq!(res, Verdict::Accept(ip)); + + let ip = IpAddr::from([192, 168, 0, 0]); + let res = f.call(ip); + assert_eq!(res, Verdict::Accept(ip)); + + let ip = IpAddr::from([1, 2, 3, 4]); + let res = f.call(ip); + assert_eq!(res, Verdict::Reject(ip)); +} + +#[test] +fn construct_prefix() { + let s = src!( + " + filter-map main() { + apply { + accept 192.168.0.0 / 16 + } + } + " + ); + let mut p = compile(s); + let f = p + .get_function::<(), Verdict>("main") + .expect("No function found (or mismatched types)"); + + let p = Prefix::new("192.168.0.0".parse().unwrap(), 16).unwrap(); + let res = f.call(); + assert_eq!(res, Verdict::Accept(p)); +} + +#[test] +fn function_returning_unit() { + let mut runtime = Runtime::basic().unwrap(); + + #[roto_function(runtime)] + fn unit_unit() {} + + let s = src!( + " + filter-map main() { + apply { + accept unit_unit() + } + } + " + ); + + let mut p = compile_with_runtime(s, runtime); + let f = p + .get_function::<(), Verdict<(), ()>>("main") + .expect("No function found (or mismatched types)"); + + let res = f.call(); + assert_eq!(res, Verdict::Accept(())); +} + +#[test] +fn functions_with_lifetimes() { + struct Foo<'a> { + _x: &'a u32, + } + + let mut rt = Runtime::basic().unwrap(); + rt.register_type::().unwrap(); + + #[roto_function(rt)] + fn funcy(_foo: *const Foo) {} + + #[allow(clippy::needless_lifetimes)] + #[roto_function(rt)] + fn funcy2<'a>(_foo: *const Foo<'a>) {} +} + // #[test] // fn bmp_message() { // let s = " diff --git a/src/lower/eval.rs b/src/lower/eval.rs index d8419fc4..6ecd626a 100644 --- a/src/lower/eval.rs +++ b/src/lower/eval.rs @@ -4,6 +4,8 @@ //! fairly slow. This is because all variables at this point are identified //! by strings and therefore stored as a hashmap. +use log::trace; + use super::ir::{Function, Operand, Var}; use crate::{ ast::Identifier, @@ -308,6 +310,7 @@ pub fn eval( loop { let instruction = &instructions[program_counter]; + trace!("{:?}", &instruction); match instruction { Instruction::Jump(b) => { program_counter = block_map[b]; @@ -490,6 +493,11 @@ pub fn eval( debug_assert_eq!(*ty, res.get_type()); vars.insert(to.clone(), res); } + Instruction::Extend { to, ty, from } => { + let val = eval_operand(&vars, from); + let val = val.as_vec(); + vars.insert(to.clone(), IrValue::from_slice(ty, &val)); + } Instruction::Offset { to, from, offset } => { let &IrValue::Pointer(from) = eval_operand(&vars, from) else { @@ -498,10 +506,15 @@ pub fn eval( let new = mem.offset_by(from, *offset as usize); vars.insert(to.clone(), IrValue::Pointer(new)); } - Instruction::Alloc { to, size } => { + Instruction::Alloc { to, size, align_shift: _ } => { let pointer = mem.allocate(*size as usize); vars.insert(to.clone(), IrValue::Pointer(pointer)); } + Instruction::Initialize { to, bytes, align_shift: _ } => { + let pointer = mem.allocate(bytes.len()); + mem.write(pointer, bytes); + vars.insert(to.clone(), IrValue::Pointer(pointer)); + } Instruction::Write { to, val } => { let &IrValue::Pointer(to) = eval_operand(&vars, to) else { panic!() @@ -543,10 +556,18 @@ pub fn eval( else { panic!() }; - let left = mem.read_slice(left, *size as usize); - let right = mem.read_slice(right, *size as usize); - let res = left == right; - vars.insert(to.clone(), IrValue::Bool(res)); + let &IrValue::Pointer(size) = eval_operand(&vars, size) + else { + panic!() + }; + let left = mem.read_slice(left, size); + let right = mem.read_slice(right, size); + let res = match left.cmp(right) { + std::cmp::Ordering::Less => -1isize as usize, + std::cmp::Ordering::Equal => 0, + std::cmp::Ordering::Greater => 1, + }; + vars.insert(to.clone(), IrValue::Pointer(res)); } } diff --git a/src/lower/ir.rs b/src/lower/ir.rs index 68dbcdb0..0db70667 100644 --- a/src/lower/ir.rs +++ b/src/lower/ir.rs @@ -160,17 +160,31 @@ pub enum Instruction { right: Operand, }, - // Add offset to a pointer + Extend { + to: Var, + ty: IrType, + from: Operand, + }, + + /// Add offset to a pointer Offset { to: Var, from: Operand, offset: u32, }, - // Allocate a stack slot + /// Allocate a stack slot Alloc { to: Var, size: u32, + align_shift: u8, + }, + + /// Write literal bytes to a variable + Initialize { + to: Var, + bytes: Vec, + align_shift: u8, }, /// Write to a stack slot @@ -196,7 +210,7 @@ pub enum Instruction { /// Compare chunks of memory MemCmp { to: Var, - size: u32, + size: Operand, left: Operand, right: Operand, }, @@ -455,6 +469,13 @@ impl<'a> IrPrinter<'a> { self.operand(right), ) } + Extend { to, ty, from } => { + format!( + "{}: extend({ty}, {})", + self.var(to), + self.operand(from), + ) + } Jump(to) => { format!("jump {}", self.label(to)) } @@ -474,8 +495,19 @@ impl<'a> IrPrinter<'a> { self.label(default) ) } - Alloc { to, size } => { - format!("{} = mem::alloc({size})", self.var(to)) + Alloc { to, size, align_shift } => { + format!("{} = mem::alloc(size={size}, align_shift={align_shift})", self.var(to)) + } + Initialize { to, bytes, align_shift } => { + format!( + "{} = mem::initialize([{}], align_shift={align_shift})", + self.var(to), + bytes + .iter() + .map(|b| b.to_string()) + .collect::>() + .join(", ") + ) } Offset { to, from, offset } => { format!( @@ -512,10 +544,11 @@ impl<'a> IrPrinter<'a> { right, } => { format!( - "{} = mem::cmp({}, {}, {size})", + "{} = mem::cmp({}, {}, {})", self.var(to), self.operand(left), self.operand(right), + self.operand(size), ) } } diff --git a/src/lower/mod.rs b/src/lower/mod.rs index ec65f6fe..a1fb45c9 100644 --- a/src/lower/mod.rs +++ b/src/lower/mod.rs @@ -13,21 +13,21 @@ mod test_eval; use ir::{Block, Function, Instruction, Operand, Var, VarKind}; use label::{LabelRef, LabelStore}; -use std::{collections::HashMap, net::IpAddr}; +use std::{any::TypeId, collections::HashMap, net::IpAddr}; use value::IrType; use crate::{ - ast::{self, Expr, Identifier, Literal}, - parser::meta::{Meta, MetaId}, - runtime::RuntimeFunction, + ast::{self, Identifier, Literal}, + parser::meta::Meta, + runtime::{self, RuntimeFunction}, typechecker::{ - self, info::TypeInfo, scope::{DefinitionRef, ScopeRef}, types::{ FunctionDefinition, FunctionKind, Primitive, Signature, Type, }, }, + Runtime, }; use self::value::IrValue; @@ -61,25 +61,28 @@ struct Lowerer<'r> { tmp_idx: usize, blocks: Vec, type_info: &'r mut TypeInfo, - runtime_functions: &'r mut HashMap, + runtime_functions: &'r mut HashMap, label_store: &'r mut LabelStore, + runtime: &'r Runtime, } pub fn lower( tree: &ast::SyntaxTree, type_info: &mut TypeInfo, - runtime_functions: &mut HashMap, + runtime_functions: &mut HashMap, label_store: &mut LabelStore, + runtime: &Runtime, ) -> Vec { - Lowerer::tree(type_info, runtime_functions, tree, label_store) + Lowerer::tree(type_info, runtime_functions, tree, label_store, runtime) } impl<'r> Lowerer<'r> { fn new( type_info: &'r mut TypeInfo, - runtime_functions: &'r mut HashMap, + runtime_functions: &'r mut HashMap, function_name: &Meta, label_store: &'r mut LabelStore, + runtime: &'r Runtime, ) -> Self { let function_scope = type_info.function_scope(function_name); Self { @@ -90,6 +93,7 @@ impl<'r> Lowerer<'r> { runtime_functions, blocks: Vec::new(), label_store, + runtime, } } @@ -131,9 +135,10 @@ impl<'r> Lowerer<'r> { /// Lower a syntax tree fn tree( type_info: &mut TypeInfo, - runtime_functions: &mut HashMap, + runtime_functions: &mut HashMap, tree: &ast::SyntaxTree, label_store: &'r mut LabelStore, + runtime: &'r Runtime, ) -> Vec { let ast::SyntaxTree { declarations: expressions, @@ -150,6 +155,7 @@ impl<'r> Lowerer<'r> { runtime_functions, &x.ident, label_store, + runtime, ) .filter_map(x), ); @@ -166,6 +172,7 @@ impl<'r> Lowerer<'r> { runtime_functions, ident, label_store, + runtime, ) .function(ident, params, ret, body), ); @@ -432,14 +439,20 @@ impl<'r> Lowerer<'r> { let func = self.type_info.function(ident).clone(); match &func.definition { - FunctionDefinition::Runtime(runtime_func) => self - .call_runtime_function( - ident, - id, - &func, + FunctionDefinition::Runtime(runtime_func) => { + let args: Vec<_> = args + .iter() + .map(|e| self.expr(e).unwrap()) + .collect(); + + self.call_runtime_function( + **ident, runtime_func, - &args.node, - ), + args, + &func.signature.parameter_types, + &func.signature.return_type, + ) + } FunctionDefinition::Roto => { let ty = self.type_info.type_of(ident); let DefinitionRef(_, ident) = @@ -498,9 +511,13 @@ impl<'r> Lowerer<'r> { let to = self.new_tmp(); let size = self.type_info.size_of(&ty); + let alignment = self.type_info.alignment_of(&ty); + let align_shift = alignment.ilog2() as u8; + self.add(Instruction::Alloc { to: to.clone(), size, + align_shift, }); self.add(Instruction::Write { to: to.clone().into(), @@ -540,14 +557,18 @@ impl<'r> Lowerer<'r> { None }; - let args = receiver_iter.into_iter().chain(&args.node); + let args: Vec<_> = receiver_iter + .into_iter() + .chain(&args.node) + .map(|e| self.expr(e).unwrap()) + .collect(); self.call_runtime_function( - ident, - id, - &func, + **ident, runtime_func, args, + &func.signature.parameter_types, + &func.signature.return_type, ) } ast::Expr::Access(e, field) => { @@ -559,6 +580,8 @@ impl<'r> Lowerer<'r> { }; let size = self.type_info.size_of(&ty); + let alignment = self.type_info.alignment_of(&ty); + let align_shift = alignment.ilog2() as u8; let Some(idx) = variants.iter().position(|(f, _)| field.node == *f) @@ -570,6 +593,7 @@ impl<'r> Lowerer<'r> { self.add(Instruction::Alloc { to: to.clone(), size, + align_shift, }); self.add(Instruction::Write { @@ -604,6 +628,8 @@ impl<'r> Lowerer<'r> { ast::Expr::TypedRecord(_, record) | ast::Expr::Record(record) => { let ty = self.type_info.type_of(id); let size = self.type_info.size_of(&ty); + let alignment = self.type_info.alignment_of(&ty); + let align_shift = alignment.ilog2() as u8; let fields: Vec<_> = record .fields @@ -619,6 +645,7 @@ impl<'r> Lowerer<'r> { self.add(Instruction::Alloc { to: to.clone(), size, + align_shift, }); for (field_name, field_operand) in fields { @@ -666,15 +693,134 @@ impl<'r> Lowerer<'r> { let place = self.new_tmp(); match (op, binop_to_cmp(op, &ty), ty) { + ( + ast::BinOp::Div, + _, + Type::Primitive(Primitive::IpAddr), + ) => { + let function = self.type_info.function(id); + let FunctionDefinition::Runtime(runtime_func) = + function.definition.clone() + else { + panic!() + }; + + let size = self + .type_info + .size_of(&Type::Primitive(Primitive::Prefix)); + let alignment = self.type_info.alignment_of( + &Type::Primitive(Primitive::Prefix), + ); + let align_shift = alignment.ilog2() as u8; + self.add(Instruction::Alloc { + to: place.clone(), + size, + align_shift, + }); + + let ident = Identifier::from("new"); + let ir_func = IrFunction { + name: ident, + ptr: runtime_func.description.pointer(), + params: vec![ + IrType::Pointer, + IrType::Pointer, + IrType::U8, + ], + ret: None, + }; + + self.runtime_functions + .insert(runtime_func.id, ir_func); + + self.add(Instruction::CallRuntime { + to: None, + func: runtime_func, + args: vec![place.clone().into(), left, right], + }); + } + ( + ast::BinOp::Eq, + _, + Type::Primitive(Primitive::IpAddr), + ) => { + let ip_addr_type_id = TypeId::of::(); + let runtime_func = self + .find_runtime_function( + runtime::FunctionKind::Method( + ip_addr_type_id, + ), + "eq", + ) + .clone(); + + let out = self + .call_runtime_function( + Identifier::from("eq"), + &runtime_func, + [left, right], + &[ + Type::Primitive(Primitive::IpAddr), + Type::Primitive(Primitive::IpAddr), + ], + &Type::Primitive(Primitive::Bool), + ) + .unwrap(); + self.add(Instruction::Assign { + to: place.clone(), + val: out, + ty: IrType::Bool, + }) + } + ( + ast::BinOp::Ne, + _, + Type::Primitive(Primitive::IpAddr), + ) => { + let ip_addr_type_id = TypeId::of::(); + let runtime_func = self + .find_runtime_function( + runtime::FunctionKind::Method( + ip_addr_type_id, + ), + "eq", + ) + .clone(); + + let out = self + .call_runtime_function( + Identifier::from("eq"), + &runtime_func, + [left, right], + &[ + Type::Primitive(Primitive::IpAddr), + Type::Primitive(Primitive::IpAddr), + ], + &Type::Primitive(Primitive::Bool), + ) + .unwrap(); + + self.add(Instruction::Not { + to: place.clone(), + val: out, + }) + } (ast::BinOp::Eq, _, ty) if self.is_reference_type(&ty) => { let size = self.type_info.size_of(&ty); + let tmp = self.new_tmp(); self.add(Instruction::MemCmp { + to: tmp.clone(), + size: IrValue::Pointer(size as usize).into(), + left: left.clone(), + right: right.clone(), + }); + self.add(Instruction::Cmp { to: place.clone(), - size, - left, - right, + cmp: ir::IntCmp::Eq, + left: tmp.into(), + right: IrValue::Pointer(0).into(), }) } (ast::BinOp::Ne, _, ty) @@ -684,13 +830,15 @@ impl<'r> Lowerer<'r> { let tmp = self.new_tmp(); self.add(Instruction::MemCmp { to: tmp.clone(), - size, - left, - right, + size: IrValue::Pointer(size as usize).into(), + left: left.clone(), + right: right.clone(), }); - self.add(Instruction::Not { + self.add(Instruction::Cmp { to: place.clone(), - val: tmp.into(), + cmp: ir::IntCmp::Ne, + left: tmp.into(), + right: IrValue::Pointer(0).into(), }) } (_, Some(cmp), _) => { @@ -818,16 +966,37 @@ impl<'r> Lowerer<'r> { } } + fn find_runtime_function( + &self, + kind: runtime::FunctionKind, + name: &str, + ) -> &RuntimeFunction { + self.runtime + .functions + .iter() + .find(|f| f.kind == kind && f.name == name) + .unwrap() + } + /// Lower a literal fn literal(&mut self, lit: &Meta) -> Operand { match &lit.node { Literal::String(_) => todo!(), Literal::Asn(n) => IrValue::Asn(*n).into(), - Literal::IpAddress(addr) => IrValue::IpAddr(match addr { - ast::IpAddress::Ipv4(x) => IpAddr::V4(*x), - ast::IpAddress::Ipv6(x) => IpAddr::V6(*x), - }) - .into(), + Literal::IpAddress(addr) => { + let to = self.new_tmp(); + const SIZE: usize = std::mem::size_of::(); + const ALIGN: usize = std::mem::align_of::(); + let align_shift = ALIGN.ilog2() as u8; + + let x: [u8; SIZE] = unsafe { std::mem::transmute_copy(addr) }; + self.add(Instruction::Initialize { + to: to.clone(), + bytes: x.into(), + align_shift, + }); + to.into() + } Literal::Integer(x) => { let ty = self.type_info.type_of(lit); match ty { @@ -927,6 +1096,7 @@ impl<'r> Lowerer<'r> { | Type::NamedRecord(..) | Type::Enum(..) | Type::Verdict(..) + | Type::Primitive(Primitive::IpAddr | Primitive::Prefix) ) } @@ -943,6 +1113,7 @@ impl<'r> Lowerer<'r> { Type::Primitive(Primitive::I32) => IrType::I32, Type::Primitive(Primitive::I64) => IrType::I64, Type::Primitive(Primitive::Asn) => IrType::Asn, + Type::Primitive(Primitive::IpAddr) => IrType::Pointer, Type::IntVar(_) => IrType::I32, Type::BuiltIn(_, _) => IrType::ExtPointer, x if self.is_reference_type(&x) => IrType::Pointer, @@ -950,43 +1121,41 @@ impl<'r> Lowerer<'r> { } } - fn call_runtime_function<'a>( + fn call_runtime_function( &mut self, - ident: &Meta, - id: MetaId, - func: &typechecker::types::Function, + ident: Identifier, runtime_func: &RuntimeFunction, - args: impl IntoIterator>, + args: impl IntoIterator, + parameter_types: &[Type], + return_type: &Type, ) -> Option { - let ret = &func.signature.return_type; - let ty = self.type_info.type_of(id); let out_ptr = self.new_tmp(); - let size = self.type_info.size_of(ret); + let size = self.type_info.size_of(return_type); self.add(Instruction::Alloc { to: out_ptr.clone(), size, + align_shift: 0, }); - let args = std::iter::once(Operand::Place(out_ptr.clone())) - .chain(args.into_iter().flat_map(|a| self.expr(a))) - .collect(); - let mut params = Vec::new(); params.push(IrType::Pointer); - for ty in &func.signature.parameter_types { + for ty in parameter_types { params.push(self.lower_type(ty)) } + let args = std::iter::once(Operand::Place(out_ptr.clone())) + .chain(args) + .collect(); + let ir_func = IrFunction { - name: ident.node, + name: ident, ptr: runtime_func.description.pointer(), params, ret: None, }; - self.runtime_functions - .insert(ident.as_str().into(), ir_func); + self.runtime_functions.insert(runtime_func.id, ir_func); self.add(Instruction::CallRuntime { to: None, @@ -994,10 +1163,12 @@ impl<'r> Lowerer<'r> { args, }); - if self.is_reference_type(&ty) { + if self.is_reference_type(return_type) { Some(out_ptr.into()) + } else if size > 0 { + Some(self.read_field(out_ptr.into(), 0, return_type)) } else { - Some(self.read_field(out_ptr.into(), 0, &ty)) + None } } } @@ -1010,6 +1181,8 @@ fn binop_to_cmp(op: &ast::BinOp, ty: &Type) -> Option { | Primitive::U16 | Primitive::U8 | Primitive::Asn + | Primitive::IpAddr + | Primitive::Prefix | Primitive::Bool => false, Primitive::I64 | Primitive::I32 diff --git a/src/lower/value.rs b/src/lower/value.rs index 736c9007..b42526f5 100644 --- a/src/lower/value.rs +++ b/src/lower/value.rs @@ -21,7 +21,6 @@ pub enum IrValue { I32(i32), I64(i64), Asn(Asn), - IpAddr(std::net::IpAddr), Pointer(usize), ExtPointer(*mut ()), ExtValue(Vec), @@ -40,7 +39,6 @@ pub enum IrType { I32, I64, Asn, - IpAddr, Pointer, ExtPointer, ExtValue, @@ -55,7 +53,6 @@ impl IrType { U16 | I16 => 2, U32 | I32 | Asn => 4, U64 | I64 => 8, - IpAddr => 4, Pointer | ExtValue | ExtPointer => (usize::BITS / 8) as usize, } } @@ -79,7 +76,6 @@ impl Display for IrType { I32 => "i32", I64 => "i64", Asn => "Asn", - IpAddr => "IpAddr", Pointer => "Pointer", ExtValue => "ExtValue", ExtPointer => "ExtPointer", @@ -100,9 +96,9 @@ impl PartialEq for IrValue { (I16(l), I16(r)) => l == r, (I32(l), I32(r)) => l == r, (Asn(l), Asn(r)) => l == r, + (Pointer(l), Pointer(r)) => l == r, (ExtValue(_), ExtValue(_)) => false, (ExtPointer(_), ExtPointer(_)) => false, - (Pointer(_), Pointer(_)) => panic!("can't compare pointers"), _ => panic!("tried comparing different types"), } } @@ -124,7 +120,6 @@ impl IrValue { I32(_) => IrType::I32, I64(_) => IrType::I64, Asn(_) => IrType::Asn, - IpAddr(_) => IrType::I32, Pointer(_) => IrType::Pointer, ExtValue(_) => IrType::ExtValue, ExtPointer(_) => IrType::ExtPointer, @@ -143,7 +138,6 @@ impl IrValue { Self::I32(x) => x, Self::I64(x) => x, Self::Asn(x) => x, - Self::IpAddr(x) => x, Self::Pointer(x) => x, Self::ExtValue(x) => x, Self::ExtPointer(x) => x, @@ -188,7 +182,6 @@ impl IrValue { Self::I32(x) => x.to_ne_bytes().into(), Self::I64(x) => x.to_ne_bytes().into(), Self::Asn(x) => x.into_u32().to_ne_bytes().into(), - Self::IpAddr(_) => todo!(), Self::Pointer(x) => x.to_ne_bytes().into(), Self::ExtValue(x) => x.clone(), Self::ExtPointer(x) => { @@ -243,16 +236,6 @@ impl IrValue { let val: &[u8; 4] = val.try_into().unwrap(); Self::Asn(Asn::from_u32(u32::from_ne_bytes(*val))) } - IrType::IpAddr => { - let val: &[u8; 32] = val.try_into().unwrap(); - if val[0] == 0 { - let addr: [u8; 4] = val[1..5].try_into().unwrap(); - Self::IpAddr(std::net::IpAddr::from(addr)) - } else { - let addr: [u8; 16] = val[1..17].try_into().unwrap(); - Self::IpAddr(std::net::IpAddr::from(addr)) - } - } IrType::Pointer => { const SIZE: usize = (usize::BITS / 8) as usize; let val: &[u8; SIZE] = val.try_into().unwrap(); @@ -326,7 +309,6 @@ impl Display for IrValue { I32(x) => write!(f, "i32({x})"), I64(x) => write!(f, "i64({x})"), Asn(x) => write!(f, "Asn({x})"), - IpAddr(x) => write!(f, "IpAddr({x})"), Pointer(x) => write!(f, "Pointer({x})"), ExtValue(..) => write!(f, "ExtValue(..)"), ExtPointer(..) => write!(f, "ExtPointer(..)"), diff --git a/src/parser/expr.rs b/src/parser/expr.rs index 20c3d1e8..526cc88b 100644 --- a/src/parser/expr.rs +++ b/src/parser/expr.rs @@ -1,9 +1,11 @@ +use std::net::IpAddr; + use inetnum::asn::Asn; use crate::{ ast::{ - BinOp, Block, Expr, IpAddress, Literal, Match, MatchArm, Pattern, - Record, ReturnKind, + BinOp, Block, Expr, Literal, Match, MatchArm, Pattern, Record, + ReturnKind, }, parser::ParseError, }; @@ -387,6 +389,9 @@ impl<'source> Parser<'source, '_> { | Token::Bool(_) | Token::Integer(_) | Token::Hyphen + | Token::IpV4(_) + | Token::IpV6(_) + | Token::Asn(_) ) } @@ -506,19 +511,19 @@ impl<'source> Parser<'source, '_> { self.simple_literal() } - fn ip_address(&mut self) -> ParseResult> { + fn ip_address(&mut self) -> ParseResult> { let (token, span) = self.next()?; let addr = match token { - Token::IpV4(s) => IpAddress::Ipv4( - s.parse::().map_err(|e| { + Token::IpV4(s) => { + IpAddr::V4(s.parse::().map_err(|e| { ParseError::invalid_literal("Ipv4 addresses", s, e, span) - })?, - ), - Token::IpV6(s) => IpAddress::Ipv6( - s.parse::().map_err(|e| { + })?) + } + Token::IpV6(s) => { + IpAddr::V6(s.parse::().map_err(|e| { ParseError::invalid_literal("Ipv6 addresses", s, e, span) - })?, - ), + })?) + } _ => { return Err(ParseError::expected( "an IP address", diff --git a/src/parser/token.rs b/src/parser/token.rs index d0a0ed4f..0bbff4cf 100644 --- a/src/parser/token.rs +++ b/src/parser/token.rs @@ -145,7 +145,7 @@ pub enum Token<'s> { #[regex(r"([0-9]+\.){3}[0-9]+")] IpV4(&'s str), - #[regex(r"([0-9a-zA-Z]*:){2,6}[0-9a-zA-Z]*")] + #[regex(r"[0-9a-zA-Z]*(:[0-9a-zA-Z]*){2,6}")] IpV6(&'s str), // This regex is a super set of all the forms of communities: diff --git a/src/pipeline.rs b/src/pipeline.rs index 7077fe42..b80f5432 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -77,12 +77,13 @@ pub struct TypeChecked { trees: Vec, type_infos: Vec, scope_graph: ScopeGraph, + runtime: Runtime, } /// Compiler stage: HIR pub struct Lowered { pub ir: Vec, - runtime_functions: HashMap, + runtime_functions: HashMap, label_store: LabelStore, type_info: TypeInfo, } @@ -357,6 +358,7 @@ impl Parsed { trees, type_infos, scope_graph, + runtime, }) } else { Err(RotoReport { @@ -374,6 +376,7 @@ impl TypeChecked { trees, mut type_infos, scope_graph, + runtime, } = self; let mut runtime_functions = HashMap::new(); let mut label_store = LabelStore::default(); @@ -382,8 +385,10 @@ impl TypeChecked { &mut type_infos[0], &mut runtime_functions, &mut label_store, + &runtime, ); + let _ = env_logger::try_init(); if log::log_enabled!(log::Level::Info) { let s = IrPrinter { scope_graph: &scope_graph, diff --git a/src/runtime/func.rs b/src/runtime/func.rs index 0515de15..406172a7 100644 --- a/src/runtime/func.rs +++ b/src/runtime/func.rs @@ -84,7 +84,7 @@ pub trait Func: Sized { macro_rules! func_impl { ($($arg:ident),*) => { - impl<$($arg,)* Ret> Func<($($arg,)*), Ret> for extern "C" fn($($arg),*) -> Ret + impl<$($arg,)* Ret> Func<($($arg,)*), Ret> for for<'a> extern "C" fn($($arg),*) -> Ret where $( $arg: Reflect + for<'a> TryFrom<&'a IrValue> + 'static, diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index 6c58a108..71ec11b6 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -30,10 +30,11 @@ pub mod func; pub mod ty; pub mod verdict; -use std::any::TypeId; +use std::{any::TypeId, net::IpAddr}; use func::{Func, FunctionDescription}; -use inetnum::asn::Asn; +use inetnum::{addr::Prefix, asn::Asn}; +use roto_macros::roto_method; use ty::{Ty, TypeDescription, TypeRegistry}; /// Provides the types and functions that Roto can access via FFI @@ -71,9 +72,17 @@ pub enum FunctionKind { #[derive(Clone, Debug, PartialEq, Eq)] pub struct RuntimeFunction { + /// Name that the function can be referenced by pub name: String, + + /// Description of the signature of the function pub description: FunctionDescription, + + /// Whether it's a free function, method or a static method pub kind: FunctionKind, + + /// Unique identifier for this function + pub id: usize, } impl Runtime { @@ -187,10 +196,12 @@ impl Runtime { let description = f.to_function_description(self)?; self.check_description(&description)?; + let id = self.functions.len(); self.functions.push(RuntimeFunction { name: name.into(), description, kind: FunctionKind::Free, + id, }); Ok(()) } @@ -229,10 +240,12 @@ impl Runtime { ); } + let id = self.functions.len(); self.functions.push(RuntimeFunction { name: name.into(), description, kind: FunctionKind::Method(std::any::TypeId::of::()), + id, }); Ok(()) @@ -246,10 +259,12 @@ impl Runtime { let description = f.to_function_description(self).unwrap(); self.check_description(&description)?; + let id = self.functions.len(); self.functions.push(RuntimeFunction { name: name.into(), description, kind: FunctionKind::StaticMethod(std::any::TypeId::of::()), + id, }); Ok(()) } @@ -313,6 +328,53 @@ impl Runtime { rt.register_copy_type::()?; rt.register_copy_type::()?; rt.register_copy_type::()?; + rt.register_type::()?; + rt.register_type::()?; + + extern "C" fn prefix_new(out: *mut Prefix, ip: *mut IpAddr, len: u8) { + let ip = unsafe { *ip }; + + let p = Prefix::new(ip, len).unwrap(); + let p = unsafe { + std::mem::transmute::()]>(p) + }; + + let out = out as *mut [u8; std::mem::size_of::()]; + unsafe { + *out = p; + } + } + + rt.register_static_method::( + "new", + prefix_new as extern "C" fn(_, _, _) -> _, + ) + .unwrap(); + + #[roto_method(rt, IpAddr, eq)] + fn ipaddr_eq(a: *const IpAddr, b: *const IpAddr) -> bool { + let a = unsafe { *a }; + let b = unsafe { *b }; + a == b + } + + #[roto_method(rt, IpAddr)] + fn is_ipv4(ip: *const IpAddr) -> bool { + let ip = unsafe { &*ip }; + ip.is_ipv4() + } + + #[roto_method(rt, IpAddr)] + fn is_ipv6(ip: *const IpAddr) -> bool { + let ip = unsafe { &*ip }; + ip.is_ipv6() + } + + #[roto_method(rt, IpAddr)] + fn to_canonical(ip: *const IpAddr) -> IpAddr { + let ip = unsafe { &*ip }; + ip.to_canonical() + } Ok(rt) } @@ -329,26 +391,20 @@ impl Runtime { #[cfg(test)] pub mod tests { - use std::net::IpAddr; - use super::Runtime; - use roto_macros::roto_function; - use routecore::{ - addr::Prefix, - bgp::{ - aspath::{AsPath, HopPath}, - communities::Community, - path_attributes::{ - Aggregator, AtomicAggregate, MultiExitDisc, NextHop, - }, - types::{LocalPref, OriginType}, + use roto_macros::{roto_function, roto_method}; + use routecore::bgp::{ + aspath::{AsPath, HopPath}, + communities::Community, + path_attributes::{ + Aggregator, AtomicAggregate, MultiExitDisc, NextHop, }, + types::{LocalPref, OriginType}, }; pub fn routecore_runtime() -> Result { let mut rt = Runtime::basic()?; - rt.register_type::()?; rt.register_type::()?; rt.register_type::()?; rt.register_type::()?; @@ -356,48 +412,19 @@ pub mod tests { rt.register_type::()?; rt.register_type::()?; rt.register_type::()?; - rt.register_type::()?; rt.register_type::()?; rt.register_type::>>()?; - #[roto_function] + #[roto_function(rt)] fn pow(x: u32, y: u32) -> u32 { x.pow(y) } - rt.register_function("pow", pow)?; - - #[roto_function] + #[roto_method(rt, u32)] fn is_even(x: u32) -> bool { x % 2 == 0 } - rt.register_method::("is_even", is_even)?; - - #[roto_function] - fn is_ipv4(ip: *const IpAddr) -> bool { - let ip = unsafe { &*ip }; - ip.is_ipv4() - } - - rt.register_method::("is_ipv4", is_ipv4)?; - - #[roto_function] - fn is_ipv6(ip: *const IpAddr) -> bool { - let ip = unsafe { &*ip }; - ip.is_ipv6() - } - - rt.register_method::("is_ipv6", is_ipv6)?; - - #[roto_function] - fn to_canonical(ip: *const IpAddr) -> IpAddr { - let ip = unsafe { &*ip }; - ip.to_canonical() - } - - rt.register_method::("to_canonical", to_canonical)?; - Ok(rt) } @@ -422,6 +449,7 @@ pub mod tests { "i64", "Asn", "IpAddr", + "Prefix", "OriginType", "NextHop", "MultiExitDisc", @@ -429,7 +457,6 @@ pub mod tests { "Aggregator", "AtomicAggregate", "Community", - "Prefix", "HopPath", "AsPath" ] diff --git a/src/runtime/ty.rs b/src/runtime/ty.rs index 8ca74b7c..b5727b33 100644 --- a/src/runtime/ty.rs +++ b/src/runtime/ty.rs @@ -11,11 +11,11 @@ use std::{ any::{type_name, TypeId}, collections::HashMap, - ops::DerefMut, + net::IpAddr, sync::{LazyLock, Mutex}, }; -use inetnum::asn::Asn; +use inetnum::{addr::Prefix, asn::Asn}; use super::verdict::Verdict; @@ -73,7 +73,7 @@ impl Ty { pub static GLOBAL_TYPE_REGISTRY: LazyLock> = LazyLock::new(|| Mutex::new(TypeRegistry::default())); -/// A map from TypeId to a [`Ty`], which is a description of the type +/// A map from [`TypeId`] to a [`Ty`], which is a description of the type #[derive(Default)] pub struct TypeRegistry { map: HashMap, @@ -89,21 +89,44 @@ impl TypeRegistry { self.map.get(&id) } + /// Register a type implementing [`Reflect`] pub fn resolve(&mut self) -> Ty { T::resolve(self) } } +/// A type that can register itself into a [`TypeRegistry`]. +/// +/// Via the [`TypeRegistry`], it is then possible to query for information +/// about this type. Reflection is recursive for types such as [`Verdict`], +/// [`Result`] and [`Option`]. +/// +/// Pointers are explicitly _not_ recursive, because they can be used to pass +/// pointers to types that have been registered to Roto and therefore don't +/// need to implement this trait. +/// +/// Additionally, this trait specifies how a type should be passed to Roto, via +/// the `AsParam` associated type. pub trait Reflect: 'static { - fn resolve(registry: &mut TypeRegistry) -> Ty; + /// The type that this type should be converted into when passed to Roto + type AsParam; - fn resolve_global() -> Ty { - let mut reg = GLOBAL_TYPE_REGISTRY.lock().unwrap(); - Self::resolve(reg.deref_mut()) - } + /// Convert the type to its `AsParam` + fn as_param(&mut self) -> Self::AsParam; + + /// Put information about this type into the [`TypeRegistry`] + /// + /// The information is also returned for direct use. + fn resolve(registry: &mut TypeRegistry) -> Ty; } impl Reflect for Verdict { + type AsParam = *mut Self; + + fn as_param(&mut self) -> Self::AsParam { + self as _ + } + fn resolve(registry: &mut TypeRegistry) -> Ty { let t = A::resolve(registry).type_id; let e = R::resolve(registry).type_id; @@ -114,6 +137,12 @@ impl Reflect for Verdict { } impl Reflect for Result { + type AsParam = *mut Self; + + fn as_param(&mut self) -> Self::AsParam { + self as _ + } + fn resolve(registry: &mut TypeRegistry) -> Ty { let t = T::resolve(registry).type_id; let e = E::resolve(registry).type_id; @@ -124,6 +153,12 @@ impl Reflect for Result { } impl Reflect for Option { + type AsParam = *mut Self; + + fn as_param(&mut self) -> Self::AsParam { + self as _ + } + fn resolve(registry: &mut TypeRegistry) -> Ty { let t = T::resolve(registry).type_id; @@ -133,6 +168,12 @@ impl Reflect for Option { } impl Reflect for *mut T { + type AsParam = Self; + + fn as_param(&mut self) -> Self::AsParam { + *self + } + fn resolve(registry: &mut TypeRegistry) -> Ty { let t = registry.store::(TypeDescription::Leaf).type_id; @@ -142,6 +183,12 @@ impl Reflect for *mut T { } impl Reflect for *const T { + type AsParam = Self; + + fn as_param(&mut self) -> Self::AsParam { + *self + } + fn resolve(registry: &mut TypeRegistry) -> Ty { let t = registry.store::(TypeDescription::Leaf).type_id; @@ -150,9 +197,39 @@ impl Reflect for *const T { } } +impl Reflect for IpAddr { + type AsParam = *mut Self; + + fn as_param(&mut self) -> Self::AsParam { + self as _ + } + + fn resolve(registry: &mut TypeRegistry) -> Ty { + registry.store::(TypeDescription::Leaf) + } +} + +impl Reflect for Prefix { + type AsParam = *mut Self; + + fn as_param(&mut self) -> Self::AsParam { + self as _ + } + + fn resolve(registry: &mut TypeRegistry) -> Ty { + registry.store::(TypeDescription::Leaf) + } +} + macro_rules! simple_reflect { ($t:ty) => { impl Reflect for $t { + type AsParam = Self; + + fn as_param(&mut self) -> Self::AsParam { + *self + } + fn resolve(registry: &mut TypeRegistry) -> Ty { registry.store::(TypeDescription::Leaf) } diff --git a/src/typechecker/error.rs b/src/typechecker/error.rs index 59aa8d6b..95d769ff 100644 --- a/src/typechecker/error.rs +++ b/src/typechecker/error.rs @@ -1,3 +1,5 @@ +//! Type errors + use std::fmt::Display; use crate::{ @@ -13,6 +15,7 @@ pub enum Level { Info, } +/// A label is a bit of text attached to a span #[derive(Clone, Debug)] pub struct Label { pub level: Level, @@ -21,6 +24,7 @@ pub struct Label { } impl Label { + /// Create an error label fn error(msg: impl Display, id: MetaId) -> Self { Label { level: Level::Error, @@ -29,6 +33,7 @@ impl Label { } } + /// Create an info label fn info(msg: impl Display, id: MetaId) -> Self { Label { level: Level::Info, @@ -38,6 +43,7 @@ impl Label { } } +/// A type error displayed to the user #[derive(Clone, Debug)] pub struct TypeError { pub description: String, @@ -46,6 +52,7 @@ pub struct TypeError { } impl TypeChecker<'_> { + /// Catch all error with a basic format with just one label pub fn error_simple( &self, description: impl Display, diff --git a/src/typechecker/expr.rs b/src/typechecker/expr.rs index 6478ae41..5862840e 100644 --- a/src/typechecker/expr.rs +++ b/src/typechecker/expr.rs @@ -1,3 +1,5 @@ +//! Type checking of expressions + use std::{borrow::Borrow, collections::HashSet}; use crate::{ @@ -477,7 +479,7 @@ impl TypeChecker<'_> { let t = match lit.node { String(_) => Type::Primitive(Primitive::String), Asn(_) => Type::Primitive(Primitive::Asn), - IpAddress(_) => Type::Name(Identifier::from("IpAddr")), + IpAddress(_) => Type::Primitive(Primitive::IpAddr), Bool(_) => Type::Primitive(Primitive::Bool), Integer(_) => self.fresh_int(), }; @@ -638,6 +640,43 @@ impl TypeChecker<'_> { ) -> TypeResult { use ast::BinOp::*; + // There's a special case: constructing prefixes with `/` + // We do a conservative check on the left hand side to see if it + // could be an ip address. This (hopefully) does not conflict with the + // integer implementation later. + if let Div = op { + let var = self.fresh_var(); + let ctx_left = ctx.with_type(var.clone()); + + let mut diverges = false; + diverges |= self.expr(scope, &ctx_left, left)?; + + let resolved = self.resolve_type(&var); + + if let Type::Primitive(Primitive::IpAddr) = resolved { + let ctx_right = ctx.with_type(Type::Primitive(Primitive::U8)); + diverges |= self.expr(scope, &ctx_right, right)?; + + self.unify( + &ctx.expected_type, + &Type::Primitive(Primitive::Prefix), + span, + None, + )?; + + let name = Identifier::from("new"); + let (function, _sig) = self.find_function( + &FunctionKind::StaticMethod(Type::Primitive( + Primitive::Prefix, + )), + name, + ).unwrap(); + let function = function.clone(); + self.type_info.function_calls.insert(span, function); + return Ok(diverges); + } + }; + match op { And | Or => { self.unify( diff --git a/src/typechecker/info.rs b/src/typechecker/info.rs index cf3fe2a3..c0abd0f8 100644 --- a/src/typechecker/info.rs +++ b/src/typechecker/info.rs @@ -1,10 +1,12 @@ -use std::collections::HashMap; +use std::{collections::HashMap, net::IpAddr}; + +use inetnum::addr::Prefix; use crate::{ast::Identifier, parser::meta::MetaId}; use super::{ scope::{DefinitionRef, ScopeRef}, - types::{Function, Type}, + types::{Function, Primitive, Type}, unionfind::UnionFind, }; @@ -145,6 +147,12 @@ impl TypeInfo { .map(|f| self.alignment_of(f)) .max() .unwrap_or(4), + Type::Primitive(Primitive::IpAddr) => { + std::mem::align_of::() as u32 + } + Type::Primitive(Primitive::Prefix) => { + std::mem::align_of::() as u32 + } ty => self.size_of(&ty), }; // Alignment must be guaranteed to be at least 1 diff --git a/src/typechecker/mod.rs b/src/typechecker/mod.rs index 847a2254..0f28451c 100644 --- a/src/typechecker/mod.rs +++ b/src/typechecker/mod.rs @@ -207,6 +207,7 @@ impl TypeChecker<'_> { name, description, kind, + id: _, } = func; let mut rust_parameters = diff --git a/src/typechecker/types.rs b/src/typechecker/types.rs index 97fdd29b..cd042bcb 100644 --- a/src/typechecker/types.rs +++ b/src/typechecker/types.rs @@ -45,6 +45,8 @@ pub enum Primitive { String, Bool, Asn, + IpAddr, + Prefix, } impl From for Type { @@ -71,6 +73,8 @@ impl Display for Primitive { Primitive::String => "String", Primitive::Bool => "bool", Primitive::Asn => "Asn", + Primitive::IpAddr => "IpAddr", + Primitive::Prefix => "Prefix", } ) } @@ -174,20 +178,16 @@ impl Type { impl Primitive { /// Size of the type in bytes pub fn size(&self) -> u32 { + use Primitive::*; match self { - Primitive::U8 => 1, - Primitive::U16 => 2, - Primitive::U32 => 4, - Primitive::U64 => 8, - Primitive::I8 => 1, - Primitive::I16 => 2, - Primitive::I32 => 4, - Primitive::I64 => 8, - Primitive::Unit => 0, - Primitive::String => 4, - Primitive::Bool => 1, - // Asn has the same size as u32, which is 4 bytes - Primitive::Asn => 4, + U8 | I8 | Bool => 1, + U16 | I16 => 2, + U32 | I32 | Asn => 4, + U64 | I64 => 8, + Unit => 0, + String => 4, + IpAddr => std::mem::size_of::() as u32, + Prefix => std::mem::size_of::() as u32, } } } @@ -287,6 +287,8 @@ pub fn default_types(runtime: &Runtime) -> Vec<(Identifier, Type)> { ("String", String), ("Unit", Unit), ("Asn", Asn), + ("IpAddr", IpAddr), + ("Prefix", Prefix), ]; let mut types = Vec::new();