feat(checker): add is_subtype method
This commit is contained in:
4
midas/checker/builtins.py
Normal file
4
midas/checker/builtins.py
Normal file
@@ -0,0 +1,4 @@
|
||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||
"float": {"int"},
|
||||
"int": {"bool"},
|
||||
}
|
||||
@@ -6,10 +6,19 @@ from typing import Optional
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
|
||||
from midas.checker.types import Function, Type, UnitType, UnknownType
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
BaseType,
|
||||
ComplexType,
|
||||
Function,
|
||||
Type,
|
||||
UnitType,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
@@ -168,6 +177,45 @@ class Checker(
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
self.ctx.resolve(stmts)
|
||||
|
||||
def unfold_type(self, type: Type) -> Type:
|
||||
match type:
|
||||
case AliasType(type=ref_type):
|
||||
return self.unfold_type(ref_type)
|
||||
case _:
|
||||
return type
|
||||
|
||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||
"""Check whether `type1` is a subtype of `type2`
|
||||
|
||||
For more details on the rules checked here, see TAPL Chap. 15-16-17
|
||||
|
||||
Args:
|
||||
type1 (Type): the potential subtype
|
||||
type2 (Type): the potential supertype
|
||||
|
||||
Returns:
|
||||
bool: whether `type1` is a subtype of `type2`
|
||||
"""
|
||||
|
||||
type1 = self.unfold_type(type1)
|
||||
type2 = self.unfold_type(type2)
|
||||
|
||||
if type1 == type2:
|
||||
return True
|
||||
|
||||
match (type1, type2):
|
||||
case (BaseType(name=name1), BaseType(name=name2)):
|
||||
return name1 in BUILTIN_SUBTYPES.get(name2, set())
|
||||
|
||||
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
||||
for k, t in props2.items():
|
||||
if k not in props1:
|
||||
return False
|
||||
if self.is_subtype(props1[k], t):
|
||||
return False
|
||||
return True
|
||||
|
||||
return False
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
self.type_of(stmt.expr)
|
||||
|
||||
@@ -266,7 +314,7 @@ class Checker(
|
||||
self.env.define(stmt.name, type)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
value: Type = self.type_of(stmt.value)
|
||||
value_type: Type = self.type_of(stmt.value)
|
||||
for target in stmt.targets:
|
||||
if not isinstance(target, p.VariableExpr):
|
||||
self.logger.warning(f"Unsupported assignment to {target}")
|
||||
@@ -276,13 +324,15 @@ class Checker(
|
||||
var_type: Optional[Type] = self.look_up_variable(name, target)
|
||||
|
||||
if var_type is None:
|
||||
self.env.define(name, value)
|
||||
self.env.define(name, value_type)
|
||||
else:
|
||||
# TODO: implement real comparison method
|
||||
if var_type != value:
|
||||
# S <: T
|
||||
# Γ, x: T v: S
|
||||
# x = v
|
||||
if not self.is_subtype(value_type, var_type):
|
||||
self.error(
|
||||
stmt.location,
|
||||
f"Cannot assign {value} to {name} of type {var_type}",
|
||||
f"Cannot assign {value_type} to {name} of type {var_type}",
|
||||
)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
@@ -357,7 +407,7 @@ class Checker(
|
||||
function: Function = callee
|
||||
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
||||
for arg in mapped:
|
||||
if arg.type != arg.argument.type:
|
||||
if not self.is_subtype(arg.type, arg.argument.type):
|
||||
self.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
@@ -386,13 +436,17 @@ class Checker(
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
|
||||
left: Type = expr.left.accept(self)
|
||||
right: Type = expr.right.accept(self)
|
||||
# TODO: union type
|
||||
if left != right:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Operands must be of the same type, left={left} != right={right}",
|
||||
)
|
||||
return left
|
||||
|
||||
if self.is_subtype(left, right):
|
||||
return right
|
||||
if self.is_subtype(right, left):
|
||||
return left
|
||||
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Incompatible operand types, {left=} and {right=}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def visit_set_expr(self, expr: p.SetExpr) -> Type: ...
|
||||
|
||||
@@ -410,13 +464,16 @@ class Checker(
|
||||
|
||||
true_type: Type = expr.if_true.accept(self)
|
||||
false_type: Type = expr.if_false.accept(self)
|
||||
if true_type != false_type:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Type mismatch in ternary if branches: true={true_type} != false={false_type}",
|
||||
)
|
||||
return UnknownType()
|
||||
return true_type
|
||||
if self.is_subtype(true_type, false_type):
|
||||
return false_type
|
||||
if self.is_subtype(false_type, true_type):
|
||||
return true_type
|
||||
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Incompatible types in ternary if branches: true={true_type} and false={false_type}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def visit_base_type(self, node: p.BaseType) -> Type:
|
||||
return self.ctx.get_type(node.base)
|
||||
|
||||
Reference in New Issue
Block a user