feat(achilles): Implement a Unit type
Add support for a zero-sized Unit type. This requires some special at the codegen level because LLVM (unsurprisingly) only allows Void types in function return position - to make that a little easier to handle there's a new pass that strips any unit-only expressions and pulls unit-only function arguments up to new `let` bindings, so we never have to actually pass around unit values. Change-Id: I0fc18a516821f2d69172c42a6a5d246b23471e38 Reviewed-on: https://cl.tvl.fyi/c/depot/+/2695 Reviewed-by: glittershark <grfn@gws.fyi> Tested-by: BuildkiteCI
This commit is contained in:
		
							parent
							
								
									db62866d82
								
							
						
					
					
						commit
						8e13b1303a
					
				
					 16 changed files with 447 additions and 88 deletions
				
			
		
							
								
								
									
										47
									
								
								users/glittershark/achilles/Cargo.lock
									
										
									
										generated
									
									
									
								
							
							
						
						
									
										47
									
								
								users/glittershark/achilles/Cargo.lock
									
										
									
										generated
									
									
									
								
							| 
						 | 
				
			
			@ -16,6 +16,7 @@ dependencies = [
 | 
			
		|||
 "nom",
 | 
			
		||||
 "nom-trace",
 | 
			
		||||
 "pratt",
 | 
			
		||||
 "pretty_assertions",
 | 
			
		||||
 "proptest",
 | 
			
		||||
 "test-strategy",
 | 
			
		||||
 "thiserror",
 | 
			
		||||
| 
						 | 
				
			
			@ -31,6 +32,15 @@ dependencies = [
 | 
			
		|||
 "memchr",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "ansi_term"
 | 
			
		||||
version = "0.12.1"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2"
 | 
			
		||||
dependencies = [
 | 
			
		||||
 "winapi",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "anyhow"
 | 
			
		||||
version = "1.0.38"
 | 
			
		||||
| 
						 | 
				
			
			@ -155,6 +165,16 @@ version = "0.1.3"
 | 
			
		|||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "59c6fe4622b269032d2c5140a592d67a9c409031d286174fcde172fbed86f0d3"
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "ctor"
 | 
			
		||||
version = "0.1.19"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "e8f45d9ad417bcef4817d614a501ab55cdd96a6fdb24f49aab89a54acfd66b19"
 | 
			
		||||
dependencies = [
 | 
			
		||||
 "quote",
 | 
			
		||||
 "syn",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "derive_more"
 | 
			
		||||
version = "0.99.11"
 | 
			
		||||
| 
						 | 
				
			
			@ -166,6 +186,12 @@ dependencies = [
 | 
			
		|||
 "syn",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "diff"
 | 
			
		||||
version = "0.1.12"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "0e25ea47919b1560c4e3b7fe0aaab9becf5b84a10325ddf7db0f0ba5e1026499"
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "either"
 | 
			
		||||
version = "1.6.1"
 | 
			
		||||
| 
						 | 
				
			
			@ -366,6 +392,15 @@ version = "2.4.0"
 | 
			
		|||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "afb2e1c3ee07430c2cf76151675e583e0f19985fa6efae47d6848a3e2c824f85"
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "output_vt100"
 | 
			
		||||
version = "0.1.2"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "53cdc5b785b7a58c5aad8216b3dfa114df64b0b06ae6e1501cef91df2fbdf8f9"
 | 
			
		||||
dependencies = [
 | 
			
		||||
 "winapi",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "parking_lot"
 | 
			
		||||
version = "0.11.1"
 | 
			
		||||
| 
						 | 
				
			
			@ -412,6 +447,18 @@ version = "0.3.0"
 | 
			
		|||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "e31bbc12f7936a7b195790dd6d9b982b66c54f45ff6766decf25c44cac302dce"
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "pretty_assertions"
 | 
			
		||||
version = "0.7.1"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "f297542c27a7df8d45de2b0e620308ab883ad232d06c14b76ac3e144bda50184"
 | 
			
		||||
dependencies = [
 | 
			
		||||
 "ansi_term",
 | 
			
		||||
 "ctor",
 | 
			
		||||
 "diff",
 | 
			
		||||
 "output_vt100",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "proc-macro-error"
 | 
			
		||||
version = "1.0.4"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -23,3 +23,4 @@ void = "1.0.2"
 | 
			
		|||
 | 
			
		||||
[dev-dependencies]
 | 
			
		||||
crate-root = "0.1.3"
 | 
			
		||||
pretty_assertions = "0.7.1"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										1
									
								
								users/glittershark/achilles/ach/.gitignore
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								users/glittershark/achilles/ach/.gitignore
									
										
									
									
										vendored
									
									
								
							| 
						 | 
				
			
			@ -4,3 +4,4 @@
 | 
			
		|||
functions
 | 
			
		||||
simple
 | 
			
		||||
externs
 | 
			
		||||
units
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										7
									
								
								users/glittershark/achilles/ach/units.ach
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								users/glittershark/achilles/ach/units.ach
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,7 @@
 | 
			
		|||
extern puts : fn cstring -> int
 | 
			
		||||
 | 
			
		||||
ty print : fn cstring -> ()
 | 
			
		||||
fn print x = let _ = puts x in ()
 | 
			
		||||
 | 
			
		||||
ty main : fn -> int
 | 
			
		||||
fn main = let _ = print "hi" in 0
 | 
			
		||||
| 
						 | 
				
			
			@ -246,7 +246,7 @@ impl<'a, T> Expr<'a, T> {
 | 
			
		|||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone)]
 | 
			
		||||
#[derive(Debug, Clone, PartialEq, Eq)]
 | 
			
		||||
pub enum Decl<'a, T> {
 | 
			
		||||
    Fun {
 | 
			
		||||
        name: Ident<'a>,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,6 +30,7 @@ impl<'a> Ident<'a> {
 | 
			
		|||
        Ident(Cow::Owned(self.0.clone().into_owned()))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Construct an identifier from a &str without checking that it's a valid identifier
 | 
			
		||||
    pub fn from_str_unchecked(s: &'a str) -> Self {
 | 
			
		||||
        debug_assert!(is_valid_identifier(s));
 | 
			
		||||
        Self(Cow::Borrowed(s))
 | 
			
		||||
| 
						 | 
				
			
			@ -109,6 +110,7 @@ pub enum UnaryOperator {
 | 
			
		|||
 | 
			
		||||
#[derive(Debug, PartialEq, Eq, Clone)]
 | 
			
		||||
pub enum Literal<'a> {
 | 
			
		||||
    Unit,
 | 
			
		||||
    Int(u64),
 | 
			
		||||
    Bool(bool),
 | 
			
		||||
    String(Cow<'a, str>),
 | 
			
		||||
| 
						 | 
				
			
			@ -120,6 +122,7 @@ impl<'a> Literal<'a> {
 | 
			
		|||
            Literal::Int(i) => Literal::Int(*i),
 | 
			
		||||
            Literal::Bool(b) => Literal::Bool(*b),
 | 
			
		||||
            Literal::String(s) => Literal::String(Cow::Owned(s.clone().into_owned())),
 | 
			
		||||
            Literal::Unit => Literal::Unit,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -308,6 +311,7 @@ pub enum Type<'a> {
 | 
			
		|||
    Float,
 | 
			
		||||
    Bool,
 | 
			
		||||
    CString,
 | 
			
		||||
    Unit,
 | 
			
		||||
    Var(Ident<'a>),
 | 
			
		||||
    Function(FunctionType<'a>),
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -319,6 +323,7 @@ impl<'a> Type<'a> {
 | 
			
		|||
            Type::Float => Type::Float,
 | 
			
		||||
            Type::Bool => Type::Bool,
 | 
			
		||||
            Type::CString => Type::CString,
 | 
			
		||||
            Type::Unit => Type::Unit,
 | 
			
		||||
            Type::Var(v) => Type::Var(v.to_owned()),
 | 
			
		||||
            Type::Function(f) => Type::Function(f.to_owned()),
 | 
			
		||||
        }
 | 
			
		||||
| 
						 | 
				
			
			@ -374,6 +379,7 @@ impl<'a> Type<'a> {
 | 
			
		|||
            Type::Float => Type::Float,
 | 
			
		||||
            Type::Bool => Type::Bool,
 | 
			
		||||
            Type::CString => Type::CString,
 | 
			
		||||
            Type::Unit => Type::Unit,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -385,6 +391,7 @@ impl<'a> Display for Type<'a> {
 | 
			
		|||
            Type::Float => f.write_str("float"),
 | 
			
		||||
            Type::Bool => f.write_str("bool"),
 | 
			
		||||
            Type::CString => f.write_str("cstring"),
 | 
			
		||||
            Type::Unit => f.write_str("()"),
 | 
			
		||||
            Type::Var(v) => v.fmt(f),
 | 
			
		||||
            Type::Function(ft) => ft.fmt(f),
 | 
			
		||||
        }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -68,8 +68,12 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
 | 
			
		|||
        self.function_stack.last().unwrap()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn finish_function(&mut self, res: &BasicValueEnum<'ctx>) -> FunctionValue<'ctx> {
 | 
			
		||||
        self.builder.build_return(Some(res));
 | 
			
		||||
    pub fn finish_function(&mut self, res: Option<&BasicValueEnum<'ctx>>) -> FunctionValue<'ctx> {
 | 
			
		||||
        self.builder.build_return(match res {
 | 
			
		||||
            // lol
 | 
			
		||||
            Some(val) => Some(val),
 | 
			
		||||
            None => None,
 | 
			
		||||
        });
 | 
			
		||||
        self.function_stack.pop().unwrap()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -78,79 +82,92 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
 | 
			
		|||
            .append_basic_block(*self.function_stack.last().unwrap(), name)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<AnyValueEnum<'ctx>> {
 | 
			
		||||
    pub fn codegen_expr(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        expr: &'ast Expr<'ast, Type>,
 | 
			
		||||
    ) -> Result<Option<AnyValueEnum<'ctx>>> {
 | 
			
		||||
        match expr {
 | 
			
		||||
            Expr::Ident(id, _) => self
 | 
			
		||||
                .env
 | 
			
		||||
                .resolve(id)
 | 
			
		||||
                .cloned()
 | 
			
		||||
                .ok_or_else(|| Error::UndefinedVariable(id.to_owned())),
 | 
			
		||||
                .ok_or_else(|| Error::UndefinedVariable(id.to_owned()))
 | 
			
		||||
                .map(Some),
 | 
			
		||||
            Expr::Literal(lit, ty) => {
 | 
			
		||||
                let ty = self.codegen_int_type(ty);
 | 
			
		||||
                match lit {
 | 
			
		||||
                    Literal::Int(i) => Ok(AnyValueEnum::IntValue(ty.const_int(*i, false))),
 | 
			
		||||
                    Literal::Bool(b) => Ok(AnyValueEnum::IntValue(
 | 
			
		||||
                    Literal::Int(i) => Ok(Some(AnyValueEnum::IntValue(ty.const_int(*i, false)))),
 | 
			
		||||
                    Literal::Bool(b) => Ok(Some(AnyValueEnum::IntValue(
 | 
			
		||||
                        ty.const_int(if *b { 1 } else { 0 }, false),
 | 
			
		||||
                    ))),
 | 
			
		||||
                    Literal::String(s) => Ok(Some(
 | 
			
		||||
                        self.builder
 | 
			
		||||
                            .build_global_string_ptr(s, "s")
 | 
			
		||||
                            .as_pointer_value()
 | 
			
		||||
                            .into(),
 | 
			
		||||
                    )),
 | 
			
		||||
                    Literal::String(s) => Ok(self
 | 
			
		||||
                        .builder
 | 
			
		||||
                        .build_global_string_ptr(s, "s")
 | 
			
		||||
                        .as_pointer_value()
 | 
			
		||||
                        .into()),
 | 
			
		||||
                    Literal::Unit => Ok(None),
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            Expr::UnaryOp { op, rhs, .. } => {
 | 
			
		||||
                let rhs = self.codegen_expr(rhs)?;
 | 
			
		||||
                let rhs = self.codegen_expr(rhs)?.unwrap();
 | 
			
		||||
                match op {
 | 
			
		||||
                    UnaryOperator::Not => unimplemented!(),
 | 
			
		||||
                    UnaryOperator::Neg => Ok(AnyValueEnum::IntValue(
 | 
			
		||||
                    UnaryOperator::Neg => Ok(Some(AnyValueEnum::IntValue(
 | 
			
		||||
                        self.builder.build_int_neg(rhs.into_int_value(), "neg"),
 | 
			
		||||
                    )),
 | 
			
		||||
                    ))),
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            Expr::BinaryOp { lhs, op, rhs, .. } => {
 | 
			
		||||
                let lhs = self.codegen_expr(lhs)?;
 | 
			
		||||
                let rhs = self.codegen_expr(rhs)?;
 | 
			
		||||
                let lhs = self.codegen_expr(lhs)?.unwrap();
 | 
			
		||||
                let rhs = self.codegen_expr(rhs)?.unwrap();
 | 
			
		||||
                match op {
 | 
			
		||||
                    BinaryOperator::Add => Ok(AnyValueEnum::IntValue(self.builder.build_int_add(
 | 
			
		||||
                        lhs.into_int_value(),
 | 
			
		||||
                        rhs.into_int_value(),
 | 
			
		||||
                        "add",
 | 
			
		||||
                    ))),
 | 
			
		||||
                    BinaryOperator::Sub => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub(
 | 
			
		||||
                        lhs.into_int_value(),
 | 
			
		||||
                        rhs.into_int_value(),
 | 
			
		||||
                        "add",
 | 
			
		||||
                    ))),
 | 
			
		||||
                    BinaryOperator::Mul => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub(
 | 
			
		||||
                        lhs.into_int_value(),
 | 
			
		||||
                        rhs.into_int_value(),
 | 
			
		||||
                        "add",
 | 
			
		||||
                    ))),
 | 
			
		||||
                    BinaryOperator::Div => {
 | 
			
		||||
                        Ok(AnyValueEnum::IntValue(self.builder.build_int_signed_div(
 | 
			
		||||
                    BinaryOperator::Add => {
 | 
			
		||||
                        Ok(Some(AnyValueEnum::IntValue(self.builder.build_int_add(
 | 
			
		||||
                            lhs.into_int_value(),
 | 
			
		||||
                            rhs.into_int_value(),
 | 
			
		||||
                            "add",
 | 
			
		||||
                        )))
 | 
			
		||||
                        ))))
 | 
			
		||||
                    }
 | 
			
		||||
                    BinaryOperator::Sub => {
 | 
			
		||||
                        Ok(Some(AnyValueEnum::IntValue(self.builder.build_int_sub(
 | 
			
		||||
                            lhs.into_int_value(),
 | 
			
		||||
                            rhs.into_int_value(),
 | 
			
		||||
                            "add",
 | 
			
		||||
                        ))))
 | 
			
		||||
                    }
 | 
			
		||||
                    BinaryOperator::Mul => {
 | 
			
		||||
                        Ok(Some(AnyValueEnum::IntValue(self.builder.build_int_sub(
 | 
			
		||||
                            lhs.into_int_value(),
 | 
			
		||||
                            rhs.into_int_value(),
 | 
			
		||||
                            "add",
 | 
			
		||||
                        ))))
 | 
			
		||||
                    }
 | 
			
		||||
                    BinaryOperator::Div => Ok(Some(AnyValueEnum::IntValue(
 | 
			
		||||
                        self.builder.build_int_signed_div(
 | 
			
		||||
                            lhs.into_int_value(),
 | 
			
		||||
                            rhs.into_int_value(),
 | 
			
		||||
                            "add",
 | 
			
		||||
                        ),
 | 
			
		||||
                    ))),
 | 
			
		||||
                    BinaryOperator::Pow => unimplemented!(),
 | 
			
		||||
                    BinaryOperator::Equ => {
 | 
			
		||||
                        Ok(AnyValueEnum::IntValue(self.builder.build_int_compare(
 | 
			
		||||
                    BinaryOperator::Equ => Ok(Some(AnyValueEnum::IntValue(
 | 
			
		||||
                        self.builder.build_int_compare(
 | 
			
		||||
                            IntPredicate::EQ,
 | 
			
		||||
                            lhs.into_int_value(),
 | 
			
		||||
                            rhs.into_int_value(),
 | 
			
		||||
                            "eq",
 | 
			
		||||
                        )))
 | 
			
		||||
                    }
 | 
			
		||||
                        ),
 | 
			
		||||
                    ))),
 | 
			
		||||
                    BinaryOperator::Neq => todo!(),
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            Expr::Let { bindings, body, .. } => {
 | 
			
		||||
                self.env.push();
 | 
			
		||||
                for Binding { ident, body, .. } in bindings {
 | 
			
		||||
                    let val = self.codegen_expr(body)?;
 | 
			
		||||
                    self.env.set(ident, val);
 | 
			
		||||
                    if let Some(val) = self.codegen_expr(body)? {
 | 
			
		||||
                        self.env.set(ident, val);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                let res = self.codegen_expr(body);
 | 
			
		||||
                self.env.pop();
 | 
			
		||||
| 
						 | 
				
			
			@ -165,7 +182,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
 | 
			
		|||
                let then_block = self.append_basic_block("then");
 | 
			
		||||
                let else_block = self.append_basic_block("else");
 | 
			
		||||
                let join_block = self.append_basic_block("join");
 | 
			
		||||
                let condition = self.codegen_expr(condition)?;
 | 
			
		||||
                let condition = self.codegen_expr(condition)?.unwrap();
 | 
			
		||||
                self.builder.build_conditional_branch(
 | 
			
		||||
                    condition.into_int_value(),
 | 
			
		||||
                    then_block,
 | 
			
		||||
| 
						 | 
				
			
			@ -180,12 +197,22 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
 | 
			
		|||
                self.builder.build_unconditional_branch(join_block);
 | 
			
		||||
 | 
			
		||||
                self.builder.position_at_end(join_block);
 | 
			
		||||
                let phi = self.builder.build_phi(self.codegen_type(type_), "join");
 | 
			
		||||
                phi.add_incoming(&[
 | 
			
		||||
                    (&BasicValueEnum::try_from(then_res).unwrap(), then_block),
 | 
			
		||||
                    (&BasicValueEnum::try_from(else_res).unwrap(), else_block),
 | 
			
		||||
                ]);
 | 
			
		||||
                Ok(phi.as_basic_value().into())
 | 
			
		||||
                if let Some(phi_type) = self.codegen_type(type_) {
 | 
			
		||||
                    let phi = self.builder.build_phi(phi_type, "join");
 | 
			
		||||
                    phi.add_incoming(&[
 | 
			
		||||
                        (
 | 
			
		||||
                            &BasicValueEnum::try_from(then_res.unwrap()).unwrap(),
 | 
			
		||||
                            then_block,
 | 
			
		||||
                        ),
 | 
			
		||||
                        (
 | 
			
		||||
                            &BasicValueEnum::try_from(else_res.unwrap()).unwrap(),
 | 
			
		||||
                            else_block,
 | 
			
		||||
                        ),
 | 
			
		||||
                    ]);
 | 
			
		||||
                    Ok(Some(phi.as_basic_value().into()))
 | 
			
		||||
                } else {
 | 
			
		||||
                    Ok(None)
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            Expr::Call { fun, args, .. } => {
 | 
			
		||||
                if let Expr::Ident(id, _) = &**fun {
 | 
			
		||||
| 
						 | 
				
			
			@ -196,15 +223,14 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
 | 
			
		|||
                        .ok_or_else(|| Error::UndefinedVariable(id.to_owned()))?;
 | 
			
		||||
                    let args = args
 | 
			
		||||
                        .iter()
 | 
			
		||||
                        .map(|arg| Ok(self.codegen_expr(arg)?.try_into().unwrap()))
 | 
			
		||||
                        .map(|arg| Ok(self.codegen_expr(arg)?.unwrap().try_into().unwrap()))
 | 
			
		||||
                        .collect::<Result<Vec<_>>>()?;
 | 
			
		||||
                    Ok(self
 | 
			
		||||
                        .builder
 | 
			
		||||
                        .build_call(function, &args, "call")
 | 
			
		||||
                        .try_as_basic_value()
 | 
			
		||||
                        .left()
 | 
			
		||||
                        .unwrap()
 | 
			
		||||
                        .into())
 | 
			
		||||
                        .map(|val| val.into()))
 | 
			
		||||
                } else {
 | 
			
		||||
                    todo!()
 | 
			
		||||
                }
 | 
			
		||||
| 
						 | 
				
			
			@ -216,7 +242,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
 | 
			
		|||
                let function = self.codegen_function(&fname, args, body)?;
 | 
			
		||||
                self.builder.position_at_end(cur_block);
 | 
			
		||||
                self.env.restore(env);
 | 
			
		||||
                Ok(function.into())
 | 
			
		||||
                Ok(Some(function.into()))
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			@ -227,15 +253,17 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
 | 
			
		|||
        args: &'ast [(Ident<'ast>, Type)],
 | 
			
		||||
        body: &'ast Expr<'ast, Type>,
 | 
			
		||||
    ) -> Result<FunctionValue<'ctx>> {
 | 
			
		||||
        let arg_types = args
 | 
			
		||||
            .iter()
 | 
			
		||||
            .filter_map(|(_, at)| self.codegen_type(at))
 | 
			
		||||
            .collect::<Vec<_>>();
 | 
			
		||||
 | 
			
		||||
        self.new_function(
 | 
			
		||||
            name,
 | 
			
		||||
            self.codegen_type(body.type_()).fn_type(
 | 
			
		||||
                args.iter()
 | 
			
		||||
                    .map(|(_, at)| self.codegen_type(at))
 | 
			
		||||
                    .collect::<Vec<_>>()
 | 
			
		||||
                    .as_slice(),
 | 
			
		||||
                false,
 | 
			
		||||
            ),
 | 
			
		||||
            match self.codegen_type(body.type_()) {
 | 
			
		||||
                Some(ret_ty) => ret_ty.fn_type(&arg_types, false),
 | 
			
		||||
                None => self.context.void_type().fn_type(&arg_types, false),
 | 
			
		||||
            },
 | 
			
		||||
        );
 | 
			
		||||
        self.env.push();
 | 
			
		||||
        for (i, (arg, _)) in args.iter().enumerate() {
 | 
			
		||||
| 
						 | 
				
			
			@ -244,9 +272,9 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
 | 
			
		|||
                self.cur_function().get_nth_param(i as u32).unwrap().into(),
 | 
			
		||||
            );
 | 
			
		||||
        }
 | 
			
		||||
        let res = self.codegen_expr(body)?.try_into().unwrap();
 | 
			
		||||
        let res = self.codegen_expr(body)?;
 | 
			
		||||
        self.env.pop();
 | 
			
		||||
        Ok(self.finish_function(&res))
 | 
			
		||||
        Ok(self.finish_function(res.map(|av| av.try_into().unwrap()).as_ref()))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn codegen_extern(
 | 
			
		||||
| 
						 | 
				
			
			@ -255,15 +283,16 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
 | 
			
		|||
        args: &'ast [Type],
 | 
			
		||||
        ret: &'ast Type,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        let arg_types = args
 | 
			
		||||
            .iter()
 | 
			
		||||
            .map(|t| self.codegen_type(t).unwrap())
 | 
			
		||||
            .collect::<Vec<_>>();
 | 
			
		||||
        self.module.add_function(
 | 
			
		||||
            name,
 | 
			
		||||
            self.codegen_type(ret).fn_type(
 | 
			
		||||
                &args
 | 
			
		||||
                    .iter()
 | 
			
		||||
                    .map(|t| self.codegen_type(t))
 | 
			
		||||
                    .collect::<Vec<_>>(),
 | 
			
		||||
                false,
 | 
			
		||||
            ),
 | 
			
		||||
            match self.codegen_type(ret) {
 | 
			
		||||
                Some(ret_ty) => ret_ty.fn_type(&arg_types, false),
 | 
			
		||||
                None => self.context.void_type().fn_type(&arg_types, false),
 | 
			
		||||
            },
 | 
			
		||||
            None,
 | 
			
		||||
        );
 | 
			
		||||
        Ok(())
 | 
			
		||||
| 
						 | 
				
			
			@ -287,29 +316,31 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
 | 
			
		|||
 | 
			
		||||
    pub fn codegen_main(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<()> {
 | 
			
		||||
        self.new_function("main", self.context.i64_type().fn_type(&[], false));
 | 
			
		||||
        let res = self.codegen_expr(expr)?.try_into().unwrap();
 | 
			
		||||
        let res = self.codegen_expr(expr)?;
 | 
			
		||||
        if *expr.type_() != Type::Int {
 | 
			
		||||
            self.builder
 | 
			
		||||
                .build_return(Some(&self.context.i64_type().const_int(0, false)));
 | 
			
		||||
        } else {
 | 
			
		||||
            self.finish_function(&res);
 | 
			
		||||
            self.finish_function(res.map(|r| r.try_into().unwrap()).as_ref());
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn codegen_type(&self, type_: &'ast Type) -> BasicTypeEnum<'ctx> {
 | 
			
		||||
    fn codegen_type(&self, type_: &'ast Type) -> Option<BasicTypeEnum<'ctx>> {
 | 
			
		||||
        // TODO
 | 
			
		||||
        match type_ {
 | 
			
		||||
            Type::Int => self.context.i64_type().into(),
 | 
			
		||||
            Type::Float => self.context.f64_type().into(),
 | 
			
		||||
            Type::Bool => self.context.bool_type().into(),
 | 
			
		||||
            Type::CString => self
 | 
			
		||||
                .context
 | 
			
		||||
                .i8_type()
 | 
			
		||||
                .ptr_type(AddressSpace::Generic)
 | 
			
		||||
                .into(),
 | 
			
		||||
            Type::Int => Some(self.context.i64_type().into()),
 | 
			
		||||
            Type::Float => Some(self.context.f64_type().into()),
 | 
			
		||||
            Type::Bool => Some(self.context.bool_type().into()),
 | 
			
		||||
            Type::CString => Some(
 | 
			
		||||
                self.context
 | 
			
		||||
                    .i8_type()
 | 
			
		||||
                    .ptr_type(AddressSpace::Generic)
 | 
			
		||||
                    .into(),
 | 
			
		||||
            ),
 | 
			
		||||
            Type::Function(_) => todo!(),
 | 
			
		||||
            Type::Var(_) => unreachable!(),
 | 
			
		||||
            Type::Unit => None,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -8,7 +8,7 @@ use test_strategy::Arbitrary;
 | 
			
		|||
 | 
			
		||||
use crate::codegen::{self, Codegen};
 | 
			
		||||
use crate::common::Result;
 | 
			
		||||
use crate::passes::hir::monomorphize;
 | 
			
		||||
use crate::passes::hir::{monomorphize, strip_positive_units};
 | 
			
		||||
use crate::{parser, tc};
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Arbitrary)]
 | 
			
		||||
| 
						 | 
				
			
			@ -55,9 +55,10 @@ pub struct CompilerOptions {
 | 
			
		|||
 | 
			
		||||
pub fn compile_file(input: &Path, output: &Path, options: &CompilerOptions) -> Result<()> {
 | 
			
		||||
    let src = fs::read_to_string(input)?;
 | 
			
		||||
    let (_, decls) = parser::toplevel(&src)?; // TODO: statements
 | 
			
		||||
    let (_, decls) = parser::toplevel(&src)?;
 | 
			
		||||
    let mut decls = tc::typecheck_toplevel(decls)?;
 | 
			
		||||
    monomorphize::run_toplevel(&mut decls);
 | 
			
		||||
    strip_positive_units::run_toplevel(&mut decls);
 | 
			
		||||
 | 
			
		||||
    let context = codegen::Context::create();
 | 
			
		||||
    let mut codegen = Codegen::new(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,6 +30,7 @@ impl<'a> Interpreter<'a> {
 | 
			
		|||
            Expr::Literal(Literal::Int(i), _) => Ok((*i).into()),
 | 
			
		||||
            Expr::Literal(Literal::Bool(b), _) => Ok((*b).into()),
 | 
			
		||||
            Expr::Literal(Literal::String(s), _) => Ok(s.clone().into()),
 | 
			
		||||
            Expr::Literal(Literal::Unit, _) => unreachable!(),
 | 
			
		||||
            Expr::UnaryOp { op, rhs, .. } => {
 | 
			
		||||
                let rhs = self.eval(rhs)?;
 | 
			
		||||
                match op {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -186,7 +186,9 @@ named!(string(&str) -> Literal, preceded!(
 | 
			
		|||
    )
 | 
			
		||||
));
 | 
			
		||||
 | 
			
		||||
named!(literal(&str) -> Literal, alt!(int | bool_ | string));
 | 
			
		||||
named!(unit(&str) -> Literal, map!(complete!(tag!("()")), |_| Literal::Unit));
 | 
			
		||||
 | 
			
		||||
named!(literal(&str) -> Literal, alt!(int | bool_ | string | unit));
 | 
			
		||||
 | 
			
		||||
named!(literal_expr(&str) -> Expr, map!(literal, Expr::Literal));
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -270,7 +272,6 @@ named!(funcref(&str) -> Expr, alt!(
 | 
			
		|||
 | 
			
		||||
named!(no_arg_call(&str) -> Expr, do_parse!(
 | 
			
		||||
    fun: funcref
 | 
			
		||||
        >> multispace0
 | 
			
		||||
        >> complete!(tag!("()"))
 | 
			
		||||
        >> (Expr::Call {
 | 
			
		||||
            fun: Box::new(fun),
 | 
			
		||||
| 
						 | 
				
			
			@ -431,6 +432,11 @@ pub(crate) mod tests {
 | 
			
		|||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn unit() {
 | 
			
		||||
        assert_eq!(test_parse!(expr, "()"), Expr::Literal(Literal::Unit));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn bools() {
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
| 
						 | 
				
			
			@ -515,6 +521,18 @@ pub(crate) mod tests {
 | 
			
		|||
        );
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn unit_call() {
 | 
			
		||||
        let res = test_parse!(expr, "f ()");
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            res,
 | 
			
		||||
            Expr::Call {
 | 
			
		||||
                fun: ident_expr("f"),
 | 
			
		||||
                args: vec![Expr::Literal(Literal::Unit)]
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn call_with_args() {
 | 
			
		||||
        let res = test_parse!(expr, "f x 1");
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,9 +1,9 @@
 | 
			
		|||
use nom::character::complete::{multispace0, multispace1};
 | 
			
		||||
use nom::error::{ErrorKind, ParseError};
 | 
			
		||||
use nom::{alt, char, complete, do_parse, many0, named, separated_list0, tag, terminated};
 | 
			
		||||
use nom::{alt, char, complete, do_parse, eof, many0, named, separated_list0, tag, terminated};
 | 
			
		||||
 | 
			
		||||
#[macro_use]
 | 
			
		||||
mod macros;
 | 
			
		||||
pub(crate) mod macros;
 | 
			
		||||
mod expr;
 | 
			
		||||
mod type_;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -136,7 +136,11 @@ named!(pub decl(&str) -> Decl, alt!(
 | 
			
		|||
    extern_decl
 | 
			
		||||
));
 | 
			
		||||
 | 
			
		||||
named!(pub toplevel(&str) -> Vec<Decl>, terminated!(many0!(decl), multispace0));
 | 
			
		||||
named!(pub toplevel(&str) -> Vec<Decl>, do_parse!(
 | 
			
		||||
    decls: many0!(decl)
 | 
			
		||||
        >> multispace0
 | 
			
		||||
        >> eof!()
 | 
			
		||||
        >> (decls)));
 | 
			
		||||
 | 
			
		||||
#[cfg(test)]
 | 
			
		||||
mod tests {
 | 
			
		||||
| 
						 | 
				
			
			@ -215,4 +219,21 @@ mod tests {
 | 
			
		|||
            }]
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn return_unit() {
 | 
			
		||||
        assert_eq!(
 | 
			
		||||
            test_parse!(decl, "fn g _ = ()"),
 | 
			
		||||
            Decl::Fun {
 | 
			
		||||
                name: "g".try_into().unwrap(),
 | 
			
		||||
                body: Fun {
 | 
			
		||||
                    args: vec![Arg {
 | 
			
		||||
                        ident: "_".try_into().unwrap(),
 | 
			
		||||
                        type_: None,
 | 
			
		||||
                    }],
 | 
			
		||||
                    body: Expr::Literal(Literal::Unit),
 | 
			
		||||
                },
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -29,6 +29,7 @@ named!(pub type_(&str) -> Type, alt!(
 | 
			
		|||
    tag!("float") => { |_| Type::Float } |
 | 
			
		||||
    tag!("bool") => { |_| Type::Bool } |
 | 
			
		||||
    tag!("cstring") => { |_| Type::CString } |
 | 
			
		||||
    tag!("()") => { |_| Type::Unit } |
 | 
			
		||||
    function_type => { |ft| Type::Function(ft) }|
 | 
			
		||||
    ident => { |id| Type::Var(id) } |
 | 
			
		||||
    delimited!(
 | 
			
		||||
| 
						 | 
				
			
			@ -51,6 +52,7 @@ mod tests {
 | 
			
		|||
        assert_eq!(test_parse!(type_, "float"), Type::Float);
 | 
			
		||||
        assert_eq!(test_parse!(type_, "bool"), Type::Bool);
 | 
			
		||||
        assert_eq!(test_parse!(type_, "cstring"), Type::CString);
 | 
			
		||||
        assert_eq!(test_parse!(type_, "()"), Type::Unit);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,6 +4,7 @@ use crate::ast::hir::{Binding, Decl, Expr};
 | 
			
		|||
use crate::ast::{BinaryOperator, Ident, Literal, UnaryOperator};
 | 
			
		||||
 | 
			
		||||
pub(crate) mod monomorphize;
 | 
			
		||||
pub(crate) mod strip_positive_units;
 | 
			
		||||
 | 
			
		||||
pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
 | 
			
		||||
    type Error;
 | 
			
		||||
| 
						 | 
				
			
			@ -53,7 +54,12 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
 | 
			
		|||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn pre_visit_expr(&mut self, _expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> {
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_expr(&mut self, expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> {
 | 
			
		||||
        self.pre_visit_expr(expr)?;
 | 
			
		||||
        match expr {
 | 
			
		||||
            Expr::Ident(id, t) => {
 | 
			
		||||
                self.visit_ident(id)?;
 | 
			
		||||
| 
						 | 
				
			
			@ -140,6 +146,17 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
 | 
			
		|||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn post_visit_fun_decl(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        _name: &mut Ident<'ast>,
 | 
			
		||||
        _type_args: &mut Vec<Ident>,
 | 
			
		||||
        _args: &mut Vec<(Ident, T)>,
 | 
			
		||||
        _body: &mut Box<Expr<T>>,
 | 
			
		||||
        _type_: &mut T,
 | 
			
		||||
    ) -> Result<(), Self::Error> {
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_decl(&mut self, decl: &'a mut Decl<'ast, T>) -> Result<(), Self::Error> {
 | 
			
		||||
        match decl {
 | 
			
		||||
            Decl::Fun {
 | 
			
		||||
| 
						 | 
				
			
			@ -150,15 +167,16 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
 | 
			
		|||
                type_,
 | 
			
		||||
            } => {
 | 
			
		||||
                self.visit_ident(name)?;
 | 
			
		||||
                for type_arg in type_args {
 | 
			
		||||
                for type_arg in type_args.iter_mut() {
 | 
			
		||||
                    self.visit_ident(type_arg)?;
 | 
			
		||||
                }
 | 
			
		||||
                for (arg, t) in args {
 | 
			
		||||
                for (arg, t) in args.iter_mut() {
 | 
			
		||||
                    self.visit_ident(arg)?;
 | 
			
		||||
                    self.visit_type(t)?;
 | 
			
		||||
                }
 | 
			
		||||
                self.visit_expr(body)?;
 | 
			
		||||
                self.visit_type(type_)?;
 | 
			
		||||
                self.post_visit_fun_decl(name, type_args, args, body, type_)?;
 | 
			
		||||
            }
 | 
			
		||||
            Decl::Extern {
 | 
			
		||||
                name,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,189 @@
 | 
			
		|||
use std::collections::HashMap;
 | 
			
		||||
use std::mem;
 | 
			
		||||
 | 
			
		||||
use ast::hir::Binding;
 | 
			
		||||
use ast::Literal;
 | 
			
		||||
use void::{ResultVoidExt, Void};
 | 
			
		||||
 | 
			
		||||
use crate::ast::hir::{Decl, Expr};
 | 
			
		||||
use crate::ast::{self, Ident};
 | 
			
		||||
 | 
			
		||||
use super::Visitor;
 | 
			
		||||
 | 
			
		||||
/// Strip all values with a unit type in positive (non-return) position
 | 
			
		||||
pub(crate) struct StripPositiveUnits {}
 | 
			
		||||
 | 
			
		||||
impl<'a, 'ast> Visitor<'a, 'ast, ast::Type<'ast>> for StripPositiveUnits {
 | 
			
		||||
    type Error = Void;
 | 
			
		||||
 | 
			
		||||
    fn pre_visit_expr(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        expr: &mut Expr<'ast, ast::Type<'ast>>,
 | 
			
		||||
    ) -> Result<(), Self::Error> {
 | 
			
		||||
        let mut extracted = vec![];
 | 
			
		||||
        if let Expr::Call { args, .. } = expr {
 | 
			
		||||
            // TODO(grfn): replace with drain_filter once it's stabilized
 | 
			
		||||
            let mut i = 0;
 | 
			
		||||
            while i != args.len() {
 | 
			
		||||
                if args[i].type_() == &ast::Type::Unit {
 | 
			
		||||
                    let expr = args.remove(i);
 | 
			
		||||
                    if !matches!(expr, Expr::Literal(Literal::Unit, _)) {
 | 
			
		||||
                        extracted.push(expr)
 | 
			
		||||
                    };
 | 
			
		||||
                } else {
 | 
			
		||||
                    i += 1
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if !extracted.is_empty() {
 | 
			
		||||
            let body = mem::replace(expr, Expr::Literal(Literal::Unit, ast::Type::Unit));
 | 
			
		||||
            *expr = Expr::Let {
 | 
			
		||||
                bindings: extracted
 | 
			
		||||
                    .into_iter()
 | 
			
		||||
                    .map(|expr| Binding {
 | 
			
		||||
                        ident: Ident::from_str_unchecked("___discarded"),
 | 
			
		||||
                        type_: expr.type_().clone(),
 | 
			
		||||
                        body: expr,
 | 
			
		||||
                    })
 | 
			
		||||
                    .collect(),
 | 
			
		||||
                type_: body.type_().clone(),
 | 
			
		||||
                body: Box::new(body),
 | 
			
		||||
            };
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn post_visit_call(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        _fun: &mut Expr<'ast, ast::Type<'ast>>,
 | 
			
		||||
        _type_args: &mut HashMap<Ident<'ast>, ast::Type<'ast>>,
 | 
			
		||||
        args: &mut Vec<Expr<'ast, ast::Type<'ast>>>,
 | 
			
		||||
    ) -> Result<(), Self::Error> {
 | 
			
		||||
        args.retain(|arg| arg.type_() != &ast::Type::Unit);
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn visit_type(&mut self, type_: &mut ast::Type<'ast>) -> Result<(), Self::Error> {
 | 
			
		||||
        if let ast::Type::Function(ft) = type_ {
 | 
			
		||||
            ft.args.retain(|a| a != &ast::Type::Unit);
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn post_visit_fun_decl(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        _name: &mut Ident<'ast>,
 | 
			
		||||
        _type_args: &mut Vec<Ident>,
 | 
			
		||||
        args: &mut Vec<(Ident, ast::Type<'ast>)>,
 | 
			
		||||
        _body: &mut Box<Expr<ast::Type<'ast>>>,
 | 
			
		||||
        _type_: &mut ast::Type<'ast>,
 | 
			
		||||
    ) -> Result<(), Self::Error> {
 | 
			
		||||
        args.retain(|(_, ty)| ty != &ast::Type::Unit);
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub(crate) fn run_toplevel<'a>(toplevel: &mut Vec<Decl<'a, ast::Type<'a>>>) {
 | 
			
		||||
    let mut pass = StripPositiveUnits {};
 | 
			
		||||
    for decl in toplevel.iter_mut() {
 | 
			
		||||
        pass.visit_decl(decl).void_unwrap();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[cfg(test)]
 | 
			
		||||
mod tests {
 | 
			
		||||
    use super::*;
 | 
			
		||||
    use crate::parser::toplevel;
 | 
			
		||||
    use crate::tc::typecheck_toplevel;
 | 
			
		||||
    use pretty_assertions::assert_eq;
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn unit_only_arg() {
 | 
			
		||||
        let (_, program) = toplevel(
 | 
			
		||||
            "ty f : fn () -> int
 | 
			
		||||
             fn f _ = 1
 | 
			
		||||
 | 
			
		||||
             ty main : fn -> int
 | 
			
		||||
             fn main = f ()",
 | 
			
		||||
        )
 | 
			
		||||
        .unwrap();
 | 
			
		||||
 | 
			
		||||
        let (_, expected) = toplevel(
 | 
			
		||||
            "ty f : fn -> int
 | 
			
		||||
             fn f = 1
 | 
			
		||||
 | 
			
		||||
             ty main : fn -> int
 | 
			
		||||
             fn main = f()",
 | 
			
		||||
        )
 | 
			
		||||
        .unwrap();
 | 
			
		||||
        let expected = typecheck_toplevel(expected).unwrap();
 | 
			
		||||
 | 
			
		||||
        let mut program = typecheck_toplevel(program).unwrap();
 | 
			
		||||
        run_toplevel(&mut program);
 | 
			
		||||
 | 
			
		||||
        assert_eq!(program, expected);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn unit_and_other_arg() {
 | 
			
		||||
        let (_, program) = toplevel(
 | 
			
		||||
            "ty f : fn (), int -> int
 | 
			
		||||
             fn f _ x = x
 | 
			
		||||
 | 
			
		||||
             ty main : fn -> int
 | 
			
		||||
             fn main = f () 1",
 | 
			
		||||
        )
 | 
			
		||||
        .unwrap();
 | 
			
		||||
 | 
			
		||||
        let (_, expected) = toplevel(
 | 
			
		||||
            "ty f : fn int -> int
 | 
			
		||||
             fn f x = x
 | 
			
		||||
 | 
			
		||||
             ty main : fn -> int
 | 
			
		||||
             fn main = f 1",
 | 
			
		||||
        )
 | 
			
		||||
        .unwrap();
 | 
			
		||||
        let expected = typecheck_toplevel(expected).unwrap();
 | 
			
		||||
 | 
			
		||||
        let mut program = typecheck_toplevel(program).unwrap();
 | 
			
		||||
        run_toplevel(&mut program);
 | 
			
		||||
 | 
			
		||||
        assert_eq!(program, expected);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn unit_expr_and_other_arg() {
 | 
			
		||||
        let (_, program) = toplevel(
 | 
			
		||||
            "ty f : fn (), int -> int
 | 
			
		||||
             fn f _ x = x
 | 
			
		||||
 | 
			
		||||
             ty g : fn int -> ()
 | 
			
		||||
             fn g _ = ()
 | 
			
		||||
 | 
			
		||||
             ty main : fn -> int
 | 
			
		||||
             fn main = f (g 2) 1",
 | 
			
		||||
        )
 | 
			
		||||
        .unwrap();
 | 
			
		||||
 | 
			
		||||
        let (_, expected) = toplevel(
 | 
			
		||||
            "ty f : fn int -> int
 | 
			
		||||
             fn f x = x
 | 
			
		||||
 | 
			
		||||
             ty g : fn int -> ()
 | 
			
		||||
             fn g _ = ()
 | 
			
		||||
 | 
			
		||||
             ty main : fn -> int
 | 
			
		||||
             fn main = let ___discarded = g 2 in f 1",
 | 
			
		||||
        )
 | 
			
		||||
        .unwrap();
 | 
			
		||||
        assert_eq!(expected.len(), 6);
 | 
			
		||||
        let expected = typecheck_toplevel(expected).unwrap();
 | 
			
		||||
 | 
			
		||||
        let mut program = typecheck_toplevel(program).unwrap();
 | 
			
		||||
        run_toplevel(&mut program);
 | 
			
		||||
 | 
			
		||||
        assert_eq!(program, expected);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -85,6 +85,7 @@ pub enum Type {
 | 
			
		|||
    Exist(TyVar),
 | 
			
		||||
    Nullary(NullaryType),
 | 
			
		||||
    Prim(PrimType),
 | 
			
		||||
    Unit,
 | 
			
		||||
    Fun {
 | 
			
		||||
        args: Vec<Type>,
 | 
			
		||||
        ret: Box<Type>,
 | 
			
		||||
| 
						 | 
				
			
			@ -96,6 +97,7 @@ impl<'a> TryFrom<Type> for ast::Type<'a> {
 | 
			
		|||
 | 
			
		||||
    fn try_from(value: Type) -> result::Result<Self, Self::Error> {
 | 
			
		||||
        match value {
 | 
			
		||||
            Type::Unit => Ok(ast::Type::Unit),
 | 
			
		||||
            Type::Univ(_) => todo!(),
 | 
			
		||||
            Type::Exist(_) => Err(value),
 | 
			
		||||
            Type::Nullary(_) => todo!(),
 | 
			
		||||
| 
						 | 
				
			
			@ -126,6 +128,7 @@ impl Display for Type {
 | 
			
		|||
            Type::Univ(TyVar(n)) => write!(f, "∀{}", n),
 | 
			
		||||
            Type::Exist(TyVar(n)) => write!(f, "∃{}", n),
 | 
			
		||||
            Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret),
 | 
			
		||||
            Type::Unit => write!(f, "()"),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -171,6 +174,7 @@ impl<'ast> Typechecker<'ast> {
 | 
			
		|||
                    Literal::Int(_) => Type::Prim(PrimType::Int),
 | 
			
		||||
                    Literal::Bool(_) => Type::Prim(PrimType::Bool),
 | 
			
		||||
                    Literal::String(_) => Type::Prim(PrimType::CString),
 | 
			
		||||
                    Literal::Unit => Type::Unit,
 | 
			
		||||
                };
 | 
			
		||||
                Ok(hir::Expr::Literal(lit.to_owned(), type_))
 | 
			
		||||
            }
 | 
			
		||||
| 
						 | 
				
			
			@ -377,6 +381,7 @@ impl<'ast> Typechecker<'ast> {
 | 
			
		|||
 | 
			
		||||
    fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
 | 
			
		||||
        match (ty1, ty2) {
 | 
			
		||||
            (Type::Unit, Type::Unit) => Ok(Type::Unit),
 | 
			
		||||
            (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
 | 
			
		||||
                Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()),
 | 
			
		||||
                Some(var @ ast::Type::Var(_)) => {
 | 
			
		||||
| 
						 | 
				
			
			@ -466,6 +471,7 @@ impl<'ast> Typechecker<'ast> {
 | 
			
		|||
        let ret = match ty {
 | 
			
		||||
            Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)),
 | 
			
		||||
            Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))),
 | 
			
		||||
            Type::Unit => Ok(ast::Type::Unit),
 | 
			
		||||
            Type::Nullary(_) => todo!(),
 | 
			
		||||
            Type::Prim(pr) => Ok(pr.into()),
 | 
			
		||||
            Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType {
 | 
			
		||||
| 
						 | 
				
			
			@ -496,6 +502,7 @@ impl<'ast> Typechecker<'ast> {
 | 
			
		|||
                }
 | 
			
		||||
                Type::Nullary(_) => todo!(),
 | 
			
		||||
                Type::Prim(pr) => break Some((*pr).into()),
 | 
			
		||||
                Type::Unit => break Some(ast::Type::Unit),
 | 
			
		||||
                Type::Fun { args, ret } => todo!(),
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
| 
						 | 
				
			
			@ -503,6 +510,7 @@ impl<'ast> Typechecker<'ast> {
 | 
			
		|||
 | 
			
		||||
    fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type {
 | 
			
		||||
        match ast_type {
 | 
			
		||||
            ast::Type::Unit => Type::Unit,
 | 
			
		||||
            ast::Type::Int => INT,
 | 
			
		||||
            ast::Type::Float => FLOAT,
 | 
			
		||||
            ast::Type::Bool => BOOL,
 | 
			
		||||
| 
						 | 
				
			
			@ -570,6 +578,8 @@ impl<'ast> Typechecker<'ast> {
 | 
			
		|||
            }
 | 
			
		||||
            (Type::Univ(_), _) => false,
 | 
			
		||||
            (Type::Exist(_), _) => false,
 | 
			
		||||
            (Type::Unit, ast::Type::Unit) => true,
 | 
			
		||||
            (Type::Unit, _) => false,
 | 
			
		||||
            (Type::Nullary(_), _) => todo!(),
 | 
			
		||||
            (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
 | 
			
		||||
            (Type::Fun { args, ret }, ast::Type::Function(ft)) => {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -24,6 +24,11 @@ const FIXTURES: &[Fixture] = &[
 | 
			
		|||
        exit_code: 0,
 | 
			
		||||
        expected_output: "foobar\n",
 | 
			
		||||
    },
 | 
			
		||||
    Fixture {
 | 
			
		||||
        name: "units",
 | 
			
		||||
        exit_code: 0,
 | 
			
		||||
        expected_output: "hi\n",
 | 
			
		||||
    },
 | 
			
		||||
];
 | 
			
		||||
 | 
			
		||||
#[test]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue