11 Commits

18 changed files with 482 additions and 23 deletions

View File

@@ -157,6 +157,11 @@ class ListExpr:
items: list[Expr] items: list[Expr]
class DictExpr:
keys: list[Optional[Expr]]
values: list[Expr]
class SubscriptExpr: class SubscriptExpr:
object: Expr object: Expr
index: Expr index: Expr

View File

@@ -745,6 +745,27 @@ class PythonAstPrinter(
self._mark_last() self._mark_last()
item.accept(self) item.accept(self)
def visit_dict_expr(self, expr: p.DictExpr) -> None:
self._write_line("DictExpr")
with self._child_level():
self._write_line("keys")
with self._child_level():
for i, key in enumerate(expr.keys):
self._idx = i
if i == len(expr.keys) - 1:
self._mark_last()
if key is None:
self._write_line("None")
else:
key.accept(self)
self._write_line("values", last=True)
with self._child_level():
for i, value in enumerate(expr.values):
self._idx = i
if i == len(expr.values) - 1:
self._mark_last()
value.accept(self)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
self._write_line("SubscriptExpr") self._write_line("SubscriptExpr")
with self._child_level(): with self._child_level():

View File

@@ -259,6 +259,9 @@ class Expr(ABC):
@abstractmethod @abstractmethod
def visit_list_expr(self, expr: ListExpr) -> T: ... def visit_list_expr(self, expr: ListExpr) -> T: ...
@abstractmethod
def visit_dict_expr(self, expr: DictExpr) -> T: ...
@abstractmethod @abstractmethod
def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ... def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ...
@@ -370,6 +373,15 @@ class ListExpr(Expr):
return visitor.visit_list_expr(self) return visitor.visit_list_expr(self)
@dataclass(frozen=True)
class DictExpr(Expr):
keys: list[Optional[Expr]]
values: list[Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_dict_expr(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class SubscriptExpr(Expr): class SubscriptExpr(Expr):
object: Expr object: Expr

View File

@@ -150,3 +150,32 @@ extend list[T] {
prop __doc__: str prop __doc__: str
} }
extend dict[K, V] {
def copy: fn() -> dict[K, V]
def keys: fn() -> list[K] // TODO: use builtin types
def values: fn() -> list[V] // TODO: use builtin types
// def items: fn() -> list[tuple[K, V]] // TODO: use builtin types
// def get: fn(key: K, default: None = None, /) -> V | None
def get: fn(key: K, default: V, /) -> V
// def get: fn[T](key: K, default: T, /) -> V | T
def pop: fn(key: K, /) -> V
def pop: fn(key: K, default: V, /) -> V
// def pop: fn[T](key: K, default: T, /) -> V | T
def __len__: fn() -> int
def __getitem__: fn(key: K, /) -> V
def __setitem__: fn(key: K, value: V, /) -> None
def __delitem__: fn(key: K, /) -> None
// def __iter__: fn() -> Iterator[K]
def __eq__: fn(value: object, /) -> bool
// def __reversed__: fn() -> Iterator[K]
def __or__: fn(value: dict[K, V], /) -> dict[K, V]
// def __or__: fn[K2, V2](value: dict[K2, V2], /) -> dict[K | K2, V | V2]
def __ror__: fn(value: dict[K, V], /) -> dict[K, V]
// def __ror__: fn[K2, V2](value: dict[K2, V2], /) -> dict[K | K2, V | V2]
// def __ior__: fn(value: SupportsKeysAndGetItem[K, V], /) -> dict[K, V]
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
}

View File

@@ -39,3 +39,14 @@ def define_builtins(reg: TypesRegistry):
body=BaseType(name="list"), body=BaseType(name="list"),
), ),
) )
dict = reg.define_type(
"dict",
GenericType(
name="dict",
params=[
TypeVar(name="K", bound=None),
TypeVar(name="V", bound=None),
],
body=BaseType(name="dict"),
),
)

121
midas/checker/preamble.py Normal file
View File

