diff --git a/gen/python.py b/gen/python.py index f67f540..4af901a 100644 --- a/gen/python.py +++ b/gen/python.py @@ -157,6 +157,11 @@ class ListExpr: items: list[Expr] +class DictExpr: + keys: list[Optional[Expr]] + values: list[Expr] + + class SubscriptExpr: object: Expr index: Expr diff --git a/midas/ast/printer.py b/midas/ast/printer.py index 68ff7ba..694c272 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -745,6 +745,27 @@ class PythonAstPrinter( self._mark_last() item.accept(self) + def visit_dict_expr(self, expr: p.DictExpr) -> None: + self._write_line("DictExpr") + with self._child_level(): + self._write_line("keys") + with self._child_level(): + for i, key in enumerate(expr.keys): + self._idx = i + if i == len(expr.keys) - 1: + self._mark_last() + if key is None: + self._write_line("None") + else: + key.accept(self) + self._write_line("values", last=True) + with self._child_level(): + for i, value in enumerate(expr.values): + self._idx = i + if i == len(expr.values) - 1: + self._mark_last() + value.accept(self) + def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: self._write_line("SubscriptExpr") with self._child_level(): diff --git a/midas/ast/python.py b/midas/ast/python.py index 73d49e5..7770de6 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -259,6 +259,9 @@ class Expr(ABC): @abstractmethod def visit_list_expr(self, expr: ListExpr) -> T: ... + @abstractmethod + def visit_dict_expr(self, expr: DictExpr) -> T: ... + @abstractmethod def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ... @@ -370,6 +373,15 @@ class ListExpr(Expr): return visitor.visit_list_expr(self) +@dataclass(frozen=True) +class DictExpr(Expr): + keys: list[Optional[Expr]] + values: list[Expr] + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_dict_expr(self) + + @dataclass(frozen=True) class SubscriptExpr(Expr): object: Expr diff --git a/midas/checker/builtins.midas b/midas/checker/builtins.midas index 6e89172..3110cbf 100644 --- a/midas/checker/builtins.midas +++ b/midas/checker/builtins.midas @@ -150,3 +150,32 @@ extend list[T] { prop __doc__: str } + +extend dict[K, V] { + def copy: fn() -> dict[K, V] + def keys: fn() -> list[K] // TODO: use builtin types + def values: fn() -> list[V] // TODO: use builtin types + // def items: fn() -> list[tuple[K, V]] // TODO: use builtin types + + // def get: fn(key: K, default: None = None, /) -> V | None + def get: fn(key: K, default: V, /) -> V + // def get: fn[T](key: K, default: T, /) -> V | T + def pop: fn(key: K, /) -> V + def pop: fn(key: K, default: V, /) -> V + // def pop: fn[T](key: K, default: T, /) -> V | T + def __len__: fn() -> int + def __getitem__: fn(key: K, /) -> V + def __setitem__: fn(key: K, value: V, /) -> None + def __delitem__: fn(key: K, /) -> None + // def __iter__: fn() -> Iterator[K] + def __eq__: fn(value: object, /) -> bool + // def __reversed__: fn() -> Iterator[K] + + def __or__: fn(value: dict[K, V], /) -> dict[K, V] + // def __or__: fn[K2, V2](value: dict[K2, V2], /) -> dict[K | K2, V | V2] + def __ror__: fn(value: dict[K, V], /) -> dict[K, V] + // def __ror__: fn[K2, V2](value: dict[K2, V2], /) -> dict[K | K2, V | V2] + // def __ior__: fn(value: SupportsKeysAndGetItem[K, V], /) -> dict[K, V] + // def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V] + +} \ No newline at end of file diff --git a/midas/checker/builtins.py b/midas/checker/builtins.py index b1adf6d..7bf1d98 100644 --- a/midas/checker/builtins.py +++ b/midas/checker/builtins.py @@ -39,3 +39,14 @@ def define_builtins(reg: TypesRegistry): body=BaseType(name="list"), ), ) + dict = reg.define_type( + "dict", + GenericType( + name="dict", + params=[ + TypeVar(name="K", bound=None), + TypeVar(name="V", bound=None), + ], + body=BaseType(name="dict"), + ), + ) diff --git a/midas/checker/preamble.py b/midas/checker/preamble.py index a543dd9..ea7001b 100644 --- a/midas/checker/preamble.py +++ b/midas/checker/preamble.py @@ -61,7 +61,7 @@ class Preamble(Environment): # TODO: more specific arg types self._def_function( name=name, - pos=[Param("object", TopType())], + pos=[Param("object", TopType(), required=False)], returns=self._types.get_type(name), ) diff --git a/midas/checker/python.py b/midas/checker/python.py index 5679b19..c4bffff 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -552,6 +552,46 @@ class PythonTyper( ) return self.types.apply_generic(list_type, [UnknownType()]) + def visit_dict_expr(self, expr: p.DictExpr) -> Type: + dict_type: Type = self.types.get_type("dict") + + key_types: list[Type] = [] + value_types: list[Type] = [] + for key, value in zip(expr.keys, expr.values): + if key is None: + self.reporter.warning( + value.location, "Dictionary unpacking not supported" + ) + continue + key_types.append(self.type_of(key)) + value_types.append(self.type_of(value)) + + key_types = self.types.reduce_types(key_types) + value_types = self.types.reduce_types(value_types) + + if len(key_types) == 0 or len(value_types) == 0: + return dict_type + + key_type: Type = UnknownType() + value_type: Type = UnknownType() + + if len(key_types) == 1: + key_type = key_types[0] + else: + self.reporter.error( + expr.location, + f"Heterogeneous dict keys: {key_types}", + ) + + if len(value_types) == 1: + value_type = value_types[0] + else: + self.reporter.error( + expr.location, + f"Heterogeneous dict values: {value_types}", + ) + return self.types.apply_generic(dict_type, [key_type, value_type]) + def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type: object: Type = self.type_of(expr.object) operation: Optional[Type] = self.types.lookup_member(object, "__getitem__") diff --git a/midas/checker/resolver.py b/midas/checker/resolver.py index 3226faf..3bf73d7 100644 --- a/midas/checker/resolver.py +++ b/midas/checker/resolver.py @@ -213,6 +213,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): for item in expr.items: self.resolve(item) + def visit_dict_expr(self, expr: p.DictExpr) -> None: + for key in expr.keys: + if key is not None: + self.resolve(key) + for value in expr.values: + self.resolve(value) + def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: self.resolve(expr.object) self.resolve(expr.index) diff --git a/midas/generator/generator.py b/midas/generator/generator.py index a2eb5c2..7575ca5 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -4,8 +4,8 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Optional -from midas.ast.location import Location import midas.ast.python as p +from midas.ast.location import Location from midas.checker.types import ( AliasType, AppliedType, @@ -139,6 +139,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): elts=[item.accept(self) for item in expr.items], ) + def visit_dict_expr(self, expr: p.DictExpr) -> ast.expr: + return ast.Dict( + keys=[key.accept(self) if key is not None else None for key in expr.keys], + values=[value.accept(self) for value in expr.values], + ) + def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr: return ast.Subscript( value=expr.object.accept(self), diff --git a/midas/parser/python.py b/midas/parser/python.py index 90c029a..4110feb 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -10,6 +10,7 @@ from midas.ast.python import ( CastExpr, CompareExpr, ConstraintType, + DictExpr, Expr, ExpressionStmt, ForStmt, @@ -447,6 +448,16 @@ class PythonParser: items=[self.parse_expr(item) for item in items], ) + case ast.Dict(keys=keys, values=values): + return DictExpr( + location=location, + keys=[ + self.parse_expr(key) if key is not None else None + for key in keys + ], + values=[self.parse_expr(value) for value in values], + ) + case ast.Subscript(value=value, slice=index): return SubscriptExpr( location=location, diff --git a/tests/serializer/python.py b/tests/serializer/python.py index 45951df..038b496 100644 --- a/tests/serializer/python.py +++ b/tests/serializer/python.py @@ -9,6 +9,7 @@ from midas.ast.python import ( CastExpr, CompareExpr, ConstraintType, + DictExpr, Expr, ExpressionStmt, ForStmt, @@ -278,6 +279,13 @@ class PythonAstJsonSerializer( "items": [item.accept(self) for item in expr.items], } + def visit_dict_expr(self, expr: DictExpr) -> dict: + return { + "_type": "DictExpr", + "keys": [self._serialize_optional(key) for key in expr.keys], + "values": self._serialize_list(expr.values), + } + def visit_subscript_expr(self, expr: SubscriptExpr) -> dict: return { "_type": "SubscriptExpr",