feat(checker): add is_subtype method

This commit is contained in:
2026-06-06 16:30:04 +02:00
parent 1c30188122
commit 67c40a3909
2 changed files with 82 additions and 21 deletions

View File

@@ -0,0 +1,4 @@
BUILTIN_SUBTYPES: dict[str, set[str]] = {
"float": {"int"},
"int": {"bool"},
}

View File

@@ -6,10 +6,19 @@ from typing import Optional
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.environment import Environment
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
from midas.checker.types import Function, Type, UnitType, UnknownType
from midas.checker.types import (
AliasType,
BaseType,
ComplexType,
Function,
Type,
UnitType,
UnknownType,
)
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
@@ -168,6 +177,45 @@ class Checker(
stmts: list[m.Stmt] = parser.parse()
self.ctx.resolve(stmts)
def unfold_type(self, type: Type) -> Type:
match type:
case AliasType(type=ref_type):
return self.unfold_type(ref_type)
case _:
return type
def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2`
For more details on the rules checked here, see TAPL Chap. 15-16-17
Args:
type1 (Type): the potential subtype
type2 (Type): the potential supertype
Returns:
bool: whether `type1` is a subtype of `type2`
"""
type1 = self.unfold_type(type1)
type2 = self.unfold_type(type2)
if type1 == type2:
return True
match (type1, type2):
case (BaseType(name=name1), BaseType(name=name2)):
return name1 in BUILTIN_SUBTYPES.get(name2, set())
case (ComplexType(properties=props1), ComplexType(properties=props2)):
for k, t in props2.items():
if k not in props1:
return False
if self.is_subtype(props1[k], t):
return False
return True
return False
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
self.type_of(stmt.expr)
@@ -266,7 +314,7 @@ class Checker(
self.env.define(stmt.name, type)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
value: Type = self.type_of(stmt.value)
value_type: Type = self.type_of(stmt.value)
for target in stmt.targets:
if not isinstance(target, p.VariableExpr):
self.logger.warning(f"Unsupported assignment to {target}")
@@ -276,13 +324,15 @@ class Checker(
var_type: Optional[Type] = self.look_up_variable(name, target)
if var_type is None:
self.env.define(name, value)
self.env.define(name, value_type)
else:
# TODO: implement real comparison method
if var_type != value:
# S <: T
# Γ, x: T v: S
# x = v
if not self.is_subtype(value_type, var_type):
self.error(
stmt.location,
f"Cannot assign {value} to {name} of type {var_type}",
f"Cannot assign {value_type} to {name} of type {var_type}",
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
@@ -357,7 +407,7 @@ class Checker(
function: Function = callee
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
for arg in mapped:
if arg.type != arg.argument.type:
if not self.is_subtype(arg.type, arg.argument.type):
self.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
@@ -386,13 +436,17 @@ class Checker(
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
left: Type = expr.left.accept(self)
right: Type = expr.right.accept(self)
# TODO: union type
if left != right:
self.error(
expr.location,
f"Operands must be of the same type, left={left} != right={right}",
)
return left
if self.is_subtype(left, right):
return right
if self.is_subtype(right, left):
return left
self.error(
expr.location,
f"Incompatible operand types, {left=} and {right=}",
)
return UnknownType()
def visit_set_expr(self, expr: p.SetExpr) -> Type: ...
@@ -410,13 +464,16 @@ class Checker(
true_type: Type = expr.if_true.accept(self)
false_type: Type = expr.if_false.accept(self)
if true_type != false_type:
self.error(
expr.location,
f"Type mismatch in ternary if branches: true={true_type} != false={false_type}",
)
return UnknownType()
return true_type
if self.is_subtype(true_type, false_type):
return false_type
if self.is_subtype(false_type, true_type):
return true_type
self.error(
expr.location,
f"Incompatible types in ternary if branches: true={true_type} and false={false_type}",
)
return UnknownType()
def visit_base_type(self, node: p.BaseType) -> Type:
return self.ctx.get_type(node.base)