Implement universally quantified type variables, both explicitly given by the user and inferred by the type inference algorithm.
		
			
				
	
	
		
			664 lines
		
	
	
	
		
			22 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
			
		
		
	
	
			664 lines
		
	
	
	
		
			22 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
| use bimap::BiMap;
 | |
| use derive_more::From;
 | |
| use itertools::Itertools;
 | |
| use std::cell::RefCell;
 | |
| use std::collections::HashMap;
 | |
| use std::convert::{TryFrom, TryInto};
 | |
| use std::fmt::{self, Display};
 | |
| use std::{mem, result};
 | |
| use thiserror::Error;
 | |
| 
 | |
| use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal};
 | |
| use crate::common::env::Env;
 | |
| use crate::common::{Namer, NamerOf};
 | |
| 
 | |
| #[derive(Debug, Error)]
 | |
| pub enum Error {
 | |
|     #[error("Undefined variable {0}")]
 | |
|     UndefinedVariable(Ident<'static>),
 | |
| 
 | |
|     #[error("Mismatched types: expected {expected}, but got {actual}")]
 | |
|     TypeMismatch { expected: Type, actual: Type },
 | |
| 
 | |
|     #[error("Mismatched types, expected numeric type, but got {0}")]
 | |
|     NonNumeric(Type),
 | |
| 
 | |
|     #[error("Ambiguous type {0}")]
 | |
|     AmbiguousType(TyVar),
 | |
| }
 | |
| 
 | |
| pub type Result<T> = result::Result<T, Error>;
 | |
| 
 | |
| #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
 | |
| pub struct TyVar(u64);
 | |
| 
 | |
| impl Display for TyVar {
 | |
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 | |
|         write!(f, "t{}", self.0)
 | |
|     }
 | |
| }
 | |
| 
 | |
| #[derive(Debug, PartialEq, Eq, Clone, Hash)]
 | |
| pub struct NullaryType(String);
 | |
| 
 | |
| impl Display for NullaryType {
 | |
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 | |
|         f.write_str(&self.0)
 | |
|     }
 | |
| }
 | |
| 
 | |
| #[derive(Debug, PartialEq, Eq, Clone, Copy)]
 | |
| pub enum PrimType {
 | |
|     Int,
 | |
|     Float,
 | |
|     Bool,
 | |
|     CString,
 | |
| }
 | |
| 
 | |
