Compare commits

6 Commits

Author SHA1 Message Date
2a8b7d559c tests: add simple gen test 2026-06-16 14:56:59 +02:00
da38cad23d feat(tests): add generator tester 2026-06-16 14:56:59 +02:00
591012d059 fix(checker): allow calling AppliedType and UnknownType 2026-06-16 14:56:58 +02:00
4b1087d6b9 fix(cli): improve dump-registry command output 2026-06-16 14:56:57 +02:00
732f7b0796 feat(checker): add environment preamble
this adds some builtin functions such as the builtin type constructors
2026-06-16 14:56:56 +02:00
c4062c9595 fix(checker): allow inferred return to be subtype of hint 2026-06-16 14:56:47 +02:00
10 changed files with 331 additions and 22 deletions

121
midas/checker/preamble.py Normal file
View File

@@ -0,0 +1,121 @@
from dataclasses import dataclass
from midas.checker.environment import Environment
from midas.checker.registry import TypesRegistry
from midas.checker.types import Function, GenericType, TopType, Type, TypeVar, UnitType
@dataclass(frozen=True)
class Param:
name: str
type: Type
required: bool = True
class Preamble(Environment):
def __init__(self, types: TypesRegistry) -> None:
super().__init__()
self._types: TypesRegistry = types
self._def_type_constructor("object")
self._def_type_constructor("float")
self._def_type_constructor("int")
self._def_type_constructor("bool")
self._def_type_constructor("str")
self._def_function(
name="list",
pos=[Param("object", TopType())],
returns=self._list_of(TopType()),
)
# TODO: use sink
self._def_function(
name="print",
pos=[Param("object", TopType())],
returns=UnitType(),
)
map_in = TypeVar(name="T", bound=None)
map_out = TypeVar(name="U", bound=None)
mapper = self._make_function(
name="MapTransform",
pos=[Param("v", map_in)],
returns=map_out,
)
self._def_function(
name="map",
pos=[
Param("transform", mapper),
Param(
"iterable",
self._list_of(map_in), # TODO: replace with Iterable[T]
),
],
returns=self._list_of(map_out), # TODO: replace with Iterable[U]
)
def _list_of(self, item_type: Type) -> Type:
return self._types.apply_generic(self._types.get_type("list"), [item_type])
def _def_type_constructor(self, name: str):
# TODO: more specific arg types
self._def_function(
name=name,
pos=[Param("object", TopType())],
returns=self._types.get_type(name),
)
def _make_function(
self,
*,
name: str,
pos: list[Param] = [],
mixed: list[Param] = [],
kw: list[Param] = [],
returns: Type = UnitType(),
type_vars: list[TypeVar] = [],
) -> Type:
def map_args(params: list[Param], offset: int) -> list[Function.Argument]:
return [
Function.Argument(
pos=i + offset,
name=param.name,
type=param.type,
required=param.required,
)
for i, param in enumerate(params)
]
function = Function(
pos_args=map_args(pos, 0),
args=map_args(mixed, len(pos)),
kw_args=map_args(kw, len(pos) + len(mixed)),
returns=returns,
)
if len(type_vars) != 0:
function = GenericType(
name=name,
params=type_vars,
body=function,
)
return function
def _def_function(
self,
*,
name: str,
pos: list[Param] = [],
mixed: list[Param] = [],
kw: list[Param] = [],
returns: Type = UnitType(),
type_vars: list[TypeVar] = [],
):
function: Type = self._make_function(
name=name,
pos=pos,
mixed=mixed,
kw=kw,
returns=returns,
type_vars=type_vars,
)
self.define(name, function)

View File