@@ -0,0 +1,121 @@
from dataclasses import dataclass
from midas.checker.environment import Environment
from midas.checker.registry import TypesRegistry
from midas.checker.types import Function, GenericType, TopType, Type, TypeVar, UnitType
@dataclass(frozen=True)
class Param:
name: str
type: Type
required: bool = True
class Preamble(Environment):
def __init__(self, types: TypesRegistry) -> None:
super().__init__()
self._types: TypesRegistry = types
self._def_type_constructor("object")
self._def_type_constructor("float")
self._def_type_constructor("int")
self._def_type_constructor("bool")
self._def_type_constructor("str")
self._def_function(
name="list",
pos=[Param("object", TopType())],
returns=self._list_of(TopType()),
)
# TODO: use sink
self._def_function(
name="print",
pos=[Param("object", TopType())],
returns=UnitType(),
)
map_in = TypeVar(name="T", bound=None)
map_out = TypeVar(name="U", bound=None)
mapper = self._make_function(
name="MapTransform",
pos=[Param("v", map_in)],
returns=map_out,
)
self._def_function(
name="map",
pos=[
Param("transform", mapper),
Param(
"iterable",
self._list_of(map_in), # TODO: replace with Iterable[T]
),
],
returns=self._list_of(map_out), # TODO: replace with Iterable[U]
)
def _list_of(self, item_type: Type) -> Type:
return self._types.apply_generic(self._types.get_type("list"), [item_type])
def _def_type_constructor(self, name: str):
# TODO: more specific arg types
self._def_function(
name=name,
pos=[Param("object", TopType(), required=False)],
returns=self._types.get_type(name),
)
def _make_function(
self,
*,
name: str,
pos: list[Param] = [],
mixed: list[Param] = [],
kw: list[Param] = [],
returns: Type = UnitType(),
type_vars: list[TypeVar] = [],
) -> Type:
def map_args(params: list[Param], offset: int) -> list[Function.Argument]:
return [
Function.Argument(
pos=i + offset,
name=param.name,
type=param.type,
required=param.required,
)
for i, param in enumerate(params)
]
function = Function(
pos_args=map_args(pos, 0),
args=map_args(mixed, len(pos)),
kw_args=map_args(kw, len(pos) + len(mixed)),
returns=returns,
)
if len(type_vars) != 0:
function = GenericType(
name=name,
params=type_vars,
body=function,
)
return function
def _def_function(
self,
*,
name: str,
pos: list[Param] = [],
mixed: list[Param] = [],
kw: list[Param] = [],
returns: Type = UnitType(),
type_vars: list[TypeVar] = [],
):
function: Type = self._make_function(
name=name,
pos=pos,
mixed=mixed,
kw=kw,
returns=returns,
type_vars=type_vars,
)
self.define(name, function)

View File

