refactor(checker): replace all accept calls
make visitor accept calls more explicit with type_of(), resolve_type_expr() and process_stmt()
This commit is contained in:
@@ -1,8 +1,8 @@
|
|||||||
type Meter = float
|
type Meter = float
|
||||||
|
|
||||||
extend Meter {
|
extend Meter {
|
||||||
def __add__: fn(Meter) -> Meter
|
def __add__: fn(Meter, /) -> Meter
|
||||||
def __sub__: fn(Meter) -> Meter
|
def __sub__: fn(Meter, /) -> Meter
|
||||||
}
|
}
|
||||||
|
|
||||||
type Coordinate = object
|
type Coordinate = object
|
||||||
|
|||||||
@@ -78,6 +78,12 @@ class PythonTyper(
|
|||||||
self.judgements.append((expr, type))
|
self.judgements.append((expr, type))
|
||||||
return type
|
return type
|
||||||
|
|
||||||
|
def resolve_type_expr(self, expr: p.MidasType) -> Type:
|
||||||
|
return expr.accept(self)
|
||||||
|
|
||||||
|
def process_stmt(self, stmt: p.Stmt) -> None:
|
||||||
|
stmt.accept(self)
|
||||||
|
|
||||||
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
||||||
"""Evaluate a sequence of statements
|
"""Evaluate a sequence of statements
|
||||||
|
|
||||||
@@ -93,7 +99,7 @@ class PythonTyper(
|
|||||||
returned: bool = False
|
returned: bool = False
|
||||||
for i, stmt in enumerate(block):
|
for i, stmt in enumerate(block):
|
||||||
try:
|
try:
|
||||||
stmt.accept(self)
|
self.process_stmt(stmt)
|
||||||
except ReturnException:
|
except ReturnException:
|
||||||
returned = True
|
returned = True
|
||||||
if i < len(block) - 1:
|
if i < len(block) - 1:
|
||||||
@@ -111,7 +117,7 @@ class PythonTyper(
|
|||||||
statements (list[p.Stmt]): the statements to evaluate and check
|
statements (list[p.Stmt]): the statements to evaluate and check
|
||||||
"""
|
"""
|
||||||
for stmt in statements:
|
for stmt in statements:
|
||||||
stmt.accept(self)
|
self.process_stmt(stmt)
|
||||||
|
|
||||||
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
|
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
|
||||||
|
|
||||||
@@ -144,9 +150,9 @@ class PythonTyper(
|
|||||||
|
|
||||||
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||||
if arg.type is not None:
|
if arg.type is not None:
|
||||||
return arg.type.accept(self)
|
return self.resolve_type_expr(arg.type)
|
||||||
if arg.default is not None:
|
if arg.default is not None:
|
||||||
return arg.default.accept(self)
|
return self.type_of(arg.default)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
pos: int = 0
|
pos: int = 0
|
||||||
@@ -186,7 +192,7 @@ class PythonTyper(
|
|||||||
|
|
||||||
returns_hint: Optional[Type] = None
|
returns_hint: Optional[Type] = None
|
||||||
if stmt.returns is not None:
|
if stmt.returns is not None:
|
||||||
returns_hint = stmt.returns.accept(self)
|
returns_hint = self.resolve_type_expr(stmt.returns)
|
||||||
# Early define to handle simple fully-typed recursion
|
# Early define to handle simple fully-typed recursion
|
||||||
inside_function: Function = Function(
|
inside_function: Function = Function(
|
||||||
pos_args=pos_args,
|
pos_args=pos_args,
|
||||||
@@ -232,7 +238,7 @@ class PythonTyper(
|
|||||||
|
|
||||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||||
# TODO check not yet defined locally
|
# TODO check not yet defined locally
|
||||||
type: Type = stmt.type.accept(self)
|
type: Type = self.resolve_type_expr(stmt.type)
|
||||||
self.env.define(stmt.name, type)
|
self.env.define(stmt.name, type)
|
||||||
|
|
||||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||||
@@ -287,7 +293,7 @@ class PythonTyper(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||||
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
|
type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType()
|
||||||
self.env.return_types.append(type)
|
self.env.return_types.append(type)
|
||||||
raise ReturnException()
|
raise ReturnException()
|
||||||
|
|
||||||
@@ -297,7 +303,7 @@ class PythonTyper(
|
|||||||
# if (m := 1 + 1) < 2:
|
# if (m := 1 + 1) < 2:
|
||||||
# ...
|
# ...
|
||||||
# print(m) # <- m is still defined
|
# print(m) # <- m is still defined
|
||||||
test_type: Type = stmt.test.accept(self)
|
test_type: Type = self.type_of(stmt.test)
|
||||||
|
|
||||||
# TODO Allow subtypes or any type
|
# TODO Allow subtypes or any type
|
||||||
if test_type != self.types.get_type("bool"):
|
if test_type != self.types.get_type("bool"):
|
||||||
@@ -419,8 +425,8 @@ class PythonTyper(
|
|||||||
return type or UnknownType()
|
return type or UnknownType()
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
|
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
|
||||||
left: Type = expr.left.accept(self)
|
left: Type = self.type_of(expr.left)
|
||||||
right: Type = expr.right.accept(self)
|
right: Type = self.type_of(expr.right)
|
||||||
|
|
||||||
if self.is_subtype(left, right):
|
if self.is_subtype(left, right):
|
||||||
return right
|
return right
|
||||||
@@ -434,10 +440,10 @@ class PythonTyper(
|
|||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
||||||
return expr.type.accept(self)
|
return self.resolve_type_expr(expr.type)
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
||||||
test_type: Type = expr.test.accept(self)
|
test_type: Type = self.type_of(expr.test)
|
||||||
|
|
||||||
# TODO Allow subtypes or any type
|
# TODO Allow subtypes or any type
|
||||||
if test_type != self.types.get_type("bool"):
|
if test_type != self.types.get_type("bool"):
|
||||||
@@ -445,8 +451,8 @@ class PythonTyper(
|
|||||||
expr.test.location, f"If test must be a boolean, got {test_type}"
|
expr.test.location, f"If test must be a boolean, got {test_type}"
|
||||||
)
|
)
|
||||||
|
|
||||||
true_type: Type = expr.if_true.accept(self)
|
true_type: Type = self.type_of(expr.if_true)
|
||||||
false_type: Type = expr.if_false.accept(self)
|
false_type: Type = self.type_of(expr.if_false)
|
||||||
if self.is_subtype(true_type, false_type):
|
if self.is_subtype(true_type, false_type):
|
||||||
return false_type
|
return false_type
|
||||||
if self.is_subtype(false_type, true_type):
|
if self.is_subtype(false_type, true_type):
|
||||||
@@ -484,7 +490,7 @@ class PythonTyper(
|
|||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
if node.param is not None:
|
if node.param is not None:
|
||||||
param: Type = node.param.accept(self)
|
param: Type = self.resolve_type_expr(node.param)
|
||||||
return self.types.apply_generic(base, [param])
|
return self.types.apply_generic(base, [param])
|
||||||
return base
|
return base
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user