refactor(checker): move unfold_type to types.py

This commit is contained in:
2026-06-08 10:56:27 +02:00
parent 25e6410546
commit c64ab434b5
2 changed files with 11 additions and 9 deletions

View File

@@ -19,6 +19,7 @@ from midas.checker.types import (
Type, Type,
UnitType, UnitType,
UnknownType, UnknownType,
unfold_type,
) )
from midas.lexer.midas import MidasLexer from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token from midas.lexer.token import Token
@@ -178,13 +179,6 @@ class Checker(
stmts: list[m.Stmt] = parser.parse() stmts: list[m.Stmt] = parser.parse()
self.ctx.resolve(stmts) self.ctx.resolve(stmts)
def unfold_type(self, type: Type) -> Type:
match type:
case AliasType(type=ref_type):
return self.unfold_type(ref_type)
case _:
return type
def is_subtype(self, type1: Type, type2: Type) -> bool: def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2` """Check whether `type1` is a subtype of `type2`
@@ -470,7 +464,7 @@ class Checker(
def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type): def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type):
object: Type = self.type_of(target.object) object: Type = self.type_of(target.object)
base_object: Type = self.unfold_type(object) base_object: Type = unfold_type(object)
match base_object: match base_object:
case ComplexType(properties=properties): case ComplexType(properties=properties):
if target.name not in properties: if target.name not in properties:
@@ -611,7 +605,7 @@ class Checker(
def visit_get_expr(self, expr: p.GetExpr) -> Type: def visit_get_expr(self, expr: p.GetExpr) -> Type:
object: Type = self.type_of(expr.object) object: Type = self.type_of(expr.object)
base_object: Type = self.unfold_type(object) base_object: Type = unfold_type(object)
match base_object: match base_object:
case ComplexType(properties=properties): case ComplexType(properties=properties):
if expr.name not in properties: if expr.name not in properties:

View File

@@ -120,6 +120,14 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
raise NotImplementedError(f"Unsupported type {type}") raise NotImplementedError(f"Unsupported type {type}")
def unfold_type(type: Type) -> Type:
match type:
case AliasType(type=ref_type):
return unfold_type(ref_type)
case _:
return type
Type = ( Type = (
BaseType BaseType
| AliasType | AliasType