diff --git a/midas/checker/types.py b/midas/checker/types.py index 1e2f149..d9d91b7 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -9,9 +9,9 @@ class BaseType: @dataclass(frozen=True, kw_only=True) -class SimpleType: +class AliasType: name: str - base: BaseType | SimpleType + type: Type @dataclass(frozen=True, kw_only=True) @@ -39,4 +39,16 @@ class Function: required: bool -Type = BaseType | SimpleType | UnknownType | UnitType | Function +@dataclass(frozen=True, kw_only=True) +class ComplexType: + properties: dict[str, Type] + + +@dataclass(frozen=True, kw_only=True) +class UnionType: + alternatives: list[Type] + + +Type = ( + BaseType | AliasType | UnknownType | UnitType | Function | ComplexType | UnionType +) diff --git a/midas/resolver/midas.py b/midas/resolver/midas.py index d57ca70..ff97dbc 100644 --- a/midas/resolver/midas.py +++ b/midas/resolver/midas.py @@ -1,11 +1,15 @@ from typing import Optional import midas.ast.midas as m -from midas.checker.types import BaseType, SimpleType, Type +from midas.checker.types import ( + Type, + UnionType, + UnknownType, +) from midas.resolver.builtin import define_builtins -class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]): +class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]): """A resolver which evaluates Midas type definitions and build a registry""" def __init__(self) -> None: @@ -94,20 +98,12 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]): for stmt in stmts: stmt.accept(self) - def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None: - # TODO generics, optional, constraint - base: Type = self.get_type(stmt.base.name.lexeme) - match base: - case BaseType() | SimpleType(): - type = SimpleType( - name=stmt.name.lexeme, - base=base, - ) - self.define_type(type.name, type) - case _: - raise TypeError(f"Invalid base {base} for simple type") - - def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None: ... + def visit_type_stmt(self, stmt: m.TypeStmt) -> None: + type: Type = stmt.type.accept(self) + for param in stmt.params: + if param.bound is not None: + param.bound.accept(self) + self.define_type(stmt.name.lexeme, type) def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... @@ -127,27 +123,44 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]): def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ... - def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> Type: - return self.get_type(expr.name.lexeme) + def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ... - def visit_logical_expr(self, expr: m.LogicalExpr) -> Type: ... + def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ... - def visit_binary_expr(self, expr: m.BinaryExpr) -> Type: ... + def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ... - def visit_unary_expr(self, expr: m.UnaryExpr) -> Type: ... + def visit_get_expr(self, expr: m.GetExpr) -> None: ... - def visit_get_expr(self, expr: m.GetExpr) -> Type: ... + def visit_variable_expr(self, expr: m.VariableExpr) -> None: ... - def visit_variable_expr(self, expr: m.VariableExpr) -> Type: ... - - def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type: + def visit_grouping_expr(self, expr: m.GroupingExpr) -> None: return expr.expr.accept(self) - def visit_literal_expr(self, expr: m.LiteralExpr) -> Type: ... + def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ... - def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type: ... + def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ... - def visit_template_expr(self, expr: m.TemplateExpr) -> Type: ... + def visit_named_type(self, type: m.NamedType) -> Type: + return self.get_type(type.name.lexeme) - def visit_type_expr(self, expr: m.TypeExpr) -> Type: - return self.get_type(expr.name.lexeme) + def visit_generic_type(self, type: m.GenericType) -> Type: + type_: Type = type.type.accept(self) + params: list[Type] = [param.accept(self) for param in type.params] + # TODO + return UnknownType() + + def visit_constraint_type(self, type: m.ConstraintType) -> Type: + type_: Type = type.type.accept(self) + type.constraint.accept(self) + # TODO + return UnknownType() + + def visit_union_type(self, type: m.UnionType) -> Type: + types: list[Type] = [type_.accept(self) for type_ in type.types] + return UnionType(alternatives=types) + + def visit_complex_type(self, type: m.ComplexType) -> Type: + for prop in type.properties: + prop.accept(self) + # TODO + return UnknownType()