feat(checker): add is_subtype method
This commit is contained in:
4
midas/checker/builtins.py
Normal file
4
midas/checker/builtins.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||||
|
"float": {"int"},
|
||||||
|
"int": {"bool"},
|
||||||
|
}
|
||||||
@@ -6,10 +6,19 @@ from typing import Optional
|
|||||||
import midas.ast.midas as m
|
import midas.ast.midas as m
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
|
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
|
||||||
from midas.checker.types import Function, Type, UnitType, UnknownType
|
from midas.checker.types import (
|
||||||
|
AliasType,
|
||||||
|
BaseType,
|
||||||
|
ComplexType,
|
||||||
|
Function,
|
||||||
|
Type,
|
||||||
|
UnitType,
|
||||||
|
UnknownType,
|
||||||
|
)
|
||||||
from midas.lexer.midas import MidasLexer
|
from midas.lexer.midas import MidasLexer
|
||||||
from midas.lexer.token import Token
|
from midas.lexer.token import Token
|
||||||
from midas.parser.midas import MidasParser
|
from midas.parser.midas import MidasParser
|
||||||
@@ -168,6 +177,45 @@ class Checker(
|
|||||||
stmts: list[m.Stmt] = parser.parse()
|
stmts: list[m.Stmt] = parser.parse()
|
||||||
self.ctx.resolve(stmts)
|
self.ctx.resolve(stmts)
|
||||||
|
|
||||||
|
def unfold_type(self, type: Type) -> Type:
|
||||||
|
match type:
|
||||||
|
case AliasType(type=ref_type):
|
||||||
|
return self.unfold_type(ref_type)
|
||||||
|
case _:
|
||||||
|
return type
|
||||||
|
|
||||||
|
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||||
|
"""Check whether `type1` is a subtype of `type2`
|
||||||
|
|
||||||
|
For more details on the rules checked here, see TAPL Chap. 15-16-17
|
||||||
|
|
||||||
|
Args:
|
||||||
|
type1 (Type): the potential subtype
|
||||||
|
type2 (Type): the potential supertype
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether `type1` is a subtype of `type2`
|
||||||
|
"""
|
||||||
|
|
||||||
|
type1 = self.unfold_type(type1)
|
||||||
|
type2 = self.unfold_type(type2)
|
||||||
|
|
||||||
|
if type1 == type2:
|
||||||
|
return True
|
||||||
|
|
||||||
|
match (type1, type2):
|
||||||
|
case (BaseType(name=name1), BaseType(name=name2)):
|
||||||
|
return name1 in BUILTIN_SUBTYPES.get(name2, set())
|
||||||
|
|
||||||
|
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
||||||
|
for k, t in props2.items():
|
||||||
|
if k not in props1:
|
||||||
|
return False
|
||||||
|
if self.is_subtype(props1[k], t):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||||
self.type_of(stmt.expr)
|
self.type_of(stmt.expr)
|
||||||
|
|
||||||
@@ -266,7 +314,7 @@ class Checker(
|
|||||||
self.env.define(stmt.name, type)
|
self.env.define(stmt.name, type)
|
||||||
|
|
||||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||||
value: Type = self.type_of(stmt.value)
|
value_type: Type = self.type_of(stmt.value)
|
||||||
for target in stmt.targets:
|
for target in stmt.targets:
|
||||||
if not isinstance(target, p.VariableExpr):
|
if not isinstance(target, p.VariableExpr):
|
||||||
self.logger.warning(f"Unsupported assignment to {target}")
|
self.logger.warning(f"Unsupported assignment to {target}")
|
||||||
@@ -276,13 +324,15 @@ class Checker(
|
|||||||
var_type: Optional[Type] = self.look_up_variable(name, target)
|
var_type: Optional[Type] = self.look_up_variable(name, target)
|
||||||
|
|
||||||
if var_type is None:
|
if var_type is None:
|
||||||
self.env.define(name, value)
|
self.env.define(name, value_type)
|
||||||
else:
|
else:
|
||||||
# TODO: implement real comparison method
|
# S <: T
|
||||||
if var_type != value:
|
# Γ, x: T v: S
|
||||||
|
# x = v
|
||||||
|
if not self.is_subtype(value_type, var_type):
|
||||||
self.error(
|
self.error(
|
||||||
stmt.location,
|
stmt.location,
|
||||||
f"Cannot assign {value} to {name} of type {var_type}",
|
f"Cannot assign {value_type} to {name} of type {var_type}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||||
@@ -357,7 +407,7 @@ class Checker(
|
|||||||
function: Function = callee
|
function: Function = callee
|
||||||
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
||||||
for arg in mapped:
|
for arg in mapped:
|
||||||
if arg.type != arg.argument.type:
|
if not self.is_subtype(arg.type, arg.argument.type):
|
||||||
self.error(
|
self.error(
|
||||||
arg.expr.location,
|
arg.expr.location,
|
||||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||||
@@ -386,13 +436,17 @@ class Checker(
|
|||||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
|
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
|
||||||
left: Type = expr.left.accept(self)
|
left: Type = expr.left.accept(self)
|
||||||
right: Type = expr.right.accept(self)
|
right: Type = expr.right.accept(self)
|
||||||
# TODO: union type
|
|
||||||
if left != right:
|
if self.is_subtype(left, right):
|
||||||
|
return right
|
||||||
|
if self.is_subtype(right, left):
|
||||||
|
return left
|
||||||
|
|
||||||
self.error(
|
self.error(
|
||||||
expr.location,
|
expr.location,
|
||||||
f"Operands must be of the same type, left={left} != right={right}",
|
f"Incompatible operand types, {left=} and {right=}",
|
||||||
)
|
)
|
||||||
return left
|
return UnknownType()
|
||||||
|
|
||||||
def visit_set_expr(self, expr: p.SetExpr) -> Type: ...
|
def visit_set_expr(self, expr: p.SetExpr) -> Type: ...
|
||||||
|
|
||||||
@@ -410,13 +464,16 @@ class Checker(
|
|||||||
|
|
||||||
true_type: Type = expr.if_true.accept(self)
|
true_type: Type = expr.if_true.accept(self)
|
||||||
false_type: Type = expr.if_false.accept(self)
|
false_type: Type = expr.if_false.accept(self)
|
||||||
if true_type != false_type:
|
if self.is_subtype(true_type, false_type):
|
||||||
|
return false_type
|
||||||
|
if self.is_subtype(false_type, true_type):
|
||||||
|
return true_type
|
||||||
|
|
||||||
self.error(
|
self.error(
|
||||||
expr.location,
|
expr.location,
|
||||||
f"Type mismatch in ternary if branches: true={true_type} != false={false_type}",
|
f"Incompatible types in ternary if branches: true={true_type} and false={false_type}",
|
||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
return true_type
|
|
||||||
|
|
||||||
def visit_base_type(self, node: p.BaseType) -> Type:
|
def visit_base_type(self, node: p.BaseType) -> Type:
|
||||||
return self.ctx.get_type(node.base)
|
return self.ctx.get_type(node.base)
|
||||||
|
|||||||
Reference in New Issue
Block a user