diff --git a/gen/python.py b/gen/python.py index e6d08c9..79ba8b0 100644 --- a/gen/python.py +++ b/gen/python.py @@ -139,4 +139,8 @@ class TernaryExpr: if_false: Expr +class ListExpr: + items: list[Expr] + + ###< diff --git a/midas/ast/printer.py b/midas/ast/printer.py index f8fb411..dc2e64c 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -626,3 +626,14 @@ class PythonAstPrinter( self._write_line("if_false", last=True) with self._child_level(single=True): expr.if_false.accept(self) + + def visit_list_expr(self, expr: p.ListExpr) -> None: + self._write_line("ListExpr") + with self._child_level(): + self._write_line("items", last=True) + with self._child_level(): + for i, item in enumerate(expr.items): + self._idx = i + if i == len(expr.items) - 1: + self._mark_last() + item.accept(self) diff --git a/midas/ast/python.py b/midas/ast/python.py index dd5d905..1aea8ed 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -220,6 +220,9 @@ class Expr(ABC): @abstractmethod def visit_ternary_expr(self, expr: TernaryExpr) -> T: ... + @abstractmethod + def visit_list_expr(self, expr: ListExpr) -> T: ... + @dataclass(frozen=True) class BinaryExpr(Expr): @@ -312,3 +315,11 @@ class TernaryExpr(Expr): def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_ternary_expr(self) + + +@dataclass(frozen=True) +class ListExpr(Expr): + items: list[Expr] + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_list_expr(self) diff --git a/midas/checker/builtins.py b/midas/checker/builtins.py index 24dc288..f20eb50 100644 --- a/midas/checker/builtins.py +++ b/midas/checker/builtins.py @@ -2,7 +2,15 @@ from __future__ import annotations from typing import TYPE_CHECKING -from midas.checker.types import BaseType, Type, UnitType +from midas.checker.types import ( + BaseType, + ComplexType, + Function, + GenericType, + Type, + TypeVar, + UnitType, +) if TYPE_CHECKING: from midas.checker.registry import TypesRegistry @@ -76,3 +84,29 @@ def define_builtins(reg: TypesRegistry): op(reg, float, "__le__", int, bool) # float <= int = bool op(reg, float, "__ge__", int, bool) # float >= int = bool op(reg, float, "__eq__", int, bool) # float == int = bool + + list = reg.define_type( + "list", + GenericType( + name="list", + params=[TypeVar(name="T", bound=None)], + body=ComplexType( + properties={ + "append": Function( + name="append", + pos_args=[ + Function.Argument( + pos=0, + name="object", + type=TypeVar(name="T", bound=None), + required=True, + ) + ], + args=[], + kw_args=[], + returns=UnitType(), + ) + } + ), + ), + ) diff --git a/midas/checker/python.py b/midas/checker/python.py index e1812d3..63a076e 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -499,6 +499,40 @@ class PythonTyper( ) return UnknownType() + def visit_list_expr(self, expr: p.ListExpr) -> Type: + list_type: Type = self.types.get_type("list") + item_types: list[Type] = [self.type_of(item) for item in expr.items] + + # Try to reduce types with subsumption + reduced: bool = True + keep: list[int] = list(range(len(item_types))) + while reduced: + reduced = False + for i, i1 in enumerate(keep): + type1: Type = item_types[i1] + for i2 in keep[i + 1 :]: + type2 = item_types[i2] + if self.types.is_subtype(type1, type2): + keep.remove(i1) + elif self.types.is_subtype(type2, type1): + keep.remove(i2) + else: + continue + reduced = True + break + + if len(keep) == 0: + return list_type + + if len(keep) == 1: + item_type: Type = item_types[keep[0]] + return self.types.apply_generic(list_type, [item_type]) + self.reporter.error( + expr.location, + f"Heterogeneous list items: {[item_types[i] for i in keep]}", + ) + return self.types.apply_generic(list_type, [UnknownType()]) + def visit_base_type(self, node: p.BaseType) -> Type: return self.types.get_type(node.base) diff --git a/midas/checker/resolver.py b/midas/checker/resolver.py index 18fcba4..0b7d990 100644 --- a/midas/checker/resolver.py +++ b/midas/checker/resolver.py @@ -180,3 +180,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): self.resolve(expr.test) self.resolve(expr.if_true) self.resolve(expr.if_false) + + def visit_list_expr(self, expr: p.ListExpr) -> None: + for item in expr.items: + self.resolve(item) diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index e4a9556..0d6a018 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -214,6 +214,10 @@ class PythonHighlighter( def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ... + def visit_list_expr(self, expr: p.ListExpr) -> None: + for item in expr.items: + item.accept(self) + class MidasHighlighter( Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None] diff --git a/midas/parser/python.py b/midas/parser/python.py index 79011bc..bbe23c8 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -17,6 +17,7 @@ from midas.ast.python import ( Function, GetExpr, IfStmt, + ListExpr, LiteralExpr, LogicalExpr, MidasType, @@ -416,6 +417,12 @@ class PythonParser: case ast.Name(id=name): return VariableExpr(location=location, name=name) + case ast.List(elts=items): + return ListExpr( + location=location, + items=[self.parse_expr(item) for item in items], + ) + case _: raise UnsupportedSyntaxError(node) diff --git a/tests/serializer/python.py b/tests/serializer/python.py index bab3f8c..833d4e4 100644 --- a/tests/serializer/python.py +++ b/tests/serializer/python.py @@ -16,6 +16,7 @@ from midas.ast.python import ( Function, GetExpr, IfStmt, + ListExpr, LiteralExpr, LogicalExpr, MidasType, @@ -245,3 +246,9 @@ class PythonAstJsonSerializer( "if_true": expr.if_true.accept(self), "if_false": expr.if_false.accept(self), } + + def visit_list_expr(self, expr: ListExpr) -> dict: + return { + "_type": "ListExpr", + "items": [item.accept(self) for item in expr.items], + }