diff --git a/midas/checker/checker.py b/midas/checker/checker.py index 0c54fa4..c26f0aa 100644 --- a/midas/checker/checker.py +++ b/midas/checker/checker.py @@ -1,661 +1,35 @@ -import logging -from dataclasses import dataclass from pathlib import Path from typing import Optional -import midas.ast.midas as m -import midas.ast.python as p -from midas.ast.location import Location -from midas.checker.diagnostic import Diagnostic, DiagnosticType -from midas.checker.environment import Environment -from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS -from midas.checker.types import ( - ComplexType, - Function, - Operation, - Type, - UnitType, - UnknownType, - unfold_type, -) -from midas.lexer.midas import MidasLexer -from midas.lexer.token import Token -from midas.parser.midas import MidasParser -from midas.resolver.midas import MidasResolver +from midas.checker.diagnostic import Diagnostic +from midas.checker.midas import MidasTyper +from midas.checker.python import PythonTyper +from midas.checker.registry import TypesRegistry +from midas.checker.reporter import Reporter -class ReturnException(Exception): - pass +class TypeChecker: + def __init__(self): + self.types: TypesRegistry = TypesRegistry() + self.reporter: Reporter = Reporter() + self.midas_typer = MidasTyper(self.types, self.reporter) + self.python_typer = PythonTyper(self.types, self.reporter) -@dataclass(frozen=True, kw_only=True) -class MappedArgument: - expr: p.Expr - type: Type - argument: Function.Argument + def import_midas(self, path: Path): + source: str = path.read_text() + return self.import_midas_source(source, path=str(path)) + def import_midas_source(self, source: str, path: Optional[str] = None): + self.midas_typer.process(source, path) -class Checker( - p.Stmt.Visitor[None], - p.Expr.Visitor[Type], - p.MidasType.Visitor[Type], -): - """A type checker which can use custom type definitions""" + def type_check(self, path: Path): + source: str = path.read_text() + return self.type_check_source(source, path=str(path)) - def __init__( - self, - locals: dict[p.Expr, int], - source_path: Path, - types_paths: list[Path], - ): - self.logger: logging.Logger = logging.getLogger("Checker") - self.source_path: Path = source_path - self.types_paths: list[Path] = types_paths - self.ctx: MidasResolver = MidasResolver() - self.global_env: Environment = Environment() - self.env: Environment = self.global_env - self.locals: dict[p.Expr, int] = locals - self.diagnostics: list[Diagnostic] = [] - self.judgements: list[tuple[p.Expr, Type]] = [] + def type_check_source(self, source: str, path: Optional[str] = None): + self.python_typer.process(source, path) - def diagnostic(self, type: DiagnosticType, location: Location, message: str): - self.diagnostics.append( - Diagnostic( - file_path=self.source_path, - location=location, - type=type, - message=message, - ) - ) - - def error(self, location: Location, message: str): - self.diagnostic( - type=DiagnosticType.ERROR, - location=location, - message=message, - ) - - def warning(self, location: Location, message: str): - self.diagnostic( - type=DiagnosticType.WARNING, - location=location, - message=message, - ) - - def info(self, location: Location, message: str): - self.diagnostic( - type=DiagnosticType.INFO, - location=location, - message=message, - ) - - def type_of(self, expr: p.Expr) -> Type: - """Evaluate the type of an expression - - Args: - expr (p.Expr): the expression to evaluate - - Returns: - Type: the type of the given expression - """ - type: Type = expr.accept(self) - self.judgements.append((expr, type)) - return type - - def process_block(self, block: list[p.Stmt], env: Environment) -> bool: - """Evaluate a sequence of statements - - Args: - block (list[p.Stmt]): the statements to evaluate - env (Environment): the environment in which to evaluate - - Returns: - bool: whether a return statement is present in the block - """ - previous_env: Environment = self.env - self.env = env - returned: bool = False - for i, stmt in enumerate(block): - try: - stmt.accept(self) - except ReturnException: - returned = True - if i < len(block) - 1: - self.warning(block[i + 1].location, "Unreachable statement") - break - self.env = previous_env - return returned - - def check(self, statements: list[p.Stmt]) -> list[Diagnostic]: - """Type check a sequence of statements and returns diagnostics - - Args: - statements (list[p.Stmt]): the statements to evaluate and check - - Returns: - list[Diagnostic]: the list of diagnostics (errors, warning, etc.) - """ - self.diagnostics = [] - - for path in self.types_paths: - self.import_midas(path) - self.logger.debug(f"Midas types: {self.ctx._types}") - self.logger.debug(f"Midas operations: {self.ctx._operations}") - - for stmt in statements: - stmt.accept(self) - - self.logger.debug(f"Final environment: {self.env.flat_dict()}") - return self.diagnostics - - def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]: - """Look up a variable in the environment it was declared - - Args: - name (str): the name of the variable - expr (p.Expr): the variable expression, used to lookup the scope distance - - Returns: - Optional[Type]: the type of the variable, or None if it was not found - """ - distance: Optional[int] = self.locals.get(expr) - if distance is not None: - return self.env.get_at(distance, name) - return self.global_env.get(name) - - def import_midas(self, path: Path) -> None: - """Import Midas definitions from a path - - Args: - path (Path): the import path - """ - self.logger.debug(f"Importing type definitions from {path}") - lexer: MidasLexer = MidasLexer(path.read_text()) - tokens: list[Token] = lexer.process() - parser: MidasParser = MidasParser(tokens) - stmts: list[m.Stmt] = parser.parse() - self.ctx.resolve(stmts) - - def is_subtype(self, type1: Type, type2: Type) -> bool: - return self.ctx.is_subtype(type1, type2) - - def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: - self.type_of(stmt.expr) - - def visit_function(self, stmt: p.Function) -> None: - env: Environment = Environment(self.env) - pos_args: list[Function.Argument] = [] - args: list[Function.Argument] = [] - kw_args: list[Function.Argument] = [] - - def eval_arg_type(arg: p.Function.Argument) -> Type: - if arg.type is not None: - return arg.type.accept(self) - if arg.default is not None: - return arg.default.accept(self) - return UnknownType() - - pos: int = 0 - for arg in stmt.posonlyargs: - pos_args.append( - Function.Argument( - pos=pos, - name=arg.name, - type=eval_arg_type(arg), - required=arg.default is None, - ) - ) - pos += 1 - for arg in stmt.args: - args.append( - Function.Argument( - pos=pos, - name=arg.name, - type=eval_arg_type(arg), - required=arg.default is None, - ) - ) - pos += 1 - for arg in stmt.kwonlyargs: - kw_args.append( - Function.Argument( - pos=pos, # not relevant - name=arg.name, - type=eval_arg_type(arg), - required=arg.default is None, - ) - ) - pos += 1 - - for arg in pos_args + args + kw_args: - env.define(arg.name, arg.type) - - returns_hint: Optional[Type] = None - if stmt.returns is not None: - returns_hint = stmt.returns.accept(self) - # Early define to handle simple fully-typed recursion - inside_function: Function = Function( - name=stmt.name, - pos_args=pos_args, - args=args, - kw_args=kw_args, - returns=returns_hint, - ) - self.env.define(stmt.name, inside_function) - - returned: bool = self.process_block(stmt.body, env) - inferred_return: Type = UnknownType() - if not returned: - env.return_types.append(UnitType()) - return_types: set[Type] = set(env.return_types) - if len(return_types) == 1: - inferred_return = list(return_types)[0] - elif len(return_types) > 1: - self.error( - stmt.location, - f"Mixed return types: {env.return_types}", - ) - - returns: Type = UnknownType() - if returns_hint is not None: - assert stmt.returns is not None - returns = returns_hint - if returns != inferred_return: - self.error( - stmt.returns.location, - f"Return type mismatch, annotated {returns} but returns {inferred_return}", - ) - else: - returns = inferred_return - - # TODO: handle *args and **kwargs sinks - function: Function = Function( - name=stmt.name, - pos_args=pos_args, - args=args, - kw_args=kw_args, - returns=returns, - ) - self.env.define(stmt.name, function) - - def visit_type_assign(self, stmt: p.TypeAssign) -> None: - # TODO check not yet defined locally - type: Type = stmt.type.accept(self) - self.env.define(stmt.name, type) - - def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: - value_type: Type = self.type_of(stmt.value) - for target in stmt.targets: - self._assign(stmt.location, target, value_type) - - def _assign(self, location: Location, target: p.Expr, value_type: Type): - match target: - case p.VariableExpr(): - self._assign_var(location, target, value_type) - - case p.GetExpr(): - self._assign_attr(location, target, value_type) - - case _: - if not isinstance(target, p.VariableExpr): - self.logger.warning(f"Unsupported assignment to {target}") - self.warning(target.location, f"Unsupported assignment to {target}") - - def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type): - name: str = target.name - var_type: Optional[Type] = self.look_up_variable(name, target) - - if var_type is None: - self.env.define(name, value_type) - else: - # S <: T - # Γ, x: T v: S - # x = v - if not self.is_subtype(value_type, var_type): - self.error( - location, - f"Cannot assign {value_type} to {name} of type {var_type}", - ) - - def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type): - object: Type = self.type_of(target.object) - base_object: Type = unfold_type(object) - match base_object: - case ComplexType(properties=properties): - if target.name not in properties: - self.error( - target.location, f"Unknown property '{target.name} on {object}" - ) - return - - prop_type: Type = properties[target.name] - if not self.is_subtype(value_type, prop_type): - self.error( - location, - f"Cannot assign {value_type} to property '{target.name}' of type {prop_type} on {object}", - ) - return - - case UnknownType(): - pass - - case _: - self.error( - target.location, - f"Cannot assign {value_type} to unknown property '{target.name}' on {object}", - ) - - def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: - type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType() - self.env.return_types.append(type) - raise ReturnException() - - def visit_if_stmt(self, stmt: p.IfStmt) -> None: - # Not evaluated in sub-environment because assignments in the test leak out of the if - # For example: - # if (m := 1 + 1) < 2: - # ... - # print(m) # <- m is still defined - test_type: Type = stmt.test.accept(self) - - # TODO Allow subtypes or any type - if test_type != self.ctx.get_type("bool"): - self.error( - stmt.test.location, f"If test must be a boolean, got {test_type}" - ) - - env: Environment = Environment(self.env) - body_returned: bool = self.process_block(stmt.body, env) - else_returned: bool = self.process_block(stmt.orelse, env) - self.env.return_types.extend(env.return_types) - if body_returned and else_returned: - raise ReturnException() - - def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: - method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) - if method is None: - self.logger.warning(f"Unsupported operator {expr.operator}") - self.warning(expr.location, f"Unsupported operator {expr.operator}") - return UnknownType() - left: Type = self.type_of(expr.left) - right: Type = self.type_of(expr.right) - - operations: list[Operation] = self.ctx.get_operations_by_name(method) - valid_operations: list[Operation] = [] - for op in operations: - sig: Operation.CallSignature = op.signature - if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right): - valid_operations.append(op) - - if len(valid_operations) == 0: - self.error( - expr.location, - f"Undefined operation {method} between {left} and {right}", - ) - return UnknownType() - elif len(valid_operations) == 1: - self.logger.debug(f"Unique operation {method} between {left} and {right}") - return valid_operations[0].result - - for i, op1 in enumerate(valid_operations): - sig1: Operation.CallSignature = op1.signature - best_match: bool = True - for j, op2 in enumerate(valid_operations): - if i == j: - continue - sig2: Operation.CallSignature = op2.signature - if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype( - sig1.right, sig2.right - ): - best_match = False - break - self.logger.debug(f"{op1} is a full overload of {op2}") - if best_match: - return op1.result - - overloads: list[str] = [ - f"({op.signature.left} {op.signature.method} {op.signature.right}) -> {op.result}" - for op in valid_operations - ] - self.error( - expr.location, - f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(overloads)}", - ) - return UnknownType() - - def visit_compare_expr(self, expr: p.CompareExpr) -> Type: - method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__) - if method is None: - self.logger.warning(f"Unsupported operator {expr.operator}") - self.warning(expr.location, f"Unsupported operator {expr.operator}") - return UnknownType() - left: Type = self.type_of(expr.left) - right: Type = self.type_of(expr.right) - - result: Optional[Type] = self.ctx.get_operation_result(left, method, right) - if result is None: - self.error( - expr.location, - f"Undefined operation {method} between {left} and {right}", - ) - return UnknownType() - return result - - def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ... - - def visit_call_expr(self, expr: p.CallExpr) -> Type: - callee: Type = self.type_of(expr.callee) - if not isinstance(callee, Function): - self.error(expr.callee.location, "Callee is not a function") - return UnknownType() - function: Function = callee - mapped: list[MappedArgument] = self.map_call_arguments(function, expr) - for arg in mapped: - if not self.is_subtype(arg.type, arg.argument.type): - self.error( - arg.expr.location, - f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", - ) - return function.returns - - def visit_get_expr(self, expr: p.GetExpr) -> Type: - object: Type = self.type_of(expr.object) - base_object: Type = unfold_type(object) - match base_object: - case ComplexType(properties=properties): - if expr.name not in properties: - self.error( - expr.location, f"Unknown property '{expr.name} on {object}" - ) - return UnknownType() - return properties[expr.name] - - case UnknownType(): - return UnknownType() - - case _: - self.error( - expr.location, f"Cannot get property '{expr.name}' on {object}" - ) - return UnknownType() - - def visit_literal_expr(self, expr: p.LiteralExpr) -> Type: - match expr.value: - case bool(): # Must be before int - return self.ctx.get_type("bool") - case int(): - return self.ctx.get_type("int") - case float(): - return self.ctx.get_type("float") - case str(): - return self.ctx.get_type("str") - case _: - self.warning(expr.location, f"Unknown literal {expr}") - return UnknownType() - - def visit_variable_expr(self, expr: p.VariableExpr) -> Type: - return self.look_up_variable(expr.name, expr) or UnknownType() - - def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: - left: Type = expr.left.accept(self) - right: Type = expr.right.accept(self) - - if self.is_subtype(left, right): - return right - if self.is_subtype(right, left): - return left - - self.error( - expr.location, - f"Incompatible operand types, {left=} and {right=}", - ) - return UnknownType() - - def visit_cast_expr(self, expr: p.CastExpr) -> Type: - return expr.type.accept(self) - - def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type: - test_type: Type = expr.test.accept(self) - - # TODO Allow subtypes or any type - if test_type != self.ctx.get_type("bool"): - self.error( - expr.test.location, f"If test must be a boolean, got {test_type}" - ) - - true_type: Type = expr.if_true.accept(self) - false_type: Type = expr.if_false.accept(self) - if self.is_subtype(true_type, false_type): - return false_type - if self.is_subtype(false_type, true_type): - return true_type - - self.error( - expr.location, - f"Incompatible types in ternary if branches: true={true_type} and false={false_type}", - ) - return UnknownType() - - def visit_base_type(self, node: p.BaseType) -> Type: - return self.ctx.get_type(node.base) - - def visit_constraint_type(self, node: p.ConstraintType) -> Type: ... - - def visit_frame_column(self, node: p.FrameColumn) -> Type: ... - - def visit_frame_type(self, node: p.FrameType) -> Type: ... - - def map_call_arguments( - self, function: Function, call: p.CallExpr - ) -> list[MappedArgument]: - """Map call arguments to function parameters as defined in its signature - - This method maps positional-only, keyword-only and mixed parameter definitions - with the arguments passed at the call site - - Any mismatched, missing or unexpected argument is reported as a diagnostic - - Args: - function (Function): the function definition - call (p.CallExpr): the call expression - - Returns: - list[MappedArgument]: the list of mapped arguments - """ - positional: list[tuple[p.Expr, Type]] = [ - (arg, self.type_of(arg)) for arg in call.arguments - ] - keywords: dict[str, tuple[p.Expr, Type]] = { - name: (arg, self.type_of(arg)) for name, arg in call.keywords.items() - } - set_args: set[str] = set() - - required_positional: list[str] = [ - arg.name for arg in function.pos_args + function.args if arg.required - ] - required_keyword: list[str] = [ - arg.name for arg in function.kw_args if arg.required - ] - - mapped: list[MappedArgument] = [] - - pos_params: list[Function.Argument] = list(function.pos_args) - mixed_params: list[Function.Argument] = list(function.args) - kw_params: dict[str, Function.Argument] = { - arg.name: arg for arg in function.kw_args - } - - # TODO: handle *args and **kwargs sinks - for arg in positional: - param: Function.Argument - if len(pos_params) != 0: - param = pos_params.pop(0) - elif len(mixed_params) != 0: - param = mixed_params.pop(0) - else: - self.error(arg[0].location, "Too many positional arguments") - break - name: str = param.name - if name in required_positional: - required_positional.remove(name) - if name in required_keyword: - required_keyword.remove(name) - set_args.add(name) - mapped.append( - MappedArgument( - expr=arg[0], - type=arg[1], - argument=param, - ) - ) - - kw_params.update({arg.name: arg for arg in mixed_params}) - for name, arg in keywords.items(): - param: Function.Argument - if name not in kw_params: - if name in set_args: - self.error( - arg[0].location, f"Multiple values for argument '{name}'" - ) - else: - self.error(arg[0].location, f"Unknown keyword argument '{name}'") - continue - param = kw_params.pop(name) - if name in required_positional: - required_positional.remove(name) - if name in required_keyword: - required_keyword.remove(name) - set_args.add(name) - mapped.append( - MappedArgument( - expr=arg[0], - type=arg[1], - argument=param, - ) - ) - - def join_args(args: list[str]) -> str: - args = list(map(lambda a: f"'{a}'", args)) - if len(args) == 0: - return "" - if len(args) == 1: - return args[0] - return ", ".join(args[:-1]) + " and " + args[-1] - - if len(required_positional) != 0: - plural: str = "" if len(required_positional) == 1 else "s" - args: str = join_args(required_positional) - self.error( - call.location, - f"Missing required positional argument{plural}: {args}", - ) - - if len(required_keyword) != 0: - plural: str = "" if len(required_keyword) == 1 else "s" - args: str = join_args(required_keyword) - self.error( - call.location, - f"Missing required keyword argument{plural}: {args}", - ) - - return mapped + @property + def diagnostics(self) -> list[Diagnostic]: + return self.reporter.diagnostics diff --git a/midas/checker/diagnostic.py b/midas/checker/diagnostic.py index 2925653..f4b3d12 100644 --- a/midas/checker/diagnostic.py +++ b/midas/checker/diagnostic.py @@ -1,6 +1,5 @@ from dataclasses import dataclass from enum import StrEnum -from pathlib import Path from typing import Optional from midas.ast.location import Location @@ -14,7 +13,7 @@ class DiagnosticType(StrEnum): @dataclass(frozen=True) class Diagnostic: - file_path: Optional[str | Path] + file_path: Optional[str] location: Location type: DiagnosticType message: str diff --git a/midas/checker/midas.py b/midas/checker/midas.py new file mode 100644 index 0000000..37a856d --- /dev/null +++ b/midas/checker/midas.py @@ -0,0 +1,137 @@ +import logging +from typing import Optional + +import midas.ast.midas as m +from midas.checker.registry import TypesRegistry +from midas.checker.reporter import FileReporter, Reporter +from midas.checker.types import ( + AliasType, + ComplexType, + GenericType, + Type, + TypeVar, + UnknownType, +) +from midas.lexer.midas import MidasLexer +from midas.lexer.token import Token +from midas.parser.midas import MidasParser +from midas.resolver.builtin import define_builtins + + +class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]): + """A resolver which evaluates Midas type definitions and build a registry""" + + def __init__(self, types: TypesRegistry, reporter: Reporter) -> None: + self.logger: logging.Logger = logging.getLogger("MidasTyper") + self.reporter: FileReporter = reporter.for_file(None) + + self.types: TypesRegistry = types + self._local_variables: dict[str, TypeVar] = {} + + define_builtins(self.types) + + def process(self, source: str, path: Optional[str]): + self.reporter = self.reporter.for_file(path) + lexer: MidasLexer = MidasLexer(source) + tokens: list[Token] = lexer.process() + parser: MidasParser = MidasParser(tokens) + stmts: list[m.Stmt] = parser.parse() + self.resolve(stmts) + + def get_type(self, name: str) -> Type: + """Get a type from its name + + Args: + name (str): the name of the type + + Raises: + NameError: if the type is not defined + + Returns: + Type: the type + """ + if name in self._local_variables: + return self._local_variables[name] + return self.types.get_type(name) + + def resolve(self, stmts: list[m.Stmt]): + """Process a sequence of statements + + Args: + stmts (list[m.Stmt]): the statements + """ + for stmt in stmts: + stmt.accept(self) + + def visit_type_stmt(self, stmt: m.TypeStmt) -> None: + params: list[TypeVar] = [] + for param in stmt.params: + name: str = param.name.lexeme + bound: Optional[Type] = None + if param.bound is not None: + bound = param.bound.accept(self) + var = TypeVar(name=name, bound=bound) + self._local_variables[name] = var + params.append(var) + type: Type = stmt.type.accept(self) + if len(params) != 0: + type = GenericType(params=params, body=type) + name: str = stmt.name.lexeme + self.types.define_type(name, AliasType(name=name, type=type)) + self._local_variables.clear() + + def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... + + def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: + base: Type = stmt.type.accept(self) + for op in stmt.operations: + right: Type = op.operand.accept(self) + result: Type = op.result.accept(self) + self.types.define_operation( + left=base, + operator=op.name.lexeme, + right=right, + result=result, + ) + + def visit_op_stmt(self, stmt: m.OpStmt) -> None: ... + + def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ... + + def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ... + + def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ... + + def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ... + + def visit_get_expr(self, expr: m.GetExpr) -> None: ... + + def visit_variable_expr(self, expr: m.VariableExpr) -> None: ... + + def visit_grouping_expr(self, expr: m.GroupingExpr) -> None: + return expr.expr.accept(self) + + def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ... + + def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ... + + def visit_named_type(self, type: m.NamedType) -> Type: + return self.get_type(type.name.lexeme) + + def visit_generic_type(self, type: m.GenericType) -> Type: + type_: Type = type.type.accept(self) + params: list[Type] = [param.accept(self) for param in type.params] + return self.types.apply_generic(type_, params) + + def visit_constraint_type(self, type: m.ConstraintType) -> Type: + type_: Type = type.type.accept(self) + type.constraint.accept(self) + # TODO + return UnknownType() + + def visit_complex_type(self, type: m.ComplexType) -> Type: + return ComplexType( + properties={ + prop.name.lexeme: prop.type.accept(self) for prop in type.properties + } + ) diff --git a/midas/checker/python.py b/midas/checker/python.py new file mode 100644 index 0000000..751497d --- /dev/null +++ b/midas/checker/python.py @@ -0,0 +1,626 @@ +import ast +import logging +from dataclasses import dataclass +from typing import Optional + +import midas.ast.python as p +from midas.ast.location import Location +from midas.checker.environment import Environment +from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS +from midas.checker.registry import TypesRegistry +from midas.checker.reporter import FileReporter, Reporter +from midas.checker.types import ( + ComplexType, + Function, + Operation, + Type, + UnitType, + UnknownType, + unfold_type, +) +from midas.parser.python import PythonParser +from midas.resolver.resolver import Resolver + + +class ReturnException(Exception): + pass + + +@dataclass(frozen=True, kw_only=True) +class MappedArgument: + expr: p.Expr + type: Type + argument: Function.Argument + + +class PythonTyper( + p.Stmt.Visitor[None], + p.Expr.Visitor[Type], + p.MidasType.Visitor[Type], +): + """A type checker which can use custom type definitions""" + + def __init__( + self, + types: TypesRegistry, + reporter: Reporter, + ): + self.logger: logging.Logger = logging.getLogger("PythonTyper") + self.reporter: FileReporter = reporter.for_file(None) + self.types: TypesRegistry = types + self.global_env: Environment = Environment() + self.env: Environment = self.global_env + self.locals: dict[p.Expr, int] = {} + self.judgements: list[tuple[p.Expr, Type]] = [] + + def process(self, source: str, path: Optional[str]): + self.reporter = self.reporter.for_file(path) + + tree: ast.Module = ast.parse(source, filename=path or "") + parser = PythonParser() + stmts: list[p.Stmt] = parser.parse_module(tree) + resolver = Resolver() + resolver.resolve(*stmts) + + self.env = self.global_env + self.locals = resolver.locals + self.judgements = [] + + self.check(stmts) + + def type_of(self, expr: p.Expr) -> Type: + """Evaluate the type of an expression + + Args: + expr (p.Expr): the expression to evaluate + + Returns: + Type: the type of the given expression + """ + type: Type = expr.accept(self) + self.judgements.append((expr, type)) + return type + + def process_block(self, block: list[p.Stmt], env: Environment) -> bool: + """Evaluate a sequence of statements + + Args: + block (list[p.Stmt]): the statements to evaluate + env (Environment): the environment in which to evaluate + + Returns: + bool: whether a return statement is present in the block + """ + previous_env: Environment = self.env + self.env = env + returned: bool = False + for i, stmt in enumerate(block): + try: + stmt.accept(self) + except ReturnException: + returned = True + if i < len(block) - 1: + self.reporter.warning( + block[i + 1].location, "Unreachable statement" + ) + break + self.env = previous_env + return returned + + def check(self, statements: list[p.Stmt]) -> None: + """Type check a sequence of statements and returns diagnostics + + Args: + statements (list[p.Stmt]): the statements to evaluate and check + """ + for stmt in statements: + stmt.accept(self) + + self.logger.debug(f"Final environment: {self.env.flat_dict()}") + + def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]: + """Look up a variable in the environment it was declared + + Args: + name (str): the name of the variable + expr (p.Expr): the variable expression, used to lookup the scope distance + + Returns: + Optional[Type]: the type of the variable, or None if it was not found + """ + distance: Optional[int] = self.locals.get(expr) + if distance is not None: + return self.env.get_at(distance, name) + return self.global_env.get(name) + + def is_subtype(self, type1: Type, type2: Type) -> bool: + return self.types.is_subtype(type1, type2) + + def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: + self.type_of(stmt.expr) + + def visit_function(self, stmt: p.Function) -> None: + env: Environment = Environment(self.env) + pos_args: list[Function.Argument] = [] + args: list[Function.Argument] = [] + kw_args: list[Function.Argument] = [] + + def eval_arg_type(arg: p.Function.Argument) -> Type: + if arg.type is not None: + return arg.type.accept(self) + if arg.default is not None: + return arg.default.accept(self) + return UnknownType() + + pos: int = 0 + for arg in stmt.posonlyargs: + pos_args.append( + Function.Argument( + pos=pos, + name=arg.name, + type=eval_arg_type(arg), + required=arg.default is None, + ) + ) + pos += 1 + for arg in stmt.args: + args.append( + Function.Argument( + pos=pos, + name=arg.name, + type=eval_arg_type(arg), + required=arg.default is None, + ) + ) + pos += 1 + for arg in stmt.kwonlyargs: + kw_args.append( + Function.Argument( + pos=pos, # not relevant + name=arg.name, + type=eval_arg_type(arg), + required=arg.default is None, + ) + ) + pos += 1 + + for arg in pos_args + args + kw_args: + env.define(arg.name, arg.type) + + returns_hint: Optional[Type] = None + if stmt.returns is not None: + returns_hint = stmt.returns.accept(self) + # Early define to handle simple fully-typed recursion + inside_function: Function = Function( + name=stmt.name, + pos_args=pos_args, + args=args, + kw_args=kw_args, + returns=returns_hint, + ) + self.env.define(stmt.name, inside_function) + + returned: bool = self.process_block(stmt.body, env) + inferred_return: Type = UnknownType() + if not returned: + env.return_types.append(UnitType()) + return_types: set[Type] = set(env.return_types) + if len(return_types) == 1: + inferred_return = list(return_types)[0] + elif len(return_types) > 1: + self.reporter.error( + stmt.location, + f"Mixed return types: {env.return_types}", + ) + + returns: Type = UnknownType() + if returns_hint is not None: + assert stmt.returns is not None + returns = returns_hint + if returns != inferred_return: + self.reporter.error( + stmt.returns.location, + f"Return type mismatch, annotated {returns} but returns {inferred_return}", + ) + else: + returns = inferred_return + + # TODO: handle *args and **kwargs sinks + function: Function = Function( + name=stmt.name, + pos_args=pos_args, + args=args, + kw_args=kw_args, + returns=returns, + ) + self.env.define(stmt.name, function) + + def visit_type_assign(self, stmt: p.TypeAssign) -> None: + # TODO check not yet defined locally + type: Type = stmt.type.accept(self) + self.env.define(stmt.name, type) + + def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: + value_type: Type = self.type_of(stmt.value) + for target in stmt.targets: + self._assign(stmt.location, target, value_type) + + def _assign(self, location: Location, target: p.Expr, value_type: Type): + match target: + case p.VariableExpr(): + self._assign_var(location, target, value_type) + + case p.GetExpr(): + self._assign_attr(location, target, value_type) + + case _: + if not isinstance(target, p.VariableExpr): + self.logger.warning(f"Unsupported assignment to {target}") + self.reporter.warning( + target.location, f"Unsupported assignment to {target}" + ) + + def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type): + name: str = target.name + var_type: Optional[Type] = self.look_up_variable(name, target) + + if var_type is None: + self.env.define(name, value_type) + else: + # S <: T + # Γ, x: T v: S + # x = v + if not self.is_subtype(value_type, var_type): + self.reporter.error( + location, + f"Cannot assign {value_type} to {name} of type {var_type}", + ) + + def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type): + object: Type = self.type_of(target.object) + base_object: Type = unfold_type(object) + match base_object: + case ComplexType(properties=properties): + if target.name not in properties: + self.reporter.error( + target.location, f"Unknown property '{target.name} on {object}" + ) + return + + prop_type: Type = properties[target.name] + if not self.is_subtype(value_type, prop_type): + self.reporter.error( + location, + f"Cannot assign {value_type} to property '{target.name}' of type {prop_type} on {object}", + ) + return + + case UnknownType(): + pass + + case _: + self.reporter.error( + target.location, + f"Cannot assign {value_type} to unknown property '{target.name}' on {object}", + ) + + def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: + type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType() + self.env.return_types.append(type) + raise ReturnException() + + def visit_if_stmt(self, stmt: p.IfStmt) -> None: + # Not evaluated in sub-environment because assignments in the test leak out of the if + # For example: + # if (m := 1 + 1) < 2: + # ... + # print(m) # <- m is still defined + test_type: Type = stmt.test.accept(self) + + # TODO Allow subtypes or any type + if test_type != self.types.get_type("bool"): + self.reporter.error( + stmt.test.location, f"If test must be a boolean, got {test_type}" + ) + + env: Environment = Environment(self.env) + body_returned: bool = self.process_block(stmt.body, env) + else_returned: bool = self.process_block(stmt.orelse, env) + self.env.return_types.extend(env.return_types) + if body_returned and else_returned: + raise ReturnException() + + def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: + method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) + if method is None: + self.logger.warning(f"Unsupported operator {expr.operator}") + self.reporter.warning( + expr.location, f"Unsupported operator {expr.operator}" + ) + return UnknownType() + left: Type = self.type_of(expr.left) + right: Type = self.type_of(expr.right) + + operations: list[Operation] = self.types.get_operations_by_name(method) + valid_operations: list[Operation] = [] + for op in operations: + sig: Operation.CallSignature = op.signature + if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right): + valid_operations.append(op) + + if len(valid_operations) == 0: + self.reporter.error( + expr.location, + f"Undefined operation {method} between {left} and {right}", + ) + return UnknownType() + elif len(valid_operations) == 1: + self.logger.debug(f"Unique operation {method} between {left} and {right}") + return valid_operations[0].result + + for i, op1 in enumerate(valid_operations): + sig1: Operation.CallSignature = op1.signature + best_match: bool = True + for j, op2 in enumerate(valid_operations): + if i == j: + continue + sig2: Operation.CallSignature = op2.signature + if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype( + sig1.right, sig2.right + ): + best_match = False + break + self.logger.debug(f"{op1} is a full overload of {op2}") + if best_match: + return op1.result + + overloads: list[str] = [ + f"({op.signature.left} {op.signature.method} {op.signature.right}) -> {op.result}" + for op in valid_operations + ] + self.reporter.error( + expr.location, + f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(overloads)}", + ) + return UnknownType() + + def visit_compare_expr(self, expr: p.CompareExpr) -> Type: + method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__) + if method is None: + self.logger.warning(f"Unsupported operator {expr.operator}") + self.reporter.warning( + expr.location, f"Unsupported operator {expr.operator}" + ) + return UnknownType() + left: Type = self.type_of(expr.left) + right: Type = self.type_of(expr.right) + + result: Optional[Type] = self.types.get_operation_result(left, method, right) + if result is None: + self.reporter.error( + expr.location, + f"Undefined operation {method} between {left} and {right}", + ) + return UnknownType() + return result + + def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ... + + def visit_call_expr(self, expr: p.CallExpr) -> Type: + callee: Type = self.type_of(expr.callee) + if not isinstance(callee, Function): + self.reporter.error(expr.callee.location, "Callee is not a function") + return UnknownType() + function: Function = callee + mapped: list[MappedArgument] = self.map_call_arguments(function, expr) + for arg in mapped: + if not self.is_subtype(arg.type, arg.argument.type): + self.reporter.error( + arg.expr.location, + f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", + ) + return function.returns + + def visit_get_expr(self, expr: p.GetExpr) -> Type: + object: Type = self.type_of(expr.object) + base_object: Type = unfold_type(object) + match base_object: + case ComplexType(properties=properties): + if expr.name not in properties: + self.reporter.error( + expr.location, f"Unknown property '{expr.name} on {object}" + ) + return UnknownType() + return properties[expr.name] + + case UnknownType(): + return UnknownType() + + case _: + self.reporter.error( + expr.location, f"Cannot get property '{expr.name}' on {object}" + ) + return UnknownType() + + def visit_literal_expr(self, expr: p.LiteralExpr) -> Type: + match expr.value: + case bool(): # Must be before int + return self.types.get_type("bool") + case int(): + return self.types.get_type("int") + case float(): + return self.types.get_type("float") + case str(): + return self.types.get_type("str") + case _: + self.reporter.warning(expr.location, f"Unknown literal {expr}") + return UnknownType() + + def visit_variable_expr(self, expr: p.VariableExpr) -> Type: + return self.look_up_variable(expr.name, expr) or UnknownType() + + def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: + left: Type = expr.left.accept(self) + right: Type = expr.right.accept(self) + + if self.is_subtype(left, right): + return right + if self.is_subtype(right, left): + return left + + self.reporter.error( + expr.location, + f"Incompatible operand types, {left=} and {right=}", + ) + return UnknownType() + + def visit_cast_expr(self, expr: p.CastExpr) -> Type: + return expr.type.accept(self) + + def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type: + test_type: Type = expr.test.accept(self) + + # TODO Allow subtypes or any type + if test_type != self.types.get_type("bool"): + self.reporter.error( + expr.test.location, f"If test must be a boolean, got {test_type}" + ) + + true_type: Type = expr.if_true.accept(self) + false_type: Type = expr.if_false.accept(self) + if self.is_subtype(true_type, false_type): + return false_type + if self.is_subtype(false_type, true_type): + return true_type + + self.reporter.error( + expr.location, + f"Incompatible types in ternary if branches: true={true_type} and false={false_type}", + ) + return UnknownType() + + def visit_base_type(self, node: p.BaseType) -> Type: + return self.types.get_type(node.base) + + def visit_constraint_type(self, node: p.ConstraintType) -> Type: ... + + def visit_frame_column(self, node: p.FrameColumn) -> Type: ... + + def visit_frame_type(self, node: p.FrameType) -> Type: ... + + def map_call_arguments( + self, function: Function, call: p.CallExpr + ) -> list[MappedArgument]: + """Map call arguments to function parameters as defined in its signature + + This method maps positional-only, keyword-only and mixed parameter definitions + with the arguments passed at the call site + + Any mismatched, missing or unexpected argument is reported as a diagnostic + + Args: + function (Function): the function definition + call (p.CallExpr): the call expression + + Returns: + list[MappedArgument]: the list of mapped arguments + """ + positional: list[tuple[p.Expr, Type]] = [ + (arg, self.type_of(arg)) for arg in call.arguments + ] + keywords: dict[str, tuple[p.Expr, Type]] = { + name: (arg, self.type_of(arg)) for name, arg in call.keywords.items() + } + set_args: set[str] = set() + + required_positional: list[str] = [ + arg.name for arg in function.pos_args + function.args if arg.required + ] + required_keyword: list[str] = [ + arg.name for arg in function.kw_args if arg.required + ] + + mapped: list[MappedArgument] = [] + + pos_params: list[Function.Argument] = list(function.pos_args) + mixed_params: list[Function.Argument] = list(function.args) + kw_params: dict[str, Function.Argument] = { + arg.name: arg for arg in function.kw_args + } + + # TODO: handle *args and **kwargs sinks + for arg in positional: + param: Function.Argument + if len(pos_params) != 0: + param = pos_params.pop(0) + elif len(mixed_params) != 0: + param = mixed_params.pop(0) + else: + self.reporter.error(arg[0].location, "Too many positional arguments") + break + name: str = param.name + if name in required_positional: + required_positional.remove(name) + if name in required_keyword: + required_keyword.remove(name) + set_args.add(name) + mapped.append( + MappedArgument( + expr=arg[0], + type=arg[1], + argument=param, + ) + ) + + kw_params.update({arg.name: arg for arg in mixed_params}) + for name, arg in keywords.items(): + param: Function.Argument + if name not in kw_params: + if name in set_args: + self.reporter.error( + arg[0].location, f"Multiple values for argument '{name}'" + ) + else: + self.reporter.error( + arg[0].location, f"Unknown keyword argument '{name}'" + ) + continue + param = kw_params.pop(name) + if name in required_positional: + required_positional.remove(name) + if name in required_keyword: + required_keyword.remove(name) + set_args.add(name) + mapped.append( + MappedArgument( + expr=arg[0], + type=arg[1], + argument=param, + ) + ) + + def join_args(args: list[str]) -> str: + args = list(map(lambda a: f"'{a}'", args)) + if len(args) == 0: + return "" + if len(args) == 1: + return args[0] + return ", ".join(args[:-1]) + " and " + args[-1] + + if len(required_positional) != 0: + plural: str = "" if len(required_positional) == 1 else "s" + args: str = join_args(required_positional) + self.reporter.error( + call.location, + f"Missing required positional argument{plural}: {args}", + ) + + if len(required_keyword) != 0: + plural: str = "" if len(required_keyword) == 1 else "s" + args: str = join_args(required_keyword) + self.reporter.error( + call.location, + f"Missing required keyword argument{plural}: {args}", + ) + + return mapped diff --git a/midas/resolver/midas.py b/midas/checker/registry.py similarity index 74% rename from midas/resolver/midas.py rename to midas/checker/registry.py index 6872569..1324bbd 100644 --- a/midas/resolver/midas.py +++ b/midas/checker/registry.py @@ -1,6 +1,5 @@ from typing import Optional -import midas.ast.midas as m from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.types import ( AliasType, @@ -10,24 +9,15 @@ from midas.checker.types import ( GenericType, Operation, Type, - TypeVar, - UnknownType, substitute_typevars, ) -from midas.resolver.builtin import define_builtins -class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]): - """A resolver which evaluates Midas type definitions and build a registry""" - +class TypesRegistry: def __init__(self) -> None: self._types: dict[str, Type] = {} self._operations: dict[Operation.CallSignature, Type] = {} - self._local_variables: dict[str, TypeVar] = {} - - define_builtins(self) - def get_type(self, name: str) -> Type: """Get a type from its name @@ -40,8 +30,6 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T Returns: Type: the type """ - if name in self._local_variables: - return self._local_variables[name] if name in self._types: return self._types[name] raise NameError(f"Undefined type {name}") @@ -120,117 +108,6 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T ) self._operations[signature] = result - def resolve(self, stmts: list[m.Stmt]): - """Process a sequence of statements - - Args: - stmts (list[m.Stmt]): the statements - """ - for stmt in stmts: - stmt.accept(self) - - def visit_type_stmt(self, stmt: m.TypeStmt) -> None: - params: list[TypeVar] = [] - for param in stmt.params: - name: str = param.name.lexeme - bound: Optional[Type] = None - if param.bound is not None: - bound = param.bound.accept(self) - var = TypeVar(name=name, bound=bound) - self._local_variables[name] = var - params.append(var) - type: Type = stmt.type.accept(self) - if len(params) != 0: - type = GenericType(params=params, body=type) - name: str = stmt.name.lexeme - self.define_type(name, AliasType(name=name, type=type)) - self._local_variables.clear() - - def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... - - def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: - base: Type = stmt.type.accept(self) - for op in stmt.operations: - right: Type = op.operand.accept(self) - result: Type = op.result.accept(self) - self.define_operation( - left=base, - operator=op.name.lexeme, - right=right, - result=result, - ) - - def visit_op_stmt(self, stmt: m.OpStmt) -> None: ... - - def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ... - - def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ... - - def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ... - - def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ... - - def visit_get_expr(self, expr: m.GetExpr) -> None: ... - - def visit_variable_expr(self, expr: m.VariableExpr) -> None: ... - - def visit_grouping_expr(self, expr: m.GroupingExpr) -> None: - return expr.expr.accept(self) - - def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ... - - def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ... - - def visit_named_type(self, type: m.NamedType) -> Type: - return self.get_type(type.name.lexeme) - - def visit_generic_type(self, type: m.GenericType) -> Type: - type_: Type = type.type.accept(self) - params: list[Type] = [param.accept(self) for param in type.params] - return self.apply_generic(type_, params) - - def apply_generic(self, type: Type, params: list[Type]) -> Type: - match type: - case AliasType(name=name, type=base): - return AliasType(name=name, type=self.apply_generic(base, params)) - - case GenericType(params=type_vars, body=body): - n_params: int = len(params) - n_type_vars: int = len(type_vars) - if n_params < n_type_vars: - raise ValueError( - f"Missing type parameters, expected {n_type_vars} but only {n_params} provided" - ) - if n_params > n_type_vars: - raise ValueError( - f"Too many type parameters, expected {n_type_vars} but {n_params} provided" - ) - substitutions: dict[str, Type] = {} - for param, type_var in zip(params, type_vars): - if type_var.bound is not None and not self.is_subtype( - param, type_var.bound - ): - raise ValueError( - f"Type parameter {param} is not a subtype of {type_var.bound}" - ) - substitutions[type_var.name] = param - return substitute_typevars(body, substitutions) - case _: - raise ValueError(f"{type} is not a generic type") - - def visit_constraint_type(self, type: m.ConstraintType) -> Type: - type_: Type = type.type.accept(self) - type.constraint.accept(self) - # TODO - return UnknownType() - - def visit_complex_type(self, type: m.ComplexType) -> Type: - return ComplexType( - properties={ - prop.name.lexeme: prop.type.accept(self) for prop in type.properties - } - ) - def is_subtype(self, type1: Type, type2: Type) -> bool: """Check whether `type1` is a subtype of `type2` @@ -371,3 +248,33 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T return False return True + + def apply_generic(self, type: Type, params: list[Type]) -> Type: + match type: + case AliasType(name=name, type=base): + return AliasType(name=name, type=self.apply_generic(base, params)) + + case GenericType(params=type_vars, body=body): + n_params: int = len(params) + n_type_vars: int = len(type_vars) + if n_params < n_type_vars: + raise ValueError( + f"Missing type parameters, expected {n_type_vars} but only {n_params} provided" + ) + if n_params > n_type_vars: + raise ValueError( + f"Too many type parameters, expected {n_type_vars} but {n_params} provided" + ) + substitutions: dict[str, Type] = {} + for param, type_var in zip(params, type_vars): + if type_var.bound is not None and not self.is_subtype( + param, type_var.bound + ): + raise ValueError( + f"Type parameter {param} is not a subtype of {type_var.bound}" + ) + substitutions[type_var.name] = param + return substitute_typevars(body, substitutions) + + case _: + raise ValueError(f"{type} is not a generic type") diff --git a/midas/cli/main.py b/midas/cli/main.py index ae4295b..cafeeaf 100644 --- a/midas/cli/main.py +++ b/midas/cli/main.py @@ -10,7 +10,7 @@ import midas.ast.midas as m import midas.ast.python as p from midas.ast.location import Location from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter -from midas.checker.checker import Checker +from midas.checker.checker import TypeChecker from midas.checker.diagnostic import Diagnostic, DiagnosticType from midas.checker.types import Type from midas.cli.ansi import Ansi @@ -25,7 +25,6 @@ from midas.lexer.midas import MidasLexer from midas.lexer.token import Token, TokenType from midas.parser.midas import MidasParser from midas.parser.python import PythonParser -from midas.resolver.resolver import Resolver from midas.utils import UniversalJSONDumper @@ -98,18 +97,13 @@ def compile( ): logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN) source: str = file.read() - tree: ast.Module = ast.parse(source, filename=file.name) - parser = PythonParser() - stmts: list[p.Stmt] = parser.parse_module(tree) - resolver = Resolver() - resolver.resolve(*stmts) - types_paths: list[Path] = [Path(t.name).resolve() for t in types] - checker = Checker( - resolver.locals, - source_path=Path(file.name).resolve(), - types_paths=types_paths, - ) - diagnostics: list[Diagnostic] = checker.check(stmts) + + checker = TypeChecker() + for path in types: + checker.import_midas(Path(path.name).resolve()) + + checker.type_check_source(source, str(Path(file.name).resolve())) + diagnostics: list[Diagnostic] = checker.diagnostics lines: list[str] = source.split("\n") for diagnostic in diagnostics: print_diagnostic(lines, diagnostic) @@ -118,7 +112,7 @@ def compile( print( json.dumps( UniversalJSONDumper.dump( - checker.global_env, + checker.python_typer.global_env, [("Environment", "_children")], lambda obj: isinstance(obj, get_args(Type)), ), diff --git a/midas/resolver/builtin.py b/midas/resolver/builtin.py index 04bc6e3..c3c7e65 100644 --- a/midas/resolver/builtin.py +++ b/midas/resolver/builtin.py @@ -1,15 +1,9 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - +from midas.checker.registry import TypesRegistry from midas.checker.types import BaseType, Type, UnitType -if TYPE_CHECKING: - from midas.resolver.midas import MidasResolver - -def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type): - ctx.define_operation( +def op(reg: TypesRegistry, t1: Type, operator: str, t2: Type, t3: Type): + reg.define_operation( left=t1, operator=operator, right=t2, @@ -17,8 +11,8 @@ def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type): ) -def basic_op(ctx: MidasResolver, type: Type, op: str): - ctx.define_operation( +def basic_op(reg: TypesRegistry, type: Type, op: str): + reg.define_operation( left=type, operator=op, right=type, @@ -26,47 +20,47 @@ def basic_op(ctx: MidasResolver, type: Type, op: str): ) -def define_builtins(ctx: MidasResolver): +def define_builtins(reg: TypesRegistry): """Define builtin types and operations""" - unit = ctx.define_type("None", UnitType()) - bool = ctx.define_type("bool", BaseType(name="bool")) - int = ctx.define_type("int", BaseType(name="int")) - float = ctx.define_type("float", BaseType(name="float")) - str = ctx.define_type("str", BaseType(name="str")) + unit = reg.define_type("None", UnitType()) + bool = reg.define_type("bool", BaseType(name="bool")) + int = reg.define_type("int", BaseType(name="int")) + float = reg.define_type("float", BaseType(name="float")) + str = reg.define_type("str", BaseType(name="str")) - basic_op(ctx, int, "__add__") # int + int = int - basic_op(ctx, int, "__sub__") # int - int = int - basic_op(ctx, int, "__mul__") # int * int = int - basic_op(ctx, int, "__pow__") # int ** int = int - basic_op(ctx, int, "__mod__") # int % int = int - basic_op(ctx, int, "__and__") # int & int = int - basic_op(ctx, int, "__or__") # int | int = int - basic_op(ctx, int, "__xor__") # int ^ int = int - op(ctx, int, "__lt__", int, bool) # int < int = bool - op(ctx, int, "__gt__", int, bool) # int > int = bool - op(ctx, int, "__le__", int, bool) # int <= int = bool - op(ctx, int, "__ge__", int, bool) # int >= int = bool - op(ctx, int, "__eq__", int, bool) # int == int = bool - basic_op(ctx, float, "__add__") # float + float = float - basic_op(ctx, float, "__sub__") # float - float = float - basic_op(ctx, float, "__mul__") # float * float = float - basic_op(ctx, float, "__truediv__") # float / float = float - op(ctx, float, "__lt__", float, bool) # float < float = bool - op(ctx, float, "__gt__", float, bool) # float > float = bool - op(ctx, float, "__le__", float, bool) # float <= float = bool - op(ctx, float, "__ge__", float, bool) # float >= float = bool - op(ctx, float, "__eq__", float, bool) # float == float = bool - basic_op(ctx, str, "__add__") # str + str = str - op(ctx, str, "__eq__", str, bool) # str == str = bool + basic_op(reg, int, "__add__") # int + int = int + basic_op(reg, int, "__sub__") # int - int = int + basic_op(reg, int, "__mul__") # int * int = int + basic_op(reg, int, "__pow__") # int ** int = int + basic_op(reg, int, "__mod__") # int % int = int + basic_op(reg, int, "__and__") # int & int = int + basic_op(reg, int, "__or__") # int | int = int + basic_op(reg, int, "__xor__") # int ^ int = int + op(reg, int, "__lt__", int, bool) # int < int = bool + op(reg, int, "__gt__", int, bool) # int > int = bool + op(reg, int, "__le__", int, bool) # int <= int = bool + op(reg, int, "__ge__", int, bool) # int >= int = bool + op(reg, int, "__eq__", int, bool) # int == int = bool + basic_op(reg, float, "__add__") # float + float = float + basic_op(reg, float, "__sub__") # float - float = float + basic_op(reg, float, "__mul__") # float * float = float + basic_op(reg, float, "__truediv__") # float / float = float + op(reg, float, "__lt__", float, bool) # float < float = bool + op(reg, float, "__gt__", float, bool) # float > float = bool + op(reg, float, "__le__", float, bool) # float <= float = bool + op(reg, float, "__ge__", float, bool) # float >= float = bool + op(reg, float, "__eq__", float, bool) # float == float = bool + basic_op(reg, str, "__add__") # str + str = str + op(reg, str, "__eq__", str, bool) # str == str = bool - op(ctx, int, "__lt__", float, bool) # int < float = bool - op(ctx, int, "__gt__", float, bool) # int > float = bool - op(ctx, int, "__le__", float, bool) # int <= float = bool - op(ctx, int, "__ge__", float, bool) # int >= float = bool - op(ctx, int, "__eq__", float, bool) # int == float = bool + op(reg, int, "__lt__", float, bool) # int < float = bool + op(reg, int, "__gt__", float, bool) # int > float = bool + op(reg, int, "__le__", float, bool) # int <= float = bool + op(reg, int, "__ge__", float, bool) # int >= float = bool + op(reg, int, "__eq__", float, bool) # int == float = bool - op(ctx, float, "__lt__", int, bool) # float < int = bool - op(ctx, float, "__gt__", int, bool) # float > int = bool - op(ctx, float, "__le__", int, bool) # float <= int = bool - op(ctx, float, "__ge__", int, bool) # float >= int = bool - op(ctx, float, "__eq__", int, bool) # float == int = bool + op(reg, float, "__lt__", int, bool) # float < int = bool + op(reg, float, "__gt__", int, bool) # float > int = bool + op(reg, float, "__le__", int, bool) # float <= int = bool + op(reg, float, "__ge__", int, bool) # float >= int = bool + op(reg, float, "__eq__", int, bool) # float == int = bool diff --git a/tests/checker.py b/tests/checker.py index 27a94cb..3ceb34e 100644 --- a/tests/checker.py +++ b/tests/checker.py @@ -1,14 +1,11 @@ -import ast import json from dataclasses import asdict, dataclass, field from pathlib import Path import midas.ast.python as p -from midas.checker.checker import Checker +from midas.checker.checker import TypeChecker from midas.checker.diagnostic import Diagnostic from midas.checker.types import Type -from midas.parser.python import PythonParser -from midas.resolver.resolver import Resolver from tests.base import Tester from tests.serializer.python import PythonAstJsonSerializer @@ -36,24 +33,16 @@ class CheckerTester(Tester): if not path.is_file(): raise TypeError(f"Test '{path}' is not a file") - types_paths: list[Path] = [] + result: CaseResult = CaseResult() + + checker = TypeChecker() types_path: Path = path.with_suffix(".midas") if types_path.exists(): - types_paths.append(types_path) - source: str = path.read_text() - tree: ast.Module = ast.parse(source, filename=path) - parser = PythonParser() - stmts: list[p.Stmt] = parser.parse_module(tree) - resolver = Resolver() - resolver.resolve(*stmts) - result: CaseResult = CaseResult() - checker = Checker( - resolver.locals, - source_path=path, - types_paths=types_paths, - ) + checker.import_midas(types_path) - diagnostics: list[Diagnostic] = checker.check(stmts) + checker.type_check(path) + + diagnostics: list[Diagnostic] = checker.diagnostics for diagnostic in diagnostics: result.diagnostics.append( { @@ -72,7 +61,7 @@ class CheckerTester(Tester): } ) - judgements: list[tuple[p.Expr, Type]] = checker.judgements + judgements: list[tuple[p.Expr, Type]] = checker.python_typer.judgements serializer = PythonAstJsonSerializer() for expr, type in judgements: loc = expr.location