Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move from string_interner to symbol_table #66

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ env_logger = "0.10"
log = "0.4"
logos = "0.14.0"
inetnum = "0.1.0"
string-interner = "0.17.0"
symbol_table = { version = "0.3.0", features = ["global"] }

[dev-dependencies]
bytes = "1"
Expand Down
36 changes: 34 additions & 2 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
//! A [`SyntaxTree`] is the output of the Roto parser. It contains a
//! representation of the Roto script as Rust types for further processing.

use std::fmt::Display;

use inetnum::asn::Asn;
use string_interner::symbol::SymbolU32;
use symbol_table::GlobalSymbol;

use crate::parser::meta::Meta;

Expand Down Expand Up @@ -208,7 +210,37 @@ pub struct OutputStream {
/// It is a word composed of a leading alphabetic Unicode character, followed
/// by alphanumeric Unicode characters or underscore or hyphen.
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Identifier(pub SymbolU32);
pub struct Identifier(GlobalSymbol);

impl Display for Identifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}

impl Identifier {
pub fn as_str(&self) -> &'static str {
self.0.as_str()
}
}

impl From<&str> for Identifier {
fn from(value: &str) -> Self {
Self(value.into())
}
}

impl From<&String> for Identifier {
fn from(value: &String) -> Self {
Self(value.into())
}
}

impl From<String> for Identifier {
fn from(value: String) -> Self {
Self(value.into())
}
}

