Implement functions, both top-level and anonymous
Implement both top-level and anonymous functions, but not closures in either case.
This commit is contained in:
		
							parent
							
								
									80f8ede0bb
								
							
						
					
					
						commit
						1ea2d8ba9f
					
				
					 10 changed files with 503 additions and 127 deletions
				
			
		
							
								
								
									
										10
									
								
								Cargo.lock
									
										
									
										generated
									
									
									
								
							
							
						
						
									
										10
									
								
								Cargo.lock
									
										
									
										generated
									
									
									
								
							|  | @ -8,6 +8,7 @@ dependencies = [ | |||
|  "clap", | ||||
|  "derive_more", | ||||
|  "inkwell", | ||||
|  "itertools", | ||||
|  "llvm-sys", | ||||
|  "nom", | ||||
|  "nom-trace", | ||||
|  | @ -245,6 +246,15 @@ dependencies = [ | |||
|  "cfg-if", | ||||
| ] | ||||
| 
 | ||||
| [[package]] | ||||
| name = "itertools" | ||||
| version = "0.10.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "37d572918e350e82412fe766d24b15e6682fb2ed2bbe018280caa810397cb319" | ||||
| dependencies = [ | ||||
|  "either", | ||||
| ] | ||||
| 
 | ||||
| [[package]] | ||||
| name = "lazy_static" | ||||
| version = "1.4.0" | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ anyhow = "1.0.38" | |||
| clap = "3.0.0-beta.2" | ||||
| derive_more = "0.99.11" | ||||
| inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] } | ||||
| itertools = "0.10.0" | ||||
| llvm-sys = "110.0.1" | ||||
| nom = "6.1.2" | ||||
| nom-trace = { git = "https://github.com/glittershark/nom-trace", branch = "nom-6" } | ||||
|  |  | |||
|  | @ -2,10 +2,12 @@ use std::borrow::Cow; | |||
| use std::convert::TryFrom; | ||||
| use std::fmt::{self, Display, Formatter}; | ||||
| 
 | ||||
| use itertools::Itertools; | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, Eq)] | ||||
| pub struct InvalidIdentifier<'a>(Cow<'a, str>); | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, Eq, Hash)] | ||||
| #[derive(Debug, PartialEq, Eq, Hash, Clone)] | ||||
| pub struct Ident<'a>(pub Cow<'a, str>); | ||||
| 
 | ||||
| impl<'a> From<&'a Ident<'a>> for &'a str { | ||||
|  | @ -69,7 +71,7 @@ impl<'a> TryFrom<String> for Ident<'a> { | |||
|     } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, Eq)] | ||||
| #[derive(Debug, PartialEq, Eq, Copy, Clone)] | ||||
| pub enum BinaryOperator { | ||||
|     /// `+`
 | ||||
|     Add, | ||||
|  | @ -93,7 +95,7 @@ pub enum BinaryOperator { | |||
|     Neq, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, Eq)] | ||||
| #[derive(Debug, PartialEq, Eq, Copy, Clone)] | ||||
| pub enum UnaryOperator { | ||||
|     /// !
 | ||||
|     Not, | ||||
|  | @ -102,12 +104,12 @@ pub enum UnaryOperator { | |||
|     Neg, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, Eq)] | ||||
| #[derive(Debug, PartialEq, Eq, Clone)] | ||||
| pub enum Literal { | ||||
|     Int(u64), | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, Eq)] | ||||
| #[derive(Debug, PartialEq, Eq, Clone)] | ||||
| pub enum Expr<'a> { | ||||
|     Ident(Ident<'a>), | ||||
| 
 | ||||
|  | @ -134,33 +136,101 @@ pub enum Expr<'a> { | |||
|         then: Box<Expr<'a>>, | ||||
|         else_: Box<Expr<'a>>, | ||||
|     }, | ||||
| 
 | ||||
|     Fun(Box<Fun<'a>>), | ||||
| 
 | ||||
|     Call { | ||||
|         fun: Box<Expr<'a>>, | ||||
|         args: Vec<Expr<'a>>, | ||||
|     }, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, Eq)] | ||||
| impl<'a> Expr<'a> { | ||||
|     pub fn to_owned(&self) -> Expr<'static> { | ||||
|         match self { | ||||
|             Expr::Ident(ref id) => Expr::Ident(id.to_owned()), | ||||
|             Expr::Literal(ref lit) => Expr::Literal(lit.clone()), | ||||
|             Expr::UnaryOp { op, rhs } => Expr::UnaryOp { | ||||
|                 op: *op, | ||||
|                 rhs: Box::new((**rhs).to_owned()), | ||||
|             }, | ||||
|             Expr::BinaryOp { lhs, op, rhs } => Expr::BinaryOp { | ||||
|                 lhs: Box::new((**lhs).to_owned()), | ||||
|                 op: *op, | ||||
|                 rhs: Box::new((**rhs).to_owned()), | ||||
|             }, | ||||
|             Expr::Let { bindings, body } => Expr::Let { | ||||
|                 bindings: bindings | ||||
|                     .iter() | ||||
|                     .map(|(id, expr)| (id.to_owned(), expr.to_owned())) | ||||
|                     .collect(), | ||||
|                 body: Box::new((**body).to_owned()), | ||||
|             }, | ||||
|             Expr::If { | ||||
|                 condition, | ||||
|                 then, | ||||
|                 else_, | ||||
|             } => Expr::If { | ||||
|                 condition: Box::new((**condition).to_owned()), | ||||
|                 then: Box::new((**then).to_owned()), | ||||
|                 else_: Box::new((**else_).to_owned()), | ||||
|             }, | ||||
|             Expr::Fun(fun) => Expr::Fun(Box::new((**fun).to_owned())), | ||||
|             Expr::Call { fun, args } => Expr::Call { | ||||
|                 fun: Box::new((**fun).to_owned()), | ||||
|                 args: args.iter().map(|arg| arg.to_owned()).collect(), | ||||
|             }, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, Eq, Clone)] | ||||
| pub struct Fun<'a> { | ||||
|     pub name: Ident<'a>, | ||||
|     pub args: Vec<Ident<'a>>, | ||||
|     pub body: Expr<'a>, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, Eq)] | ||||
| pub enum Decl<'a> { | ||||
|     Fun(Fun<'a>), | ||||
| impl<'a> Fun<'a> { | ||||
|     fn to_owned(&self) -> Fun<'static> { | ||||
|         Fun { | ||||
|             args: self.args.iter().map(|arg| arg.to_owned()).collect(), | ||||
|             body: self.body.to_owned(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, Eq, Clone, Copy)] | ||||
| #[derive(Debug, PartialEq, Eq)] | ||||
| pub enum Decl<'a> { | ||||
|     Fun { name: Ident<'a>, body: Fun<'a> }, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, Eq, Clone)] | ||||
| pub struct FunctionType { | ||||
|     pub args: Vec<Type>, | ||||
|     pub ret: Box<Type>, | ||||
| } | ||||
| 
 | ||||