@@ -7,10 +7,12 @@ import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver from midas.checker.resolver import Resolver
from midas.checker.types import ( from midas.checker.types import (
AppliedType,
Function, Function,
OverloadedFunction, OverloadedFunction,
Type, Type,
@@ -56,7 +58,7 @@ class PythonTyper(
self.logger: logging.Logger = logging.getLogger("PythonTyper") self.logger: logging.Logger = logging.getLogger("PythonTyper")
self.reporter: FileReporter = reporter.for_file(None) self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types self.types: TypesRegistry = types
self.global_env: Environment = Environment() self.global_env: Environment = Preamble(self.types)
self.env: Environment = self.global_env self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {} self.locals: dict[p.Expr, int] = {}
self.judgements: list[tuple[p.Expr, Type]] = [] self.judgements: list[tuple[p.Expr, Type]] = []
@@ -252,7 +254,7 @@ class PythonTyper(
if returns_hint is not None: if returns_hint is not None:
assert stmt.returns is not None assert stmt.returns is not None
returns = returns_hint returns = returns_hint
if returns != inferred_return: if not self.is_subtype(inferred_return, returns):
self.reporter.error( self.reporter.error(
stmt.returns.location, stmt.returns.location,
f"Return type mismatch, annotated {returns} but returns {inferred_return}", f"Return type mismatch, annotated {returns} but returns {inferred_return}",
@@ -643,6 +645,15 @@ class PythonTyper(
if function is None: if function is None:
return None return None
return function.returns return function.returns
case AppliedType(body=body):
return self._get_call_result(
location, body, positional, keywords, report_errors
)
case UnknownType():
return UnknownType()
case _: case _:
if report_errors: if report_errors:
self.reporter.error(location, f"{callee} is not callable") self.reporter.error(location, f"{callee} is not callable")

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional from typing import Optional
@@ -41,23 +41,21 @@ class UnitType:
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Function: class Function:
pos_args: list[Argument] pos_args: list[Argument] = field(default_factory=list)
args: list[Argument] args: list[Argument] = field(default_factory=list)
kw_args: list[Argument] kw_args: list[Argument] = field(default_factory=list)
returns: Type returns: Type
def __str__(self) -> str: def __str__(self) -> str:
args: list[str] = [] args: list[str] = []
if len(self.pos_args) != 0: if len(self.pos_args) != 0:
args += list(map(str, self.pos_args)) args += list(map(str, self.pos_args))
if len(self.args) + len(self.kw_args) != 0:
args.append("/") args.append("/")
if len(self.args) != 0: if len(self.args) != 0:
args += list(map(str, self.args)) args += list(map(str, self.args))
if len(self.kw_args) != 0: if len(self.kw_args) != 0:
if len(args) != 0:
args.append("*") args.append("*")
args += list(map(str, self.kw_args)) args += list(map(str, self.kw_args))

View File

@@ -9,7 +9,21 @@ from typing import TextIO
import click import click
from midas.checker.checker import TypeChecker from midas.checker.checker import TypeChecker
from midas.checker.types import Type from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
def base_type(type: Type) -> Type:
match type:
case BaseType():
return type
case AliasType(type=base):
return base
case AppliedType(body=body):
return body
case GenericType(body=body):
return body
case _:
return type
@click.command(help="Dump types registry") @click.command(help="Dump types registry")
@@ -23,7 +37,7 @@ def dump_registry(
for name, type in checker.types._types.items(): for name, type in checker.types._types.items():
members: dict[str, Type] = checker.types._members.get(name, {}) members: dict[str, Type] = checker.types._members.get(name, {})
print(f"{name} = {type}") print(f"{name} = {base_type(type)}")
if len(members) != 0: if len(members) != 0:
print(" " * 4 + "Members:") print(" " * 4 + "Members:")
for member_name, member_type in members.items(): for member_name, member_type in members.items():

View File

@@ -2,6 +2,7 @@ import ast
import shutil import shutil
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Optional
from midas.ast.location import Location from midas.ast.location import Location
import midas.ast.python as p import midas.ast.python as p
@@ -44,14 +45,21 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._alias_count: int = 0 self._alias_count: int = 0
self._scopes: list[Scope] = [] self._scopes: list[Scope] = []
def generate(self, typed_ast: TypedAST, src_path: Path) -> Path: def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
self.rel_src_path = src_path.relative_to(self.workdir) self.rel_src_path = src_path.relative_to(self.workdir)
self._typed_ast = typed_ast self._typed_ast = typed_ast
body: list[ast.stmt] = self._visit_body(typed_ast.stmts) body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
module = ast.Module(body=body, type_ignores=[]) module = ast.Module(body=body, type_ignores=[])
module = ast.fix_missing_locations(module) module = ast.fix_missing_locations(module)
return module
def generate(
self, typed_ast: TypedAST, src_path: Path, out_path: Optional[Path] = None
) -> Path:
module: ast.AST = self.generate_ast(typed_ast, src_path)
compiled: str = ast.unparse(module) compiled: str = ast.unparse(module)
out_path: Path = (self.build_dir / self.rel_src_path).resolve() if out_path is None:
out_path = (self.build_dir / self.rel_src_path).resolve()
try: try:
_ = out_path.relative_to(self.build_dir) _ = out_path.relative_to(self.build_dir)
except ValueError: except ValueError:

View File

@@ -21,6 +21,10 @@ class Tester(ABC):
@abstractmethod @abstractmethod
def namespace(self) -> str: ... def namespace(self) -> str: ...
@property
def extension(self) -> str:
return "json"
@property @property
def base_dir(self) -> Path: def base_dir(self) -> Path:
return self.CASES_DIR / self.namespace return self.CASES_DIR / self.namespace
@@ -99,7 +103,7 @@ class Tester(ABC):
return True return True
def _result_path(self, test_path: Path) -> Path: def _result_path(self, test_path: Path) -> Path:
return test_path.parent / (test_path.name + ".ref.json") return test_path.parent / (test_path.name + f".ref.{self.extension}")
def _print_diff(self, diff: Iterator[str]): def _print_diff(self, diff: Iterator[str]):
for line in diff: for line in diff:

View File

@@ -0,0 +1,14 @@
type Meter = float
type Second = float
type MeterPerSecond = float
extend Meter {
def __add__: fn(Meter, /) -> Meter
def __sub__: fn(Meter, /) -> Meter
def __truediv__: fn(Second, /) -> MeterPerSecond
}
extend Second {
def __add__: fn(Second, /) -> Second
def __sub__: fn(Second, /) -> Second
}

View File

@@ -0,0 +1,5 @@
from midas import cast, Meter, Second
distance: Meter = cast(Meter, 123.45)
time: Second = cast(Second, 6.7)
speed = distance / time

View File

@@ -0,0 +1,79 @@
Module(
body=[
ImportFrom(
module='midas',
names=[
alias(name='cast'),
alias(name='Meter'),
alias(name='Second')],
level=0),
Assign(
targets=[
Name(id='__midas_alias_0__')],
value=Constant(value=123.45)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_alias_0__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='01_simple_types.py:L3:19: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_alias_0__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assign(
targets=[
Name(id='distance')],
value=Name(id='__midas_alias_0__')),
Delete(
targets=[
Name(id='__midas_alias_0__')]),
Assign(
targets=[
Name(id='__midas_alias_1__')],
value=Constant(value=6.7)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_alias_1__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='01_simple_types.py:L4:16: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_alias_1__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assign(
targets=[
Name(id='time')],
value=Name(id='__midas_alias_1__')),
Delete(
targets=[
Name(id='__midas_alias_1__')]),
Assign(
targets=[
Name(id='speed')],
value=BinOp(
left=Name(id='distance'),
op=Div(),
right=Name(id='time')))],
type_ignores=[])

55
tests/generator.py Normal file
View File

@@ -0,0 +1,55 @@
import ast
from dataclasses import dataclass
from pathlib import Path
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import DiagnosticType
from midas.generator.generator import Generator
from midas.utils import TypedAST
from tests.base import Tester
@dataclass
class CaseResult:
compiled_ast: ast.AST = ast.Module([], [])
def dumps(self) -> str:
return ast.dump(self.compiled_ast, indent=2)
class GeneratorTester(Tester):
@property
def namespace(self) -> str:
return "generator"
@property
def extension(self) -> str:
return "txt"
def _list_tests(self) -> list[Path]:
return list(self.base_dir.rglob("*.py"))
def _exec_case(self, path: Path) -> CaseResult:
if not path.exists():
raise FileNotFoundError(f"Could not find test '{path}'")
if not path.is_file():
raise TypeError(f"Test '{path}' is not a file")
result: CaseResult = CaseResult()
checker = TypeChecker()
types_path: Path = path.with_suffix(".midas")
if types_path.exists():
checker.import_midas(types_path)
typed_ast: TypedAST = checker.type_check(path)
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
generator = Generator(workdir=path.parent)
result.compiled_ast = generator.generate_ast(typed_ast, path)
return result
if __name__ == "__main__":
GeneratorTester.main()