#[derive(Clone, Debug)]
pub struct RecordType {
Expand Down
31 changes: 7 additions & 24 deletions src/codegen/check.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use inetnum::asn::Asn;
use string_interner::{backend::StringBackend, StringInterner};

use crate::{
runtime::ty::{
Reflect, TypeDescription, TypeRegistry, GLOBAL_TYPE_REGISTRY,
},
typechecker::{
info::TypeInfo,
types::{type_to_string, Primitive, Type},
types::{Primitive, Type},
},
};
use std::{any::TypeId, fmt::Display, mem::MaybeUninit};
Expand Down Expand Up @@ -59,19 +58,17 @@ impl Display for FunctionRetrievalError {

pub fn check_roto_type_reflect<T: Reflect>(
type_info: &mut TypeInfo,
identifiers: &StringInterner<StringBackend>,
roto_ty: &Type,
) -> Result<(), TypeMismatch> {
let mut registry = GLOBAL_TYPE_REGISTRY.lock().unwrap();
let rust_ty = registry.resolve::<T>().type_id;
check_roto_type(&registry, type_info, identifiers, rust_ty, roto_ty)
check_roto_type(&registry, type_info, rust_ty, roto_ty)
}

#[allow(non_snake_case)]
fn check_roto_type(
registry: &TypeRegistry,
type_info: &mut TypeInfo,
identifiers: &StringInterner<StringBackend>,
rust_ty: TypeId,
roto_ty: &Type,
) -> Result<(), TypeMismatch> {
Expand All @@ -91,13 +88,13 @@ fn check_roto_type(
let Some(rust_ty) = registry.get(rust_ty) else {
return Err(TypeMismatch {
rust_ty: "unknown".into(),
roto_ty: type_to_string(identifiers, roto_ty),
roto_ty: roto_ty.to_string(),
});
};

let error_message = TypeMismatch {
rust_ty: rust_ty.rust_name.to_string(),
roto_ty: type_to_string(identifiers, roto_ty),
roto_ty: roto_ty.to_string(),
};

let mut roto_ty = type_info.resolve(roto_ty);
Expand Down Expand Up @@ -134,20 +131,8 @@ fn check_roto_type(
let Type::Verdict(roto_accept, roto_reject) = &roto_ty else {
return Err(error_message);
};
check_roto_type(
registry,
type_info,
identifiers,
rust_accept,
roto_accept,
)?;
check_roto_type(
registry,
type_info,
identifiers,
rust_reject,
roto_reject,
)?;
check_roto_type(registry, type_info, rust_accept, roto_accept)?;
check_roto_type(registry, type_info, rust_reject, roto_reject)?;
Ok(())
}
// We don't do options and results, we should hint towards verdict
Expand All @@ -173,7 +158,6 @@ pub fn return_type_by_ref(registry: &TypeRegistry, rust_ty: TypeId) -> bool {
pub trait RotoParams {
fn check(
type_info: &mut TypeInfo,
identifiers: &StringInterner<StringBackend>,
ty: &[Type],
) -> Result<(), FunctionRetrievalError>;

Expand Down Expand Up @@ -201,7 +185,6 @@ macro_rules! params {
{
fn check(
type_info: &mut TypeInfo,
identifiers: &StringInterner<StringBackend>,
ty: &[Type]
) -> Result<(), FunctionRetrievalError> {
let [$($t),*] = ty else {
Expand All @@ -216,7 +199,7 @@ macro_rules! params {
let mut i = 0;
$(
i += 1;
check_roto_type_reflect::<$t>(type_info, identifiers, $t)
check_roto_type_reflect::<$t>(type_info, $t)
.map_err(|e| FunctionRetrievalError::TypeMismatch(format!("argument {i}"), e))?;
)*
Ok(())
Expand Down
33 changes: 9 additions & 24 deletions src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ use cranelift::{
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{FuncId, Linkage, Module as _};
use log::info;
use string_interner::{backend::StringBackend, StringInterner};

pub mod check;
#[cfg(test)]
Expand Down Expand Up @@ -148,7 +147,7 @@ pub struct FunctionInfo {
signature: types::Signature,
}

struct ModuleBuilder<'a> {
struct ModuleBuilder {
/// The set of public functions and their signatures.
functions: HashMap<String, FunctionInfo>,

Expand All @@ -167,18 +166,15 @@ struct ModuleBuilder<'a> {
/// Instruction set architecture
isa: Arc<dyn TargetIsa>,

/// Identifiers are used for debugging and resolving function names.
identifiers: &'a StringInterner<StringBackend>,

/// To print labels for debugging.
#[allow(unused)]
label_store: LabelStore,

type_info: TypeInfo,
}

struct FuncGen<'a, 'c> {
module: &'c mut ModuleBuilder<'a>,
struct FuncGen<'c> {
module: &'c mut ModuleBuilder,

/// The cranelift function builder
builder: FunctionBuilder<'c>,
Expand All @@ -198,7 +194,6 @@ const MEMFLAGS: MemFlags = MemFlags::new().with_aligned();
pub fn codegen(
ir: &[ir::Function],
runtime_functions: &HashMap<String, IrFunction>,
identifiers: &StringInterner<StringBackend>,
label_store: LabelStore,
type_info: TypeInfo,
) -> Module {
Expand Down Expand Up @@ -228,7 +223,6 @@ pub fn codegen(
inner: jit,
isa,
variable_map: HashMap::new(),
identifiers,
label_store,
type_info,
};
Expand Down Expand Up @@ -259,13 +253,12 @@ pub fn codegen(
let mut builder_context = FunctionBuilderContext::new();
for func in ir {
module.define_function(func, &mut builder_context);
// info!("\n{}", func);
}

module.finalize()
}

impl ModuleBuilder<'_> {
impl ModuleBuilder {
/// Declare a function and its signature (without the body)
fn declare_function(&mut self, func: &ir::Function) {
let ir::Function {
Expand All @@ -286,11 +279,10 @@ impl ModuleBuilder<'_> {
None => Vec::new(),
};

let name = self.identifiers.resolve(name.0).unwrap();
let func_id = self
.inner
.declare_function(
name,
name.as_str(),
if *public {
Linkage::Export
} else {
Expand Down Expand Up @@ -324,8 +316,7 @@ impl ModuleBuilder<'_> {
scope,
..
} = func;
let name = self.identifiers.resolve(name.0).unwrap();
let func_id = self.functions[name].id;
let func_id = self.functions[name.as_str()].id;

let mut ctx = self.inner.make_context();
let mut sig = self.inner.make_signature();
Expand Down Expand Up @@ -403,7 +394,7 @@ impl ModuleBuilder<'_> {
}
}

impl<'a, 'c> FuncGen<'a, 'c> {
impl<'c> FuncGen<'c> {
fn finalize(self) {
self.builder.finalize()
}
Expand Down Expand Up @@ -511,7 +502,7 @@ impl<'a, 'c> FuncGen<'a, 'c> {
self.def(var, val)
}
ir::Instruction::Call { to, func, args } => {
let func = self.module.identifiers.resolve(func.0).unwrap();
let func = func.as_str();
let func_id = self.module.functions[func].id;
let func_ref = self
.module
Expand Down Expand Up @@ -793,7 +784,6 @@ impl<'a, 'c> FuncGen<'a, 'c> {
impl Module {
pub fn get_function<Params: RotoParams, Return: Reflect>(
&mut self,
identifiers: &StringInterner<StringBackend>,
name: &str,
) -> Result<TypedFunc<Params, Return>, FunctionRetrievalError> {
let function_info = self.functions.get(name).ok_or_else(|| {
Expand All @@ -806,15 +796,10 @@ impl Module {
let sig = &function_info.signature;
let id = function_info.id;

Params::check(
&mut self.type_info,
identifiers,
&sig.parameter_types,
)?;
Params::check(&mut self.type_info, &sig.parameter_types)?;

check_roto_type_reflect::<Return>(
&mut self.type_info,
identifiers,
&sig.return_type,
)
.map_err(|e| {
Expand Down
4 changes: 1 addition & 3 deletions src/lower/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use crate::{
runtime::RuntimeFunction,
};
use std::collections::HashMap;
use string_interner::{backend::StringBackend, StringInterner};

/// Memory for the IR evaluation
///
Expand Down Expand Up @@ -232,9 +231,8 @@ pub fn eval(
filter_map: &str,
mem: &mut Memory,
rx: Vec<IrValue>,
identifiers: &StringInterner<StringBackend>,
) -> Option<IrValue> {
let filter_map_ident = Identifier(identifiers.get(filter_map).unwrap());
let filter_map_ident = Identifier::from(filter_map);
let f = p
.iter()
.find(|f| f.name == filter_map_ident)
Expand Down
7 changes: 2 additions & 5 deletions src/lower/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@

use std::fmt::Display;

use string_interner::{backend::StringBackend, StringInterner};

use crate::{
ast::Identifier,
runtime,
Expand Down Expand Up @@ -278,17 +276,16 @@ pub struct Block {

pub struct IrPrinter<'a> {
pub scope_graph: &'a ScopeGraph,
pub identifiers: &'a StringInterner<StringBackend>,
pub label_store: &'a LabelStore,
}

impl<'a> IrPrinter<'a> {
pub fn ident(&self, ident: &Identifier) -> &'a str {
self.identifiers.resolve(ident.0).unwrap()
ident.as_str()
}

pub fn scope(&self, scope: ScopeRef) -> String {
self.scope_graph.print_scope(scope, self.identifiers)
self.scope_graph.print_scope(scope)
}

pub fn var(&self, var: &Var) -> String {
Expand Down
Loading
Loading