diff --git a/midas/checker/builtins.py b/midas/checker/builtins.py new file mode 100644 index 0000000..bc80084 --- /dev/null +++ b/midas/checker/builtins.py @@ -0,0 +1,4 @@ +BUILTIN_SUBTYPES: dict[str, set[str]] = { + "float": {"int"}, + "int": {"bool"}, +} diff --git a/midas/checker/checker.py b/midas/checker/checker.py index 8ca9b33..aff3782 100644 --- a/midas/checker/checker.py +++ b/midas/checker/checker.py @@ -6,10 +6,19 @@ from typing import Optional import midas.ast.midas as m import midas.ast.python as p from midas.ast.location import Location +from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.diagnostic import Diagnostic, DiagnosticType from midas.checker.environment import Environment from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS -from midas.checker.types import Function, Type, UnitType, UnknownType +from midas.checker.types import ( + AliasType, + BaseType, + ComplexType, + Function, + Type, + UnitType, + UnknownType, +) from midas.lexer.midas import MidasLexer from midas.lexer.token import Token from midas.parser.midas import MidasParser @@ -168,6 +177,45 @@ class Checker( stmts: list[m.Stmt] = parser.parse() self.ctx.resolve(stmts) + def unfold_type(self, type: Type) -> Type: + match type: + case AliasType(type=ref_type): + return self.unfold_type(ref_type) + case _: + return type + + def is_subtype(self, type1: Type, type2: Type) -> bool: + """Check whether `type1` is a subtype of `type2` + + For more details on the rules checked here, see TAPL Chap. 15-16-17 + + Args: + type1 (Type): the potential subtype + type2 (Type): the potential supertype + + Returns: + bool: whether `type1` is a subtype of `type2` + """ + + type1 = self.unfold_type(type1) + type2 = self.unfold_type(type2) + + if type1 == type2: + return True + + match (type1, type2): + case (BaseType(name=name1), BaseType(name=name2)): + return name1 in BUILTIN_SUBTYPES.get(name2, set()) + + case (ComplexType(properties=props1), ComplexType(properties=props2)): + for k, t in props2.items(): + if k not in props1: + return False + if self.is_subtype(props1[k], t): + return False + return True + + return False def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: self.type_of(stmt.expr) @@ -266,7 +314,7 @@ class Checker( self.env.define(stmt.name, type) def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: - value: Type = self.type_of(stmt.value) + value_type: Type = self.type_of(stmt.value) for target in stmt.targets: if not isinstance(target, p.VariableExpr): self.logger.warning(f"Unsupported assignment to {target}") @@ -276,13 +324,15 @@ class Checker( var_type: Optional[Type] = self.look_up_variable(name, target) if var_type is None: - self.env.define(name, value) + self.env.define(name, value_type) else: - # TODO: implement real comparison method - if var_type != value: + # S <: T + # Γ, x: T v: S + # x = v + if not self.is_subtype(value_type, var_type): self.error( stmt.location, - f"Cannot assign {value} to {name} of type {var_type}", + f"Cannot assign {value_type} to {name} of type {var_type}", ) def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: @@ -357,7 +407,7 @@ class Checker( function: Function = callee mapped: list[MappedArgument] = self.map_call_arguments(function, expr) for arg in mapped: - if arg.type != arg.argument.type: + if not self.is_subtype(arg.type, arg.argument.type): self.error( arg.expr.location, f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", @@ -386,13 +436,17 @@ class Checker( def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: left: Type = expr.left.accept(self) right: Type = expr.right.accept(self) - # TODO: union type - if left != right: - self.error( - expr.location, - f"Operands must be of the same type, left={left} != right={right}", - ) - return left + + if self.is_subtype(left, right): + return right + if self.is_subtype(right, left): + return left + + self.error( + expr.location, + f"Incompatible operand types, {left=} and {right=}", + ) + return UnknownType() def visit_set_expr(self, expr: p.SetExpr) -> Type: ... @@ -410,13 +464,16 @@ class Checker( true_type: Type = expr.if_true.accept(self) false_type: Type = expr.if_false.accept(self) - if true_type != false_type: - self.error( - expr.location, - f"Type mismatch in ternary if branches: true={true_type} != false={false_type}", - ) - return UnknownType() - return true_type + if self.is_subtype(true_type, false_type): + return false_type + if self.is_subtype(false_type, true_type): + return true_type + + self.error( + expr.location, + f"Incompatible types in ternary if branches: true={true_type} and false={false_type}", + ) + return UnknownType() def visit_base_type(self, node: p.BaseType) -> Type: return self.ctx.get_type(node.base)