| impl<'a> From<PrimType> for ast::Type<'a> {
 | |
|     fn from(pr: PrimType) -> Self {
 | |
|         match pr {
 | |
|             PrimType::Int => ast::Type::Int,
 | |
|             PrimType::Float => ast::Type::Float,
 | |
|             PrimType::Bool => ast::Type::Bool,
 | |
|             PrimType::CString => ast::Type::CString,
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| impl Display for PrimType {
 | |
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 | |
|         match self {
 | |
|             PrimType::Int => f.write_str("int"),
 | |
|             PrimType::Float => f.write_str("float"),
 | |
|             PrimType::Bool => f.write_str("bool"),
 | |
|             PrimType::CString => f.write_str("cstring"),
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| #[derive(Debug, PartialEq, Eq, Clone, From)]
 | |
| pub enum Type {
 | |
|     #[from(ignore)]
 | |
|     Univ(TyVar),
 | |
|     #[from(ignore)]
 | |
|     Exist(TyVar),
 | |
|     Nullary(NullaryType),
 | |
|     Prim(PrimType),
 | |
|     Fun {
 | |
|         args: Vec<Type>,
 | |
|         ret: Box<Type>,
 | |
|     },
 | |
| }
 | |
| 
 | |
| impl<'a> TryFrom<Type> for ast::Type<'a> {
 | |
|     type Error = Type;
 | |
| 
 | |
|     fn try_from(value: Type) -> result::Result<Self, Self::Error> {
 | |
|         match value {
 | |
|             Type::Univ(_) => todo!(),
 | |
|             Type::Exist(_) => Err(value),
 | |
|             Type::Nullary(_) => todo!(),
 | |
|             Type::Prim(p) => Ok(p.into()),
 | |
|             Type::Fun { ref args, ref ret } => Ok(ast::Type::Function(ast::FunctionType {
 | |
|                 args: args
 | |
|                     .clone()
 | |
|                     .into_iter()
 | |
|                     .map(Self::try_from)
 | |
|                     .try_collect()
 | |
|                     .map_err(|_| value.clone())?,
 | |
|                 ret: Box::new((*ret.clone()).try_into().map_err(|_| value.clone())?),
 | |
|             })),
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| const INT: Type = Type::Prim(PrimType::Int);
 | |
| const FLOAT: Type = Type::Prim(PrimType::Float);
 | |
| const BOOL: Type = Type::Prim(PrimType::Bool);
 | |
| const CSTRING: Type = Type::Prim(PrimType::CString);
 | |
| 
 | |
| impl Display for Type {
 | |
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 | |
|         match self {
 | |
|             Type::Nullary(nt) => nt.fmt(f),
 | |
|             Type::Prim(p) => p.fmt(f),
 | |
|             Type::Univ(TyVar(n)) => write!(f, "∀{}", n),
 | |
|             Type::Exist(TyVar(n)) => write!(f, "∃{}", n),
 | |
|             Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret),
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| struct Typechecker<'ast> {
 | |
|     ty_var_namer: NamerOf<TyVar>,
 | |
|     ctx: HashMap<TyVar, Type>,
 | |
|     env: Env<Ident<'ast>, Type>,
 | |
| 
 | |
|     /// AST type var -> type
 | |
|     instantiations: Env<Ident<'ast>, Type>,
 | |
| 
 | |
|     /// AST type-var -> universal TyVar
 | |
|     type_vars: RefCell<(BiMap<Ident<'ast>, TyVar>, NamerOf<Ident<'static>>)>,
 | |
| }
 | |
| 
 | |
| impl<'ast> Typechecker<'ast> {
 | |
|     fn new() -> Self {
 | |
|         Self {
 | |
|             ty_var_namer: Namer::new(TyVar).boxed(),
 | |
|             type_vars: RefCell::new((
 | |
|                 Default::default(),
 | |
|                 Namer::alphabetic().map(|n| Ident::try_from(n).unwrap()),
 | |
|             )),
 | |
|             ctx: Default::default(),
 | |
|             env: Default::default(),
 | |
|             instantiations: Default::default(),
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     pub(crate) fn tc_expr(&mut self, expr: ast::Expr<'ast>) -> Result<hir::Expr<'ast, Type>> {
 | |
|         match expr {
 | |
|             ast::Expr::Ident(ident) => {
 | |
|                 let type_ = self
 | |
|                     .env
 | |
|                     .resolve(&ident)
 | |
|                     .ok_or_else(|| Error::UndefinedVariable(ident.to_owned()))?
 | |
|                     .clone();
 | |
|                 Ok(hir::Expr::Ident(ident, type_))
 | |
|             }
 | |
|             ast::Expr::Literal(lit) => {
 | |
|                 let type_ = match lit {
 | |
|                     Literal::Int(_) => Type::Prim(PrimType::Int),
 | |
|                     Literal::Bool(_) => Type::Prim(PrimType::Bool),
 | |
|                     Literal::String(_) => Type::Prim(PrimType::CString),
 | |
|                 };
 | |
|                 Ok(hir::Expr::Literal(lit.to_owned(), type_))
 | |
|             }
 | |
|             ast::Expr::UnaryOp { op, rhs } => todo!(),
 | |
|             ast::Expr::BinaryOp { lhs, op, rhs } => {
 | |
|                 let lhs = self.tc_expr(*lhs)?;
 | |
|                 let rhs = self.tc_expr(*rhs)?;
 | |
|                 let type_ = match op {
 | |
|                     BinaryOperator::Equ | BinaryOperator::Neq => {
 | |
|                         self.unify(lhs.type_(), rhs.type_())?;
 | |
|                         Type::Prim(PrimType::Bool)
 | |
|                     }
 | |
|                     BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => {
 | |
|                         let ty = self.unify(lhs.type_(), rhs.type_())?;
 | |
|                         // if !matches!(ty, Type::Int | Type::Float) {
 | |
|                         //     return Err(Error::NonNumeric(ty));
 | |
|                         // }
 | |
|                         ty
 | |
|                     }
 | |
|                     BinaryOperator::Div => todo!(),
 | |
|                     BinaryOperator::Pow => todo!(),
 | |
|                 };
 | |
|                 Ok(hir::Expr::BinaryOp {
 | |
|                     lhs: Box::new(lhs),
 | |
|                     op,
 | |
|                     rhs: Box::new(rhs),
 | |
|                     type_,
 | |
|                 })
 | |
|             }
 | |
|             ast::Expr::Let { bindings, body } => {
 | |
|                 self.env.push();
 | |
|                 let bindings = bindings
 | |
|                     .into_iter()
 | |
|                     .map(
 | |
|                         |ast::Binding { ident, type_, body }| -> Result<hir::Binding<Type>> {
 | |
|                             let body = self.tc_expr(body)?;
 | |
|                             if let Some(type_) = type_ {
 | |
|                                 let type_ = self.type_from_ast_type(type_);
 | |
|                                 self.unify(body.type_(), &type_)?;
 | |
|                             }
 | |
|                             self.env.set(ident.clone(), body.type_().clone());
 | |
|                             Ok(hir::Binding {
 | |
|                                 ident,
 | |
|                                 type_: body.type_().clone(),
 | |
|                                 body,
 | |
|                             })
 | |
|                         },
 | |
|                     )
 | |
|                     .collect::<Result<Vec<hir::Binding<Type>>>>()?;
 | |
|                 let body = self.tc_expr(*body)?;
 | |
|                 self.env.pop();
 | |
|                 Ok(hir::Expr::Let {
 | |
|                     bindings,
 | |
|                     type_: body.type_().clone(),
 | |
|                     body: Box::new(body),
 | |
|                 })
 | |
|             }
 | |
|             ast::Expr::If {
 | |
|                 condition,
 | |
|                 then,
 | |
|                 else_,
 | |
|             } => {
 | |
|                 let condition = self.tc_expr(*condition)?;
 | |
|                 self.unify(&Type::Prim(PrimType::Bool), condition.type_())?;
 | |
|                 let then = self.tc_expr(*then)?;
 | |
|                 let else_ = self.tc_expr(*else_)?;
 | |
|                 let type_ = self.unify(then.type_(), else_.type_())?;
 | |
|                 Ok(hir::Expr::If {
 | |
|                     condition: Box::new(condition),
 | |
|                     then: Box::new(then),
 | |
|                     else_: Box::new(else_),
 | |
|                     type_,
 | |
|                 })
 | |
|             }
 | |
|             ast::Expr::Fun(f) => {
 | |
|                 let ast::Fun { args, body } = *f;
 | |
|                 self.env.push();
 | |
|                 let args: Vec<_> = args
 | |
|                     .into_iter()
 | |
|                     .map(|Arg { ident, type_ }| {
 | |
|                         let ty = match type_ {
 | |
|                             Some(t) => self.type_from_ast_type(t),
 | |
|                             None => self.fresh_ex(),
 | |
|                         };
 | |
|                         self.env.set(ident.clone(), ty.clone());
 | |
|                         (ident, ty)
 | |
|                     })
 | |
|                     .collect();
 | |
|                 let body = self.tc_expr(body)?;
 | |
|                 self.env.pop();
 | |
|                 Ok(hir::Expr::Fun {
 | |
|                     type_: self.universalize(
 | |
|                         args.iter().map(|(_, ty)| ty.clone()).collect(),
 | |
|                         body.type_().clone(),
 | |
|                     ),
 | |
|                     args,
 | |
|                     body: Box::new(body),
 | |
|                 })
 | |
|             }
 | |
|             ast::Expr::Call { fun, args } => {
 | |
|                 let ret_ty = self.fresh_ex();
 | |
|                 let arg_tys = args.iter().map(|_| self.fresh_ex()).collect::<Vec<_>>();
 | |
|                 let ft = Type::Fun {
 | |
|                     args: arg_tys.clone(),
 | |
|                     ret: Box::new(ret_ty.clone()),
 | |
|                 };
 | |
|                 let fun = self.tc_expr(*fun)?;
 | |
|                 self.instantiations.push();
 | |
|                 self.unify(&ft, fun.type_())?;
 | |
|                 let args = args
 | |
|                     .into_iter()
 | |
|                     .zip(arg_tys)
 | |
|                     .map(|(arg, ty)| {
 | |
|                         let arg = self.tc_expr(arg)?;
 | |
|                         self.unify(&ty, arg.type_())?;
 | |
|                         Ok(arg)
 | |
|                     })
 | |
|                     .try_collect()?;
 | |
|                 self.commit_instantiations();
 | |
|                 Ok(hir::Expr::Call {
 | |
|                     fun: Box::new(fun),
 | |
|                     args,
 | |
|                     type_: ret_ty,
 | |
|                 })
 | |
|             }
 | |
|             ast::Expr::Ascription { expr, type_ } => {
 | |
|                 let expr = self.tc_expr(*expr)?;
 | |
|                 let type_ = self.type_from_ast_type(type_);
 | |
|                 self.unify(expr.type_(), &type_)?;
 | |
|                 Ok(expr)
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     pub(crate) fn tc_decl(&mut self, decl: ast::Decl<'ast>) -> Result<hir::Decl<'ast, Type>> {
 | |
|         match decl {
 | |
|             ast::Decl::Fun { name, body } => {
 | |
|                 let body = self.tc_expr(ast::Expr::Fun(Box::new(body)))?;
 | |
|                 let type_ = body.type_().clone();
 | |
|                 self.env.set(name.clone(), type_);
 | |
|                 match body {
 | |
|                     hir::Expr::Fun { args, body, type_ } => Ok(hir::Decl::Fun {
 | |
|                         name,
 | |
|                         args,
 | |
|                         body,
 | |
|                         type_,
 | |
|                     }),
 | |
|                     _ => unreachable!(),
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     fn fresh_tv(&mut self) -> TyVar {
 | |
|         self.ty_var_namer.make_name()
 | |
|     }
 | |
| 
 | |
|     fn fresh_ex(&mut self) -> Type {
 | |
|         Type::Exist(self.fresh_tv())
 | |
|     }
 | |
| 
 | |
|     fn fresh_univ(&mut self) -> Type {
 | |
|         Type::Univ(self.fresh_tv())
 | |
|     }
 | |
| 
 | |
|     #[allow(clippy::redundant_closure)] // https://github.com/rust-lang/rust-clippy/issues/6903
 | |
|     fn universalize(&mut self, args: Vec<Type>, ret: Type) -> Type {
 | |
|         let mut vars = HashMap::new();
 | |
|         let mut universalize_type = move |ty| match ty {
 | |
|             Type::Exist(tv) if self.resolve_tv(tv).is_none() => vars
 | |
|                 .entry(tv)
 | |
|                 .or_insert_with_key(|tv| {
 | |
|                     let ty = self.fresh_univ();
 | |
|                     self.ctx.insert(*tv, ty.clone());
 | |
|                     ty
 | |
|                 })
 | |
|                 .clone(),
 | |
|             _ => ty,
 | |
|         };
 | |
| 
 | |
|         Type::Fun {
 | |
|             args: args.into_iter().map(|t| universalize_type(t)).collect(),
 | |
|             ret: Box::new(universalize_type(ret)),
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
 | |
|         match (ty1, ty2) {
 | |
|             (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
 | |
|                 Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()),
 | |
|                 Some(var @ ast::Type::Var(_)) => {
 | |
|                     let var = self.type_from_ast_type(var);
 | |
|                     self.unify(&var, ty)
 | |
|                 }
 | |
|                 Some(existing_ty) => match ty {
 | |
|                     Type::Exist(_) => {
 | |
|                         let rhs = self.type_from_ast_type(existing_ty);
 | |
|                         self.unify(ty, &rhs)
 | |
|                     }
 | |
|                     _ => Err(Error::TypeMismatch {
 | |
|                         expected: ty.clone(),
 | |
|                         actual: self.type_from_ast_type(existing_ty),
 | |
|                     }),
 | |
|                 },
 | |
|                 None => match self.ctx.insert(*tv, ty.clone()) {
 | |
|                     Some(existing) => self.unify(&existing, ty),
 | |
|                     None => Ok(ty.clone()),
 | |
|                 },
 | |
|             },
 | |
|             (Type::Univ(u1), Type::Univ(u2)) if u1 == u2 => Ok(ty2.clone()),
 | |
|             (Type::Univ(u), ty) | (ty, Type::Univ(u)) => {
 | |
|                 let ident = self.name_univ(*u);
 | |
|                 match self.instantiations.resolve(&ident) {
 | |
|                     Some(existing_ty) if ty == existing_ty => Ok(ty.clone()),
 | |
|                     Some(existing_ty) => Err(Error::TypeMismatch {
 | |
|                         expected: ty.clone(),
 | |
|                         actual: existing_ty.clone(),
 | |
|                     }),
 | |
|                     None => {
 | |
|                         self.instantiations.set(ident, ty.clone());
 | |
|                         Ok(ty.clone())
 | |
|                     }
 | |
|                 }
 | |
|             }
 | |
|             (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()),
 | |
|             (
 | |
|                 Type::Fun {
 | |
|                     args: args1,
 | |
|                     ret: ret1,
 | |
|                 },
 | |
|                 Type::Fun {
 | |
|                     args: args2,
 | |
|                     ret: ret2,
 | |
|                 },
 | |
|             ) => {
 | |
|                 let args = args1
 | |
|                     .iter()
 | |
|                     .zip(args2)
 | |
|                     .map(|(t1, t2)| self.unify(t1, t2))
 | |
|                     .try_collect()?;
 | |
|                 let ret = self.unify(ret1, ret2)?;
 | |
|                 Ok(Type::Fun {
 | |
|                     args,
 | |
|                     ret: Box::new(ret),
 | |
|                 })
 | |
|             }
 | |
|             (Type::Nullary(_), _) | (_, Type::Nullary(_)) => todo!(),
 | |
|             _ => Err(Error::TypeMismatch {
 | |
|                 expected: ty1.clone(),
 | |
|                 actual: ty2.clone(),
 | |
|             }),
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     fn finalize_expr(
 | |
|         &self,
 | |
|         expr: hir::Expr<'ast, Type>,
 | |
|     ) -> Result<hir::Expr<'ast, ast::Type<'ast>>> {
 | |
|         expr.traverse_type(|ty| self.finalize_type(ty))
 | |
|     }
 | |
| 
 | |
|     fn finalize_decl(
 | |
|         &self,
 | |
|         decl: hir::Decl<'ast, Type>,
 | |
|     ) -> Result<hir::Decl<'ast, ast::Type<'ast>>> {
 | |
|         decl.traverse_type(|ty| self.finalize_type(ty))
 | |
|     }
 | |
| 
 | |
|     fn finalize_type(&self, ty: Type) -> Result<ast::Type<'static>> {
 | |
|         let ret = match ty {
 | |
|             Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)),
 | |
|             Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))),
 | |
|             Type::Nullary(_) => todo!(),
 | |
|             Type::Prim(pr) => Ok(pr.into()),
 | |
|             Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType {
 | |
|                 args: args
 | |
|                     .into_iter()
 | |
|                     .map(|ty| self.finalize_type(ty))
 | |
|                     .try_collect()?,
 | |
|                 ret: Box::new(self.finalize_type(*ret)?),
 | |
|             })),
 | |
|         };
 | |
|         ret
 | |
|     }
 | |
| 
 | |
|     fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type<'static>> {
 | |
|         let mut res = &Type::Exist(tv);
 | |
|         loop {
 | |
|             match res {
 | |
|                 Type::Exist(tv) => {
 | |
|                     res = self.ctx.get(tv)?;
 | |
|                 }
 | |
|                 Type::Univ(tv) => {
 | |
|                     let ident = self.name_univ(*tv);
 | |
|                     if let Some(r) = self.instantiations.resolve(&ident) {
 | |
|                         res = r;
 | |
|                     } else {
 | |
|                         break Some(ast::Type::Var(ident));
 | |
|                     }
 | |
|                 }
 | |
|                 Type::Nullary(_) => todo!(),
 | |
|                 Type::Prim(pr) => break Some((*pr).into()),
 | |
|                 Type::Fun { args, ret } => todo!(),
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type {
 | |
|         match ast_type {
 | |
|             ast::Type::Int => INT,
 | |
|             ast::Type::Float => FLOAT,
 | |
|             ast::Type::Bool => BOOL,
 | |
|             ast::Type::CString => CSTRING,
 | |
|             ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun {
 | |
|                 args: args
 | |
|                     .into_iter()
 | |
|                     .map(|t| self.type_from_ast_type(t))
 | |
|                     .collect(),
 | |
|                 ret: Box::new(self.type_from_ast_type(*ret)),
 | |
|             },
 | |
|             ast::Type::Var(id) => Type::Univ({
 | |
|                 let opt_tv = { self.type_vars.borrow_mut().0.get_by_left(&id).copied() };
 | |
|                 opt_tv.unwrap_or_else(|| {
 | |
|                     let tv = self.fresh_tv();
 | |
|                     self.type_vars
 | |
|                         .borrow_mut()
 | |
|                         .0
 | |
|                         .insert_no_overwrite(id, tv)
 | |
|                         .unwrap();
 | |
|                     tv
 | |
|                 })
 | |
|             }),
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     fn name_univ(&self, tv: TyVar) -> Ident<'static> {
 | |
|         let mut vars = self.type_vars.borrow_mut();
 | |
|         vars.0
 | |
|             .get_by_right(&tv)
 | |
|             .map(Ident::to_owned)
 | |
|             .unwrap_or_else(|| {
 | |
|                 let name = vars.1.make_name();
 | |
|                 vars.0.insert_no_overwrite(name.clone(), tv).unwrap();
 | |
|                 name
 | |
|             })
 | |
|     }
 | |
| 
 | |
|     fn commit_instantiations(&mut self) {
 | |
|         let mut ctx = mem::take(&mut self.ctx);
 | |
|         for (_, v) in ctx.iter_mut() {
 | |
|             if let Type::Univ(tv) = v {
 | |
|                 if let Some(concrete) = self.instantiations.resolve(&self.name_univ(*tv)) {
 | |
|                     *v = concrete.clone();
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|         self.ctx = ctx;
 | |
|         self.instantiations.pop();
 | |
|     }
 | |
| 
 | |
|     fn types_match(&self, type_: &Type, ast_type: &ast::Type<'ast>) -> bool {
 | |
|         match (type_, ast_type) {
 | |
|             (Type::Univ(u), ast::Type::Var(v)) => {
 | |
|                 Some(u) == self.type_vars.borrow().0.get_by_left(v)
 | |
|             }
 | |
|             (Type::Univ(_), _) => false,
 | |
|             (Type::Exist(_), _) => false,
 | |
|             (Type::Nullary(_), _) => todo!(),
 | |
|             (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
 | |
|             (Type::Fun { args, ret }, ast::Type::Function(ft)) => {
 | |
|                 args.len() == ft.args.len()
 | |
|                     && args
 | |
|                         .iter()
 | |
|                         .zip(&ft.args)
 | |
|                         .all(|(a1, a2)| self.types_match(a1, &a2))
 | |
|                     && self.types_match(&*ret, &*ft.ret)
 | |
|             }
 | |
|             (Type::Fun { .. }, _) => false,
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| pub fn typecheck_expr(expr: ast::Expr) -> Result<hir::Expr<ast::Type>> {
 | |
|     let mut typechecker = Typechecker::new();
 | |
|     let typechecked = typechecker.tc_expr(expr)?;
 | |
|     typechecker.finalize_expr(typechecked)
 | |
| }
 | |
| 
 | |
| pub fn typecheck_toplevel(decls: Vec<ast::Decl>) -> Result<Vec<hir::Decl<ast::Type>>> {
 | |
|     let mut typechecker = Typechecker::new();
 | |
|     decls
 | |
|         .into_iter()
 | |
|         .map(|decl| {
 | |
|             let hir_decl = typechecker.tc_decl(decl)?;
 | |
|             let res = typechecker.finalize_decl(hir_decl)?;
 | |
|             typechecker.ctx.clear();
 | |
|             Ok(res)
 | |
|         })
 | |
|         .try_collect()
 | |
| }
 | |
| 
 | |
| #[cfg(test)]
 | |
| mod tests {
 | |
|     use super::*;
 | |
| 
 | |
|     macro_rules! assert_type {
 | |
|         ($expr: expr, $type: expr) => {
 | |
|             use crate::parser::{expr, type_};
 | |
|             let parsed_expr = test_parse!(expr, $expr);
 | |
|             let parsed_type = test_parse!(type_, $type);
 | |
|             let res = typecheck_expr(parsed_expr).unwrap_or_else(|e| panic!("{}", e));
 | |
|             assert!(
 | |
|                 res.type_().alpha_equiv(&parsed_type),
 | |
|                 "{} inferred type {}, but expected {}",
 | |
|                 $expr,
 | |
|                 res.type_(),
 | |
|                 $type
 | |
|             );
 | |
|         };
 | |
|     }
 | |
| 
 | |
|     macro_rules! assert_type_error {
 | |
|         ($expr: expr) => {
 | |
|             use crate::parser::expr;
 | |
|             let parsed_expr = test_parse!(expr, $expr);
 | |
|             let res = typecheck_expr(parsed_expr);
 | |
|             assert!(
 | |
|                 res.is_err(),
 | |
|                 "Expected type error, but got type: {}",
 | |
|                 res.unwrap().type_()
 | |
|             );
 | |
|         };
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     fn literal_int() {
 | |
|         assert_type!("1", "int");
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     fn conditional() {
 | |
|         assert_type!("if 1 == 2 then 3 else 4", "int");
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     #[ignore]
 | |
|     fn add_bools() {
 | |
|         assert_type_error!("true + false");
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     fn call_generic_function() {
 | |
|         assert_type!("(fn x = x) 1", "int");
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     fn generic_function() {
 | |
|         assert_type!("fn x = x", "fn x -> x");
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     #[ignore]
 | |
|     fn let_generalization() {
 | |
|         assert_type!("let id = fn x = x in if id true then id 1 else 2", "int");
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     fn concrete_function() {
 | |
|         assert_type!("fn x = x + 1", "fn int -> int");
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     fn arg_ascriptions() {
 | |
|         assert_type!("fn (x: int) = x", "fn int -> int");
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     fn call_concrete_function() {
 | |
|         assert_type!("(fn x = x + 1) 2", "int");
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     fn conditional_non_bool() {
 | |
|         assert_type_error!("if 3 then true else false");
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     fn let_int() {
 | |
|         assert_type!("let x = 1 in x", "int");
 | |
|     }
 | |
| }
 |