@@ -7,10 +7,12 @@ import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver from midas.checker.resolver import Resolver
from midas.checker.types import ( from midas.checker.types import (
AppliedType,
Function, Function,
OverloadedFunction, OverloadedFunction,
Type, Type,
@@ -56,7 +58,7 @@ class PythonTyper(
self.logger: logging.Logger = logging.getLogger("PythonTyper") self.logger: logging.Logger = logging.getLogger("PythonTyper")
self.reporter: FileReporter = reporter.for_file(None) self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types self.types: TypesRegistry = types
self.global_env: Environment = Environment() self.global_env: Environment = Preamble(self.types)
self.env: Environment = self.global_env self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {} self.locals: dict[p.Expr, int] = {}
self.judgements: list[tuple[p.Expr, Type]] = [] self.judgements: list[tuple[p.Expr, Type]] = []
@@ -252,7 +254,7 @@ class PythonTyper(
if returns_hint is not None: if returns_hint is not None:
assert stmt.returns is not None assert stmt.returns is not None
returns = returns_hint returns = returns_hint
if returns != inferred_return: if not self.is_subtype(inferred_return, returns):
self.reporter.error( self.reporter.error(
stmt.returns.location, stmt.returns.location,
f"Return type mismatch, annotated {returns} but returns {inferred_return}", f"Return type mismatch, annotated {returns} but returns {inferred_return}",
@@ -550,6 +552,46 @@ class PythonTyper(
) )
return self.types.apply_generic(list_type, [UnknownType()]) return self.types.apply_generic(list_type, [UnknownType()])
def visit_dict_expr(self, expr: p.DictExpr) -> Type:
dict_type: Type = self.types.get_type("dict")
key_types: list[Type] = []
value_types: list[Type] = []
for key, value in zip(expr.keys, expr.values):
if key is None:
self.reporter.warning(
value.location, "Dictionary unpacking not supported"
)
continue
key_types.append(self.type_of(key))
value_types.append(self.type_of(value))
key_types = self.types.reduce_types(key_types)
value_types = self.types.reduce_types(value_types)
if len(key_types) == 0 or len(value_types) == 0:
return dict_type
key_type: Type = UnknownType()
value_type: Type = UnknownType()
if len(key_types) == 1:
key_type = key_types[0]
else:
self.reporter.error(
expr.location,
f"Heterogeneous dict keys: {key_types}",
)
if len(value_types) == 1:
value_type = value_types[0]
else:
self.reporter.error(
expr.location,
f"Heterogeneous dict values: {value_types}",
)
return self.types.apply_generic(dict_type, [key_type, value_type])
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type: def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type:
object: Type = self.type_of(expr.object) object: Type = self.type_of(expr.object)
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__") operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
@@ -643,6 +685,15 @@ class PythonTyper(
if function is None: if function is None:
return None return None
return function.returns return function.returns
case AppliedType(body=body):
return self._get_call_result(
location, body, positional, keywords, report_errors
)
case UnknownType():
return UnknownType()
case _: case _:
if report_errors: if report_errors:
self.reporter.error(location, f"{callee} is not callable") self.reporter.error(location, f"{callee} is not callable")

View File

@@ -213,6 +213,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
for item in expr.items: for item in expr.items:
self.resolve(item) self.resolve(item)
def visit_dict_expr(self, expr: p.DictExpr) -> None:
for key in expr.keys:
if key is not None:
self.resolve(key)
for value in expr.values:
self.resolve(value)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
self.resolve(expr.object) self.resolve(expr.object)
self.resolve(expr.index) self.resolve(expr.index)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional from typing import Optional
@@ -41,23 +41,21 @@ class UnitType:
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Function: class Function:
pos_args: list[Argument] pos_args: list[Argument] = field(default_factory=list)
args: list[Argument] args: list[Argument] = field(default_factory=list)
kw_args: list[Argument] kw_args: list[Argument] = field(default_factory=list)
returns: Type returns: Type
def __str__(self) -> str: def __str__(self) -> str:
args: list[str] = [] args: list[str] = []
if len(self.pos_args) != 0: if len(self.pos_args) != 0:
args += list(map(str, self.pos_args)) args += list(map(str, self.pos_args))
if len(self.args) + len(self.kw_args) != 0:
args.append("/") args.append("/")
if len(self.args) != 0: if len(self.args) != 0:
args += list(map(str, self.args)) args += list(map(str, self.args))
if len(self.kw_args) != 0: if len(self.kw_args) != 0:
if len(args) != 0:
args.append("*") args.append("*")
args += list(map(str, self.kw_args)) args += list(map(str, self.kw_args))

View File

@@ -9,7 +9,21 @@ from typing import TextIO
import click import click
from midas.checker.checker import TypeChecker from midas.checker.checker import TypeChecker
from midas.checker.types import Type from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
def base_type(type: Type) -> Type:
match type:
case BaseType():
return type
case AliasType(type=base):
return base
case AppliedType(body=body):
return body
case GenericType(body=body):
return body
case _:
return type
@click.command(help="Dump types registry") @click.command(help="Dump types registry")
@@ -23,7 +37,7 @@ def dump_registry(
for name, type in checker.types._types.items(): for name, type in checker.types._types.items():
members: dict[str, Type] = checker.types._members.get(name, {}) members: dict[str, Type] = checker.types._members.get(name, {})
print(f"{name} = {type}") print(f"{name} = {base_type(type)}")
if len(members) != 0: if len(members) != 0:
print(" " * 4 + "Members:") print(" " * 4 + "Members:")
for member_name, member_type in members.items(): for member_name, member_type in members.items():

View File

@@ -2,9 +2,10 @@ import ast
import shutil import shutil
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Optional
from midas.ast.location import Location
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.types import ( from midas.checker.types import (
AliasType, AliasType,
AppliedType, AppliedType,
@@ -44,14 +45,21 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._alias_count: int = 0 self._alias_count: int = 0
self._scopes: list[Scope] = [] self._scopes: list[Scope] = []
def generate(self, typed_ast: TypedAST, src_path: Path) -> Path: def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
self.rel_src_path = src_path.relative_to(self.workdir) self.rel_src_path = src_path.relative_to(self.workdir)
self._typed_ast = typed_ast self._typed_ast = typed_ast
body: list[ast.stmt] = self._visit_body(typed_ast.stmts) body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
module = ast.Module(body=body, type_ignores=[]) module = ast.Module(body=body, type_ignores=[])
module = ast.fix_missing_locations(module) module = ast.fix_missing_locations(module)
return module
def generate(
self, typed_ast: TypedAST, src_path: Path, out_path: Optional[Path] = None
) -> Path:
module: ast.AST = self.generate_ast(typed_ast, src_path)
compiled: str = ast.unparse(module) compiled: str = ast.unparse(module)
out_path: Path = (self.build_dir / self.rel_src_path).resolve() if out_path is None:
out_path = (self.build_dir / self.rel_src_path).resolve()
try: try:
_ = out_path.relative_to(self.build_dir) _ = out_path.relative_to(self.build_dir)
except ValueError: except ValueError:
@@ -131,6 +139,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
elts=[item.accept(self) for item in expr.items], elts=[item.accept(self) for item in expr.items],
) )
def visit_dict_expr(self, expr: p.DictExpr) -> ast.expr:
return ast.Dict(
keys=[key.accept(self) if key is not None else None for key in expr.keys],
values=[value.accept(self) for value in expr.values],
)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr: def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
return ast.Subscript( return ast.Subscript(
value=expr.object.accept(self), value=expr.object.accept(self),

View File

@@ -10,6 +10,7 @@ from midas.ast.python import (
CastExpr, CastExpr,
CompareExpr, CompareExpr,
ConstraintType, ConstraintType,
DictExpr,
Expr, Expr,
ExpressionStmt, ExpressionStmt,
ForStmt, ForStmt,
@@ -447,6 +448,16 @@ class PythonParser:
items=[self.parse_expr(item) for item in items], items=[self.parse_expr(item) for item in items],
) )
case ast.Dict(keys=keys, values=values):
return DictExpr(
location=location,
keys=[
self.parse_expr(key) if key is not None else None
for key in keys
],
values=[self.parse_expr(value) for value in values],
)
case ast.Subscript(value=value, slice=index): case ast.Subscript(value=value, slice=index):
return SubscriptExpr( return SubscriptExpr(
location=location, location=location,

View File

@@ -21,6 +21,10 @@ class Tester(ABC):
@abstractmethod @abstractmethod
def namespace(self) -> str: ... def namespace(self) -> str: ...
@property
def extension(self) -> str:
return "json"
@property @property
def base_dir(self) -> Path: def base_dir(self) -> Path:
return self.CASES_DIR / self.namespace return self.CASES_DIR / self.namespace
@@ -99,7 +103,7 @@ class Tester(ABC):
return True return True
def _result_path(self, test_path: Path) -> Path: def _result_path(self, test_path: Path) -> Path:
return test_path.parent / (test_path.name + ".ref.json") return test_path.parent / (test_path.name + f".ref.{self.extension}")
def _print_diff(self, diff: Iterator[str]): def _print_diff(self, diff: Iterator[str]):
for line in diff: for line in diff:

View File

@@ -0,0 +1,14 @@
type Meter = float
type Second = float
type MeterPerSecond = float
extend Meter {
def __add__: fn(Meter, /) -> Meter
def __sub__: fn(Meter, /) -> Meter
def __truediv__: fn(Second, /) -> MeterPerSecond
}
extend Second {
def __add__: fn(Second, /) -> Second
def __sub__: fn(Second, /) -> Second
}

View File

@@ -0,0 +1,5 @@
from midas import cast, Meter, Second
distance: Meter = cast(Meter, 123.45)
time: Second = cast(Second, 6.7)
speed = distance / time

View File

@@ -0,0 +1,79 @@
Module(
body=[
ImportFrom(
module='midas',
names=[
alias(name='cast'),
alias(name='Meter'),
alias(name='Second')],
level=0),
Assign(
targets=[
Name(id='__midas_alias_0__')],
value=Constant(value=123.45)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_alias_0__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='01_simple_types.py:L3:19: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_alias_0__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assign(
targets=[
Name(id='distance')],
value=Name(id='__midas_alias_0__')),
Delete(
targets=[
Name(id='__midas_alias_0__')]),
Assign(
targets=[
Name(id='__midas_alias_1__')],
value=Constant(value=6.7)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_alias_1__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='01_simple_types.py:L4:16: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_alias_1__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assign(
targets=[
Name(id='time')],
value=Name(id='__midas_alias_1__')),
Delete(
targets=[
Name(id='__midas_alias_1__')]),
Assign(
targets=[
Name(id='speed')],
value=BinOp(
left=Name(id='distance'),
op=Div(),
right=Name(id='time')))],
type_ignores=[])

55
tests/generator.py Normal file
View File

@@ -0,0 +1,55 @@
import ast
from dataclasses import dataclass
from pathlib import Path
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import DiagnosticType
from midas.generator.generator import Generator
from midas.utils import TypedAST
from tests.base import Tester
@dataclass
class CaseResult:
compiled_ast: ast.AST = ast.Module([], [])
def dumps(self) -> str:
return ast.dump(self.compiled_ast, indent=2)
class GeneratorTester(Tester):
@property
def namespace(self) -> str:
return "generator"
@property
def extension(self) -> str:
return "txt"
def _list_tests(self) -> list[Path]:
return list(self.base_dir.rglob("*.py"))
def _exec_case(self, path: Path) -> CaseResult:
if not path.exists():
raise FileNotFoundError(f"Could not find test '{path}'")
if not path.is_file():
raise TypeError(f"Test '{path}' is not a file")
result: CaseResult = CaseResult()
checker = TypeChecker()
types_path: Path = path.with_suffix(".midas")
if types_path.exists():
checker.import_midas(types_path)
typed_ast: TypedAST = checker.type_check(path)
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
generator = Generator(workdir=path.parent)
result.compiled_ast = generator.generate_ast(typed_ast, path)
return result
if __name__ == "__main__":
GeneratorTester.main()

View File

@@ -9,6 +9,7 @@ from midas.ast.python import (
CastExpr, CastExpr,
CompareExpr, CompareExpr,
ConstraintType, ConstraintType,
DictExpr,
Expr, Expr,
ExpressionStmt, ExpressionStmt,
ForStmt, ForStmt,
@@ -278,6 +279,13 @@ class PythonAstJsonSerializer(
"items": [item.accept(self) for item in expr.items], "items": [item.accept(self) for item in expr.items],
} }
def visit_dict_expr(self, expr: DictExpr) -> dict:
return {
"_type": "DictExpr",
"keys": [self._serialize_optional(key) for key in expr.keys],
"values": self._serialize_list(expr.values),
}
def visit_subscript_expr(self, expr: SubscriptExpr) -> dict: def visit_subscript_expr(self, expr: SubscriptExpr) -> dict:
return { return {
"_type": "SubscriptExpr", "_type": "SubscriptExpr",