From 1618eb486626449cbf3dcc858732aa2889edf103 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vid=20Drobni=C4=8D?= Date: Tue, 4 Jun 2024 20:32:49 +0200 Subject: [PATCH] feat: execute if statements --- runtime/src/bytecode.rs | 1 + runtime/src/compiler/mod.rs | 61 ++++++++++++-- runtime/src/compiler/test.rs | 153 +++++++++++++++++++++++++++++++++++ runtime/src/vm/mod.rs | 1 + runtime/src/vm/test.rs | 35 ++++++++ 5 files changed, 246 insertions(+), 5 deletions(-) diff --git a/runtime/src/bytecode.rs b/runtime/src/bytecode.rs index 4b6b96a..0cad457 100644 --- a/runtime/src/bytecode.rs +++ b/runtime/src/bytecode.rs @@ -5,6 +5,7 @@ use parser::position::Range; #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum Instruction { Pop, + Null, Constant(usize), Array(usize), HashMap(usize), diff --git a/runtime/src/compiler/mod.rs b/runtime/src/compiler/mod.rs index 9309de7..b0e4bb6 100644 --- a/runtime/src/compiler/mod.rs +++ b/runtime/src/compiler/mod.rs @@ -7,7 +7,7 @@ use crate::{ }; use parser::{ - ast::{self, PrefixOperatorKind}, + ast::{self, NodeKind, PrefixOperatorKind}, position::Range, }; @@ -112,7 +112,7 @@ impl Compiler { self.compile_node(&index.index)?; self.emit(Instruction::IndexGet, node.range); } - ast::NodeValue::If(_) => todo!(), + ast::NodeValue::If(if_node) => self.compile_if(if_node)?, ast::NodeValue::While(while_loop) => self.compile_while(while_loop)?, ast::NodeValue::For(for_loop) => self.compile_for(for_loop)?, ast::NodeValue::Break => todo!(), @@ -204,6 +204,33 @@ impl Compiler { Ok(()) } + fn compile_if(&mut self, if_node: &ast::IfNode) -> Result<(), Error> { + self.compile_node(&if_node.condition)?; + + // Jump to alternative + let jump_cons = self.emit(Instruction::JumpNotTruthy(0), if_node.condition.range); + + // Compile consequence and add jump to skip alternative + self.compile_block(&if_node.consequence, true)?; + let jump_alt = self.emit(Instruction::Jump(0), if_node.consequence.range); + + // Fix jump after consequence index. + let cons_index = self.current_scope().instructions.len(); + self.current_scope().instructions[jump_cons] = Instruction::JumpNotTruthy(cons_index); + + match &if_node.alternative { + Some(alt) => self.compile_block(alt, true)?, + None => { + self.emit(Instruction::Null, if_node.consequence.range); + } + } + + let alt_index = self.current_scope().instructions.len(); + self.current_scope().instructions[jump_alt] = Instruction::Jump(alt_index); + + Ok(()) + } + fn compile_while(&mut self, while_loop: &ast::While) -> Result<(), Error> { let start_index = self.current_scope().instructions.len(); self.compile_node(&while_loop.condition)?; @@ -211,7 +238,7 @@ impl Compiler { // Jump position will be fixed after let jump_index = self.emit(Instruction::JumpNotTruthy(0), while_loop.condition.range); - self.compile_block(&while_loop.body)?; + self.compile_block(&while_loop.body, false)?; self.emit(Instruction::Jump(start_index), while_loop.body.range); @@ -231,7 +258,7 @@ impl Compiler { let jump_index = self.emit(Instruction::JumpNotTruthy(0), for_loop.condition.range); // Compile the body - self.compile_block(&for_loop.body)?; + self.compile_block(&for_loop.body, false)?; self.compile_node(&for_loop.after)?; if for_loop.after.kind() == ast::NodeKind::Expression { self.emit(Instruction::Pop, for_loop.after.range); @@ -245,7 +272,14 @@ impl Compiler { Ok(()) } - fn compile_block(&mut self, block: &ast::Block) -> Result<(), Error> { + // Compiles block. If emit_last is true, last statement in the block will be left on stack. + // In case value was not pushed in the last node of the block, null will be pushed. + fn compile_block(&mut self, block: &ast::Block, emit_last: bool) -> Result<(), Error> { + if emit_last && block.nodes.is_empty() { + self.emit(Instruction::Null, block.range); + return Ok(()); + } + for node in &block.nodes { self.compile_node(node)?; @@ -254,6 +288,23 @@ impl Compiler { } } + if !emit_last { + return Ok(()); + } + + // We already handled empty block where emit last is true, so it's safe to unwrap. + let last = block.nodes.last().unwrap(); + match last.kind() { + NodeKind::Expression => { + // Remove the `pop` instruction + self.current_scope().instructions.pop(); + self.current_scope().ranges.pop(); + } + NodeKind::Statement => { + self.emit(Instruction::Null, last.range); + } + } + Ok(()) } diff --git a/runtime/src/compiler/test.rs b/runtime/src/compiler/test.rs index 29f6365..1f41181 100644 --- a/runtime/src/compiler/test.rs +++ b/runtime/src/compiler/test.rs @@ -937,3 +937,156 @@ fn for_loop() { assert_eq!(bytecode, expected); } + +#[test] +fn if_statement() { + let tests = [ + ( + "if (true) {}", + Bytecode { + constants: vec![Object::Boolean(true)], + instructions: vec![ + Instruction::Constant(0), + Instruction::JumpNotTruthy(4), + Instruction::Null, + Instruction::Jump(5), + Instruction::Null, + Instruction::Pop, + ], + ranges: vec![ + Range { + start: Position::new(0, 4), + end: Position::new(0, 8), + }, + Range { + start: Position::new(0, 4), + end: Position::new(0, 8), + }, + Range { + start: Position::new(0, 10), + end: Position::new(0, 12), + }, + Range { + start: Position::new(0, 10), + end: Position::new(0, 12), + }, + Range { + start: Position::new(0, 10), + end: Position::new(0, 12), + }, + Range { + start: Position::new(0, 0), + end: Position::new(0, 12), + }, + ], + }, + ), + ( + "if (true) {} else {}", + Bytecode { + constants: vec![Object::Boolean(true)], + instructions: vec![ + Instruction::Constant(0), + Instruction::JumpNotTruthy(4), + Instruction::Null, + Instruction::Jump(5), + Instruction::Null, + Instruction::Pop, + ], + ranges: vec![ + Range { + start: Position::new(0, 4), + end: Position::new(0, 8), + }, + Range { + start: Position::new(0, 4), + end: Position::new(0, 8), + }, + Range { + start: Position::new(0, 10), + end: Position::new(0, 12), + }, + Range { + start: Position::new(0, 10), + end: Position::new(0, 12), + }, + Range { + start: Position::new(0, 18), + end: Position::new(0, 20), + }, + Range { + start: Position::new(0, 0), + end: Position::new(0, 20), + }, + ], + }, + ), + ]; + + for (input, expected) in tests { + let program = parse(input).unwrap(); + let compiler = Compiler::new(); + let bytecode = compiler.compile(&program).unwrap(); + assert_eq!(bytecode, expected); + } + + // Ignore ranges + let tests = [ + ( + "if (true) {a = 0} else {10}", + Bytecode { + constants: vec![ + Object::Boolean(true), + Object::Integer(0), + Object::Integer(10), + ], + instructions: vec![ + Instruction::Constant(0), + Instruction::JumpNotTruthy(6), + Instruction::Constant(1), + Instruction::StoreGlobal(0), + Instruction::Null, + Instruction::Jump(7), + Instruction::Constant(2), + Instruction::Pop, + ], + ranges: vec![], + }, + ), + ( + "if (true) {a = 0} else if (false) {10}", + Bytecode { + constants: vec![ + Object::Boolean(true), + Object::Integer(0), + Object::Boolean(false), + Object::Integer(10), + ], + instructions: vec![ + Instruction::Constant(0), + Instruction::JumpNotTruthy(6), + Instruction::Constant(1), + Instruction::StoreGlobal(0), + Instruction::Null, + Instruction::Jump(11), + Instruction::Constant(2), + Instruction::JumpNotTruthy(10), + Instruction::Constant(3), + Instruction::Jump(11), + Instruction::Null, + Instruction::Pop, + ], + ranges: vec![], + }, + ), + ]; + + for (input, expected) in tests { + let program = parse(input).unwrap(); + let compiler = Compiler::new(); + let mut bytecode = compiler.compile(&program).unwrap(); + bytecode.ranges = vec![]; + + assert_eq!(bytecode, expected); + } +} diff --git a/runtime/src/vm/mod.rs b/runtime/src/vm/mod.rs index dfb96ce..7fa376a 100644 --- a/runtime/src/vm/mod.rs +++ b/runtime/src/vm/mod.rs @@ -83,6 +83,7 @@ impl VirtualMachine { fn execute_instruction(&mut self, ip: usize, bytecode: &Bytecode) -> Result { match bytecode.instructions[ip] { + Instruction::Null => self.push(Object::Null)?, Instruction::Constant(idx) => self.push(bytecode.constants[idx].clone())?, Instruction::Pop => { self.pop(); diff --git a/runtime/src/vm/test.rs b/runtime/src/vm/test.rs index 13093d0..5a49cbc 100644 --- a/runtime/src/vm/test.rs +++ b/runtime/src/vm/test.rs @@ -315,3 +315,38 @@ fn for_loop() { let input = "for (i = 0; i < 42; i = i + 1) {}\n i"; run_test(input, Ok(Object::Integer(42))); } + +#[test] +fn if_statement() { + let tests = [ + ("if (1 < 2) {10}", Object::Integer(10)), + ("if (1 > 2) {10}", Object::Null), + ("if (true) {10} else {}", Object::Integer(10)), + ("if (false) {10} else {20}", Object::Integer(20)), + ("if (false) {10} else if (false) {20}", Object::Null), + ( + "if (false) {10} else if (false) {20} else {30}", + Object::Integer(30), + ), + ( + "if (false) {10} else if (false) {20} else if (true) {30} else {40}", + Object::Integer(30), + ), + ( + r#" + if (1 * 2 * 3 - 5 == 1) { + a = 10 + a = a * 6 + a + 9 + } else { + 42 + } + "#, + Object::Integer(69), + ), + ]; + + for (input, expected) in tests { + run_test(input, Ok(expected)); + } +}