From 61514d036c128800950b6abb2c3bb783972e522c Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 16 Jun 2026 10:38:09 +0200 Subject: [PATCH 1/9] feat(passer): add raw statements and expressions --- gen/python.py | 8 +++++++ midas/ast/printer.py | 10 +++++++++ midas/ast/python.py | 22 +++++++++++++++++++ midas/parser/python.py | 7 ++++-- .../python-parser/01_simple_types.py.ref.json | 4 ++++ .../python-parser/02_custom_types.py.ref.json | 4 ++++ .../python-parser/03_functions.py.ref.json | 4 ++++ tests/serializer/python.py | 14 ++++++++++++ 8 files changed, 71 insertions(+), 2 deletions(-) diff --git a/gen/python.py b/gen/python.py index 99b926d..f67f540 100644 --- a/gen/python.py +++ b/gen/python.py @@ -92,6 +92,10 @@ class ForStmt: body: list[Stmt] +class RawStmt: + stmt: ast.stmt + + ###< @@ -164,4 +168,8 @@ class SliceExpr: step: Optional[Expr] +class RawExpr: + expr: ast.expr + + ###< diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 364da35..68ff7ba 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -613,6 +613,11 @@ class PythonAstPrinter( self._mark_last() body_stmt.accept(self) + def visit_raw_stmt(self, stmt: p.RawStmt) -> None: + self._write_line("RawStmt") + with self._child_level(single=True): + self._write_line(f"stmt: {ast.unparse(stmt.stmt)}") + def visit_binary_expr(self, expr: p.BinaryExpr) -> None: self._write_line("BinaryExpr") with self._child_level(): @@ -756,3 +761,8 @@ class PythonAstPrinter( self._write_optional_child("lower", expr.lower) self._write_optional_child("upper", expr.upper) self._write_optional_child("step", expr.step, last=True) + + def visit_raw_expr(self, expr: p.RawExpr) -> None: + self._write_line("RawExpr") + with self._child_level(single=True): + self._write_line(f"expr: {ast.unparse(expr.expr)}") diff --git a/midas/ast/python.py b/midas/ast/python.py index b781d2e..73d49e5 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -113,6 +113,9 @@ class Stmt(ABC): @abstractmethod def visit_for_stmt(self, stmt: ForStmt) -> T: ... + @abstractmethod + def visit_raw_stmt(self, stmt: RawStmt) -> T: ... + @dataclass(frozen=True) class ExpressionStmt(Stmt): @@ -202,6 +205,14 @@ class ForStmt(Stmt): return visitor.visit_for_stmt(self) +@dataclass(frozen=True) +class RawStmt(Stmt): + stmt: ast.stmt + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_raw_stmt(self) + + ############### # Expressions # ############### @@ -254,6 +265,9 @@ class Expr(ABC): @abstractmethod def visit_slice_expr(self, expr: SliceExpr) -> T: ... + @abstractmethod + def visit_raw_expr(self, expr: RawExpr) -> T: ... + @dataclass(frozen=True) class BinaryExpr(Expr): @@ -373,3 +387,11 @@ class SliceExpr(Expr): def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_slice_expr(self) + + +@dataclass(frozen=True) +class RawExpr(Expr): + expr: ast.expr + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_raw_expr(self) diff --git a/midas/parser/python.py b/midas/parser/python.py index 55edd34..90c029a 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -22,6 +22,8 @@ from midas.ast.python import ( LiteralExpr, LogicalExpr, MidasType, + RawExpr, + RawStmt, ReturnStmt, SliceExpr, Stmt, @@ -99,7 +101,7 @@ class PythonParser: case _: print(f"Unsupported statement: {ast.unparse(node)}") - return None + return RawStmt(location=location, stmt=node) def parse_annotation_assign(self, node: ast.AnnAssign) -> list[Stmt]: statements: list[Stmt] = [] @@ -461,7 +463,8 @@ class PythonParser: ) case _: - raise UnsupportedSyntaxError(node) + print(f"Unsupported expression: {ast.unparse(node)}") + return RawExpr(location=location, expr=node) def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr: op: ast.boolop = node.op diff --git a/tests/cases/python-parser/01_simple_types.py.ref.json b/tests/cases/python-parser/01_simple_types.py.ref.json index e4fd591..452b9c0 100644 --- a/tests/cases/python-parser/01_simple_types.py.ref.json +++ b/tests/cases/python-parser/01_simple_types.py.ref.json @@ -1,5 +1,9 @@ { "stmts": [ + { + "_type": "RawStmt", + "stmt": "from __future__ import annotations" + }, { "_type": "TypeAssign", "name": "df", diff --git a/tests/cases/python-parser/02_custom_types.py.ref.json b/tests/cases/python-parser/02_custom_types.py.ref.json index 82c726c..9d77ebd 100644 --- a/tests/cases/python-parser/02_custom_types.py.ref.json +++ b/tests/cases/python-parser/02_custom_types.py.ref.json @@ -1,5 +1,9 @@ { "stmts": [ + { + "_type": "RawStmt", + "stmt": "from __future__ import annotations" + }, { "_type": "TypeAssign", "name": "df", diff --git a/tests/cases/python-parser/03_functions.py.ref.json b/tests/cases/python-parser/03_functions.py.ref.json index 529455b..a8f261f 100644 --- a/tests/cases/python-parser/03_functions.py.ref.json +++ b/tests/cases/python-parser/03_functions.py.ref.json @@ -1,5 +1,9 @@ { "stmts": [ + { + "_type": "RawStmt", + "stmt": "from __future__ import annotations" + }, { "_type": "Function", "name": "func", diff --git a/tests/serializer/python.py b/tests/serializer/python.py index 56171b8..45951df 100644 --- a/tests/serializer/python.py +++ b/tests/serializer/python.py @@ -22,6 +22,8 @@ from midas.ast.python import ( LogicalExpr, MidasType, Pass, + RawExpr, + RawStmt, ReturnStmt, SliceExpr, Stmt, @@ -191,6 +193,12 @@ class PythonAstJsonSerializer( "body": self._serialize_list(stmt.body), } + def visit_raw_stmt(self, stmt: RawStmt) -> dict: + return { + "_type": "RawStmt", + "stmt": ast.unparse(stmt.stmt), + } + def visit_binary_expr(self, expr: BinaryExpr) -> dict: return { "_type": "BinaryExpr", @@ -284,3 +292,9 @@ class PythonAstJsonSerializer( "upper": self._serialize_optional(expr.upper), "step": self._serialize_optional(expr.step), } + + def visit_raw_expr(self, expr: RawExpr) -> dict: + return { + "_type": "RawExpr", + "expr": ast.unparse(expr.expr), + } From 0a8e0fb6c2a6b82393eed6129c5cb319fc7f5dcf Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 16 Jun 2026 10:39:26 +0200 Subject: [PATCH 2/9] feat(checker): handle raw expr/stmt --- midas/checker/python.py | 6 ++++++ midas/checker/resolver.py | 6 ++++++ midas/generator/generator.py | 6 ++++++ 3 files changed, 18 insertions(+) diff --git a/midas/checker/python.py b/midas/checker/python.py index 316836e..9af8e73 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -370,6 +370,9 @@ class PythonTyper( if body_returned: raise ReturnException() + def visit_raw_stmt(self, stmt: p.RawStmt) -> None: + pass + def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) if method is None: @@ -566,6 +569,9 @@ class PythonTyper( def visit_slice_expr(self, expr: p.SliceExpr) -> Type: return self.types.get_type("slice") + def visit_raw_expr(self, expr: p.RawExpr) -> Type: + return UnknownType() + def visit_base_type(self, node: p.BaseType) -> Type: base: Type try: diff --git a/midas/checker/resolver.py b/midas/checker/resolver.py index c99a18d..3226faf 100644 --- a/midas/checker/resolver.py +++ b/midas/checker/resolver.py @@ -163,6 +163,9 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): self.resolve(*stmt.body) self.end_scope() + def visit_raw_stmt(self, stmt: p.RawStmt) -> None: + pass + def visit_binary_expr(self, expr: p.BinaryExpr) -> None: self.resolve(expr.left) self.resolve(expr.right) @@ -221,3 +224,6 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): self.resolve(expr.upper) if expr.step is not None: self.resolve(expr.step) + + def visit_raw_expr(self, expr: p.RawExpr) -> None: + pass diff --git a/midas/generator/generator.py b/midas/generator/generator.py index bef6e16..ddab606 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -108,6 +108,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): step=expr.step.accept(self) if expr.step is not None else None, ) + def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr: + return expr.expr + def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt: return ast.Expr( value=stmt.expr.accept(self), @@ -169,5 +172,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): orelse=[], ) + def visit_raw_stmt(self, stmt: p.RawStmt) -> ast.stmt: + return stmt.stmt + def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]: return [stmt.accept(self) for stmt in stmts] From c3229b557cf64381187741e9e5ce6e8acd7005f8 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 16 Jun 2026 12:49:36 +0200 Subject: [PATCH 3/9] feat(gen): add basic cast assertions on base type --- midas/generator/generator.py | 160 +++++++++++++++++++++++++++++++++-- 1 file changed, 154 insertions(+), 6 deletions(-) diff --git a/midas/generator/generator.py b/midas/generator/generator.py index ddab606..0c086d8 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -1,11 +1,33 @@ import ast import shutil +from dataclasses import dataclass, field from pathlib import Path +from midas.ast.location import Location import midas.ast.python as p +from midas.checker.types import ( + AliasType, + AppliedType, + BaseType, + ComplexType, + ExtensionType, + Function, + GenericType, + OverloadedFunction, + TopType, + Type, + TypeVar, + UnitType, +) from midas.utils import TypedAST +@dataclass +class Scope: + pre_assertions: list[ast.stmt] = field(default_factory=list) + aliases: list[str] = field(default_factory=list) + + class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): def __init__(self, workdir: Path) -> None: self.workdir: Path = workdir.resolve() @@ -13,19 +35,28 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): if self.build_dir.exists(): shutil.rmtree(self.build_dir) self.build_dir.mkdir(parents=True, exist_ok=True) + self.rel_src_path: Path = Path() + + self._typed_ast: TypedAST = TypedAST( + stmts=[], + judgements=[], + ) + self._alias_count: int = 0 + self._scopes: list[Scope] = [] def generate(self, typed_ast: TypedAST, src_path: Path) -> Path: + self.rel_src_path = src_path.relative_to(self.workdir) + self._typed_ast = typed_ast body: list[ast.stmt] = self._visit_body(typed_ast.stmts) module = ast.Module(body=body, type_ignores=[]) module = ast.fix_missing_locations(module) compiled: str = ast.unparse(module) - rel_src_path: Path = src_path.relative_to(self.workdir) - out_path: Path = (self.build_dir / rel_src_path).resolve() + out_path: Path = (self.build_dir / self.rel_src_path).resolve() try: _ = out_path.relative_to(self.build_dir) except ValueError: raise ValueError( - f"Directory traversal, {rel_src_path} points outside of parent directory" + f"Directory traversal, {self.rel_src_path} points outside of parent directory" ) out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(compiled) @@ -80,8 +111,13 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): ) def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr: - # TODO: insert assertion - return expr.expr.accept(self) + expr2: ast.expr = expr.expr.accept(self) + alias: ast.expr = self._make_alias(expr2) + + type: Type = self._get_expr_type(expr) + self._make_cast_asserts(expr.location, alias, type) + + return alias def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr: return ast.IfExp( @@ -176,4 +212,116 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): return stmt.stmt def _visit_body(self, stmts: list[p.Stmt]) -> list[ast.stmt]: - return [stmt.accept(self) for stmt in stmts] + generated: list[ast.stmt] = [] + for stmt in stmts: + scope = Scope() + self._scopes.append(scope) + + stmt2 = stmt.accept(self) + generated.extend(scope.pre_assertions) + generated.append(stmt2) + if len(scope.aliases) != 0: + generated.append( + ast.Delete(targets=[ast.Name(id=alias) for alias in scope.aliases]) + ) + self._scopes.pop() + + # Remove redundant pass statements + if len(generated) > 1: + generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)] + return generated + + def _make_alias(self, expr: ast.expr) -> ast.expr: + name: str = f"__midas_alias_{self._alias_count}__" + alias = ast.Name(id=name) + self._alias_count += 1 + self._scopes[-1].aliases.append(name) + self._scopes[-1].pre_assertions.append( + ast.Assign( + targets=[alias], + value=expr, + ) + ) + return alias + + def _add_assert(self, expr: ast.expr, message: str | ast.expr): + if isinstance(message, str): + message = ast.Constant(value=message) + self._scopes[-1].pre_assertions.append( + ast.Assert( + test=expr, + msg=message, + ) + ) + + def _get_expr_type(self, query: p.Expr) -> Type: + for expr, type in self._typed_ast.judgements: + if expr == query: + return type + raise RuntimeError(f"Cannot get type judgement for {query}") + + def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type): + match type: + case BaseType(name=name): + self._add_assert( + ast.Call( + func=ast.Name(id="isinstance"), + args=[expr, ast.Name(id=name)], + keywords=[], + ), + self._make_cast_assert_message(src_location, expr, type), + ) + + case AliasType(type=base): + self._make_cast_asserts(src_location, expr, base) + + case UnitType(): + self._add_assert( + ast.Compare( + left=expr, + ops=[ast.Is()], + comparators=[ + ast.Constant(value=None), + ], + ), + self._make_cast_assert_message(src_location, expr, type), + ) + + case AppliedType(): + self._make_cast_asserts(src_location, expr, type.body) + + case ( + TopType() + | Function() + | OverloadedFunction() + | ComplexType() + | ExtensionType() + | GenericType() + ): + raise NotImplementedError(f"Can't make assertion for type {type}") + + case TypeVar(): + raise RuntimeError("Unexpected TypeVar") + + def _make_cast_assert_message( + self, location: Location, expr: ast.expr, type: Type + ) -> ast.expr: + loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}" + # f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type" + return ast.JoinedStr( + values=[ + ast.Constant(f"{loc_str}: CastError: Cannot cast "), + ast.FormattedValue( + value=ast.Attribute( + value=ast.Call( + func=ast.Name(id="type"), + args=[expr], + keywords=[], + ), + attr="__name__", + ), + conversion=-1, + ), + ast.Constant(f" to {type}"), + ] + ) From c4062c95950293e957cbd119c597b22dff73d283 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 16 Jun 2026 12:55:05 +0200 Subject: [PATCH 4/9] fix(checker): allow inferred return to be subtype of hint --- midas/checker/python.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/midas/checker/python.py b/midas/checker/python.py index 9af8e73..b4da4ba 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -252,7 +252,7 @@ class PythonTyper( if returns_hint is not None: assert stmt.returns is not None returns = returns_hint - if returns != inferred_return: + if not self.is_subtype(inferred_return, returns): self.reporter.error( stmt.returns.location, f"Return type mismatch, annotated {returns} but returns {inferred_return}", From 732f7b079667644e1ea65af77ebd79aa34a322f4 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 16 Jun 2026 14:02:45 +0200 Subject: [PATCH 5/9] feat(checker): add environment preamble this adds some builtin functions such as the builtin type constructors --- midas/checker/preamble.py | 121 ++++++++++++++++++++++++++++++++++++++ midas/checker/python.py | 3 +- midas/checker/types.py | 8 +-- 3 files changed, 127 insertions(+), 5 deletions(-) create mode 100644 midas/checker/preamble.py diff --git a/midas/checker/preamble.py b/midas/checker/preamble.py new file mode 100644 index 0000000..a543dd9 --- /dev/null +++ b/midas/checker/preamble.py @@ -0,0 +1,121 @@ +from dataclasses import dataclass + +from midas.checker.environment import Environment +from midas.checker.registry import TypesRegistry +from midas.checker.types import Function, GenericType, TopType, Type, TypeVar, UnitType + + +@dataclass(frozen=True) +class Param: + name: str + type: Type + required: bool = True + + +class Preamble(Environment): + def __init__(self, types: TypesRegistry) -> None: + super().__init__() + self._types: TypesRegistry = types + + self._def_type_constructor("object") + self._def_type_constructor("float") + self._def_type_constructor("int") + self._def_type_constructor("bool") + self._def_type_constructor("str") + self._def_function( + name="list", + pos=[Param("object", TopType())], + returns=self._list_of(TopType()), + ) + + # TODO: use sink + self._def_function( + name="print", + pos=[Param("object", TopType())], + returns=UnitType(), + ) + + map_in = TypeVar(name="T", bound=None) + map_out = TypeVar(name="U", bound=None) + mapper = self._make_function( + name="MapTransform", + pos=[Param("v", map_in)], + returns=map_out, + ) + self._def_function( + name="map", + pos=[ + Param("transform", mapper), + Param( + "iterable", + self._list_of(map_in), # TODO: replace with Iterable[T] + ), + ], + returns=self._list_of(map_out), # TODO: replace with Iterable[U] + ) + + def _list_of(self, item_type: Type) -> Type: + return self._types.apply_generic(self._types.get_type("list"), [item_type]) + + def _def_type_constructor(self, name: str): + # TODO: more specific arg types + self._def_function( + name=name, + pos=[Param("object", TopType())], + returns=self._types.get_type(name), + ) + + def _make_function( + self, + *, + name: str, + pos: list[Param] = [], + mixed: list[Param] = [], + kw: list[Param] = [], + returns: Type = UnitType(), + type_vars: list[TypeVar] = [], + ) -> Type: + def map_args(params: list[Param], offset: int) -> list[Function.Argument]: + return [ + Function.Argument( + pos=i + offset, + name=param.name, + type=param.type, + required=param.required, + ) + for i, param in enumerate(params) + ] + + function = Function( + pos_args=map_args(pos, 0), + args=map_args(mixed, len(pos)), + kw_args=map_args(kw, len(pos) + len(mixed)), + returns=returns, + ) + if len(type_vars) != 0: + function = GenericType( + name=name, + params=type_vars, + body=function, + ) + return function + + def _def_function( + self, + *, + name: str, + pos: list[Param] = [], + mixed: list[Param] = [], + kw: list[Param] = [], + returns: Type = UnitType(), + type_vars: list[TypeVar] = [], + ): + function: Type = self._make_function( + name=name, + pos=pos, + mixed=mixed, + kw=kw, + returns=returns, + type_vars=type_vars, + ) + self.define(name, function) diff --git a/midas/checker/python.py b/midas/checker/python.py index b4da4ba..6e6ea1a 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -7,6 +7,7 @@ import midas.ast.python as p from midas.ast.location import Location from midas.checker.environment import Environment from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS +from midas.checker.preamble import Preamble from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter, Reporter from midas.checker.resolver import Resolver @@ -56,7 +57,7 @@ class PythonTyper( self.logger: logging.Logger = logging.getLogger("PythonTyper") self.reporter: FileReporter = reporter.for_file(None) self.types: TypesRegistry = types - self.global_env: Environment = Environment() + self.global_env: Environment = Preamble(self.types) self.env: Environment = self.global_env self.locals: dict[p.Expr, int] = {} self.judgements: list[tuple[p.Expr, Type]] = [] diff --git a/midas/checker/types.py b/midas/checker/types.py index c6d41d1..f2fef3b 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional @@ -41,9 +41,9 @@ class UnitType: @dataclass(frozen=True, kw_only=True) class Function: - pos_args: list[Argument] - args: list[Argument] - kw_args: list[Argument] + pos_args: list[Argument] = field(default_factory=list) + args: list[Argument] = field(default_factory=list) + kw_args: list[Argument] = field(default_factory=list) returns: Type def __str__(self) -> str: From 4b1087d6b9a7165e39f4c7a4b822fa294ebae618 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 16 Jun 2026 14:03:13 +0200 Subject: [PATCH 6/9] fix(cli): improve dump-registry command output --- midas/checker/types.py | 6 ++---- midas/cli/commands/registry.py | 18 ++++++++++++++++-- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/midas/checker/types.py b/midas/checker/types.py index f2fef3b..708d68b 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -50,15 +50,13 @@ class Function: args: list[str] = [] if len(self.pos_args) != 0: args += list(map(str, self.pos_args)) - if len(self.args) + len(self.kw_args) != 0: - args.append("/") + args.append("/") if len(self.args) != 0: args += list(map(str, self.args)) if len(self.kw_args) != 0: - if len(args) != 0: - args.append("*") + args.append("*") args += list(map(str, self.kw_args)) return f"({', '.join(args)}) -> {self.returns}" diff --git a/midas/cli/commands/registry.py b/midas/cli/commands/registry.py index 4e830be..d978ad9 100644 --- a/midas/cli/commands/registry.py +++ b/midas/cli/commands/registry.py @@ -9,7 +9,21 @@ from typing import TextIO import click from midas.checker.checker import TypeChecker -from midas.checker.types import Type +from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type + + +def base_type(type: Type) -> Type: + match type: + case BaseType(): + return type + case AliasType(type=base): + return base + case AppliedType(body=body): + return body + case GenericType(body=body): + return body + case _: + return type @click.command(help="Dump types registry") @@ -23,7 +37,7 @@ def dump_registry( for name, type in checker.types._types.items(): members: dict[str, Type] = checker.types._members.get(name, {}) - print(f"{name} = {type}") + print(f"{name} = {base_type(type)}") if len(members) != 0: print(" " * 4 + "Members:") for member_name, member_type in members.items(): From 591012d0596157c16e8e333753a39f5080b928a8 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 16 Jun 2026 14:04:20 +0200 Subject: [PATCH 7/9] fix(checker): allow calling AppliedType and UnknownType --- midas/checker/python.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/midas/checker/python.py b/midas/checker/python.py index 6e6ea1a..5679b19 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -12,6 +12,7 @@ from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter, Reporter from midas.checker.resolver import Resolver from midas.checker.types import ( + AppliedType, Function, OverloadedFunction, Type, @@ -644,6 +645,15 @@ class PythonTyper( if function is None: return None return function.returns + + case AppliedType(body=body): + return self._get_call_result( + location, body, positional, keywords, report_errors + ) + + case UnknownType(): + return UnknownType() + case _: if report_errors: self.reporter.error(location, f"{callee} is not callable") From da38cad23d3b6034720bc8d0283e2f2ea5d2aae7 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 16 Jun 2026 14:35:52 +0200 Subject: [PATCH 8/9] feat(tests): add generator tester --- midas/generator/generator.py | 26 +++++++++++------ tests/base.py | 6 +++- tests/generator.py | 55 ++++++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 10 deletions(-) create mode 100644 tests/generator.py diff --git a/midas/generator/generator.py b/midas/generator/generator.py index 0c086d8..a2eb5c2 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -2,6 +2,7 @@ import ast import shutil from dataclasses import dataclass, field from pathlib import Path +from typing import Optional from midas.ast.location import Location import midas.ast.python as p @@ -44,21 +45,28 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): self._alias_count: int = 0 self._scopes: list[Scope] = [] - def generate(self, typed_ast: TypedAST, src_path: Path) -> Path: + def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST: self.rel_src_path = src_path.relative_to(self.workdir) self._typed_ast = typed_ast body: list[ast.stmt] = self._visit_body(typed_ast.stmts) module = ast.Module(body=body, type_ignores=[]) module = ast.fix_missing_locations(module) + return module + + def generate( + self, typed_ast: TypedAST, src_path: Path, out_path: Optional[Path] = None + ) -> Path: + module: ast.AST = self.generate_ast(typed_ast, src_path) compiled: str = ast.unparse(module) - out_path: Path = (self.build_dir / self.rel_src_path).resolve() - try: - _ = out_path.relative_to(self.build_dir) - except ValueError: - raise ValueError( - f"Directory traversal, {self.rel_src_path} points outside of parent directory" - ) - out_path.parent.mkdir(parents=True, exist_ok=True) + if out_path is None: + out_path = (self.build_dir / self.rel_src_path).resolve() + try: + _ = out_path.relative_to(self.build_dir) + except ValueError: + raise ValueError( + f"Directory traversal, {self.rel_src_path} points outside of parent directory" + ) + out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(compiled) return out_path diff --git a/tests/base.py b/tests/base.py index 0749d79..c8bad17 100644 --- a/tests/base.py +++ b/tests/base.py @@ -21,6 +21,10 @@ class Tester(ABC): @abstractmethod def namespace(self) -> str: ... + @property + def extension(self) -> str: + return "json" + @property def base_dir(self) -> Path: return self.CASES_DIR / self.namespace @@ -99,7 +103,7 @@ class Tester(ABC): return True def _result_path(self, test_path: Path) -> Path: - return test_path.parent / (test_path.name + ".ref.json") + return test_path.parent / (test_path.name + f".ref.{self.extension}") def _print_diff(self, diff: Iterator[str]): for line in diff: diff --git a/tests/generator.py b/tests/generator.py new file mode 100644 index 0000000..72b7002 --- /dev/null +++ b/tests/generator.py @@ -0,0 +1,55 @@ +import ast +from dataclasses import dataclass +from pathlib import Path + +from midas.checker.checker import TypeChecker +from midas.checker.diagnostic import DiagnosticType +from midas.generator.generator import Generator +from midas.utils import TypedAST +from tests.base import Tester + + +@dataclass +class CaseResult: + compiled_ast: ast.AST = ast.Module([], []) + + def dumps(self) -> str: + return ast.dump(self.compiled_ast, indent=2) + + +class GeneratorTester(Tester): + @property + def namespace(self) -> str: + return "generator" + + @property + def extension(self) -> str: + return "txt" + + def _list_tests(self) -> list[Path]: + return list(self.base_dir.rglob("*.py")) + + def _exec_case(self, path: Path) -> CaseResult: + if not path.exists(): + raise FileNotFoundError(f"Could not find test '{path}'") + if not path.is_file(): + raise TypeError(f"Test '{path}' is not a file") + + result: CaseResult = CaseResult() + + checker = TypeChecker() + types_path: Path = path.with_suffix(".midas") + if types_path.exists(): + checker.import_midas(types_path) + + typed_ast: TypedAST = checker.type_check(path) + + if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics): + generator = Generator(workdir=path.parent) + result.compiled_ast = generator.generate_ast(typed_ast, path) + + return result + + +if __name__ == "__main__": + GeneratorTester.main() From 2a8b7d559ce33c5fd623e8d6c81903d6c6f1fe24 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 16 Jun 2026 14:36:06 +0200 Subject: [PATCH 9/9] tests: add simple gen test --- tests/cases/generator/01_simple_types.midas | 14 ++++ tests/cases/generator/01_simple_types.py | 5 ++ .../generator/01_simple_types.py.ref.txt | 79 +++++++++++++++++++ 3 files changed, 98 insertions(+) create mode 100644 tests/cases/generator/01_simple_types.midas create mode 100644 tests/cases/generator/01_simple_types.py create mode 100644 tests/cases/generator/01_simple_types.py.ref.txt diff --git a/tests/cases/generator/01_simple_types.midas b/tests/cases/generator/01_simple_types.midas new file mode 100644 index 0000000..ff4edb1 --- /dev/null +++ b/tests/cases/generator/01_simple_types.midas @@ -0,0 +1,14 @@ +type Meter = float +type Second = float +type MeterPerSecond = float + +extend Meter { + def __add__: fn(Meter, /) -> Meter + def __sub__: fn(Meter, /) -> Meter + def __truediv__: fn(Second, /) -> MeterPerSecond +} + +extend Second { + def __add__: fn(Second, /) -> Second + def __sub__: fn(Second, /) -> Second +} diff --git a/tests/cases/generator/01_simple_types.py b/tests/cases/generator/01_simple_types.py new file mode 100644 index 0000000..5d6b399 --- /dev/null +++ b/tests/cases/generator/01_simple_types.py @@ -0,0 +1,5 @@ +from midas import cast, Meter, Second + +distance: Meter = cast(Meter, 123.45) +time: Second = cast(Second, 6.7) +speed = distance / time diff --git a/tests/cases/generator/01_simple_types.py.ref.txt b/tests/cases/generator/01_simple_types.py.ref.txt new file mode 100644 index 0000000..aa5f964 --- /dev/null +++ b/tests/cases/generator/01_simple_types.py.ref.txt @@ -0,0 +1,79 @@ +Module( + body=[ + ImportFrom( + module='midas', + names=[ + alias(name='cast'), + alias(name='Meter'), + alias(name='Second')], + level=0), + Assign( + targets=[ + Name(id='__midas_alias_0__')], + value=Constant(value=123.45)), + Assert( + test=Call( + func=Name(id='isinstance'), + args=[ + Name(id='__midas_alias_0__'), + Name(id='float')], + keywords=[]), + msg=JoinedStr( + values=[ + Constant(value='01_simple_types.py:L3:19: CastError: Cannot cast '), + FormattedValue( + value=Attribute( + value=Call( + func=Name(id='type'), + args=[ + Name(id='__midas_alias_0__')], + keywords=[]), + attr='__name__'), + conversion=-1), + Constant(value=' to float')])), + Assign( + targets=[ + Name(id='distance')], + value=Name(id='__midas_alias_0__')), + Delete( + targets=[ + Name(id='__midas_alias_0__')]), + Assign( + targets=[ + Name(id='__midas_alias_1__')], + value=Constant(value=6.7)), + Assert( + test=Call( + func=Name(id='isinstance'), + args=[ + Name(id='__midas_alias_1__'), + Name(id='float')], + keywords=[]), + msg=JoinedStr( + values=[ + Constant(value='01_simple_types.py:L4:16: CastError: Cannot cast '), + FormattedValue( + value=Attribute( + value=Call( + func=Name(id='type'), + args=[ + Name(id='__midas_alias_1__')], + keywords=[]), + attr='__name__'), + conversion=-1), + Constant(value=' to float')])), + Assign( + targets=[ + Name(id='time')], + value=Name(id='__midas_alias_1__')), + Delete( + targets=[ + Name(id='__midas_alias_1__')]), + Assign( + targets=[ + Name(id='speed')], + value=BinOp( + left=Name(id='distance'), + op=Div(), + right=Name(id='time')))], + type_ignores=[]) \ No newline at end of file