From f89722fad809d7d681b946705ebf9a7811e7ac02 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 8 Jun 2026 10:52:34 +0200 Subject: [PATCH 01/64] feat(checker): add generic type structure --- midas/checker/types.py | 74 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/midas/checker/types.py b/midas/checker/types.py index 83707b6..15079e0 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Optional @dataclass(frozen=True, kw_only=True) @@ -57,4 +58,75 @@ class Operation: right: Type -Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType +@dataclass(frozen=True, kw_only=True) +class TypeVar: + name: str + bound: Optional[Type] + + +@dataclass(frozen=True, kw_only=True) +class GenericType: + params: list[TypeVar] + body: Type + + +def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: + def sub_argument(arg: Function.Argument): + return Function.Argument( + pos=arg.pos, + name=arg.name, + type=substitute_typevars(arg.type, substitutions), + required=arg.required, + ) + + match type: + case BaseType(name=name) if name in substitutions: + return substitutions[name] + + case AliasType(name=name, type=type2): + return AliasType(name=name, type=substitute_typevars(type2, substitutions)) + + case Function( + name=name, + pos_args=pos_args, + args=args, + kw_args=kw_args, + returns=returns, + ): + return Function( + name=name, + pos_args=list(map(sub_argument, pos_args)), + args=list(map(sub_argument, args)), + kw_args=list(map(sub_argument, kw_args)), + returns=substitute_typevars(returns, substitutions), + ) + + case ComplexType(properties=properties): + properties2: dict[str, Type] = { + name: substitute_typevars(prop, substitutions) + for name, prop in properties.items() + } + return ComplexType(properties=properties2) + + case TypeVar(name=name): + if name in substitutions: + return substitutions[name] + raise ValueError(f"Missing TypeVar substitution for {name}") + + case UnknownType() | UnitType(): + return type + + case _: + raise NotImplementedError(f"Unsupported type {type}") + + +Type = ( + BaseType + | AliasType + | UnknownType + | UnitType + | Function + | ComplexType + | TypeVar + | GenericType +) From 1d00875a8cd9e0d45be14aa536f873aaaeac886b Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 8 Jun 2026 10:55:15 +0200 Subject: [PATCH 02/64] feat(resolver): handle generics definition --- midas/resolver/midas.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/midas/resolver/midas.py b/midas/resolver/midas.py index 468f59a..063a752 100644 --- a/midas/resolver/midas.py +++ b/midas/resolver/midas.py @@ -4,8 +4,10 @@ import midas.ast.midas as m from midas.checker.types import ( AliasType, ComplexType, + GenericType, Operation, Type, + TypeVar, UnknownType, ) from midas.resolver.builtin import define_builtins @@ -18,6 +20,8 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T self._types: dict[str, Type] = {} self._operations: dict[Operation.CallSignature, Type] = {} + self._local_variables: dict[str, TypeVar] = {} + define_builtins(self) def get_type(self, name: str) -> Type: @@ -32,10 +36,11 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T Returns: Type: the type """ - type: Optional[Type] = self._types.get(name) - if type is None: - raise NameError(f"Undefined type {name}") - return type + if name in self._local_variables: + return self._local_variables[name] + if name in self._types: + return self._types[name] + raise NameError(f"Undefined type {name}") def get_operation_result( self, left: Type, operator: str, right: Type @@ -121,12 +126,21 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T stmt.accept(self) def visit_type_stmt(self, stmt: m.TypeStmt) -> None: - type: Type = stmt.type.accept(self) + params: list[TypeVar] = [] for param in stmt.params: + name: str = param.name.lexeme + bound: Optional[Type] = None if param.bound is not None: - param.bound.accept(self) + bound = param.bound.accept(self) + var = TypeVar(name=name, bound=bound) + self._local_variables[name] = var + params.append(var) + type: Type = stmt.type.accept(self) + if len(params) != 0: + type = GenericType(params=params, body=type) name: str = stmt.name.lexeme self.define_type(name, AliasType(name=name, type=type)) + self._local_variables.clear() def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... From d51d24f8657c184ae1171c99ef6c3889498fff08 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 8 Jun 2026 10:56:27 +0200 Subject: [PATCH 03/64] refactor(checker): move unfold_type to types.py --- midas/checker/checker.py | 12 +++--------- midas/checker/types.py | 8 ++++++++ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/midas/checker/checker.py b/midas/checker/checker.py index ab7261c..3100587 100644 --- a/midas/checker/checker.py +++ b/midas/checker/checker.py @@ -19,6 +19,7 @@ from midas.checker.types import ( Type, UnitType, UnknownType, + unfold_type, ) from midas.lexer.midas import MidasLexer from midas.lexer.token import Token @@ -178,13 +179,6 @@ class Checker( stmts: list[m.Stmt] = parser.parse() self.ctx.resolve(stmts) - def unfold_type(self, type: Type) -> Type: - match type: - case AliasType(type=ref_type): - return self.unfold_type(ref_type) - case _: - return type - def is_subtype(self, type1: Type, type2: Type) -> bool: """Check whether `type1` is a subtype of `type2` @@ -470,7 +464,7 @@ class Checker( def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type): object: Type = self.type_of(target.object) - base_object: Type = self.unfold_type(object) + base_object: Type = unfold_type(object) match base_object: case ComplexType(properties=properties): if target.name not in properties: @@ -611,7 +605,7 @@ class Checker( def visit_get_expr(self, expr: p.GetExpr) -> Type: object: Type = self.type_of(expr.object) - base_object: Type = self.unfold_type(object) + base_object: Type = unfold_type(object) match base_object: case ComplexType(properties=properties): if expr.name not in properties: diff --git a/midas/checker/types.py b/midas/checker/types.py index 15079e0..ee41e14 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -120,6 +120,14 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: raise NotImplementedError(f"Unsupported type {type}") +def unfold_type(type: Type) -> Type: + match type: + case AliasType(type=ref_type): + return unfold_type(ref_type) + case _: + return type + + Type = ( BaseType | AliasType From f9c15abaf47ef797196da2b231461f84913affe2 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 8 Jun 2026 10:57:50 +0200 Subject: [PATCH 04/64] refactor(checker): move is_subtype to resolver --- midas/checker/checker.py | 147 +-------------------------------------- midas/resolver/midas.py | 144 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 146 deletions(-) diff --git a/midas/checker/checker.py b/midas/checker/checker.py index 3100587..0c54fa4 100644 --- a/midas/checker/checker.py +++ b/midas/checker/checker.py @@ -6,13 +6,10 @@ from typing import Optional import midas.ast.midas as m import midas.ast.python as p from midas.ast.location import Location -from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.diagnostic import Diagnostic, DiagnosticType from midas.checker.environment import Environment from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS from midas.checker.types import ( - AliasType, - BaseType, ComplexType, Function, Operation, @@ -180,149 +177,7 @@ class Checker( self.ctx.resolve(stmts) def is_subtype(self, type1: Type, type2: Type) -> bool: - """Check whether `type1` is a subtype of `type2` - - For more details on the rules checked here, see TAPL Chap. 15-16-17 - - Args: - type1 (Type): the potential subtype - type2 (Type): the potential supertype - - Returns: - bool: whether `type1` is a subtype of `type2` - """ - - if type1 == type2: - return True - - match (type1, type2): - case (AliasType(type=base1), _): - return self.is_subtype(base1, type2) - - case (BaseType(name=name1), BaseType(name=name2)): - return name1 in BUILTIN_SUBTYPES.get(name2, set()) - - case (ComplexType(properties=props1), ComplexType(properties=props2)): - for k, t in props2.items(): - if k not in props1: - return False - if not self.is_subtype(props1[k], t): - return False - return True - - case (Function(returns=return1), Function(returns=return2)): - if not self.is_func_subtype(type1, type2): - return False - if not self.is_subtype(return1, return2): - return False - return True - - return False - - # TODO: verify the logic in here - def is_func_subtype(self, func1: Function, func2: Function) -> bool: - """Check whether a function is a subtype of another - - Args: - func1 (Function): the potential function subtype - func2 (Function): the potential function supertype - - Returns: - bool: whether `func1` is a subtype of `func2` - """ - if not self.is_subtype(func1.returns, func2.returns): - return False - - pos1: list[Function.Argument] = func1.pos_args - mixed1: list[Function.Argument] = func1.args - kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args} - pos2: list[Function.Argument] = func2.pos_args - mixed2: list[Function.Argument] = func2.args - kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args} - - mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2} - mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2} - - def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool: - if not self.is_subtype(sub.type, sup.type): - return False - if not sup.required and sub.required: - return False - return True - - for arg1 in pos1: - arg2: Function.Argument - if arg1.pos < len(pos2): - arg2 = pos2[arg1.pos] - elif arg1.pos in mixed_by_pos: - arg2 = mixed_by_pos[arg1.pos] - elif not arg1.required: - continue - else: - return False - if not is_arg_subtype(arg2, arg1): - return False - - for name, arg1 in kw1.items(): - arg2: Function.Argument - if name in kw2: - arg2 = kw2[name] - elif name in mixed_by_name: - arg2 = mixed_by_name[name] - elif not arg1.required: - continue - else: - return False - if not is_arg_subtype(arg2, arg1): - return False - - for arg1 in mixed1: - pos_arg2: Optional[Function.Argument] = None - kw_arg2: Optional[Function.Argument] = None - if arg1.name in kw2: - kw_arg2 = kw2[arg1.name] - elif arg1.name in mixed_by_name: - kw_arg2 = mixed_by_name[arg1.name] - if arg1.pos < len(pos2): - pos_arg2 = pos2[arg1.pos] - elif arg1.pos in mixed_by_pos: - pos_arg2 = mixed_by_pos[arg1.pos] - - # No match in func2 and arg is required - if pos_arg2 is None and kw_arg2 is None and arg1.required: - return False - - # Matching keyword argument - if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1): - return False - - # Matching positional argument - if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1): - return False - - mixed_positions: set[int] = {a.pos for a in mixed1} - mixed_names: set[str] = {a.name for a in mixed1} - for arg2 in pos2: - if not arg2.required: - continue - if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions: - return False - - for name, arg2 in kw2.items(): - if not arg2.required: - continue - if name not in kw1 and name not in mixed_names: - return False - - for arg2 in mixed2: - if arg2.required: - continue - pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions - kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names - if not pos_match or not kw_match: - return False - - return True + return self.ctx.is_subtype(type1, type2) def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: self.type_of(stmt.expr) diff --git a/midas/resolver/midas.py b/midas/resolver/midas.py index 063a752..c7b168d 100644 --- a/midas/resolver/midas.py +++ b/midas/resolver/midas.py @@ -1,9 +1,12 @@ from typing import Optional import midas.ast.midas as m +from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.types import ( AliasType, + BaseType, ComplexType, + Function, GenericType, Operation, Type, @@ -198,3 +201,144 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T prop.name.lexeme: prop.type.accept(self) for prop in type.properties } ) + + def is_subtype(self, type1: Type, type2: Type) -> bool: + """Check whether `type1` is a subtype of `type2` + + For more details on the rules checked here, see TAPL Chap. 15-16-17 + + Args: + type1 (Type): the potential subtype + type2 (Type): the potential supertype + + Returns: + bool: whether `type1` is a subtype of `type2` + """ + + if type1 == type2: + return True + + match (type1, type2): + case (AliasType(type=base1), _): + return self.is_subtype(base1, type2) + + case (BaseType(name=name1), BaseType(name=name2)): + return name1 in BUILTIN_SUBTYPES.get(name2, set()) + + case (ComplexType(properties=props1), ComplexType(properties=props2)): + for k, t in props2.items(): + if k not in props1: + return False + if not self.is_subtype(props1[k], t): + return False + return True + + case (Function(), Function()): + return self.is_func_subtype(type1, type2) + + return False + + # TODO: verify the logic in here + def is_func_subtype(self, func1: Function, func2: Function) -> bool: + """Check whether a function is a subtype of another + + Args: + func1 (Function): the potential function subtype + func2 (Function): the potential function supertype + + Returns: + bool: whether `func1` is a subtype of `func2` + """ + if not self.is_subtype(func1.returns, func2.returns): + return False + + pos1: list[Function.Argument] = func1.pos_args + mixed1: list[Function.Argument] = func1.args + kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args} + pos2: list[Function.Argument] = func2.pos_args + mixed2: list[Function.Argument] = func2.args + kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args} + + mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2} + mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2} + + def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool: + if not self.is_subtype(sub.type, sup.type): + return False + if not sup.required and sub.required: + return False + return True + + for arg1 in pos1: + arg2: Function.Argument + if arg1.pos < len(pos2): + arg2 = pos2[arg1.pos] + elif arg1.pos in mixed_by_pos: + arg2 = mixed_by_pos[arg1.pos] + elif not arg1.required: + continue + else: + return False + if not is_arg_subtype(arg2, arg1): + return False + + for name, arg1 in kw1.items(): + arg2: Function.Argument + if name in kw2: + arg2 = kw2[name] + elif name in mixed_by_name: + arg2 = mixed_by_name[name] + elif not arg1.required: + continue + else: + return False + if not is_arg_subtype(arg2, arg1): + return False + + for arg1 in mixed1: + pos_arg2: Optional[Function.Argument] = None + kw_arg2: Optional[Function.Argument] = None + if arg1.name in kw2: + kw_arg2 = kw2[arg1.name] + elif arg1.name in mixed_by_name: + kw_arg2 = mixed_by_name[arg1.name] + if arg1.pos < len(pos2): + pos_arg2 = pos2[arg1.pos] + elif arg1.pos in mixed_by_pos: + pos_arg2 = mixed_by_pos[arg1.pos] + + # No match in func2 and arg is required + if pos_arg2 is None and kw_arg2 is None and arg1.required: + return False + + # Matching keyword argument + if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1): + return False + + # Matching positional argument + if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1): + return False + + mixed_positions: set[int] = {a.pos for a in mixed1} + mixed_names: set[str] = {a.name for a in mixed1} + for arg2 in pos2: + if not arg2.required: + continue + if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions: + return False + + for name, arg2 in kw2.items(): + if not arg2.required: + continue + if name not in kw1 and name not in mixed_names: + return False + + for arg2 in mixed2: + if arg2.required: + continue + pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions + kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names + if not pos_match or not kw_match: + return False + + return True From c4c142482a828452fa15ebda6e5dc18f651b7998 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 8 Jun 2026 10:59:01 +0200 Subject: [PATCH 05/64] feat(resolver): handle generic application --- midas/resolver/midas.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/midas/resolver/midas.py b/midas/resolver/midas.py index c7b168d..6872569 100644 --- a/midas/resolver/midas.py +++ b/midas/resolver/midas.py @@ -12,6 +12,7 @@ from midas.checker.types import ( Type, TypeVar, UnknownType, + substitute_typevars, ) from midas.resolver.builtin import define_builtins @@ -186,8 +187,36 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T def visit_generic_type(self, type: m.GenericType) -> Type: type_: Type = type.type.accept(self) params: list[Type] = [param.accept(self) for param in type.params] - # TODO - return UnknownType() + return self.apply_generic(type_, params) + + def apply_generic(self, type: Type, params: list[Type]) -> Type: + match type: + case AliasType(name=name, type=base): + return AliasType(name=name, type=self.apply_generic(base, params)) + + case GenericType(params=type_vars, body=body): + n_params: int = len(params) + n_type_vars: int = len(type_vars) + if n_params < n_type_vars: + raise ValueError( + f"Missing type parameters, expected {n_type_vars} but only {n_params} provided" + ) + if n_params > n_type_vars: + raise ValueError( + f"Too many type parameters, expected {n_type_vars} but {n_params} provided" + ) + substitutions: dict[str, Type] = {} + for param, type_var in zip(params, type_vars): + if type_var.bound is not None and not self.is_subtype( + param, type_var.bound + ): + raise ValueError( + f"Type parameter {param} is not a subtype of {type_var.bound}" + ) + substitutions[type_var.name] = param + return substitute_typevars(body, substitutions) + case _: + raise ValueError(f"{type} is not a generic type") def visit_constraint_type(self, type: m.ConstraintType) -> Type: type_: Type = type.type.accept(self) From 111afe4dd4ed212e9e5b1005b429e885ba208aa0 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 8 Jun 2026 13:38:35 +0200 Subject: [PATCH 06/64] feat(checker): add reporter class --- midas/checker/diagnostic.py | 16 +++++++--- midas/checker/reporter.py | 63 +++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 5 deletions(-) create mode 100644 midas/checker/reporter.py diff --git a/midas/checker/diagnostic.py b/midas/checker/diagnostic.py index 77f687e..2925653 100644 --- a/midas/checker/diagnostic.py +++ b/midas/checker/diagnostic.py @@ -14,7 +14,7 @@ class DiagnosticType(StrEnum): @dataclass(frozen=True) class Diagnostic: - file_path: Path + file_path: Optional[str | Path] location: Location type: DiagnosticType message: str @@ -28,10 +28,16 @@ class Diagnostic: and self.location.end_col_offset is not None ): end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}" - loc: str = ( - f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}" - ) - return f"{self.type} in {self.file_path} {loc}" + + loc: str = "" + if self.file_path is not None: + loc += f" in {self.file_path}" + if end_loc is None: + loc += f" at {start_loc}" + else: + loc += f" from {start_loc} to {end_loc}" + + return f"{self.type}{loc}" def __str__(self) -> str: return f"{self.location_str}: {self.message}" diff --git a/midas/checker/reporter.py b/midas/checker/reporter.py new file mode 100644 index 0000000..b68766a --- /dev/null +++ b/midas/checker/reporter.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import Optional + +from midas.ast.location import Location +from midas.checker.diagnostic import Diagnostic, DiagnosticType + + +class Reporter: + def __init__(self): + self.diagnostics: list[Diagnostic] = [] + + def report( + self, + path: Optional[str], + type: DiagnosticType, + location: Location, + message: str, + ): + self.diagnostics.append( + Diagnostic( + file_path=path, + location=location, + type=type, + message=message, + ) + ) + + def for_file(self, path: Optional[str]) -> FileReporter: + return FileReporter(self, path) + + +class FileReporter: + def __init__(self, base_reporter: Reporter, path: Optional[str]) -> None: + self.base_reporter: Reporter = base_reporter + self.path: Optional[str] = path + + def for_file(self, path: Optional[str]) -> FileReporter: + return FileReporter(self.base_reporter, path) + + def report(self, type: DiagnosticType, location: Location, message: str): + self.base_reporter.report(self.path, type, location, message) + + def error(self, location: Location, message: str): + self.report( + type=DiagnosticType.ERROR, + location=location, + message=message, + ) + + def warning(self, location: Location, message: str): + self.report( + type=DiagnosticType.WARNING, + location=location, + message=message, + ) + + def info(self, location: Location, message: str): + self.report( + type=DiagnosticType.INFO, + location=location, + message=message, + ) From 2ff1f27614c43b8e853eb325a706d84bab58c4f6 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 8 Jun 2026 13:41:42 +0200 Subject: [PATCH 07/64] refactor(checker): restructure around shared registry restructure the type checker with a shared TypesRegistry used by MidasTyper and PythonTyper this commit also relocates some methods in more appropriate places, such as is_subtype and apply_generic (now in TypesRegistry) --- midas/checker/checker.py | 674 +----------------- midas/checker/diagnostic.py | 3 +- midas/checker/midas.py | 137 ++++ midas/checker/python.py | 626 ++++++++++++++++ .../midas.py => checker/registry.py} | 155 +--- midas/cli/main.py | 24 +- midas/resolver/builtin.py | 96 ++- tests/checker.py | 29 +- 8 files changed, 882 insertions(+), 862 deletions(-) create mode 100644 midas/checker/midas.py create mode 100644 midas/checker/python.py rename midas/{resolver/midas.py => checker/registry.py} (74%) diff --git a/midas/checker/checker.py b/midas/checker/checker.py index 0c54fa4..c26f0aa 100644 --- a/midas/checker/checker.py +++ b/midas/checker/checker.py @@ -1,661 +1,35 @@ -import logging -from dataclasses import dataclass from pathlib import Path from typing import Optional -import midas.ast.midas as m -import midas.ast.python as p -from midas.ast.location import Location -from midas.checker.diagnostic import Diagnostic, DiagnosticType -from midas.checker.environment import Environment -from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS -from midas.checker.types import ( - ComplexType, - Function, - Operation, - Type, - UnitType, - UnknownType, - unfold_type, -) -from midas.lexer.midas import MidasLexer -from midas.lexer.token import Token -from midas.parser.midas import MidasParser -from midas.resolver.midas import MidasResolver +from midas.checker.diagnostic import Diagnostic +from midas.checker.midas import MidasTyper +from midas.checker.python import PythonTyper +from midas.checker.registry import TypesRegistry +from midas.checker.reporter import Reporter -class ReturnException(Exception): - pass +class TypeChecker: + def __init__(self): + self.types: TypesRegistry = TypesRegistry() + self.reporter: Reporter = Reporter() + self.midas_typer = MidasTyper(self.types, self.reporter) + self.python_typer = PythonTyper(self.types, self.reporter) -@dataclass(frozen=True, kw_only=True) -class MappedArgument: - expr: p.Expr - type: Type - argument: Function.Argument + def import_midas(self, path: Path): + source: str = path.read_text() + return self.import_midas_source(source, path=str(path)) + def import_midas_source(self, source: str, path: Optional[str] = None): + self.midas_typer.process(source, path) -class Checker( - p.Stmt.Visitor[None], - p.Expr.Visitor[Type], - p.MidasType.Visitor[Type], -): - """A type checker which can use custom type definitions""" + def type_check(self, path: Path): + source: str = path.read_text() + return self.type_check_source(source, path=str(path)) - def __init__( - self, - locals: dict[p.Expr, int], - source_path: Path, - types_paths: list[Path], - ): - self.logger: logging.Logger = logging.getLogger("Checker") - self.source_path: Path = source_path - self.types_paths: list[Path] = types_paths - self.ctx: MidasResolver = MidasResolver() - self.global_env: Environment = Environment() - self.env: Environment = self.global_env - self.locals: dict[p.Expr, int] = locals - self.diagnostics: list[Diagnostic] = [] - self.judgements: list[tuple[p.Expr, Type]] = [] + def type_check_source(self, source: str, path: Optional[str] = None): + self.python_typer.process(source, path) - def diagnostic(self, type: DiagnosticType, location: Location, message: str): - self.diagnostics.append( - Diagnostic( - file_path=self.source_path, - location=location, - type=type, - message=message, - ) - ) - - def error(self, location: Location, message: str): - self.diagnostic( - type=DiagnosticType.ERROR, - location=location, - message=message, - ) - - def warning(self, location: Location, message: str): - self.diagnostic( - type=DiagnosticType.WARNING, - location=location, - message=message, - ) - - def info(self, location: Location, message: str): - self.diagnostic( - type=DiagnosticType.INFO, - location=location, - message=message, - ) - - def type_of(self, expr: p.Expr) -> Type: - """Evaluate the type of an expression - - Args: - expr (p.Expr): the expression to evaluate - - Returns: - Type: the type of the given expression - """ - type: Type = expr.accept(self) - self.judgements.append((expr, type)) - return type - - def process_block(self, block: list[p.Stmt], env: Environment) -> bool: - """Evaluate a sequence of statements - - Args: - block (list[p.Stmt]): the statements to evaluate - env (Environment): the environment in which to evaluate - - Returns: - bool: whether a return statement is present in the block - """ - previous_env: Environment = self.env - self.env = env - returned: bool = False - for i, stmt in enumerate(block): - try: - stmt.accept(self) - except ReturnException: - returned = True - if i < len(block) - 1: - self.warning(block[i + 1].location, "Unreachable statement") - break - self.env = previous_env - return returned - - def check(self, statements: list[p.Stmt]) -> list[Diagnostic]: - """Type check a sequence of statements and returns diagnostics - - Args: - statements (list[p.Stmt]): the statements to evaluate and check - - Returns: - list[Diagnostic]: the list of diagnostics (errors, warning, etc.) - """ - self.diagnostics = [] - - for path in self.types_paths: - self.import_midas(path) - self.logger.debug(f"Midas types: {self.ctx._types}") - self.logger.debug(f"Midas operations: {self.ctx._operations}") - - for stmt in statements: - stmt.accept(self) - - self.logger.debug(f"Final environment: {self.env.flat_dict()}") - return self.diagnostics - - def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]: - """Look up a variable in the environment it was declared - - Args: - name (str): the name of the variable - expr (p.Expr): the variable expression, used to lookup the scope distance - - Returns: - Optional[Type]: the type of the variable, or None if it was not found - """ - distance: Optional[int] = self.locals.get(expr) - if distance is not None: - return self.env.get_at(distance, name) - return self.global_env.get(name) - - def import_midas(self, path: Path) -> None: - """Import Midas definitions from a path - - Args: - path (Path): the import path - """ - self.logger.debug(f"Importing type definitions from {path}") - lexer: MidasLexer = MidasLexer(path.read_text()) - tokens: list[Token] = lexer.process() - parser: MidasParser = MidasParser(tokens) - stmts: list[m.Stmt] = parser.parse() - self.ctx.resolve(stmts) - - def is_subtype(self, type1: Type, type2: Type) -> bool: - return self.ctx.is_subtype(type1, type2) - - def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: - self.type_of(stmt.expr) - - def visit_function(self, stmt: p.Function) -> None: - env: Environment = Environment(self.env) - pos_args: list[Function.Argument] = [] - args: list[Function.Argument] = [] - kw_args: list[Function.Argument] = [] - - def eval_arg_type(arg: p.Function.Argument) -> Type: - if arg.type is not None: - return arg.type.accept(self) - if arg.default is not None: - return arg.default.accept(self) - return UnknownType() - - pos: int = 0 - for arg in stmt.posonlyargs: - pos_args.append( - Function.Argument( - pos=pos, - name=arg.name, - type=eval_arg_type(arg), - required=arg.default is None, - ) - ) - pos += 1 - for arg in stmt.args: - args.append( - Function.Argument( - pos=pos, - name=arg.name, - type=eval_arg_type(arg), - required=arg.default is None, - ) - ) - pos += 1 - for arg in stmt.kwonlyargs: - kw_args.append( - Function.Argument( - pos=pos, # not relevant - name=arg.name, - type=eval_arg_type(arg), - required=arg.default is None, - ) - ) - pos += 1 - - for arg in pos_args + args + kw_args: - env.define(arg.name, arg.type) - - returns_hint: Optional[Type] = None - if stmt.returns is not None: - returns_hint = stmt.returns.accept(self) - # Early define to handle simple fully-typed recursion - inside_function: Function = Function( - name=stmt.name, - pos_args=pos_args, - args=args, - kw_args=kw_args, - returns=returns_hint, - ) - self.env.define(stmt.name, inside_function) - - returned: bool = self.process_block(stmt.body, env) - inferred_return: Type = UnknownType() - if not returned: - env.return_types.append(UnitType()) - return_types: set[Type] = set(env.return_types) - if len(return_types) == 1: - inferred_return = list(return_types)[0] - elif len(return_types) > 1: - self.error( - stmt.location, - f"Mixed return types: {env.return_types}", - ) - - returns: Type = UnknownType() - if returns_hint is not None: - assert stmt.returns is not None - returns = returns_hint - if returns != inferred_return: - self.error( - stmt.returns.location, - f"Return type mismatch, annotated {returns} but returns {inferred_return}", - ) - else: - returns = inferred_return - - # TODO: handle *args and **kwargs sinks - function: Function = Function( - name=stmt.name, - pos_args=pos_args, - args=args, - kw_args=kw_args, - returns=returns, - ) - self.env.define(stmt.name, function) - - def visit_type_assign(self, stmt: p.TypeAssign) -> None: - # TODO check not yet defined locally - type: Type = stmt.type.accept(self) - self.env.define(stmt.name, type) - - def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: - value_type: Type = self.type_of(stmt.value) - for target in stmt.targets: - self._assign(stmt.location, target, value_type) - - def _assign(self, location: Location, target: p.Expr, value_type: Type): - match target: - case p.VariableExpr(): - self._assign_var(location, target, value_type) - - case p.GetExpr(): - self._assign_attr(location, target, value_type) - - case _: - if not isinstance(target, p.VariableExpr): - self.logger.warning(f"Unsupported assignment to {target}") - self.warning(target.location, f"Unsupported assignment to {target}") - - def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type): - name: str = target.name - var_type: Optional[Type] = self.look_up_variable(name, target) - - if var_type is None: - self.env.define(name, value_type) - else: - # S <: T - # Γ, x: T v: S - # x = v - if not self.is_subtype(value_type, var_type): - self.error( - location, - f"Cannot assign {value_type} to {name} of type {var_type}", - ) - - def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type): - object: Type = self.type_of(target.object) - base_object: Type = unfold_type(object) - match base_object: - case ComplexType(properties=properties): - if target.name not in properties: - self.error( - target.location, f"Unknown property '{target.name} on {object}" - ) - return - - prop_type: Type = properties[target.name] - if not self.is_subtype(value_type, prop_type): - self.error( - location, - f"Cannot assign {value_type} to property '{target.name}' of type {prop_type} on {object}", - ) - return - - case UnknownType(): - pass - - case _: - self.error( - target.location, - f"Cannot assign {value_type} to unknown property '{target.name}' on {object}", - ) - - def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: - type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType() - self.env.return_types.append(type) - raise ReturnException() - - def visit_if_stmt(self, stmt: p.IfStmt) -> None: - # Not evaluated in sub-environment because assignments in the test leak out of the if - # For example: - # if (m := 1 + 1) < 2: - # ... - # print(m) # <- m is still defined - test_type: Type = stmt.test.accept(self) - - # TODO Allow subtypes or any type - if test_type != self.ctx.get_type("bool"): - self.error( - stmt.test.location, f"If test must be a boolean, got {test_type}" - ) - - env: Environment = Environment(self.env) - body_returned: bool = self.process_block(stmt.body, env) - else_returned: bool = self.process_block(stmt.orelse, env) - self.env.return_types.extend(env.return_types) - if body_returned and else_returned: - raise ReturnException() - - def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: - method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) - if method is None: - self.logger.warning(f"Unsupported operator {expr.operator}") - self.warning(expr.location, f"Unsupported operator {expr.operator}") - return UnknownType() - left: Type = self.type_of(expr.left) - right: Type = self.type_of(expr.right) - - operations: list[Operation] = self.ctx.get_operations_by_name(method) - valid_operations: list[Operation] = [] - for op in operations: - sig: Operation.CallSignature = op.signature - if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right): - valid_operations.append(op) - - if len(valid_operations) == 0: - self.error( - expr.location, - f"Undefined operation {method} between {left} and {right}", - ) - return UnknownType() - elif len(valid_operations) == 1: - self.logger.debug(f"Unique operation {method} between {left} and {right}") - return valid_operations[0].result - - for i, op1 in enumerate(valid_operations): - sig1: Operation.CallSignature = op1.signature - best_match: bool = True - for j, op2 in enumerate(valid_operations): - if i == j: - continue - sig2: Operation.CallSignature = op2.signature - if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype( - sig1.right, sig2.right - ): - best_match = False - break - self.logger.debug(f"{op1} is a full overload of {op2}") - if best_match: - return op1.result - - overloads: list[str] = [ - f"({op.signature.left} {op.signature.method} {op.signature.right}) -> {op.result}" - for op in valid_operations - ] - self.error( - expr.location, - f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(overloads)}", - ) - return UnknownType() - - def visit_compare_expr(self, expr: p.CompareExpr) -> Type: - method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__) - if method is None: - self.logger.warning(f"Unsupported operator {expr.operator}") - self.warning(expr.location, f"Unsupported operator {expr.operator}") - return UnknownType() - left: Type = self.type_of(expr.left) - right: Type = self.type_of(expr.right) - - result: Optional[Type] = self.ctx.get_operation_result(left, method, right) - if result is None: - self.error( - expr.location, - f"Undefined operation {method} between {left} and {right}", - ) - return UnknownType() - return result - - def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ... - - def visit_call_expr(self, expr: p.CallExpr) -> Type: - callee: Type = self.type_of(expr.callee) - if not isinstance(callee, Function): - self.error(expr.callee.location, "Callee is not a function") - return UnknownType() - function: Function = callee - mapped: list[MappedArgument] = self.map_call_arguments(function, expr) - for arg in mapped: - if not self.is_subtype(arg.type, arg.argument.type): - self.error( - arg.expr.location, - f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", - ) - return function.returns - - def visit_get_expr(self, expr: p.GetExpr) -> Type: - object: Type = self.type_of(expr.object) - base_object: Type = unfold_type(object) - match base_object: - case ComplexType(properties=properties): - if expr.name not in properties: - self.error( - expr.location, f"Unknown property '{expr.name} on {object}" - ) - return UnknownType() - return properties[expr.name] - - case UnknownType(): - return UnknownType() - - case _: - self.error( - expr.location, f"Cannot get property '{expr.name}' on {object}" - ) - return UnknownType() - - def visit_literal_expr(self, expr: p.LiteralExpr) -> Type: - match expr.value: - case bool(): # Must be before int - return self.ctx.get_type("bool") - case int(): - return self.ctx.get_type("int") - case float(): - return self.ctx.get_type("float") - case str(): - return self.ctx.get_type("str") - case _: - self.warning(expr.location, f"Unknown literal {expr}") - return UnknownType() - - def visit_variable_expr(self, expr: p.VariableExpr) -> Type: - return self.look_up_variable(expr.name, expr) or UnknownType() - - def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: - left: Type = expr.left.accept(self) - right: Type = expr.right.accept(self) - - if self.is_subtype(left, right): - return right - if self.is_subtype(right, left): - return left - - self.error( - expr.location, - f"Incompatible operand types, {left=} and {right=}", - ) - return UnknownType() - - def visit_cast_expr(self, expr: p.CastExpr) -> Type: - return expr.type.accept(self) - - def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type: - test_type: Type = expr.test.accept(self) - - # TODO Allow subtypes or any type - if test_type != self.ctx.get_type("bool"): - self.error( - expr.test.location, f"If test must be a boolean, got {test_type}" - ) - - true_type: Type = expr.if_true.accept(self) - false_type: Type = expr.if_false.accept(self) - if self.is_subtype(true_type, false_type): - return false_type - if self.is_subtype(false_type, true_type): - return true_type - - self.error( - expr.location, - f"Incompatible types in ternary if branches: true={true_type} and false={false_type}", - ) - return UnknownType() - - def visit_base_type(self, node: p.BaseType) -> Type: - return self.ctx.get_type(node.base) - - def visit_constraint_type(self, node: p.ConstraintType) -> Type: ... - - def visit_frame_column(self, node: p.FrameColumn) -> Type: ... - - def visit_frame_type(self, node: p.FrameType) -> Type: ... - - def map_call_arguments( - self, function: Function, call: p.CallExpr - ) -> list[MappedArgument]: - """Map call arguments to function parameters as defined in its signature - - This method maps positional-only, keyword-only and mixed parameter definitions - with the arguments passed at the call site - - Any mismatched, missing or unexpected argument is reported as a diagnostic - - Args: - function (Function): the function definition - call (p.CallExpr): the call expression - - Returns: - list[MappedArgument]: the list of mapped arguments - """ - positional: list[tuple[p.Expr, Type]] = [ - (arg, self.type_of(arg)) for arg in call.arguments - ] - keywords: dict[str, tuple[p.Expr, Type]] = { - name: (arg, self.type_of(arg)) for name, arg in call.keywords.items() - } - set_args: set[str] = set() - - required_positional: list[str] = [ - arg.name for arg in function.pos_args + function.args if arg.required - ] - required_keyword: list[str] = [ - arg.name for arg in function.kw_args if arg.required - ] - - mapped: list[MappedArgument] = [] - - pos_params: list[Function.Argument] = list(function.pos_args) - mixed_params: list[Function.Argument] = list(function.args) - kw_params: dict[str, Function.Argument] = { - arg.name: arg for arg in function.kw_args - } - - # TODO: handle *args and **kwargs sinks - for arg in positional: - param: Function.Argument - if len(pos_params) != 0: - param = pos_params.pop(0) - elif len(mixed_params) != 0: - param = mixed_params.pop(0) - else: - self.error(arg[0].location, "Too many positional arguments") - break - name: str = param.name - if name in required_positional: - required_positional.remove(name) - if name in required_keyword: - required_keyword.remove(name) - set_args.add(name) - mapped.append( - MappedArgument( - expr=arg[0], - type=arg[1], - argument=param, - ) - ) - - kw_params.update({arg.name: arg for arg in mixed_params}) - for name, arg in keywords.items(): - param: Function.Argument - if name not in kw_params: - if name in set_args: - self.error( - arg[0].location, f"Multiple values for argument '{name}'" - ) - else: - self.error(arg[0].location, f"Unknown keyword argument '{name}'") - continue - param = kw_params.pop(name) - if name in required_positional: - required_positional.remove(name) - if name in required_keyword: - required_keyword.remove(name) - set_args.add(name) - mapped.append( - MappedArgument( - expr=arg[0], - type=arg[1], - argument=param, - ) - ) - - def join_args(args: list[str]) -> str: - args = list(map(lambda a: f"'{a}'", args)) - if len(args) == 0: - return "" - if len(args) == 1: - return args[0] - return ", ".join(args[:-1]) + " and " + args[-1] - - if len(required_positional) != 0: - plural: str = "" if len(required_positional) == 1 else "s" - args: str = join_args(required_positional) - self.error( - call.location, - f"Missing required positional argument{plural}: {args}", - ) - - if len(required_keyword) != 0: - plural: str = "" if len(required_keyword) == 1 else "s" - args: str = join_args(required_keyword) - self.error( - call.location, - f"Missing required keyword argument{plural}: {args}", - ) - - return mapped + @property + def diagnostics(self) -> list[Diagnostic]: + return self.reporter.diagnostics diff --git a/midas/checker/diagnostic.py b/midas/checker/diagnostic.py index 2925653..f4b3d12 100644 --- a/midas/checker/diagnostic.py +++ b/midas/checker/diagnostic.py @@ -1,6 +1,5 @@ from dataclasses import dataclass from enum import StrEnum -from pathlib import Path from typing import Optional from midas.ast.location import Location @@ -14,7 +13,7 @@ class DiagnosticType(StrEnum): @dataclass(frozen=True) class Diagnostic: - file_path: Optional[str | Path] + file_path: Optional[str] location: Location type: DiagnosticType message: str diff --git a/midas/checker/midas.py b/midas/checker/midas.py new file mode 100644 index 0000000..37a856d --- /dev/null +++ b/midas/checker/midas.py @@ -0,0 +1,137 @@ +import logging +from typing import Optional + +import midas.ast.midas as m +from midas.checker.registry import TypesRegistry +from midas.checker.reporter import FileReporter, Reporter +from midas.checker.types import ( + AliasType, + ComplexType, + GenericType, + Type, + TypeVar, + UnknownType, +) +from midas.lexer.midas import MidasLexer +from midas.lexer.token import Token +from midas.parser.midas import MidasParser +from midas.resolver.builtin import define_builtins + + +class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]): + """A resolver which evaluates Midas type definitions and build a registry""" + + def __init__(self, types: TypesRegistry, reporter: Reporter) -> None: + self.logger: logging.Logger = logging.getLogger("MidasTyper") + self.reporter: FileReporter = reporter.for_file(None) + + self.types: TypesRegistry = types + self._local_variables: dict[str, TypeVar] = {} + + define_builtins(self.types) + + def process(self, source: str, path: Optional[str]): + self.reporter = self.reporter.for_file(path) + lexer: MidasLexer = MidasLexer(source) + tokens: list[Token] = lexer.process() + parser: MidasParser = MidasParser(tokens) + stmts: list[m.Stmt] = parser.parse() + self.resolve(stmts) + + def get_type(self, name: str) -> Type: + """Get a type from its name + + Args: + name (str): the name of the type + + Raises: + NameError: if the type is not defined + + Returns: + Type: the type + """ + if name in self._local_variables: + return self._local_variables[name] + return self.types.get_type(name) + + def resolve(self, stmts: list[m.Stmt]): + """Process a sequence of statements + + Args: + stmts (list[m.Stmt]): the statements + """ + for stmt in stmts: + stmt.accept(self) + + def visit_type_stmt(self, stmt: m.TypeStmt) -> None: + params: list[TypeVar] = [] + for param in stmt.params: + name: str = param.name.lexeme + bound: Optional[Type] = None + if param.bound is not None: + bound = param.bound.accept(self) + var = TypeVar(name=name, bound=bound) + self._local_variables[name] = var + params.append(var) + type: Type = stmt.type.accept(self) + if len(params) != 0: + type = GenericType(params=params, body=type) + name: str = stmt.name.lexeme + self.types.define_type(name, AliasType(name=name, type=type)) + self._local_variables.clear() + + def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... + + def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: + base: Type = stmt.type.accept(self) + for op in stmt.operations: + right: Type = op.operand.accept(self) + result: Type = op.result.accept(self) + self.types.define_operation( + left=base, + operator=op.name.lexeme, + right=right, + result=result, + ) + + def visit_op_stmt(self, stmt: m.OpStmt) -> None: ... + + def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ... + + def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ... + + def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ... + + def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ... + + def visit_get_expr(self, expr: m.GetExpr) -> None: ... + + def visit_variable_expr(self, expr: m.VariableExpr) -> None: ... + + def visit_grouping_expr(self, expr: m.GroupingExpr) -> None: + return expr.expr.accept(self) + + def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ... + + def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ... + + def visit_named_type(self, type: m.NamedType) -> Type: + return self.get_type(type.name.lexeme) + + def visit_generic_type(self, type: m.GenericType) -> Type: + type_: Type = type.type.accept(self) + params: list[Type] = [param.accept(self) for param in type.params] + return self.types.apply_generic(type_, params) + + def visit_constraint_type(self, type: m.ConstraintType) -> Type: + type_: Type = type.type.accept(self) + type.constraint.accept(self) + # TODO + return UnknownType() + + def visit_complex_type(self, type: m.ComplexType) -> Type: + return ComplexType( + properties={ + prop.name.lexeme: prop.type.accept(self) for prop in type.properties + } + ) diff --git a/midas/checker/python.py b/midas/checker/python.py new file mode 100644 index 0000000..751497d --- /dev/null +++ b/midas/checker/python.py @@ -0,0 +1,626 @@ +import ast +import logging +from dataclasses import dataclass +from typing import Optional + +import midas.ast.python as p +from midas.ast.location import Location +from midas.checker.environment import Environment +from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS +from midas.checker.registry import TypesRegistry +from midas.checker.reporter import FileReporter, Reporter +from midas.checker.types import ( + ComplexType, + Function, + Operation, + Type, + UnitType, + UnknownType, + unfold_type, +) +from midas.parser.python import PythonParser +from midas.resolver.resolver import Resolver + + +class ReturnException(Exception): + pass + + +@dataclass(frozen=True, kw_only=True) +class MappedArgument: + expr: p.Expr + type: Type + argument: Function.Argument + + +class PythonTyper( + p.Stmt.Visitor[None], + p.Expr.Visitor[Type], + p.MidasType.Visitor[Type], +): + """A type checker which can use custom type definitions""" + + def __init__( + self, + types: TypesRegistry, + reporter: Reporter, + ): + self.logger: logging.Logger = logging.getLogger("PythonTyper") + self.reporter: FileReporter = reporter.for_file(None) + self.types: TypesRegistry = types + self.global_env: Environment = Environment() + self.env: Environment = self.global_env + self.locals: dict[p.Expr, int] = {} + self.judgements: list[tuple[p.Expr, Type]] = [] + + def process(self, source: str, path: Optional[str]): + self.reporter = self.reporter.for_file(path) + + tree: ast.Module = ast.parse(source, filename=path or "") + parser = PythonParser() + stmts: list[p.Stmt] = parser.parse_module(tree) + resolver = Resolver() + resolver.resolve(*stmts) + + self.env = self.global_env + self.locals = resolver.locals + self.judgements = [] + + self.check(stmts) + + def type_of(self, expr: p.Expr) -> Type: + """Evaluate the type of an expression + + Args: + expr (p.Expr): the expression to evaluate + + Returns: + Type: the type of the given expression + """ + type: Type = expr.accept(self) + self.judgements.append((expr, type)) + return type + + def process_block(self, block: list[p.Stmt], env: Environment) -> bool: + """Evaluate a sequence of statements + + Args: + block (list[p.Stmt]): the statements to evaluate + env (Environment): the environment in which to evaluate + + Returns: + bool: whether a return statement is present in the block + """ + previous_env: Environment = self.env + self.env = env + returned: bool = False + for i, stmt in enumerate(block): + try: + stmt.accept(self) + except ReturnException: + returned = True + if i < len(block) - 1: + self.reporter.warning( + block[i + 1].location, "Unreachable statement" + ) + break + self.env = previous_env + return returned + + def check(self, statements: list[p.Stmt]) -> None: + """Type check a sequence of statements and returns diagnostics + + Args: + statements (list[p.Stmt]): the statements to evaluate and check + """ + for stmt in statements: + stmt.accept(self) + + self.logger.debug(f"Final environment: {self.env.flat_dict()}") + + def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]: + """Look up a variable in the environment it was declared + + Args: + name (str): the name of the variable + expr (p.Expr): the variable expression, used to lookup the scope distance + + Returns: + Optional[Type]: the type of the variable, or None if it was not found + """ + distance: Optional[int] = self.locals.get(expr) + if distance is not None: + return self.env.get_at(distance, name) + return self.global_env.get(name) + + def is_subtype(self, type1: Type, type2: Type) -> bool: + return self.types.is_subtype(type1, type2) + + def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: + self.type_of(stmt.expr) + + def visit_function(self, stmt: p.Function) -> None: + env: Environment = Environment(self.env) + pos_args: list[Function.Argument] = [] + args: list[Function.Argument] = [] + kw_args: list[Function.Argument] = [] + + def eval_arg_type(arg: p.Function.Argument) -> Type: + if arg.type is not None: + return arg.type.accept(self) + if arg.default is not None: + return arg.default.accept(self) + return UnknownType() + + pos: int = 0 + for arg in stmt.posonlyargs: + pos_args.append( + Function.Argument( + pos=pos, + name=arg.name, + type=eval_arg_type(arg), + required=arg.default is None, + ) + ) + pos += 1 + for arg in stmt.args: + args.append( + Function.Argument( + pos=pos, + name=arg.name, + type=eval_arg_type(arg), + required=arg.default is None, + ) + ) + pos += 1 + for arg in stmt.kwonlyargs: + kw_args.append( + Function.Argument( + pos=pos, # not relevant + name=arg.name, + type=eval_arg_type(arg), + required=arg.default is None, + ) + ) + pos += 1 + + for arg in pos_args + args + kw_args: + env.define(arg.name, arg.type) + + returns_hint: Optional[Type] = None + if stmt.returns is not None: + returns_hint = stmt.returns.accept(self) + # Early define to handle simple fully-typed recursion + inside_function: Function = Function( + name=stmt.name, + pos_args=pos_args, + args=args, + kw_args=kw_args, + returns=returns_hint, + ) + self.env.define(stmt.name, inside_function) + + returned: bool = self.process_block(stmt.body, env) + inferred_return: Type = UnknownType() + if not returned: + env.return_types.append(UnitType()) + return_types: set[Type] = set(env.return_types) + if len(return_types) == 1: + inferred_return = list(return_types)[0] + elif len(return_types) > 1: + self.reporter.error( + stmt.location, + f"Mixed return types: {env.return_types}", + ) + + returns: Type = UnknownType() + if returns_hint is not None: + assert stmt.returns is not None + returns = returns_hint + if returns != inferred_return: + self.reporter.error( + stmt.returns.location, + f"Return type mismatch, annotated {returns} but returns {inferred_return}", + ) + else: + returns = inferred_return + + # TODO: handle *args and **kwargs sinks + function: Function = Function( + name=stmt.name, + pos_args=pos_args, + args=args, + kw_args=kw_args, + returns=returns, + ) + self.env.define(stmt.name, function) + + def visit_type_assign(self, stmt: p.TypeAssign) -> None: + # TODO check not yet defined locally + type: Type = stmt.type.accept(self) + self.env.define(stmt.name, type) + + def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: + value_type: Type = self.type_of(stmt.value) + for target in stmt.targets: + self._assign(stmt.location, target, value_type) + + def _assign(self, location: Location, target: p.Expr, value_type: Type): + match target: + case p.VariableExpr(): + self._assign_var(location, target, value_type) + + case p.GetExpr(): + self._assign_attr(location, target, value_type) + + case _: + if not isinstance(target, p.VariableExpr): + self.logger.warning(f"Unsupported assignment to {target}") + self.reporter.warning( + target.location, f"Unsupported assignment to {target}" + ) + + def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type): + name: str = target.name + var_type: Optional[Type] = self.look_up_variable(name, target) + + if var_type is None: + self.env.define(name, value_type) + else: + # S <: T + # Γ, x: T v: S + # x = v + if not self.is_subtype(value_type, var_type): + self.reporter.error( + location, + f"Cannot assign {value_type} to {name} of type {var_type}", + ) + + def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type): + object: Type = self.type_of(target.object) + base_object: Type = unfold_type(object) + match base_object: + case ComplexType(properties=properties): + if target.name not in properties: + self.reporter.error( + target.location, f"Unknown property '{target.name} on {object}" + ) + return + + prop_type: Type = properties[target.name] + if not self.is_subtype(value_type, prop_type): + self.reporter.error( + location, + f"Cannot assign {value_type} to property '{target.name}' of type {prop_type} on {object}", + ) + return + + case UnknownType(): + pass + + case _: + self.reporter.error( + target.location, + f"Cannot assign {value_type} to unknown property '{target.name}' on {object}", + ) + + def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: + type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType() + self.env.return_types.append(type) + raise ReturnException() + + def visit_if_stmt(self, stmt: p.IfStmt) -> None: + # Not evaluated in sub-environment because assignments in the test leak out of the if + # For example: + # if (m := 1 + 1) < 2: + # ... + # print(m) # <- m is still defined + test_type: Type = stmt.test.accept(self) + + # TODO Allow subtypes or any type + if test_type != self.types.get_type("bool"): + self.reporter.error( + stmt.test.location, f"If test must be a boolean, got {test_type}" + ) + + env: Environment = Environment(self.env) + body_returned: bool = self.process_block(stmt.body, env) + else_returned: bool = self.process_block(stmt.orelse, env) + self.env.return_types.extend(env.return_types) + if body_returned and else_returned: + raise ReturnException() + + def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: + method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) + if method is None: + self.logger.warning(f"Unsupported operator {expr.operator}") + self.reporter.warning( + expr.location, f"Unsupported operator {expr.operator}" + ) + return UnknownType() + left: Type = self.type_of(expr.left) + right: Type = self.type_of(expr.right) + + operations: list[Operation] = self.types.get_operations_by_name(method) + valid_operations: list[Operation] = [] + for op in operations: + sig: Operation.CallSignature = op.signature + if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right): + valid_operations.append(op) + + if len(valid_operations) == 0: + self.reporter.error( + expr.location, + f"Undefined operation {method} between {left} and {right}", + ) + return UnknownType() + elif len(valid_operations) == 1: + self.logger.debug(f"Unique operation {method} between {left} and {right}") + return valid_operations[0].result + + for i, op1 in enumerate(valid_operations): + sig1: Operation.CallSignature = op1.signature + best_match: bool = True + for j, op2 in enumerate(valid_operations): + if i == j: + continue + sig2: Operation.CallSignature = op2.signature + if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype( + sig1.right, sig2.right + ): + best_match = False + break + self.logger.debug(f"{op1} is a full overload of {op2}") + if best_match: + return op1.result + + overloads: list[str] = [ + f"({op.signature.left} {op.signature.method} {op.signature.right}) -> {op.result}" + for op in valid_operations + ] + self.reporter.error( + expr.location, + f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(overloads)}", + ) + return UnknownType() + + def visit_compare_expr(self, expr: p.CompareExpr) -> Type: + method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__) + if method is None: + self.logger.warning(f"Unsupported operator {expr.operator}") + self.reporter.warning( + expr.location, f"Unsupported operator {expr.operator}" + ) + return UnknownType() + left: Type = self.type_of(expr.left) + right: Type = self.type_of(expr.right) + + result: Optional[Type] = self.types.get_operation_result(left, method, right) + if result is None: + self.reporter.error( + expr.location, + f"Undefined operation {method} between {left} and {right}", + ) + return UnknownType() + return result + + def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ... + + def visit_call_expr(self, expr: p.CallExpr) -> Type: + callee: Type = self.type_of(expr.callee) + if not isinstance(callee, Function): + self.reporter.error(expr.callee.location, "Callee is not a function") + return UnknownType() + function: Function = callee + mapped: list[MappedArgument] = self.map_call_arguments(function, expr) + for arg in mapped: + if not self.is_subtype(arg.type, arg.argument.type): + self.reporter.error( + arg.expr.location, + f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", + ) + return function.returns + + def visit_get_expr(self, expr: p.GetExpr) -> Type: + object: Type = self.type_of(expr.object) + base_object: Type = unfold_type(object) + match base_object: + case ComplexType(properties=properties): + if expr.name not in properties: + self.reporter.error( + expr.location, f"Unknown property '{expr.name} on {object}" + ) + return UnknownType() + return properties[expr.name] + + case UnknownType(): + return UnknownType() + + case _: + self.reporter.error( + expr.location, f"Cannot get property '{expr.name}' on {object}" + ) + return UnknownType() + + def visit_literal_expr(self, expr: p.LiteralExpr) -> Type: + match expr.value: + case bool(): # Must be before int + return self.types.get_type("bool") + case int(): + return self.types.get_type("int") + case float(): + return self.types.get_type("float") + case str(): + return self.types.get_type("str") + case _: + self.reporter.warning(expr.location, f"Unknown literal {expr}") + return UnknownType() + + def visit_variable_expr(self, expr: p.VariableExpr) -> Type: + return self.look_up_variable(expr.name, expr) or UnknownType() + + def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: + left: Type = expr.left.accept(self) + right: Type = expr.right.accept(self) + + if self.is_subtype(left, right): + return right + if self.is_subtype(right, left): + return left + + self.reporter.error( + expr.location, + f"Incompatible operand types, {left=} and {right=}", + ) + return UnknownType() + + def visit_cast_expr(self, expr: p.CastExpr) -> Type: + return expr.type.accept(self) + + def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type: + test_type: Type = expr.test.accept(self) + + # TODO Allow subtypes or any type + if test_type != self.types.get_type("bool"): + self.reporter.error( + expr.test.location, f"If test must be a boolean, got {test_type}" + ) + + true_type: Type = expr.if_true.accept(self) + false_type: Type = expr.if_false.accept(self) + if self.is_subtype(true_type, false_type): + return false_type + if self.is_subtype(false_type, true_type): + return true_type + + self.reporter.error( + expr.location, + f"Incompatible types in ternary if branches: true={true_type} and false={false_type}", + ) + return UnknownType() + + def visit_base_type(self, node: p.BaseType) -> Type: + return self.types.get_type(node.base) + + def visit_constraint_type(self, node: p.ConstraintType) -> Type: ... + + def visit_frame_column(self, node: p.FrameColumn) -> Type: ... + + def visit_frame_type(self, node: p.FrameType) -> Type: ... + + def map_call_arguments( + self, function: Function, call: p.CallExpr + ) -> list[MappedArgument]: + """Map call arguments to function parameters as defined in its signature + + This method maps positional-only, keyword-only and mixed parameter definitions + with the arguments passed at the call site + + Any mismatched, missing or unexpected argument is reported as a diagnostic + + Args: + function (Function): the function definition + call (p.CallExpr): the call expression + + Returns: + list[MappedArgument]: the list of mapped arguments + """ + positional: list[tuple[p.Expr, Type]] = [ + (arg, self.type_of(arg)) for arg in call.arguments + ] + keywords: dict[str, tuple[p.Expr, Type]] = { + name: (arg, self.type_of(arg)) for name, arg in call.keywords.items() + } + set_args: set[str] = set() + + required_positional: list[str] = [ + arg.name for arg in function.pos_args + function.args if arg.required + ] + required_keyword: list[str] = [ + arg.name for arg in function.kw_args if arg.required + ] + + mapped: list[MappedArgument] = [] + + pos_params: list[Function.Argument] = list(function.pos_args) + mixed_params: list[Function.Argument] = list(function.args) + kw_params: dict[str, Function.Argument] = { + arg.name: arg for arg in function.kw_args + } + + # TODO: handle *args and **kwargs sinks + for arg in positional: + param: Function.Argument + if len(pos_params) != 0: + param = pos_params.pop(0) + elif len(mixed_params) != 0: + param = mixed_params.pop(0) + else: + self.reporter.error(arg[0].location, "Too many positional arguments") + break + name: str = param.name + if name in required_positional: + required_positional.remove(name) + if name in required_keyword: + required_keyword.remove(name) + set_args.add(name) + mapped.append( + MappedArgument( + expr=arg[0], + type=arg[1], + argument=param, + ) + ) + + kw_params.update({arg.name: arg for arg in mixed_params}) + for name, arg in keywords.items(): + param: Function.Argument + if name not in kw_params: + if name in set_args: + self.reporter.error( + arg[0].location, f"Multiple values for argument '{name}'" + ) + else: + self.reporter.error( + arg[0].location, f"Unknown keyword argument '{name}'" + ) + continue + param = kw_params.pop(name) + if name in required_positional: + required_positional.remove(name) + if name in required_keyword: + required_keyword.remove(name) + set_args.add(name) + mapped.append( + MappedArgument( + expr=arg[0], + type=arg[1], + argument=param, + ) + ) + + def join_args(args: list[str]) -> str: + args = list(map(lambda a: f"'{a}'", args)) + if len(args) == 0: + return "" + if len(args) == 1: + return args[0] + return ", ".join(args[:-1]) + " and " + args[-1] + + if len(required_positional) != 0: + plural: str = "" if len(required_positional) == 1 else "s" + args: str = join_args(required_positional) + self.reporter.error( + call.location, + f"Missing required positional argument{plural}: {args}", + ) + + if len(required_keyword) != 0: + plural: str = "" if len(required_keyword) == 1 else "s" + args: str = join_args(required_keyword) + self.reporter.error( + call.location, + f"Missing required keyword argument{plural}: {args}", + ) + + return mapped diff --git a/midas/resolver/midas.py b/midas/checker/registry.py similarity index 74% rename from midas/resolver/midas.py rename to midas/checker/registry.py index 6872569..1324bbd 100644 --- a/midas/resolver/midas.py +++ b/midas/checker/registry.py @@ -1,6 +1,5 @@ from typing import Optional -import midas.ast.midas as m from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.types import ( AliasType, @@ -10,24 +9,15 @@ from midas.checker.types import ( GenericType, Operation, Type, - TypeVar, - UnknownType, substitute_typevars, ) -from midas.resolver.builtin import define_builtins -class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]): - """A resolver which evaluates Midas type definitions and build a registry""" - +class TypesRegistry: def __init__(self) -> None: self._types: dict[str, Type] = {} self._operations: dict[Operation.CallSignature, Type] = {} - self._local_variables: dict[str, TypeVar] = {} - - define_builtins(self) - def get_type(self, name: str) -> Type: """Get a type from its name @@ -40,8 +30,6 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T Returns: Type: the type """ - if name in self._local_variables: - return self._local_variables[name] if name in self._types: return self._types[name] raise NameError(f"Undefined type {name}") @@ -120,117 +108,6 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T ) self._operations[signature] = result - def resolve(self, stmts: list[m.Stmt]): - """Process a sequence of statements - - Args: - stmts (list[m.Stmt]): the statements - """ - for stmt in stmts: - stmt.accept(self) - - def visit_type_stmt(self, stmt: m.TypeStmt) -> None: - params: list[TypeVar] = [] - for param in stmt.params: - name: str = param.name.lexeme - bound: Optional[Type] = None - if param.bound is not None: - bound = param.bound.accept(self) - var = TypeVar(name=name, bound=bound) - self._local_variables[name] = var - params.append(var) - type: Type = stmt.type.accept(self) - if len(params) != 0: - type = GenericType(params=params, body=type) - name: str = stmt.name.lexeme - self.define_type(name, AliasType(name=name, type=type)) - self._local_variables.clear() - - def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... - - def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: - base: Type = stmt.type.accept(self) - for op in stmt.operations: - right: Type = op.operand.accept(self) - result: Type = op.result.accept(self) - self.define_operation( - left=base, - operator=op.name.lexeme, - right=right, - result=result, - ) - - def visit_op_stmt(self, stmt: m.OpStmt) -> None: ... - - def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ... - - def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ... - - def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ... - - def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ... - - def visit_get_expr(self, expr: m.GetExpr) -> None: ... - - def visit_variable_expr(self, expr: m.VariableExpr) -> None: ... - - def visit_grouping_expr(self, expr: m.GroupingExpr) -> None: - return expr.expr.accept(self) - - def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ... - - def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ... - - def visit_named_type(self, type: m.NamedType) -> Type: - return self.get_type(type.name.lexeme) - - def visit_generic_type(self, type: m.GenericType) -> Type: - type_: Type = type.type.accept(self) - params: list[Type] = [param.accept(self) for param in type.params] - return self.apply_generic(type_, params) - - def apply_generic(self, type: Type, params: list[Type]) -> Type: - match type: - case AliasType(name=name, type=base): - return AliasType(name=name, type=self.apply_generic(base, params)) - - case GenericType(params=type_vars, body=body): - n_params: int = len(params) - n_type_vars: int = len(type_vars) - if n_params < n_type_vars: - raise ValueError( - f"Missing type parameters, expected {n_type_vars} but only {n_params} provided" - ) - if n_params > n_type_vars: - raise ValueError( - f"Too many type parameters, expected {n_type_vars} but {n_params} provided" - ) - substitutions: dict[str, Type] = {} - for param, type_var in zip(params, type_vars): - if type_var.bound is not None and not self.is_subtype( - param, type_var.bound - ): - raise ValueError( - f"Type parameter {param} is not a subtype of {type_var.bound}" - ) - substitutions[type_var.name] = param - return substitute_typevars(body, substitutions) - case _: - raise ValueError(f"{type} is not a generic type") - - def visit_constraint_type(self, type: m.ConstraintType) -> Type: - type_: Type = type.type.accept(self) - type.constraint.accept(self) - # TODO - return UnknownType() - - def visit_complex_type(self, type: m.ComplexType) -> Type: - return ComplexType( - properties={ - prop.name.lexeme: prop.type.accept(self) for prop in type.properties - } - ) - def is_subtype(self, type1: Type, type2: Type) -> bool: """Check whether `type1` is a subtype of `type2` @@ -371,3 +248,33 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T return False return True + + def apply_generic(self, type: Type, params: list[Type]) -> Type: + match type: + case AliasType(name=name, type=base): + return AliasType(name=name, type=self.apply_generic(base, params)) + + case GenericType(params=type_vars, body=body): + n_params: int = len(params) + n_type_vars: int = len(type_vars) + if n_params < n_type_vars: + raise ValueError( + f"Missing type parameters, expected {n_type_vars} but only {n_params} provided" + ) + if n_params > n_type_vars: + raise ValueError( + f"Too many type parameters, expected {n_type_vars} but {n_params} provided" + ) + substitutions: dict[str, Type] = {} + for param, type_var in zip(params, type_vars): + if type_var.bound is not None and not self.is_subtype( + param, type_var.bound + ): + raise ValueError( + f"Type parameter {param} is not a subtype of {type_var.bound}" + ) + substitutions[type_var.name] = param + return substitute_typevars(body, substitutions) + + case _: + raise ValueError(f"{type} is not a generic type") diff --git a/midas/cli/main.py b/midas/cli/main.py index ae4295b..cafeeaf 100644 --- a/midas/cli/main.py +++ b/midas/cli/main.py @@ -10,7 +10,7 @@ import midas.ast.midas as m import midas.ast.python as p from midas.ast.location import Location from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter -from midas.checker.checker import Checker +from midas.checker.checker import TypeChecker from midas.checker.diagnostic import Diagnostic, DiagnosticType from midas.checker.types import Type from midas.cli.ansi import Ansi @@ -25,7 +25,6 @@ from midas.lexer.midas import MidasLexer from midas.lexer.token import Token, TokenType from midas.parser.midas import MidasParser from midas.parser.python import PythonParser -from midas.resolver.resolver import Resolver from midas.utils import UniversalJSONDumper @@ -98,18 +97,13 @@ def compile( ): logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN) source: str = file.read() - tree: ast.Module = ast.parse(source, filename=file.name) - parser = PythonParser() - stmts: list[p.Stmt] = parser.parse_module(tree) - resolver = Resolver() - resolver.resolve(*stmts) - types_paths: list[Path] = [Path(t.name).resolve() for t in types] - checker = Checker( - resolver.locals, - source_path=Path(file.name).resolve(), - types_paths=types_paths, - ) - diagnostics: list[Diagnostic] = checker.check(stmts) + + checker = TypeChecker() + for path in types: + checker.import_midas(Path(path.name).resolve()) + + checker.type_check_source(source, str(Path(file.name).resolve())) + diagnostics: list[Diagnostic] = checker.diagnostics lines: list[str] = source.split("\n") for diagnostic in diagnostics: print_diagnostic(lines, diagnostic) @@ -118,7 +112,7 @@ def compile( print( json.dumps( UniversalJSONDumper.dump( - checker.global_env, + checker.python_typer.global_env, [("Environment", "_children")], lambda obj: isinstance(obj, get_args(Type)), ), diff --git a/midas/resolver/builtin.py b/midas/resolver/builtin.py index 04bc6e3..c3c7e65 100644 --- a/midas/resolver/builtin.py +++ b/midas/resolver/builtin.py @@ -1,15 +1,9 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - +from midas.checker.registry import TypesRegistry from midas.checker.types import BaseType, Type, UnitType -if TYPE_CHECKING: - from midas.resolver.midas import MidasResolver - -def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type): - ctx.define_operation( +def op(reg: TypesRegistry, t1: Type, operator: str, t2: Type, t3: Type): + reg.define_operation( left=t1, operator=operator, right=t2, @@ -17,8 +11,8 @@ def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type): ) -def basic_op(ctx: MidasResolver, type: Type, op: str): - ctx.define_operation( +def basic_op(reg: TypesRegistry, type: Type, op: str): + reg.define_operation( left=type, operator=op, right=type, @@ -26,47 +20,47 @@ def basic_op(ctx: MidasResolver, type: Type, op: str): ) -def define_builtins(ctx: MidasResolver): +def define_builtins(reg: TypesRegistry): """Define builtin types and operations""" - unit = ctx.define_type("None", UnitType()) - bool = ctx.define_type("bool", BaseType(name="bool")) - int = ctx.define_type("int", BaseType(name="int")) - float = ctx.define_type("float", BaseType(name="float")) - str = ctx.define_type("str", BaseType(name="str")) + unit = reg.define_type("None", UnitType()) + bool = reg.define_type("bool", BaseType(name="bool")) + int = reg.define_type("int", BaseType(name="int")) + float = reg.define_type("float", BaseType(name="float")) + str = reg.define_type("str", BaseType(name="str")) - basic_op(ctx, int, "__add__") # int + int = int - basic_op(ctx, int, "__sub__") # int - int = int - basic_op(ctx, int, "__mul__") # int * int = int - basic_op(ctx, int, "__pow__") # int ** int = int - basic_op(ctx, int, "__mod__") # int % int = int - basic_op(ctx, int, "__and__") # int & int = int - basic_op(ctx, int, "__or__") # int | int = int - basic_op(ctx, int, "__xor__") # int ^ int = int - op(ctx, int, "__lt__", int, bool) # int < int = bool - op(ctx, int, "__gt__", int, bool) # int > int = bool - op(ctx, int, "__le__", int, bool) # int <= int = bool - op(ctx, int, "__ge__", int, bool) # int >= int = bool - op(ctx, int, "__eq__", int, bool) # int == int = bool - basic_op(ctx, float, "__add__") # float + float = float - basic_op(ctx, float, "__sub__") # float - float = float - basic_op(ctx, float, "__mul__") # float * float = float - basic_op(ctx, float, "__truediv__") # float / float = float - op(ctx, float, "__lt__", float, bool) # float < float = bool - op(ctx, float, "__gt__", float, bool) # float > float = bool - op(ctx, float, "__le__", float, bool) # float <= float = bool - op(ctx, float, "__ge__", float, bool) # float >= float = bool - op(ctx, float, "__eq__", float, bool) # float == float = bool - basic_op(ctx, str, "__add__") # str + str = str - op(ctx, str, "__eq__", str, bool) # str == str = bool + basic_op(reg, int, "__add__") # int + int = int + basic_op(reg, int, "__sub__") # int - int = int + basic_op(reg, int, "__mul__") # int * int = int + basic_op(reg, int, "__pow__") # int ** int = int + basic_op(reg, int, "__mod__") # int % int = int + basic_op(reg, int, "__and__") # int & int = int + basic_op(reg, int, "__or__") # int | int = int + basic_op(reg, int, "__xor__") # int ^ int = int + op(reg, int, "__lt__", int, bool) # int < int = bool + op(reg, int, "__gt__", int, bool) # int > int = bool + op(reg, int, "__le__", int, bool) # int <= int = bool + op(reg, int, "__ge__", int, bool) # int >= int = bool + op(reg, int, "__eq__", int, bool) # int == int = bool + basic_op(reg, float, "__add__") # float + float = float + basic_op(reg, float, "__sub__") # float - float = float + basic_op(reg, float, "__mul__") # float * float = float + basic_op(reg, float, "__truediv__") # float / float = float + op(reg, float, "__lt__", float, bool) # float < float = bool + op(reg, float, "__gt__", float, bool) # float > float = bool + op(reg, float, "__le__", float, bool) # float <= float = bool + op(reg, float, "__ge__", float, bool) # float >= float = bool + op(reg, float, "__eq__", float, bool) # float == float = bool + basic_op(reg, str, "__add__") # str + str = str + op(reg, str, "__eq__", str, bool) # str == str = bool - op(ctx, int, "__lt__", float, bool) # int < float = bool - op(ctx, int, "__gt__", float, bool) # int > float = bool - op(ctx, int, "__le__", float, bool) # int <= float = bool - op(ctx, int, "__ge__", float, bool) # int >= float = bool - op(ctx, int, "__eq__", float, bool) # int == float = bool + op(reg, int, "__lt__", float, bool) # int < float = bool + op(reg, int, "__gt__", float, bool) # int > float = bool + op(reg, int, "__le__", float, bool) # int <= float = bool + op(reg, int, "__ge__", float, bool) # int >= float = bool + op(reg, int, "__eq__", float, bool) # int == float = bool - op(ctx, float, "__lt__", int, bool) # float < int = bool - op(ctx, float, "__gt__", int, bool) # float > int = bool - op(ctx, float, "__le__", int, bool) # float <= int = bool - op(ctx, float, "__ge__", int, bool) # float >= int = bool - op(ctx, float, "__eq__", int, bool) # float == int = bool + op(reg, float, "__lt__", int, bool) # float < int = bool + op(reg, float, "__gt__", int, bool) # float > int = bool + op(reg, float, "__le__", int, bool) # float <= int = bool + op(reg, float, "__ge__", int, bool) # float >= int = bool + op(reg, float, "__eq__", int, bool) # float == int = bool diff --git a/tests/checker.py b/tests/checker.py index 27a94cb..3ceb34e 100644 --- a/tests/checker.py +++ b/tests/checker.py @@ -1,14 +1,11 @@ -import ast import json from dataclasses import asdict, dataclass, field from pathlib import Path import midas.ast.python as p -from midas.checker.checker import Checker +from midas.checker.checker import TypeChecker from midas.checker.diagnostic import Diagnostic from midas.checker.types import Type -from midas.parser.python import PythonParser -from midas.resolver.resolver import Resolver from tests.base import Tester from tests.serializer.python import PythonAstJsonSerializer @@ -36,24 +33,16 @@ class CheckerTester(Tester): if not path.is_file(): raise TypeError(f"Test '{path}' is not a file") - types_paths: list[Path] = [] + result: CaseResult = CaseResult() + + checker = TypeChecker() types_path: Path = path.with_suffix(".midas") if types_path.exists(): - types_paths.append(types_path) - source: str = path.read_text() - tree: ast.Module = ast.parse(source, filename=path) - parser = PythonParser() - stmts: list[p.Stmt] = parser.parse_module(tree) - resolver = Resolver() - resolver.resolve(*stmts) - result: CaseResult = CaseResult() - checker = Checker( - resolver.locals, - source_path=path, - types_paths=types_paths, - ) + checker.import_midas(types_path) - diagnostics: list[Diagnostic] = checker.check(stmts) + checker.type_check(path) + + diagnostics: list[Diagnostic] = checker.diagnostics for diagnostic in diagnostics: result.diagnostics.append( { @@ -72,7 +61,7 @@ class CheckerTester(Tester): } ) - judgements: list[tuple[p.Expr, Type]] = checker.judgements + judgements: list[tuple[p.Expr, Type]] = checker.python_typer.judgements serializer = PythonAstJsonSerializer() for expr, type in judgements: loc = expr.location From 7236749bd54a6978ac5e2ae7d59bf5e8b953d42e Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 8 Jun 2026 13:44:26 +0200 Subject: [PATCH 08/64] refactor(checker): unify builtins definitions --- midas/checker/builtins.py | 67 +++++++++++++++++++++++++++++++++++++++ midas/checker/midas.py | 2 +- midas/resolver/builtin.py | 66 -------------------------------------- 3 files changed, 68 insertions(+), 67 deletions(-) delete mode 100644 midas/resolver/builtin.py diff --git a/midas/checker/builtins.py b/midas/checker/builtins.py index bc80084..ac3e737 100644 --- a/midas/checker/builtins.py +++ b/midas/checker/builtins.py @@ -1,4 +1,71 @@ +from midas.checker.registry import TypesRegistry +from midas.checker.types import BaseType, Type, UnitType + BUILTIN_SUBTYPES: dict[str, set[str]] = { "float": {"int"}, "int": {"bool"}, } + + +def op(reg: TypesRegistry, t1: Type, operator: str, t2: Type, t3: Type): + reg.define_operation( + left=t1, + operator=operator, + right=t2, + result=t3, + ) + + +def basic_op(reg: TypesRegistry, type: Type, op: str): + reg.define_operation( + left=type, + operator=op, + right=type, + result=type, + ) + + +def define_builtins(reg: TypesRegistry): + """Define builtin types and operations""" + unit = reg.define_type("None", UnitType()) + bool = reg.define_type("bool", BaseType(name="bool")) + int = reg.define_type("int", BaseType(name="int")) + float = reg.define_type("float", BaseType(name="float")) + str = reg.define_type("str", BaseType(name="str")) + + basic_op(reg, int, "__add__") # int + int = int + basic_op(reg, int, "__sub__") # int - int = int + basic_op(reg, int, "__mul__") # int * int = int + basic_op(reg, int, "__pow__") # int ** int = int + basic_op(reg, int, "__mod__") # int % int = int + basic_op(reg, int, "__and__") # int & int = int + basic_op(reg, int, "__or__") # int | int = int + basic_op(reg, int, "__xor__") # int ^ int = int + op(reg, int, "__lt__", int, bool) # int < int = bool + op(reg, int, "__gt__", int, bool) # int > int = bool + op(reg, int, "__le__", int, bool) # int <= int = bool + op(reg, int, "__ge__", int, bool) # int >= int = bool + op(reg, int, "__eq__", int, bool) # int == int = bool + basic_op(reg, float, "__add__") # float + float = float + basic_op(reg, float, "__sub__") # float - float = float + basic_op(reg, float, "__mul__") # float * float = float + basic_op(reg, float, "__truediv__") # float / float = float + op(reg, float, "__lt__", float, bool) # float < float = bool + op(reg, float, "__gt__", float, bool) # float > float = bool + op(reg, float, "__le__", float, bool) # float <= float = bool + op(reg, float, "__ge__", float, bool) # float >= float = bool + op(reg, float, "__eq__", float, bool) # float == float = bool + basic_op(reg, str, "__add__") # str + str = str + op(reg, str, "__eq__", str, bool) # str == str = bool + + op(reg, int, "__lt__", float, bool) # int < float = bool + op(reg, int, "__gt__", float, bool) # int > float = bool + op(reg, int, "__le__", float, bool) # int <= float = bool + op(reg, int, "__ge__", float, bool) # int >= float = bool + op(reg, int, "__eq__", float, bool) # int == float = bool + + op(reg, float, "__lt__", int, bool) # float < int = bool + op(reg, float, "__gt__", int, bool) # float > int = bool + op(reg, float, "__le__", int, bool) # float <= int = bool + op(reg, float, "__ge__", int, bool) # float >= int = bool + op(reg, float, "__eq__", int, bool) # float == int = bool diff --git a/midas/checker/midas.py b/midas/checker/midas.py index 37a856d..a7ce36f 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -2,6 +2,7 @@ import logging from typing import Optional import midas.ast.midas as m +from midas.checker.builtins import define_builtins from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter, Reporter from midas.checker.types import ( @@ -15,7 +16,6 @@ from midas.checker.types import ( from midas.lexer.midas import MidasLexer from midas.lexer.token import Token from midas.parser.midas import MidasParser -from midas.resolver.builtin import define_builtins class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]): diff --git a/midas/resolver/builtin.py b/midas/resolver/builtin.py deleted file mode 100644 index c3c7e65..0000000 --- a/midas/resolver/builtin.py +++ /dev/null @@ -1,66 +0,0 @@ -from midas.checker.registry import TypesRegistry -from midas.checker.types import BaseType, Type, UnitType - - -def op(reg: TypesRegistry, t1: Type, operator: str, t2: Type, t3: Type): - reg.define_operation( - left=t1, - operator=operator, - right=t2, - result=t3, - ) - - -def basic_op(reg: TypesRegistry, type: Type, op: str): - reg.define_operation( - left=type, - operator=op, - right=type, - result=type, - ) - - -def define_builtins(reg: TypesRegistry): - """Define builtin types and operations""" - unit = reg.define_type("None", UnitType()) - bool = reg.define_type("bool", BaseType(name="bool")) - int = reg.define_type("int", BaseType(name="int")) - float = reg.define_type("float", BaseType(name="float")) - str = reg.define_type("str", BaseType(name="str")) - - basic_op(reg, int, "__add__") # int + int = int - basic_op(reg, int, "__sub__") # int - int = int - basic_op(reg, int, "__mul__") # int * int = int - basic_op(reg, int, "__pow__") # int ** int = int - basic_op(reg, int, "__mod__") # int % int = int - basic_op(reg, int, "__and__") # int & int = int - basic_op(reg, int, "__or__") # int | int = int - basic_op(reg, int, "__xor__") # int ^ int = int - op(reg, int, "__lt__", int, bool) # int < int = bool - op(reg, int, "__gt__", int, bool) # int > int = bool - op(reg, int, "__le__", int, bool) # int <= int = bool - op(reg, int, "__ge__", int, bool) # int >= int = bool - op(reg, int, "__eq__", int, bool) # int == int = bool - basic_op(reg, float, "__add__") # float + float = float - basic_op(reg, float, "__sub__") # float - float = float - basic_op(reg, float, "__mul__") # float * float = float - basic_op(reg, float, "__truediv__") # float / float = float - op(reg, float, "__lt__", float, bool) # float < float = bool - op(reg, float, "__gt__", float, bool) # float > float = bool - op(reg, float, "__le__", float, bool) # float <= float = bool - op(reg, float, "__ge__", float, bool) # float >= float = bool - op(reg, float, "__eq__", float, bool) # float == float = bool - basic_op(reg, str, "__add__") # str + str = str - op(reg, str, "__eq__", str, bool) # str == str = bool - - op(reg, int, "__lt__", float, bool) # int < float = bool - op(reg, int, "__gt__", float, bool) # int > float = bool - op(reg, int, "__le__", float, bool) # int <= float = bool - op(reg, int, "__ge__", float, bool) # int >= float = bool - op(reg, int, "__eq__", float, bool) # int == float = bool - - op(reg, float, "__lt__", int, bool) # float < int = bool - op(reg, float, "__gt__", int, bool) # float > int = bool - op(reg, float, "__le__", int, bool) # float <= int = bool - op(reg, float, "__ge__", int, bool) # float >= int = bool - op(reg, float, "__eq__", int, bool) # float == int = bool From 314d4d344bdd230e03cc3644ab9f405e57bb5857 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 8 Jun 2026 13:45:39 +0200 Subject: [PATCH 09/64] refactor(resolver): move resolver to checker module --- midas/checker/python.py | 2 +- midas/{resolver => checker}/resolver.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename midas/{resolver => checker}/resolver.py (100%) diff --git a/midas/checker/python.py b/midas/checker/python.py index 751497d..e1812d3 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -9,6 +9,7 @@ from midas.checker.environment import Environment from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter, Reporter +from midas.checker.resolver import Resolver from midas.checker.types import ( ComplexType, Function, @@ -19,7 +20,6 @@ from midas.checker.types import ( unfold_type, ) from midas.parser.python import PythonParser -from midas.resolver.resolver import Resolver class ReturnException(Exception): diff --git a/midas/resolver/resolver.py b/midas/checker/resolver.py similarity index 100% rename from midas/resolver/resolver.py rename to midas/checker/resolver.py From 098bbc35c5d24fe7f6a3c00743d20c6f0f5aad5f Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 8 Jun 2026 13:48:46 +0200 Subject: [PATCH 10/64] fix: avoid circular import in builtins.py --- midas/checker/builtins.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/midas/checker/builtins.py b/midas/checker/builtins.py index ac3e737..24dc288 100644 --- a/midas/checker/builtins.py +++ b/midas/checker/builtins.py @@ -1,6 +1,13 @@ -from midas.checker.registry import TypesRegistry +from __future__ import annotations + +from typing import TYPE_CHECKING + from midas.checker.types import BaseType, Type, UnitType +if TYPE_CHECKING: + from midas.checker.registry import TypesRegistry + + BUILTIN_SUBTYPES: dict[str, set[str]] = { "float": {"int"}, "int": {"bool"}, From c1f95edc96ac9ca856922f2383e6148f5c562e7e Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 8 Jun 2026 18:21:40 +0200 Subject: [PATCH 11/64] feat(types): add name to generic type --- midas/checker/midas.py | 8 +++++--- midas/checker/types.py | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/midas/checker/midas.py b/midas/checker/midas.py index a7ce36f..c55123f 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -73,11 +73,13 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type var = TypeVar(name=name, bound=bound) self._local_variables[name] = var params.append(var) + name: str = stmt.name.lexeme type: Type = stmt.type.accept(self) if len(params) != 0: - type = GenericType(params=params, body=type) - name: str = stmt.name.lexeme - self.types.define_type(name, AliasType(name=name, type=type)) + type = GenericType(name=name, params=params, body=type) + else: + type = AliasType(name=name, type=type) + self.types.define_type(name, type) self._local_variables.clear() def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... diff --git a/midas/checker/types.py b/midas/checker/types.py index ee41e14..8c95134 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -66,6 +66,7 @@ class TypeVar: @dataclass(frozen=True, kw_only=True) class GenericType: + name: str params: list[TypeVar] body: Type From 5a6a279eafce1a12ce13ee009fa5c30067a16010 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 8 Jun 2026 18:25:37 +0200 Subject: [PATCH 12/64] feat(checker): WIP add lists --- gen/python.py | 4 ++++ midas/ast/printer.py | 11 +++++++++++ midas/ast/python.py | 11 +++++++++++ midas/checker/builtins.py | 36 +++++++++++++++++++++++++++++++++++- midas/checker/python.py | 34 ++++++++++++++++++++++++++++++++++ midas/checker/resolver.py | 4 ++++ midas/cli/highlighter.py | 4 ++++ midas/parser/python.py | 7 +++++++ tests/serializer/python.py | 7 +++++++ 9 files changed, 117 insertions(+), 1 deletion(-) diff --git a/gen/python.py b/gen/python.py index e6d08c9..79ba8b0 100644 --- a/gen/python.py +++ b/gen/python.py @@ -139,4 +139,8 @@ class TernaryExpr: if_false: Expr +class ListExpr: + items: list[Expr] + + ###< diff --git a/midas/ast/printer.py b/midas/ast/printer.py index f8fb411..dc2e64c 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -626,3 +626,14 @@ class PythonAstPrinter( self._write_line("if_false", last=True) with self._child_level(single=True): expr.if_false.accept(self) + + def visit_list_expr(self, expr: p.ListExpr) -> None: + self._write_line("ListExpr") + with self._child_level(): + self._write_line("items", last=True) + with self._child_level(): + for i, item in enumerate(expr.items): + self._idx = i + if i == len(expr.items) - 1: + self._mark_last() + item.accept(self) diff --git a/midas/ast/python.py b/midas/ast/python.py index dd5d905..1aea8ed 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -220,6 +220,9 @@ class Expr(ABC): @abstractmethod def visit_ternary_expr(self, expr: TernaryExpr) -> T: ... + @abstractmethod + def visit_list_expr(self, expr: ListExpr) -> T: ... + @dataclass(frozen=True) class BinaryExpr(Expr): @@ -312,3 +315,11 @@ class TernaryExpr(Expr): def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_ternary_expr(self) + + +@dataclass(frozen=True) +class ListExpr(Expr): + items: list[Expr] + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_list_expr(self) diff --git a/midas/checker/builtins.py b/midas/checker/builtins.py index 24dc288..f20eb50 100644 --- a/midas/checker/builtins.py +++ b/midas/checker/builtins.py @@ -2,7 +2,15 @@ from __future__ import annotations from typing import TYPE_CHECKING -from midas.checker.types import BaseType, Type, UnitType +from midas.checker.types import ( + BaseType, + ComplexType, + Function, + GenericType, + Type, + TypeVar, + UnitType, +) if TYPE_CHECKING: from midas.checker.registry import TypesRegistry @@ -76,3 +84,29 @@ def define_builtins(reg: TypesRegistry): op(reg, float, "__le__", int, bool) # float <= int = bool op(reg, float, "__ge__", int, bool) # float >= int = bool op(reg, float, "__eq__", int, bool) # float == int = bool + + list = reg.define_type( + "list", + GenericType( + name="list", + params=[TypeVar(name="T", bound=None)], + body=ComplexType( + properties={ + "append": Function( + name="append", + pos_args=[ + Function.Argument( + pos=0, + name="object", + type=TypeVar(name="T", bound=None), + required=True, + ) + ], + args=[], + kw_args=[], + returns=UnitType(), + ) + } + ), + ), + ) diff --git a/midas/checker/python.py b/midas/checker/python.py index e1812d3..63a076e 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -499,6 +499,40 @@ class PythonTyper( ) return UnknownType() + def visit_list_expr(self, expr: p.ListExpr) -> Type: + list_type: Type = self.types.get_type("list") + item_types: list[Type] = [self.type_of(item) for item in expr.items] + + # Try to reduce types with subsumption + reduced: bool = True + keep: list[int] = list(range(len(item_types))) + while reduced: + reduced = False + for i, i1 in enumerate(keep): + type1: Type = item_types[i1] + for i2 in keep[i + 1 :]: + type2 = item_types[i2] + if self.types.is_subtype(type1, type2): + keep.remove(i1) + elif self.types.is_subtype(type2, type1): + keep.remove(i2) + else: + continue + reduced = True + break + + if len(keep) == 0: + return list_type + + if len(keep) == 1: + item_type: Type = item_types[keep[0]] + return self.types.apply_generic(list_type, [item_type]) + self.reporter.error( + expr.location, + f"Heterogeneous list items: {[item_types[i] for i in keep]}", + ) + return self.types.apply_generic(list_type, [UnknownType()]) + def visit_base_type(self, node: p.BaseType) -> Type: return self.types.get_type(node.base) diff --git a/midas/checker/resolver.py b/midas/checker/resolver.py index 18fcba4..0b7d990 100644 --- a/midas/checker/resolver.py +++ b/midas/checker/resolver.py @@ -180,3 +180,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): self.resolve(expr.test) self.resolve(expr.if_true) self.resolve(expr.if_false) + + def visit_list_expr(self, expr: p.ListExpr) -> None: + for item in expr.items: + self.resolve(item) diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index e4a9556..0d6a018 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -214,6 +214,10 @@ class PythonHighlighter( def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ... + def visit_list_expr(self, expr: p.ListExpr) -> None: + for item in expr.items: + item.accept(self) + class MidasHighlighter( Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None] diff --git a/midas/parser/python.py b/midas/parser/python.py index 79011bc..bbe23c8 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -17,6 +17,7 @@ from midas.ast.python import ( Function, GetExpr, IfStmt, + ListExpr, LiteralExpr, LogicalExpr, MidasType, @@ -416,6 +417,12 @@ class PythonParser: case ast.Name(id=name): return VariableExpr(location=location, name=name) + case ast.List(elts=items): + return ListExpr( + location=location, + items=[self.parse_expr(item) for item in items], + ) + case _: raise UnsupportedSyntaxError(node) diff --git a/tests/serializer/python.py b/tests/serializer/python.py index bab3f8c..833d4e4 100644 --- a/tests/serializer/python.py +++ b/tests/serializer/python.py @@ -16,6 +16,7 @@ from midas.ast.python import ( Function, GetExpr, IfStmt, + ListExpr, LiteralExpr, LogicalExpr, MidasType, @@ -245,3 +246,9 @@ class PythonAstJsonSerializer( "if_true": expr.if_true.accept(self), "if_false": expr.if_false.accept(self), } + + def visit_list_expr(self, expr: ListExpr) -> dict: + return { + "_type": "ListExpr", + "items": [item.accept(self) for item in expr.items], + } From 9474a7336a54cb45887c751a997bcad2cb9e8ad3 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Mon, 8 Jun 2026 18:26:11 +0200 Subject: [PATCH 13/64] feat(types): WIP add AppliedType --- midas/checker/python.py | 6 +++++- midas/checker/registry.py | 9 +++++++-- midas/checker/types.py | 8 ++++++++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/midas/checker/python.py b/midas/checker/python.py index 63a076e..a920990 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -534,7 +534,11 @@ class PythonTyper( return self.types.apply_generic(list_type, [UnknownType()]) def visit_base_type(self, node: p.BaseType) -> Type: - return self.types.get_type(node.base) + base: Type = self.types.get_type(node.base) + if node.param is not None: + param: Type = node.param.accept(self) + return self.types.apply_generic(base, [param]) + return base def visit_constraint_type(self, node: p.ConstraintType) -> Type: ... diff --git a/midas/checker/registry.py b/midas/checker/registry.py index 1324bbd..da5a7ee 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -3,6 +3,7 @@ from typing import Optional from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.types import ( AliasType, + AppliedType, BaseType, ComplexType, Function, @@ -254,7 +255,7 @@ class TypesRegistry: case AliasType(name=name, type=base): return AliasType(name=name, type=self.apply_generic(base, params)) - case GenericType(params=type_vars, body=body): + case GenericType(name=name, params=type_vars, body=body): n_params: int = len(params) n_type_vars: int = len(type_vars) if n_params < n_type_vars: @@ -274,7 +275,11 @@ class TypesRegistry: f"Type parameter {param} is not a subtype of {type_var.bound}" ) substitutions[type_var.name] = param - return substitute_typevars(body, substitutions) + return AppliedType( + name=name, + args=params, + body=substitute_typevars(body, substitutions), + ) case _: raise ValueError(f"{type} is not a generic type") diff --git a/midas/checker/types.py b/midas/checker/types.py index 8c95134..9081a95 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -71,6 +71,13 @@ class GenericType: body: Type +@dataclass(frozen=True, kw_only=True) +class AppliedType: + name: str + args: list[Type] + body: Type + + def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: def sub_argument(arg: Function.Argument): return Function.Argument( @@ -138,4 +145,5 @@ Type = ( | ComplexType | TypeVar | GenericType + | AppliedType ) From 32207c3d6f0b1cd4379790022e6190cf7a6fb841 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 9 Jun 2026 08:04:45 +0200 Subject: [PATCH 14/64] refactor(checker): extract reduce_types function --- midas/checker/python.py | 27 +++++---------------------- midas/checker/registry.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/midas/checker/python.py b/midas/checker/python.py index a920990..84963bb 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -502,34 +502,17 @@ class PythonTyper( def visit_list_expr(self, expr: p.ListExpr) -> Type: list_type: Type = self.types.get_type("list") item_types: list[Type] = [self.type_of(item) for item in expr.items] + item_types = self.types.reduce_types(item_types) - # Try to reduce types with subsumption - reduced: bool = True - keep: list[int] = list(range(len(item_types))) - while reduced: - reduced = False - for i, i1 in enumerate(keep): - type1: Type = item_types[i1] - for i2 in keep[i + 1 :]: - type2 = item_types[i2] - if self.types.is_subtype(type1, type2): - keep.remove(i1) - elif self.types.is_subtype(type2, type1): - keep.remove(i2) - else: - continue - reduced = True - break - - if len(keep) == 0: + if len(item_types) == 0: return list_type - if len(keep) == 1: - item_type: Type = item_types[keep[0]] + if len(item_types) == 1: + item_type: Type = item_types[0] return self.types.apply_generic(list_type, [item_type]) self.reporter.error( expr.location, - f"Heterogeneous list items: {[item_types[i] for i in keep]}", + f"Heterogeneous list items: {item_types}", ) return self.types.apply_generic(list_type, [UnknownType()]) diff --git a/midas/checker/registry.py b/midas/checker/registry.py index da5a7ee..585e3af 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -283,3 +283,31 @@ class TypesRegistry: case _: raise ValueError(f"{type} is not a generic type") + + def reduce_types(self, types: list[Type]) -> list[Type]: + """Reduce a list of types to remove subtypes and only keep the highest types + + Args: + types (list[Type]): the types to reduce + + Returns: + list[Type]: the reduced list of types + """ + + reduced: bool = True + keep: list[int] = list(range(len(types))) + while reduced: + reduced = False + for i, i1 in enumerate(keep): + type1: Type = types[i1] + for i2 in keep[i + 1 :]: + type2 = types[i2] + if self.is_subtype(type1, type2): + keep.remove(i1) + elif self.is_subtype(type2, type1): + keep.remove(i2) + else: + continue + reduced = True + break + return [types[i] for i in keep] From 3581b7600bb0a87edd89cb01088d3c753f0bfeed Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 9 Jun 2026 08:05:31 +0200 Subject: [PATCH 15/64] fix(checker): use reduce_types to infer return type --- midas/checker/python.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/midas/checker/python.py b/midas/checker/python.py index 84963bb..59f6881 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -204,13 +204,13 @@ class PythonTyper( inferred_return: Type = UnknownType() if not returned: env.return_types.append(UnitType()) - return_types: set[Type] = set(env.return_types) + return_types: list[Type] = self.types.reduce_types(env.return_types) if len(return_types) == 1: - inferred_return = list(return_types)[0] + inferred_return = return_types[0] elif len(return_types) > 1: self.reporter.error( stmt.location, - f"Mixed return types: {env.return_types}", + f"Mixed return types: {return_types}", ) returns: Type = UnknownType() From a78aee16395d9320afedfa3903c7eaaef0e8e4f9 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 9 Jun 2026 08:06:46 +0200 Subject: [PATCH 16/64] fix(resolver): define variable on assignment if a variable is not already defined when an assignment is visited, it is then defined in the current scope --- midas/checker/resolver.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/midas/checker/resolver.py b/midas/checker/resolver.py index 0b7d990..02fcbbc 100644 --- a/midas/checker/resolver.py +++ b/midas/checker/resolver.py @@ -13,7 +13,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): def __init__(self): self.locals: dict[p.Expr, int] = {} - self.scopes: list[dict[str, bool]] = [] + self.scopes: list[dict[str, bool]] = [{}] def resolve(self, *objects: p.Stmt | p.Expr) -> None: """Resolve the given statements or expressions""" @@ -77,6 +77,12 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): self.locals[expr] = i return + def is_defined(self, name: str) -> bool: + for scope in self.scopes: + if name in scope: + return True + return False + def resolve_function(self, function: p.Function) -> None: """Resolve a function definition @@ -111,7 +117,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): self.resolve(stmt.value) for target in stmt.targets: match target: - case p.VariableExpr() | p.GetExpr(): + case p.VariableExpr(name=name): + if not self.is_defined(name): + self.declare(name) + self.define(name) + target.accept(self) + + case p.GetExpr(): target.accept(self) case _: raise Exception(f"Unsupported assignment to {target}") From 4715318913a9d5addb50b0589ad6d6d782f9fba5 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 9 Jun 2026 12:59:36 +0200 Subject: [PATCH 17/64] feat(types): add human-friendly string rep add `__str__` methods on type structures to improve readability of diagnostics --- .../04_complex_types.py | 3 ++ midas/checker/python.py | 17 +++--- midas/checker/types.py | 54 ++++++++++++++++++- .../checker/02_simple_operations.py.ref.json | 2 +- tests/cases/checker/03_functions.py.ref.json | 4 +- 5 files changed, 66 insertions(+), 14 deletions(-) diff --git a/examples/01_simple_type_checking/04_complex_types.py b/examples/01_simple_type_checking/04_complex_types.py index f36ef52..63fd1e7 100644 --- a/examples/01_simple_type_checking/04_complex_types.py +++ b/examples/01_simple_type_checking/04_complex_types.py @@ -9,3 +9,6 @@ diff_y = p2.y - p1.y dist = diff_x + diff_y p2.x += cast(Meter, 1) +p2.y = True +p2.z = 3 +p2.x.a = 3 diff --git a/midas/checker/python.py b/midas/checker/python.py index 59f6881..c11ee22 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -273,7 +273,7 @@ class PythonTyper( if not self.is_subtype(value_type, var_type): self.reporter.error( location, - f"Cannot assign {value_type} to {name} of type {var_type}", + f"Cannot assign {value_type} to variable '{name}' of type {var_type}", ) def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type): @@ -283,7 +283,7 @@ class PythonTyper( case ComplexType(properties=properties): if target.name not in properties: self.reporter.error( - target.location, f"Unknown property '{target.name} on {object}" + target.location, f"Unknown property '{object}.{target.name}'" ) return @@ -291,7 +291,7 @@ class PythonTyper( if not self.is_subtype(value_type, prop_type): self.reporter.error( location, - f"Cannot assign {value_type} to property '{target.name}' of type {prop_type} on {object}", + f"Cannot assign {value_type} to property '{object}.{target.name}' of type {prop_type}", ) return @@ -301,7 +301,7 @@ class PythonTyper( case _: self.reporter.error( target.location, - f"Cannot assign {value_type} to unknown property '{target.name}' on {object}", + f"Cannot assign {value_type} to unknown property '{object}.{target.name}'", ) def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: @@ -365,6 +365,9 @@ class PythonTyper( if i == j: continue sig2: Operation.CallSignature = op2.signature + + # If op1 is not a full overload of op2 (i.e. operands of op1 are subtypes of op2's) + # ambiguity -> not best match if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype( sig1.right, sig2.right ): @@ -374,13 +377,9 @@ class PythonTyper( if best_match: return op1.result - overloads: list[str] = [ - f"({op.signature.left} {op.signature.method} {op.signature.right}) -> {op.result}" - for op in valid_operations - ] self.reporter.error( expr.location, - f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(overloads)}", + f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(map(str, valid_operations))}", ) return UnknownType() diff --git a/midas/checker/types.py b/midas/checker/types.py index 9081a95..41ad786 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -8,21 +8,29 @@ from typing import Optional class BaseType: name: str + def __str__(self) -> str: + return self.name + @dataclass(frozen=True, kw_only=True) class AliasType: name: str type: Type + def __str__(self) -> str: + return self.name + @dataclass(frozen=True, kw_only=True) class UnknownType: - pass + def __str__(self) -> str: + return "" @dataclass(frozen=True, kw_only=True) class UnitType: - pass + def __str__(self) -> str: + return "None" @dataclass(frozen=True, kw_only=True) @@ -33,6 +41,23 @@ class Function: kw_args: list[Argument] returns: Type + def __str__(self) -> str: + args: list[str] = [] + if len(self.pos_args) != 0: + args += list(map(str, self.pos_args)) + if len(self.args) + len(self.kw_args) != 0: + args.append("/") + + if len(self.args) != 0: + args += list(map(str, self.args)) + + if len(self.kw_args) != 0: + if len(args) != 0: + args.append("*") + args += list(map(str, self.kw_args)) + + return f"{self.name}({', '.join(args)}) -> {self.returns}" + @dataclass(frozen=True, kw_only=True) class Argument: pos: int @@ -40,29 +65,48 @@ class Function: type: Type required: bool + def __str__(self) -> str: + opt: str = "" if self.required else "?" + return f"{self.name}: {self.type}{opt}" + @dataclass(frozen=True, kw_only=True) class ComplexType: properties: dict[str, Type] + def __str__(self) -> str: + props: list[str] = [f"{name}: {type}" for name, type in self.properties.items()] + return f"{{{', '.join(props)}}}" + @dataclass(frozen=True, kw_only=True) class Operation: signature: CallSignature result: Type + def __str__(self) -> str: + return f"{self.signature} -> {self.result}" + @dataclass(frozen=True, kw_only=True) class CallSignature: left: Type method: str right: Type + def __str__(self) -> str: + return f"{self.method}({self.left}, {self.right})" + @dataclass(frozen=True, kw_only=True) class TypeVar: name: str bound: Optional[Type] + def __str__(self) -> str: + if self.bound is not None: + return f"{self.name} <: {self.bound}" + return self.name + @dataclass(frozen=True, kw_only=True) class GenericType: @@ -70,6 +114,9 @@ class GenericType: params: list[TypeVar] body: Type + def __str__(self) -> str: + return f"{self.name}[{', '.join(map(str, self.params))}]" + @dataclass(frozen=True, kw_only=True) class AppliedType: @@ -77,6 +124,9 @@ class AppliedType: args: list[Type] body: Type + def __str__(self) -> str: + return f"{self.name}[{', '.join(map(str, self.args))}]" + def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: def sub_argument(arg: Function.Argument): diff --git a/tests/cases/checker/02_simple_operations.py.ref.json b/tests/cases/checker/02_simple_operations.py.ref.json index 654af17..e3881e0 100644 --- a/tests/cases/checker/02_simple_operations.py.ref.json +++ b/tests/cases/checker/02_simple_operations.py.ref.json @@ -12,7 +12,7 @@ 13 ] }, - "message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')" + "message": "Cannot assign str to variable 'c' of type int" } ], "judgments": [ diff --git a/tests/cases/checker/03_functions.py.ref.json b/tests/cases/checker/03_functions.py.ref.json index cd0ce42..3442bca 100644 --- a/tests/cases/checker/03_functions.py.ref.json +++ b/tests/cases/checker/03_functions.py.ref.json @@ -236,7 +236,7 @@ 13 ] }, - "message": "Wrong type for argument 'a', expected BaseType(name='int'), got BaseType(name='str')" + "message": "Wrong type for argument 'a', expected int, got str" }, { "type": "Error", @@ -250,7 +250,7 @@ 25 ] }, - "message": "Wrong type for argument 'c', expected BaseType(name='str'), got BaseType(name='bool')" + "message": "Wrong type for argument 'c', expected str, got bool" } ], "judgments": [ From 380753ca7a9f3dc1ea8116a356fc526795529d7f Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 9 Jun 2026 15:30:45 +0200 Subject: [PATCH 18/64] refactor(types): extract TypeParams also rename generic type params to type args (when calling a generic) --- gen/gen.py | 15 ++++++- gen/midas.py | 21 ++++++---- midas/ast/midas.py | 17 ++++---- midas/ast/python.py | 1 + midas/checker/midas.py | 4 +- midas/checker/registry.py | 26 ++++++------ midas/cli/highlighter.py | 4 +- midas/parser/midas.py | 40 ++++++++++--------- .../01_simple_types.midas.ref.json | 12 +++--- tests/serializer/midas.py | 9 ++--- 10 files changed, 85 insertions(+), 64 deletions(-) diff --git a/gen/gen.py b/gen/gen.py index e78c872..50c9c9d 100644 --- a/gen/gen.py +++ b/gen/gen.py @@ -30,6 +30,7 @@ from __future__ import annotations T = TypeVar("T") +{preamble} {sections} """ @@ -57,6 +58,11 @@ IMPORTS_REGEX = re.compile( re.MULTILINE | re.DOTALL, ) +PREAMBLE_REGEX = re.compile( + r"^###>\s*Preamble\s*?\n(?P.*?)\n###<$", + re.MULTILINE | re.DOTALL, +) + def snake_case(text: str) -> str: return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_") @@ -88,13 +94,14 @@ def make_banner(text: str) -> str: def make_section(full_name: str, base: str, param: str, body: str) -> str: + print(f" Generating {full_name}") visitor_methods: list[str] = [] classes: list[str] = [] definitions: list[str] = body.strip("\n").split("\n\n\n") for cls in definitions: cls = cls.strip("\n") name: str = re.match("class (.*?):", cls).group(1) # type: ignore - print(f"Processing {name}") + print(f" Processing {name}") visitor_methods.append(make_visitor_method(name, param)) classes.append(make_class(name, cls, base)) @@ -107,6 +114,7 @@ def make_section(full_name: str, base: str, param: str, body: str) -> str: def generate(definitions_path: Path, out_path: Path): + print(f"Processing generating {out_path} from {definitions_path}") root_dir: Path = Path(__file__).parent.parent rel_path: Path = definitions_path.relative_to(root_dir) src: str = definitions_path.read_text() @@ -116,6 +124,10 @@ def generate(definitions_path: Path, out_path: Path): if m := IMPORTS_REGEX.search(src): imports = m.group("body").strip("\n") + preamble: str = "" + if m := PREAMBLE_REGEX.search(src): + preamble = m.group("body") + for section_m in SECTION_REGEX.finditer(src): full_name: str = section_m.group("name") base: str = section_m.group("base") @@ -129,6 +141,7 @@ def generate(definitions_path: Path, out_path: Path): gen_path=Path(__file__).relative_to(root_dir), ), imports=imports, + preamble=preamble, sections="\n\n\n".join(sections), ) out_path.write_text(result) diff --git a/gen/midas.py b/gen/midas.py index e1c304d..cca6f39 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -12,18 +12,23 @@ from midas.lexer.token import Token ###< +###> Preamble +@dataclass(frozen=True, kw_only=True) +class TypeParam: + location: Location + name: Token + bound: Optional[Type] + + +###< + + ###> Stmt | Statements class TypeStmt: name: Token - params: list[Param] + params: list[TypeParam] type: Type - @dataclass(frozen=True, kw_only=True) - class Param: - location: Location - name: Token - bound: Optional[Type] - class PropertyStmt: name: Token @@ -103,7 +108,7 @@ class NamedType: class GenericType: type: Type - params: list[Type] + args: list[Type] class ConstraintType: diff --git a/midas/ast/midas.py b/midas/ast/midas.py index 335e5cf..4459b52 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -14,6 +14,13 @@ from midas.lexer.token import Token T = TypeVar("T") +@dataclass(frozen=True, kw_only=True) +class TypeParam: + location: Location + name: Token + bound: Optional[Type] + + ############## # Statements # ############## @@ -46,15 +53,9 @@ class Stmt(ABC): @dataclass(frozen=True) class TypeStmt(Stmt): name: Token - params: list[Param] + params: list[TypeParam] type: Type - @dataclass(frozen=True, kw_only=True) - class Param: - location: Location - name: Token - bound: Optional[Type] - def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_type_stmt(self) @@ -243,7 +244,7 @@ class NamedType(Type): @dataclass(frozen=True) class GenericType(Type): type: Type - params: list[Type] + args: list[Type] def accept(self, visitor: Type.Visitor[T]) -> T: return visitor.visit_generic_type(self) diff --git a/midas/ast/python.py b/midas/ast/python.py index 1aea8ed..350dbb1 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -14,6 +14,7 @@ from midas.ast.location import Location T = TypeVar("T") + #################### # Type annotations # #################### diff --git a/midas/checker/midas.py b/midas/checker/midas.py index c55123f..2cb4ab1 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -122,8 +122,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type def visit_generic_type(self, type: m.GenericType) -> Type: type_: Type = type.type.accept(self) - params: list[Type] = [param.accept(self) for param in type.params] - return self.types.apply_generic(type_, params) + args: list[Type] = [arg.accept(self) for arg in type.args] + return self.types.apply_generic(type_, args) def visit_constraint_type(self, type: m.ConstraintType) -> Type: type_: Type = type.type.accept(self) diff --git a/midas/checker/registry.py b/midas/checker/registry.py index 585e3af..455c565 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -250,34 +250,34 @@ class TypesRegistry: return True - def apply_generic(self, type: Type, params: list[Type]) -> Type: + def apply_generic(self, type: Type, args: list[Type]) -> Type: match type: case AliasType(name=name, type=base): - return AliasType(name=name, type=self.apply_generic(base, params)) + return AliasType(name=name, type=self.apply_generic(base, args)) - case GenericType(name=name, params=type_vars, body=body): - n_params: int = len(params) + case GenericType(name=name, args=type_vars, body=body): + n_args: int = len(args) n_type_vars: int = len(type_vars) - if n_params < n_type_vars: + if n_args < n_type_vars: raise ValueError( - f"Missing type parameters, expected {n_type_vars} but only {n_params} provided" + f"Missing type arguments, expected {n_type_vars} but only {n_args} provided" ) - if n_params > n_type_vars: + if n_args > n_type_vars: raise ValueError( - f"Too many type parameters, expected {n_type_vars} but {n_params} provided" + f"Too many type arguments, expected {n_type_vars} but {n_args} provided" ) substitutions: dict[str, Type] = {} - for param, type_var in zip(params, type_vars): + for arg, type_var in zip(args, type_vars): if type_var.bound is not None and not self.is_subtype( - param, type_var.bound + arg, type_var.bound ): raise ValueError( - f"Type parameter {param} is not a subtype of {type_var.bound}" + f"Type argument {arg} is not a subtype of {type_var.bound}" ) - substitutions[type_var.name] = param + substitutions[type_var.name] = arg return AppliedType( name=name, - args=params, + args=args, body=substitute_typevars(body, substitutions), ) diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index 0d6a018..af0fb4d 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -288,8 +288,8 @@ class MidasHighlighter( def visit_generic_type(self, type: m.GenericType) -> None: self.wrap(type, "generic-type") type.type.accept(self) - for param in type.params: - param.accept(self) + for arg in type.args: + arg.accept(self) def visit_constraint_type(self, type: m.ConstraintType) -> None: self.wrap(type, "constraint-type") diff --git a/midas/parser/midas.py b/midas/parser/midas.py index 5d09b83..cd83b84 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -18,6 +18,7 @@ from midas.ast.midas import ( PropertyStmt, Stmt, Type, + TypeParam, TypeStmt, UnaryExpr, VariableExpr, @@ -108,9 +109,7 @@ class MidasParser(Parser): """ keyword: Token = self.previous() name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") - params: list[TypeStmt.Param] = [] - if self.check(TokenType.LEFT_BRACKET): - params = self.type_stmt_params() + params: list[TypeParam] = self.type_params() self.consume(TokenType.EQUAL, "Expected '=' before type definition") @@ -123,16 +122,19 @@ class MidasParser(Parser): type=type, ) - def type_stmt_params(self) -> list[TypeStmt.Param]: - """Parse a generic template expression + def type_params(self) -> list[TypeParam]: + """Parse a list of type parameters - A template is written `[TypeExpr]` + Type parameters are a comma-separated list of type variables wrapped in brackets. + Each type variable is either a simple variable, or a bounded variable written `S <: T` Returns: - TemplateExpr: the parsed template expression + list[TypeParam]: the list of type parameters, if any, or an empty list """ - self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression") - params: list[TypeStmt.Param] = [] + if not self.match(TokenType.LEFT_BRACKET): + return [] + + params: list[TypeParam] = [] while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET): name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable") bound: Optional[Type] = None @@ -140,7 +142,7 @@ class MidasParser(Parser): self.consume(TokenType.COLON, "Expected ':' after '<'") bound = self.type_expr() params.append( - TypeStmt.Param( + TypeParam( location=name.location_to(self.previous()), name=name, bound=bound, @@ -148,7 +150,7 @@ class MidasParser(Parser): ) if not self.match(TokenType.COMMA): break - self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression") + self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters") return params def type_expr(self) -> Type: @@ -187,23 +189,23 @@ class MidasParser(Parser): def generic_type(self) -> Type: type: Type = self.named_type() if self.check(TokenType.LEFT_BRACKET): - params: list[Type] = self.type_params() + args: list[Type] = self.type_args() return GenericType( location=Location.span(type.location, self.previous().get_location()), type=type, - params=params, + args=args, ) return type - def type_params(self) -> list[Type]: - params: list[Type] = [] - self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters") + def type_args(self) -> list[Type]: + args: list[Type] = [] + self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments") while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET): - params.append(self.type_expr()) + args.append(self.type_expr()) if not self.match(TokenType.COMMA): break - self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters") - return params + self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments") + return args def named_type(self) -> Type: name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") diff --git a/tests/cases/midas-parser/01_simple_types.midas.ref.json b/tests/cases/midas-parser/01_simple_types.midas.ref.json index 55b4813..1d94718 100644 --- a/tests/cases/midas-parser/01_simple_types.midas.ref.json +++ b/tests/cases/midas-parser/01_simple_types.midas.ref.json @@ -2385,7 +2385,7 @@ "_type": "NamedType", "name": "Difference" }, - "params": [ + "args": [ { "_type": "NamedType", "name": "GeoLocation" @@ -2416,7 +2416,7 @@ "_type": "NamedType", "name": "Difference" }, - "params": [ + "args": [ { "_type": "NamedType", "name": "Latitude" @@ -2433,7 +2433,7 @@ "_type": "NamedType", "name": "Difference" }, - "params": [ + "args": [ { "_type": "NamedType", "name": "Longitude" @@ -2464,7 +2464,7 @@ "_type": "NamedType", "name": "Difference" }, - "params": [ + "args": [ { "_type": "NamedType", "name": "Latitude" @@ -2494,7 +2494,7 @@ "_type": "NamedType", "name": "Difference" }, - "params": [ + "args": [ { "_type": "NamedType", "name": "Longitude" @@ -2638,7 +2638,7 @@ "_type": "NamedType", "name": "Optional" }, - "params": [ + "args": [ { "_type": "ConstraintType", "type": { diff --git a/tests/serializer/midas.py b/tests/serializer/midas.py index 919dc66..947641e 100644 --- a/tests/serializer/midas.py +++ b/tests/serializer/midas.py @@ -17,6 +17,7 @@ from midas.ast.midas import ( PropertyStmt, Stmt, Type, + TypeParam, TypeStmt, UnaryExpr, VariableExpr, @@ -46,13 +47,11 @@ class MidasAstJsonSerializer( return { "_type": "TypeStmt", "name": stmt.name.lexeme, - "params": [ - self._serialize_type_stmt_template_param(param) for param in stmt.params - ], + "params": [self._serialize_type_param(param) for param in stmt.params], "type": stmt.type.accept(self), } - def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict: + def _serialize_type_param(self, param: TypeParam) -> dict: return { "name": param.name.lexeme, "bound": self._serialize_optional(param.bound), @@ -150,7 +149,7 @@ class MidasAstJsonSerializer( return { "_type": "GenericType", "type": type.type.accept(self), - "params": self._serialize_list(type.params), + "args": self._serialize_list(type.args), } def visit_constraint_type(self, type: ConstraintType) -> dict: From f8897dd075c6569c0eb96d13a2587112211963b9 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 9 Jun 2026 23:22:45 +0200 Subject: [PATCH 19/64] feat(types): add type params to extend statement --- gen/midas.py | 1 + midas/ast/midas.py | 1 + midas/ast/printer.py | 29 +++++++++++++++++------------ midas/checker/midas.py | 12 +++--------- midas/checker/registry.py | 2 +- midas/parser/midas.py | 11 +++++++++-- 6 files changed, 32 insertions(+), 24 deletions(-) diff --git a/gen/midas.py b/gen/midas.py index cca6f39..2184f86 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -36,6 +36,7 @@ class PropertyStmt: class ExtendStmt: + params: list[TypeParam] type: Type operations: list[OpStmt] diff --git a/midas/ast/midas.py b/midas/ast/midas.py index 4459b52..d759079 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -71,6 +71,7 @@ class PropertyStmt(Stmt): @dataclass(frozen=True) class ExtendStmt(Stmt): + params: list[TypeParam] type: Type operations: list[OpStmt] diff --git a/midas/ast/printer.py b/midas/ast/printer.py index dc2e64c..82bd0b4 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -100,12 +100,12 @@ class MidasAstPrinter( self._idx = i if i == len(stmt.params) - 1: self._mark_last() - self._print_type_stmt_param(param) + self._print_type_param(param) self._write_line("type", last=True) with self._child_level(single=True): stmt.type.accept(self) - def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None: + def _print_type_param(self, param: m.TypeParam) -> None: self._write_line("Param") with self._child_level(): self._write_line(f'name: "{param.name.lexeme}"') @@ -122,6 +122,13 @@ class MidasAstPrinter( def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: self._write_line("ExtendStmt") with self._child_level(): + self._write_line("params") + with self._child_level(): + for i, param in enumerate(stmt.params): + self._idx = i + if i == len(stmt.params) - 1: + self._mark_last() + self._print_type_param(param) self._write_line("type") with self._child_level(single=True): stmt.type.accept(self) @@ -234,11 +241,11 @@ class MidasAstPrinter( self._write_line("type") with self._child_level(): type.type.accept(self) - self._write_line("params", last=True) + self._write_line("args", last=True) with self._child_level(): - for i, param in enumerate(type.params): + for i, param in enumerate(type.args): self._idx = i - if i == len(type.params) - 1: + if i == len(type.args) - 1: self._mark_last() param.accept(self) @@ -279,14 +286,12 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] def visit_type_stmt(self, stmt: m.TypeStmt) -> str: template: str = "" if len(stmt.params) != 0: - params: list[str] = [ - self._print_type_template_param(param) for param in stmt.params - ] + params: list[str] = [self._print_type_param(param) for param in stmt.params] template = f"[{', '.join(params)}]" res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}" return self.indented(res) - def _print_type_template_param(self, param: m.TypeStmt.Param) -> str: + def _print_type_param(self, param: m.TypeParam) -> str: res: str = param.name.lexeme if param.bound is not None: res += "<:" + param.bound.accept(self) @@ -358,9 +363,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] def visit_generic_type(self, type: m.GenericType) -> str: res: str = type.type.accept(self) - if len(type.params) != 0: - params: list[str] = [param.accept(self) for param in type.params] - res += f"[{', '.join(params)}]" + if len(type.args) != 0: + args: list[str] = [param.accept(self) for param in type.args] + res += f"[{', '.join(args)}]" return res def visit_constraint_type(self, type: m.ConstraintType) -> str: diff --git a/midas/checker/midas.py b/midas/checker/midas.py index 2cb4ab1..a6d86a9 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -64,15 +64,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type stmt.accept(self) def visit_type_stmt(self, stmt: m.TypeStmt) -> None: - params: list[TypeVar] = [] - for param in stmt.params: - name: str = param.name.lexeme - bound: Optional[Type] = None - if param.bound is not None: - bound = param.bound.accept(self) - var = TypeVar(name=name, bound=bound) - self._local_variables[name] = var - params.append(var) + params: list[TypeVar] = self._resolve_type_params(stmt.params) + name: str = stmt.name.lexeme type: Type = stmt.type.accept(self) if len(params) != 0: @@ -85,6 +78,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: + self._resolve_type_params(stmt.params) base: Type = stmt.type.accept(self) for op in stmt.operations: right: Type = op.operand.accept(self) diff --git a/midas/checker/registry.py b/midas/checker/registry.py index 455c565..d5c432a 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -255,7 +255,7 @@ class TypesRegistry: case AliasType(name=name, type=base): return AliasType(name=name, type=self.apply_generic(base, args)) - case GenericType(name=name, args=type_vars, body=body): + case GenericType(name=name, params=type_vars, body=body): n_args: int = len(args) n_type_vars: int = len(type_vars) if n_args < n_type_vars: diff --git a/midas/parser/midas.py b/midas/parser/midas.py index cd83b84..35c7a97 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -383,12 +383,14 @@ class MidasParser(Parser): def extend_declaration(self) -> ExtendStmt: """Parse an extension definition - An extension is written `extend Type { operations }` + An extension is written `extend Type { operations }` or `extend[S <: T, U] Type { operations }` Returns: ExtendStmt: the parsed extension statement """ keyword: Token = self.previous() + params: list[TypeParam] = self.type_params() + type: Type = self.type_expr() self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body") operations: list[OpStmt] = [] @@ -396,7 +398,12 @@ class MidasParser(Parser): operations.append(self.op_declaration()) self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body") location: Location = keyword.location_to(self.previous()) - return ExtendStmt(location=location, type=type, operations=operations) + return ExtendStmt( + location=location, + params=params, + type=type, + operations=operations, + ) def op_declaration(self) -> OpStmt: """Parse an operation definition From 9fde115016dc3e6e6f04dec8eb665f537f08cb6f Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 9 Jun 2026 23:48:06 +0200 Subject: [PATCH 20/64] feat: add function type to midas syntax --- gen/midas.py | 13 +++++++ midas/ast/midas.py | 20 ++++++++++ midas/ast/printer.py | 58 ++++++++++++++++++++++++++++ midas/checker/midas.py | 38 +++++++++++++++++++ midas/cli/highlighter.py | 6 +++ midas/lexer/midas.py | 6 ++- midas/lexer/token.py | 4 +- midas/parser/midas.py | 80 +++++++++++++++++++++++++++++++++------ tests/serializer/midas.py | 16 ++++++++ 9 files changed, 226 insertions(+), 15 deletions(-) diff --git a/gen/midas.py b/gen/midas.py index 2184f86..16a4dd8 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -121,4 +121,17 @@ class ComplexType: properties: list[PropertyStmt] +class FunctionType: + pos_args: list[Argument] + kw_args: list[Argument] + returns: Type + + @dataclass(frozen=True, kw_only=True) + class Argument: + location: Optional[Location] = None + name: Optional[Token] + type: Type + required: bool + + ###< diff --git a/midas/ast/midas.py b/midas/ast/midas.py index d759079..00e71c8 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -233,6 +233,9 @@ class Type(ABC): @abstractmethod def visit_complex_type(self, type: ComplexType) -> T: ... + @abstractmethod + def visit_function_type(self, type: FunctionType) -> T: ... + @dataclass(frozen=True) class NamedType(Type): @@ -266,3 +269,20 @@ class ComplexType(Type): def accept(self, visitor: Type.Visitor[T]) -> T: return visitor.visit_complex_type(self) + + +@dataclass(frozen=True) +class FunctionType(Type): + pos_args: list[Argument] + kw_args: list[Argument] + returns: Type + + @dataclass(frozen=True, kw_only=True) + class Argument: + location: Optional[Location] = None + name: Optional[Token] + type: Type + required: bool + + def accept(self, visitor: Type.Visitor[T]) -> T: + return visitor.visit_function_type(self) diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 82bd0b4..5d109ef 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -270,6 +270,41 @@ class MidasAstPrinter( self._mark_last() prop.accept(self) + def visit_function_type(self, type: m.FunctionType) -> None: + self._write_line("FunctionType") + with self._child_level(): + self._write_line("pos_args") + with self._child_level(): + for i, arg in enumerate(type.pos_args): + self._idx = i + if i == len(type.pos_args) - 1: + self._mark_last() + self._print_function_arg(arg) + + self._write_line("kw_args") + with self._child_level(): + for i, arg in enumerate(type.kw_args): + self._idx = i + if i == len(type.kw_args) - 1: + self._mark_last() + self._print_function_arg(arg) + + self._write_line("returns", last=True) + with self._child_level(single=True): + type.returns.accept(self) + + def _print_function_arg(self, arg: m.FunctionType.Argument) -> None: + self._write_line("Argument") + with self._child_level(): + name: str = "None" + if arg.name is not None: + name = f'"{arg.name.lexeme}"' + self._write_line(f"name: {name}") + self._write_line("type") + with self._child_level(single=True): + arg.type.accept(self) + self._write_line(f"required: {arg.required}", last=True) + class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]): def __init__(self, indent: int = 4): @@ -383,6 +418,29 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] res += self.indented("}") return res + def visit_function_type(self, type: m.FunctionType) -> str: + pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args] + kw_args: list[str] = [self._print_arg(arg) for arg in type.pos_args] + args: list[str] = pos_args + + if len(pos_args) != 0: + args.append("/") + if len(kw_args) != 0: + args.append("*") + args += kw_args + + return f"({', '.join(args)}) -> {type.returns.accept(self)}" + + def _print_arg(self, arg: m.FunctionType.Argument) -> str: + res: str = "" + if arg.name is not None: + res += arg.name.lexeme + res += ": " + res += arg.type.accept(self) + if not arg.required: + res += "?" + return res + class PythonAstPrinter( AstPrinter, diff --git a/midas/checker/midas.py b/midas/checker/midas.py index a6d86a9..decb40c 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -8,6 +8,7 @@ from midas.checker.reporter import FileReporter, Reporter from midas.checker.types import ( AliasType, ComplexType, + Function, GenericType, Type, TypeVar, @@ -131,3 +132,40 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type prop.name.lexeme: prop.type.accept(self) for prop in type.properties } ) + + def visit_function_type(self, type: m.FunctionType) -> Type: + return Function( + name="", + pos_args=[ + Function.Argument( + pos=i, + name=arg.name.lexeme if arg.name is not None else str(i), + type=arg.type.accept(self), + required=arg.required, + ) + for i, arg in enumerate(type.pos_args) + ], + args=[], + kw_args=[ + Function.Argument( + pos=i, + name=arg.name.lexeme if arg.name is not None else str(i), + type=arg.type.accept(self), + required=arg.required, + ) + for i, arg in enumerate(type.kw_args, start=len(type.pos_args)) + ], + returns=type.returns.accept(self), + ) + + def _resolve_type_params(self, params: list[m.TypeParam]): + vars: list[TypeVar] = [] + for param in params: + name: str = param.name.lexeme + bound: Optional[Type] = None + if param.bound is not None: + bound = param.bound.accept(self) + var = TypeVar(name=name, bound=bound) + self._local_variables[name] = var + vars.append(var) + return vars diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index af0fb4d..16fdf94 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -301,6 +301,12 @@ class MidasHighlighter( for prop in type.properties: prop.accept(self) + def visit_function_type(self, type: m.FunctionType) -> None: + self.wrap(type, "function") + for arg in type.pos_args + type.kw_args: + arg.type.accept(self) + type.returns.accept(self) + class DiagnosticsHighlighter(Highlighter): EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css" diff --git a/midas/lexer/midas.py b/midas/lexer/midas.py index 124ea09..c3246fc 100644 --- a/midas/lexer/midas.py +++ b/midas/lexer/midas.py @@ -50,12 +50,14 @@ class MidasLexer(Lexer): # self.add_token(TokenType.PLUS) case "-": self.add_token(TokenType.MINUS) - # case "*": - # self.add_token(TokenType.STAR) + case "*": + self.add_token(TokenType.STAR) case "/" if self.match("/"): self.scan_comment() case "/" if self.match("*"): self.scan_comment_multiline() + case "/": + self.add_token(TokenType.SLASH) case "\n": self.add_token(TokenType.NEWLINE) case " " | "\r" | "\t": diff --git a/midas/lexer/token.py b/midas/lexer/token.py index f08964a..74bf7b0 100644 --- a/midas/lexer/token.py +++ b/midas/lexer/token.py @@ -27,8 +27,8 @@ class TokenType(Enum): # Operators # PLUS = auto() MINUS = auto() - # STAR = auto() - # SLASH = auto() + STAR = auto() + SLASH = auto() GREATER = auto() GREATER_EQUAL = auto() LESS = auto() diff --git a/midas/parser/midas.py b/midas/parser/midas.py index 35c7a97..ce5d3f9 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -7,6 +7,7 @@ from midas.ast.midas import ( ConstraintType, Expr, ExtendStmt, + FunctionType, GenericType, GetExpr, GroupingExpr, @@ -24,7 +25,7 @@ from midas.ast.midas import ( VariableExpr, WildcardExpr, ) -from midas.lexer.token import Token, TokenType +from midas.lexer.token import KEYWORDS, Token, TokenType from midas.parser.base import Parser from midas.parser.errors import ParsingError @@ -108,7 +109,7 @@ class MidasParser(Parser): TypeStmt: the parsed type declaration statement """ keyword: Token = self.previous() - name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") + name: Token = self.consume_identifier("Expected type name") params: list[TypeParam] = self.type_params() self.consume(TokenType.EQUAL, "Expected '=' before type definition") @@ -136,7 +137,7 @@ class MidasParser(Parser): params: list[TypeParam] = [] while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET): - name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable") + name: Token = self.consume_identifier("Expected type variable") bound: Optional[Type] = None if self.match(TokenType.LESS): self.consume(TokenType.COLON, "Expected ':' after '<'") @@ -208,7 +209,7 @@ class MidasParser(Parser): return args def named_type(self) -> Type: - name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") + name: Token = self.consume_identifier("Expected type name") return NamedType( location=name.get_location(), name=name, @@ -324,9 +325,7 @@ class MidasParser(Parser): """ expr: Expr = self.primary() while self.match(TokenType.DOT): - name: Token = self.consume( - TokenType.IDENTIFIER, "Expected property name after '.'" - ) + name: Token = self.consume_identifier("Expected property name after '.'") location: Location = Location.span(expr.location, name.get_location()) expr = GetExpr(location=location, expr=expr, name=name) return expr @@ -350,7 +349,7 @@ class MidasParser(Parser): if self.match(TokenType.NUMBER): return LiteralExpr(location=token.get_location(), value=token.value) - if self.match(TokenType.IDENTIFIER): + if self.match_identifier(): return VariableExpr(location=token.get_location(), name=token) if self.match(TokenType.UNDERSCORE): @@ -363,6 +362,20 @@ class MidasParser(Parser): raise self.error(self.peek(), "Expected expression") + def consume_identifier(self, message: str = "Expected identifier") -> Token: + if not self.match_identifier(): + raise self.error(self.peek(), message) + return self.previous() + + def match_identifier(self) -> bool: + return self.match(TokenType.IDENTIFIER, *KEYWORDS.values()) + + def check_identifier(self) -> bool: + for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]: + if self.check(tt): + return True + return False + def property_stmt(self) -> PropertyStmt: """Parse a property statement @@ -371,7 +384,7 @@ class MidasParser(Parser): Returns: PropertyStmt: the parsed property statement """ - name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name") + name: Token = self.consume_identifier("Expected property name") self.consume(TokenType.COLON, "Expected ':' after property name") type: Type = self.type_expr() return PropertyStmt( @@ -439,9 +452,9 @@ class MidasParser(Parser): PredicateStmt: the parsed predicate declaration statement """ keyword: Token = self.previous() - name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name") + name: Token = self.consume_identifier("Expected predicate name") self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject") - subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name") + subject: Token = self.consume_identifier("Expected subject name") self.consume(TokenType.COLON, "Expected ':' after subject name") type: Type = self.type_expr() self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject") @@ -454,3 +467,48 @@ class MidasParser(Parser): type=type, condition=condition, ) + + def function(self) -> FunctionType: + l_paren: Token = self.consume( + TokenType.LEFT_PAREN, "Expected '(' before function parameters" + ) + pos_args: list[FunctionType.Argument] = [] + kw_args: list[FunctionType.Argument] = [] + + positional: bool = True + while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN): + if positional and ( + self.match(TokenType.STAR) or self.match(TokenType.SLASH) + ): + positional = False + else: + name: Optional[Token] = None + if self.check_identifier() and self.check_next(TokenType.COLON): + name = self.advance() + self.advance() + type: Type = self.type_expr() + required: bool = self.match(TokenType.QMARK) + arg = FunctionType.Argument( + location=None, + name=name, + type=type, + required=required, + ) + if positional: + pos_args.append(arg) + else: + kw_args.append(arg) + + if not self.match(TokenType.COMMA): + break + self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters") + + self.consume(TokenType.ARROW, "Expected '->' before result type") + result: Type = self.type_expr() + + return FunctionType( + location=l_paren.location_to(self.previous()), + pos_args=pos_args, + kw_args=kw_args, + returns=result, + ) diff --git a/tests/serializer/midas.py b/tests/serializer/midas.py index 947641e..2a5daf5 100644 --- a/tests/serializer/midas.py +++ b/tests/serializer/midas.py @@ -6,6 +6,7 @@ from midas.ast.midas import ( ConstraintType, Expr, ExtendStmt, + FunctionType, GenericType, GetExpr, GroupingExpr, @@ -164,3 +165,18 @@ class MidasAstJsonSerializer( "_type": "ComplexType", "properties": self._serialize_list(type.properties), } + + def visit_function_type(self, type: FunctionType) -> dict: + return { + "_type": "FunctionType", + "pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args], + "kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args], + "returns": type.returns.accept(self), + } + + def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict: + return { + "name": arg.name, + "type": arg.type.accept(self), + "required": arg.required, + } From 3d5f97a0f47b95e5408056967909d070abd16621 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Thu, 11 Jun 2026 13:42:19 +0200 Subject: [PATCH 21/64] feat(parser): add extension type and rename properties --- gen/midas.py | 9 +++++++-- midas/ast/midas.py | 20 ++++++++++++++++---- midas/ast/printer.py | 35 ++++++++++++++++++++++++----------- 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/gen/midas.py b/gen/midas.py index 16a4dd8..4141217 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -30,7 +30,7 @@ class TypeStmt: type: Type -class PropertyStmt: +class MemberStmt: name: Token type: Type @@ -118,7 +118,12 @@ class ConstraintType: class ComplexType: - properties: list[PropertyStmt] + members: list[MemberStmt] + + +class ExtensionType: + base: Type + extension: ComplexType class FunctionType: diff --git a/midas/ast/midas.py b/midas/ast/midas.py index 00e71c8..36d959b 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -38,7 +38,7 @@ class Stmt(ABC): def visit_type_stmt(self, stmt: TypeStmt) -> T: ... @abstractmethod - def visit_property_stmt(self, stmt: PropertyStmt) -> T: ... + def visit_member_stmt(self, stmt: MemberStmt) -> T: ... @abstractmethod def visit_extend_stmt(self, stmt: ExtendStmt) -> T: ... @@ -61,12 +61,12 @@ class TypeStmt(Stmt): @dataclass(frozen=True) -class PropertyStmt(Stmt): +class MemberStmt(Stmt): name: Token type: Type def accept(self, visitor: Stmt.Visitor[T]) -> T: - return visitor.visit_property_stmt(self) + return visitor.visit_member_stmt(self) @dataclass(frozen=True) @@ -233,6 +233,9 @@ class Type(ABC): @abstractmethod def visit_complex_type(self, type: ComplexType) -> T: ... + @abstractmethod + def visit_extension_type(self, type: ExtensionType) -> T: ... + @abstractmethod def visit_function_type(self, type: FunctionType) -> T: ... @@ -265,12 +268,21 @@ class ConstraintType(Type): @dataclass(frozen=True) class ComplexType(Type): - properties: list[PropertyStmt] + members: list[MemberStmt] def accept(self, visitor: Type.Visitor[T]) -> T: return visitor.visit_complex_type(self) +@dataclass(frozen=True) +class ExtensionType(Type): + base: Type + extension: ComplexType + + def accept(self, visitor: Type.Visitor[T]) -> T: + return visitor.visit_extension_type(self) + + @dataclass(frozen=True) class FunctionType(Type): pos_args: list[Argument] diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 5d109ef..2a5eec3 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -111,8 +111,8 @@ class MidasAstPrinter( self._write_line(f'name: "{param.name.lexeme}"') self._write_optional_child("bound", param.bound, last=True) - def visit_property_stmt(self, stmt: m.PropertyStmt): - self._write_line("PropertyStmt") + def visit_member_stmt(self, stmt: m.MemberStmt): + self._write_line("MemberStmt") with self._child_level(): self._write_line(f'name: "{stmt.name.lexeme}"') self._write_line("type", last=True) @@ -262,13 +262,23 @@ class MidasAstPrinter( def visit_complex_type(self, type: m.ComplexType) -> None: self._write_line("ComplexType") with self._child_level(): - self._write_line("properties", last=True) + self._write_line("members", last=True) with self._child_level(): - for i, prop in enumerate(type.properties): + for i, member in enumerate(type.members): self._idx = i - if i == len(type.properties) - 1: + if i == len(type.members) - 1: self._mark_last() - prop.accept(self) + member.accept(self) + + def visit_extension_type(self, type: m.ExtensionType) -> None: + self._write_line("ExtensionType") + with self._child_level(): + self._write_line("base") + with self._child_level(single=True): + type.base.accept(self) + self._write_line("extension", last=True) + with self._child_level(single=True): + type.extension.accept(self) def visit_function_type(self, type: m.FunctionType) -> None: self._write_line("FunctionType") @@ -332,7 +342,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] res += "<:" + param.bound.accept(self) return res - def visit_property_stmt(self, stmt: m.PropertyStmt): + def visit_member_stmt(self, stmt: m.MemberStmt): res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}" return self.indented(res) @@ -411,16 +421,19 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] def visit_complex_type(self, type: m.ComplexType) -> str: res: str = "{\n" self.level += 1 - for prop in type.properties: - res += prop.accept(self) + for member in type.members: + res += member.accept(self) res += "\n" self.level -= 1 res += self.indented("}") return res + def visit_extension_type(self, type: m.ExtensionType) -> str: + return f"{type.base.accept(self)} & {type.extension.accept(self)}" + def visit_function_type(self, type: m.FunctionType) -> str: pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args] - kw_args: list[str] = [self._print_arg(arg) for arg in type.pos_args] + kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args] args: list[str] = pos_args if len(pos_args) != 0: @@ -429,7 +442,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] args.append("*") args += kw_args - return f"({', '.join(args)}) -> {type.returns.accept(self)}" + return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}" def _print_arg(self, arg: m.FunctionType.Argument) -> str: res: str = "" From 900be47d34721deee1f5c335d330e6b444e49017 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Thu, 11 Jun 2026 13:49:47 +0200 Subject: [PATCH 22/64] feat(parser): add new ast nodes to parser --- midas/lexer/token.py | 2 ++ midas/parser/midas.py | 55 +++++++++++++++++++++++++++---------------- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/midas/lexer/token.py b/midas/lexer/token.py index 74bf7b0..60b6e47 100644 --- a/midas/lexer/token.py +++ b/midas/lexer/token.py @@ -50,6 +50,7 @@ class TokenType(Enum): PREDICATE = auto() EXTEND = auto() WHERE = auto() + FUNC = auto() # Misc COMMENT = auto() @@ -67,6 +68,7 @@ KEYWORDS: dict[str, TokenType] = { "true": TokenType.TRUE, "false": TokenType.FALSE, "none": TokenType.NONE, + "fn": TokenType.FUNC, } diff --git a/midas/parser/midas.py b/midas/parser/midas.py index ce5d3f9..0d0cbde 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -7,16 +7,17 @@ from midas.ast.midas import ( ConstraintType, Expr, ExtendStmt, + ExtensionType, FunctionType, GenericType, GetExpr, GroupingExpr, LiteralExpr, LogicalExpr, + MemberStmt, NamedType, OpStmt, PredicateStmt, - PropertyStmt, Stmt, Type, TypeParam, @@ -163,7 +164,19 @@ class MidasParser(Parser): Returns: TypeExpr: the parsed type expression """ - return self.constraint_type() + base: Type + if self.match(TokenType.FUNC): + base = self.function() + else: + base = self.constraint_type() + if self.match(TokenType.AND): + extension: ComplexType = self.complex_type() + return ExtensionType( + location=Location.span(base.location, extension.location), + base=base, + extension=extension, + ) + return base def constraint_type(self) -> Type: type: Type = self.base_type() @@ -215,30 +228,32 @@ class MidasParser(Parser): name=name, ) - def complex_type(self) -> Type: + def complex_type(self) -> ComplexType: """Parse a type definition body A type definition body is a set of whitespace-separated property statements enclosed in curly braces Returns: - list[PropertyStmt]: the parsed type properties + ComplexType: the parsed complex type """ left: Token = self.consume( TokenType.LEFT_BRACE, "Expected '{' to start type body" ) - properties: list[PropertyStmt] = [] + members: list[MemberStmt] = [] + # TODO: add keyword to differentiate properties and methods, + # and allow multiple methods with the same name but not properties names: set[str] = set() while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end(): - prop: PropertyStmt = self.property_stmt() - if prop.name.lexeme in names: - raise self.error(prop.name, "Duplicate property") - names.add(prop.name.lexeme) - properties.append(prop) + member: MemberStmt = self.member_stmt() + # if member.name.lexeme in names: + # raise self.error(member.name, "Duplicate property") + # names.add(member.name.lexeme) + members.append(member) right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body") return ComplexType( location=left.location_to(right), - properties=properties, + members=members, ) def constraint(self) -> Expr: @@ -376,18 +391,18 @@ class MidasParser(Parser): return True return False - def property_stmt(self) -> PropertyStmt: - """Parse a property statement + def member_stmt(self) -> MemberStmt: + """Parse a member statement - A type property statement is written `name: Type` or `name: Type where Condition` + A type member statement is written `name: Type` Returns: - PropertyStmt: the parsed property statement + MemberStmt: the parsed member statement """ - name: Token = self.consume_identifier("Expected property name") - self.consume(TokenType.COLON, "Expected ':' after property name") + name: Token = self.consume_identifier("Expected member name") + self.consume(TokenType.COLON, "Expected ':' after member name") type: Type = self.type_expr() - return PropertyStmt( + return MemberStmt( location=name.location_to(self.previous()), name=name, type=type, @@ -487,12 +502,12 @@ class MidasParser(Parser): name = self.advance() self.advance() type: Type = self.type_expr() - required: bool = self.match(TokenType.QMARK) + optional: bool = self.match(TokenType.QMARK) arg = FunctionType.Argument( location=None, name=name, type=type, - required=required, + required=not optional, ) if positional: pos_args.append(arg) From d9100d83005c7a9f64d5b6c9705a35e07458cb0e Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Thu, 11 Jun 2026 17:12:50 +0200 Subject: [PATCH 23/64] feat(checker): adapt typers to members and extension type --- midas/checker/midas.py | 16 +++++--- midas/checker/python.py | 81 +++++++++++++++++++++++++++-------------- midas/checker/types.py | 38 ++++++++++++++----- 3 files changed, 92 insertions(+), 43 deletions(-) diff --git a/midas/checker/midas.py b/midas/checker/midas.py index decb40c..6a528a0 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -8,6 +8,7 @@ from midas.checker.reporter import FileReporter, Reporter from midas.checker.types import ( AliasType, ComplexType, + ExtensionType, Function, GenericType, Type, @@ -76,7 +77,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type self.types.define_type(name, type) self._local_variables.clear() - def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... + def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ... def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: self._resolve_type_params(stmt.params) @@ -126,16 +127,21 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type # TODO return UnknownType() - def visit_complex_type(self, type: m.ComplexType) -> Type: + def visit_complex_type(self, type: m.ComplexType) -> ComplexType: return ComplexType( - properties={ - prop.name.lexeme: prop.type.accept(self) for prop in type.properties + members={ + member.name.lexeme: member.type.accept(self) for member in type.members } ) + def visit_extension_type(self, type: m.ExtensionType) -> Type: + return ExtensionType( + base=type.base.accept(self), + extension=self.visit_complex_type(type.extension), + ) + def visit_function_type(self, type: m.FunctionType) -> Type: return Function( - name="", pos_args=[ Function.Argument( pos=i, diff --git a/midas/checker/python.py b/midas/checker/python.py index c11ee22..3d44975 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -12,6 +12,7 @@ from midas.checker.reporter import FileReporter, Reporter from midas.checker.resolver import Resolver from midas.checker.types import ( ComplexType, + ExtensionType, Function, Operation, Type, @@ -192,7 +193,6 @@ class PythonTyper( returns_hint = stmt.returns.accept(self) # Early define to handle simple fully-typed recursion inside_function: Function = Function( - name=stmt.name, pos_args=pos_args, args=args, kw_args=kw_args, @@ -227,7 +227,6 @@ class PythonTyper( # TODO: handle *args and **kwargs sinks function: Function = Function( - name=stmt.name, pos_args=pos_args, args=args, kw_args=kw_args, @@ -250,8 +249,9 @@ class PythonTyper( case p.VariableExpr(): self._assign_var(location, target, value_type) - case p.GetExpr(): - self._assign_attr(location, target, value_type) + case p.GetExpr(object=object, name=name): + object_type: Type = self.type_of(object) + self._assign_attr(location, object_type, name, value_type) case _: if not isinstance(target, p.VariableExpr): @@ -276,32 +276,43 @@ class PythonTyper( f"Cannot assign {value_type} to variable '{name}' of type {var_type}", ) - def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type): - object: Type = self.type_of(target.object) + def _assign_attr( + self, location: Location, object: Type, name: str, value_type: Type + ): + # TODO: improve recursion to have better error messages base_object: Type = unfold_type(object) match base_object: - case ComplexType(properties=properties): - if target.name not in properties: + case ComplexType(members=members): + if name not in members: + self.reporter.error(location, f"Unknown member '{object}.{name}'") + return + + member_type: Type = members[name] + if not self.is_subtype(value_type, member_type): self.reporter.error( - target.location, f"Unknown property '{object}.{target.name}'" + location, + f"Cannot assign {value_type} to member '{object}.{name}' of type {member_type}", ) return - prop_type: Type = properties[target.name] - if not self.is_subtype(value_type, prop_type): - self.reporter.error( - location, - f"Cannot assign {value_type} to property '{object}.{target.name}' of type {prop_type}", - ) - return + case ExtensionType(base=base, extension=ComplexType(members=members)): + if name in members: + member_type: Type = members[name] + if not self.is_subtype(value_type, member_type): + self.reporter.error( + location, + f"Cannot assign {value_type} to member '{object}.{name}' of type {member_type}", + ) + return + return self._assign_attr(location, base, name, value_type) case UnknownType(): pass case _: self.reporter.error( - target.location, - f"Cannot assign {value_type} to unknown property '{object}.{target.name}'", + location, + f"Cannot assign {value_type} to unknown property '{object}.{name}'", ) def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: @@ -422,23 +433,37 @@ class PythonTyper( def visit_get_expr(self, expr: p.GetExpr) -> Type: object: Type = self.type_of(expr.object) + member: Optional[Type] = self._get_member(object, expr.name) + if member is None: + self.reporter.error( + expr.location, f"Unknown property '{expr.name}' on {object}" + ) + return UnknownType() + self.logger.debug(f"Property '{expr.name}' on {object} has type {member}") + return member + + def _get_member(self, object: Type, name: str) -> Optional[Type]: base_object: Type = unfold_type(object) match base_object: - case ComplexType(properties=properties): - if expr.name not in properties: - self.reporter.error( - expr.location, f"Unknown property '{expr.name} on {object}" - ) - return UnknownType() - return properties[expr.name] + case ComplexType(members=members): + if name in members: + return members[name] + self.logger.debug(f"No property '{name}' in {base_object}") + return None + + case ExtensionType(base=base, extension=ComplexType(members=members)): + if name in members: + return members[name] + self.logger.debug( + f"No property '{name}' on {base_object}, looking up in base" + ) + return self._get_member(base, name) case UnknownType(): return UnknownType() case _: - self.reporter.error( - expr.location, f"Cannot get property '{expr.name}' on {object}" - ) + self.logger.debug(f"Can't get property on {base_object}") return UnknownType() def visit_literal_expr(self, expr: p.LiteralExpr) -> Type: diff --git a/midas/checker/types.py b/midas/checker/types.py index 41ad786..9057e4c 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -35,7 +35,6 @@ class UnitType: @dataclass(frozen=True, kw_only=True) class Function: - name: str pos_args: list[Argument] args: list[Argument] kw_args: list[Argument] @@ -56,7 +55,7 @@ class Function: args.append("*") args += list(map(str, self.kw_args)) - return f"{self.name}({', '.join(args)}) -> {self.returns}" + return f"({', '.join(args)}) -> {self.returns}" @dataclass(frozen=True, kw_only=True) class Argument: @@ -72,13 +71,22 @@ class Function: @dataclass(frozen=True, kw_only=True) class ComplexType: - properties: dict[str, Type] + members: dict[str, Type] def __str__(self) -> str: - props: list[str] = [f"{name}: {type}" for name, type in self.properties.items()] + props: list[str] = [f"{name}: {type}" for name, type in self.members.items()] return f"{{{', '.join(props)}}}" +@dataclass(frozen=True, kw_only=True) +class ExtensionType: + base: Type + extension: ComplexType + + def __str__(self) -> str: + return f"{self.base} & {self.extension}" + + @dataclass(frozen=True, kw_only=True) class Operation: signature: CallSignature @@ -145,26 +153,35 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: return AliasType(name=name, type=substitute_typevars(type2, substitutions)) case Function( - name=name, pos_args=pos_args, args=args, kw_args=kw_args, returns=returns, ): return Function( - name=name, pos_args=list(map(sub_argument, pos_args)), args=list(map(sub_argument, args)), kw_args=list(map(sub_argument, kw_args)), returns=substitute_typevars(returns, substitutions), ) - case ComplexType(properties=properties): - properties2: dict[str, Type] = { + case ComplexType(members=members): + members2: dict[str, Type] = { name: substitute_typevars(prop, substitutions) - for name, prop in properties.items() + for name, prop in members.items() } - return ComplexType(properties=properties2) + return ComplexType(members=members2) + + case ExtensionType(base=base, extension=ComplexType(members=members)): + return ExtensionType( + base=substitute_typevars(base, substitutions), + extension=ComplexType( + members={ + name: substitute_typevars(prop, substitutions) + for name, prop in members.items() + } + ), + ) case TypeVar(name=name): if name in substitutions: @@ -193,6 +210,7 @@ Type = ( | UnitType | Function | ComplexType + | ExtensionType | TypeVar | GenericType | AppliedType From ae0bd75f3b0f758b8fbc6aa92c2e9e2c0321b923 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Thu, 11 Jun 2026 17:15:28 +0200 Subject: [PATCH 24/64] fix(checker): improve error for recursive type ref --- midas/checker/midas.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/midas/checker/midas.py b/midas/checker/midas.py index 6a528a0..8c0fede 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -30,6 +30,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type self.types: TypesRegistry = types self._local_variables: dict[str, TypeVar] = {} + self._current_name: Optional[str] = None + define_builtins(self.types) def process(self, source: str, path: Optional[str]): @@ -66,9 +68,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type stmt.accept(self) def visit_type_stmt(self, stmt: m.TypeStmt) -> None: + name: str = stmt.name.lexeme + self._current_name = name params: list[TypeVar] = self._resolve_type_params(stmt.params) - name: str = stmt.name.lexeme type: Type = stmt.type.accept(self) if len(params) != 0: type = GenericType(name=name, params=params, body=type) @@ -76,6 +79,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type type = AliasType(name=name, type=type) self.types.define_type(name, type) self._local_variables.clear() + self._current_name = None def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ... @@ -114,12 +118,24 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ... def visit_named_type(self, type: m.NamedType) -> Type: - return self.get_type(type.name.lexeme) + name: str = type.name.lexeme + try: + return self.get_type(name) + except NameError: + msg: str = f"Undefined type {name}" + if self._current_name == name: + msg += ". Recursive types are not supported, use an extend block" + self.reporter.error(type.name.get_location(), msg) + return UnknownType() def visit_generic_type(self, type: m.GenericType) -> Type: type_: Type = type.type.accept(self) args: list[Type] = [arg.accept(self) for arg in type.args] - return self.types.apply_generic(type_, args) + try: + return self.types.apply_generic(type_, args) + except Exception as e: + self.reporter.error(type.location, f"Cannot apply generic type: {e}") + return UnknownType() def visit_constraint_type(self, type: m.ConstraintType) -> Type: type_: Type = type.type.accept(self) From efea1b29e7abb573cb82d84ef468610ead587b32 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Fri, 12 Jun 2026 14:43:27 +0200 Subject: [PATCH 25/64] fix(cli): show diagnostics from different files --- midas/cli/main.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/midas/cli/main.py b/midas/cli/main.py index cafeeaf..8b4d722 100644 --- a/midas/cli/main.py +++ b/midas/cli/main.py @@ -97,15 +97,27 @@ def compile( ): logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN) source: str = file.read() + source_path: Path = Path(file.name).resolve() checker = TypeChecker() - for path in types: - checker.import_midas(Path(path.name).resolve()) + for types_file in types: + checker.import_midas(Path(types_file.name).resolve()) - checker.type_check_source(source, str(Path(file.name).resolve())) + checker.type_check_source(source, str(source_path)) diagnostics: list[Diagnostic] = checker.diagnostics lines: list[str] = source.split("\n") + files: dict[Optional[str], list[str]] = {None: []} + for diagnostic in diagnostics: + filename: Optional[str] = diagnostic.file_path + if filename is not None and filename not in files: + path: Path = Path(filename) + if path.exists() and path.is_file(): + files[filename] = path.read_text().split("\n") + else: + files[filename] = [] + + lines: list[str] = files[filename] print_diagnostic(lines, diagnostic) if verbose: From 650f60e70c78fb50bd1ddd19a29354d4d8340601 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Fri, 12 Jun 2026 14:44:02 +0200 Subject: [PATCH 26/64] feat(cli): add option to show type judgements --- midas/cli/main.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/midas/cli/main.py b/midas/cli/main.py index 8b4d722..af95abd 100644 --- a/midas/cli/main.py +++ b/midas/cli/main.py @@ -88,11 +88,13 @@ def print_diagnostic(lines: list[str], diagnostic: Diagnostic, indent: int = 4): @click.option("-l", "--highlight", type=click.File("w")) @click.option("-t", "--types", type=click.File("r"), multiple=True) @click.option("-v", "--verbose", is_flag=True) +@click.option("-j", "--show-judgements", is_flag=True) @click.argument("file", type=click.File("r")) def compile( highlight: Optional[TextIO], types: tuple[TextIO], verbose: bool, + show_judgements: bool, file: TextIO, ): logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN) @@ -104,10 +106,22 @@ def compile( checker.import_midas(Path(types_file.name).resolve()) checker.type_check_source(source, str(source_path)) - diagnostics: list[Diagnostic] = checker.diagnostics + diagnostics: list[Diagnostic] = checker.diagnostics.copy() lines: list[str] = source.split("\n") files: dict[Optional[str], list[str]] = {None: []} + if show_judgements: + for expr, type in checker.python_typer.judgements: + print(f"Judged that {expr} at {expr.location} is of type {type}") + diagnostics.append( + Diagnostic( + file_path=str(source_path), + location=expr.location, + type=DiagnosticType.INFO, + message=f"Type: {type}", + ) + ) + for diagnostic in diagnostics: filename: Optional[str] = diagnostic.file_path if filename is not None and filename not in files: From 42284704de6124d9e54db36d33088f309da328a7 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Fri, 12 Jun 2026 16:41:33 +0200 Subject: [PATCH 27/64] feat(parser): accept props and methods in extend --- gen/midas.py | 11 +++++++++-- midas/ast/midas.py | 11 +++++++++-- midas/ast/printer.py | 36 +++++++++++++++++++++++++----------- midas/lexer/token.py | 4 ++++ midas/parser/midas.py | 23 +++++++++++++++++------ 5 files changed, 64 insertions(+), 21 deletions(-) diff --git a/gen/midas.py b/gen/midas.py index 4141217..72813d4 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -4,6 +4,7 @@ ###> Imports from abc import ABC, abstractmethod from dataclasses import dataclass +from enum import Enum, auto from typing import Any, Generic, Optional, TypeVar from midas.ast.location import Location @@ -20,6 +21,11 @@ class TypeParam: bound: Optional[Type] +class MemberKind(Enum): + PROPERTY = auto() + METHOD = auto() + + ###< @@ -33,12 +39,13 @@ class TypeStmt: class MemberStmt: name: Token type: Type + kind: MemberKind class ExtendStmt: + name: Token params: list[TypeParam] - type: Type - operations: list[OpStmt] + members: list[MemberStmt] class OpStmt: diff --git a/midas/ast/midas.py b/midas/ast/midas.py index 36d959b..affd768 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -7,6 +7,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass +from enum import Enum, auto from typing import Any, Generic, Optional, TypeVar from midas.ast.location import Location @@ -21,6 +22,11 @@ class TypeParam: bound: Optional[Type] +class MemberKind(Enum): + PROPERTY = auto() + METHOD = auto() + + ############## # Statements # ############## @@ -64,6 +70,7 @@ class TypeStmt(Stmt): class MemberStmt(Stmt): name: Token type: Type + kind: MemberKind def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_member_stmt(self) @@ -71,9 +78,9 @@ class MemberStmt(Stmt): @dataclass(frozen=True) class ExtendStmt(Stmt): + name: Token params: list[TypeParam] - type: Type - operations: list[OpStmt] + members: list[MemberStmt] def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_extend_stmt(self) diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 2a5eec3..c9a9d33 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -114,6 +114,7 @@ class MidasAstPrinter( def visit_member_stmt(self, stmt: m.MemberStmt): self._write_line("MemberStmt") with self._child_level(): + self._write_line(f"kind: {stmt.kind.name}") self._write_line(f'name: "{stmt.name.lexeme}"') self._write_line("type", last=True) with self._child_level(single=True): @@ -129,16 +130,21 @@ class MidasAstPrinter( if i == len(stmt.params) - 1: self._mark_last() self._print_type_param(param) - self._write_line("type") - with self._child_level(single=True): - stmt.type.accept(self) - self._write_line("operations", last=True) + self._write_line(f'name: "{stmt.name.lexeme}"') + self._write_line("params") with self._child_level(): - for i, op in enumerate(stmt.operations): + for i, param in enumerate(stmt.params): self._idx = i - if i == len(stmt.operations) - 1: + if i == len(stmt.params) - 1: self._mark_last() - op.accept(self) + self._print_type_param(param) + self._write_line("members", last=True) + with self._child_level(): + for i, member in enumerate(stmt.members): + self._idx = i + if i == len(stmt.members) - 1: + self._mark_last() + member.accept(self) def visit_op_stmt(self, stmt: m.OpStmt) -> None: self._write_line("OpStmt") @@ -343,15 +349,23 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] return res def visit_member_stmt(self, stmt: m.MemberStmt): - res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}" + keyword: str = { + m.MemberKind.PROPERTY: "prop", + m.MemberKind.METHOD: "def", + }.get(stmt.kind, "") + res: str = f"{keyword} {stmt.name.lexeme}: {stmt.type.accept(self)}" return self.indented(res) def visit_extend_stmt(self, stmt: m.ExtendStmt): - res: str = self.indented(f"extend {stmt.type.accept(self)}") + template: str = "" + if len(stmt.params) != 0: + params: list[str] = [self._print_type_param(param) for param in stmt.params] + template = f"[{', '.join(params)}]" + res: str = self.indented(f"extend {stmt.name.lexeme}{template}") res += " {\n" self.level += 1 - for op in stmt.operations: - res += op.accept(self) + for member in stmt.members: + res += member.accept(self) + "\n" self.level -= 1 res += self.indented("}") return res diff --git a/midas/lexer/token.py b/midas/lexer/token.py index 60b6e47..95e0c18 100644 --- a/midas/lexer/token.py +++ b/midas/lexer/token.py @@ -50,6 +50,8 @@ class TokenType(Enum): PREDICATE = auto() EXTEND = auto() WHERE = auto() + PROP = auto() + DEF = auto() FUNC = auto() # Misc @@ -68,6 +70,8 @@ KEYWORDS: dict[str, TokenType] = { "true": TokenType.TRUE, "false": TokenType.FALSE, "none": TokenType.NONE, + "prop": TokenType.PROP, + "def": TokenType.DEF, "fn": TokenType.FUNC, } diff --git a/midas/parser/midas.py b/midas/parser/midas.py index 0d0cbde..ce94b2d 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -14,6 +14,7 @@ from midas.ast.midas import ( GroupingExpr, LiteralExpr, LogicalExpr, + MemberKind, MemberStmt, NamedType, OpStmt, @@ -394,18 +395,28 @@ class MidasParser(Parser): def member_stmt(self) -> MemberStmt: """Parse a member statement - A type member statement is written `name: Type` + A type member statement is written `prop name: Type` or `def name: Type` Returns: MemberStmt: the parsed member statement """ + kind: MemberKind + if self.match(TokenType.PROP): + kind = MemberKind.PROPERTY + elif self.match(TokenType.DEF): + kind = MemberKind.METHOD + else: + raise self.error(self.peek(), "Expected 'prop' or 'def'") + name: Token = self.consume_identifier("Expected member name") self.consume(TokenType.COLON, "Expected ':' after member name") + type: Type = self.type_expr() return MemberStmt( location=name.location_to(self.previous()), name=name, type=type, + kind=kind, ) def extend_declaration(self) -> ExtendStmt: @@ -417,20 +428,20 @@ class MidasParser(Parser): ExtendStmt: the parsed extension statement """ keyword: Token = self.previous() + name: Token = self.consume_identifier("Expected type name") params: list[TypeParam] = self.type_params() - type: Type = self.type_expr() self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body") - operations: list[OpStmt] = [] + members: list[MemberStmt] = [] while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE): - operations.append(self.op_declaration()) + members.append(self.member_stmt()) self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body") location: Location = keyword.location_to(self.previous()) return ExtendStmt( location=location, + name=name, params=params, - type=type, - operations=operations, + members=members, ) def op_declaration(self) -> OpStmt: From b3665c646223ee0d8a8fdad7c66f23dc9258f3c9 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Fri, 12 Jun 2026 16:42:25 +0200 Subject: [PATCH 28/64] fix(cli): update highlighter --- midas/cli/highlighter.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index 16fdf94..af81890 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -232,15 +232,14 @@ class MidasHighlighter( self.wrap(LocatableToken(stmt.name), "type-name") stmt.type.accept(self) - def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: - self.wrap(stmt, "property") + def visit_member_stmt(self, stmt: m.MemberStmt) -> None: + self.wrap(stmt, "member") stmt.type.accept(self) def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: self.wrap(stmt, "extend") - stmt.type.accept(self) - for op in stmt.operations: - op.accept(self) + for member in stmt.members: + member.accept(self) def visit_op_stmt(self, stmt: m.OpStmt) -> None: self.wrap(stmt, "op") @@ -298,8 +297,8 @@ class MidasHighlighter( def visit_complex_type(self, type: m.ComplexType) -> None: self.wrap(type, "complex-type") - for prop in type.properties: - prop.accept(self) + for member in type.members: + member.accept(self) def visit_function_type(self, type: m.FunctionType) -> None: self.wrap(type, "function") @@ -307,6 +306,11 @@ class MidasHighlighter( arg.type.accept(self) type.returns.accept(self) + def visit_extension_type(self, type: m.ExtensionType) -> None: + self.wrap(type, "extension") + type.base.accept(self) + type.extension.accept(self) + class DiagnosticsHighlighter(Highlighter): EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css" From 179b88bfed929039dbce5dbde200e98aece3e92c Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Fri, 12 Jun 2026 16:45:11 +0200 Subject: [PATCH 29/64] feat(checker): add members registry --- midas/checker/midas.py | 22 +++++++++++++--------- midas/checker/registry.py | 26 ++++++++++++++++++++++++++ midas/checker/types.py | 9 +++++++++ 3 files changed, 48 insertions(+), 9 deletions(-) diff --git a/midas/checker/midas.py b/midas/checker/midas.py index 8c0fede..cc2cec2 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -85,15 +85,19 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: self._resolve_type_params(stmt.params) - base: Type = stmt.type.accept(self) - for op in stmt.operations: - right: Type = op.operand.accept(self) - result: Type = op.result.accept(self) - self.types.define_operation( - left=base, - operator=op.name.lexeme, - right=right, - result=result, + base_name: str = stmt.name.lexeme + try: + _ = self.get_type(base_name) + except NameError: + self.reporter.error(stmt.name.get_location(), f"Unknown type '{base_name}'") + + for member in stmt.members: + member_type: Type = member.type.accept(self) + self.types.define_member( + base_name, + member.name.lexeme, + member_type, + member.kind == m.MemberKind.METHOD, ) def visit_op_stmt(self, stmt: m.OpStmt) -> None: ... diff --git a/midas/checker/registry.py b/midas/checker/registry.py index d5c432a..16c36a5 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -1,3 +1,4 @@ +import logging from typing import Optional from midas.checker.builtins import BUILTIN_SUBTYPES @@ -9,6 +10,7 @@ from midas.checker.types import ( Function, GenericType, Operation, + OverloadedFunction, Type, substitute_typevars, ) @@ -16,7 +18,9 @@ from midas.checker.types import ( class TypesRegistry: def __init__(self) -> None: + self.logger: logging.Logger = logging.getLogger("TypesRegistry") self._types: dict[str, Type] = {} + self._members: dict[str, dict[str, Type]] = {} self._operations: dict[Operation.CallSignature, Type] = {} def get_type(self, name: str) -> Type: @@ -86,6 +90,28 @@ class TypesRegistry: self._types[name] = type return type + def define_member( + self, type_name: str, member_name: str, member_type: Type, is_method: bool + ): + members: dict[str, Type] = self._members.setdefault(type_name, {}) + if member_name in members: + if not is_method: + self.logger.error( + f"Member '{member_name}' already defined for type {type_name}" + ) + return + current: Type = members[member_name] + combined: Type + match current: + case OverloadedFunction(overloads=overloads): + combined = OverloadedFunction(overloads=overloads + [member_type]) + case _: + combined = OverloadedFunction(overloads=[current, member_type]) + members[member_name] = combined + + else: + members[member_name] = member_type + def define_operation(self, left: Type, operator: str, right: Type, result: Type): """Define an operation in the registry diff --git a/midas/checker/types.py b/midas/checker/types.py index 9057e4c..dd8c173 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -69,6 +69,14 @@ class Function: return f"{self.name}: {self.type}{opt}" +@dataclass(frozen=True, kw_only=True) +class OverloadedFunction: + overloads: list[Type] + + def __str__(self) -> str: + return "" + + @dataclass(frozen=True, kw_only=True) class ComplexType: members: dict[str, Type] @@ -209,6 +217,7 @@ Type = ( | UnknownType | UnitType | Function + | OverloadedFunction | ComplexType | ExtensionType | TypeVar From b5de28e291bfe0eba3ce471df4ff9debcd26416a Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Fri, 12 Jun 2026 16:51:18 +0200 Subject: [PATCH 30/64] feat(checker): implement lookup_member method --- midas/checker/python.py | 84 ++++++++------------------------------- midas/checker/registry.py | 50 +++++++++++++++++++++++ 2 files changed, 66 insertions(+), 68 deletions(-) diff --git a/midas/checker/python.py b/midas/checker/python.py index 3d44975..f1cecb8 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -11,14 +11,11 @@ from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter, Reporter from midas.checker.resolver import Resolver from midas.checker.types import ( - ComplexType, - ExtensionType, Function, Operation, Type, UnitType, UnknownType, - unfold_type, ) from midas.parser.python import PythonParser @@ -250,8 +247,7 @@ class PythonTyper( self._assign_var(location, target, value_type) case p.GetExpr(object=object, name=name): - object_type: Type = self.type_of(object) - self._assign_attr(location, object_type, name, value_type) + self._assign_attr(location, object, name, value_type) case _: if not isinstance(target, p.VariableExpr): @@ -277,43 +273,19 @@ class PythonTyper( ) def _assign_attr( - self, location: Location, object: Type, name: str, value_type: Type + self, location: Location, object: p.Expr, name: str, value_type: Type ): - # TODO: improve recursion to have better error messages - base_object: Type = unfold_type(object) - match base_object: - case ComplexType(members=members): - if name not in members: - self.reporter.error(location, f"Unknown member '{object}.{name}'") - return - - member_type: Type = members[name] - if not self.is_subtype(value_type, member_type): - self.reporter.error( - location, - f"Cannot assign {value_type} to member '{object}.{name}' of type {member_type}", - ) - return - - case ExtensionType(base=base, extension=ComplexType(members=members)): - if name in members: - member_type: Type = members[name] - if not self.is_subtype(value_type, member_type): - self.reporter.error( - location, - f"Cannot assign {value_type} to member '{object}.{name}' of type {member_type}", - ) - return - return self._assign_attr(location, base, name, value_type) - - case UnknownType(): - pass - - case _: - self.reporter.error( - location, - f"Cannot assign {value_type} to unknown property '{object}.{name}'", - ) + object_type: Type = self.type_of(object) + member: Optional[Type] = self.types.lookup_member(object_type, name) + if member is None: + self.reporter.error(location, f"Unknown member '{name}' of {object_type}") + return + self.logger.debug(f"Member '{name}' of {object_type} has type {member}") + if not self.is_subtype(value_type, member): + self.reporter.error( + location, + f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}", + ) def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType() @@ -433,39 +405,15 @@ class PythonTyper( def visit_get_expr(self, expr: p.GetExpr) -> Type: object: Type = self.type_of(expr.object) - member: Optional[Type] = self._get_member(object, expr.name) + member: Optional[Type] = self.types.lookup_member(object, expr.name) if member is None: self.reporter.error( - expr.location, f"Unknown property '{expr.name}' on {object}" + expr.location, f"Unknown member '{expr.name}' of {object}" ) return UnknownType() - self.logger.debug(f"Property '{expr.name}' on {object} has type {member}") + self.logger.debug(f"Member '{expr.name}' of {object} has type {member}") return member - def _get_member(self, object: Type, name: str) -> Optional[Type]: - base_object: Type = unfold_type(object) - match base_object: - case ComplexType(members=members): - if name in members: - return members[name] - self.logger.debug(f"No property '{name}' in {base_object}") - return None - - case ExtensionType(base=base, extension=ComplexType(members=members)): - if name in members: - return members[name] - self.logger.debug( - f"No property '{name}' on {base_object}, looking up in base" - ) - return self._get_member(base, name) - - case UnknownType(): - return UnknownType() - - case _: - self.logger.debug(f"Can't get property on {base_object}") - return UnknownType() - def visit_literal_expr(self, expr: p.LiteralExpr) -> Type: match expr.value: case bool(): # Must be before int diff --git a/midas/checker/registry.py b/midas/checker/registry.py index 16c36a5..db2f972 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -7,11 +7,13 @@ from midas.checker.types import ( AppliedType, BaseType, ComplexType, + ExtensionType, Function, GenericType, Operation, OverloadedFunction, Type, + UnknownType, substitute_typevars, ) @@ -337,3 +339,51 @@ class TypesRegistry: reduced = True break return [types[i] for i in keep] + + def lookup_member(self, type: Type, member_name: str) -> Optional[Type]: + match type: + case AliasType(name=name, type=base): + if name in self._members: + if member_name in self._members[name]: + return self._members[name][member_name] + return self.lookup_member(base, member_name) + + case AppliedType(name=name, body=body, args=args): + generic: Type = self.get_type(name) + + if not isinstance(generic, GenericType): + raise ValueError("AppliedType not derived from a GenericType") + + substitutions = { + type_var.name: arg for arg, type_var in zip(args, generic.params) + } + if name in self._members: + if member_name in self._members[name]: + member_type: Type = self._members[name][member_name] + return substitute_typevars(member_type, substitutions) + + member_type2: Optional[Type] = self.lookup_member(body, member_name) + if member_type2 is not None: + member_type2 = substitute_typevars(member_type2, substitutions) + return member_type2 + + case ComplexType(members=members): + if member_name in members: + return members[member_name] + self.logger.debug(f"No member '{member_name}' in {type}") + return None + + case ExtensionType(base=base, extension=ComplexType(members=members)): + if member_name in members: + return members[member_name] + self.logger.debug( + f"No member '{member_name}' on {type}, looking up in base" + ) + return self.lookup_member(base, member_name) + + case UnknownType(): + return UnknownType() + + case _: + self.logger.debug(f"Can't get member on {type}") + return None From 01ff5ca8d5a627d4a1e625aecbf2a59d6a36c5f8 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Fri, 12 Jun 2026 16:53:34 +0200 Subject: [PATCH 31/64] fix(checker): handle nested generic members --- midas/checker/registry.py | 6 ++++++ midas/checker/types.py | 7 +++++++ 2 files changed, 13 insertions(+) diff --git a/midas/checker/registry.py b/midas/checker/registry.py index db2f972..7ed206d 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -13,6 +13,7 @@ from midas.checker.types import ( Operation, OverloadedFunction, Type, + TypeVar, UnknownType, substitute_typevars, ) @@ -171,6 +172,11 @@ class TypesRegistry: case (Function(), Function()): return self.is_func_subtype(type1, type2) + case (TypeVar(bound=bound), _): + if bound is None: + return False + return self.is_subtype(bound, type2) + return False # TODO: verify the logic in here diff --git a/midas/checker/types.py b/midas/checker/types.py index dd8c173..d0dbe2f 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -191,6 +191,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: ), ) + case AppliedType(name=name, args=args, body=body): + return AppliedType( + name=name, + args=[substitute_typevars(arg, substitutions) for arg in args], + body=substitute_typevars(body, substitutions), + ) + case TypeVar(name=name): if name in substitutions: return substitutions[name] From 2e898ab1e9b2f0d2ceb7354cc47458bef1fa968f Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Fri, 12 Jun 2026 16:55:01 +0200 Subject: [PATCH 32/64] fix(checker): update binary operation lookup --- midas/checker/python.py | 66 ++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/midas/checker/python.py b/midas/checker/python.py index f1cecb8..1e57d35 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -324,47 +324,36 @@ class PythonTyper( left: Type = self.type_of(expr.left) right: Type = self.type_of(expr.right) - operations: list[Operation] = self.types.get_operations_by_name(method) - valid_operations: list[Operation] = [] - for op in operations: - sig: Operation.CallSignature = op.signature - if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right): - valid_operations.append(op) - - if len(valid_operations) == 0: + operation: Optional[Type] = self.types.lookup_member(left, method) + if operation is None: self.reporter.error( expr.location, f"Undefined operation {method} between {left} and {right}", ) return UnknownType() - elif len(valid_operations) == 1: - self.logger.debug(f"Unique operation {method} between {left} and {right}") - return valid_operations[0].result - for i, op1 in enumerate(valid_operations): - sig1: Operation.CallSignature = op1.signature - best_match: bool = True - for j, op2 in enumerate(valid_operations): - if i == j: - continue - sig2: Operation.CallSignature = op2.signature + match operation: + case Function() as function: + if not self._is_binary_function(function): + self.reporter.error( + expr.location, + f"Wrong definition of binary operation. Expected function with 2 positional-only parameters, got {function}", + ) + return UnknownType() - # If op1 is not a full overload of op2 (i.e. operands of op1 are subtypes of op2's) - # ambiguity -> not best match - if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype( - sig1.right, sig2.right - ): - best_match = False - break - self.logger.debug(f"{op1} is a full overload of {op2}") - if best_match: - return op1.result - - self.reporter.error( - expr.location, - f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(map(str, valid_operations))}", - ) - return UnknownType() + rhs: Function.Argument = function.pos_args[0] + if not self.is_subtype(right, rhs.type): + self.reporter.error( + expr.location, + f"Wrong type for right-hand side, expected {rhs.type}, got {right}", + ) + return UnknownType() + return function.returns + case _: + self.reporter.warning( + expr.location, f"Unsupported operation {operation}" + ) + return UnknownType() def visit_compare_expr(self, expr: p.CompareExpr) -> Type: method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__) @@ -617,3 +606,12 @@ class PythonTyper( ) return mapped + + def _is_binary_function(self, function: Function) -> bool: + if len(function.pos_args) != 1: + return False + if len(function.args) != 0: + return False + if len(function.kw_args) != 0: + return False + return True From 52981f12f201b31b1be734e557af82d060c245df Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Fri, 12 Jun 2026 16:56:03 +0200 Subject: [PATCH 33/64] fix(checker): minor fix when using base type in generic --- midas/checker/types.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/midas/checker/types.py b/midas/checker/types.py index d0dbe2f..8767764 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -157,6 +157,9 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: case BaseType(name=name) if name in substitutions: return substitutions[name] + case BaseType(): + return type + case AliasType(name=name, type=type2): return AliasType(name=name, type=substitute_typevars(type2, substitutions)) From 2935c71366de0e260e80ac38c9f341865f5e5f8b Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Fri, 12 Jun 2026 17:01:02 +0200 Subject: [PATCH 34/64] fix(checker): give warning on unknown variable --- midas/checker/python.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/midas/checker/python.py b/midas/checker/python.py index 1e57d35..7d10392 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -12,7 +12,6 @@ from midas.checker.reporter import FileReporter, Reporter from midas.checker.resolver import Resolver from midas.checker.types import ( Function, - Operation, Type, UnitType, UnknownType, @@ -418,7 +417,11 @@ class PythonTyper( return UnknownType() def visit_variable_expr(self, expr: p.VariableExpr) -> Type: - return self.look_up_variable(expr.name, expr) or UnknownType() + type: Optional[Type] = self.look_up_variable(expr.name, expr) + if type is None: + self.logger.debug(f"Unknown variable {expr.name} in {self.env.flat_dict()}") + self.reporter.warning(expr.location, "Unknown variable") + return type or UnknownType() def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: left: Type = expr.left.accept(self) From 50eaafc38807709471cf26d93deea958c1d7e8ee Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Fri, 12 Jun 2026 17:01:19 +0200 Subject: [PATCH 35/64] feat(tests): update serializer --- tests/serializer/midas.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/serializer/midas.py b/tests/serializer/midas.py index 2a5daf5..af4c0b1 100644 --- a/tests/serializer/midas.py +++ b/tests/serializer/midas.py @@ -6,16 +6,17 @@ from midas.ast.midas import ( ConstraintType, Expr, ExtendStmt, + ExtensionType, FunctionType, GenericType, GetExpr, GroupingExpr, LiteralExpr, LogicalExpr, + MemberStmt, NamedType, OpStmt, PredicateStmt, - PropertyStmt, Stmt, Type, TypeParam, @@ -58,9 +59,10 @@ class MidasAstJsonSerializer( "bound": self._serialize_optional(param.bound), } - def visit_property_stmt(self, stmt: PropertyStmt) -> dict: + def visit_member_stmt(self, stmt: MemberStmt) -> dict: return { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": stmt.kind.name, "name": stmt.name.lexeme, "type": stmt.type.accept(self), } @@ -68,8 +70,9 @@ class MidasAstJsonSerializer( def visit_extend_stmt(self, stmt: ExtendStmt) -> dict: return { "_type": "ExtendStmt", - "type": stmt.type.accept(self), - "operations": self._serialize_list(stmt.operations), + "name": stmt.name.lexeme, + "params": [self._serialize_type_param(param) for param in stmt.params], + "members": self._serialize_list(stmt.members), } def visit_op_stmt(self, stmt: OpStmt) -> dict: @@ -163,7 +166,7 @@ class MidasAstJsonSerializer( def visit_complex_type(self, type: ComplexType) -> dict: return { "_type": "ComplexType", - "properties": self._serialize_list(type.properties), + "members": self._serialize_list(type.members), } def visit_function_type(self, type: FunctionType) -> dict: @@ -180,3 +183,10 @@ class MidasAstJsonSerializer( "type": arg.type.accept(self), "required": arg.required, } + + def visit_extension_type(self, type: ExtensionType) -> dict: + return { + "_type": "ExtensionType", + "base": type.base.accept(self), + "extension": type.extension.accept(self), + } From 6d6bb66c545757329b2aa8ef3dca137d131ab5c0 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 12:39:46 +0200 Subject: [PATCH 36/64] feat(checker): define members on builtin types --- midas/checker/builtins.midas | 152 +++++++++++++++++++++++++++++++++++ midas/checker/builtins.py | 59 +------------- midas/checker/midas.py | 3 + 3 files changed, 157 insertions(+), 57 deletions(-) create mode 100644 midas/checker/builtins.midas diff --git a/midas/checker/builtins.midas b/midas/checker/builtins.midas new file mode 100644 index 0000000..ccc1502 --- /dev/null +++ b/midas/checker/builtins.midas @@ -0,0 +1,152 @@ +extend float { + def hex: fn() -> str + def is_integer: fn() -> bool + prop real: float + prop imag: float + def conjugate: fn() -> float + def __add__: fn(value: float, /) -> float + def __sub__: fn(value: float, /) -> float + def __mul__: fn(value: float, /) -> float + def __floordiv__: fn(value: float, /) -> float + def __truediv__: fn(value: float, /) -> float + def __mod__: fn(value: float, /) -> float + // def __divmod__: fn(value: float, /) -> tuple[float, float] + + def __pow__: fn(value: int, /) -> float + // positive __value -> float; negative __value -> complex + // return type must be Any as `float | complex` causes too many false-positive errors + def __pow__: fn(value: float, /) -> Any + def __radd__: fn(value: float, /) -> float + def __rsub__: fn(value: float, /) -> float + def __rmul__: fn(value: float, /) -> float + def __rfloordiv__: fn(value: float, /) -> float + def __rtruediv__: fn(value: float, /) -> float + def __rmod__: fn(value: float, /) -> float + // def __rdivmod__: fn(value: float, /) -> tuple[float, float] + // def __rpow__: fn(value: _PositiveInteger, mod: None = None, /) -> float + // def __rpow__: fn(value: _NegativeInteger, mod: None = None, /) -> complex + // Returning `complex` for the general case gives too many false-positive errors. + // def __rpow__: fn(value: float, mod: None = None, /) -> Any + // def __getnewargs__: fn() -> tuple[float] + def __trunc__: fn() -> int + def __ceil__: fn() -> int + def __floor__: fn() -> int + def __round__: fn(ndigits: None?, /) -> int + def __round__: fn(ndigits: int, /) -> float + def __eq__: fn(value: object, /) -> bool + def __ne__: fn(value: object, /) -> bool + def __lt__: fn(value: float, /) -> bool + def __le__: fn(value: float, /) -> bool + def __gt__: fn(value: float, /) -> bool + def __ge__: fn(value: float, /) -> bool + def __neg__: fn() -> float + def __pos__: fn() -> float + def __int__: fn() -> int + def __float__: fn() -> float + def __abs__: fn() -> float + def __hash__: fn() -> int + def __bool__: fn() -> bool + def __format__: fn(format_spec: str, /) -> str +} + +extend int { + prop real: int + prop imag: int + prop numerator: int + prop denominator: int + def conjugate: fn() -> int + def bit_length: fn() -> int + def bit_count: fn() -> int + def to_bytes: fn(length: int?, byteorder: str?, *, signed: bool?) -> bytes + + def __add__: fn(value: int, /) -> int + def __sub__: fn(value: int, /) -> int + def __mul__: fn(value: int, /) -> int + def __floordiv__: fn(value: int, /) -> int + def __truediv__: fn(value: int, /) -> float + def __mod__: fn(value: int, /) -> int + // def __divmod__: fn(value: int, /) -> tuple[int, int] + def __radd__: fn(value: int, /) -> int + def __rsub__: fn(value: int, /) -> int + def __rmul__: fn(value: int, /) -> int + def __rfloordiv__: fn(value: int, /) -> int + def __rtruediv__: fn(value: int, /) -> float + def __rmod__: fn(value: int, /) -> int + // def __rdivmod__: fn(value: int, /) -> tuple[int, int] + def __pow__: fn(value: int, /) -> int + // def __pow__: fn(value: _PositiveInteger, mod: None = None, /) -> int + // def __pow__: fn(value: _NegativeInteger, mod: None = None, /) -> float + // positive __value -> int; negative __value -> float + // return type must be Any as `int | float` causes too many false-positive errors + // def __pow__: fn(value: int, mod: None = None, /) -> Any + // def __pow__: fn(value: int, mod: int, /) -> int + def __rpow__: fn(value: int, /) -> Any + def __and__: fn(value: int, /) -> int + def __or__: fn(value: int, /) -> int + def __xor__: fn(value: int, /) -> int + def __lshift__: fn(value: int, /) -> int + def __rshift__: fn(value: int, /) -> int + def __rand__: fn(value: int, /) -> int + def __ror__: fn(value: int, /) -> int + def __rxor__: fn(value: int, /) -> int + def __rlshift__: fn(value: int, /) -> int + def __rrshift__: fn(value: int, /) -> int + def __neg__: fn() -> int + def __pos__: fn() -> int + def __invert__: fn() -> int + def __trunc__: fn() -> int + def __ceil__: fn() -> int + def __floor__: fn() -> int + def __round__: fn(ndigits: None?, /) -> int + def __round__: fn(ndigits: int, /) -> int + + // def __getnewargs__: fn() -> tuple[int] + def __eq__: fn(value: object, /) -> bool + def __ne__: fn(value: object, /) -> bool + def __lt__: fn(value: int, /) -> bool + def __le__: fn(value: int, /) -> bool + def __gt__: fn(value: int, /) -> bool + def __ge__: fn(value: int, /) -> bool + def __float__: fn() -> float + def __int__: fn() -> int + def __abs__: fn() -> int + def __hash__: fn() -> int + def __bool__: fn() -> bool + def __index__: fn() -> int + def __format__: fn(format_spec: str, /) -> str +} + +extend list[T] { + def copy: fn () -> list[T] + def append: fn (object: T, /) -> None + def extend: fn (iterable: list[T], /) -> None + def pop: fn (index: int?, /) -> T + def index: fn (value: T, start: int?, stop: int?, /) -> int + def count: fn (value: T, /) -> int + def insert: fn (index: int, object: T, /) -> None + def remove: fn (value: T, /) -> None + def sort: fn (*, reverse: bool?) -> None + def __len__: fn () -> int + // def __iter__: fn () -> Iterator[T] + def __getitem__: fn (i: int, /) -> T + //__getitem__: fn (s: slice, /) -> list[T] + def __setitem__: fn (key: int, value: T, /) -> None + //__setitem__: fn (key: slice, value: list[T], /) -> None + def __delitem__: fn (key: int, /) -> None + // def __delitem__: fn (key: slice, /) -> None + // def __add__: fn[S <: T] (value: list[S], /) -> list[T] + def __add__: fn (value: list[T], /) -> list[T] + def __iadd__: fn (value: list[T], /) -> list[T] + def __mul__: fn (value: int, /) -> list[T] + def __rmul__: fn (value: int, /) -> list[T] + def __imul__: fn (value: int, /) -> list[T] + def __contains__: fn (key: object, /) -> bool + // def __reversed__: fn (self) -> Iterator[_T] + def __gt__: fn (value: list[T], /) -> bool + def __ge__: fn (value: list[T], /) -> bool + def __lt__: fn (value: list[T], /) -> bool + def __le__: fn (value: list[T], /) -> bool + def __eq__: fn (value: object, /) -> bool + + prop __doc__: str +} diff --git a/midas/checker/builtins.py b/midas/checker/builtins.py index f20eb50..1b62eeb 100644 --- a/midas/checker/builtins.py +++ b/midas/checker/builtins.py @@ -4,8 +4,6 @@ from typing import TYPE_CHECKING from midas.checker.types import ( BaseType, - ComplexType, - Function, GenericType, Type, TypeVar, @@ -43,70 +41,17 @@ def basic_op(reg: TypesRegistry, type: Type, op: str): def define_builtins(reg: TypesRegistry): """Define builtin types and operations""" unit = reg.define_type("None", UnitType()) + object = reg.define_type("object", BaseType(name="object")) bool = reg.define_type("bool", BaseType(name="bool")) int = reg.define_type("int", BaseType(name="int")) float = reg.define_type("float", BaseType(name="float")) str = reg.define_type("str", BaseType(name="str")) - basic_op(reg, int, "__add__") # int + int = int - basic_op(reg, int, "__sub__") # int - int = int - basic_op(reg, int, "__mul__") # int * int = int - basic_op(reg, int, "__pow__") # int ** int = int - basic_op(reg, int, "__mod__") # int % int = int - basic_op(reg, int, "__and__") # int & int = int - basic_op(reg, int, "__or__") # int | int = int - basic_op(reg, int, "__xor__") # int ^ int = int - op(reg, int, "__lt__", int, bool) # int < int = bool - op(reg, int, "__gt__", int, bool) # int > int = bool - op(reg, int, "__le__", int, bool) # int <= int = bool - op(reg, int, "__ge__", int, bool) # int >= int = bool - op(reg, int, "__eq__", int, bool) # int == int = bool - basic_op(reg, float, "__add__") # float + float = float - basic_op(reg, float, "__sub__") # float - float = float - basic_op(reg, float, "__mul__") # float * float = float - basic_op(reg, float, "__truediv__") # float / float = float - op(reg, float, "__lt__", float, bool) # float < float = bool - op(reg, float, "__gt__", float, bool) # float > float = bool - op(reg, float, "__le__", float, bool) # float <= float = bool - op(reg, float, "__ge__", float, bool) # float >= float = bool - op(reg, float, "__eq__", float, bool) # float == float = bool - basic_op(reg, str, "__add__") # str + str = str - op(reg, str, "__eq__", str, bool) # str == str = bool - - op(reg, int, "__lt__", float, bool) # int < float = bool - op(reg, int, "__gt__", float, bool) # int > float = bool - op(reg, int, "__le__", float, bool) # int <= float = bool - op(reg, int, "__ge__", float, bool) # int >= float = bool - op(reg, int, "__eq__", float, bool) # int == float = bool - - op(reg, float, "__lt__", int, bool) # float < int = bool - op(reg, float, "__gt__", int, bool) # float > int = bool - op(reg, float, "__le__", int, bool) # float <= int = bool - op(reg, float, "__ge__", int, bool) # float >= int = bool - op(reg, float, "__eq__", int, bool) # float == int = bool - list = reg.define_type( "list", GenericType( name="list", params=[TypeVar(name="T", bound=None)], - body=ComplexType( - properties={ - "append": Function( - name="append", - pos_args=[ - Function.Argument( - pos=0, - name="object", - type=TypeVar(name="T", bound=None), - required=True, - ) - ], - args=[], - kw_args=[], - returns=UnitType(), - ) - } - ), + body=BaseType(name="list"), ), ) diff --git a/midas/checker/midas.py b/midas/checker/midas.py index cc2cec2..25096ae 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path from typing import Optional import midas.ast.midas as m @@ -33,6 +34,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type self._current_name: Optional[str] = None define_builtins(self.types) + builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve() + self.process(builtins_path.read_text(), str(builtins_path)) def process(self, source: str, path: Optional[str]): self.reporter = self.reporter.for_file(path) From 84a5f41e62030d855d1f06b8df1ea87f1045f6f9 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 12:40:26 +0200 Subject: [PATCH 37/64] fix: extend example of complex types --- .../04_complex_types.midas | 20 ++++++++++++----- .../04_complex_types.py | 22 ++++++++++++++++--- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/examples/01_simple_type_checking/04_complex_types.midas b/examples/01_simple_type_checking/04_complex_types.midas index b920c37..b561cef 100644 --- a/examples/01_simple_type_checking/04_complex_types.midas +++ b/examples/01_simple_type_checking/04_complex_types.midas @@ -1,11 +1,21 @@ type Meter = float extend Meter { - op __add__(Meter) -> Meter - op __sub__(Meter) -> Meter + def __add__: fn(Meter) -> Meter + def __sub__: fn(Meter) -> Meter } -type Coordinate = { - x: Meter - y: Meter +type Coordinate = object + +extend Coordinate { + prop x: Meter + prop y: Meter +} + +type Difference[T <: float] = T +type MeterDifference = Difference[Meter] + +type CompDiff[T <: float] = { + prop d1: Difference[T] + prop d2: Difference[T] } \ No newline at end of file diff --git a/examples/01_simple_type_checking/04_complex_types.py b/examples/01_simple_type_checking/04_complex_types.py index 63fd1e7..ebe958f 100644 --- a/examples/01_simple_type_checking/04_complex_types.py +++ b/examples/01_simple_type_checking/04_complex_types.py @@ -1,5 +1,6 @@ # type: ignore # ruff: disable [F821] + p1: Coordinate p2: Coordinate @@ -9,6 +10,21 @@ diff_y = p2.y - p1.y dist = diff_x + diff_y p2.x += cast(Meter, 1) -p2.y = True -p2.z = 3 -p2.x.a = 3 +p2.y = True # invalid, wrong type +p2.z = 3 # invalid, no property 'z' on Coordinate +p2.x.a = 3 # invalid, no properties on Meter + +foo: list[float] = [] + +append = foo.append + +foo.append("") # invalid, must be float +foo.append(2) +append(True) # invalid, must be float +append(2) + +bar: list[list[Meter]] + +bar.append([p2.x]) + +foo2 = foo + foo From 4c9cbd9faabac9783555c41af281fde88cd0c39e Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 12:45:40 +0200 Subject: [PATCH 38/64] feat(checker): add top type (Any) --- midas/checker/builtins.py | 2 ++ midas/checker/registry.py | 4 ++++ midas/checker/types.py | 9 ++++++++- 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/midas/checker/builtins.py b/midas/checker/builtins.py index 1b62eeb..6fc46d2 100644 --- a/midas/checker/builtins.py +++ b/midas/checker/builtins.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from midas.checker.types import ( BaseType, GenericType, + TopType, Type, TypeVar, UnitType, @@ -40,6 +41,7 @@ def basic_op(reg: TypesRegistry, type: Type, op: str): def define_builtins(reg: TypesRegistry): """Define builtin types and operations""" + any = reg.define_type("Any", TopType()) unit = reg.define_type("None", UnitType()) object = reg.define_type("object", BaseType(name="object")) bool = reg.define_type("bool", BaseType(name="bool")) diff --git a/midas/checker/registry.py b/midas/checker/registry.py index 7ed206d..8d60b27 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -12,6 +12,7 @@ from midas.checker.types import ( GenericType, Operation, OverloadedFunction, + TopType, Type, TypeVar, UnknownType, @@ -155,6 +156,9 @@ class TypesRegistry: return True match (type1, type2): + case (_, TopType()): + return True + case (AliasType(type=base1), _): return self.is_subtype(base1, type2) diff --git a/midas/checker/types.py b/midas/checker/types.py index 8767764..444e868 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -4,6 +4,12 @@ from dataclasses import dataclass from typing import Optional +@dataclass(frozen=True, kw_only=True) +class TopType: + def __str__(self) -> str: + return "Any" + + @dataclass(frozen=True, kw_only=True) class BaseType: name: str @@ -222,7 +228,8 @@ def unfold_type(type: Type) -> Type: Type = ( - BaseType + TopType + | BaseType | AliasType | UnknownType | UnitType From 99924ee6c2f084990f661e80b63acfd6811d3d3d Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 13:16:24 +0200 Subject: [PATCH 39/64] feat(parser): add mixed arguments in midas functions --- gen/midas.py | 1 + midas/ast/midas.py | 1 + midas/ast/printer.py | 10 ++++++++ midas/checker/midas.py | 32 ++++++++++++------------- midas/checker/python.py | 2 +- midas/parser/midas.py | 50 ++++++++++++++++++++++----------------- tests/serializer/midas.py | 1 + 7 files changed, 57 insertions(+), 40 deletions(-) diff --git a/gen/midas.py b/gen/midas.py index 72813d4..5405b6c 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -135,6 +135,7 @@ class ExtensionType: class FunctionType: pos_args: list[Argument] + args: list[Argument] kw_args: list[Argument] returns: Type diff --git a/midas/ast/midas.py b/midas/ast/midas.py index affd768..c35e856 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -293,6 +293,7 @@ class ExtensionType(Type): @dataclass(frozen=True) class FunctionType(Type): pos_args: list[Argument] + args: list[Argument] kw_args: list[Argument] returns: Type diff --git a/midas/ast/printer.py b/midas/ast/printer.py index c9a9d33..d094d42 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -297,6 +297,14 @@ class MidasAstPrinter( self._mark_last() self._print_function_arg(arg) + self._write_line("args") + with self._child_level(): + for i, arg in enumerate(type.args): + self._idx = i + if i == len(type.args) - 1: + self._mark_last() + self._print_function_arg(arg) + self._write_line("kw_args") with self._child_level(): for i, arg in enumerate(type.kw_args): @@ -447,11 +455,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] def visit_function_type(self, type: m.FunctionType) -> str: pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args] + mixed_args: list[str] = [self._print_arg(arg) for arg in type.args] kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args] args: list[str] = pos_args if len(pos_args) != 0: args.append("/") + args += mixed_args if len(kw_args) != 0: args.append("*") args += kw_args diff --git a/midas/checker/midas.py b/midas/checker/midas.py index 25096ae..e874e63 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -164,25 +164,23 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type ) def visit_function_type(self, type: m.FunctionType) -> Type: + n_pos_args: int = len(type.pos_args) + n_args: int = len(type.args) + + def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument: + return Function.Argument( + pos=i, + name=arg.name.lexeme if arg.name is not None else str(i), + type=arg.type.accept(self), + required=arg.required, + ) + return Function( - pos_args=[ - Function.Argument( - pos=i, - name=arg.name.lexeme if arg.name is not None else str(i), - type=arg.type.accept(self), - required=arg.required, - ) - for i, arg in enumerate(type.pos_args) - ], - args=[], + pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)], + args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)], kw_args=[ - Function.Argument( - pos=i, - name=arg.name.lexeme if arg.name is not None else str(i), - type=arg.type.accept(self), - required=arg.required, - ) - for i, arg in enumerate(type.kw_args, start=len(type.pos_args)) + process_arg(arg, i + n_pos_args + n_args) + for i, arg in enumerate(type.kw_args) ], returns=type.returns.accept(self), ) diff --git a/midas/checker/python.py b/midas/checker/python.py index 7d10392..57e6687 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -336,7 +336,7 @@ class PythonTyper( if not self._is_binary_function(function): self.reporter.error( expr.location, - f"Wrong definition of binary operation. Expected function with 2 positional-only parameters, got {function}", + f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}", ) return UnknownType() diff --git a/midas/parser/midas.py b/midas/parser/midas.py index ce94b2d..06f44a4 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -499,34 +499,39 @@ class MidasParser(Parser): TokenType.LEFT_PAREN, "Expected '(' before function parameters" ) pos_args: list[FunctionType.Argument] = [] + args: list[FunctionType.Argument] = [] kw_args: list[FunctionType.Argument] = [] - positional: bool = True + section: int = 0 while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN): - if positional and ( - self.match(TokenType.STAR) or self.match(TokenType.SLASH) - ): - positional = False - else: - name: Optional[Token] = None - if self.check_identifier() and self.check_next(TokenType.COLON): - name = self.advance() - self.advance() - type: Type = self.type_expr() - optional: bool = self.match(TokenType.QMARK) - arg = FunctionType.Argument( - location=None, - name=name, - type=type, - required=not optional, - ) - if positional: - pos_args.append(arg) - else: - kw_args.append(arg) + match section: + case 0 if self.match(TokenType.SLASH): + pos_args = args + args = [] + section = 1 + case 0 | 1 if self.match(TokenType.STAR): + section = 2 + case _: + name: Optional[Token] = None + if self.check_identifier() and self.check_next(TokenType.COLON): + name = self.advance() + self.advance() + type: Type = self.type_expr() + optional: bool = self.match(TokenType.QMARK) + arg = FunctionType.Argument( + location=None, + name=name, + type=type, + required=not optional, + ) + if section == 2: + kw_args.append(arg) + else: + args.append(arg) if not self.match(TokenType.COMMA): break + self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters") self.consume(TokenType.ARROW, "Expected '->' before result type") @@ -535,6 +540,7 @@ class MidasParser(Parser): return FunctionType( location=l_paren.location_to(self.previous()), pos_args=pos_args, + args=args, kw_args=kw_args, returns=result, ) diff --git a/tests/serializer/midas.py b/tests/serializer/midas.py index af4c0b1..f1e55da 100644 --- a/tests/serializer/midas.py +++ b/tests/serializer/midas.py @@ -173,6 +173,7 @@ class MidasAstJsonSerializer( return { "_type": "FunctionType", "pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args], + "args": [self._serialize_func_arg(arg) for arg in type.args], "kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args], "returns": type.returns.accept(self), } From 109c8eb35a1edfd76c0903fcc248de7c820f2c71 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 13:43:16 +0200 Subject: [PATCH 40/64] fix(parser): make name required for mixed and keyword args --- midas/parser/midas.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/midas/parser/midas.py b/midas/parser/midas.py index 06f44a4..2fc46cf 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -502,20 +502,33 @@ class MidasParser(Parser): args: list[FunctionType.Argument] = [] kw_args: list[FunctionType.Argument] = [] + args_first_tokens: list[Token] = [] + section: int = 0 while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN): match section: case 0 if self.match(TokenType.SLASH): pos_args = args args = [] + args_first_tokens = [] section = 1 case 0 | 1 if self.match(TokenType.STAR): section = 2 case _: + # Record first token of mixed argument for errors if unnamed + if section != 2: + args_first_tokens.append(self.peek()) + name: Optional[Token] = None - if self.check_identifier() and self.check_next(TokenType.COLON): + if section == 2: + name = self.consume_identifier("Expected keyword argument name") + self.consume( + TokenType.COLON, "Expected ':' after argument name" + ) + elif self.check_identifier() and self.check_next(TokenType.COLON): name = self.advance() self.advance() + type: Type = self.type_expr() optional: bool = self.match(TokenType.QMARK) arg = FunctionType.Argument( @@ -532,6 +545,11 @@ class MidasParser(Parser): if not self.match(TokenType.COMMA): break + for arg, token in zip(args, args_first_tokens): + if arg.name is None: + # Not raised because we can keep parsing + self.error(token, "Unnamed mixed argument") + self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters") self.consume(TokenType.ARROW, "Expected '->' before result type") From 6f5d971c66b90e32051bc86c44b658fe666c9c37 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 13:43:33 +0200 Subject: [PATCH 41/64] fix(checker): gravefully handle unknown type --- midas/checker/python.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/midas/checker/python.py b/midas/checker/python.py index 57e6687..f8f049a 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -481,7 +481,13 @@ class PythonTyper( return self.types.apply_generic(list_type, [UnknownType()]) def visit_base_type(self, node: p.BaseType) -> Type: - base: Type = self.types.get_type(node.base) + base: Type + try: + base = self.types.get_type(node.base) + except NameError: + self.reporter.warning(node.location, f"Unknown type '{node.base}'") + return UnknownType() + if node.param is not None: param: Type = node.param.accept(self) return self.types.apply_generic(base, [param]) From eb223c6cb731b405ba6cdf37ef76dc852ac1f8ec Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 13:44:05 +0200 Subject: [PATCH 42/64] fix(checker): forward parsing errors as diagnostics --- midas/checker/midas.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/midas/checker/midas.py b/midas/checker/midas.py index e874e63..e27ca97 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -43,6 +43,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type tokens: list[Token] = lexer.process() parser: MidasParser = MidasParser(tokens) stmts: list[m.Stmt] = parser.parse() + for error in parser.errors: + self.reporter.error(error.token.get_location(), error.message) self.resolve(stmts) def get_type(self, name: str) -> Type: From 3ee1161680c8207fd6e95d135ef38a153ffcb58f Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 13:49:57 +0200 Subject: [PATCH 43/64] fix: remove unused op statement --- gen/midas.py | 6 ------ midas/ast/midas.py | 13 ------------- midas/ast/printer.py | 18 ------------------ midas/checker/midas.py | 2 -- midas/cli/highlighter.py | 8 +------- midas/lexer/token.py | 2 -- midas/parser/midas.py | 29 ++--------------------------- tests/serializer/midas.py | 9 --------- 8 files changed, 3 insertions(+), 84 deletions(-) diff --git a/gen/midas.py b/gen/midas.py index 5405b6c..42caf4f 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -48,12 +48,6 @@ class ExtendStmt: members: list[MemberStmt] -class OpStmt: - name: Token - operand: Type - result: Type - - class PredicateStmt: name: Token subject: Token diff --git a/midas/ast/midas.py b/midas/ast/midas.py index c35e856..e71aff9 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -49,9 +49,6 @@ class Stmt(ABC): @abstractmethod def visit_extend_stmt(self, stmt: ExtendStmt) -> T: ... - @abstractmethod - def visit_op_stmt(self, stmt: OpStmt) -> T: ... - @abstractmethod def visit_predicate_stmt(self, stmt: PredicateStmt) -> T: ... @@ -86,16 +83,6 @@ class ExtendStmt(Stmt): return visitor.visit_extend_stmt(self) -@dataclass(frozen=True) -class OpStmt(Stmt): - name: Token - operand: Type - result: Type - - def accept(self, visitor: Stmt.Visitor[T]) -> T: - return visitor.visit_op_stmt(self) - - @dataclass(frozen=True) class PredicateStmt(Stmt): name: Token diff --git a/midas/ast/printer.py b/midas/ast/printer.py index d094d42..2124778 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -146,19 +146,6 @@ class MidasAstPrinter( self._mark_last() member.accept(self) - def visit_op_stmt(self, stmt: m.OpStmt) -> None: - self._write_line("OpStmt") - with self._child_level(): - self._write_line(f'name: "{stmt.name.lexeme}"') - - self._write_line("operand") - with self._child_level(single=True): - stmt.operand.accept(self) - - self._write_line("result", last=True) - with self._child_level(single=True): - stmt.result.accept(self) - def visit_predicate_stmt(self, stmt: m.PredicateStmt): self._write_line("PredicateStmt") with self._child_level(): @@ -378,11 +365,6 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] res += self.indented("}") return res - def visit_op_stmt(self, stmt: m.OpStmt): - operand: str = stmt.operand.accept(self) - result: str = stmt.result.accept(self) - return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}\n") - def visit_predicate_stmt(self, stmt: m.PredicateStmt): name: str = stmt.name.lexeme subject: str = stmt.subject.lexeme diff --git a/midas/checker/midas.py b/midas/checker/midas.py index e27ca97..f54d6ab 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -105,8 +105,6 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type member.kind == m.MemberKind.METHOD, ) - def visit_op_stmt(self, stmt: m.OpStmt) -> None: ... - def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ... def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ... diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index af81890..00c8dcf 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -241,12 +241,6 @@ class MidasHighlighter( for member in stmt.members: member.accept(self) - def visit_op_stmt(self, stmt: m.OpStmt) -> None: - self.wrap(stmt, "op") - self.wrap(LocatableToken(stmt.name), "op-name") - stmt.operand.accept(self) - stmt.result.accept(self) - def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: self.wrap(stmt, "predicate") self.wrap(LocatableToken(stmt.name), "predicate-name") @@ -302,7 +296,7 @@ class MidasHighlighter( def visit_function_type(self, type: m.FunctionType) -> None: self.wrap(type, "function") - for arg in type.pos_args + type.kw_args: + for arg in type.pos_args + type.args + type.kw_args: arg.type.accept(self) type.returns.accept(self) diff --git a/midas/lexer/token.py b/midas/lexer/token.py index 95e0c18..f0c08a1 100644 --- a/midas/lexer/token.py +++ b/midas/lexer/token.py @@ -46,7 +46,6 @@ class TokenType(Enum): # Keywords TYPE = auto() - OP = auto() PREDICATE = auto() EXTEND = auto() WHERE = auto() @@ -63,7 +62,6 @@ class TokenType(Enum): KEYWORDS: dict[str, TokenType] = { "type": TokenType.TYPE, - "op": TokenType.OP, "predicate": TokenType.PREDICATE, "extend": TokenType.EXTEND, "where": TokenType.WHERE, diff --git a/midas/parser/midas.py b/midas/parser/midas.py index 2fc46cf..33069f3 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -17,7 +17,6 @@ from midas.ast.midas import ( MemberKind, MemberStmt, NamedType, - OpStmt, PredicateStmt, Stmt, Type, @@ -37,9 +36,10 @@ class MidasParser(Parser): SYNC_BOUNDARY: set[TokenType] = { TokenType.TYPE, - TokenType.OP, TokenType.EXTEND, TokenType.PREDICATE, + TokenType.PROP, + TokenType.FUNC, } def parse(self) -> list[Stmt]: @@ -444,31 +444,6 @@ class MidasParser(Parser): members=members, ) - def op_declaration(self) -> OpStmt: - """Parse an operation definition - - An operation is written `op name(Type) -> Type` - - Returns: - OpStmt: the parsed operation statement - """ - keyword: Token = self.consume(TokenType.OP, "Expected 'op' keyword") - - name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name") - self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type") - operand: Type = self.type_expr() - self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type") - - self.consume(TokenType.ARROW, "Expected '->' before result type") - result: Type = self.type_expr() - - return OpStmt( - location=keyword.location_to(self.previous()), - name=name, - operand=operand, - result=result, - ) - def predicate_declaration(self) -> PredicateStmt: """Parse a predicate declaration diff --git a/tests/serializer/midas.py b/tests/serializer/midas.py index f1e55da..8bffdb3 100644 --- a/tests/serializer/midas.py +++ b/tests/serializer/midas.py @@ -15,7 +15,6 @@ from midas.ast.midas import ( LogicalExpr, MemberStmt, NamedType, - OpStmt, PredicateStmt, Stmt, Type, @@ -75,14 +74,6 @@ class MidasAstJsonSerializer( "members": self._serialize_list(stmt.members), } - def visit_op_stmt(self, stmt: OpStmt) -> dict: - return { - "_type": "OpStmt", - "name": stmt.name.lexeme, - "operand": stmt.operand.accept(self), - "result": stmt.result.accept(self), - } - def visit_predicate_stmt(self, stmt: PredicateStmt) -> dict: return { "_type": "PredicateStmt", From df2e609c6075a9b8c834a9184cfcd6b10605b30e Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 14:00:23 +0200 Subject: [PATCH 44/64] fix(checker): handle members on base type --- midas/checker/registry.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/midas/checker/registry.py b/midas/checker/registry.py index 8d60b27..5529ff6 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -352,6 +352,12 @@ class TypesRegistry: def lookup_member(self, type: Type, member_name: str) -> Optional[Type]: match type: + case BaseType(name=name): + if name in self._members: + if member_name in self._members[name]: + return self._members[name][member_name] + return None + case AliasType(name=name, type=base): if name in self._members: if member_name in self._members[name]: From 9a227b6d4cf7dfdd4bf5bb597abdcab21d1ed2dc Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 14:00:50 +0200 Subject: [PATCH 45/64] fix(checker): remove in.to_bytes --- midas/checker/builtins.midas | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/midas/checker/builtins.midas b/midas/checker/builtins.midas index ccc1502..ba8f18b 100644 --- a/midas/checker/builtins.midas +++ b/midas/checker/builtins.midas @@ -57,7 +57,7 @@ extend int { def conjugate: fn() -> int def bit_length: fn() -> int def bit_count: fn() -> int - def to_bytes: fn(length: int?, byteorder: str?, *, signed: bool?) -> bytes + // def to_bytes: fn(length: int?, byteorder: str?, *, signed: bool?) -> bytes def __add__: fn(value: int, /) -> int def __sub__: fn(value: int, /) -> int From 221b5ca926b761d07cd15106cb496fe0feef9937 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 17:44:40 +0200 Subject: [PATCH 46/64] fix(checker): adapt comparison to lookup method --- midas/checker/python.py | 71 +++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 38 deletions(-) diff --git a/midas/checker/python.py b/midas/checker/python.py index f8f049a..cf25593 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -320,39 +320,8 @@ class PythonTyper( expr.location, f"Unsupported operator {expr.operator}" ) return UnknownType() - left: Type = self.type_of(expr.left) - right: Type = self.type_of(expr.right) - operation: Optional[Type] = self.types.lookup_member(left, method) - if operation is None: - self.reporter.error( - expr.location, - f"Undefined operation {method} between {left} and {right}", - ) - return UnknownType() - - match operation: - case Function() as function: - if not self._is_binary_function(function): - self.reporter.error( - expr.location, - f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}", - ) - return UnknownType() - - rhs: Function.Argument = function.pos_args[0] - if not self.is_subtype(right, rhs.type): - self.reporter.error( - expr.location, - f"Wrong type for right-hand side, expected {rhs.type}, got {right}", - ) - return UnknownType() - return function.returns - case _: - self.reporter.warning( - expr.location, f"Unsupported operation {operation}" - ) - return UnknownType() + return self._visit_binary_expr(expr.location, expr.left, expr.right, method) def visit_compare_expr(self, expr: p.CompareExpr) -> Type: method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__) @@ -362,17 +331,43 @@ class PythonTyper( expr.location, f"Unsupported operator {expr.operator}" ) return UnknownType() - left: Type = self.type_of(expr.left) - right: Type = self.type_of(expr.right) - result: Optional[Type] = self.types.get_operation_result(left, method, right) - if result is None: + return self._visit_binary_expr(expr.location, expr.left, expr.right, method) + + def _visit_binary_expr( + self, location: Location, left_expr: p.Expr, right_expr: p.Expr, method: str + ) -> Type: + left: Type = self.type_of(left_expr) + right: Type = self.type_of(right_expr) + + operation: Optional[Type] = self.types.lookup_member(left, method) + if operation is None: self.reporter.error( - expr.location, + location, f"Undefined operation {method} between {left} and {right}", ) return UnknownType() - return result + + match operation: + case Function() as function: + if not self._is_binary_function(function): + self.reporter.error( + location, + f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}", + ) + return UnknownType() + + rhs: Function.Argument = function.pos_args[0] + if not self.is_subtype(right, rhs.type): + self.reporter.error( + location, + f"Wrong type for right-hand side, expected {rhs.type}, got {right}", + ) + return UnknownType() + return function.returns + case _: + self.reporter.warning(location, f"Unsupported operation {operation}") + return UnknownType() def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ... From 0d0115534bc32d59824f6830d6f7af629bdbb8c6 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 17:46:15 +0200 Subject: [PATCH 47/64] tests: update tests --- .../checker/02_simple_operations.py.ref.json | 18 +- tests/cases/checker/03_functions.py.ref.json | 12 - tests/cases/checker/04_custom_types.midas | 10 +- tests/cases/checker/06_subtyping.py | 2 +- tests/cases/checker/06_subtyping.py.ref.json | 13 +- .../cases/midas-parser/01_simple_types.midas | 22 +- .../01_simple_types.midas.ref.json | 716 ++++++++++++------ 7 files changed, 517 insertions(+), 276 deletions(-) diff --git a/tests/cases/checker/02_simple_operations.py.ref.json b/tests/cases/checker/02_simple_operations.py.ref.json index e3881e0..a2c5569 100644 --- a/tests/cases/checker/02_simple_operations.py.ref.json +++ b/tests/cases/checker/02_simple_operations.py.ref.json @@ -13,6 +13,20 @@ ] }, "message": "Cannot assign str to variable 'c' of type int" + }, + { + "type": "Error", + "location": { + "start": [ + 9, + 4 + ], + "end": [ + 9, + 9 + ] + }, + "message": "Undefined operation __add__ between bool and bool" } ], "judgments": [ @@ -158,9 +172,7 @@ "name": "d" } }, - "type": { - "name": "int" - } + "type": {} }, { "location": { diff --git a/tests/cases/checker/03_functions.py.ref.json b/tests/cases/checker/03_functions.py.ref.json index 3442bca..756a143 100644 --- a/tests/cases/checker/03_functions.py.ref.json +++ b/tests/cases/checker/03_functions.py.ref.json @@ -264,7 +264,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -328,7 +327,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -410,7 +408,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -509,7 +506,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -609,7 +605,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -725,7 +720,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -842,7 +836,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -924,7 +917,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -1006,7 +998,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -1123,7 +1114,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -1240,7 +1230,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -1357,7 +1346,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, diff --git a/tests/cases/checker/04_custom_types.midas b/tests/cases/checker/04_custom_types.midas index 6a1a6a2..ff4edb1 100644 --- a/tests/cases/checker/04_custom_types.midas +++ b/tests/cases/checker/04_custom_types.midas @@ -3,12 +3,12 @@ type Second = float type MeterPerSecond = float extend Meter { - op __add__(Meter) -> Meter - op __sub__(Meter) -> Meter - op __truediv__(Second) -> MeterPerSecond + def __add__: fn(Meter, /) -> Meter + def __sub__: fn(Meter, /) -> Meter + def __truediv__: fn(Second, /) -> MeterPerSecond } extend Second { - op __add__(Second) -> Second - op __sub__(Second) -> Second + def __add__: fn(Second, /) -> Second + def __sub__: fn(Second, /) -> Second } diff --git a/tests/cases/checker/06_subtyping.py b/tests/cases/checker/06_subtyping.py index c334ab8..7ab9dd7 100644 --- a/tests/cases/checker/06_subtyping.py +++ b/tests/cases/checker/06_subtyping.py @@ -9,4 +9,4 @@ def maximum(a: float, b: float): v3 = maximum(v1, v2) -v3 = v1 + v2 +v3 = v2 + v1 diff --git a/tests/cases/checker/06_subtyping.py.ref.json b/tests/cases/checker/06_subtyping.py.ref.json index 689402e..0659939 100644 --- a/tests/cases/checker/06_subtyping.py.ref.json +++ b/tests/cases/checker/06_subtyping.py.ref.json @@ -63,7 +63,6 @@ "name": "maximum" }, "type": { - "name": "maximum", "pos_args": [], "args": [ { @@ -149,10 +148,10 @@ }, "expr": { "_type": "VariableExpr", - "name": "v1" + "name": "v2" }, "type": { - "name": "int" + "name": "float" } }, { @@ -162,10 +161,10 @@ }, "expr": { "_type": "VariableExpr", - "name": "v2" + "name": "v1" }, "type": { - "name": "float" + "name": "int" } }, { @@ -177,12 +176,12 @@ "_type": "BinaryExpr", "left": { "_type": "VariableExpr", - "name": "v1" + "name": "v2" }, "operator": "+", "right": { "_type": "VariableExpr", - "name": "v2" + "name": "v1" } }, "type": { diff --git a/tests/cases/midas-parser/01_simple_types.midas b/tests/cases/midas-parser/01_simple_types.midas index 6446790..f0df3e2 100644 --- a/tests/cases/midas-parser/01_simple_types.midas +++ b/tests/cases/midas-parser/01_simple_types.midas @@ -10,8 +10,8 @@ type Difference[T] = T // Complex custom type, containing two values accessible through properties type GeoLocation = { - lat: Latitude - lon: Longitude + prop lat: Latitude + prop lon: Longitude } // Define operations on our custom type @@ -19,23 +19,23 @@ extend GeoLocation { // This type is compatible with the `-` operation with another GeoLocation // i.e. you can subtract a GeoLocation from another GeoLocation, resulting // in a Difference of GeoLocations - op __sub__(GeoLocation) -> Difference[GeoLocation] + def __sub__: fn(GeoLocation, /) -> Difference[GeoLocation] } // For complex generics, you need to specify how the genericity the properties // are handled type Difference[GeoLocation] = { - lat: Difference[Latitude] - lon: Difference[Longitude] + prop lat: Difference[Latitude] + prop lon: Difference[Longitude] } // Simple operation defined on our custom types extend Latitude { - op __sub__(Latitude) -> Difference[Latitude] + def __sub__: fn(Latitude, /) -> Difference[Latitude] } extend Longitude { - op __sub__(Longitude) -> Difference[Longitude] + def __sub__: fn(Longitude, /) -> Difference[Longitude] } // Predefined custom predicates that can be referenced in other definitions @@ -45,13 +45,13 @@ predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10) predicate Arctic(loc: GeoLocation) = (loc.lat >= 66) type Person = { - name: str + prop name: str // Property with an inline constraint - age: Optional[int where (0 <= _ < 150)] + prop age: Optional[int where (0 <= _ < 150)] // Property referencing a predicate - height: float where StrictlyPositive + prop height: float where StrictlyPositive - home: GeoLocation + prop home: GeoLocation } diff --git a/tests/cases/midas-parser/01_simple_types.midas.ref.json b/tests/cases/midas-parser/01_simple_types.midas.ref.json index 1d94718..be45687 100644 --- a/tests/cases/midas-parser/01_simple_types.midas.ref.json +++ b/tests/cases/midas-parser/01_simple_types.midas.ref.json @@ -511,17 +511,11 @@ "column": 1 }, { - "type": "IDENTIFIER", - "lexeme": "lat", + "type": "PROP", + "lexeme": "prop", "line": 13, "column": 5 }, - { - "type": "COLON", - "lexeme": ":", - "line": 13, - "column": 8 - }, { "type": "WHITESPACE", "lexeme": " ", @@ -530,15 +524,33 @@ }, { "type": "IDENTIFIER", - "lexeme": "Latitude", + "lexeme": "lat", "line": 13, "column": 10 }, + { + "type": "COLON", + "lexeme": ":", + "line": 13, + "column": 13 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 13, + "column": 14 + }, + { + "type": "IDENTIFIER", + "lexeme": "Latitude", + "line": 13, + "column": 15 + }, { "type": "NEWLINE", "lexeme": "\n", "line": 13, - "column": 18 + "column": 23 }, { "type": "WHITESPACE", @@ -547,17 +559,11 @@ "column": 1 }, { - "type": "IDENTIFIER", - "lexeme": "lon", + "type": "PROP", + "lexeme": "prop", "line": 14, "column": 5 }, - { - "type": "COLON", - "lexeme": ":", - "line": 14, - "column": 8 - }, { "type": "WHITESPACE", "lexeme": " ", @@ -566,15 +572,33 @@ }, { "type": "IDENTIFIER", - "lexeme": "Longitude", + "lexeme": "lon", "line": 14, "column": 10 }, + { + "type": "COLON", + "lexeme": ":", + "line": 14, + "column": 13 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 14, + "column": 14 + }, + { + "type": "IDENTIFIER", + "lexeme": "Longitude", + "line": 14, + "column": 15 + }, { "type": "NEWLINE", "lexeme": "\n", "line": 14, - "column": 19 + "column": 24 }, { "type": "RIGHT_BRACE", @@ -703,8 +727,8 @@ "column": 1 }, { - "type": "OP", - "lexeme": "op", + "type": "DEF", + "lexeme": "def", "line": 22, "column": 5 }, @@ -712,79 +736,115 @@ "type": "WHITESPACE", "lexeme": " ", "line": 22, - "column": 7 + "column": 8 }, { "type": "IDENTIFIER", "lexeme": "__sub__", "line": 22, - "column": 8 + "column": 9 + }, + { + "type": "COLON", + "lexeme": ":", + "line": 22, + "column": 16 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 22, + "column": 17 + }, + { + "type": "FUNC", + "lexeme": "fn", + "line": 22, + "column": 18 }, { "type": "LEFT_PAREN", "lexeme": "(", "line": 22, - "column": 15 + "column": 20 }, { "type": "IDENTIFIER", "lexeme": "GeoLocation", "line": 22, - "column": 16 + "column": 21 + }, + { + "type": "COMMA", + "lexeme": ",", + "line": 22, + "column": 32 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 22, + "column": 33 + }, + { + "type": "SLASH", + "lexeme": "/", + "line": 22, + "column": 34 }, { "type": "RIGHT_PAREN", "lexeme": ")", "line": 22, - "column": 27 + "column": 35 }, { "type": "WHITESPACE", "lexeme": " ", "line": 22, - "column": 28 + "column": 36 }, { "type": "ARROW", "lexeme": "->", "line": 22, - "column": 29 + "column": 37 }, { "type": "WHITESPACE", "lexeme": " ", "line": 22, - "column": 31 + "column": 39 }, { "type": "IDENTIFIER", "lexeme": "Difference", "line": 22, - "column": 32 + "column": 40 }, { "type": "LEFT_BRACKET", "lexeme": "[", "line": 22, - "column": 42 + "column": 50 }, { "type": "IDENTIFIER", "lexeme": "GeoLocation", "line": 22, - "column": 43 + "column": 51 }, { "type": "RIGHT_BRACKET", "lexeme": "]", "line": 22, - "column": 54 + "column": 62 }, { "type": "NEWLINE", "lexeme": "\n", "line": 22, - "column": 55 + "column": 63 }, { "type": "RIGHT_BRACE", @@ -901,17 +961,11 @@ "column": 1 }, { - "type": "IDENTIFIER", - "lexeme": "lat", + "type": "PROP", + "lexeme": "prop", "line": 28, "column": 5 }, - { - "type": "COLON", - "lexeme": ":", - "line": 28, - "column": 8 - }, { "type": "WHITESPACE", "lexeme": " ", @@ -920,33 +974,51 @@ }, { "type": "IDENTIFIER", - "lexeme": "Difference", + "lexeme": "lat", "line": 28, "column": 10 }, + { + "type": "COLON", + "lexeme": ":", + "line": 28, + "column": 13 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 28, + "column": 14 + }, + { + "type": "IDENTIFIER", + "lexeme": "Difference", + "line": 28, + "column": 15 + }, { "type": "LEFT_BRACKET", "lexeme": "[", "line": 28, - "column": 20 + "column": 25 }, { "type": "IDENTIFIER", "lexeme": "Latitude", "line": 28, - "column": 21 + "column": 26 }, { "type": "RIGHT_BRACKET", "lexeme": "]", "line": 28, - "column": 29 + "column": 34 }, { "type": "NEWLINE", "lexeme": "\n", "line": 28, - "column": 30 + "column": 35 }, { "type": "WHITESPACE", @@ -955,17 +1027,11 @@ "column": 1 }, { - "type": "IDENTIFIER", - "lexeme": "lon", + "type": "PROP", + "lexeme": "prop", "line": 29, "column": 5 }, - { - "type": "COLON", - "lexeme": ":", - "line": 29, - "column": 8 - }, { "type": "WHITESPACE", "lexeme": " ", @@ -974,33 +1040,51 @@ }, { "type": "IDENTIFIER", - "lexeme": "Difference", + "lexeme": "lon", "line": 29, "column": 10 }, + { + "type": "COLON", + "lexeme": ":", + "line": 29, + "column": 13 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 29, + "column": 14 + }, + { + "type": "IDENTIFIER", + "lexeme": "Difference", + "line": 29, + "column": 15 + }, { "type": "LEFT_BRACKET", "lexeme": "[", "line": 29, - "column": 20 + "column": 25 }, { "type": "IDENTIFIER", "lexeme": "Longitude", "line": 29, - "column": 21 + "column": 26 }, { "type": "RIGHT_BRACKET", "lexeme": "]", "line": 29, - "column": 30 + "column": 35 }, { "type": "NEWLINE", "lexeme": "\n", "line": 29, - "column": 31 + "column": 36 }, { "type": "RIGHT_BRACE", @@ -1075,8 +1159,8 @@ "column": 1 }, { - "type": "OP", - "lexeme": "op", + "type": "DEF", + "lexeme": "def", "line": 34, "column": 5 }, @@ -1084,79 +1168,115 @@ "type": "WHITESPACE", "lexeme": " ", "line": 34, - "column": 7 + "column": 8 }, { "type": "IDENTIFIER", "lexeme": "__sub__", "line": 34, - "column": 8 + "column": 9 + }, + { + "type": "COLON", + "lexeme": ":", + "line": 34, + "column": 16 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 34, + "column": 17 + }, + { + "type": "FUNC", + "lexeme": "fn", + "line": 34, + "column": 18 }, { "type": "LEFT_PAREN", "lexeme": "(", "line": 34, - "column": 15 + "column": 20 }, { "type": "IDENTIFIER", "lexeme": "Latitude", "line": 34, - "column": 16 + "column": 21 + }, + { + "type": "COMMA", + "lexeme": ",", + "line": 34, + "column": 29 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 34, + "column": 30 + }, + { + "type": "SLASH", + "lexeme": "/", + "line": 34, + "column": 31 }, { "type": "RIGHT_PAREN", "lexeme": ")", "line": 34, - "column": 24 + "column": 32 }, { "type": "WHITESPACE", "lexeme": " ", "line": 34, - "column": 25 + "column": 33 }, { "type": "ARROW", "lexeme": "->", "line": 34, - "column": 26 + "column": 34 }, { "type": "WHITESPACE", "lexeme": " ", "line": 34, - "column": 28 + "column": 36 }, { "type": "IDENTIFIER", "lexeme": "Difference", "line": 34, - "column": 29 + "column": 37 }, { "type": "LEFT_BRACKET", "lexeme": "[", "line": 34, - "column": 39 + "column": 47 }, { "type": "IDENTIFIER", "lexeme": "Latitude", "line": 34, - "column": 40 + "column": 48 }, { "type": "RIGHT_BRACKET", "lexeme": "]", "line": 34, - "column": 48 + "column": 56 }, { "type": "NEWLINE", "lexeme": "\n", "line": 34, - "column": 49 + "column": 57 }, { "type": "RIGHT_BRACE", @@ -1219,8 +1339,8 @@ "column": 1 }, { - "type": "OP", - "lexeme": "op", + "type": "DEF", + "lexeme": "def", "line": 38, "column": 5 }, @@ -1228,79 +1348,115 @@ "type": "WHITESPACE", "lexeme": " ", "line": 38, - "column": 7 + "column": 8 }, { "type": "IDENTIFIER", "lexeme": "__sub__", "line": 38, - "column": 8 + "column": 9 + }, + { + "type": "COLON", + "lexeme": ":", + "line": 38, + "column": 16 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 38, + "column": 17 + }, + { + "type": "FUNC", + "lexeme": "fn", + "line": 38, + "column": 18 }, { "type": "LEFT_PAREN", "lexeme": "(", "line": 38, - "column": 15 + "column": 20 }, { "type": "IDENTIFIER", "lexeme": "Longitude", "line": 38, - "column": 16 + "column": 21 + }, + { + "type": "COMMA", + "lexeme": ",", + "line": 38, + "column": 30 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 38, + "column": 31 + }, + { + "type": "SLASH", + "lexeme": "/", + "line": 38, + "column": 32 }, { "type": "RIGHT_PAREN", "lexeme": ")", "line": 38, - "column": 25 + "column": 33 }, { "type": "WHITESPACE", "lexeme": " ", "line": 38, - "column": 26 + "column": 34 }, { "type": "ARROW", "lexeme": "->", "line": 38, - "column": 27 + "column": 35 }, { "type": "WHITESPACE", "lexeme": " ", "line": 38, - "column": 29 + "column": 37 }, { "type": "IDENTIFIER", "lexeme": "Difference", "line": 38, - "column": 30 + "column": 38 }, { "type": "LEFT_BRACKET", "lexeme": "[", "line": 38, - "column": 40 + "column": 48 }, { "type": "IDENTIFIER", "lexeme": "Longitude", "line": 38, - "column": 41 + "column": 49 }, { "type": "RIGHT_BRACKET", "lexeme": "]", "line": 38, - "column": 50 + "column": 58 }, { "type": "NEWLINE", "lexeme": "\n", "line": 38, - "column": 51 + "column": 59 }, { "type": "RIGHT_BRACE", @@ -1903,34 +2059,46 @@ "column": 1 }, { - "type": "IDENTIFIER", - "lexeme": "name", + "type": "PROP", + "lexeme": "prop", "line": 48, "column": 5 }, - { - "type": "COLON", - "lexeme": ":", - "line": 48, - "column": 9 - }, { "type": "WHITESPACE", "lexeme": " ", "line": 48, + "column": 9 + }, + { + "type": "IDENTIFIER", + "lexeme": "name", + "line": 48, "column": 10 }, + { + "type": "COLON", + "lexeme": ":", + "line": 48, + "column": 14 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 48, + "column": 15 + }, { "type": "IDENTIFIER", "lexeme": "str", "line": 48, - "column": 11 + "column": 16 }, { "type": "NEWLINE", "lexeme": "\n", "line": 48, - "column": 14 + "column": 19 }, { "type": "NEWLINE", @@ -1963,17 +2131,11 @@ "column": 1 }, { - "type": "IDENTIFIER", - "lexeme": "age", + "type": "PROP", + "lexeme": "prop", "line": 51, "column": 5 }, - { - "type": "COLON", - "lexeme": ":", - "line": 51, - "column": 8 - }, { "type": "WHITESPACE", "lexeme": " ", @@ -1982,74 +2144,68 @@ }, { "type": "IDENTIFIER", - "lexeme": "Optional", + "lexeme": "age", "line": 51, "column": 10 }, + { + "type": "COLON", + "lexeme": ":", + "line": 51, + "column": 13 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 51, + "column": 14 + }, + { + "type": "IDENTIFIER", + "lexeme": "Optional", + "line": 51, + "column": 15 + }, { "type": "LEFT_BRACKET", "lexeme": "[", "line": 51, - "column": 18 + "column": 23 }, { "type": "IDENTIFIER", "lexeme": "int", "line": 51, - "column": 19 + "column": 24 }, { "type": "WHITESPACE", "lexeme": " ", "line": 51, - "column": 22 + "column": 27 }, { "type": "WHERE", "lexeme": "where", "line": 51, - "column": 23 + "column": 28 }, { "type": "WHITESPACE", "lexeme": " ", "line": 51, - "column": 28 + "column": 33 }, { "type": "LEFT_PAREN", "lexeme": "(", "line": 51, - "column": 29 + "column": 34 }, { "type": "NUMBER", "lexeme": "0", "line": 51, - "column": 30 - }, - { - "type": "WHITESPACE", - "lexeme": " ", - "line": 51, - "column": 31 - }, - { - "type": "LESS_EQUAL", - "lexeme": "<=", - "line": 51, - "column": 32 - }, - { - "type": "WHITESPACE", - "lexeme": " ", - "line": 51, - "column": 34 - }, - { - "type": "UNDERSCORE", - "lexeme": "_", - "line": 51, "column": 35 }, { @@ -2059,8 +2215,8 @@ "column": 36 }, { - "type": "LESS", - "lexeme": "<", + "type": "LESS_EQUAL", + "lexeme": "<=", "line": 51, "column": 37 }, @@ -2068,31 +2224,55 @@ "type": "WHITESPACE", "lexeme": " ", "line": 51, - "column": 38 + "column": 39 + }, + { + "type": "UNDERSCORE", + "lexeme": "_", + "line": 51, + "column": 40 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 51, + "column": 41 + }, + { + "type": "LESS", + "lexeme": "<", + "line": 51, + "column": 42 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 51, + "column": 43 }, { "type": "NUMBER", "lexeme": "150", "line": 51, - "column": 39 + "column": 44 }, { "type": "RIGHT_PAREN", "lexeme": ")", "line": 51, - "column": 42 + "column": 47 }, { "type": "RIGHT_BRACKET", "lexeme": "]", "line": 51, - "column": 43 + "column": 48 }, { "type": "NEWLINE", "lexeme": "\n", "line": 51, - "column": 44 + "column": 49 }, { "type": "NEWLINE", @@ -2124,59 +2304,71 @@ "line": 54, "column": 1 }, + { + "type": "PROP", + "lexeme": "prop", + "line": 54, + "column": 5 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 54, + "column": 9 + }, { "type": "IDENTIFIER", "lexeme": "height", "line": 54, - "column": 5 + "column": 10 }, { "type": "COLON", "lexeme": ":", "line": 54, - "column": 11 + "column": 16 }, { "type": "WHITESPACE", "lexeme": " ", "line": 54, - "column": 12 + "column": 17 }, { "type": "IDENTIFIER", "lexeme": "float", "line": 54, - "column": 13 + "column": 18 }, { "type": "WHITESPACE", "lexeme": " ", "line": 54, - "column": 18 + "column": 23 }, { "type": "WHERE", "lexeme": "where", "line": 54, - "column": 19 + "column": 24 }, { "type": "WHITESPACE", "lexeme": " ", "line": 54, - "column": 24 + "column": 29 }, { "type": "IDENTIFIER", "lexeme": "StrictlyPositive", "line": 54, - "column": 25 + "column": 30 }, { "type": "NEWLINE", "lexeme": "\n", "line": 54, - "column": 41 + "column": 46 }, { "type": "NEWLINE", @@ -2191,34 +2383,46 @@ "column": 1 }, { - "type": "IDENTIFIER", - "lexeme": "home", + "type": "PROP", + "lexeme": "prop", "line": 56, "column": 5 }, - { - "type": "COLON", - "lexeme": ":", - "line": 56, - "column": 9 - }, { "type": "WHITESPACE", "lexeme": " ", "line": 56, + "column": 9 + }, + { + "type": "IDENTIFIER", + "lexeme": "home", + "line": 56, "column": 10 }, + { + "type": "COLON", + "lexeme": ":", + "line": 56, + "column": 14 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 56, + "column": 15 + }, { "type": "IDENTIFIER", "lexeme": "GeoLocation", "line": 56, - "column": 11 + "column": 16 }, { "type": "NEWLINE", "lexeme": "\n", "line": 56, - "column": 22 + "column": 27 }, { "type": "RIGHT_BRACE", @@ -2345,9 +2549,10 @@ "params": [], "type": { "_type": "ComplexType", - "properties": [ + "members": [ { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": "PROPERTY", "name": "lat", "type": { "_type": "NamedType", @@ -2355,7 +2560,8 @@ } }, { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": "PROPERTY", "name": "lon", "type": { "_type": "NamedType", @@ -2367,30 +2573,40 @@ }, { "_type": "ExtendStmt", - "type": { - "_type": "NamedType", - "name": "GeoLocation" - }, - "operations": [ + "name": "GeoLocation", + "params": [], + "members": [ { - "_type": "OpStmt", + "_type": "MemberStmt", + "kind": "METHOD", "name": "__sub__", - "operand": { - "_type": "NamedType", - "name": "GeoLocation" - }, - "result": { - "_type": "GenericType", - "type": { - "_type": "NamedType", - "name": "Difference" - }, - "args": [ + "type": { + "_type": "FunctionType", + "pos_args": [ { - "_type": "NamedType", - "name": "GeoLocation" + "name": null, + "type": { + "_type": "NamedType", + "name": "GeoLocation" + }, + "required": true } - ] + ], + "args": [], + "kw_args": [], + "returns": { + "_type": "GenericType", + "type": { + "_type": "NamedType", + "name": "Difference" + }, + "args": [ + { + "_type": "NamedType", + "name": "GeoLocation" + } + ] + } } } ] @@ -2406,9 +2622,10 @@ ], "type": { "_type": "ComplexType", - "properties": [ + "members": [ { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": "PROPERTY", "name": "lat", "type": { "_type": "GenericType", @@ -2425,7 +2642,8 @@ } }, { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": "PROPERTY", "name": "lon", "type": { "_type": "GenericType", @@ -2446,60 +2664,80 @@ }, { "_type": "ExtendStmt", - "type": { - "_type": "NamedType", - "name": "Latitude" - }, - "operations": [ + "name": "Latitude", + "params": [], + "members": [ { - "_type": "OpStmt", + "_type": "MemberStmt", + "kind": "METHOD", "name": "__sub__", - "operand": { - "_type": "NamedType", - "name": "Latitude" - }, - "result": { - "_type": "GenericType", - "type": { - "_type": "NamedType", - "name": "Difference" - }, - "args": [ + "type": { + "_type": "FunctionType", + "pos_args": [ { - "_type": "NamedType", - "name": "Latitude" + "name": null, + "type": { + "_type": "NamedType", + "name": "Latitude" + }, + "required": true } - ] + ], + "args": [], + "kw_args": [], + "returns": { + "_type": "GenericType", + "type": { + "_type": "NamedType", + "name": "Difference" + }, + "args": [ + { + "_type": "NamedType", + "name": "Latitude" + } + ] + } } } ] }, { "_type": "ExtendStmt", - "type": { - "_type": "NamedType", - "name": "Longitude" - }, - "operations": [ + "name": "Longitude", + "params": [], + "members": [ { - "_type": "OpStmt", + "_type": "MemberStmt", + "kind": "METHOD", "name": "__sub__", - "operand": { - "_type": "NamedType", - "name": "Longitude" - }, - "result": { - "_type": "GenericType", - "type": { - "_type": "NamedType", - "name": "Difference" - }, - "args": [ + "type": { + "_type": "FunctionType", + "pos_args": [ { - "_type": "NamedType", - "name": "Longitude" + "name": null, + "type": { + "_type": "NamedType", + "name": "Longitude" + }, + "required": true } - ] + ], + "args": [], + "kw_args": [], + "returns": { + "_type": "GenericType", + "type": { + "_type": "NamedType", + "name": "Difference" + }, + "args": [ + { + "_type": "NamedType", + "name": "Longitude" + } + ] + } } } ] @@ -2620,9 +2858,10 @@ "params": [], "type": { "_type": "ComplexType", - "properties": [ + "members": [ { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": "PROPERTY", "name": "name", "type": { "_type": "NamedType", @@ -2630,7 +2869,8 @@ } }, { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": "PROPERTY", "name": "age", "type": { "_type": "GenericType", @@ -2672,7 +2912,8 @@ } }, { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": "PROPERTY", "name": "height", "type": { "_type": "ConstraintType", @@ -2687,7 +2928,8 @@ } }, { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": "PROPERTY", "name": "home", "type": { "_type": "NamedType", From 890e2f035a86b3100f7e565d0f9b2102a4c093cb Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 18:01:02 +0200 Subject: [PATCH 48/64] refactor(checker): replace all accept calls make visitor accept calls more explicit with type_of(), resolve_type_expr() and process_stmt() --- .../04_complex_types.midas | 4 +-- midas/checker/python.py | 36 +++++++++++-------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/examples/01_simple_type_checking/04_complex_types.midas b/examples/01_simple_type_checking/04_complex_types.midas index b561cef..adc76b3 100644 --- a/examples/01_simple_type_checking/04_complex_types.midas +++ b/examples/01_simple_type_checking/04_complex_types.midas @@ -1,8 +1,8 @@ type Meter = float extend Meter { - def __add__: fn(Meter) -> Meter - def __sub__: fn(Meter) -> Meter + def __add__: fn(Meter, /) -> Meter + def __sub__: fn(Meter, /) -> Meter } type Coordinate = object diff --git a/midas/checker/python.py b/midas/checker/python.py index cf25593..88ecde0 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -78,6 +78,12 @@ class PythonTyper( self.judgements.append((expr, type)) return type + def resolve_type_expr(self, expr: p.MidasType) -> Type: + return expr.accept(self) + + def process_stmt(self, stmt: p.Stmt) -> None: + stmt.accept(self) + def process_block(self, block: list[p.Stmt], env: Environment) -> bool: """Evaluate a sequence of statements @@ -93,7 +99,7 @@ class PythonTyper( returned: bool = False for i, stmt in enumerate(block): try: - stmt.accept(self) + self.process_stmt(stmt) except ReturnException: returned = True if i < len(block) - 1: @@ -111,7 +117,7 @@ class PythonTyper( statements (list[p.Stmt]): the statements to evaluate and check """ for stmt in statements: - stmt.accept(self) + self.process_stmt(stmt) self.logger.debug(f"Final environment: {self.env.flat_dict()}") @@ -144,9 +150,9 @@ class PythonTyper( def eval_arg_type(arg: p.Function.Argument) -> Type: if arg.type is not None: - return arg.type.accept(self) + return self.resolve_type_expr(arg.type) if arg.default is not None: - return arg.default.accept(self) + return self.type_of(arg.default) return UnknownType() pos: int = 0 @@ -186,7 +192,7 @@ class PythonTyper( returns_hint: Optional[Type] = None if stmt.returns is not None: - returns_hint = stmt.returns.accept(self) + returns_hint = self.resolve_type_expr(stmt.returns) # Early define to handle simple fully-typed recursion inside_function: Function = Function( pos_args=pos_args, @@ -232,7 +238,7 @@ class PythonTyper( def visit_type_assign(self, stmt: p.TypeAssign) -> None: # TODO check not yet defined locally - type: Type = stmt.type.accept(self) + type: Type = self.resolve_type_expr(stmt.type) self.env.define(stmt.name, type) def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: @@ -287,7 +293,7 @@ class PythonTyper( ) def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: - type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType() + type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType() self.env.return_types.append(type) raise ReturnException() @@ -297,7 +303,7 @@ class PythonTyper( # if (m := 1 + 1) < 2: # ... # print(m) # <- m is still defined - test_type: Type = stmt.test.accept(self) + test_type: Type = self.type_of(stmt.test) # TODO Allow subtypes or any type if test_type != self.types.get_type("bool"): @@ -419,8 +425,8 @@ class PythonTyper( return type or UnknownType() def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: - left: Type = expr.left.accept(self) - right: Type = expr.right.accept(self) + left: Type = self.type_of(expr.left) + right: Type = self.type_of(expr.right) if self.is_subtype(left, right): return right @@ -434,10 +440,10 @@ class PythonTyper( return UnknownType() def visit_cast_expr(self, expr: p.CastExpr) -> Type: - return expr.type.accept(self) + return self.resolve_type_expr(expr.type) def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type: - test_type: Type = expr.test.accept(self) + test_type: Type = self.type_of(expr.test) # TODO Allow subtypes or any type if test_type != self.types.get_type("bool"): @@ -445,8 +451,8 @@ class PythonTyper( expr.test.location, f"If test must be a boolean, got {test_type}" ) - true_type: Type = expr.if_true.accept(self) - false_type: Type = expr.if_false.accept(self) + true_type: Type = self.type_of(expr.if_true) + false_type: Type = self.type_of(expr.if_false) if self.is_subtype(true_type, false_type): return false_type if self.is_subtype(false_type, true_type): @@ -484,7 +490,7 @@ class PythonTyper( return UnknownType() if node.param is not None: - param: Type = node.param.accept(self) + param: Type = self.resolve_type_expr(node.param) return self.types.apply_generic(base, [param]) return base From 064702fe13a00559a91dfeea134543aa179a577b Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 18:01:32 +0200 Subject: [PATCH 49/64] tests: update with newly reported judgements --- tests/cases/checker/03_functions.py.ref.json | 13 ++ .../cases/checker/05_control_flow.py.ref.json | 194 ++++++++++++++++++ tests/cases/checker/06_subtyping.py.ref.json | 47 +++++ 3 files changed, 254 insertions(+) diff --git a/tests/cases/checker/03_functions.py.ref.json b/tests/cases/checker/03_functions.py.ref.json index 756a143..814bc35 100644 --- a/tests/cases/checker/03_functions.py.ref.json +++ b/tests/cases/checker/03_functions.py.ref.json @@ -254,6 +254,19 @@ } ], "judgments": [ + { + "location": { + "from": "L2:11", + "to": "L2:15" + }, + "expr": { + "_type": "LiteralExpr", + "value": true + }, + "type": { + "name": "bool" + } + }, { "location": { "from": "L5:5", diff --git a/tests/cases/checker/05_control_flow.py.ref.json b/tests/cases/checker/05_control_flow.py.ref.json index 8f031f2..be86030 100644 --- a/tests/cases/checker/05_control_flow.py.ref.json +++ b/tests/cases/checker/05_control_flow.py.ref.json @@ -70,6 +70,27 @@ "name": "int" } }, + { + "location": { + "from": "L2:11", + "to": "L2:16" + }, + "expr": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "a" + }, + "operator": "+", + "right": { + "_type": "VariableExpr", + "name": "b" + } + }, + "type": { + "name": "int" + } + }, { "location": { "from": "L5:7", @@ -96,6 +117,27 @@ "name": "int" } }, + { + "location": { + "from": "L5:7", + "to": "L5:12" + }, + "expr": { + "_type": "CompareExpr", + "left": { + "_type": "VariableExpr", + "name": "a" + }, + "operator": "<", + "right": { + "_type": "VariableExpr", + "name": "b" + } + }, + "type": { + "name": "bool" + } + }, { "location": { "from": "L6:15", @@ -122,6 +164,27 @@ "name": "int" } }, + { + "location": { + "from": "L6:15", + "to": "L6:20" + }, + "expr": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "b" + }, + "operator": "-", + "right": { + "_type": "VariableExpr", + "name": "a" + } + }, + "type": { + "name": "int" + } + }, { "location": { "from": "L8:15", @@ -148,6 +211,27 @@ "name": "int" } }, + { + "location": { + "from": "L8:15", + "to": "L8:20" + }, + "expr": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "a" + }, + "operator": "-", + "right": { + "_type": "VariableExpr", + "name": "b" + } + }, + "type": { + "name": "int" + } + }, { "location": { "from": "L15:7", @@ -174,6 +258,27 @@ "name": "int" } }, + { + "location": { + "from": "L15:7", + "to": "L15:13" + }, + "expr": { + "_type": "CompareExpr", + "left": { + "_type": "VariableExpr", + "name": "a" + }, + "operator": ">", + "right": { + "_type": "LiteralExpr", + "value": 10 + } + }, + "type": { + "name": "bool" + } + }, { "location": { "from": "L16:15", @@ -200,6 +305,40 @@ "name": "int" } }, + { + "location": { + "from": "L16:15", + "to": "L16:21" + }, + "expr": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "a" + }, + "operator": "-", + "right": { + "_type": "LiteralExpr", + "value": 10 + } + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L18:15", + "to": "L18:16" + }, + "expr": { + "_type": "VariableExpr", + "name": "a" + }, + "type": { + "name": "int" + } + }, { "location": { "from": "L22:7", @@ -226,6 +365,27 @@ "name": "int" } }, + { + "location": { + "from": "L22:7", + "to": "L22:12" + }, + "expr": { + "_type": "CompareExpr", + "left": { + "_type": "VariableExpr", + "name": "a" + }, + "operator": "<", + "right": { + "_type": "VariableExpr", + "name": "b" + } + }, + "type": { + "name": "bool" + } + }, { "location": { "from": "L23:15", @@ -251,6 +411,40 @@ "type": { "name": "int" } + }, + { + "location": { + "from": "L23:15", + "to": "L23:20" + }, + "expr": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "b" + }, + "operator": "-", + "right": { + "_type": "VariableExpr", + "name": "a" + } + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L25:15", + "to": "L25:21" + }, + "expr": { + "_type": "LiteralExpr", + "value": "oops" + }, + "type": { + "name": "str" + } } ] } \ No newline at end of file diff --git a/tests/cases/checker/06_subtyping.py.ref.json b/tests/cases/checker/06_subtyping.py.ref.json index 0659939..3435f45 100644 --- a/tests/cases/checker/06_subtyping.py.ref.json +++ b/tests/cases/checker/06_subtyping.py.ref.json @@ -53,6 +53,53 @@ "name": "float" } }, + { + "location": { + "from": "L6:7", + "to": "L6:12" + }, + "expr": { + "_type": "CompareExpr", + "left": { + "_type": "VariableExpr", + "name": "b" + }, + "operator": ">", + "right": { + "_type": "VariableExpr", + "name": "a" + } + }, + "type": { + "name": "bool" + } + }, + { + "location": { + "from": "L7:15", + "to": "L7:16" + }, + "expr": { + "_type": "VariableExpr", + "name": "b" + }, + "type": { + "name": "float" + } + }, + { + "location": { + "from": "L8:11", + "to": "L8:12" + }, + "expr": { + "_type": "VariableExpr", + "name": "a" + }, + "type": { + "name": "float" + } + }, { "location": { "from": "L11:5", From 1c71badf24dfe745fcdb93110e3bae40fb33c0c0 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 18:11:14 +0200 Subject: [PATCH 50/64] fix(checker): report unsupported features --- midas/checker/midas.py | 24 ++++++++++++++++-------- midas/checker/python.py | 16 ++++++++++++---- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/midas/checker/midas.py b/midas/checker/midas.py index f54d6ab..3764c03 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -105,24 +105,32 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type member.kind == m.MemberKind.METHOD, ) - def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ... + def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: + self.reporter.warning(stmt.location, "PredicateStmt not yet supported") - def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ... + def visit_logical_expr(self, expr: m.LogicalExpr) -> None: + self.reporter.warning(expr.location, "LogicalExpr not yet supported") - def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ... + def visit_binary_expr(self, expr: m.BinaryExpr) -> None: + self.reporter.warning(expr.location, "BinaryExpr not yet supported") - def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ... + def visit_unary_expr(self, expr: m.UnaryExpr) -> None: + self.reporter.warning(expr.location, "UnaryExpr not yet supported") - def visit_get_expr(self, expr: m.GetExpr) -> None: ... + def visit_get_expr(self, expr: m.GetExpr) -> None: + self.reporter.warning(expr.location, "GetExpr not yet supported") - def visit_variable_expr(self, expr: m.VariableExpr) -> None: ... + def visit_variable_expr(self, expr: m.VariableExpr) -> None: + self.reporter.warning(expr.location, "VariableExpr not yet supported") def visit_grouping_expr(self, expr: m.GroupingExpr) -> None: return expr.expr.accept(self) - def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ... + def visit_literal_expr(self, expr: m.LiteralExpr) -> None: + self.reporter.warning(expr.location, "LiteralExpr not yet supported") - def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ... + def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: + self.reporter.warning(expr.location, "WildcardExpr not yet supported") def visit_named_type(self, type: m.NamedType) -> Type: name: str = type.name.lexeme diff --git a/midas/checker/python.py b/midas/checker/python.py index 88ecde0..5149bb7 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -375,7 +375,9 @@ class PythonTyper( self.reporter.warning(location, f"Unsupported operation {operation}") return UnknownType() - def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ... + def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: + self.reporter.warning(expr.location, "UnaryExpr not yet supported") + return UnknownType() def visit_call_expr(self, expr: p.CallExpr) -> Type: callee: Type = self.type_of(expr.callee) @@ -494,11 +496,17 @@ class PythonTyper( return self.types.apply_generic(base, [param]) return base - def visit_constraint_type(self, node: p.ConstraintType) -> Type: ... + def visit_constraint_type(self, node: p.ConstraintType) -> Type: + self.reporter.warning(node.location, "ConstraintType not yet supported") + return UnknownType() - def visit_frame_column(self, node: p.FrameColumn) -> Type: ... + def visit_frame_column(self, node: p.FrameColumn) -> Type: + self.reporter.warning(node.location, "FrameColumn not yet supported") + return UnknownType() - def visit_frame_type(self, node: p.FrameType) -> Type: ... + def visit_frame_type(self, node: p.FrameType) -> Type: + self.reporter.warning(node.location, "FrameType not yet supported") + return UnknownType() def map_call_arguments( self, function: Function, call: p.CallExpr From 6577241af9b20cb92fb7c135691a6c4ed1ddfe21 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 18:19:50 +0200 Subject: [PATCH 51/64] feat(checker): handle unary operations --- .../01_simple_operations.py | 2 + .../02_simple_types.midas | 10 ++--- midas/checker/operators.py | 7 +++ midas/checker/python.py | 44 +++++++++++++++++-- 4 files changed, 55 insertions(+), 8 deletions(-) diff --git a/examples/01_simple_type_checking/01_simple_operations.py b/examples/01_simple_type_checking/01_simple_operations.py index a3ac707..4e767f2 100644 --- a/examples/01_simple_type_checking/01_simple_operations.py +++ b/examples/01_simple_type_checking/01_simple_operations.py @@ -9,3 +9,5 @@ d = True e = d + d f: float = a + +f = -f diff --git a/examples/01_simple_type_checking/02_simple_types.midas b/examples/01_simple_type_checking/02_simple_types.midas index 6a1a6a2..ff4edb1 100644 --- a/examples/01_simple_type_checking/02_simple_types.midas +++ b/examples/01_simple_type_checking/02_simple_types.midas @@ -3,12 +3,12 @@ type Second = float type MeterPerSecond = float extend Meter { - op __add__(Meter) -> Meter - op __sub__(Meter) -> Meter - op __truediv__(Second) -> MeterPerSecond + def __add__: fn(Meter, /) -> Meter + def __sub__: fn(Meter, /) -> Meter + def __truediv__: fn(Second, /) -> MeterPerSecond } extend Second { - op __add__(Second) -> Second - op __sub__(Second) -> Second + def __add__: fn(Second, /) -> Second + def __sub__: fn(Second, /) -> Second } diff --git a/midas/checker/operators.py b/midas/checker/operators.py index e65ab07..58af88c 100644 --- a/midas/checker/operators.py +++ b/midas/checker/operators.py @@ -29,3 +29,10 @@ COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = { # ast.In: "__in__", # ast.NotIn: "__notin__", } + +UNARY_METHODS: dict[Type[ast.unaryop], str] = { + ast.Invert: "__invert__", + # ast.Not: "", + ast.UAdd: "__pos__", + ast.USub: "__neg__", +} diff --git a/midas/checker/python.py b/midas/checker/python.py index 5149bb7..9c788c8 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -6,7 +6,7 @@ from typing import Optional import midas.ast.python as p from midas.ast.location import Location from midas.checker.environment import Environment -from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS +from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter, Reporter from midas.checker.resolver import Resolver @@ -376,8 +376,37 @@ class PythonTyper( return UnknownType() def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: - self.reporter.warning(expr.location, "UnaryExpr not yet supported") - return UnknownType() + method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__) + if method is None: + self.logger.warning(f"Unsupported operator {expr.operator}") + self.reporter.warning( + expr.location, f"Unsupported operator {expr.operator}" + ) + return UnknownType() + + operand: Type = self.type_of(expr.right) + operation: Optional[Type] = self.types.lookup_member(operand, method) + if operation is None: + self.reporter.error( + expr.location, + f"Undefined operation {method} for {operand}", + ) + return UnknownType() + + match operation: + case Function() as function: + if not self._is_unary_function(function): + self.reporter.error( + expr.location, + f"Wrong definition of unary operation. Expected function with 0 parameters, got {function}", + ) + return UnknownType() + return function.returns + case _: + self.reporter.warning( + expr.location, f"Unsupported operation {operation}" + ) + return UnknownType() def visit_call_expr(self, expr: p.CallExpr) -> Type: callee: Type = self.type_of(expr.callee) @@ -633,3 +662,12 @@ class PythonTyper( if len(function.kw_args) != 0: return False return True + + def _is_unary_function(self, function: Function) -> bool: + if len(function.pos_args) != 0: + return False + if len(function.args) != 0: + return False + if len(function.kw_args) != 0: + return False + return True From c92b6b5c184c411c9b58ba5c5bd780b9514c627b Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 18:44:19 +0200 Subject: [PATCH 52/64] feat(parser): add subscript expressions --- gen/python.py | 5 +++++ midas/ast/printer.py | 12 +++++++++++- midas/ast/python.py | 12 ++++++++++++ midas/cli/highlighter.py | 4 ++++ midas/parser/python.py | 8 ++++++++ tests/serializer/python.py | 8 ++++++++ 6 files changed, 48 insertions(+), 1 deletion(-) diff --git a/gen/python.py b/gen/python.py index 79ba8b0..b7c38ec 100644 --- a/gen/python.py +++ b/gen/python.py @@ -143,4 +143,9 @@ class ListExpr: items: list[Expr] +class SubscriptExpr: + object: Expr + index: Expr + + ###< diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 2124778..3495883 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -664,7 +664,7 @@ class PythonAstPrinter( def visit_literal_expr(self, expr: p.LiteralExpr) -> None: self._write_line("LiteralExpr") with self._child_level(single=True): - self._write_line(f"value: {expr.value}") + self._write_line(f"value: {expr.value!r}") def visit_variable_expr(self, expr: p.VariableExpr) -> None: self._write_line("VariableExpr") @@ -719,3 +719,13 @@ class PythonAstPrinter( if i == len(expr.items) - 1: self._mark_last() item.accept(self) + + def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: + self._write_line("SubscriptExpr") + with self._child_level(): + self._write_line("object") + with self._child_level(single=True): + expr.object.accept(self) + self._write_line("index", last=True) + with self._child_level(single=True): + expr.index.accept(self) diff --git a/midas/ast/python.py b/midas/ast/python.py index 350dbb1..a199b89 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -224,6 +224,9 @@ class Expr(ABC): @abstractmethod def visit_list_expr(self, expr: ListExpr) -> T: ... + @abstractmethod + def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ... + @dataclass(frozen=True) class BinaryExpr(Expr): @@ -324,3 +327,12 @@ class ListExpr(Expr): def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_list_expr(self) + + +@dataclass(frozen=True) +class SubscriptExpr(Expr): + object: Expr + index: Expr + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_subscript_expr(self) diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index 00c8dcf..3c3f07e 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -218,6 +218,10 @@ class PythonHighlighter( for item in expr.items: item.accept(self) + def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: + expr.object.accept(self) + expr.index.accept(self) + class MidasHighlighter( Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None] diff --git a/midas/parser/python.py b/midas/parser/python.py index bbe23c8..8c1f5a7 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -23,6 +23,7 @@ from midas.ast.python import ( MidasType, ReturnStmt, Stmt, + SubscriptExpr, TernaryExpr, TypeAssign, UnaryExpr, @@ -423,6 +424,13 @@ class PythonParser: items=[self.parse_expr(item) for item in items], ) + case ast.Subscript(value=value, slice=index): + return SubscriptExpr( + location=location, + object=self.parse_expr(value), + index=self.parse_expr(index), + ) + case _: raise UnsupportedSyntaxError(node) diff --git a/tests/serializer/python.py b/tests/serializer/python.py index 833d4e4..73b8b84 100644 --- a/tests/serializer/python.py +++ b/tests/serializer/python.py @@ -22,6 +22,7 @@ from midas.ast.python import ( MidasType, ReturnStmt, Stmt, + SubscriptExpr, TernaryExpr, TypeAssign, UnaryExpr, @@ -252,3 +253,10 @@ class PythonAstJsonSerializer( "_type": "ListExpr", "items": [item.accept(self) for item in expr.items], } + + def visit_subscript_expr(self, expr: SubscriptExpr) -> dict: + return { + "_type": "SubscriptExpr", + "object": expr.object.accept(self), + "index": expr.index.accept(self), + } From 178e24cd02fbbfa21c75bb7746800c0ba324d952 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 18:44:49 +0200 Subject: [PATCH 53/64] feat(checker): type check subscripts --- .../04_complex_types.py | 5 ++ midas/checker/python.py | 62 ++++++++++++++----- midas/checker/resolver.py | 4 ++ 3 files changed, 56 insertions(+), 15 deletions(-) diff --git a/examples/01_simple_type_checking/04_complex_types.py b/examples/01_simple_type_checking/04_complex_types.py index ebe958f..4f21cc6 100644 --- a/examples/01_simple_type_checking/04_complex_types.py +++ b/examples/01_simple_type_checking/04_complex_types.py @@ -28,3 +28,8 @@ bar: list[list[Meter]] bar.append([p2.x]) foo2 = foo + foo + +a = foo[0] +b = bar[0][1] +c = bar[0][1][2] # invalid, not method __getitem__ on Meter +c = bar[""] # invalid, wrong index type diff --git a/midas/checker/python.py b/midas/checker/python.py index 9c788c8..45b9a33 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -356,7 +356,7 @@ class PythonTyper( match operation: case Function() as function: - if not self._is_binary_function(function): + if not self._check_arity(function, 1, 0, 0): self.reporter.error( location, f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}", @@ -395,7 +395,7 @@ class PythonTyper( match operation: case Function() as function: - if not self._is_unary_function(function): + if not self._check_arity(function, 0, 0, 0): self.reporter.error( expr.location, f"Wrong definition of unary operation. Expected function with 0 parameters, got {function}", @@ -512,6 +512,41 @@ class PythonTyper( ) return self.types.apply_generic(list_type, [UnknownType()]) + def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type: + object: Type = self.type_of(expr.object) + operation: Optional[Type] = self.types.lookup_member(object, "__getitem__") + if operation is None: + self.reporter.error( + expr.location, + f"Undefined method __getitem__ on {object}", + ) + return UnknownType() + + index: Type = self.type_of(expr.index) + + match operation: + case Function() as function: + if not self._check_arity(function, 1, 0, 0): + self.reporter.error( + expr.location, + f"Wrong definition of __getitem__. Expected function with 1 positional-only parameters, got {function}", + ) + return UnknownType() + + index_arg: Function.Argument = function.pos_args[0] + if not self.is_subtype(index, index_arg.type): + self.reporter.error( + expr.location, + f"Wrong index type, expected {index_arg.type}, got {index}", + ) + return UnknownType() + return function.returns + case _: + self.reporter.warning( + expr.location, f"Unsupported operation {operation}" + ) + return UnknownType() + def visit_base_type(self, node: p.BaseType) -> Type: base: Type try: @@ -654,20 +689,17 @@ class PythonTyper( return mapped - def _is_binary_function(self, function: Function) -> bool: - if len(function.pos_args) != 1: + def _check_arity( + self, + function: Function, + n_pos: Optional[int] = None, + n_mixed: Optional[int] = None, + n_keyword: Optional[int] = None, + ) -> bool: + if n_pos is not None and len(function.pos_args) != n_pos: return False - if len(function.args) != 0: + if n_mixed is not None and len(function.args) != n_mixed: return False - if len(function.kw_args) != 0: - return False - return True - - def _is_unary_function(self, function: Function) -> bool: - if len(function.pos_args) != 0: - return False - if len(function.args) != 0: - return False - if len(function.kw_args) != 0: + if n_keyword is not None and len(function.kw_args) != n_keyword: return False return True diff --git a/midas/checker/resolver.py b/midas/checker/resolver.py index 02fcbbc..636ccfe 100644 --- a/midas/checker/resolver.py +++ b/midas/checker/resolver.py @@ -196,3 +196,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): def visit_list_expr(self, expr: p.ListExpr) -> None: for item in expr.items: self.resolve(item) + + def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: + self.resolve(expr.object) + self.resolve(expr.index) From 2df0380815eae0264712157b021caff2c0b0b40a Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 18:57:02 +0200 Subject: [PATCH 54/64] fix(types): remove unused operation structures --- midas/checker/builtins.py | 19 ------------- midas/checker/registry.py | 58 --------------------------------------- midas/checker/types.py | 18 ------------ 3 files changed, 95 deletions(-) diff --git a/midas/checker/builtins.py b/midas/checker/builtins.py index 6fc46d2..961545a 100644 --- a/midas/checker/builtins.py +++ b/midas/checker/builtins.py @@ -6,7 +6,6 @@ from midas.checker.types import ( BaseType, GenericType, TopType, - Type, TypeVar, UnitType, ) @@ -21,24 +20,6 @@ BUILTIN_SUBTYPES: dict[str, set[str]] = { } -def op(reg: TypesRegistry, t1: Type, operator: str, t2: Type, t3: Type): - reg.define_operation( - left=t1, - operator=operator, - right=t2, - result=t3, - ) - - -def basic_op(reg: TypesRegistry, type: Type, op: str): - reg.define_operation( - left=type, - operator=op, - right=type, - result=type, - ) - - def define_builtins(reg: TypesRegistry): """Define builtin types and operations""" any = reg.define_type("Any", TopType()) diff --git a/midas/checker/registry.py b/midas/checker/registry.py index 5529ff6..6591548 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -10,7 +10,6 @@ from midas.checker.types import ( ExtensionType, Function, GenericType, - Operation, OverloadedFunction, TopType, Type, @@ -25,7 +24,6 @@ class TypesRegistry: self.logger: logging.Logger = logging.getLogger("TypesRegistry") self._types: dict[str, Type] = {} self._members: dict[str, dict[str, Type]] = {} - self._operations: dict[Operation.CallSignature, Type] = {} def get_type(self, name: str) -> Type: """Get a type from its name @@ -43,39 +41,6 @@ class TypesRegistry: return self._types[name] raise NameError(f"Undefined type {name}") - def get_operation_result( - self, left: Type, operator: str, right: Type - ) -> Optional[Type]: - """Get the resulting type of an operation - - Args: - left (Type): the type of the left operand - operator (str): the operation name - right (Type): the type of the right operand - - Returns: - Optional[Type]: the result type, or None if no matching operation was found - """ - signature: Operation.CallSignature = Operation.CallSignature( - left=left, - method=operator, - right=right, - ) - result: Optional[Type] = self._operations.get(signature) - return result - - def get_operations_by_name(self, name: str) -> list[Operation]: - operations: list[Operation] = [] - for signature, result in self._operations.items(): - if signature.method == name: - operations.append( - Operation( - signature=signature, - result=result, - ) - ) - return operations - def define_type(self, name: str, type: Type) -> Type: """Define a type in the registry @@ -116,29 +81,6 @@ class TypesRegistry: else: members[member_name] = member_type - def define_operation(self, left: Type, operator: str, right: Type, result: Type): - """Define an operation in the registry - - Args: - left (Type): the type of the left operand - operator (str): the operation name - right (Type): the type of the right operand - result (Type): the result type - - Raises: - ValueError: if an operation is already defined with these operands and name - """ - signature: Operation.CallSignature = Operation.CallSignature( - left=left, - method=operator, - right=right, - ) - if signature in self._operations: - raise ValueError( - f"Operation {operator} already defined between {left} and {right}" - ) - self._operations[signature] = result - def is_subtype(self, type1: Type, type2: Type) -> bool: """Check whether `type1` is a subtype of `type2` diff --git a/midas/checker/types.py b/midas/checker/types.py index 444e868..0bf8ea2 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -101,24 +101,6 @@ class ExtensionType: return f"{self.base} & {self.extension}" -@dataclass(frozen=True, kw_only=True) -class Operation: - signature: CallSignature - result: Type - - def __str__(self) -> str: - return f"{self.signature} -> {self.result}" - - @dataclass(frozen=True, kw_only=True) - class CallSignature: - left: Type - method: str - right: Type - - def __str__(self) -> str: - return f"{self.method}({self.left}, {self.right})" - - @dataclass(frozen=True, kw_only=True) class TypeVar: name: str From 759868172975027dd265678740a27ea1da657c18 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sun, 14 Jun 2026 15:48:31 +0200 Subject: [PATCH 55/64] feat(checker): handle overloaded function calls --- midas/checker/python.py | 220 +++++++++++++++++++++++++++------------- 1 file changed, 148 insertions(+), 72 deletions(-) diff --git a/midas/checker/python.py b/midas/checker/python.py index 45b9a33..dbe8c8e 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -12,12 +12,16 @@ from midas.checker.reporter import FileReporter, Reporter from midas.checker.resolver import Resolver from midas.checker.types import ( Function, + OverloadedFunction, Type, UnitType, UnknownType, + unfold_type, ) from midas.parser.python import PythonParser +TypedExpr = tuple[p.Expr, Type] + class ReturnException(Exception): pass @@ -354,26 +358,7 @@ class PythonTyper( ) return UnknownType() - match operation: - case Function() as function: - if not self._check_arity(function, 1, 0, 0): - self.reporter.error( - location, - f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}", - ) - return UnknownType() - - rhs: Function.Argument = function.pos_args[0] - if not self.is_subtype(right, rhs.type): - self.reporter.error( - location, - f"Wrong type for right-hand side, expected {rhs.type}, got {right}", - ) - return UnknownType() - return function.returns - case _: - self.reporter.warning(location, f"Unsupported operation {operation}") - return UnknownType() + return self._get_call_result(location, operation, [(right_expr, right)], {}) def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__) @@ -393,35 +378,24 @@ class PythonTyper( ) return UnknownType() - match operation: - case Function() as function: - if not self._check_arity(function, 0, 0, 0): - self.reporter.error( - expr.location, - f"Wrong definition of unary operation. Expected function with 0 parameters, got {function}", - ) - return UnknownType() - return function.returns - case _: - self.reporter.warning( - expr.location, f"Unsupported operation {operation}" - ) - return UnknownType() + return self._get_call_result( + expr.location, operation, [(expr.right, operand)], {} + ) def visit_call_expr(self, expr: p.CallExpr) -> Type: callee: Type = self.type_of(expr.callee) - if not isinstance(callee, Function): - self.reporter.error(expr.callee.location, "Callee is not a function") - return UnknownType() - function: Function = callee - mapped: list[MappedArgument] = self.map_call_arguments(function, expr) - for arg in mapped: - if not self.is_subtype(arg.type, arg.argument.type): - self.reporter.error( - arg.expr.location, - f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", - ) - return function.returns + positional: list[TypedExpr] = [ + (arg, self.type_of(arg)) for arg in expr.arguments + ] + keywords: dict[str, TypedExpr] = { + name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items() + } + return self._get_call_result( + location=expr.location, + callee=callee, + positional=positional, + keywords=keywords, + ) def visit_get_expr(self, expr: p.GetExpr) -> Type: object: Type = self.type_of(expr.object) @@ -572,9 +546,105 @@ class PythonTyper( self.reporter.warning(node.location, "FrameType not yet supported") return UnknownType() + def _get_call_result( + self, + location: Location, + callee: Type, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + ) -> Type: + match callee: + case Function() as function: + valid: bool + mapped: list[MappedArgument] + valid, mapped = self.map_call_arguments( + function, location, positional, keywords + ) + valid = valid and self._are_arguments_valid(mapped) + if not valid: + return UnknownType() + return function.returns + + case OverloadedFunction(overloads=overloads): + function = self._match_overload( + overloads, location, positional, keywords + ) + if function is None: + return UnknownType() + return function.returns + case _: + self.reporter.error(location, f"{callee} is not callable") + return UnknownType() + + def _are_arguments_valid( + self, + arguments: list[MappedArgument], + report_errors: bool = True, + ) -> bool: + valid: bool = True + for arg in arguments: + if not self.is_subtype(arg.type, arg.argument.type): + if report_errors: + self.reporter.error( + arg.expr.location, + f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", + ) + valid = False + return valid + + def _match_overload( + self, + overloads: list[Type], + location: Location, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + ) -> Optional[Function]: + candidates: list[Function] = [] + for overload in overloads: + function: Type = unfold_type(overload) + if not isinstance(function, Function): + self.logger.error( + f"Overload is not a function: {overload} is {function}" + ) + continue + valid, mapped = self.map_call_arguments( + function=function, + location=location, + positional=positional, + keywords=keywords, + report_errors=False, + ) + if valid and self._are_arguments_valid(mapped, report_errors=False): + candidates.append(function) + + pos_types: str = ", ".join(str(type) for _, type in positional) + kw_types: str = ", ".join( + f"{name}: {type}" for name, (_, type) in keywords.items() + ) + for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}" + + if len(candidates) == 0: + self.reporter.error( + location, + f"No matching overload in {overloads} {for_args}", + ) + return None + if len(candidates) > 1: + self.reporter.error( + location, + f"Multiple matching overloads {for_args}: {', '.join(map(str, candidates))}", + ) + return None + return candidates[0] + def map_call_arguments( - self, function: Function, call: p.CallExpr - ) -> list[MappedArgument]: + self, + function: Function, + location: Location, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + report_errors: bool = True, + ) -> tuple[bool, list[MappedArgument]]: """Map call arguments to function parameters as defined in its signature This method maps positional-only, keyword-only and mixed parameter definitions @@ -589,12 +659,6 @@ class PythonTyper( Returns: list[MappedArgument]: the list of mapped arguments """ - positional: list[tuple[p.Expr, Type]] = [ - (arg, self.type_of(arg)) for arg in call.arguments - ] - keywords: dict[str, tuple[p.Expr, Type]] = { - name: (arg, self.type_of(arg)) for name, arg in call.keywords.items() - } set_args: set[str] = set() required_positional: list[str] = [ @@ -612,6 +676,8 @@ class PythonTyper( arg.name: arg for arg in function.kw_args } + valid_call: bool = True + # TODO: handle *args and **kwargs sinks for arg in positional: param: Function.Argument @@ -620,7 +686,11 @@ class PythonTyper( elif len(mixed_params) != 0: param = mixed_params.pop(0) else: - self.reporter.error(arg[0].location, "Too many positional arguments") + if report_errors: + self.reporter.error( + arg[0].location, "Too many positional arguments" + ) + valid_call = False break name: str = param.name if name in required_positional: @@ -640,14 +710,16 @@ class PythonTyper( for name, arg in keywords.items(): param: Function.Argument if name not in kw_params: - if name in set_args: - self.reporter.error( - arg[0].location, f"Multiple values for argument '{name}'" - ) - else: - self.reporter.error( - arg[0].location, f"Unknown keyword argument '{name}'" - ) + if report_errors: + if name in set_args: + self.reporter.error( + arg[0].location, f"Multiple values for argument '{name}'" + ) + else: + self.reporter.error( + arg[0].location, f"Unknown keyword argument '{name}'" + ) + valid_call = False continue param = kw_params.pop(name) if name in required_positional: @@ -674,20 +746,24 @@ class PythonTyper( if len(required_positional) != 0: plural: str = "" if len(required_positional) == 1 else "s" args: str = join_args(required_positional) - self.reporter.error( - call.location, - f"Missing required positional argument{plural}: {args}", - ) + if report_errors: + self.reporter.error( + location, + f"Missing required positional argument{plural}: {args}", + ) + valid_call = False if len(required_keyword) != 0: plural: str = "" if len(required_keyword) == 1 else "s" args: str = join_args(required_keyword) - self.reporter.error( - call.location, - f"Missing required keyword argument{plural}: {args}", - ) + if report_errors: + self.reporter.error( + location, + f"Missing required keyword argument{plural}: {args}", + ) + valid_call = False - return mapped + return valid_call, mapped def _check_arity( self, From 46a22797b691c752ac548f17f26baec81db134ec Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sun, 14 Jun 2026 15:50:20 +0200 Subject: [PATCH 56/64] chore: add examples for functions and overloads --- .../01_simple_type_checking/05_functions.py | 28 +++++++++++++++++++ .../06_overloads.midas | 8 ++++++ .../01_simple_type_checking/06_overloads.py | 14 ++++++++++ 3 files changed, 50 insertions(+) create mode 100644 examples/01_simple_type_checking/05_functions.py create mode 100644 examples/01_simple_type_checking/06_overloads.midas create mode 100644 examples/01_simple_type_checking/06_overloads.py diff --git a/examples/01_simple_type_checking/05_functions.py b/examples/01_simple_type_checking/05_functions.py new file mode 100644 index 0000000..9c04813 --- /dev/null +++ b/examples/01_simple_type_checking/05_functions.py @@ -0,0 +1,28 @@ +def incr(value: int): + return value + 1 + + +def decr(value: int): + return value - 1 + + +def foo(a: int, /, b: float, *, c: str): + return True + + +r1 = foo() # foo() missing 2 required positional arguments: 'a' and 'b' +r2 = foo(1) # foo() missing 1 required positional argument: 'b' +r3 = foo(1, 2.0) # foo() missing 1 required keyword-only argument: 'c' +r4 = foo(1, b=2.0) # foo() missing 1 required keyword-only argument: 'c' +r5 = foo(1, 2.0, "test") # foo() takes 2 positional arguments but 3 were given +r6 = foo(1, 2.0, b=3.0) # foo() got multiple values for argument 'b' +r7 = foo( + a=1 +) # foo() got some positional-only arguments passed as keyword arguments: 'a' +r8 = foo(g="test") # foo() got an unexpected keyword argument 'g' + +r9a = foo(1, 2.0, c="test") +r9b = foo(1, b=2.0, c="test") +r9c = foo(1, c="test", b=2.0) + +r10 = foo("a", 3, c=False) # wrong argument types diff --git a/examples/01_simple_type_checking/06_overloads.midas b/examples/01_simple_type_checking/06_overloads.midas new file mode 100644 index 0000000..47c80e0 --- /dev/null +++ b/examples/01_simple_type_checking/06_overloads.midas @@ -0,0 +1,8 @@ +type T1 = object +type T2 = object +type Foo = object + +extend Foo { + def bar: fn(T1, /) -> int + def bar: fn(T2, /) -> float +} diff --git a/examples/01_simple_type_checking/06_overloads.py b/examples/01_simple_type_checking/06_overloads.py new file mode 100644 index 0000000..105d5ce --- /dev/null +++ b/examples/01_simple_type_checking/06_overloads.py @@ -0,0 +1,14 @@ +# type: ignore +# ruff: disable [F821] + +foo: Foo +t1: T1 +t2: T2 + +a = foo.bar(t1) +b = foo.bar(t2) + +func = foo.bar + +c = func(t1) +d = func(t2) From 2a579c06b1b73c109b9595b593c62bf4133ebc6f Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sun, 14 Jun 2026 16:07:41 +0200 Subject: [PATCH 57/64] refactor(checker): unify call check for subscript --- midas/checker/python.py | 41 +++-------------------------------------- 1 file changed, 3 insertions(+), 38 deletions(-) diff --git a/midas/checker/python.py b/midas/checker/python.py index dbe8c8e..11b2e77 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -497,29 +497,9 @@ class PythonTyper( return UnknownType() index: Type = self.type_of(expr.index) - - match operation: - case Function() as function: - if not self._check_arity(function, 1, 0, 0): - self.reporter.error( - expr.location, - f"Wrong definition of __getitem__. Expected function with 1 positional-only parameters, got {function}", - ) - return UnknownType() - - index_arg: Function.Argument = function.pos_args[0] - if not self.is_subtype(index, index_arg.type): - self.reporter.error( - expr.location, - f"Wrong index type, expected {index_arg.type}, got {index}", - ) - return UnknownType() - return function.returns - case _: - self.reporter.warning( - expr.location, f"Unsupported operation {operation}" - ) - return UnknownType() + return self._get_call_result( + expr.location, operation, [(expr.index, index)], {} + ) def visit_base_type(self, node: p.BaseType) -> Type: base: Type @@ -764,18 +744,3 @@ class PythonTyper( valid_call = False return valid_call, mapped - - def _check_arity( - self, - function: Function, - n_pos: Optional[int] = None, - n_mixed: Optional[int] = None, - n_keyword: Optional[int] = None, - ) -> bool: - if n_pos is not None and len(function.pos_args) != n_pos: - return False - if n_mixed is not None and len(function.args) != n_mixed: - return False - if n_keyword is not None and len(function.kw_args) != n_keyword: - return False - return True From e1da87eaa0275ff366628bfbf41997f56d93bd16 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sun, 14 Jun 2026 16:08:34 +0200 Subject: [PATCH 58/64] doc(checker): add docstrings to new call checks --- midas/checker/python.py | 52 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/midas/checker/python.py b/midas/checker/python.py index 11b2e77..25d5d8b 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -533,6 +533,24 @@ class PythonTyper( positional: list[TypedExpr], keywords: dict[str, TypedExpr], ) -> Type: + """Get the result type of a function call + + If the function has overloads, the function will try to resolve the + appropriate signature. + Argument types are matched to the defined parameters. + The function doesn't take the raw expression as a parameter to accomodate + for desugared calls such as for operators. + + Args: + location (Location): the call location + callee (Type): the called function + positional (list[TypedExpr]): the list positional arguments + keywords (dict[str, TypedExpr]): the map of keyword arguments + + Returns: + Type: the return type of the call, or `UnknownType` if either + the call is invalid or no overload matched the arguments uniquely + """ match callee: case Function() as function: valid: bool @@ -561,6 +579,15 @@ class PythonTyper( arguments: list[MappedArgument], report_errors: bool = True, ) -> bool: + """Check whether the passed argument types correspond to their matched parameter definitions + + Args: + arguments (list[MappedArgument]): the list of argument/parameter pairs + report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. + + Returns: + bool: True if all arguments fit the matching parameter definitions, False otherwise + """ valid: bool = True for arg in arguments: if not self.is_subtype(arg.type, arg.argument.type): @@ -579,6 +606,18 @@ class PythonTyper( positional: list[TypedExpr], keywords: dict[str, TypedExpr], ) -> Optional[Function]: + """Try and resolve the appropriate overload for the given arguments + + Args: + overloads (list[Type]): the list of possible overloads + location (Location): the call location + positional (list[TypedExpr]): the list of positional arguments + keywords (dict[str, TypedExpr]): the map of keywords arguments + + Returns: + Optional[Function]: the resolved function signature if it can be + determined unambigously, or `None`. + """ candidates: list[Function] = [] for overload in overloads: function: Type = unfold_type(overload) @@ -625,19 +664,24 @@ class PythonTyper( keywords: dict[str, TypedExpr], report_errors: bool = True, ) -> tuple[bool, list[MappedArgument]]: - """Map call arguments to function parameters as defined in its signature + """Map call arguments to a function's parameters as defined in its signature This method maps positional-only, keyword-only and mixed parameter definitions with the arguments passed at the call site - Any mismatched, missing or unexpected argument is reported as a diagnostic + Any mismatched, missing or unexpected argument is reported as a diagnostic, + unless `report_errors` is set to `False` Args: function (Function): the function definition - call (p.CallExpr): the call expression + location (Location): the call location + positional (list[TypedExpr]): the list of positional arguments + keywords (dict[str, TypedExpr]): the map of keyword arguments + report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. Returns: - list[MappedArgument]: the list of mapped arguments + tuple[bool, list[MappedArgument]]: a boolean reporting whether + the call is valid and the list of mapped arguments """ set_args: set[str] = set() From 0a35563aaffdef45e140c137a970881341b04503 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sun, 14 Jun 2026 16:36:10 +0200 Subject: [PATCH 59/64] feat(checker): resolve overloads with subtypes try to find the most specific overload if multiple matches are found --- .../06_overloads.midas | 2 + .../01_simple_type_checking/06_overloads.py | 4 + midas/checker/python.py | 85 ++++++++++++++++--- 3 files changed, 81 insertions(+), 10 deletions(-) diff --git a/examples/01_simple_type_checking/06_overloads.midas b/examples/01_simple_type_checking/06_overloads.midas index 47c80e0..777c410 100644 --- a/examples/01_simple_type_checking/06_overloads.midas +++ b/examples/01_simple_type_checking/06_overloads.midas @@ -1,8 +1,10 @@ type T1 = object type T2 = object type Foo = object +type T2b = T2 extend Foo { def bar: fn(T1, /) -> int def bar: fn(T2, /) -> float + def bar: fn(T2b, /) -> int } diff --git a/examples/01_simple_type_checking/06_overloads.py b/examples/01_simple_type_checking/06_overloads.py index 105d5ce..86406e0 100644 --- a/examples/01_simple_type_checking/06_overloads.py +++ b/examples/01_simple_type_checking/06_overloads.py @@ -12,3 +12,7 @@ func = foo.bar c = func(t1) d = func(t2) + +t2b: T2b + +e = foo.bar(t2b) diff --git a/midas/checker/python.py b/midas/checker/python.py index 25d5d8b..9ce9399 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -34,6 +34,12 @@ class MappedArgument: argument: Function.Argument +@dataclass(frozen=True, kw_only=True) +class OverloadCandidate: + function: Function + mapped: list[MappedArgument] + + class PythonTyper( p.Stmt.Visitor[None], p.Expr.Visitor[Type], @@ -618,7 +624,7 @@ class PythonTyper( Optional[Function]: the resolved function signature if it can be determined unambigously, or `None`. """ - candidates: list[Function] = [] + candidates: list[OverloadCandidate] = [] for overload in overloads: function: Type = unfold_type(overload) if not isinstance(function, Function): @@ -634,7 +640,12 @@ class PythonTyper( report_errors=False, ) if valid and self._are_arguments_valid(mapped, report_errors=False): - candidates.append(function) + candidates.append( + OverloadCandidate( + function=function, + mapped=mapped, + ) + ) pos_types: str = ", ".join(str(type) for _, type in positional) kw_types: str = ", ".join( @@ -642,19 +653,43 @@ class PythonTyper( ) for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}" - if len(candidates) == 0: + n_candidates: int = len(candidates) + + # Exactly 1 match -> return it + if n_candidates == 1: + return candidates[0].function + + # No match -> invalid call + if n_candidates == 0: self.reporter.error( location, f"No matching overload in {overloads} {for_args}", ) return None - if len(candidates) > 1: - self.reporter.error( - location, - f"Multiple matching overloads {for_args}: {', '.join(map(str, candidates))}", - ) - return None - return candidates[0] + + # Multiple matches -> see if one <: all others (more specific) + for i1, c1 in enumerate(candidates): + mapped1: list[MappedArgument] = c1.mapped + best_match: bool = True + for i2, c2 in enumerate(candidates): + if i1 == i2: + continue + mapped2: list[MappedArgument] = c2.mapped + if not self._are_mapped_subtypes(mapped1, mapped2): + best_match = False + break + self.logger.debug(f"{c1.function} is a full overload of {c2.function}") + if best_match: + return c1.function + + candidates_str: str = ", ".join( + str(candidate.function) for candidate in candidates + ) + self.reporter.error( + location, + f"Multiple matching overloads {for_args}: {candidates_str}", + ) + return None def map_call_arguments( self, @@ -788,3 +823,33 @@ class PythonTyper( valid_call = False return valid_call, mapped + + def _are_mapped_subtypes( + self, mapped1: list[MappedArgument], mapped2: list[MappedArgument] + ) -> bool: + """Check whether the given argument mappings are subtype/supertype of one another + + This function checks whether the argument mappings `mapped1` are subtypes + of `mapped2`. If any of the parameter type in `mapped1` is not a subtype + of the corresponding parameter in `mapped2`, `False` is returned. + + This is used to check whether a given overload is + a more specific function/ a subtype of another. + + Args: + mapped1 (list[MappedArgument]): the first argument mappings (subtype) + mapped2 (list[MappedArgument]): the second argument mappings (supertype) + + Returns: + bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise + """ + by_expr: dict[p.Expr, Type] = {} + for arg in mapped1: + by_expr[arg.expr] = arg.argument.type + + for arg in mapped2: + type2: Type = arg.argument.type + type1: Type = by_expr[arg.expr] + if not self.is_subtype(type1, type2): + return False + return True From 35798e5752103a50d7bec10944df0f9b0d542789 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sun, 14 Jun 2026 16:45:13 +0200 Subject: [PATCH 60/64] tests: update with new subscript and call checks invalid function calls now return UnknownType even if the function has a return type --- .../cases/checker/01_simple_types.py.ref.json | 17 ++- tests/cases/checker/03_functions.py.ref.json | 36 ++--- .../python-parser/02_custom_types.py.ref.json | 132 ++++++++++++++++++ 3 files changed, 157 insertions(+), 28 deletions(-) diff --git a/tests/cases/checker/01_simple_types.py.ref.json b/tests/cases/checker/01_simple_types.py.ref.json index 3c4d0b9..ac24fcd 100644 --- a/tests/cases/checker/01_simple_types.py.ref.json +++ b/tests/cases/checker/01_simple_types.py.ref.json @@ -1,4 +1,19 @@ { - "diagnostics": [], + "diagnostics": [ + { + "type": "Warning", + "location": { + "start": [ + 6, + 4 + ], + "end": [ + 13, + 5 + ] + }, + "message": "FrameType not yet supported" + } + ], "judgments": [] } \ No newline at end of file diff --git a/tests/cases/checker/03_functions.py.ref.json b/tests/cases/checker/03_functions.py.ref.json index 814bc35..fa06642 100644 --- a/tests/cases/checker/03_functions.py.ref.json +++ b/tests/cases/checker/03_functions.py.ref.json @@ -326,9 +326,7 @@ "arguments": [], "keywords": {} }, - "type": { - "name": "bool" - } + "type": {} }, { "location": { @@ -407,9 +405,7 @@ ], "keywords": {} }, - "type": { - "name": "bool" - } + "type": {} }, { "location": { @@ -505,9 +501,7 @@ ], "keywords": {} }, - "type": { - "name": "bool" - } + "type": {} }, { "location": { @@ -604,9 +598,7 @@ } } }, - "type": { - "name": "bool" - } + "type": {} }, { "location": { @@ -719,9 +711,7 @@ ], "keywords": {} }, - "type": { - "name": "bool" - } + "type": {} }, { "location": { @@ -835,9 +825,7 @@ } } }, - "type": { - "name": "bool" - } + "type": {} }, { "location": { @@ -916,9 +904,7 @@ } } }, - "type": { - "name": "bool" - } + "type": {} }, { "location": { @@ -997,9 +983,7 @@ } } }, - "type": { - "name": "bool" - } + "type": {} }, { "location": { @@ -1461,9 +1445,7 @@ } } }, - "type": { - "name": "bool" - } + "type": {} } ] } \ No newline at end of file diff --git a/tests/cases/python-parser/02_custom_types.py.ref.json b/tests/cases/python-parser/02_custom_types.py.ref.json index 639610d..82c726c 100644 --- a/tests/cases/python-parser/02_custom_types.py.ref.json +++ b/tests/cases/python-parser/02_custom_types.py.ref.json @@ -18,6 +18,80 @@ ] } }, + { + "_type": "TypeAssign", + "name": "lat", + "type": { + "_type": "BaseType", + "base": "Column", + "param": { + "_type": "BaseType", + "base": "GeoLocation", + "param": null + } + } + }, + { + "_type": "AssignStmt", + "targets": [ + { + "_type": "VariableExpr", + "name": "lat" + } + ], + "value": { + "_type": "GetExpr", + "object": { + "_type": "SubscriptExpr", + "object": { + "_type": "VariableExpr", + "name": "df" + }, + "index": { + "_type": "LiteralExpr", + "value": "location" + } + }, + "name": "lat" + } + }, + { + "_type": "TypeAssign", + "name": "lon", + "type": { + "_type": "BaseType", + "base": "Column", + "param": { + "_type": "BaseType", + "base": "GeoLocation", + "param": null + } + } + }, + { + "_type": "AssignStmt", + "targets": [ + { + "_type": "VariableExpr", + "name": "lon" + } + ], + "value": { + "_type": "GetExpr", + "object": { + "_type": "SubscriptExpr", + "object": { + "_type": "VariableExpr", + "name": "df" + }, + "index": { + "_type": "LiteralExpr", + "value": "location" + } + }, + "name": "lon" + } + }, { "_type": "ExpressionStmt", "expr": { @@ -33,6 +107,64 @@ } } }, + { + "_type": "TypeAssign", + "name": "lat1", + "type": { + "_type": "BaseType", + "base": "Latitude", + "param": null + } + }, + { + "_type": "AssignStmt", + "targets": [ + { + "_type": "VariableExpr", + "name": "lat1" + } + ], + "value": { + "_type": "SubscriptExpr", + "object": { + "_type": "VariableExpr", + "name": "lat" + }, + "index": { + "_type": "LiteralExpr", + "value": 0 + } + } + }, + { + "_type": "TypeAssign", + "name": "lat2", + "type": { + "_type": "BaseType", + "base": "Latitude", + "param": null + } + }, + { + "_type": "AssignStmt", + "targets": [ + { + "_type": "VariableExpr", + "name": "lat2" + } + ], + "value": { + "_type": "SubscriptExpr", + "object": { + "_type": "VariableExpr", + "name": "lat" + }, + "index": { + "_type": "LiteralExpr", + "value": 1 + } + } + }, { "_type": "TypeAssign", "name": "lat_diff", From 1eedcff5aa6d02e51bcaded65fcb10269b924df8 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sun, 14 Jun 2026 16:53:38 +0200 Subject: [PATCH 61/64] feat(parser): add slice expression --- gen/python.py | 6 ++++++ midas/ast/printer.py | 7 +++++++ midas/ast/python.py | 13 +++++++++++++ midas/cli/highlighter.py | 8 ++++++++ midas/parser/python.py | 9 +++++++++ tests/serializer/python.py | 9 +++++++++ 6 files changed, 52 insertions(+) diff --git a/gen/python.py b/gen/python.py index b7c38ec..35908f7 100644 --- a/gen/python.py +++ b/gen/python.py @@ -148,4 +148,10 @@ class SubscriptExpr: index: Expr +class SliceExpr: + lower: Optional[Expr] + upper: Optional[Expr] + step: Optional[Expr] + + ###< diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 3495883..e52472c 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -729,3 +729,10 @@ class PythonAstPrinter( self._write_line("index", last=True) with self._child_level(single=True): expr.index.accept(self) + + def visit_slice_expr(self, expr: p.SliceExpr) -> None: + self._write_line("SliceExpr") + with self._child_level(): + self._write_optional_child("lower", expr.lower) + self._write_optional_child("upper", expr.upper) + self._write_optional_child("step", expr.step, last=True) diff --git a/midas/ast/python.py b/midas/ast/python.py index a199b89..f025e2f 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -227,6 +227,9 @@ class Expr(ABC): @abstractmethod def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ... + @abstractmethod + def visit_slice_expr(self, expr: SliceExpr) -> T: ... + @dataclass(frozen=True) class BinaryExpr(Expr): @@ -336,3 +339,13 @@ class SubscriptExpr(Expr): def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_subscript_expr(self) + + +@dataclass(frozen=True) +class SliceExpr(Expr): + lower: Optional[Expr] + upper: Optional[Expr] + step: Optional[Expr] + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_slice_expr(self) diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index 3c3f07e..bc7727c 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -222,6 +222,14 @@ class PythonHighlighter( expr.object.accept(self) expr.index.accept(self) + def visit_slice_expr(self, expr: p.SliceExpr) -> None: + if expr.lower is not None: + expr.lower.accept(self) + if expr.upper is not None: + expr.upper.accept(self) + if expr.step is not None: + expr.step.accept(self) + class MidasHighlighter( Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None] diff --git a/midas/parser/python.py b/midas/parser/python.py index 8c1f5a7..a0726da 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -22,6 +22,7 @@ from midas.ast.python import ( LogicalExpr, MidasType, ReturnStmt, + SliceExpr, Stmt, SubscriptExpr, TernaryExpr, @@ -431,6 +432,14 @@ class PythonParser: index=self.parse_expr(index), ) + case ast.Slice(lower=lower, upper=upper, step=step): + return SliceExpr( + location=location, + lower=self.parse_expr(lower) if lower is not None else None, + upper=self.parse_expr(upper) if upper is not None else None, + step=self.parse_expr(step) if step is not None else None, + ) + case _: raise UnsupportedSyntaxError(node) diff --git a/tests/serializer/python.py b/tests/serializer/python.py index 73b8b84..b090eea 100644 --- a/tests/serializer/python.py +++ b/tests/serializer/python.py @@ -21,6 +21,7 @@ from midas.ast.python import ( LogicalExpr, MidasType, ReturnStmt, + SliceExpr, Stmt, SubscriptExpr, TernaryExpr, @@ -260,3 +261,11 @@ class PythonAstJsonSerializer( "object": expr.object.accept(self), "index": expr.index.accept(self), } + + def visit_slice_expr(self, expr: SliceExpr) -> dict: + return { + "_type": "SliceExpr", + "lower": self._serialize_optional(expr.lower), + "upper": self._serialize_optional(expr.upper), + "step": self._serialize_optional(expr.step), + } From 37a464d2bc7ea3f166b878c27121c191c8d655c9 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sun, 14 Jun 2026 16:58:20 +0200 Subject: [PATCH 62/64] feat(checker): type check slice expressions --- midas/checker/builtins.py | 1 + midas/checker/python.py | 3 +++ midas/checker/resolver.py | 8 ++++++++ 3 files changed, 12 insertions(+) diff --git a/midas/checker/builtins.py b/midas/checker/builtins.py index 961545a..b1adf6d 100644 --- a/midas/checker/builtins.py +++ b/midas/checker/builtins.py @@ -29,6 +29,7 @@ def define_builtins(reg: TypesRegistry): int = reg.define_type("int", BaseType(name="int")) float = reg.define_type("float", BaseType(name="float")) str = reg.define_type("str", BaseType(name="str")) + slice = reg.define_type("slice", BaseType(name="slice")) list = reg.define_type( "list", diff --git a/midas/checker/python.py b/midas/checker/python.py index 9ce9399..f28ce77 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -507,6 +507,9 @@ class PythonTyper( expr.location, operation, [(expr.index, index)], {} ) + def visit_slice_expr(self, expr: p.SliceExpr) -> Type: + return self.types.get_type("slice") + def visit_base_type(self, node: p.BaseType) -> Type: base: Type try: diff --git a/midas/checker/resolver.py b/midas/checker/resolver.py index 636ccfe..12f18cf 100644 --- a/midas/checker/resolver.py +++ b/midas/checker/resolver.py @@ -200,3 +200,11 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: self.resolve(expr.object) self.resolve(expr.index) + + def visit_slice_expr(self, expr: p.SliceExpr) -> None: + if expr.lower is not None: + self.resolve(expr.lower) + if expr.upper is not None: + self.resolve(expr.upper) + if expr.step is not None: + self.resolve(expr.step) From bd0421b5d83ee4383ffffc62252388cb03361fa2 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sun, 14 Jun 2026 17:04:10 +0200 Subject: [PATCH 63/64] fix(checker): handle generic overloads --- midas/checker/python.py | 3 ++- midas/checker/types.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/midas/checker/python.py b/midas/checker/python.py index f28ce77..a0f7a06 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -664,9 +664,10 @@ class PythonTyper( # No match -> invalid call if n_candidates == 0: + overloads_str: str = ", ".join(map(str, overloads)) self.reporter.error( location, - f"No matching overload in {overloads} {for_args}", + f"No matching overload in [{overloads_str}] {for_args}", ) return None diff --git a/midas/checker/types.py b/midas/checker/types.py index 0bf8ea2..c6d41d1 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -164,6 +164,14 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: returns=substitute_typevars(returns, substitutions), ) + case OverloadedFunction(overloads=overloads): + return OverloadedFunction( + overloads=[ + substitute_typevars(overload, substitutions) + for overload in overloads + ] + ) + case ComplexType(members=members): members2: dict[str, Type] = { name: substitute_typevars(prop, substitutions) From 635bf7353179f67bc399f7bddd6066dfafb3efcc Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sun, 14 Jun 2026 17:04:29 +0200 Subject: [PATCH 64/64] feat(checker): add slice overloads on lists --- examples/01_simple_type_checking/04_complex_types.py | 2 ++ midas/checker/builtins.midas | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/01_simple_type_checking/04_complex_types.py b/examples/01_simple_type_checking/04_complex_types.py index 4f21cc6..f1d1215 100644 --- a/examples/01_simple_type_checking/04_complex_types.py +++ b/examples/01_simple_type_checking/04_complex_types.py @@ -33,3 +33,5 @@ a = foo[0] b = bar[0][1] c = bar[0][1][2] # invalid, not method __getitem__ on Meter c = bar[""] # invalid, wrong index type + +d = foo[1:2] diff --git a/midas/checker/builtins.midas b/midas/checker/builtins.midas index ba8f18b..6e89172 100644 --- a/midas/checker/builtins.midas +++ b/midas/checker/builtins.midas @@ -129,11 +129,11 @@ extend list[T] { def __len__: fn () -> int // def __iter__: fn () -> Iterator[T] def __getitem__: fn (i: int, /) -> T - //__getitem__: fn (s: slice, /) -> list[T] + def __getitem__: fn (s: slice, /) -> list[T] def __setitem__: fn (key: int, value: T, /) -> None - //__setitem__: fn (key: slice, value: list[T], /) -> None + def __setitem__: fn (key: slice, value: list[T], /) -> None def __delitem__: fn (key: int, /) -> None - // def __delitem__: fn (key: slice, /) -> None + def __delitem__: fn (key: slice, /) -> None // def __add__: fn[S <: T] (value: list[S], /) -> list[T] def __add__: fn (value: list[T], /) -> list[T] def __iadd__: fn (value: list[T], /) -> list[T]