refactor(checker): restructure around shared registry
restructure the type checker with a shared TypesRegistry used by MidasTyper and PythonTyper this commit also relocates some methods in more appropriate places, such as is_subtype and apply_generic (now in TypesRegistry)
This commit is contained in:
@@ -1,661 +1,35 @@
|
|||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import midas.ast.midas as m
|
from midas.checker.diagnostic import Diagnostic
|
||||||
import midas.ast.python as p
|
from midas.checker.midas import MidasTyper
|
||||||
from midas.ast.location import Location
|
from midas.checker.python import PythonTyper
|
||||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.reporter import Reporter
|
||||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
|
|
||||||
from midas.checker.types import (
|
|
||||||
ComplexType,
|
|
||||||
Function,
|
|
||||||
Operation,
|
|
||||||
Type,
|
|
||||||
UnitType,
|
|
||||||
UnknownType,
|
|
||||||
unfold_type,
|
|
||||||
)
|
|
||||||
from midas.lexer.midas import MidasLexer
|
|
||||||
from midas.lexer.token import Token
|
|
||||||
from midas.parser.midas import MidasParser
|
|
||||||
from midas.resolver.midas import MidasResolver
|
|
||||||
|
|
||||||
|
|
||||||
class ReturnException(Exception):
|
class TypeChecker:
|
||||||
pass
|
def __init__(self):
|
||||||
|
self.types: TypesRegistry = TypesRegistry()
|
||||||
|
self.reporter: Reporter = Reporter()
|
||||||
|
|
||||||
|
self.midas_typer = MidasTyper(self.types, self.reporter)
|
||||||
|
self.python_typer = PythonTyper(self.types, self.reporter)
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
def import_midas(self, path: Path):
|
||||||
class MappedArgument:
|
source: str = path.read_text()
|
||||||
expr: p.Expr
|
return self.import_midas_source(source, path=str(path))
|
||||||
type: Type
|
|
||||||
argument: Function.Argument
|
|
||||||
|
|
||||||
|
def import_midas_source(self, source: str, path: Optional[str] = None):
|
||||||
|
self.midas_typer.process(source, path)
|
||||||
|
|
||||||
class Checker(
|
def type_check(self, path: Path):
|
||||||
p.Stmt.Visitor[None],
|
source: str = path.read_text()
|
||||||
p.Expr.Visitor[Type],
|
return self.type_check_source(source, path=str(path))
|
||||||
p.MidasType.Visitor[Type],
|
|
||||||
):
|
|
||||||
"""A type checker which can use custom type definitions"""
|
|
||||||
|
|
||||||
def __init__(
|
def type_check_source(self, source: str, path: Optional[str] = None):
|
||||||
self,
|
self.python_typer.process(source, path)
|
||||||
locals: dict[p.Expr, int],
|
|
||||||
source_path: Path,
|
|
||||||
types_paths: list[Path],
|
|
||||||
):
|
|
||||||
self.logger: logging.Logger = logging.getLogger("Checker")
|
|
||||||
self.source_path: Path = source_path
|
|
||||||
self.types_paths: list[Path] = types_paths
|
|
||||||
self.ctx: MidasResolver = MidasResolver()
|
|
||||||
self.global_env: Environment = Environment()
|
|
||||||
self.env: Environment = self.global_env
|
|
||||||
self.locals: dict[p.Expr, int] = locals
|
|
||||||
self.diagnostics: list[Diagnostic] = []
|
|
||||||
self.judgements: list[tuple[p.Expr, Type]] = []
|
|
||||||
|
|
||||||
def diagnostic(self, type: DiagnosticType, location: Location, message: str):
|
@property
|
||||||
self.diagnostics.append(
|
def diagnostics(self) -> list[Diagnostic]:
|
||||||
Diagnostic(
|
return self.reporter.diagnostics
|
||||||
file_path=self.source_path,
|
|
||||||
location=location,
|
|
||||||
type=type,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def error(self, location: Location, message: str):
|
|
||||||
self.diagnostic(
|
|
||||||
type=DiagnosticType.ERROR,
|
|
||||||
location=location,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
|
|
||||||
def warning(self, location: Location, message: str):
|
|
||||||
self.diagnostic(
|
|
||||||
type=DiagnosticType.WARNING,
|
|
||||||
location=location,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
|
|
||||||
def info(self, location: Location, message: str):
|
|
||||||
self.diagnostic(
|
|
||||||
type=DiagnosticType.INFO,
|
|
||||||
location=location,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
|
|
||||||
def type_of(self, expr: p.Expr) -> Type:
|
|
||||||
"""Evaluate the type of an expression
|
|
||||||
|
|
||||||
Args:
|
|
||||||
expr (p.Expr): the expression to evaluate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Type: the type of the given expression
|
|
||||||
"""
|
|
||||||
type: Type = expr.accept(self)
|
|
||||||
self.judgements.append((expr, type))
|
|
||||||
return type
|
|
||||||
|
|
||||||
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
|
||||||
"""Evaluate a sequence of statements
|
|
||||||
|
|
||||||
Args:
|
|
||||||
block (list[p.Stmt]): the statements to evaluate
|
|
||||||
env (Environment): the environment in which to evaluate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: whether a return statement is present in the block
|
|
||||||
"""
|
|
||||||
previous_env: Environment = self.env
|
|
||||||
self.env = env
|
|
||||||
returned: bool = False
|
|
||||||
for i, stmt in enumerate(block):
|
|
||||||
try:
|
|
||||||
stmt.accept(self)
|
|
||||||
except ReturnException:
|
|
||||||
returned = True
|
|
||||||
if i < len(block) - 1:
|
|
||||||
self.warning(block[i + 1].location, "Unreachable statement")
|
|
||||||
break
|
|
||||||
self.env = previous_env
|
|
||||||
return returned
|
|
||||||
|
|
||||||
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
|
|
||||||
"""Type check a sequence of statements and returns diagnostics
|
|
||||||
|
|
||||||
Args:
|
|
||||||
statements (list[p.Stmt]): the statements to evaluate and check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[Diagnostic]: the list of diagnostics (errors, warning, etc.)
|
|
||||||
"""
|
|
||||||
self.diagnostics = []
|
|
||||||
|
|
||||||
for path in self.types_paths:
|
|
||||||
self.import_midas(path)
|
|
||||||
self.logger.debug(f"Midas types: {self.ctx._types}")
|
|
||||||
self.logger.debug(f"Midas operations: {self.ctx._operations}")
|
|
||||||
|
|
||||||
for stmt in statements:
|
|
||||||
stmt.accept(self)
|
|
||||||
|
|
||||||
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
|
|
||||||
return self.diagnostics
|
|
||||||
|
|
||||||
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
|
|
||||||
"""Look up a variable in the environment it was declared
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): the name of the variable
|
|
||||||
expr (p.Expr): the variable expression, used to lookup the scope distance
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Type]: the type of the variable, or None if it was not found
|
|
||||||
"""
|
|
||||||
distance: Optional[int] = self.locals.get(expr)
|
|
||||||
if distance is not None:
|
|
||||||
return self.env.get_at(distance, name)
|
|
||||||
return self.global_env.get(name)
|
|
||||||
|
|
||||||
def import_midas(self, path: Path) -> None:
|
|
||||||
"""Import Midas definitions from a path
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (Path): the import path
|
|
||||||
"""
|
|
||||||
self.logger.debug(f"Importing type definitions from {path}")
|
|
||||||
lexer: MidasLexer = MidasLexer(path.read_text())
|
|
||||||
tokens: list[Token] = lexer.process()
|
|
||||||
parser: MidasParser = MidasParser(tokens)
|
|
||||||
stmts: list[m.Stmt] = parser.parse()
|
|
||||||
self.ctx.resolve(stmts)
|
|
||||||
|
|
||||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
|
||||||
return self.ctx.is_subtype(type1, type2)
|
|
||||||
|
|
||||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
|
||||||
self.type_of(stmt.expr)
|
|
||||||
|
|
||||||
def visit_function(self, stmt: p.Function) -> None:
|
|
||||||
env: Environment = Environment(self.env)
|
|
||||||
pos_args: list[Function.Argument] = []
|
|
||||||
args: list[Function.Argument] = []
|
|
||||||
kw_args: list[Function.Argument] = []
|
|
||||||
|
|
||||||
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
|
||||||
if arg.type is not None:
|
|
||||||
return arg.type.accept(self)
|
|
||||||
if arg.default is not None:
|
|
||||||
return arg.default.accept(self)
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
pos: int = 0
|
|
||||||
for arg in stmt.posonlyargs:
|
|
||||||
pos_args.append(
|
|
||||||
Function.Argument(
|
|
||||||
pos=pos,
|
|
||||||
name=arg.name,
|
|
||||||
type=eval_arg_type(arg),
|
|
||||||
required=arg.default is None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
pos += 1
|
|
||||||
for arg in stmt.args:
|
|
||||||
args.append(
|
|
||||||
Function.Argument(
|
|
||||||
pos=pos,
|
|
||||||
name=arg.name,
|
|
||||||
type=eval_arg_type(arg),
|
|
||||||
required=arg.default is None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
pos += 1
|
|
||||||
for arg in stmt.kwonlyargs:
|
|
||||||
kw_args.append(
|
|
||||||
Function.Argument(
|
|
||||||
pos=pos, # not relevant
|
|
||||||
name=arg.name,
|
|
||||||
type=eval_arg_type(arg),
|
|
||||||
required=arg.default is None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
pos += 1
|
|
||||||
|
|
||||||
for arg in pos_args + args + kw_args:
|
|
||||||
env.define(arg.name, arg.type)
|
|
||||||
|
|
||||||
returns_hint: Optional[Type] = None
|
|
||||||
if stmt.returns is not None:
|
|
||||||
returns_hint = stmt.returns.accept(self)
|
|
||||||
# Early define to handle simple fully-typed recursion
|
|
||||||
inside_function: Function = Function(
|
|
||||||
name=stmt.name,
|
|
||||||
pos_args=pos_args,
|
|
||||||
args=args,
|
|
||||||
kw_args=kw_args,
|
|
||||||
returns=returns_hint,
|
|
||||||
)
|
|
||||||
self.env.define(stmt.name, inside_function)
|
|
||||||
|
|
||||||
returned: bool = self.process_block(stmt.body, env)
|
|
||||||
inferred_return: Type = UnknownType()
|
|
||||||
if not returned:
|
|
||||||
env.return_types.append(UnitType())
|
|
||||||
return_types: set[Type] = set(env.return_types)
|
|
||||||
if len(return_types) == 1:
|
|
||||||
inferred_return = list(return_types)[0]
|
|
||||||
elif len(return_types) > 1:
|
|
||||||
self.error(
|
|
||||||
stmt.location,
|
|
||||||
f"Mixed return types: {env.return_types}",
|
|
||||||
)
|
|
||||||
|
|
||||||
returns: Type = UnknownType()
|
|
||||||
if returns_hint is not None:
|
|
||||||
assert stmt.returns is not None
|
|
||||||
returns = returns_hint
|
|
||||||
if returns != inferred_return:
|
|
||||||
self.error(
|
|
||||||
stmt.returns.location,
|
|
||||||
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
returns = inferred_return
|
|
||||||
|
|
||||||
# TODO: handle *args and **kwargs sinks
|
|
||||||
function: Function = Function(
|
|
||||||
name=stmt.name,
|
|
||||||
pos_args=pos_args,
|
|
||||||
args=args,
|
|
||||||
kw_args=kw_args,
|
|
||||||
returns=returns,
|
|
||||||
)
|
|
||||||
self.env.define(stmt.name, function)
|
|
||||||
|
|
||||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
|
||||||
# TODO check not yet defined locally
|
|
||||||
type: Type = stmt.type.accept(self)
|
|
||||||
self.env.define(stmt.name, type)
|
|
||||||
|
|
||||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
|
||||||
value_type: Type = self.type_of(stmt.value)
|
|
||||||
for target in stmt.targets:
|
|
||||||
self._assign(stmt.location, target, value_type)
|
|
||||||
|
|
||||||
def _assign(self, location: Location, target: p.Expr, value_type: Type):
|
|
||||||
match target:
|
|
||||||
case p.VariableExpr():
|
|
||||||
self._assign_var(location, target, value_type)
|
|
||||||
|
|
||||||
case p.GetExpr():
|
|
||||||
self._assign_attr(location, target, value_type)
|
|
||||||
|
|
||||||
case _:
|
|
||||||
if not isinstance(target, p.VariableExpr):
|
|
||||||
self.logger.warning(f"Unsupported assignment to {target}")
|
|
||||||
self.warning(target.location, f"Unsupported assignment to {target}")
|
|
||||||
|
|
||||||
def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type):
|
|
||||||
name: str = target.name
|
|
||||||
var_type: Optional[Type] = self.look_up_variable(name, target)
|
|
||||||
|
|
||||||
if var_type is None:
|
|
||||||
self.env.define(name, value_type)
|
|
||||||
else:
|
|
||||||
# S <: T
|
|
||||||
# Γ, x: T v: S
|
|
||||||
# x = v
|
|
||||||
if not self.is_subtype(value_type, var_type):
|
|
||||||
self.error(
|
|
||||||
location,
|
|
||||||
f"Cannot assign {value_type} to {name} of type {var_type}",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type):
|
|
||||||
object: Type = self.type_of(target.object)
|
|
||||||
base_object: Type = unfold_type(object)
|
|
||||||
match base_object:
|
|
||||||
case ComplexType(properties=properties):
|
|
||||||
if target.name not in properties:
|
|
||||||
self.error(
|
|
||||||
target.location, f"Unknown property '{target.name} on {object}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
prop_type: Type = properties[target.name]
|
|
||||||
if not self.is_subtype(value_type, prop_type):
|
|
||||||
self.error(
|
|
||||||
location,
|
|
||||||
f"Cannot assign {value_type} to property '{target.name}' of type {prop_type} on {object}",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
case UnknownType():
|
|
||||||
pass
|
|
||||||
|
|
||||||
case _:
|
|
||||||
self.error(
|
|
||||||
target.location,
|
|
||||||
f"Cannot assign {value_type} to unknown property '{target.name}' on {object}",
|
|
||||||
)
|
|
||||||
|
|
||||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
|
||||||
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
|
|
||||||
self.env.return_types.append(type)
|
|
||||||
raise ReturnException()
|
|
||||||
|
|
||||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
|
||||||
# Not evaluated in sub-environment because assignments in the test leak out of the if
|
|
||||||
# For example:
|
|
||||||
# if (m := 1 + 1) < 2:
|
|
||||||
# ...
|
|
||||||
# print(m) # <- m is still defined
|
|
||||||
test_type: Type = stmt.test.accept(self)
|
|
||||||
|
|
||||||
# TODO Allow subtypes or any type
|
|
||||||
if test_type != self.ctx.get_type("bool"):
|
|
||||||
self.error(
|
|
||||||
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
env: Environment = Environment(self.env)
|
|
||||||
body_returned: bool = self.process_block(stmt.body, env)
|
|
||||||
else_returned: bool = self.process_block(stmt.orelse, env)
|
|
||||||
self.env.return_types.extend(env.return_types)
|
|
||||||
if body_returned and else_returned:
|
|
||||||
raise ReturnException()
|
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
|
||||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
|
||||||
if method is None:
|
|
||||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
|
||||||
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
|
||||||
return UnknownType()
|
|
||||||
left: Type = self.type_of(expr.left)
|
|
||||||
right: Type = self.type_of(expr.right)
|
|
||||||
|
|
||||||
operations: list[Operation] = self.ctx.get_operations_by_name(method)
|
|
||||||
valid_operations: list[Operation] = []
|
|
||||||
for op in operations:
|
|
||||||
sig: Operation.CallSignature = op.signature
|
|
||||||
if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right):
|
|
||||||
valid_operations.append(op)
|
|
||||||
|
|
||||||
if len(valid_operations) == 0:
|
|
||||||
self.error(
|
|
||||||
expr.location,
|
|
||||||
f"Undefined operation {method} between {left} and {right}",
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
elif len(valid_operations) == 1:
|
|
||||||
self.logger.debug(f"Unique operation {method} between {left} and {right}")
|
|
||||||
return valid_operations[0].result
|
|
||||||
|
|
||||||
for i, op1 in enumerate(valid_operations):
|
|
||||||
sig1: Operation.CallSignature = op1.signature
|
|
||||||
best_match: bool = True
|
|
||||||
for j, op2 in enumerate(valid_operations):
|
|
||||||
if i == j:
|
|
||||||
continue
|
|
||||||
sig2: Operation.CallSignature = op2.signature
|
|
||||||
if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype(
|
|
||||||
sig1.right, sig2.right
|
|
||||||
):
|
|
||||||
best_match = False
|
|
||||||
break
|
|
||||||
self.logger.debug(f"{op1} is a full overload of {op2}")
|
|
||||||
if best_match:
|
|
||||||
return op1.result
|
|
||||||
|
|
||||||
overloads: list[str] = [
|
|
||||||
f"({op.signature.left} {op.signature.method} {op.signature.right}) -> {op.result}"
|
|
||||||
for op in valid_operations
|
|
||||||
]
|
|
||||||
self.error(
|
|
||||||
expr.location,
|
|
||||||
f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(overloads)}",
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
|
||||||
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
|
||||||
if method is None:
|
|
||||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
|
||||||
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
|
||||||
return UnknownType()
|
|
||||||
left: Type = self.type_of(expr.left)
|
|
||||||
right: Type = self.type_of(expr.right)
|
|
||||||
|
|
||||||
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
|
||||||
if result is None:
|
|
||||||
self.error(
|
|
||||||
expr.location,
|
|
||||||
f"Undefined operation {method} between {left} and {right}",
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
|
|
||||||
|
|
||||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
|
||||||
callee: Type = self.type_of(expr.callee)
|
|
||||||
if not isinstance(callee, Function):
|
|
||||||
self.error(expr.callee.location, "Callee is not a function")
|
|
||||||
return UnknownType()
|
|
||||||
function: Function = callee
|
|
||||||
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
|
||||||
for arg in mapped:
|
|
||||||
if not self.is_subtype(arg.type, arg.argument.type):
|
|
||||||
self.error(
|
|
||||||
arg.expr.location,
|
|
||||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
|
||||||
)
|
|
||||||
return function.returns
|
|
||||||
|
|
||||||
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
|
||||||
object: Type = self.type_of(expr.object)
|
|
||||||
base_object: Type = unfold_type(object)
|
|
||||||
match base_object:
|
|
||||||
case ComplexType(properties=properties):
|
|
||||||
if expr.name not in properties:
|
|
||||||
self.error(
|
|
||||||
expr.location, f"Unknown property '{expr.name} on {object}"
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
return properties[expr.name]
|
|
||||||
|
|
||||||
case UnknownType():
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
case _:
|
|
||||||
self.error(
|
|
||||||
expr.location, f"Cannot get property '{expr.name}' on {object}"
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
|
|
||||||
match expr.value:
|
|
||||||
case bool(): # Must be before int
|
|
||||||
return self.ctx.get_type("bool")
|
|
||||||
case int():
|
|
||||||
return self.ctx.get_type("int")
|
|
||||||
case float():
|
|
||||||
return self.ctx.get_type("float")
|
|
||||||
case str():
|
|
||||||
return self.ctx.get_type("str")
|
|
||||||
case _:
|
|
||||||
self.warning(expr.location, f"Unknown literal {expr}")
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
|
|
||||||
return self.look_up_variable(expr.name, expr) or UnknownType()
|
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
|
|
||||||
left: Type = expr.left.accept(self)
|
|
||||||
right: Type = expr.right.accept(self)
|
|
||||||
|
|
||||||
if self.is_subtype(left, right):
|
|
||||||
return right
|
|
||||||
if self.is_subtype(right, left):
|
|
||||||
return left
|
|
||||||
|
|
||||||
self.error(
|
|
||||||
expr.location,
|
|
||||||
f"Incompatible operand types, {left=} and {right=}",
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
|
||||||
return expr.type.accept(self)
|
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
|
||||||
test_type: Type = expr.test.accept(self)
|
|
||||||
|
|
||||||
# TODO Allow subtypes or any type
|
|
||||||
if test_type != self.ctx.get_type("bool"):
|
|
||||||
self.error(
|
|
||||||
expr.test.location, f"If test must be a boolean, got {test_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
true_type: Type = expr.if_true.accept(self)
|
|
||||||
false_type: Type = expr.if_false.accept(self)
|
|
||||||
if self.is_subtype(true_type, false_type):
|
|
||||||
return false_type
|
|
||||||
if self.is_subtype(false_type, true_type):
|
|
||||||
return true_type
|
|
||||||
|
|
||||||
self.error(
|
|
||||||
expr.location,
|
|
||||||
f"Incompatible types in ternary if branches: true={true_type} and false={false_type}",
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
def visit_base_type(self, node: p.BaseType) -> Type:
|
|
||||||
return self.ctx.get_type(node.base)
|
|
||||||
|
|
||||||
def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
|
|
||||||
|
|
||||||
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
|
|
||||||
|
|
||||||
def visit_frame_type(self, node: p.FrameType) -> Type: ...
|
|
||||||
|
|
||||||
def map_call_arguments(
|
|
||||||
self, function: Function, call: p.CallExpr
|
|
||||||
) -> list[MappedArgument]:
|
|
||||||
"""Map call arguments to function parameters as defined in its signature
|
|
||||||
|
|
||||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
|
||||||
with the arguments passed at the call site
|
|
||||||
|
|
||||||
Any mismatched, missing or unexpected argument is reported as a diagnostic
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function (Function): the function definition
|
|
||||||
call (p.CallExpr): the call expression
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[MappedArgument]: the list of mapped arguments
|
|
||||||
"""
|
|
||||||
positional: list[tuple[p.Expr, Type]] = [
|
|
||||||
(arg, self.type_of(arg)) for arg in call.arguments
|
|
||||||
]
|
|
||||||
keywords: dict[str, tuple[p.Expr, Type]] = {
|
|
||||||
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
|
|
||||||
}
|
|
||||||
set_args: set[str] = set()
|
|
||||||
|
|
||||||
required_positional: list[str] = [
|
|
||||||
arg.name for arg in function.pos_args + function.args if arg.required
|
|
||||||
]
|
|
||||||
required_keyword: list[str] = [
|
|
||||||
arg.name for arg in function.kw_args if arg.required
|
|
||||||
]
|
|
||||||
|
|
||||||
mapped: list[MappedArgument] = []
|
|
||||||
|
|
||||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
|
||||||
mixed_params: list[Function.Argument] = list(function.args)
|
|
||||||
kw_params: dict[str, Function.Argument] = {
|
|
||||||
arg.name: arg for arg in function.kw_args
|
|
||||||
}
|
|
||||||
|
|
||||||
# TODO: handle *args and **kwargs sinks
|
|
||||||
for arg in positional:
|
|
||||||
param: Function.Argument
|
|
||||||
if len(pos_params) != 0:
|
|
||||||
param = pos_params.pop(0)
|
|
||||||
elif len(mixed_params) != 0:
|
|
||||||
param = mixed_params.pop(0)
|
|
||||||
else:
|
|
||||||
self.error(arg[0].location, "Too many positional arguments")
|
|
||||||
break
|
|
||||||
name: str = param.name
|
|
||||||
if name in required_positional:
|
|
||||||
required_positional.remove(name)
|
|
||||||
if name in required_keyword:
|
|
||||||
required_keyword.remove(name)
|
|
||||||
set_args.add(name)
|
|
||||||
mapped.append(
|
|
||||||
MappedArgument(
|
|
||||||
expr=arg[0],
|
|
||||||
type=arg[1],
|
|
||||||
argument=param,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
|
||||||
for name, arg in keywords.items():
|
|
||||||
param: Function.Argument
|
|
||||||
if name not in kw_params:
|
|
||||||
if name in set_args:
|
|
||||||
self.error(
|
|
||||||
arg[0].location, f"Multiple values for argument '{name}'"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.error(arg[0].location, f"Unknown keyword argument '{name}'")
|
|
||||||
continue
|
|
||||||
param = kw_params.pop(name)
|
|
||||||
if name in required_positional:
|
|
||||||
required_positional.remove(name)
|
|
||||||
if name in required_keyword:
|
|
||||||
required_keyword.remove(name)
|
|
||||||
set_args.add(name)
|
|
||||||
mapped.append(
|
|
||||||
MappedArgument(
|
|
||||||
expr=arg[0],
|
|
||||||
type=arg[1],
|
|
||||||
argument=param,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def join_args(args: list[str]) -> str:
|
|
||||||
args = list(map(lambda a: f"'{a}'", args))
|
|
||||||
if len(args) == 0:
|
|
||||||
return ""
|
|
||||||
if len(args) == 1:
|
|
||||||
return args[0]
|
|
||||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
|
||||||
|
|
||||||
if len(required_positional) != 0:
|
|
||||||
plural: str = "" if len(required_positional) == 1 else "s"
|
|
||||||
args: str = join_args(required_positional)
|
|
||||||
self.error(
|
|
||||||
call.location,
|
|
||||||
f"Missing required positional argument{plural}: {args}",
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(required_keyword) != 0:
|
|
||||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
|
||||||
args: str = join_args(required_keyword)
|
|
||||||
self.error(
|
|
||||||
call.location,
|
|
||||||
f"Missing required keyword argument{plural}: {args}",
|
|
||||||
)
|
|
||||||
|
|
||||||
return mapped
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
@@ -14,7 +13,7 @@ class DiagnosticType(StrEnum):
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Diagnostic:
|
class Diagnostic:
|
||||||
file_path: Optional[str | Path]
|
file_path: Optional[str]
|
||||||
location: Location
|
location: Location
|
||||||
type: DiagnosticType
|
type: DiagnosticType
|
||||||
message: str
|
message: str
|
||||||
|
|||||||
137
midas/checker/midas.py
Normal file
137
midas/checker/midas.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
|
from midas.checker.types import (
|
||||||
|
AliasType,
|
||||||
|
ComplexType,
|
||||||
|
GenericType,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
UnknownType,
|
||||||
|
)
|
||||||
|
from midas.lexer.midas import MidasLexer
|
||||||
|
from midas.lexer.token import Token
|
||||||
|
from midas.parser.midas import MidasParser
|
||||||
|
from midas.resolver.builtin import define_builtins
|
||||||
|
|
||||||
|
|
||||||
|
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
||||||
|
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||||
|
|
||||||
|
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||||
|
self.logger: logging.Logger = logging.getLogger("MidasTyper")
|
||||||
|
self.reporter: FileReporter = reporter.for_file(None)
|
||||||
|
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
self._local_variables: dict[str, TypeVar] = {}
|
||||||
|
|
||||||
|
define_builtins(self.types)
|
||||||
|
|
||||||
|
def process(self, source: str, path: Optional[str]):
|
||||||
|
self.reporter = self.reporter.for_file(path)
|
||||||
|
lexer: MidasLexer = MidasLexer(source)
|
||||||
|
tokens: list[Token] = lexer.process()
|
||||||
|
parser: MidasParser = MidasParser(tokens)
|
||||||
|
stmts: list[m.Stmt] = parser.parse()
|
||||||
|
self.resolve(stmts)
|
||||||
|
|
||||||
|
def get_type(self, name: str) -> Type:
|
||||||
|
"""Get a type from its name
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the type
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NameError: if the type is not defined
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the type
|
||||||
|
"""
|
||||||
|
if name in self._local_variables:
|
||||||
|
return self._local_variables[name]
|
||||||
|
return self.types.get_type(name)
|
||||||
|
|
||||||
|
def resolve(self, stmts: list[m.Stmt]):
|
||||||
|
"""Process a sequence of statements
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stmts (list[m.Stmt]): the statements
|
||||||
|
"""
|
||||||
|
for stmt in stmts:
|
||||||
|
stmt.accept(self)
|
||||||
|
|
||||||
|
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||||
|
params: list[TypeVar] = []
|
||||||
|
for param in stmt.params:
|
||||||
|
name: str = param.name.lexeme
|
||||||
|
bound: Optional[Type] = None
|
||||||
|
if param.bound is not None:
|
||||||
|
bound = param.bound.accept(self)
|
||||||
|
var = TypeVar(name=name, bound=bound)
|
||||||
|
self._local_variables[name] = var
|
||||||
|
params.append(var)
|
||||||
|
type: Type = stmt.type.accept(self)
|
||||||
|
if len(params) != 0:
|
||||||
|
type = GenericType(params=params, body=type)
|
||||||
|
name: str = stmt.name.lexeme
|
||||||
|
self.types.define_type(name, AliasType(name=name, type=type))
|
||||||
|
self._local_variables.clear()
|
||||||
|
|
||||||
|
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
|
||||||
|
|
||||||
|
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||||
|
base: Type = stmt.type.accept(self)
|
||||||
|
for op in stmt.operations:
|
||||||
|
right: Type = op.operand.accept(self)
|
||||||
|
result: Type = op.result.accept(self)
|
||||||
|
self.types.define_operation(
|
||||||
|
left=base,
|
||||||
|
operator=op.name.lexeme,
|
||||||
|
right=right,
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
|
||||||
|
|
||||||
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: m.GetExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||||
|
return expr.expr.accept(self)
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||||
|
return self.get_type(type.name.lexeme)
|
||||||
|
|
||||||
|
def visit_generic_type(self, type: m.GenericType) -> Type:
|
||||||
|
type_: Type = type.type.accept(self)
|
||||||
|
params: list[Type] = [param.accept(self) for param in type.params]
|
||||||
|
return self.types.apply_generic(type_, params)
|
||||||
|
|
||||||
|
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||||
|
type_: Type = type.type.accept(self)
|
||||||
|
type.constraint.accept(self)
|
||||||
|
# TODO
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_complex_type(self, type: m.ComplexType) -> Type:
|
||||||
|
return ComplexType(
|
||||||
|
properties={
|
||||||
|
prop.name.lexeme: prop.type.accept(self) for prop in type.properties
|
||||||
|
}
|
||||||
|
)
|
||||||
626
midas/checker/python.py
Normal file
626
midas/checker/python.py
Normal file
@@ -0,0 +1,626 @@
|
|||||||
|
import ast
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.environment import Environment
|
||||||
|
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
|
from midas.checker.types import (
|
||||||
|
ComplexType,
|
||||||
|
Function,
|
||||||
|
Operation,
|
||||||
|
Type,
|
||||||
|
UnitType,
|
||||||
|
UnknownType,
|
||||||
|
unfold_type,
|
||||||
|
)
|
||||||
|
from midas.parser.python import PythonParser
|
||||||
|
from midas.resolver.resolver import Resolver
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class MappedArgument:
|
||||||
|
expr: p.Expr
|
||||||
|
type: Type
|
||||||
|
argument: Function.Argument
|
||||||
|
|
||||||
|
|
||||||
|
class PythonTyper(
|
||||||
|
p.Stmt.Visitor[None],
|
||||||
|
p.Expr.Visitor[Type],
|
||||||
|
p.MidasType.Visitor[Type],
|
||||||
|
):
|
||||||
|
"""A type checker which can use custom type definitions"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
types: TypesRegistry,
|
||||||
|
reporter: Reporter,
|
||||||
|
):
|
||||||
|
self.logger: logging.Logger = logging.getLogger("PythonTyper")
|
||||||
|
self.reporter: FileReporter = reporter.for_file(None)
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
self.global_env: Environment = Environment()
|
||||||
|
self.env: Environment = self.global_env
|
||||||
|
self.locals: dict[p.Expr, int] = {}
|
||||||
|
self.judgements: list[tuple[p.Expr, Type]] = []
|
||||||
|
|
||||||
|
def process(self, source: str, path: Optional[str]):
|
||||||
|
self.reporter = self.reporter.for_file(path)
|
||||||
|
|
||||||
|
tree: ast.Module = ast.parse(source, filename=path or "<unknown>")
|
||||||
|
parser = PythonParser()
|
||||||
|
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||||
|
resolver = Resolver()
|
||||||
|
resolver.resolve(*stmts)
|
||||||
|
|
||||||
|
self.env = self.global_env
|
||||||
|
self.locals = resolver.locals
|
||||||
|
self.judgements = []
|
||||||
|
|
||||||
|
self.check(stmts)
|
||||||
|
|
||||||
|
def type_of(self, expr: p.Expr) -> Type:
|
||||||
|
"""Evaluate the type of an expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expr (p.Expr): the expression to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the type of the given expression
|
||||||
|
"""
|
||||||
|
type: Type = expr.accept(self)
|
||||||
|
self.judgements.append((expr, type))
|
||||||
|
return type
|
||||||
|
|
||||||
|
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
||||||
|
"""Evaluate a sequence of statements
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block (list[p.Stmt]): the statements to evaluate
|
||||||
|
env (Environment): the environment in which to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether a return statement is present in the block
|
||||||
|
"""
|
||||||
|
previous_env: Environment = self.env
|
||||||
|
self.env = env
|
||||||
|
returned: bool = False
|
||||||
|
for i, stmt in enumerate(block):
|
||||||
|
try:
|
||||||
|
stmt.accept(self)
|
||||||
|
except ReturnException:
|
||||||
|
returned = True
|
||||||
|
if i < len(block) - 1:
|
||||||
|
self.reporter.warning(
|
||||||
|
block[i + 1].location, "Unreachable statement"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
self.env = previous_env
|
||||||
|
return returned
|
||||||
|
|
||||||
|
def check(self, statements: list[p.Stmt]) -> None:
|
||||||
|
"""Type check a sequence of statements and returns diagnostics
|
||||||
|
|
||||||
|
Args:
|
||||||
|
statements (list[p.Stmt]): the statements to evaluate and check
|
||||||
|
"""
|
||||||
|
for stmt in statements:
|
||||||
|
stmt.accept(self)
|
||||||
|
|
||||||
|
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
|
||||||
|
|
||||||
|
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
|
||||||
|
"""Look up a variable in the environment it was declared
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the variable
|
||||||
|
expr (p.Expr): the variable expression, used to lookup the scope distance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Type]: the type of the variable, or None if it was not found
|
||||||
|
"""
|
||||||
|
distance: Optional[int] = self.locals.get(expr)
|
||||||
|
if distance is not None:
|
||||||
|
return self.env.get_at(distance, name)
|
||||||
|
return self.global_env.get(name)
|
||||||
|
|
||||||
|
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||||
|
return self.types.is_subtype(type1, type2)
|
||||||
|
|
||||||
|
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||||
|
self.type_of(stmt.expr)
|
||||||
|
|
||||||
|
def visit_function(self, stmt: p.Function) -> None:
|
||||||
|
env: Environment = Environment(self.env)
|
||||||
|
pos_args: list[Function.Argument] = []
|
||||||
|
args: list[Function.Argument] = []
|
||||||
|
kw_args: list[Function.Argument] = []
|
||||||
|
|
||||||
|
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||||
|
if arg.type is not None:
|
||||||
|
return arg.type.accept(self)
|
||||||
|
if arg.default is not None:
|
||||||
|
return arg.default.accept(self)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
pos: int = 0
|
||||||
|
for arg in stmt.posonlyargs:
|
||||||
|
pos_args.append(
|
||||||
|
Function.Argument(
|
||||||
|
pos=pos,
|
||||||
|
name=arg.name,
|
||||||
|
type=eval_arg_type(arg),
|
||||||
|
required=arg.default is None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pos += 1
|
||||||
|
for arg in stmt.args:
|
||||||
|
args.append(
|
||||||
|
Function.Argument(
|
||||||
|
pos=pos,
|
||||||
|
name=arg.name,
|
||||||
|
type=eval_arg_type(arg),
|
||||||
|
required=arg.default is None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pos += 1
|
||||||
|
for arg in stmt.kwonlyargs:
|
||||||
|
kw_args.append(
|
||||||
|
Function.Argument(
|
||||||
|
pos=pos, # not relevant
|
||||||
|
name=arg.name,
|
||||||
|
type=eval_arg_type(arg),
|
||||||
|
required=arg.default is None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pos += 1
|
||||||
|
|
||||||
|
for arg in pos_args + args + kw_args:
|
||||||
|
env.define(arg.name, arg.type)
|
||||||
|
|
||||||
|
returns_hint: Optional[Type] = None
|
||||||
|
if stmt.returns is not None:
|
||||||
|
returns_hint = stmt.returns.accept(self)
|
||||||
|
# Early define to handle simple fully-typed recursion
|
||||||
|
inside_function: Function = Function(
|
||||||
|
name=stmt.name,
|
||||||
|
pos_args=pos_args,
|
||||||
|
args=args,
|
||||||
|
kw_args=kw_args,
|
||||||
|
returns=returns_hint,
|
||||||
|
)
|
||||||
|
self.env.define(stmt.name, inside_function)
|
||||||
|
|
||||||
|
returned: bool = self.process_block(stmt.body, env)
|
||||||
|
inferred_return: Type = UnknownType()
|
||||||
|
if not returned:
|
||||||
|
env.return_types.append(UnitType())
|
||||||
|
return_types: set[Type] = set(env.return_types)
|
||||||
|
if len(return_types) == 1:
|
||||||
|
inferred_return = list(return_types)[0]
|
||||||
|
elif len(return_types) > 1:
|
||||||
|
self.reporter.error(
|
||||||
|
stmt.location,
|
||||||
|
f"Mixed return types: {env.return_types}",
|
||||||
|
)
|
||||||
|
|
||||||
|
returns: Type = UnknownType()
|
||||||
|
if returns_hint is not None:
|
||||||
|
assert stmt.returns is not None
|
||||||
|
returns = returns_hint
|
||||||
|
if returns != inferred_return:
|
||||||
|
self.reporter.error(
|
||||||
|
stmt.returns.location,
|
||||||
|
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
returns = inferred_return
|
||||||
|
|
||||||
|
# TODO: handle *args and **kwargs sinks
|
||||||
|
function: Function = Function(
|
||||||
|
name=stmt.name,
|
||||||
|
pos_args=pos_args,
|
||||||
|
args=args,
|
||||||
|
kw_args=kw_args,
|
||||||
|
returns=returns,
|
||||||
|
)
|
||||||
|
self.env.define(stmt.name, function)
|
||||||
|
|
||||||
|
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||||
|
# TODO check not yet defined locally
|
||||||
|
type: Type = stmt.type.accept(self)
|
||||||
|
self.env.define(stmt.name, type)
|
||||||
|
|
||||||
|
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||||
|
value_type: Type = self.type_of(stmt.value)
|
||||||
|
for target in stmt.targets:
|
||||||
|
self._assign(stmt.location, target, value_type)
|
||||||
|
|
||||||
|
def _assign(self, location: Location, target: p.Expr, value_type: Type):
|
||||||
|
match target:
|
||||||
|
case p.VariableExpr():
|
||||||
|
self._assign_var(location, target, value_type)
|
||||||
|
|
||||||
|
case p.GetExpr():
|
||||||
|
self._assign_attr(location, target, value_type)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
if not isinstance(target, p.VariableExpr):
|
||||||
|
self.logger.warning(f"Unsupported assignment to {target}")
|
||||||
|
self.reporter.warning(
|
||||||
|
target.location, f"Unsupported assignment to {target}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type):
|
||||||
|
name: str = target.name
|
||||||
|
var_type: Optional[Type] = self.look_up_variable(name, target)
|
||||||
|
|
||||||
|
if var_type is None:
|
||||||
|
self.env.define(name, value_type)
|
||||||
|
else:
|
||||||
|
# S <: T
|
||||||
|
# Γ, x: T v: S
|
||||||
|
# x = v
|
||||||
|
if not self.is_subtype(value_type, var_type):
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Cannot assign {value_type} to {name} of type {var_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type):
|
||||||
|
object: Type = self.type_of(target.object)
|
||||||
|
base_object: Type = unfold_type(object)
|
||||||
|
match base_object:
|
||||||
|
case ComplexType(properties=properties):
|
||||||
|
if target.name not in properties:
|
||||||
|
self.reporter.error(
|
||||||
|
target.location, f"Unknown property '{target.name} on {object}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
prop_type: Type = properties[target.name]
|
||||||
|
if not self.is_subtype(value_type, prop_type):
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Cannot assign {value_type} to property '{target.name}' of type {prop_type} on {object}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
case UnknownType():
|
||||||
|
pass
|
||||||
|
|
||||||
|
case _:
|
||||||
|
self.reporter.error(
|
||||||
|
target.location,
|
||||||
|
f"Cannot assign {value_type} to unknown property '{target.name}' on {object}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||||
|
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
|
||||||
|
self.env.return_types.append(type)
|
||||||
|
raise ReturnException()
|
||||||
|
|
||||||
|
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||||
|
# Not evaluated in sub-environment because assignments in the test leak out of the if
|
||||||
|
# For example:
|
||||||
|
# if (m := 1 + 1) < 2:
|
||||||
|
# ...
|
||||||
|
# print(m) # <- m is still defined
|
||||||
|
test_type: Type = stmt.test.accept(self)
|
||||||
|
|
||||||
|
# TODO Allow subtypes or any type
|
||||||
|
if test_type != self.types.get_type("bool"):
|
||||||
|
self.reporter.error(
|
||||||
|
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
env: Environment = Environment(self.env)
|
||||||
|
body_returned: bool = self.process_block(stmt.body, env)
|
||||||
|
else_returned: bool = self.process_block(stmt.orelse, env)
|
||||||
|
self.env.return_types.extend(env.return_types)
|
||||||
|
if body_returned and else_returned:
|
||||||
|
raise ReturnException()
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||||
|
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||||
|
if method is None:
|
||||||
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
|
self.reporter.warning(
|
||||||
|
expr.location, f"Unsupported operator {expr.operator}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
left: Type = self.type_of(expr.left)
|
||||||
|
right: Type = self.type_of(expr.right)
|
||||||
|
|
||||||
|
operations: list[Operation] = self.types.get_operations_by_name(method)
|
||||||
|
valid_operations: list[Operation] = []
|
||||||
|
for op in operations:
|
||||||
|
sig: Operation.CallSignature = op.signature
|
||||||
|
if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right):
|
||||||
|
valid_operations.append(op)
|
||||||
|
|
||||||
|
if len(valid_operations) == 0:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Undefined operation {method} between {left} and {right}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
elif len(valid_operations) == 1:
|
||||||
|
self.logger.debug(f"Unique operation {method} between {left} and {right}")
|
||||||
|
return valid_operations[0].result
|
||||||
|
|
||||||
|
for i, op1 in enumerate(valid_operations):
|
||||||
|
sig1: Operation.CallSignature = op1.signature
|
||||||
|
best_match: bool = True
|
||||||
|
for j, op2 in enumerate(valid_operations):
|
||||||
|
if i == j:
|
||||||
|
continue
|
||||||
|
sig2: Operation.CallSignature = op2.signature
|
||||||
|
if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype(
|
||||||
|
sig1.right, sig2.right
|
||||||
|
):
|
||||||
|
best_match = False
|
||||||
|
break
|
||||||
|
self.logger.debug(f"{op1} is a full overload of {op2}")
|
||||||
|
if best_match:
|
||||||
|
return op1.result
|
||||||
|
|
||||||
|
overloads: list[str] = [
|
||||||
|
f"({op.signature.left} {op.signature.method} {op.signature.right}) -> {op.result}"
|
||||||
|
for op in valid_operations
|
||||||
|
]
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(overloads)}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||||
|
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||||
|
if method is None:
|
||||||
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
|
self.reporter.warning(
|
||||||
|
expr.location, f"Unsupported operator {expr.operator}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
left: Type = self.type_of(expr.left)
|
||||||
|
right: Type = self.type_of(expr.right)
|
||||||
|
|
||||||
|
result: Optional[Type] = self.types.get_operation_result(left, method, right)
|
||||||
|
if result is None:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Undefined operation {method} between {left} and {right}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||||
|
callee: Type = self.type_of(expr.callee)
|
||||||
|
if not isinstance(callee, Function):
|
||||||
|
self.reporter.error(expr.callee.location, "Callee is not a function")
|
||||||
|
return UnknownType()
|
||||||
|
function: Function = callee
|
||||||
|
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
||||||
|
for arg in mapped:
|
||||||
|
if not self.is_subtype(arg.type, arg.argument.type):
|
||||||
|
self.reporter.error(
|
||||||
|
arg.expr.location,
|
||||||
|
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||||
|
)
|
||||||
|
return function.returns
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
||||||
|
object: Type = self.type_of(expr.object)
|
||||||
|
base_object: Type = unfold_type(object)
|
||||||
|
match base_object:
|
||||||
|
case ComplexType(properties=properties):
|
||||||
|
if expr.name not in properties:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Unknown property '{expr.name} on {object}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return properties[expr.name]
|
||||||
|
|
||||||
|
case UnknownType():
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
case _:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Cannot get property '{expr.name}' on {object}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
|
||||||
|
match expr.value:
|
||||||
|
case bool(): # Must be before int
|
||||||
|
return self.types.get_type("bool")
|
||||||
|
case int():
|
||||||
|
return self.types.get_type("int")
|
||||||
|
case float():
|
||||||
|
return self.types.get_type("float")
|
||||||
|
case str():
|
||||||
|
return self.types.get_type("str")
|
||||||
|
case _:
|
||||||
|
self.reporter.warning(expr.location, f"Unknown literal {expr}")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
|
||||||
|
return self.look_up_variable(expr.name, expr) or UnknownType()
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
|
||||||
|
left: Type = expr.left.accept(self)
|
||||||
|
right: Type = expr.right.accept(self)
|
||||||
|
|
||||||
|
if self.is_subtype(left, right):
|
||||||
|
return right
|
||||||
|
if self.is_subtype(right, left):
|
||||||
|
return left
|
||||||
|
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Incompatible operand types, {left=} and {right=}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
||||||
|
return expr.type.accept(self)
|
||||||
|
|
||||||
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
||||||
|
test_type: Type = expr.test.accept(self)
|
||||||
|
|
||||||
|
# TODO Allow subtypes or any type
|
||||||
|
if test_type != self.types.get_type("bool"):
|
||||||
|
self.reporter.error(
|
||||||
|
expr.test.location, f"If test must be a boolean, got {test_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
true_type: Type = expr.if_true.accept(self)
|
||||||
|
false_type: Type = expr.if_false.accept(self)
|
||||||
|
if self.is_subtype(true_type, false_type):
|
||||||
|
return false_type
|
||||||
|
if self.is_subtype(false_type, true_type):
|
||||||
|
return true_type
|
||||||
|
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Incompatible types in ternary if branches: true={true_type} and false={false_type}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_base_type(self, node: p.BaseType) -> Type:
|
||||||
|
return self.types.get_type(node.base)
|
||||||
|
|
||||||
|
def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
|
||||||
|
|
||||||
|
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
|
||||||
|
|
||||||
|
def visit_frame_type(self, node: p.FrameType) -> Type: ...
|
||||||
|
|
||||||
|
def map_call_arguments(
|
||||||
|
self, function: Function, call: p.CallExpr
|
||||||
|
) -> list[MappedArgument]:
|
||||||
|
"""Map call arguments to function parameters as defined in its signature
|
||||||
|
|
||||||
|
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||||
|
with the arguments passed at the call site
|
||||||
|
|
||||||
|
Any mismatched, missing or unexpected argument is reported as a diagnostic
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function (Function): the function definition
|
||||||
|
call (p.CallExpr): the call expression
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[MappedArgument]: the list of mapped arguments
|
||||||
|
"""
|
||||||
|
positional: list[tuple[p.Expr, Type]] = [
|
||||||
|
(arg, self.type_of(arg)) for arg in call.arguments
|
||||||
|
]
|
||||||
|
keywords: dict[str, tuple[p.Expr, Type]] = {
|
||||||
|
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
|
||||||
|
}
|
||||||
|
set_args: set[str] = set()
|
||||||
|
|
||||||
|
required_positional: list[str] = [
|
||||||
|
arg.name for arg in function.pos_args + function.args if arg.required
|
||||||
|
]
|
||||||
|
required_keyword: list[str] = [
|
||||||
|
arg.name for arg in function.kw_args if arg.required
|
||||||
|
]
|
||||||
|
|
||||||
|
mapped: list[MappedArgument] = []
|
||||||
|
|
||||||
|
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||||
|
mixed_params: list[Function.Argument] = list(function.args)
|
||||||
|
kw_params: dict[str, Function.Argument] = {
|
||||||
|
arg.name: arg for arg in function.kw_args
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: handle *args and **kwargs sinks
|
||||||
|
for arg in positional:
|
||||||
|
param: Function.Argument
|
||||||
|
if len(pos_params) != 0:
|
||||||
|
param = pos_params.pop(0)
|
||||||
|
elif len(mixed_params) != 0:
|
||||||
|
param = mixed_params.pop(0)
|
||||||
|
else:
|
||||||
|
self.reporter.error(arg[0].location, "Too many positional arguments")
|
||||||
|
break
|
||||||
|
name: str = param.name
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||||
|
for name, arg in keywords.items():
|
||||||
|
param: Function.Argument
|
||||||
|
if name not in kw_params:
|
||||||
|
if name in set_args:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, f"Multiple values for argument '{name}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
param = kw_params.pop(name)
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def join_args(args: list[str]) -> str:
|
||||||
|
args = list(map(lambda a: f"'{a}'", args))
|
||||||
|
if len(args) == 0:
|
||||||
|
return ""
|
||||||
|
if len(args) == 1:
|
||||||
|
return args[0]
|
||||||
|
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||||
|
|
||||||
|
if len(required_positional) != 0:
|
||||||
|
plural: str = "" if len(required_positional) == 1 else "s"
|
||||||
|
args: str = join_args(required_positional)
|
||||||
|
self.reporter.error(
|
||||||
|
call.location,
|
||||||
|
f"Missing required positional argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(required_keyword) != 0:
|
||||||
|
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||||
|
args: str = join_args(required_keyword)
|
||||||
|
self.reporter.error(
|
||||||
|
call.location,
|
||||||
|
f"Missing required keyword argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return mapped
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import midas.ast.midas as m
|
|
||||||
from midas.checker.builtins import BUILTIN_SUBTYPES
|
from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
AliasType,
|
||||||
@@ -10,24 +9,15 @@ from midas.checker.types import (
|
|||||||
GenericType,
|
GenericType,
|
||||||
Operation,
|
Operation,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
|
||||||
UnknownType,
|
|
||||||
substitute_typevars,
|
substitute_typevars,
|
||||||
)
|
)
|
||||||
from midas.resolver.builtin import define_builtins
|
|
||||||
|
|
||||||
|
|
||||||
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
class TypesRegistry:
|
||||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._types: dict[str, Type] = {}
|
self._types: dict[str, Type] = {}
|
||||||
self._operations: dict[Operation.CallSignature, Type] = {}
|
self._operations: dict[Operation.CallSignature, Type] = {}
|
||||||
|
|
||||||
self._local_variables: dict[str, TypeVar] = {}
|
|
||||||
|
|
||||||
define_builtins(self)
|
|
||||||
|
|
||||||
def get_type(self, name: str) -> Type:
|
def get_type(self, name: str) -> Type:
|
||||||
"""Get a type from its name
|
"""Get a type from its name
|
||||||
|
|
||||||
@@ -40,8 +30,6 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
|
|||||||
Returns:
|
Returns:
|
||||||
Type: the type
|
Type: the type
|
||||||
"""
|
"""
|
||||||
if name in self._local_variables:
|
|
||||||
return self._local_variables[name]
|
|
||||||
if name in self._types:
|
if name in self._types:
|
||||||
return self._types[name]
|
return self._types[name]
|
||||||
raise NameError(f"Undefined type {name}")
|
raise NameError(f"Undefined type {name}")
|
||||||
@@ -120,117 +108,6 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
|
|||||||
)
|
)
|
||||||
self._operations[signature] = result
|
self._operations[signature] = result
|
||||||
|
|
||||||
def resolve(self, stmts: list[m.Stmt]):
|
|
||||||
"""Process a sequence of statements
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stmts (list[m.Stmt]): the statements
|
|
||||||
"""
|
|
||||||
for stmt in stmts:
|
|
||||||
stmt.accept(self)
|
|
||||||
|
|
||||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
|
||||||
params: list[TypeVar] = []
|
|
||||||
for param in stmt.params:
|
|
||||||
name: str = param.name.lexeme
|
|
||||||
bound: Optional[Type] = None
|
|
||||||
if param.bound is not None:
|
|
||||||
bound = param.bound.accept(self)
|
|
||||||
var = TypeVar(name=name, bound=bound)
|
|
||||||
self._local_variables[name] = var
|
|
||||||
params.append(var)
|
|
||||||
type: Type = stmt.type.accept(self)
|
|
||||||
if len(params) != 0:
|
|
||||||
type = GenericType(params=params, body=type)
|
|
||||||
name: str = stmt.name.lexeme
|
|
||||||
self.define_type(name, AliasType(name=name, type=type))
|
|
||||||
self._local_variables.clear()
|
|
||||||
|
|
||||||
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
|
|
||||||
|
|
||||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
|
||||||
base: Type = stmt.type.accept(self)
|
|
||||||
for op in stmt.operations:
|
|
||||||
right: Type = op.operand.accept(self)
|
|
||||||
result: Type = op.result.accept(self)
|
|
||||||
self.define_operation(
|
|
||||||
left=base,
|
|
||||||
operator=op.name.lexeme,
|
|
||||||
right=right,
|
|
||||||
result=result,
|
|
||||||
)
|
|
||||||
|
|
||||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
|
|
||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
|
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_get_expr(self, expr: m.GetExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
|
||||||
return expr.expr.accept(self)
|
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
|
||||||
return self.get_type(type.name.lexeme)
|
|
||||||
|
|
||||||
def visit_generic_type(self, type: m.GenericType) -> Type:
|
|
||||||
type_: Type = type.type.accept(self)
|
|
||||||
params: list[Type] = [param.accept(self) for param in type.params]
|
|
||||||
return self.apply_generic(type_, params)
|
|
||||||
|
|
||||||
def apply_generic(self, type: Type, params: list[Type]) -> Type:
|
|
||||||
match type:
|
|
||||||
case AliasType(name=name, type=base):
|
|
||||||
return AliasType(name=name, type=self.apply_generic(base, params))
|
|
||||||
|
|
||||||
case GenericType(params=type_vars, body=body):
|
|
||||||
n_params: int = len(params)
|
|
||||||
n_type_vars: int = len(type_vars)
|
|
||||||
if n_params < n_type_vars:
|
|
||||||
raise ValueError(
|
|
||||||
f"Missing type parameters, expected {n_type_vars} but only {n_params} provided"
|
|
||||||
)
|
|
||||||
if n_params > n_type_vars:
|
|
||||||
raise ValueError(
|
|
||||||
f"Too many type parameters, expected {n_type_vars} but {n_params} provided"
|
|
||||||
)
|
|
||||||
substitutions: dict[str, Type] = {}
|
|
||||||
for param, type_var in zip(params, type_vars):
|
|
||||||
if type_var.bound is not None and not self.is_subtype(
|
|
||||||
param, type_var.bound
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"Type parameter {param} is not a subtype of {type_var.bound}"
|
|
||||||
)
|
|
||||||
substitutions[type_var.name] = param
|
|
||||||
return substitute_typevars(body, substitutions)
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"{type} is not a generic type")
|
|
||||||
|
|
||||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
|
||||||
type_: Type = type.type.accept(self)
|
|
||||||
type.constraint.accept(self)
|
|
||||||
# TODO
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
def visit_complex_type(self, type: m.ComplexType) -> Type:
|
|
||||||
return ComplexType(
|
|
||||||
properties={
|
|
||||||
prop.name.lexeme: prop.type.accept(self) for prop in type.properties
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||||
"""Check whether `type1` is a subtype of `type2`
|
"""Check whether `type1` is a subtype of `type2`
|
||||||
|
|
||||||
@@ -371,3 +248,33 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def apply_generic(self, type: Type, params: list[Type]) -> Type:
|
||||||
|
match type:
|
||||||
|
case AliasType(name=name, type=base):
|
||||||
|
return AliasType(name=name, type=self.apply_generic(base, params))
|
||||||
|
|
||||||
|
case GenericType(params=type_vars, body=body):
|
||||||
|
n_params: int = len(params)
|
||||||
|
n_type_vars: int = len(type_vars)
|
||||||
|
if n_params < n_type_vars:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing type parameters, expected {n_type_vars} but only {n_params} provided"
|
||||||
|
)
|
||||||
|
if n_params > n_type_vars:
|
||||||
|
raise ValueError(
|
||||||
|
f"Too many type parameters, expected {n_type_vars} but {n_params} provided"
|
||||||
|
)
|
||||||
|
substitutions: dict[str, Type] = {}
|
||||||
|
for param, type_var in zip(params, type_vars):
|
||||||
|
if type_var.bound is not None and not self.is_subtype(
|
||||||
|
param, type_var.bound
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Type parameter {param} is not a subtype of {type_var.bound}"
|
||||||
|
)
|
||||||
|
substitutions[type_var.name] = param
|
||||||
|
return substitute_typevars(body, substitutions)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"{type} is not a generic type")
|
||||||
@@ -10,7 +10,7 @@ import midas.ast.midas as m
|
|||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
|
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
|
||||||
from midas.checker.checker import Checker
|
from midas.checker.checker import TypeChecker
|
||||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||||
from midas.checker.types import Type
|
from midas.checker.types import Type
|
||||||
from midas.cli.ansi import Ansi
|
from midas.cli.ansi import Ansi
|
||||||
@@ -25,7 +25,6 @@ from midas.lexer.midas import MidasLexer
|
|||||||
from midas.lexer.token import Token, TokenType
|
from midas.lexer.token import Token, TokenType
|
||||||
from midas.parser.midas import MidasParser
|
from midas.parser.midas import MidasParser
|
||||||
from midas.parser.python import PythonParser
|
from midas.parser.python import PythonParser
|
||||||
from midas.resolver.resolver import Resolver
|
|
||||||
from midas.utils import UniversalJSONDumper
|
from midas.utils import UniversalJSONDumper
|
||||||
|
|
||||||
|
|
||||||
@@ -98,18 +97,13 @@ def compile(
|
|||||||
):
|
):
|
||||||
logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN)
|
logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN)
|
||||||
source: str = file.read()
|
source: str = file.read()
|
||||||
tree: ast.Module = ast.parse(source, filename=file.name)
|
|
||||||
parser = PythonParser()
|
checker = TypeChecker()
|
||||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
for path in types:
|
||||||
resolver = Resolver()
|
checker.import_midas(Path(path.name).resolve())
|
||||||
resolver.resolve(*stmts)
|
|
||||||
types_paths: list[Path] = [Path(t.name).resolve() for t in types]
|
checker.type_check_source(source, str(Path(file.name).resolve()))
|
||||||
checker = Checker(
|
diagnostics: list[Diagnostic] = checker.diagnostics
|
||||||
resolver.locals,
|
|
||||||
source_path=Path(file.name).resolve(),
|
|
||||||
types_paths=types_paths,
|
|
||||||
)
|
|
||||||
diagnostics: list[Diagnostic] = checker.check(stmts)
|
|
||||||
lines: list[str] = source.split("\n")
|
lines: list[str] = source.split("\n")
|
||||||
for diagnostic in diagnostics:
|
for diagnostic in diagnostics:
|
||||||
print_diagnostic(lines, diagnostic)
|
print_diagnostic(lines, diagnostic)
|
||||||
@@ -118,7 +112,7 @@ def compile(
|
|||||||
print(
|
print(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
UniversalJSONDumper.dump(
|
UniversalJSONDumper.dump(
|
||||||
checker.global_env,
|
checker.python_typer.global_env,
|
||||||
[("Environment", "_children")],
|
[("Environment", "_children")],
|
||||||
lambda obj: isinstance(obj, get_args(Type)),
|
lambda obj: isinstance(obj, get_args(Type)),
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -1,15 +1,9 @@
|
|||||||
from __future__ import annotations
|
from midas.checker.registry import TypesRegistry
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from midas.checker.types import BaseType, Type, UnitType
|
from midas.checker.types import BaseType, Type, UnitType
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from midas.resolver.midas import MidasResolver
|
|
||||||
|
|
||||||
|
def op(reg: TypesRegistry, t1: Type, operator: str, t2: Type, t3: Type):
|
||||||
def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
|
reg.define_operation(
|
||||||
ctx.define_operation(
|
|
||||||
left=t1,
|
left=t1,
|
||||||
operator=operator,
|
operator=operator,
|
||||||
right=t2,
|
right=t2,
|
||||||
@@ -17,8 +11,8 @@ def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def basic_op(ctx: MidasResolver, type: Type, op: str):
|
def basic_op(reg: TypesRegistry, type: Type, op: str):
|
||||||
ctx.define_operation(
|
reg.define_operation(
|
||||||
left=type,
|
left=type,
|
||||||
operator=op,
|
operator=op,
|
||||||
right=type,
|
right=type,
|
||||||
@@ -26,47 +20,47 @@ def basic_op(ctx: MidasResolver, type: Type, op: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def define_builtins(ctx: MidasResolver):
|
def define_builtins(reg: TypesRegistry):
|
||||||
"""Define builtin types and operations"""
|
"""Define builtin types and operations"""
|
||||||
unit = ctx.define_type("None", UnitType())
|
unit = reg.define_type("None", UnitType())
|
||||||
bool = ctx.define_type("bool", BaseType(name="bool"))
|
bool = reg.define_type("bool", BaseType(name="bool"))
|
||||||
int = ctx.define_type("int", BaseType(name="int"))
|
int = reg.define_type("int", BaseType(name="int"))
|
||||||
float = ctx.define_type("float", BaseType(name="float"))
|
float = reg.define_type("float", BaseType(name="float"))
|
||||||
str = ctx.define_type("str", BaseType(name="str"))
|
str = reg.define_type("str", BaseType(name="str"))
|
||||||
|
|
||||||
basic_op(ctx, int, "__add__") # int + int = int
|
basic_op(reg, int, "__add__") # int + int = int
|
||||||
basic_op(ctx, int, "__sub__") # int - int = int
|
basic_op(reg, int, "__sub__") # int - int = int
|
||||||
basic_op(ctx, int, "__mul__") # int * int = int
|
basic_op(reg, int, "__mul__") # int * int = int
|
||||||
basic_op(ctx, int, "__pow__") # int ** int = int
|
basic_op(reg, int, "__pow__") # int ** int = int
|
||||||
basic_op(ctx, int, "__mod__") # int % int = int
|
basic_op(reg, int, "__mod__") # int % int = int
|
||||||
basic_op(ctx, int, "__and__") # int & int = int
|
basic_op(reg, int, "__and__") # int & int = int
|
||||||
basic_op(ctx, int, "__or__") # int | int = int
|
basic_op(reg, int, "__or__") # int | int = int
|
||||||
basic_op(ctx, int, "__xor__") # int ^ int = int
|
basic_op(reg, int, "__xor__") # int ^ int = int
|
||||||
op(ctx, int, "__lt__", int, bool) # int < int = bool
|
op(reg, int, "__lt__", int, bool) # int < int = bool
|
||||||
op(ctx, int, "__gt__", int, bool) # int > int = bool
|
op(reg, int, "__gt__", int, bool) # int > int = bool
|
||||||
op(ctx, int, "__le__", int, bool) # int <= int = bool
|
op(reg, int, "__le__", int, bool) # int <= int = bool
|
||||||
op(ctx, int, "__ge__", int, bool) # int >= int = bool
|
op(reg, int, "__ge__", int, bool) # int >= int = bool
|
||||||
op(ctx, int, "__eq__", int, bool) # int == int = bool
|
op(reg, int, "__eq__", int, bool) # int == int = bool
|
||||||
basic_op(ctx, float, "__add__") # float + float = float
|
basic_op(reg, float, "__add__") # float + float = float
|
||||||
basic_op(ctx, float, "__sub__") # float - float = float
|
basic_op(reg, float, "__sub__") # float - float = float
|
||||||
basic_op(ctx, float, "__mul__") # float * float = float
|
basic_op(reg, float, "__mul__") # float * float = float
|
||||||
basic_op(ctx, float, "__truediv__") # float / float = float
|
basic_op(reg, float, "__truediv__") # float / float = float
|
||||||
op(ctx, float, "__lt__", float, bool) # float < float = bool
|
op(reg, float, "__lt__", float, bool) # float < float = bool
|
||||||
op(ctx, float, "__gt__", float, bool) # float > float = bool
|
op(reg, float, "__gt__", float, bool) # float > float = bool
|
||||||
op(ctx, float, "__le__", float, bool) # float <= float = bool
|
op(reg, float, "__le__", float, bool) # float <= float = bool
|
||||||
op(ctx, float, "__ge__", float, bool) # float >= float = bool
|
op(reg, float, "__ge__", float, bool) # float >= float = bool
|
||||||
op(ctx, float, "__eq__", float, bool) # float == float = bool
|
op(reg, float, "__eq__", float, bool) # float == float = bool
|
||||||
basic_op(ctx, str, "__add__") # str + str = str
|
basic_op(reg, str, "__add__") # str + str = str
|
||||||
op(ctx, str, "__eq__", str, bool) # str == str = bool
|
op(reg, str, "__eq__", str, bool) # str == str = bool
|
||||||
|
|
||||||
op(ctx, int, "__lt__", float, bool) # int < float = bool
|
op(reg, int, "__lt__", float, bool) # int < float = bool
|
||||||
op(ctx, int, "__gt__", float, bool) # int > float = bool
|
op(reg, int, "__gt__", float, bool) # int > float = bool
|
||||||
op(ctx, int, "__le__", float, bool) # int <= float = bool
|
op(reg, int, "__le__", float, bool) # int <= float = bool
|
||||||
op(ctx, int, "__ge__", float, bool) # int >= float = bool
|
op(reg, int, "__ge__", float, bool) # int >= float = bool
|
||||||
op(ctx, int, "__eq__", float, bool) # int == float = bool
|
op(reg, int, "__eq__", float, bool) # int == float = bool
|
||||||
|
|
||||||
op(ctx, float, "__lt__", int, bool) # float < int = bool
|
op(reg, float, "__lt__", int, bool) # float < int = bool
|
||||||
op(ctx, float, "__gt__", int, bool) # float > int = bool
|
op(reg, float, "__gt__", int, bool) # float > int = bool
|
||||||
op(ctx, float, "__le__", int, bool) # float <= int = bool
|
op(reg, float, "__le__", int, bool) # float <= int = bool
|
||||||
op(ctx, float, "__ge__", int, bool) # float >= int = bool
|
op(reg, float, "__ge__", int, bool) # float >= int = bool
|
||||||
op(ctx, float, "__eq__", int, bool) # float == int = bool
|
op(reg, float, "__eq__", int, bool) # float == int = bool
|
||||||
|
|||||||
@@ -1,14 +1,11 @@
|
|||||||
import ast
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.checker.checker import Checker
|
from midas.checker.checker import TypeChecker
|
||||||
from midas.checker.diagnostic import Diagnostic
|
from midas.checker.diagnostic import Diagnostic
|
||||||
from midas.checker.types import Type
|
from midas.checker.types import Type
|
||||||
from midas.parser.python import PythonParser
|
|
||||||
from midas.resolver.resolver import Resolver
|
|
||||||
from tests.base import Tester
|
from tests.base import Tester
|
||||||
from tests.serializer.python import PythonAstJsonSerializer
|
from tests.serializer.python import PythonAstJsonSerializer
|
||||||
|
|
||||||
@@ -36,24 +33,16 @@ class CheckerTester(Tester):
|
|||||||
if not path.is_file():
|
if not path.is_file():
|
||||||
raise TypeError(f"Test '{path}' is not a file")
|
raise TypeError(f"Test '{path}' is not a file")
|
||||||
|
|
||||||
types_paths: list[Path] = []
|
result: CaseResult = CaseResult()
|
||||||
|
|
||||||
|
checker = TypeChecker()
|
||||||
types_path: Path = path.with_suffix(".midas")
|
types_path: Path = path.with_suffix(".midas")
|
||||||
if types_path.exists():
|
if types_path.exists():
|
||||||
types_paths.append(types_path)
|
checker.import_midas(types_path)
|
||||||
source: str = path.read_text()
|
|
||||||
tree: ast.Module = ast.parse(source, filename=path)
|
|
||||||
parser = PythonParser()
|
|
||||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
|
||||||
resolver = Resolver()
|
|
||||||
resolver.resolve(*stmts)
|
|
||||||
result: CaseResult = CaseResult()
|
|
||||||
checker = Checker(
|
|
||||||
resolver.locals,
|
|
||||||
source_path=path,
|
|
||||||
types_paths=types_paths,
|
|
||||||
)
|
|
||||||
|
|
||||||
diagnostics: list[Diagnostic] = checker.check(stmts)
|
checker.type_check(path)
|
||||||
|
|
||||||
|
diagnostics: list[Diagnostic] = checker.diagnostics
|
||||||
for diagnostic in diagnostics:
|
for diagnostic in diagnostics:
|
||||||
result.diagnostics.append(
|
result.diagnostics.append(
|
||||||
{
|
{
|
||||||
@@ -72,7 +61,7 @@ class CheckerTester(Tester):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
judgements: list[tuple[p.Expr, Type]] = checker.judgements
|
judgements: list[tuple[p.Expr, Type]] = checker.python_typer.judgements
|
||||||
serializer = PythonAstJsonSerializer()
|
serializer = PythonAstJsonSerializer()
|
||||||
for expr, type in judgements:
|
for expr, type in judgements:
|
||||||
loc = expr.location
|
loc = expr.location
|
||||||
|
|||||||
Reference in New Issue
Block a user