use super::*;
use crate::types::*;
use std::sync::Arc;
impl Expr {
pub fn typecheck(self: &Arc<Self>, type_expected: Type, ctx: Context<Path, Type>) -> Result<(), TypeError> {
if let Some(type_actual) = self.typeinfer(ctx.clone()) {
if type_actual.equals(&type_expected) {
return Ok(());
} else {
return Err(TypeError::NotExpectedType(type_expected.clone(), type_actual.clone(), self.clone()));
}
}
let result = match (type_expected.clone(), &**self) {
(_type_expected, Expr::Reference(_span, _typ, _path)) => Err(TypeError::UndefinedReference(self.clone())),
(Type::Word(width_expected), Expr::Word(_span, _typ, width_actual, n)) => {
if let Some(width_actual) = width_actual {
if *width_actual == width_expected {
Err(TypeError::Other(self.clone(), format!("Not the expected width")))
} else if n >> *width_actual != 0 {
Err(TypeError::Other(self.clone(), format!("Doesn't fit")))
} else {
Ok(())
}
} else {
if n >> width_expected != 0 {
Err(TypeError::Other(self.clone(), format!("Doesn't fit")))
} else {
Ok(())
}
}
},
(_type_expected, Expr::Enum(_span, _typ, typedef, _name)) => {
if type_expected.equals(typedef) {
Ok(())
} else {
Err(TypeError::Other(self.clone(), format!("Type Error")))
}
},
(Type::Valid(_typ2), Expr::Ctor(_span, _typ, _name, es)) => {
if let Type::Valid(ref typ) = type_expected {
if es.len() == 1 {
es[0].typecheck(*typ.clone(), ctx.clone())
} else if es.len() > 1 {
Err(TypeError::Other(self.clone(), format!("Error")))
} else {
Ok(())
}
} else {
Err(TypeError::Other(self.clone(), format!("Not a Valid<T>: {self:?} is not {type_expected:?}")))
}
},
(Type::Alt(typedef, params), Expr::Ctor(_span, _typ, name, es)) => {
if let Some(typs) = typedef.alt(name) {
if es.len() == typs.len() {
for (e, typ) in es.iter().zip(typs.iter()) {
e.typecheck(typ.clone(), ctx.clone())?;
}
Ok(())
} else {
Err(TypeError::Other(self.clone(), format!("Invalid alt type")))
}
} else {
Err(TypeError::Other(self.clone(), format!("Invalid alt type")))
}
},
(Type::Struct(typedef), Expr::Struct(_span, _typ, fields)) => {
for (name, e) in fields {
if let Some(typ) = typedef.type_of_field(name) {
e.typecheck(typ, ctx.clone())?;
} else {
let typename = &typedef.name;
return Err(TypeError::Other(self.clone(), format!("struct type {typename} has no field {name}")))
}
}
Ok(())
},
(_type_expected, Expr::Let(_span, _typ, name, ascription, e, b)) => {
if let Some(typ) = ascription {
e.typecheck(typ.clone(), ctx.clone())?;
b.typecheck(type_expected.clone(), ctx.extend(name.clone().into(), typ.clone()))
} else if let Some(typ) = e.typeinfer(ctx.clone()) {
b.typecheck(type_expected.clone(), ctx.extend(name.clone().into(), typ))
} else {
Err(TypeError::Other(self.clone(), format!("Can infer type of {e:?} in let expression.")))
}
},
(_type_expected, Expr::Match(_span, _typ, subject, arms)) => {
if let Some(subject_typ) = subject.typeinfer(ctx.clone()) {
let invalid_arms: Vec<&MatchArm> = arms.into_iter().filter(|MatchArm(pat, _e)| !subject_typ.valid_pat(pat)).collect();
if invalid_arms.len() > 0 {
return Err(TypeError::Other(self.clone(), format!("Invalid patterns for {subject_typ:?}")));
}
for MatchArm(pat, e) in arms {
let new_ctx = subject_typ.extend_context_for_pat(ctx.clone(), pat);
e.typecheck(type_expected.clone(), new_ctx)?;
}
} else {
return Err(TypeError::Other(self.clone(), format!("Match: Can't infer subject type")));
}
Ok(())
},
(_type_expected, Expr::UnOp(_span, _typ, UnOp::Not, e)) => e.typecheck(type_expected.clone(), ctx.clone()),
(Type::Word(1), Expr::BinOp(_span, _typ, BinOp::Eq | BinOp::Neq | BinOp::Lt, e1, e2)) => {
if let Some(typ1) = e1.typeinfer(ctx.clone()) {
e2.typecheck(typ1, ctx.clone())?;
Ok(())
} else {
Err(TypeError::Other(self.clone(), format!("Can't infer type.")))
}
},
(Type::Word(_n), Expr::BinOp(_span, _typ, BinOp::Add | BinOp::Sub | BinOp::And | BinOp::Or | BinOp::Xor, e1, e2)) => {
e1.typecheck(type_expected.clone(), ctx.clone())?;
e2.typecheck(type_expected.clone(), ctx.clone())?;
Ok(())
},
(Type::Word(n), Expr::BinOp(_span, _typ, BinOp::AddCarry, e1, e2)) => {
if let (Some(typ1), Some(typ2)) = (e1.typeinfer(ctx.clone()), e2.typeinfer(ctx.clone())) {
if n > 0 && typ1.equals(&typ2) && typ1.equals(&Type::Word(n - 1)) {
Ok(())
} else {
Err(TypeError::Other(self.clone(), format!("Types don't match")))
}
} else {
Err(TypeError::Other(self.clone(), format!("Can't infer type.")))
}
},
(_type_expected, Expr::If(_span, _typ, cond, e1, e2)) => {
cond.typecheck(Type::word(1), ctx.clone())?;
e1.typecheck(type_expected.clone(), ctx.clone())?;
e2.typecheck(type_expected.clone(), ctx.clone())?;
Ok(())
},
(_type_expected, Expr::Mux(_span, _typ, cond, e1, e2)) => {
cond.typecheck(Type::word(1), ctx.clone())?;
e1.typecheck(type_expected.clone(), ctx.clone())?;
e2.typecheck(type_expected.clone(), ctx.clone())?;
Ok(())
},
(Type::Word(width_expected), Expr::Sext(_span, _typ, e)) => {
if let Some(type_actual) = e.typeinfer(ctx.clone()) {
if let Type::Word(m) = type_actual {
if m == 0 {
Err(TypeError::Other(self.clone(), format!("Can't sext a Word[0]")))
} else if width_expected >= m {
Ok(())
} else {
Err(TypeError::Other(self.clone(), format!("Can't sext a Word[{m}] to a a Word[{width_expected}]")))
}
} else {
Err(TypeError::Other(self.clone(), format!("Unknown?")))
}
} else {
Err(TypeError::CantInferType(self.clone()))
}
},
(Type::Word(width_expected), Expr::Zext(_span, _typ, e)) => {
if let Some(type_actual) = e.typeinfer(ctx.clone()) {
if let Type::Word(m) = type_actual {
if width_expected >= m {
Ok(())
} else {
Err(TypeError::Other(self.clone(), format!("Can't zext a Word[{m}] to a a Word[{width_expected}]")))
}
} else {
Err(TypeError::Other(self.clone(), format!("Unknown?")))
}
} else {
Err(TypeError::CantInferType(self.clone()))
}
},
(Type::Valid(inner_type), Expr::TryCast(_span, _typ, e)) => {
if let Type::Enum(typedef) = &*inner_type {
let w = typedef.bitwidth();
e.typecheck(Type::Word(w), ctx.clone())?;
Ok(())
} else {
Err(TypeError::Other(self.clone(), format!("trycast(e) has type Valid<T> for an enum type T")))
}
},
(Type::Word(n), Expr::ToWord(_span, _typ, e)) => {
let typ = e.typeinfer(ctx.clone()).unwrap();
if let Type::Enum(typedef) = typ {
let width = typedef.bitwidth();
if n == width {
Ok(())
} else {
let name = &typedef.name;
Err(TypeError::Other(self.clone(), format!("enum type {name} has bitwidth {width} which cannot be cast to Word[{n}]")))
}
} else {
unreachable!()
}
},
(Type::Vec(typ, n), Expr::Vec(_span, _typ, es)) => {
for e in es {
e.typecheck(*typ.clone(), ctx.clone())?;
}
if es.len() != n as usize {
let type_actual = Type::vec(*typ.clone(), es.len().try_into().unwrap());
Err(TypeError::NotExpectedType(type_expected.clone(), type_actual.clone(), self.clone()))
} else {
Ok(())
}
},
(_type_expected, Expr::Call(_span, _typ, fndef, es)) => {
if fndef.args.len() != es.len() {
let fn_name = &fndef.name;
let m = fndef.args.len();
let n = es.len();
Err(TypeError::Other(self.clone(), format!("{fn_name} takes {n} args, but found {m} instead")))
} else {
let mut errors = vec![];
for ((arg_name, arg_typ), e) in fndef.args.iter().zip(es.iter()) {
if let Err(error) = e.typecheck(Type::clone(&arg_typ), ctx.clone()) {
errors.push(error);
}
}
if errors.len() > 0 {
return Err(errors[0].clone());
}
Ok(())
}
},
(_type_expected, Expr::Hole(_span, _typ, _opt_name)) => Ok(()),
_ => Err(TypeError::Other(self.clone(), format!("{self:?} is not the expected type {type_expected:?}"))),
};
if let Ok(()) = &result {
self.annotate_type(type_expected.clone());
}
result
}
pub fn typeinfer(&self, ctx: Context<Path, Type>) -> Option<Type> {
let result = match self {
Expr::Reference(_span, _typ, path) => {
let type_actual = ctx.lookup(path)?;
Some(type_actual)
},
Expr::Net(_span, _typ, _netid) => panic!("Can't typecheck a net"),
Expr::Word(_span, _typ, None, _n) => None,
Expr::Word(_span, _typ, Some(w), n) => if n >> w == 0 {
Some(Type::word(*w))
} else {
None
},
Expr::Enum(_span, _typ, typedef, _name) => {
Some(typedef.clone())
},
Expr::BinOp(_span, _typ, BinOp::Eq | BinOp::Neq | BinOp::Lt, e1, e2) => {
if let Some(typ1) = e1.typeinfer(ctx.clone()) {
if let Err(_) = e2.typecheck(typ1, ctx.clone()) {
None
} else {
Some(Type::Word(1))
}
} else {
None
}
},
Expr::Cat(_span, _typ, es) => {
let mut w = 0u64;
for e in es {
if let Some(Type::Word(m)) = e.typeinfer(ctx.clone()) {
w += m;
} else {
return None;
}
}
Some(Type::word(w))
},
Expr::Call(_span, _typ, fndef, es) => {
if fndef.args.len() != es.len() {
None
} else {
let mut errors = vec![];
for ((arg_name, arg_typ), e) in fndef.args.iter().zip(es.iter()) {
if let Err(error) = e.typecheck(Type::clone(&arg_typ), ctx.clone()) {
errors.push(error);
}
}
if errors.len() > 0 {
return None;
}
Some(fndef.ret.clone())
}
},
Expr::IdxField(_span, _typ, e, field) => {
match e.typeinfer(ctx.clone()) {
Some(Type::Struct(typedef)) => {
if let Some(type_actual) = typedef.type_of_field(field) {
Some(type_actual.clone())
} else {
None
}
},
_ => None,
}
},
Expr::Idx(_span, _typ, e, i) => {
match e.typeinfer(ctx.clone()) {
Some(Type::Word(n)) if *i < n => Some(Type::word(1)),
_ => None,
}
},
Expr::IdxRange(_span, _typ, e, j, i) => {
match e.typeinfer(ctx.clone()) {
Some(Type::Word(n)) if n >= *j && *j >= *i => Some(Type::word(*j - *i)),
Some(Type::Word(_n)) => None,
Some(_typ) => None,
None => None,
}
},
_ => None,
};
if let Some(type_actual) = &result {
self.annotate_type(type_actual.clone());
}
result
}
fn annotate_type(&self, typ: Type) {
if let Some(type_cell) = self.type_of_cell() {
let _ = type_cell.set(typ.clone());
}
}
}
impl Type {
fn valid_pat(&self, pat: &Pat) -> bool {
match pat {
Pat::At(ctor, subpats) => {
match &*self {
Type::Valid(inner_type) => {
if ctor == "Invalid" && subpats.len() == 0 {
true
} else if ctor == "Valid" && subpats.len() == 1 {
inner_type.valid_pat(&subpats[0])
} else {
false
}
},
Type::Enum(typedef) => {
let alts: Vec<String> = typedef.values.iter().map(|(name, _val)| name.clone()).collect();
alts.contains(ctor)
},
Type::Alt(typedef, _params) => {
let alts: Vec<String> = typedef.alts.iter().map(|(name, _val)| name.clone()).collect();
alts.contains(ctor)
},
_ => false,
}
},
Pat::Bind(_x) => true,
Pat::Otherwise => true,
}
}
fn extend_context_for_pat(&self, ctx: Context<Path, Type>, pat: &Pat) -> Context<Path, Type> {
match (self, pat) {
(_, Pat::Bind(x)) => ctx.extend(x.clone().into(), self.clone()),
(_, Pat::Otherwise) => ctx.clone(),
(Type::Valid(inner_type), Pat::At(ctor, subpats)) => {
if ctor == "Valid" && subpats.len() == 1 {
inner_type.extend_context_for_pat(ctx.clone(), &subpats[0])
} else if ctor == "Invalid" {
ctx.clone()
} else {
unreachable!()
}
},
(Type::Enum(_typedef), Pat::At(_ctor, _subpats)) => {
ctx.clone()
},
(Type::Alt(typedef, params), Pat::At(ctor, subpats)) => {
if let Some(typs) = typedef.alt(ctor) {
let mut new_ctx = ctx.clone();
assert_eq!(subpats.len(), typs.len());
for (subpat, typ) in subpats.iter().zip(typs.iter()) {
new_ctx = typ.extend_context_for_pat(new_ctx.clone(), subpat)
}
new_ctx
} else {
unreachable!()
}
},
_ => unreachable!(),
}
}
}