| impl Display for FunctionType { | ||||
|     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { | ||||
|         write!(f, "fn {} -> {}", self.args.iter().join(", "), self.ret) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, Eq, Clone)] | ||||
| pub enum Type { | ||||
|     Int, | ||||
|     Float, | ||||
|     Bool, | ||||
|     Function(FunctionType), | ||||
| } | ||||
| 
 | ||||
| impl Display for Type { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||
|         match self { | ||||
|             Self::Int => f.write_str("int"), | ||||
|             Self::Float => f.write_str("float"), | ||||
|             Self::Bool => f.write_str("bool"), | ||||
|             Type::Int => f.write_str("int"), | ||||
|             Type::Float => f.write_str("float"), | ||||
|             Type::Bool => f.write_str("bool"), | ||||
|             Type::Function(ft) => ft.fmt(f), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -1,3 +1,4 @@ | |||
| use std::convert::{TryFrom, TryInto}; | ||||
| use std::path::Path; | ||||
| use std::result; | ||||
| 
 | ||||
|  | @ -7,7 +8,7 @@ pub use inkwell::context::Context; | |||
| use inkwell::module::Module; | ||||
| use inkwell::support::LLVMString; | ||||
| use inkwell::types::FunctionType; | ||||
| use inkwell::values::{BasicValueEnum, FunctionValue}; | ||||
| use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue}; | ||||
| use inkwell::IntPredicate; | ||||
| use thiserror::Error; | ||||
| 
 | ||||
|  | @ -35,8 +36,9 @@ pub struct Codegen<'ctx, 'ast> { | |||
|     context: &'ctx Context, | ||||
|     pub module: Module<'ctx>, | ||||
|     builder: Builder<'ctx>, | ||||
|     env: Env<'ast, BasicValueEnum<'ctx>>, | ||||
|     function: Option<FunctionValue<'ctx>>, | ||||
|     env: Env<'ast, AnyValueEnum<'ctx>>, | ||||
|     function_stack: Vec<FunctionValue<'ctx>>, | ||||
|     identifier_counter: u32, | ||||
| } | ||||
| 
 | ||||
| impl<'ctx, 'ast> Codegen<'ctx, 'ast> { | ||||
|  | @ -48,7 +50,8 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { | |||
|             module, | ||||
|             builder, | ||||
|             env: Default::default(), | ||||
|             function: None, | ||||
|             function_stack: Default::default(), | ||||
|             identifier_counter: 0, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|  | @ -57,22 +60,24 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { | |||
|         name: &str, | ||||
|         ty: FunctionType<'ctx>, | ||||
|     ) -> &'a FunctionValue<'ctx> { | ||||
|         self.function = Some(self.module.add_function(name, ty, None)); | ||||
|         self.function_stack | ||||
|             .push(self.module.add_function(name, ty, None)); | ||||
|         let basic_block = self.append_basic_block("entry"); | ||||
|         self.builder.position_at_end(basic_block); | ||||
|         self.function.as_ref().unwrap() | ||||
|         self.function_stack.last().unwrap() | ||||
|     } | ||||
| 
 | ||||
|     pub fn finish_function(&self, res: &BasicValueEnum<'ctx>) { | ||||
|     pub fn finish_function(&mut self, res: &BasicValueEnum<'ctx>) -> FunctionValue<'ctx> { | ||||
|         self.builder.build_return(Some(res)); | ||||
|         self.function_stack.pop().unwrap() | ||||
|     } | ||||
| 
 | ||||
|     pub fn append_basic_block(&self, name: &str) -> BasicBlock<'ctx> { | ||||
|         self.context | ||||
|             .append_basic_block(self.function.unwrap(), name) | ||||
|             .append_basic_block(*self.function_stack.last().unwrap(), name) | ||||
|     } | ||||
| 
 | ||||
|     pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast>) -> Result<BasicValueEnum<'ctx>> { | ||||
|     pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast>) -> Result<AnyValueEnum<'ctx>> { | ||||
|         match expr { | ||||
|             Expr::Ident(id) => self | ||||
|                 .env | ||||
|  | @ -81,13 +86,13 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { | |||
|                 .ok_or_else(|| Error::UndefinedVariable(id.to_owned())), | ||||
|             Expr::Literal(Literal::Int(i)) => { | ||||
|                 let ty = self.context.i64_type(); | ||||
|                 Ok(BasicValueEnum::IntValue(ty.const_int(*i, false))) | ||||
|                 Ok(AnyValueEnum::IntValue(ty.const_int(*i, false))) | ||||
|             } | ||||
|             Expr::UnaryOp { op, rhs } => { | ||||
|                 let rhs = self.codegen_expr(rhs)?; | ||||
|                 match op { | ||||
|                     UnaryOperator::Not => unimplemented!(), | ||||
|                     UnaryOperator::Neg => Ok(BasicValueEnum::IntValue( | ||||
|                     UnaryOperator::Neg => Ok(AnyValueEnum::IntValue( | ||||
|                         self.builder.build_int_neg(rhs.into_int_value(), "neg"), | ||||
|                     )), | ||||
|                 } | ||||
|  | @ -96,29 +101,23 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { | |||
|                 let lhs = self.codegen_expr(lhs)?; | ||||
|                 let rhs = self.codegen_expr(rhs)?; | ||||
|                 match op { | ||||
|                     BinaryOperator::Add => { | ||||
|                         Ok(BasicValueEnum::IntValue(self.builder.build_int_add( | ||||
|                     BinaryOperator::Add => Ok(AnyValueEnum::IntValue(self.builder.build_int_add( | ||||
|                         lhs.into_int_value(), | ||||
|                         rhs.into_int_value(), | ||||
|                         "add", | ||||
|                         ))) | ||||
|                     } | ||||
|                     BinaryOperator::Sub => { | ||||
|                         Ok(BasicValueEnum::IntValue(self.builder.build_int_sub( | ||||
|                     ))), | ||||
|                     BinaryOperator::Sub => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub( | ||||
|                         lhs.into_int_value(), | ||||
|                         rhs.into_int_value(), | ||||
|                         "add", | ||||
|                         ))) | ||||
|                     } | ||||
|                     BinaryOperator::Mul => { | ||||
|                         Ok(BasicValueEnum::IntValue(self.builder.build_int_sub( | ||||
|                     ))), | ||||
|                     BinaryOperator::Mul => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub( | ||||
|                         lhs.into_int_value(), | ||||
|                         rhs.into_int_value(), | ||||
|                         "add", | ||||
|                         ))) | ||||
|                     } | ||||
|                     ))), | ||||
|                     BinaryOperator::Div => { | ||||
|                         Ok(BasicValueEnum::IntValue(self.builder.build_int_signed_div( | ||||
|                         Ok(AnyValueEnum::IntValue(self.builder.build_int_signed_div( | ||||
|                             lhs.into_int_value(), | ||||
|                             rhs.into_int_value(), | ||||
|                             "add", | ||||
|  | @ -126,7 +125,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { | |||
|                     } | ||||
|                     BinaryOperator::Pow => unimplemented!(), | ||||
|                     BinaryOperator::Equ => { | ||||
|                         Ok(BasicValueEnum::IntValue(self.builder.build_int_compare( | ||||
|                         Ok(AnyValueEnum::IntValue(self.builder.build_int_compare( | ||||
|                             IntPredicate::EQ, | ||||
|                             lhs.into_int_value(), | ||||
|                             rhs.into_int_value(), | ||||
|  | @ -170,18 +169,56 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { | |||
| 
 | ||||
|                 self.builder.position_at_end(join_block); | ||||
|                 let phi = self.builder.build_phi(self.context.i64_type(), "join"); | ||||
|                 phi.add_incoming(&[(&then_res, then_block), (&else_res, else_block)]); | ||||
|                 Ok(phi.as_basic_value()) | ||||
|                 phi.add_incoming(&[ | ||||
|                     (&BasicValueEnum::try_from(then_res).unwrap(), then_block), | ||||
|                     (&BasicValueEnum::try_from(else_res).unwrap(), else_block), | ||||
|                 ]); | ||||
|                 Ok(phi.as_basic_value().into()) | ||||
|             } | ||||
|             Expr::Call { fun, args } => { | ||||
|                 if let Expr::Ident(id) = &**fun { | ||||
|                     let function = self | ||||
|                         .module | ||||
|                         .get_function(id.into()) | ||||
|                         .or_else(|| self.env.resolve(id)?.clone().try_into().ok()) | ||||
|                         .ok_or_else(|| Error::UndefinedVariable(id.to_owned()))?; | ||||
|                     let args = args | ||||
|                         .iter() | ||||
|                         .map(|arg| Ok(self.codegen_expr(arg)?.try_into().unwrap())) | ||||
|                         .collect::<Result<Vec<_>>>()?; | ||||
|                     Ok(self | ||||
|                         .builder | ||||
|                         .build_call(function, &args, "call") | ||||
|                         .try_as_basic_value() | ||||
|                         .left() | ||||
|                         .unwrap() | ||||
|                         .into()) | ||||
|                 } else { | ||||
|                     todo!() | ||||
|                 } | ||||
|             } | ||||
|             Expr::Fun(fun) => { | ||||
|                 let Fun { args, body } = &**fun; | ||||
|                 let fname = self.fresh_ident("f"); | ||||
|                 let cur_block = self.builder.get_insert_block().unwrap(); | ||||
|                 let env = self.env.save(); // TODO: closures
 | ||||
|                 let function = self.codegen_function(&fname, args, body)?; | ||||
|                 self.builder.position_at_end(cur_block); | ||||
|                 self.env.restore(env); | ||||
|                 Ok(function.into()) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast>) -> Result<()> { | ||||
|         match decl { | ||||
|             Decl::Fun(Fun { name, args, body }) => { | ||||
|     pub fn codegen_function( | ||||
|         &mut self, | ||||
|         name: &str, | ||||
|         args: &'ast [Ident<'ast>], | ||||
|         body: &'ast Expr<'ast>, | ||||
|     ) -> Result<FunctionValue<'ctx>> { | ||||
|         let i64_type = self.context.i64_type(); | ||||
|         self.new_function( | ||||
|                     name.into(), | ||||
|             name, | ||||
|             i64_type.fn_type( | ||||
|                 args.iter() | ||||
|                     .map(|_| i64_type.into()) | ||||
|  | @ -192,12 +229,23 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { | |||
|         ); | ||||
|         self.env.push(); | ||||
|         for (i, arg) in args.iter().enumerate() { | ||||
|                     self.env | ||||
|                         .set(arg, self.function.unwrap().get_nth_param(i as u32).unwrap()); | ||||
|             self.env.set( | ||||
|                 arg, | ||||
|                 self.cur_function().get_nth_param(i as u32).unwrap().into(), | ||||
|             ); | ||||
|         } | ||||
|                 let res = self.codegen_expr(body)?; | ||||
|         let res = self.codegen_expr(body)?.try_into().unwrap(); | ||||
|         self.env.pop(); | ||||
|                 self.finish_function(&res); | ||||
|         Ok(self.finish_function(&res)) | ||||
|     } | ||||
| 
 | ||||
|     pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast>) -> Result<()> { | ||||
|         match decl { | ||||
|             Decl::Fun { | ||||
|                 name, | ||||
|                 body: Fun { args, body }, | ||||
|             } => { | ||||
|                 self.codegen_function(name.into(), args, body)?; | ||||
|                 Ok(()) | ||||
|             } | ||||
|         } | ||||
|  | @ -205,7 +253,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { | |||
| 
 | ||||
|     pub fn codegen_main(&mut self, expr: &'ast Expr<'ast>) -> Result<()> { | ||||
|         self.new_function("main", self.context.i64_type().fn_type(&[], false)); | ||||
|         let res = self.codegen_expr(expr)?; | ||||
|         let res = self.codegen_expr(expr)?.try_into().unwrap(); | ||||
|         self.finish_function(&res); | ||||
|         Ok(()) | ||||
|     } | ||||
|  | @ -229,6 +277,15 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { | |||
|             )) | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn fresh_ident(&mut self, prefix: &str) -> String { | ||||
|         self.identifier_counter += 1; | ||||
|         format!("{}{}", prefix, self.identifier_counter) | ||||
|     } | ||||
| 
 | ||||
|     fn cur_function(&self) -> &FunctionValue<'ctx> { | ||||
|         self.function_stack.last().unwrap() | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[cfg(test)] | ||||
|  | @ -248,9 +305,7 @@ mod tests { | |||
|             .create_jit_execution_engine(OptimizationLevel::None) | ||||
|             .unwrap(); | ||||
| 
 | ||||
|         codegen.new_function("test", context.i64_type().fn_type(&[], false)); | ||||
|         let res = codegen.codegen_expr(&expr)?; | ||||
|         codegen.finish_function(&res); | ||||
|         codegen.codegen_function("test", &[], &expr)?; | ||||
| 
 | ||||
|         unsafe { | ||||
|             let fun: JitFunction<unsafe extern "C" fn() -> T> = | ||||
|  | @ -279,4 +334,10 @@ mod tests { | |||
|             2 | ||||
|         ); | ||||
|     } | ||||
| 
 | ||||
|     #[test] | ||||
|     fn function_call() { | ||||
|         let res = jit_eval::<i64>("let id = fn x = x in id 1").unwrap(); | ||||
|         assert_eq!(res, 1); | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -14,9 +14,7 @@ pub fn jit_eval<T>(expr: &Expr) -> Result<T> { | |||
|         .module | ||||
|         .create_jit_execution_engine(OptimizationLevel::None) | ||||
|         .map_err(Error::from)?; | ||||
|     codegen.new_function("eval", context.i64_type().fn_type(&[], false)); | ||||
|     let res = codegen.codegen_expr(&expr)?; | ||||
|     codegen.finish_function(&res); | ||||
|     codegen.codegen_function("test", &[], &expr)?; | ||||
| 
 | ||||
|     unsafe { | ||||
|         let fun: JitFunction<unsafe extern "C" fn() -> T> = | ||||
|  |  | |||
|  | @ -5,10 +5,13 @@ use clap::Clap; | |||
| use crate::common::Result; | ||||
| use crate::compiler::{self, CompilerOptions}; | ||||
| 
 | ||||
| /// Compile a source file
 | ||||
| #[derive(Clap)] | ||||
| pub struct Compile { | ||||
|     /// File to compile
 | ||||
|     file: PathBuf, | ||||
| 
 | ||||
|     /// Output file
 | ||||
|     #[clap(short = 'o')] | ||||
|     out_file: PathBuf, | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,4 +1,5 @@ | |||
| use std::collections::HashMap; | ||||
| use std::mem; | ||||
| 
 | ||||
| use crate::ast::Ident; | ||||
| 
 | ||||
|  | @ -25,6 +26,14 @@ impl<'ast, V> Env<'ast, V> { | |||
|         self.0.pop(); | ||||
|     } | ||||
| 
 | ||||
|     pub fn save(&mut self) -> Self { | ||||
|         mem::take(self) | ||||
|     } | ||||
| 
 | ||||
|     pub fn restore(&mut self, saved: Self) { | ||||
|         *self = saved; | ||||
|     } | ||||
| 
 | ||||
|     pub fn set(&mut self, k: &'ast Ident<'ast>, v: V) { | ||||
|         self.0.last_mut().unwrap().insert(k, v); | ||||
|     } | ||||
|  |  | |||
|  | @ -2,13 +2,13 @@ mod error; | |||
| mod value; | ||||
| 
 | ||||
| pub use self::error::{Error, Result}; | ||||
| pub use self::value::Value; | ||||
| use crate::ast::{BinaryOperator, Expr, Ident, Literal, UnaryOperator}; | ||||
| pub use self::value::{Function, Value}; | ||||
| use crate::ast::{BinaryOperator, Expr, FunctionType, Ident, Literal, Type, UnaryOperator}; | ||||
| use crate::common::env::Env; | ||||
| 
 | ||||
| #[derive(Debug, Default)] | ||||
| pub struct Interpreter<'a> { | ||||
|     env: Env<'a, Value>, | ||||
|     env: Env<'a, Value<'a>>, | ||||
| } | ||||
| 
 | ||||
| impl<'a> Interpreter<'a> { | ||||
|  | @ -16,14 +16,14 @@ impl<'a> Interpreter<'a> { | |||
|         Self::default() | ||||
|     } | ||||
| 
 | ||||
|     fn resolve(&self, var: &'a Ident<'a>) -> Result<Value> { | ||||
|     fn resolve(&self, var: &'a Ident<'a>) -> Result<Value<'a>> { | ||||
|         self.env | ||||
|             .resolve(var) | ||||
|             .cloned() | ||||
|             .ok_or_else(|| Error::UndefinedVariable(var.to_owned())) | ||||
|     } | ||||
| 
 | ||||
|     pub fn eval(&mut self, expr: &'a Expr<'a>) -> Result<Value> { | ||||
|     pub fn eval(&mut self, expr: &'a Expr<'a>) -> Result<Value<'a>> { | ||||
|         match expr { | ||||
|             Expr::Ident(id) => self.resolve(id), | ||||
|             Expr::Literal(Literal::Int(i)) => Ok((*i).into()), | ||||
|  | @ -63,12 +63,44 @@ impl<'a> Interpreter<'a> { | |||
|                 else_, | ||||
|             } => { | ||||
|                 let condition = self.eval(condition)?; | ||||
|                 if *(condition.into_type::<bool>()?) { | ||||
|                 if *(condition.as_type::<bool>()?) { | ||||
|                     self.eval(then) | ||||
|                 } else { | ||||
|                     self.eval(else_) | ||||
|                 } | ||||
|             } | ||||
|             Expr::Call { ref fun, args } => { | ||||
|                 let fun = self.eval(fun)?; | ||||
|                 let expected_type = FunctionType { | ||||
|                     args: args.iter().map(|_| Type::Int).collect(), | ||||
|                     ret: Box::new(Type::Int), | ||||
|                 }; | ||||
| 
 | ||||
|                 let Function { | ||||
|                     args: function_args, | ||||
|                     body, | ||||
|                     .. | ||||
|                 } = fun.as_function(expected_type)?; | ||||
|                 let arg_values = function_args.iter().zip( | ||||
|                     args.iter() | ||||
|                         .map(|v| self.eval(v)) | ||||
|                         .collect::<Result<Vec<_>>>()?, | ||||
|                 ); | ||||
|                 let mut interpreter = Interpreter::new(); | ||||
|                 for (arg_name, arg_value) in arg_values { | ||||
|                     interpreter.env.set(arg_name, arg_value); | ||||
|                 } | ||||
|                 Ok(Value::from(*interpreter.eval(body)?.as_type::<i64>()?)) | ||||
|             } | ||||
|             Expr::Fun(fun) => Ok(Value::from(value::Function { | ||||
|                 // TODO
 | ||||
|                 type_: FunctionType { | ||||
|                     args: fun.args.iter().map(|_| Type::Int).collect(), | ||||
|                     ret: Box::new(Type::Int), | ||||
|                 }, | ||||
|                 args: fun.args.iter().map(|arg| arg.to_owned()).collect(), | ||||
|                 body: fun.body.to_owned(), | ||||
|             })), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | @ -92,12 +124,12 @@ mod tests { | |||
| 
 | ||||
|     fn parse_eval<T>(src: &str) -> T | ||||
|     where | ||||
|         for<'a> &'a T: TryFrom<&'a Val>, | ||||
|         for<'a> &'a T: TryFrom<&'a Val<'a>>, | ||||
|         T: Clone + TypeOf, | ||||
|     { | ||||
|         let expr = crate::parser::expr(src).unwrap().1; | ||||
|         let res = eval(&expr).unwrap(); | ||||
|         res.into_type::<T>().unwrap().clone() | ||||
|         res.as_type::<T>().unwrap().clone() | ||||
|     } | ||||
| 
 | ||||
|     #[test] | ||||
|  | @ -108,7 +140,7 @@ mod tests { | |||
|             rhs: int_lit(2), | ||||
|         }; | ||||
|         let res = eval(&expr).unwrap(); | ||||
|         assert_eq!(*res.into_type::<i64>().unwrap(), 2); | ||||
|         assert_eq!(*res.as_type::<i64>().unwrap(), 2); | ||||
|     } | ||||
| 
 | ||||
|     #[test] | ||||
|  | @ -122,4 +154,10 @@ mod tests { | |||
|         let res = parse_eval::<i64>("let x = 1 in if x == 1 then 2 else 4"); | ||||
|         assert_eq!(res, 2); | ||||
|     } | ||||
| 
 | ||||
|     #[test] | ||||
|     fn function_call() { | ||||
|         let res = parse_eval::<i64>("let id = fn x = x in id 1"); | ||||
|         assert_eq!(res, 1); | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -6,108 +6,153 @@ use std::rc::Rc; | |||
| use derive_more::{Deref, From, TryInto}; | ||||
| 
 | ||||
| use super::{Error, Result}; | ||||
| use crate::ast::Type; | ||||
| use crate::ast::{Expr, FunctionType, Ident, Type}; | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, From, TryInto)] | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct Function<'a> { | ||||
|     pub type_: FunctionType, | ||||
|     pub args: Vec<Ident<'a>>, | ||||
|     pub body: Expr<'a>, | ||||
| } | ||||
| 
 | ||||
| #[derive(From, TryInto)] | ||||
| #[try_into(owned, ref)] | ||||
| pub enum Val { | ||||
| pub enum Val<'a> { | ||||
|     Int(i64), | ||||
|     Float(f64), | ||||
|     Bool(bool), | ||||
|     Function(Function<'a>), | ||||
| } | ||||
| 
 | ||||
| impl From<u64> for Val { | ||||
| impl<'a> fmt::Debug for Val<'a> { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||
|         match self { | ||||
|             Val::Int(x) => f.debug_tuple("Int").field(x).finish(), | ||||
|             Val::Float(x) => f.debug_tuple("Float").field(x).finish(), | ||||
|             Val::Bool(x) => f.debug_tuple("Bool").field(x).finish(), | ||||
|             Val::Function(Function { type_, .. }) => { | ||||
|                 f.debug_struct("Function").field("type_", type_).finish() | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<'a> PartialEq for Val<'a> { | ||||
|     fn eq(&self, other: &Self) -> bool { | ||||
|         match (self, other) { | ||||
|             (Val::Int(x), Val::Int(y)) => x == y, | ||||
|             (Val::Float(x), Val::Float(y)) => x == y, | ||||
|             (Val::Bool(x), Val::Bool(y)) => x == y, | ||||
|             (Val::Function(_), Val::Function(_)) => false, | ||||
|             (_, _) => false, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<'a> From<u64> for Val<'a> { | ||||
|     fn from(i: u64) -> Self { | ||||
|         Self::from(i as i64) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Display for Val { | ||||
| impl<'a> Display for Val<'a> { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||
|         match self { | ||||
|             Val::Int(x) => x.fmt(f), | ||||
|             Val::Float(x) => x.fmt(f), | ||||
|             Val::Bool(x) => x.fmt(f), | ||||
|             Val::Function(Function { type_, .. }) => write!(f, "<{}>", type_), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Val { | ||||
| impl<'a> Val<'a> { | ||||
|     pub fn type_(&self) -> Type { | ||||
|         match self { | ||||
|             Val::Int(_) => Type::Int, | ||||
|             Val::Float(_) => Type::Float, | ||||
|             Val::Bool(_) => Type::Bool, | ||||
|             Val::Function(Function { type_, .. }) => Type::Function(type_.clone()), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn into_type<'a, T>(&'a self) -> Result<&'a T> | ||||
|     pub fn as_type<'b, T>(&'b self) -> Result<&'b T> | ||||
|     where | ||||
|         T: TypeOf + 'a + Clone, | ||||
|         &'a T: TryFrom<&'a Self>, | ||||
|         T: TypeOf + 'b + Clone, | ||||
|         &'b T: TryFrom<&'b Self>, | ||||
|     { | ||||
|         <&T>::try_from(self).map_err(|_| Error::InvalidType { | ||||
|             actual: self.type_(), | ||||
|             expected: <T as TypeOf>::type_of(), | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|     pub fn as_function<'b>(&'b self, function_type: FunctionType) -> Result<&'b Function<'a>> { | ||||
|         match self { | ||||
|             Val::Function(f) if f.type_ == function_type => Ok(&f), | ||||
|             _ => Err(Error::InvalidType { | ||||
|                 actual: self.type_(), | ||||
|                 expected: Type::Function(function_type), | ||||
|             }), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, Clone, Deref)] | ||||
| pub struct Value(Rc<Val>); | ||||
| pub struct Value<'a>(Rc<Val<'a>>); | ||||
| 
 | ||||
| impl Display for Value { | ||||
| impl<'a> Display for Value<'a> { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||
|         self.0.fmt(f) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<T> From<T> for Value | ||||
| impl<'a, T> From<T> for Value<'a> | ||||
| where | ||||
|     Val: From<T>, | ||||
|     Val<'a>: From<T>, | ||||
| { | ||||
|     fn from(x: T) -> Self { | ||||
|         Self(Rc::new(x.into())) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Neg for Value { | ||||
|     type Output = Result<Value>; | ||||
| impl<'a> Neg for Value<'a> { | ||||
|     type Output = Result<Value<'a>>; | ||||
| 
 | ||||
|     fn neg(self) -> Self::Output { | ||||
|         Ok((-self.into_type::<i64>()?).into()) | ||||
|         Ok((-self.as_type::<i64>()?).into()) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Add for Value { | ||||
|     type Output = Result<Value>; | ||||
| impl<'a> Add for Value<'a> { | ||||
|     type Output = Result<Value<'a>>; | ||||
| 
 | ||||
|     fn add(self, rhs: Self) -> Self::Output { | ||||
|         Ok((self.into_type::<i64>()? + rhs.into_type::<i64>()?).into()) | ||||
|         Ok((self.as_type::<i64>()? + rhs.as_type::<i64>()?).into()) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Sub for Value { | ||||
|     type Output = Result<Value>; | ||||
| impl<'a> Sub for Value<'a> { | ||||
|     type Output = Result<Value<'a>>; | ||||
| 
 | ||||
|     fn sub(self, rhs: Self) -> Self::Output { | ||||
|         Ok((self.into_type::<i64>()? - rhs.into_type::<i64>()?).into()) | ||||
|         Ok((self.as_type::<i64>()? - rhs.as_type::<i64>()?).into()) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Mul for Value { | ||||
|     type Output = Result<Value>; | ||||
| impl<'a> Mul for Value<'a> { | ||||
|     type Output = Result<Value<'a>>; | ||||
| 
 | ||||
|     fn mul(self, rhs: Self) -> Self::Output { | ||||
|         Ok((self.into_type::<i64>()? * rhs.into_type::<i64>()?).into()) | ||||
|         Ok((self.as_type::<i64>()? * rhs.as_type::<i64>()?).into()) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Div for Value { | ||||
|     type Output = Result<Value>; | ||||
| impl<'a> Div for Value<'a> { | ||||
|     type Output = Result<Value<'a>>; | ||||
| 
 | ||||
|     fn div(self, rhs: Self) -> Self::Output { | ||||
|         Ok((self.into_type::<f64>()? / rhs.into_type::<f64>()?).into()) | ||||
|         Ok((self.as_type::<f64>()? / rhs.as_type::<f64>()?).into()) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -156,6 +156,10 @@ where | |||
|     } | ||||
| } | ||||
| 
 | ||||
| fn is_reserved(s: &str) -> bool { | ||||
|     matches!(s, "if" | "then" | "else" | "let" | "in" | "fn") | ||||
| } | ||||
| 
 | ||||
| fn ident<'a, E>(i: &'a str) -> nom::IResult<&'a str, Ident, E> | ||||
| where | ||||
|     E: ParseError<&'a str>, | ||||
|  | @ -170,7 +174,12 @@ where | |||
|                 } | ||||
|                 idx += 1; | ||||
|             } | ||||
|             Ok((&i[idx..], Ident::from_str_unchecked(&i[..idx]))) | ||||
|             let id = &i[..idx]; | ||||
|             if is_reserved(id) { | ||||
|                 Err(nom::Err::Error(E::from_error_kind(i, ErrorKind::Satisfy))) | ||||
|             } else { | ||||
|                 Ok((&i[idx..], Ident::from_str_unchecked(id))) | ||||
|             } | ||||
|         } else { | ||||
|             Err(nom::Err::Error(E::from_error_kind(i, ErrorKind::Satisfy))) | ||||
|         } | ||||
|  | @ -228,14 +237,65 @@ named!(if_(&str) -> Expr, do_parse! ( | |||
| 
 | ||||
| named!(ident_expr(&str) -> Expr, map!(ident, Expr::Ident)); | ||||
| 
 | ||||
| named!(paren_expr(&str) -> Expr, | ||||
|        delimited!(complete!(tag!("(")), expr, complete!(tag!(")")))); | ||||
| 
 | ||||
| named!(funcref(&str) -> Expr, alt!( | ||||
|     ident_expr | | ||||
|     paren_expr | ||||
| )); | ||||
| 
 | ||||
| named!(no_arg_call(&str) -> Expr, do_parse!( | ||||
|     fun: funcref | ||||
|         >> multispace0 | ||||
|         >> complete!(tag!("()")) | ||||
|         >> (Expr::Call { | ||||
|             fun: Box::new(fun), | ||||
|             args: vec![], | ||||
|         }) | ||||
| )); | ||||
| 
 | ||||
| named!(fun_expr(&str) -> Expr, do_parse!( | ||||
|     tag!("fn") | ||||
|         >> multispace1 | ||||
|         >> args: separated_list0!(multispace1, ident) | ||||
|         >> multispace0 | ||||
|         >> char!('=') | ||||
|         >> multispace0 | ||||
|         >> body: expr | ||||
|         >> (Expr::Fun(Box::new(Fun { | ||||
|             args, | ||||
|             body | ||||
|         }))) | ||||
| )); | ||||
| 
 | ||||
| named!(arg(&str) -> Expr, alt!( | ||||
|     ident_expr | | ||||
|     literal | | ||||
|     paren_expr | ||||
| )); | ||||
| 
 | ||||
| named!(call_with_args(&str) -> Expr, do_parse!( | ||||
|     fun: funcref | ||||
|         >> multispace1 | ||||
|         >> args: separated_list1!(multispace1, arg) | ||||
|         >> (Expr::Call { | ||||
|             fun: Box::new(fun), | ||||
|             args | ||||
|         }) | ||||
| )); | ||||
| 
 | ||||
| named!(simple_expr(&str) -> Expr, alt!( | ||||
|     let_ | | ||||
|     if_ | | ||||
|     fun_expr | | ||||
|     literal | | ||||
|     ident_expr | ||||
| )); | ||||
| 
 | ||||
| named!(pub expr(&str) -> Expr, alt!( | ||||
|     no_arg_call | | ||||
|     call_with_args | | ||||
|     map!(token_tree, |tt| { | ||||
|         ExprParser.parse(&mut tt.into_iter()).unwrap() | ||||
|     }) | | ||||
|  | @ -243,8 +303,8 @@ named!(pub expr(&str) -> Expr, alt!( | |||
| 
 | ||||
| //////
 | ||||
| 
 | ||||
| named!(fun(&str) -> Fun, do_parse!( | ||||
|     tag!("fn") | ||||
| named!(fun_decl(&str) -> Decl, do_parse!( | ||||
|     complete!(tag!("fn")) | ||||
|         >> multispace0 | ||||
|         >> name: ident | ||||
|         >> multispace1 | ||||
|  | @ -253,21 +313,24 @@ named!(fun(&str) -> Fun, do_parse!( | |||
|         >> char!('=') | ||||
|         >> multispace0 | ||||
|         >> body: expr | ||||
|         >> (Fun { | ||||
|         >> (Decl::Fun { | ||||
|             name, | ||||
|             body: Fun { | ||||
|                 args, | ||||
|                 body | ||||
|             } | ||||
|         }) | ||||
| )); | ||||
| 
 | ||||
| named!(pub decl(&str) -> Decl, alt!( | ||||
|     fun => { |f| Decl::Fun(f) } | ||||
|     fun_decl | ||||
| )); | ||||
| 
 | ||||
| named!(pub toplevel(&str) -> Vec<Decl>, separated_list0!(multispace1, decl)); | ||||
| named!(pub toplevel(&str) -> Vec<Decl>, many0!(decl)); | ||||
| 
 | ||||
| #[cfg(test)] | ||||
| mod tests { | ||||
|     use nom_trace::print_trace; | ||||
|     use std::convert::{TryFrom, TryInto}; | ||||
| 
 | ||||
|     use super::*; | ||||
|  | @ -281,7 +344,9 @@ mod tests { | |||
| 
 | ||||
|     macro_rules! test_parse { | ||||
|         ($parser: ident, $src: expr) => {{ | ||||
|             let (rem, res) = $parser($src).unwrap(); | ||||
|             let res = $parser($src); | ||||
|             print_trace!(); | ||||
|             let (rem, res) = res.unwrap(); | ||||
|             assert!( | ||||
|                 rem.is_empty(), | ||||
|                 "non-empty remainder: \"{}\", parsed: {:?}", | ||||
|  | @ -435,11 +500,87 @@ mod tests { | |||
|         let res = test_parse!(decl, "fn id x = x"); | ||||
|         assert_eq!( | ||||
|             res, | ||||
|             Decl::Fun(Fun { | ||||
|             Decl::Fun { | ||||
|                 name: "id".try_into().unwrap(), | ||||
|                 body: Fun { | ||||
|                     args: vec!["x".try_into().unwrap()], | ||||
|                     body: *ident_expr("x"), | ||||
|             }) | ||||
|                 } | ||||
|             } | ||||
|         ) | ||||
|     } | ||||
| 
 | ||||
|     #[test] | ||||
|     fn no_arg_call() { | ||||
|         let res = test_parse!(expr, "f()"); | ||||
|         assert_eq!( | ||||
|             res, | ||||
|             Expr::Call { | ||||
|                 fun: ident_expr("f"), | ||||
|                 args: vec![] | ||||
|             } | ||||
|         ); | ||||
|     } | ||||
| 
 | ||||
|     #[test] | ||||
|     fn call_with_args() { | ||||
|         let res = test_parse!(expr, "f x 1"); | ||||
|         assert_eq!( | ||||
|             res, | ||||
|             Expr::Call { | ||||
|                 fun: ident_expr("f"), | ||||
|                 args: vec![*ident_expr("x"), Expr::Literal(Literal::Int(1))] | ||||
|             } | ||||
|         ) | ||||
|     } | ||||
| 
 | ||||
|     #[test] | ||||
|     fn call_funcref() { | ||||
|         let res = test_parse!(expr, "(let x = 1 in x) 2"); | ||||
|         assert_eq!( | ||||
|             res, | ||||
|             Expr::Call { | ||||
|                 fun: Box::new(Expr::Let { | ||||
|                     bindings: vec![( | ||||
|                         Ident::try_from("x").unwrap(), | ||||
|                         Expr::Literal(Literal::Int(1)) | ||||
|                     )], | ||||
|                     body: ident_expr("x") | ||||
|                 }), | ||||
|                 args: vec![Expr::Literal(Literal::Int(2))] | ||||
|             } | ||||
|         ) | ||||
|     } | ||||
| 
 | ||||
|     #[test] | ||||
|     fn anon_function() { | ||||
|         let res = test_parse!(expr, "let id = fn x = x in id 1"); | ||||
|         assert_eq!( | ||||
|             res, | ||||
|             Expr::Let { | ||||
|                 bindings: vec![( | ||||
|                     Ident::try_from("id").unwrap(), | ||||
|                     Expr::Fun(Box::new(Fun { | ||||
|                         args: vec![Ident::try_from("x").unwrap()], | ||||
|                         body: *ident_expr("x") | ||||
|                     })) | ||||
|                 )], | ||||
|                 body: Box::new(Expr::Call { | ||||
|                     fun: ident_expr("id"), | ||||
|                     args: vec![Expr::Literal(Literal::Int(1))], | ||||
|                 }) | ||||
|             } | ||||
|         ); | ||||
|     } | ||||
| 
 | ||||
|     #[test] | ||||
|     fn multiple_decls() { | ||||
|         let res = test_parse!( | ||||
|             toplevel, | ||||
|             "fn id x = x
 | ||||
|              fn plus x y = x + y | ||||
|              fn main = plus (id 2) 7" | ||||
|         ); | ||||
|         assert_eq!(res.len(), 3); | ||||
|     } | ||||
| } | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue