use super::Context;
use super::Expr;
use super::Path;
use super::loc::Span;
use std::sync::Arc;
pub use crate::ast::WordLit; pub use crate::ast::Kind;    pub type Width = u64;
pub type Length = u64;
#[derive(Clone)]
pub enum Type {
    Word(Width),
    Vec(Box<Type>, Length),
    Valid(Box<Type>),
    Enum(Arc<EnumTypeDef>),
    Struct(Arc<StructTypeDef>),
    Alt(Arc<AltTypeDef>, Vec<TypeParam>),
}
impl Type {
    pub fn name(&self) -> &str {
        match self {
            Type::Word(_width) => "Word",
            Type::Vec(_typ, _length) => "Vec",
            Type::Valid(_typ) => "Valid",
            Type::Enum(typedef) => &typedef.name,
            Type::Struct(typedef) => &typedef.name,
            Type::Alt(typedef, _params) => &typedef.name,
        }
    }
    #[rustfmt::skip]
    pub fn equals(&self, other: &Type) -> bool {
        match (self, other) {
            (Type::Word(width1),      Type::Word(width2)) => width1 == width2,
            (Type::Vec(typ1, len1),   Type::Vec(typ2, len2)) => len1 == len2 && typ1.equals(typ2),
            (Type::Valid(typ1),       Type::Valid(typ2)) => typ1.equals(typ2),
            (Type::Enum(typedef1),    Type::Enum(typedef2)) => Arc::ptr_eq(typedef1, typedef2),
            (Type::Struct(typedef1),  Type::Struct(typedef2)) => Arc::ptr_eq(typedef1, typedef2),
            (Type::Alt(typedef1, params1),     Type::Alt(typedef2, params2)) => {
                if params1.len() != params2.len() {
                    return false;
                }
                for (param1, param2) in params1.iter().zip(params2.iter()) {
                    match (param1, param2) {
                        (TypeParam::Nat(n1),    TypeParam::Nat(n2)) if n1 != n2 => return false,
                        (TypeParam::Type(typ1), TypeParam::Type(typ2)) if !typ1.equals(typ2) => return false,
                        _ => (),
                    }
                }
                Arc::ptr_eq(typedef1, typedef2)
            },
            _ => false,
        }
    }
    pub fn word(w: Width) -> Type {
        Type::Word(w)
    }
    pub fn vec(typ: Type, n: Length) -> Type {
        Type::Vec(Box::new(typ), n)
    }
    pub fn valid(typ: Type) -> Type {
        Type::Valid(Box::new(typ))
    }
    pub fn bitwidth(&self) -> Width {
        match self {
            Type::Word(n) => *n,
            Type::Valid(typ) => typ.bitwidth() + 1,
            Type::Vec(typ, n) => typ.bitwidth() * n,
            Type::Enum(typedef) => typedef.bitwidth(),
            Type::Struct(typedef) => typedef.bitwidth(),
            Type::Alt(typedef, params) => todo!(), }
    }
}
#[derive(Clone)]
pub enum TypeParam {
    Nat(u64),
    Type(Type),
}
#[derive(Debug, Clone)]
pub struct EnumTypeDef {
    pub name: String,
    pub values: Vec<(String, WordLit)>,
    pub span: Span,
}
#[derive(Debug, Clone)]
pub struct StructTypeDef {
    pub name: String,
    pub fields: Vec<(String, Type)>,
    pub span: Span,
}
#[derive(Debug, Clone)]
pub struct AltTypeDef {
    pub name: String,
    pub alts: Vec<(String, Vec<Type>)>,
    pub span: Span,
}
impl AltTypeDef {
    pub fn alt(&self, name: &str) -> Option<Vec<Type>> {
        for (nam, typs) in &self.alts {
            if name == nam {
                return Some(typs.clone())
            }
        }
        None
    }
}
#[derive(Debug, Clone)]
pub struct FnDef {
    pub span: Span,
    pub name: String,
    pub type_args: Vec<(String, Kind)>,
    pub args: Vec<(String, Type)>,
    pub ret: Type,
    pub body: Arc<Expr>,
}
impl FnDef {
    pub fn context(&self) -> Context<Path, Type> {
        Context::from(self.args.iter().map(|(arg_name, arg_type)| (arg_name.to_string().into(), arg_type.clone())).collect::<Vec<_>>())
    }
}
impl StructTypeDef {
    pub fn bitwidth(&self) -> Width {
        self.fields.iter().map(|(_name, typ)| typ.bitwidth()).sum()
    }
    pub fn type_of_field(&self, fieldname: &str) -> Option<Type> {
        for (name, typ) in &self.fields {
            if name == fieldname {
                return Some(typ.clone())
            }
        }
        None
    }
}
impl EnumTypeDef {
    pub fn value_of(&self, name: &str) -> Option<u64> {
        for (other_name, WordLit(_w, value)) in &self.values {
            if name == other_name {
                return Some(*value);
            }
        }
        None
    }
    pub fn bitwidth(&self) -> Width {
        let mut max_width = None;
        for (_name, value) in &self.values {
            if let WordLit(Some(w), _n) = value {
                if let Some(max_w) = max_width {
                    assert_eq!(*w, max_w);
                } else {
                    max_width = Some(*w);
                }
             }
        }
        max_width.unwrap()
    }
}
impl std::fmt::Debug for Type {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
        match self {
                Type::Word(n) => write!(f, "Word[{n}]"),
                Type::Valid(typ) => write!(f, "Valid[{typ:?}]"),
                Type::Vec(typ, n) => write!(f, "Vec[{typ:?}, {n}]"),
                Type::Struct(typedef) => write!(f, "{}", typedef.name),
            Type::Enum(typedef) => write!(f, "{}", typedef.name),
            Type::Alt(typedef, params) => {
                if params.is_empty() {
                    write!(f, "{}", typedef.name)
                } else {
                    write!(f, "{}[{:?}]", typedef.name, params.iter().map(|param| param.to_string()).collect::<Vec<_>>().join(", "))
                }
            },
        }
    }
}
impl std::fmt::Display for TypeParam {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
        match self {
            TypeParam::Nat(n) => write!(f, "{n}"),
            TypeParam::Type(typ) => write!(f, "{typ:?}"),
        }
    }
}