diff --git a/examples/01_simple_type_checking/01_simple_operations.py b/examples/01_simple_type_checking/01_simple_operations.py index a3ac707..4e767f2 100644 --- a/examples/01_simple_type_checking/01_simple_operations.py +++ b/examples/01_simple_type_checking/01_simple_operations.py @@ -9,3 +9,5 @@ d = True e = d + d f: float = a + +f = -f diff --git a/examples/01_simple_type_checking/02_simple_types.midas b/examples/01_simple_type_checking/02_simple_types.midas index 6a1a6a2..ff4edb1 100644 --- a/examples/01_simple_type_checking/02_simple_types.midas +++ b/examples/01_simple_type_checking/02_simple_types.midas @@ -3,12 +3,12 @@ type Second = float type MeterPerSecond = float extend Meter { - op __add__(Meter) -> Meter - op __sub__(Meter) -> Meter - op __truediv__(Second) -> MeterPerSecond + def __add__: fn(Meter, /) -> Meter + def __sub__: fn(Meter, /) -> Meter + def __truediv__: fn(Second, /) -> MeterPerSecond } extend Second { - op __add__(Second) -> Second - op __sub__(Second) -> Second + def __add__: fn(Second, /) -> Second + def __sub__: fn(Second, /) -> Second } diff --git a/examples/01_simple_type_checking/04_complex_types.midas b/examples/01_simple_type_checking/04_complex_types.midas index b920c37..adc76b3 100644 --- a/examples/01_simple_type_checking/04_complex_types.midas +++ b/examples/01_simple_type_checking/04_complex_types.midas @@ -1,11 +1,21 @@ type Meter = float extend Meter { - op __add__(Meter) -> Meter - op __sub__(Meter) -> Meter + def __add__: fn(Meter, /) -> Meter + def __sub__: fn(Meter, /) -> Meter } -type Coordinate = { - x: Meter - y: Meter +type Coordinate = object + +extend Coordinate { + prop x: Meter + prop y: Meter +} + +type Difference[T <: float] = T +type MeterDifference = Difference[Meter] + +type CompDiff[T <: float] = { + prop d1: Difference[T] + prop d2: Difference[T] } \ No newline at end of file diff --git a/examples/01_simple_type_checking/04_complex_types.py b/examples/01_simple_type_checking/04_complex_types.py index f36ef52..f1d1215 100644 --- a/examples/01_simple_type_checking/04_complex_types.py +++ b/examples/01_simple_type_checking/04_complex_types.py @@ -1,5 +1,6 @@ # type: ignore # ruff: disable [F821] + p1: Coordinate p2: Coordinate @@ -9,3 +10,28 @@ diff_y = p2.y - p1.y dist = diff_x + diff_y p2.x += cast(Meter, 1) +p2.y = True # invalid, wrong type +p2.z = 3 # invalid, no property 'z' on Coordinate +p2.x.a = 3 # invalid, no properties on Meter + +foo: list[float] = [] + +append = foo.append + +foo.append("") # invalid, must be float +foo.append(2) +append(True) # invalid, must be float +append(2) + +bar: list[list[Meter]] + +bar.append([p2.x]) + +foo2 = foo + foo + +a = foo[0] +b = bar[0][1] +c = bar[0][1][2] # invalid, not method __getitem__ on Meter +c = bar[""] # invalid, wrong index type + +d = foo[1:2] diff --git a/examples/01_simple_type_checking/05_functions.py b/examples/01_simple_type_checking/05_functions.py new file mode 100644 index 0000000..9c04813 --- /dev/null +++ b/examples/01_simple_type_checking/05_functions.py @@ -0,0 +1,28 @@ +def incr(value: int): + return value + 1 + + +def decr(value: int): + return value - 1 + + +def foo(a: int, /, b: float, *, c: str): + return True + + +r1 = foo() # foo() missing 2 required positional arguments: 'a' and 'b' +r2 = foo(1) # foo() missing 1 required positional argument: 'b' +r3 = foo(1, 2.0) # foo() missing 1 required keyword-only argument: 'c' +r4 = foo(1, b=2.0) # foo() missing 1 required keyword-only argument: 'c' +r5 = foo(1, 2.0, "test") # foo() takes 2 positional arguments but 3 were given +r6 = foo(1, 2.0, b=3.0) # foo() got multiple values for argument 'b' +r7 = foo( + a=1 +) # foo() got some positional-only arguments passed as keyword arguments: 'a' +r8 = foo(g="test") # foo() got an unexpected keyword argument 'g' + +r9a = foo(1, 2.0, c="test") +r9b = foo(1, b=2.0, c="test") +r9c = foo(1, c="test", b=2.0) + +r10 = foo("a", 3, c=False) # wrong argument types diff --git a/examples/01_simple_type_checking/06_overloads.midas b/examples/01_simple_type_checking/06_overloads.midas new file mode 100644 index 0000000..777c410 --- /dev/null +++ b/examples/01_simple_type_checking/06_overloads.midas @@ -0,0 +1,10 @@ +type T1 = object +type T2 = object +type Foo = object +type T2b = T2 + +extend Foo { + def bar: fn(T1, /) -> int + def bar: fn(T2, /) -> float + def bar: fn(T2b, /) -> int +} diff --git a/examples/01_simple_type_checking/06_overloads.py b/examples/01_simple_type_checking/06_overloads.py new file mode 100644 index 0000000..86406e0 --- /dev/null +++ b/examples/01_simple_type_checking/06_overloads.py @@ -0,0 +1,18 @@ +# type: ignore +# ruff: disable [F821] + +foo: Foo +t1: T1 +t2: T2 + +a = foo.bar(t1) +b = foo.bar(t2) + +func = foo.bar + +c = func(t1) +d = func(t2) + +t2b: T2b + +e = foo.bar(t2b) diff --git a/gen/gen.py b/gen/gen.py index e78c872..50c9c9d 100644 --- a/gen/gen.py +++ b/gen/gen.py @@ -30,6 +30,7 @@ from __future__ import annotations T = TypeVar("T") +{preamble} {sections} """ @@ -57,6 +58,11 @@ IMPORTS_REGEX = re.compile( re.MULTILINE | re.DOTALL, ) +PREAMBLE_REGEX = re.compile( + r"^###>\s*Preamble\s*?\n(?P.*?)\n###<$", + re.MULTILINE | re.DOTALL, +) + def snake_case(text: str) -> str: return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_") @@ -88,13 +94,14 @@ def make_banner(text: str) -> str: def make_section(full_name: str, base: str, param: str, body: str) -> str: + print(f" Generating {full_name}") visitor_methods: list[str] = [] classes: list[str] = [] definitions: list[str] = body.strip("\n").split("\n\n\n") for cls in definitions: cls = cls.strip("\n") name: str = re.match("class (.*?):", cls).group(1) # type: ignore - print(f"Processing {name}") + print(f" Processing {name}") visitor_methods.append(make_visitor_method(name, param)) classes.append(make_class(name, cls, base)) @@ -107,6 +114,7 @@ def make_section(full_name: str, base: str, param: str, body: str) -> str: def generate(definitions_path: Path, out_path: Path): + print(f"Processing generating {out_path} from {definitions_path}") root_dir: Path = Path(__file__).parent.parent rel_path: Path = definitions_path.relative_to(root_dir) src: str = definitions_path.read_text() @@ -116,6 +124,10 @@ def generate(definitions_path: Path, out_path: Path): if m := IMPORTS_REGEX.search(src): imports = m.group("body").strip("\n") + preamble: str = "" + if m := PREAMBLE_REGEX.search(src): + preamble = m.group("body") + for section_m in SECTION_REGEX.finditer(src): full_name: str = section_m.group("name") base: str = section_m.group("base") @@ -129,6 +141,7 @@ def generate(definitions_path: Path, out_path: Path): gen_path=Path(__file__).relative_to(root_dir), ), imports=imports, + preamble=preamble, sections="\n\n\n".join(sections), ) out_path.write_text(result) diff --git a/gen/midas.py b/gen/midas.py index e1c304d..42caf4f 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -4,6 +4,7 @@ ###> Imports from abc import ABC, abstractmethod from dataclasses import dataclass +from enum import Enum, auto from typing import Any, Generic, Optional, TypeVar from midas.ast.location import Location @@ -12,33 +13,39 @@ from midas.lexer.token import Token ###< +###> Preamble +@dataclass(frozen=True, kw_only=True) +class TypeParam: + location: Location + name: Token + bound: Optional[Type] + + +class MemberKind(Enum): + PROPERTY = auto() + METHOD = auto() + + +###< + + ###> Stmt | Statements class TypeStmt: name: Token - params: list[Param] + params: list[TypeParam] type: Type - @dataclass(frozen=True, kw_only=True) - class Param: - location: Location - name: Token - bound: Optional[Type] - -class PropertyStmt: +class MemberStmt: name: Token type: Type + kind: MemberKind class ExtendStmt: - type: Type - operations: list[OpStmt] - - -class OpStmt: name: Token - operand: Type - result: Type + params: list[TypeParam] + members: list[MemberStmt] class PredicateStmt: @@ -103,7 +110,7 @@ class NamedType: class GenericType: type: Type - params: list[Type] + args: list[Type] class ConstraintType: @@ -112,7 +119,26 @@ class ConstraintType: class ComplexType: - properties: list[PropertyStmt] + members: list[MemberStmt] + + +class ExtensionType: + base: Type + extension: ComplexType + + +class FunctionType: + pos_args: list[Argument] + args: list[Argument] + kw_args: list[Argument] + returns: Type + + @dataclass(frozen=True, kw_only=True) + class Argument: + location: Optional[Location] = None + name: Optional[Token] + type: Type + required: bool ###< diff --git a/gen/python.py b/gen/python.py index e6d08c9..35908f7 100644 --- a/gen/python.py +++ b/gen/python.py @@ -139,4 +139,19 @@ class TernaryExpr: if_false: Expr +class ListExpr: + items: list[Expr] + + +class SubscriptExpr: + object: Expr + index: Expr + + +class SliceExpr: + lower: Optional[Expr] + upper: Optional[Expr] + step: Optional[Expr] + + ###< diff --git a/midas/ast/midas.py b/midas/ast/midas.py index 335e5cf..e71aff9 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -7,6 +7,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass +from enum import Enum, auto from typing import Any, Generic, Optional, TypeVar from midas.ast.location import Location @@ -14,6 +15,18 @@ from midas.lexer.token import Token T = TypeVar("T") +@dataclass(frozen=True, kw_only=True) +class TypeParam: + location: Location + name: Token + bound: Optional[Type] + + +class MemberKind(Enum): + PROPERTY = auto() + METHOD = auto() + + ############## # Statements # ############## @@ -31,14 +44,11 @@ class Stmt(ABC): def visit_type_stmt(self, stmt: TypeStmt) -> T: ... @abstractmethod - def visit_property_stmt(self, stmt: PropertyStmt) -> T: ... + def visit_member_stmt(self, stmt: MemberStmt) -> T: ... @abstractmethod def visit_extend_stmt(self, stmt: ExtendStmt) -> T: ... - @abstractmethod - def visit_op_stmt(self, stmt: OpStmt) -> T: ... - @abstractmethod def visit_predicate_stmt(self, stmt: PredicateStmt) -> T: ... @@ -46,47 +56,33 @@ class Stmt(ABC): @dataclass(frozen=True) class TypeStmt(Stmt): name: Token - params: list[Param] + params: list[TypeParam] type: Type - @dataclass(frozen=True, kw_only=True) - class Param: - location: Location - name: Token - bound: Optional[Type] - def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_type_stmt(self) @dataclass(frozen=True) -class PropertyStmt(Stmt): +class MemberStmt(Stmt): name: Token type: Type + kind: MemberKind def accept(self, visitor: Stmt.Visitor[T]) -> T: - return visitor.visit_property_stmt(self) + return visitor.visit_member_stmt(self) @dataclass(frozen=True) class ExtendStmt(Stmt): - type: Type - operations: list[OpStmt] + name: Token + params: list[TypeParam] + members: list[MemberStmt] def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_extend_stmt(self) -@dataclass(frozen=True) -class OpStmt(Stmt): - name: Token - operand: Type - result: Type - - def accept(self, visitor: Stmt.Visitor[T]) -> T: - return visitor.visit_op_stmt(self) - - @dataclass(frozen=True) class PredicateStmt(Stmt): name: Token @@ -231,6 +227,12 @@ class Type(ABC): @abstractmethod def visit_complex_type(self, type: ComplexType) -> T: ... + @abstractmethod + def visit_extension_type(self, type: ExtensionType) -> T: ... + + @abstractmethod + def visit_function_type(self, type: FunctionType) -> T: ... + @dataclass(frozen=True) class NamedType(Type): @@ -243,7 +245,7 @@ class NamedType(Type): @dataclass(frozen=True) class GenericType(Type): type: Type - params: list[Type] + args: list[Type] def accept(self, visitor: Type.Visitor[T]) -> T: return visitor.visit_generic_type(self) @@ -260,7 +262,34 @@ class ConstraintType(Type): @dataclass(frozen=True) class ComplexType(Type): - properties: list[PropertyStmt] + members: list[MemberStmt] def accept(self, visitor: Type.Visitor[T]) -> T: return visitor.visit_complex_type(self) + + +@dataclass(frozen=True) +class ExtensionType(Type): + base: Type + extension: ComplexType + + def accept(self, visitor: Type.Visitor[T]) -> T: + return visitor.visit_extension_type(self) + + +@dataclass(frozen=True) +class FunctionType(Type): + pos_args: list[Argument] + args: list[Argument] + kw_args: list[Argument] + returns: Type + + @dataclass(frozen=True, kw_only=True) + class Argument: + location: Optional[Location] = None + name: Optional[Token] + type: Type + required: bool + + def accept(self, visitor: Type.Visitor[T]) -> T: + return visitor.visit_function_type(self) diff --git a/midas/ast/printer.py b/midas/ast/printer.py index f8fb411..e52472c 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -100,20 +100,21 @@ class MidasAstPrinter( self._idx = i if i == len(stmt.params) - 1: self._mark_last() - self._print_type_stmt_param(param) + self._print_type_param(param) self._write_line("type", last=True) with self._child_level(single=True): stmt.type.accept(self) - def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None: + def _print_type_param(self, param: m.TypeParam) -> None: self._write_line("Param") with self._child_level(): self._write_line(f'name: "{param.name.lexeme}"') self._write_optional_child("bound", param.bound, last=True) - def visit_property_stmt(self, stmt: m.PropertyStmt): - self._write_line("PropertyStmt") + def visit_member_stmt(self, stmt: m.MemberStmt): + self._write_line("MemberStmt") with self._child_level(): + self._write_line(f"kind: {stmt.kind.name}") self._write_line(f'name: "{stmt.name.lexeme}"') self._write_line("type", last=True) with self._child_level(single=True): @@ -122,29 +123,28 @@ class MidasAstPrinter( def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: self._write_line("ExtendStmt") with self._child_level(): - self._write_line("type") - with self._child_level(single=True): - stmt.type.accept(self) - self._write_line("operations", last=True) + self._write_line("params") with self._child_level(): - for i, op in enumerate(stmt.operations): + for i, param in enumerate(stmt.params): self._idx = i - if i == len(stmt.operations) - 1: + if i == len(stmt.params) - 1: self._mark_last() - op.accept(self) - - def visit_op_stmt(self, stmt: m.OpStmt) -> None: - self._write_line("OpStmt") - with self._child_level(): + self._print_type_param(param) self._write_line(f'name: "{stmt.name.lexeme}"') - - self._write_line("operand") - with self._child_level(single=True): - stmt.operand.accept(self) - - self._write_line("result", last=True) - with self._child_level(single=True): - stmt.result.accept(self) + self._write_line("params") + with self._child_level(): + for i, param in enumerate(stmt.params): + self._idx = i + if i == len(stmt.params) - 1: + self._mark_last() + self._print_type_param(param) + self._write_line("members", last=True) + with self._child_level(): + for i, member in enumerate(stmt.members): + self._idx = i + if i == len(stmt.members) - 1: + self._mark_last() + member.accept(self) def visit_predicate_stmt(self, stmt: m.PredicateStmt): self._write_line("PredicateStmt") @@ -234,11 +234,11 @@ class MidasAstPrinter( self._write_line("type") with self._child_level(): type.type.accept(self) - self._write_line("params", last=True) + self._write_line("args", last=True) with self._child_level(): - for i, param in enumerate(type.params): + for i, param in enumerate(type.args): self._idx = i - if i == len(type.params) - 1: + if i == len(type.args) - 1: self._mark_last() param.accept(self) @@ -255,13 +255,66 @@ class MidasAstPrinter( def visit_complex_type(self, type: m.ComplexType) -> None: self._write_line("ComplexType") with self._child_level(): - self._write_line("properties", last=True) + self._write_line("members", last=True) with self._child_level(): - for i, prop in enumerate(type.properties): + for i, member in enumerate(type.members): self._idx = i - if i == len(type.properties) - 1: + if i == len(type.members) - 1: self._mark_last() - prop.accept(self) + member.accept(self) + + def visit_extension_type(self, type: m.ExtensionType) -> None: + self._write_line("ExtensionType") + with self._child_level(): + self._write_line("base") + with self._child_level(single=True): + type.base.accept(self) + self._write_line("extension", last=True) + with self._child_level(single=True): + type.extension.accept(self) + + def visit_function_type(self, type: m.FunctionType) -> None: + self._write_line("FunctionType") + with self._child_level(): + self._write_line("pos_args") + with self._child_level(): + for i, arg in enumerate(type.pos_args): + self._idx = i + if i == len(type.pos_args) - 1: + self._mark_last() + self._print_function_arg(arg) + + self._write_line("args") + with self._child_level(): + for i, arg in enumerate(type.args): + self._idx = i + if i == len(type.args) - 1: + self._mark_last() + self._print_function_arg(arg) + + self._write_line("kw_args") + with self._child_level(): + for i, arg in enumerate(type.kw_args): + self._idx = i + if i == len(type.kw_args) - 1: + self._mark_last() + self._print_function_arg(arg) + + self._write_line("returns", last=True) + with self._child_level(single=True): + type.returns.accept(self) + + def _print_function_arg(self, arg: m.FunctionType.Argument) -> None: + self._write_line("Argument") + with self._child_level(): + name: str = "None" + if arg.name is not None: + name = f'"{arg.name.lexeme}"' + self._write_line(f"name: {name}") + self._write_line("type") + with self._child_level(single=True): + arg.type.accept(self) + self._write_line(f"required: {arg.required}", last=True) class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]): @@ -279,38 +332,39 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] def visit_type_stmt(self, stmt: m.TypeStmt) -> str: template: str = "" if len(stmt.params) != 0: - params: list[str] = [ - self._print_type_template_param(param) for param in stmt.params - ] + params: list[str] = [self._print_type_param(param) for param in stmt.params] template = f"[{', '.join(params)}]" res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}" return self.indented(res) - def _print_type_template_param(self, param: m.TypeStmt.Param) -> str: + def _print_type_param(self, param: m.TypeParam) -> str: res: str = param.name.lexeme if param.bound is not None: res += "<:" + param.bound.accept(self) return res - def visit_property_stmt(self, stmt: m.PropertyStmt): - res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}" + def visit_member_stmt(self, stmt: m.MemberStmt): + keyword: str = { + m.MemberKind.PROPERTY: "prop", + m.MemberKind.METHOD: "def", + }.get(stmt.kind, "") + res: str = f"{keyword} {stmt.name.lexeme}: {stmt.type.accept(self)}" return self.indented(res) def visit_extend_stmt(self, stmt: m.ExtendStmt): - res: str = self.indented(f"extend {stmt.type.accept(self)}") + template: str = "" + if len(stmt.params) != 0: + params: list[str] = [self._print_type_param(param) for param in stmt.params] + template = f"[{', '.join(params)}]" + res: str = self.indented(f"extend {stmt.name.lexeme}{template}") res += " {\n" self.level += 1 - for op in stmt.operations: - res += op.accept(self) + for member in stmt.members: + res += member.accept(self) + "\n" self.level -= 1 res += self.indented("}") return res - def visit_op_stmt(self, stmt: m.OpStmt): - operand: str = stmt.operand.accept(self) - result: str = stmt.result.accept(self) - return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}\n") - def visit_predicate_stmt(self, stmt: m.PredicateStmt): name: str = stmt.name.lexeme subject: str = stmt.subject.lexeme @@ -358,9 +412,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] def visit_generic_type(self, type: m.GenericType) -> str: res: str = type.type.accept(self) - if len(type.params) != 0: - params: list[str] = [param.accept(self) for param in type.params] - res += f"[{', '.join(params)}]" + if len(type.args) != 0: + args: list[str] = [param.accept(self) for param in type.args] + res += f"[{', '.join(args)}]" return res def visit_constraint_type(self, type: m.ConstraintType) -> str: @@ -371,13 +425,41 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] def visit_complex_type(self, type: m.ComplexType) -> str: res: str = "{\n" self.level += 1 - for prop in type.properties: - res += prop.accept(self) + for member in type.members: + res += member.accept(self) res += "\n" self.level -= 1 res += self.indented("}") return res + def visit_extension_type(self, type: m.ExtensionType) -> str: + return f"{type.base.accept(self)} & {type.extension.accept(self)}" + + def visit_function_type(self, type: m.FunctionType) -> str: + pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args] + mixed_args: list[str] = [self._print_arg(arg) for arg in type.args] + kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args] + args: list[str] = pos_args + + if len(pos_args) != 0: + args.append("/") + args += mixed_args + if len(kw_args) != 0: + args.append("*") + args += kw_args + + return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}" + + def _print_arg(self, arg: m.FunctionType.Argument) -> str: + res: str = "" + if arg.name is not None: + res += arg.name.lexeme + res += ": " + res += arg.type.accept(self) + if not arg.required: + res += "?" + return res + class PythonAstPrinter( AstPrinter, @@ -582,7 +664,7 @@ class PythonAstPrinter( def visit_literal_expr(self, expr: p.LiteralExpr) -> None: self._write_line("LiteralExpr") with self._child_level(single=True): - self._write_line(f"value: {expr.value}") + self._write_line(f"value: {expr.value!r}") def visit_variable_expr(self, expr: p.VariableExpr) -> None: self._write_line("VariableExpr") @@ -626,3 +708,31 @@ class PythonAstPrinter( self._write_line("if_false", last=True) with self._child_level(single=True): expr.if_false.accept(self) + + def visit_list_expr(self, expr: p.ListExpr) -> None: + self._write_line("ListExpr") + with self._child_level(): + self._write_line("items", last=True) + with self._child_level(): + for i, item in enumerate(expr.items): + self._idx = i + if i == len(expr.items) - 1: + self._mark_last() + item.accept(self) + + def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: + self._write_line("SubscriptExpr") + with self._child_level(): + self._write_line("object") + with self._child_level(single=True): + expr.object.accept(self) + self._write_line("index", last=True) + with self._child_level(single=True): + expr.index.accept(self) + + def visit_slice_expr(self, expr: p.SliceExpr) -> None: + self._write_line("SliceExpr") + with self._child_level(): + self._write_optional_child("lower", expr.lower) + self._write_optional_child("upper", expr.upper) + self._write_optional_child("step", expr.step, last=True) diff --git a/midas/ast/python.py b/midas/ast/python.py index dd5d905..f025e2f 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -14,6 +14,7 @@ from midas.ast.location import Location T = TypeVar("T") + #################### # Type annotations # #################### @@ -220,6 +221,15 @@ class Expr(ABC): @abstractmethod def visit_ternary_expr(self, expr: TernaryExpr) -> T: ... + @abstractmethod + def visit_list_expr(self, expr: ListExpr) -> T: ... + + @abstractmethod + def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ... + + @abstractmethod + def visit_slice_expr(self, expr: SliceExpr) -> T: ... + @dataclass(frozen=True) class BinaryExpr(Expr): @@ -312,3 +322,30 @@ class TernaryExpr(Expr): def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_ternary_expr(self) + + +@dataclass(frozen=True) +class ListExpr(Expr): + items: list[Expr] + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_list_expr(self) + + +@dataclass(frozen=True) +class SubscriptExpr(Expr): + object: Expr + index: Expr + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_subscript_expr(self) + + +@dataclass(frozen=True) +class SliceExpr(Expr): + lower: Optional[Expr] + upper: Optional[Expr] + step: Optional[Expr] + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_slice_expr(self) diff --git a/midas/checker/builtins.midas b/midas/checker/builtins.midas new file mode 100644 index 0000000..6e89172 --- /dev/null +++ b/midas/checker/builtins.midas @@ -0,0 +1,152 @@ +extend float { + def hex: fn() -> str + def is_integer: fn() -> bool + prop real: float + prop imag: float + def conjugate: fn() -> float + def __add__: fn(value: float, /) -> float + def __sub__: fn(value: float, /) -> float + def __mul__: fn(value: float, /) -> float + def __floordiv__: fn(value: float, /) -> float + def __truediv__: fn(value: float, /) -> float + def __mod__: fn(value: float, /) -> float + // def __divmod__: fn(value: float, /) -> tuple[float, float] + + def __pow__: fn(value: int, /) -> float + // positive __value -> float; negative __value -> complex + // return type must be Any as `float | complex` causes too many false-positive errors + def __pow__: fn(value: float, /) -> Any + def __radd__: fn(value: float, /) -> float + def __rsub__: fn(value: float, /) -> float + def __rmul__: fn(value: float, /) -> float + def __rfloordiv__: fn(value: float, /) -> float + def __rtruediv__: fn(value: float, /) -> float + def __rmod__: fn(value: float, /) -> float + // def __rdivmod__: fn(value: float, /) -> tuple[float, float] + // def __rpow__: fn(value: _PositiveInteger, mod: None = None, /) -> float + // def __rpow__: fn(value: _NegativeInteger, mod: None = None, /) -> complex + // Returning `complex` for the general case gives too many false-positive errors. + // def __rpow__: fn(value: float, mod: None = None, /) -> Any + // def __getnewargs__: fn() -> tuple[float] + def __trunc__: fn() -> int + def __ceil__: fn() -> int + def __floor__: fn() -> int + def __round__: fn(ndigits: None?, /) -> int + def __round__: fn(ndigits: int, /) -> float + def __eq__: fn(value: object, /) -> bool + def __ne__: fn(value: object, /) -> bool + def __lt__: fn(value: float, /) -> bool + def __le__: fn(value: float, /) -> bool + def __gt__: fn(value: float, /) -> bool + def __ge__: fn(value: float, /) -> bool + def __neg__: fn() -> float + def __pos__: fn() -> float + def __int__: fn() -> int + def __float__: fn() -> float + def __abs__: fn() -> float + def __hash__: fn() -> int + def __bool__: fn() -> bool + def __format__: fn(format_spec: str, /) -> str +} + +extend int { + prop real: int + prop imag: int + prop numerator: int + prop denominator: int + def conjugate: fn() -> int + def bit_length: fn() -> int + def bit_count: fn() -> int + // def to_bytes: fn(length: int?, byteorder: str?, *, signed: bool?) -> bytes + + def __add__: fn(value: int, /) -> int + def __sub__: fn(value: int, /) -> int + def __mul__: fn(value: int, /) -> int + def __floordiv__: fn(value: int, /) -> int + def __truediv__: fn(value: int, /) -> float + def __mod__: fn(value: int, /) -> int + // def __divmod__: fn(value: int, /) -> tuple[int, int] + def __radd__: fn(value: int, /) -> int + def __rsub__: fn(value: int, /) -> int + def __rmul__: fn(value: int, /) -> int + def __rfloordiv__: fn(value: int, /) -> int + def __rtruediv__: fn(value: int, /) -> float + def __rmod__: fn(value: int, /) -> int + // def __rdivmod__: fn(value: int, /) -> tuple[int, int] + def __pow__: fn(value: int, /) -> int + // def __pow__: fn(value: _PositiveInteger, mod: None = None, /) -> int + // def __pow__: fn(value: _NegativeInteger, mod: None = None, /) -> float + // positive __value -> int; negative __value -> float + // return type must be Any as `int | float` causes too many false-positive errors + // def __pow__: fn(value: int, mod: None = None, /) -> Any + // def __pow__: fn(value: int, mod: int, /) -> int + def __rpow__: fn(value: int, /) -> Any + def __and__: fn(value: int, /) -> int + def __or__: fn(value: int, /) -> int + def __xor__: fn(value: int, /) -> int + def __lshift__: fn(value: int, /) -> int + def __rshift__: fn(value: int, /) -> int + def __rand__: fn(value: int, /) -> int + def __ror__: fn(value: int, /) -> int + def __rxor__: fn(value: int, /) -> int + def __rlshift__: fn(value: int, /) -> int + def __rrshift__: fn(value: int, /) -> int + def __neg__: fn() -> int + def __pos__: fn() -> int + def __invert__: fn() -> int + def __trunc__: fn() -> int + def __ceil__: fn() -> int + def __floor__: fn() -> int + def __round__: fn(ndigits: None?, /) -> int + def __round__: fn(ndigits: int, /) -> int + + // def __getnewargs__: fn() -> tuple[int] + def __eq__: fn(value: object, /) -> bool + def __ne__: fn(value: object, /) -> bool + def __lt__: fn(value: int, /) -> bool + def __le__: fn(value: int, /) -> bool + def __gt__: fn(value: int, /) -> bool + def __ge__: fn(value: int, /) -> bool + def __float__: fn() -> float + def __int__: fn() -> int + def __abs__: fn() -> int + def __hash__: fn() -> int + def __bool__: fn() -> bool + def __index__: fn() -> int + def __format__: fn(format_spec: str, /) -> str +} + +extend list[T] { + def copy: fn () -> list[T] + def append: fn (object: T, /) -> None + def extend: fn (iterable: list[T], /) -> None + def pop: fn (index: int?, /) -> T + def index: fn (value: T, start: int?, stop: int?, /) -> int + def count: fn (value: T, /) -> int + def insert: fn (index: int, object: T, /) -> None + def remove: fn (value: T, /) -> None + def sort: fn (*, reverse: bool?) -> None + def __len__: fn () -> int + // def __iter__: fn () -> Iterator[T] + def __getitem__: fn (i: int, /) -> T + def __getitem__: fn (s: slice, /) -> list[T] + def __setitem__: fn (key: int, value: T, /) -> None + def __setitem__: fn (key: slice, value: list[T], /) -> None + def __delitem__: fn (key: int, /) -> None + def __delitem__: fn (key: slice, /) -> None + // def __add__: fn[S <: T] (value: list[S], /) -> list[T] + def __add__: fn (value: list[T], /) -> list[T] + def __iadd__: fn (value: list[T], /) -> list[T] + def __mul__: fn (value: int, /) -> list[T] + def __rmul__: fn (value: int, /) -> list[T] + def __imul__: fn (value: int, /) -> list[T] + def __contains__: fn (key: object, /) -> bool + // def __reversed__: fn (self) -> Iterator[_T] + def __gt__: fn (value: list[T], /) -> bool + def __ge__: fn (value: list[T], /) -> bool + def __lt__: fn (value: list[T], /) -> bool + def __le__: fn (value: list[T], /) -> bool + def __eq__: fn (value: object, /) -> bool + + prop __doc__: str +} diff --git a/midas/checker/builtins.py b/midas/checker/builtins.py index bc80084..b1adf6d 100644 --- a/midas/checker/builtins.py +++ b/midas/checker/builtins.py @@ -1,4 +1,41 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from midas.checker.types import ( + BaseType, + GenericType, + TopType, + TypeVar, + UnitType, +) + +if TYPE_CHECKING: + from midas.checker.registry import TypesRegistry + + BUILTIN_SUBTYPES: dict[str, set[str]] = { "float": {"int"}, "int": {"bool"}, } + + +def define_builtins(reg: TypesRegistry): + """Define builtin types and operations""" + any = reg.define_type("Any", TopType()) + unit = reg.define_type("None", UnitType()) + object = reg.define_type("object", BaseType(name="object")) + bool = reg.define_type("bool", BaseType(name="bool")) + int = reg.define_type("int", BaseType(name="int")) + float = reg.define_type("float", BaseType(name="float")) + str = reg.define_type("str", BaseType(name="str")) + slice = reg.define_type("slice", BaseType(name="slice")) + + list = reg.define_type( + "list", + GenericType( + name="list", + params=[TypeVar(name="T", bound=None)], + body=BaseType(name="list"), + ), + ) diff --git a/midas/checker/checker.py b/midas/checker/checker.py index ab7261c..c26f0aa 100644 --- a/midas/checker/checker.py +++ b/midas/checker/checker.py @@ -1,812 +1,35 @@ -import logging -from dataclasses import dataclass from pathlib import Path from typing import Optional -import midas.ast.midas as m -import midas.ast.python as p -from midas.ast.location import Location -from midas.checker.builtins import BUILTIN_SUBTYPES -from midas.checker.diagnostic import Diagnostic, DiagnosticType -from midas.checker.environment import Environment -from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS -from midas.checker.types import ( - AliasType, - BaseType, - ComplexType, - Function, - Operation, - Type, - UnitType, - UnknownType, -) -from midas.lexer.midas import MidasLexer -from midas.lexer.token import Token -from midas.parser.midas import MidasParser -from midas.resolver.midas import MidasResolver +from midas.checker.diagnostic import Diagnostic +from midas.checker.midas import MidasTyper +from midas.checker.python import PythonTyper +from midas.checker.registry import TypesRegistry +from midas.checker.reporter import Reporter -class ReturnException(Exception): - pass +class TypeChecker: + def __init__(self): + self.types: TypesRegistry = TypesRegistry() + self.reporter: Reporter = Reporter() + self.midas_typer = MidasTyper(self.types, self.reporter) + self.python_typer = PythonTyper(self.types, self.reporter) -@dataclass(frozen=True, kw_only=True) -class MappedArgument: - expr: p.Expr - type: Type - argument: Function.Argument + def import_midas(self, path: Path): + source: str = path.read_text() + return self.import_midas_source(source, path=str(path)) + def import_midas_source(self, source: str, path: Optional[str] = None): + self.midas_typer.process(source, path) -class Checker( - p.Stmt.Visitor[None], - p.Expr.Visitor[Type], - p.MidasType.Visitor[Type], -): - """A type checker which can use custom type definitions""" + def type_check(self, path: Path): + source: str = path.read_text() + return self.type_check_source(source, path=str(path)) - def __init__( - self, - locals: dict[p.Expr, int], - source_path: Path, - types_paths: list[Path], - ): - self.logger: logging.Logger = logging.getLogger("Checker") - self.source_path: Path = source_path - self.types_paths: list[Path] = types_paths - self.ctx: MidasResolver = MidasResolver() - self.global_env: Environment = Environment() - self.env: Environment = self.global_env - self.locals: dict[p.Expr, int] = locals - self.diagnostics: list[Diagnostic] = [] - self.judgements: list[tuple[p.Expr, Type]] = [] + def type_check_source(self, source: str, path: Optional[str] = None): + self.python_typer.process(source, path) - def diagnostic(self, type: DiagnosticType, location: Location, message: str): - self.diagnostics.append( - Diagnostic( - file_path=self.source_path, - location=location, - type=type, - message=message, - ) - ) - - def error(self, location: Location, message: str): - self.diagnostic( - type=DiagnosticType.ERROR, - location=location, - message=message, - ) - - def warning(self, location: Location, message: str): - self.diagnostic( - type=DiagnosticType.WARNING, - location=location, - message=message, - ) - - def info(self, location: Location, message: str): - self.diagnostic( - type=DiagnosticType.INFO, - location=location, - message=message, - ) - - def type_of(self, expr: p.Expr) -> Type: - """Evaluate the type of an expression - - Args: - expr (p.Expr): the expression to evaluate - - Returns: - Type: the type of the given expression - """ - type: Type = expr.accept(self) - self.judgements.append((expr, type)) - return type - - def process_block(self, block: list[p.Stmt], env: Environment) -> bool: - """Evaluate a sequence of statements - - Args: - block (list[p.Stmt]): the statements to evaluate - env (Environment): the environment in which to evaluate - - Returns: - bool: whether a return statement is present in the block - """ - previous_env: Environment = self.env - self.env = env - returned: bool = False - for i, stmt in enumerate(block): - try: - stmt.accept(self) - except ReturnException: - returned = True - if i < len(block) - 1: - self.warning(block[i + 1].location, "Unreachable statement") - break - self.env = previous_env - return returned - - def check(self, statements: list[p.Stmt]) -> list[Diagnostic]: - """Type check a sequence of statements and returns diagnostics - - Args: - statements (list[p.Stmt]): the statements to evaluate and check - - Returns: - list[Diagnostic]: the list of diagnostics (errors, warning, etc.) - """ - self.diagnostics = [] - - for path in self.types_paths: - self.import_midas(path) - self.logger.debug(f"Midas types: {self.ctx._types}") - self.logger.debug(f"Midas operations: {self.ctx._operations}") - - for stmt in statements: - stmt.accept(self) - - self.logger.debug(f"Final environment: {self.env.flat_dict()}") - return self.diagnostics - - def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]: - """Look up a variable in the environment it was declared - - Args: - name (str): the name of the variable - expr (p.Expr): the variable expression, used to lookup the scope distance - - Returns: - Optional[Type]: the type of the variable, or None if it was not found - """ - distance: Optional[int] = self.locals.get(expr) - if distance is not None: - return self.env.get_at(distance, name) - return self.global_env.get(name) - - def import_midas(self, path: Path) -> None: - """Import Midas definitions from a path - - Args: - path (Path): the import path - """ - self.logger.debug(f"Importing type definitions from {path}") - lexer: MidasLexer = MidasLexer(path.read_text()) - tokens: list[Token] = lexer.process() - parser: MidasParser = MidasParser(tokens) - stmts: list[m.Stmt] = parser.parse() - self.ctx.resolve(stmts) - - def unfold_type(self, type: Type) -> Type: - match type: - case AliasType(type=ref_type): - return self.unfold_type(ref_type) - case _: - return type - - def is_subtype(self, type1: Type, type2: Type) -> bool: - """Check whether `type1` is a subtype of `type2` - - For more details on the rules checked here, see TAPL Chap. 15-16-17 - - Args: - type1 (Type): the potential subtype - type2 (Type): the potential supertype - - Returns: - bool: whether `type1` is a subtype of `type2` - """ - - if type1 == type2: - return True - - match (type1, type2): - case (AliasType(type=base1), _): - return self.is_subtype(base1, type2) - - case (BaseType(name=name1), BaseType(name=name2)): - return name1 in BUILTIN_SUBTYPES.get(name2, set()) - - case (ComplexType(properties=props1), ComplexType(properties=props2)): - for k, t in props2.items(): - if k not in props1: - return False - if not self.is_subtype(props1[k], t): - return False - return True - - case (Function(returns=return1), Function(returns=return2)): - if not self.is_func_subtype(type1, type2): - return False - if not self.is_subtype(return1, return2): - return False - return True - - return False - - # TODO: verify the logic in here - def is_func_subtype(self, func1: Function, func2: Function) -> bool: - """Check whether a function is a subtype of another - - Args: - func1 (Function): the potential function subtype - func2 (Function): the potential function supertype - - Returns: - bool: whether `func1` is a subtype of `func2` - """ - if not self.is_subtype(func1.returns, func2.returns): - return False - - pos1: list[Function.Argument] = func1.pos_args - mixed1: list[Function.Argument] = func1.args - kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args} - pos2: list[Function.Argument] = func2.pos_args - mixed2: list[Function.Argument] = func2.args - kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args} - - mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2} - mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2} - - def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool: - if not self.is_subtype(sub.type, sup.type): - return False - if not sup.required and sub.required: - return False - return True - - for arg1 in pos1: - arg2: Function.Argument - if arg1.pos < len(pos2): - arg2 = pos2[arg1.pos] - elif arg1.pos in mixed_by_pos: - arg2 = mixed_by_pos[arg1.pos] - elif not arg1.required: - continue - else: - return False - if not is_arg_subtype(arg2, arg1): - return False - - for name, arg1 in kw1.items(): - arg2: Function.Argument - if name in kw2: - arg2 = kw2[name] - elif name in mixed_by_name: - arg2 = mixed_by_name[name] - elif not arg1.required: - continue - else: - return False - if not is_arg_subtype(arg2, arg1): - return False - - for arg1 in mixed1: - pos_arg2: Optional[Function.Argument] = None - kw_arg2: Optional[Function.Argument] = None - if arg1.name in kw2: - kw_arg2 = kw2[arg1.name] - elif arg1.name in mixed_by_name: - kw_arg2 = mixed_by_name[arg1.name] - if arg1.pos < len(pos2): - pos_arg2 = pos2[arg1.pos] - elif arg1.pos in mixed_by_pos: - pos_arg2 = mixed_by_pos[arg1.pos] - - # No match in func2 and arg is required - if pos_arg2 is None and kw_arg2 is None and arg1.required: - return False - - # Matching keyword argument - if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1): - return False - - # Matching positional argument - if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1): - return False - - mixed_positions: set[int] = {a.pos for a in mixed1} - mixed_names: set[str] = {a.name for a in mixed1} - for arg2 in pos2: - if not arg2.required: - continue - if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions: - return False - - for name, arg2 in kw2.items(): - if not arg2.required: - continue - if name not in kw1 and name not in mixed_names: - return False - - for arg2 in mixed2: - if arg2.required: - continue - pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions - kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names - if not pos_match or not kw_match: - return False - - return True - - def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: - self.type_of(stmt.expr) - - def visit_function(self, stmt: p.Function) -> None: - env: Environment = Environment(self.env) - pos_args: list[Function.Argument] = [] - args: list[Function.Argument] = [] - kw_args: list[Function.Argument] = [] - - def eval_arg_type(arg: p.Function.Argument) -> Type: - if arg.type is not None: - return arg.type.accept(self) - if arg.default is not None: - return arg.default.accept(self) - return UnknownType() - - pos: int = 0 - for arg in stmt.posonlyargs: - pos_args.append( - Function.Argument( - pos=pos, - name=arg.name, - type=eval_arg_type(arg), - required=arg.default is None, - ) - ) - pos += 1 - for arg in stmt.args: - args.append( - Function.Argument( - pos=pos, - name=arg.name, - type=eval_arg_type(arg), - required=arg.default is None, - ) - ) - pos += 1 - for arg in stmt.kwonlyargs: - kw_args.append( - Function.Argument( - pos=pos, # not relevant - name=arg.name, - type=eval_arg_type(arg), - required=arg.default is None, - ) - ) - pos += 1 - - for arg in pos_args + args + kw_args: - env.define(arg.name, arg.type) - - returns_hint: Optional[Type] = None - if stmt.returns is not None: - returns_hint = stmt.returns.accept(self) - # Early define to handle simple fully-typed recursion - inside_function: Function = Function( - name=stmt.name, - pos_args=pos_args, - args=args, - kw_args=kw_args, - returns=returns_hint, - ) - self.env.define(stmt.name, inside_function) - - returned: bool = self.process_block(stmt.body, env) - inferred_return: Type = UnknownType() - if not returned: - env.return_types.append(UnitType()) - return_types: set[Type] = set(env.return_types) - if len(return_types) == 1: - inferred_return = list(return_types)[0] - elif len(return_types) > 1: - self.error( - stmt.location, - f"Mixed return types: {env.return_types}", - ) - - returns: Type = UnknownType() - if returns_hint is not None: - assert stmt.returns is not None - returns = returns_hint - if returns != inferred_return: - self.error( - stmt.returns.location, - f"Return type mismatch, annotated {returns} but returns {inferred_return}", - ) - else: - returns = inferred_return - - # TODO: handle *args and **kwargs sinks - function: Function = Function( - name=stmt.name, - pos_args=pos_args, - args=args, - kw_args=kw_args, - returns=returns, - ) - self.env.define(stmt.name, function) - - def visit_type_assign(self, stmt: p.TypeAssign) -> None: - # TODO check not yet defined locally - type: Type = stmt.type.accept(self) - self.env.define(stmt.name, type) - - def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: - value_type: Type = self.type_of(stmt.value) - for target in stmt.targets: - self._assign(stmt.location, target, value_type) - - def _assign(self, location: Location, target: p.Expr, value_type: Type): - match target: - case p.VariableExpr(): - self._assign_var(location, target, value_type) - - case p.GetExpr(): - self._assign_attr(location, target, value_type) - - case _: - if not isinstance(target, p.VariableExpr): - self.logger.warning(f"Unsupported assignment to {target}") - self.warning(target.location, f"Unsupported assignment to {target}") - - def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type): - name: str = target.name - var_type: Optional[Type] = self.look_up_variable(name, target) - - if var_type is None: - self.env.define(name, value_type) - else: - # S <: T - # Γ, x: T v: S - # x = v - if not self.is_subtype(value_type, var_type): - self.error( - location, - f"Cannot assign {value_type} to {name} of type {var_type}", - ) - - def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type): - object: Type = self.type_of(target.object) - base_object: Type = self.unfold_type(object) - match base_object: - case ComplexType(properties=properties): - if target.name not in properties: - self.error( - target.location, f"Unknown property '{target.name} on {object}" - ) - return - - prop_type: Type = properties[target.name] - if not self.is_subtype(value_type, prop_type): - self.error( - location, - f"Cannot assign {value_type} to property '{target.name}' of type {prop_type} on {object}", - ) - return - - case UnknownType(): - pass - - case _: - self.error( - target.location, - f"Cannot assign {value_type} to unknown property '{target.name}' on {object}", - ) - - def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: - type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType() - self.env.return_types.append(type) - raise ReturnException() - - def visit_if_stmt(self, stmt: p.IfStmt) -> None: - # Not evaluated in sub-environment because assignments in the test leak out of the if - # For example: - # if (m := 1 + 1) < 2: - # ... - # print(m) # <- m is still defined - test_type: Type = stmt.test.accept(self) - - # TODO Allow subtypes or any type - if test_type != self.ctx.get_type("bool"): - self.error( - stmt.test.location, f"If test must be a boolean, got {test_type}" - ) - - env: Environment = Environment(self.env) - body_returned: bool = self.process_block(stmt.body, env) - else_returned: bool = self.process_block(stmt.orelse, env) - self.env.return_types.extend(env.return_types) - if body_returned and else_returned: - raise ReturnException() - - def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: - method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) - if method is None: - self.logger.warning(f"Unsupported operator {expr.operator}") - self.warning(expr.location, f"Unsupported operator {expr.operator}") - return UnknownType() - left: Type = self.type_of(expr.left) - right: Type = self.type_of(expr.right) - - operations: list[Operation] = self.ctx.get_operations_by_name(method) - valid_operations: list[Operation] = [] - for op in operations: - sig: Operation.CallSignature = op.signature - if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right): - valid_operations.append(op) - - if len(valid_operations) == 0: - self.error( - expr.location, - f"Undefined operation {method} between {left} and {right}", - ) - return UnknownType() - elif len(valid_operations) == 1: - self.logger.debug(f"Unique operation {method} between {left} and {right}") - return valid_operations[0].result - - for i, op1 in enumerate(valid_operations): - sig1: Operation.CallSignature = op1.signature - best_match: bool = True - for j, op2 in enumerate(valid_operations): - if i == j: - continue - sig2: Operation.CallSignature = op2.signature - if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype( - sig1.right, sig2.right - ): - best_match = False - break - self.logger.debug(f"{op1} is a full overload of {op2}") - if best_match: - return op1.result - - overloads: list[str] = [ - f"({op.signature.left} {op.signature.method} {op.signature.right}) -> {op.result}" - for op in valid_operations - ] - self.error( - expr.location, - f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(overloads)}", - ) - return UnknownType() - - def visit_compare_expr(self, expr: p.CompareExpr) -> Type: - method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__) - if method is None: - self.logger.warning(f"Unsupported operator {expr.operator}") - self.warning(expr.location, f"Unsupported operator {expr.operator}") - return UnknownType() - left: Type = self.type_of(expr.left) - right: Type = self.type_of(expr.right) - - result: Optional[Type] = self.ctx.get_operation_result(left, method, right) - if result is None: - self.error( - expr.location, - f"Undefined operation {method} between {left} and {right}", - ) - return UnknownType() - return result - - def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ... - - def visit_call_expr(self, expr: p.CallExpr) -> Type: - callee: Type = self.type_of(expr.callee) - if not isinstance(callee, Function): - self.error(expr.callee.location, "Callee is not a function") - return UnknownType() - function: Function = callee - mapped: list[MappedArgument] = self.map_call_arguments(function, expr) - for arg in mapped: - if not self.is_subtype(arg.type, arg.argument.type): - self.error( - arg.expr.location, - f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", - ) - return function.returns - - def visit_get_expr(self, expr: p.GetExpr) -> Type: - object: Type = self.type_of(expr.object) - base_object: Type = self.unfold_type(object) - match base_object: - case ComplexType(properties=properties): - if expr.name not in properties: - self.error( - expr.location, f"Unknown property '{expr.name} on {object}" - ) - return UnknownType() - return properties[expr.name] - - case UnknownType(): - return UnknownType() - - case _: - self.error( - expr.location, f"Cannot get property '{expr.name}' on {object}" - ) - return UnknownType() - - def visit_literal_expr(self, expr: p.LiteralExpr) -> Type: - match expr.value: - case bool(): # Must be before int - return self.ctx.get_type("bool") - case int(): - return self.ctx.get_type("int") - case float(): - return self.ctx.get_type("float") - case str(): - return self.ctx.get_type("str") - case _: - self.warning(expr.location, f"Unknown literal {expr}") - return UnknownType() - - def visit_variable_expr(self, expr: p.VariableExpr) -> Type: - return self.look_up_variable(expr.name, expr) or UnknownType() - - def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: - left: Type = expr.left.accept(self) - right: Type = expr.right.accept(self) - - if self.is_subtype(left, right): - return right - if self.is_subtype(right, left): - return left - - self.error( - expr.location, - f"Incompatible operand types, {left=} and {right=}", - ) - return UnknownType() - - def visit_cast_expr(self, expr: p.CastExpr) -> Type: - return expr.type.accept(self) - - def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type: - test_type: Type = expr.test.accept(self) - - # TODO Allow subtypes or any type - if test_type != self.ctx.get_type("bool"): - self.error( - expr.test.location, f"If test must be a boolean, got {test_type}" - ) - - true_type: Type = expr.if_true.accept(self) - false_type: Type = expr.if_false.accept(self) - if self.is_subtype(true_type, false_type): - return false_type - if self.is_subtype(false_type, true_type): - return true_type - - self.error( - expr.location, - f"Incompatible types in ternary if branches: true={true_type} and false={false_type}", - ) - return UnknownType() - - def visit_base_type(self, node: p.BaseType) -> Type: - return self.ctx.get_type(node.base) - - def visit_constraint_type(self, node: p.ConstraintType) -> Type: ... - - def visit_frame_column(self, node: p.FrameColumn) -> Type: ... - - def visit_frame_type(self, node: p.FrameType) -> Type: ... - - def map_call_arguments( - self, function: Function, call: p.CallExpr - ) -> list[MappedArgument]: - """Map call arguments to function parameters as defined in its signature - - This method maps positional-only, keyword-only and mixed parameter definitions - with the arguments passed at the call site - - Any mismatched, missing or unexpected argument is reported as a diagnostic - - Args: - function (Function): the function definition - call (p.CallExpr): the call expression - - Returns: - list[MappedArgument]: the list of mapped arguments - """ - positional: list[tuple[p.Expr, Type]] = [ - (arg, self.type_of(arg)) for arg in call.arguments - ] - keywords: dict[str, tuple[p.Expr, Type]] = { - name: (arg, self.type_of(arg)) for name, arg in call.keywords.items() - } - set_args: set[str] = set() - - required_positional: list[str] = [ - arg.name for arg in function.pos_args + function.args if arg.required - ] - required_keyword: list[str] = [ - arg.name for arg in function.kw_args if arg.required - ] - - mapped: list[MappedArgument] = [] - - pos_params: list[Function.Argument] = list(function.pos_args) - mixed_params: list[Function.Argument] = list(function.args) - kw_params: dict[str, Function.Argument] = { - arg.name: arg for arg in function.kw_args - } - - # TODO: handle *args and **kwargs sinks - for arg in positional: - param: Function.Argument - if len(pos_params) != 0: - param = pos_params.pop(0) - elif len(mixed_params) != 0: - param = mixed_params.pop(0) - else: - self.error(arg[0].location, "Too many positional arguments") - break - name: str = param.name - if name in required_positional: - required_positional.remove(name) - if name in required_keyword: - required_keyword.remove(name) - set_args.add(name) - mapped.append( - MappedArgument( - expr=arg[0], - type=arg[1], - argument=param, - ) - ) - - kw_params.update({arg.name: arg for arg in mixed_params}) - for name, arg in keywords.items(): - param: Function.Argument - if name not in kw_params: - if name in set_args: - self.error( - arg[0].location, f"Multiple values for argument '{name}'" - ) - else: - self.error(arg[0].location, f"Unknown keyword argument '{name}'") - continue - param = kw_params.pop(name) - if name in required_positional: - required_positional.remove(name) - if name in required_keyword: - required_keyword.remove(name) - set_args.add(name) - mapped.append( - MappedArgument( - expr=arg[0], - type=arg[1], - argument=param, - ) - ) - - def join_args(args: list[str]) -> str: - args = list(map(lambda a: f"'{a}'", args)) - if len(args) == 0: - return "" - if len(args) == 1: - return args[0] - return ", ".join(args[:-1]) + " and " + args[-1] - - if len(required_positional) != 0: - plural: str = "" if len(required_positional) == 1 else "s" - args: str = join_args(required_positional) - self.error( - call.location, - f"Missing required positional argument{plural}: {args}", - ) - - if len(required_keyword) != 0: - plural: str = "" if len(required_keyword) == 1 else "s" - args: str = join_args(required_keyword) - self.error( - call.location, - f"Missing required keyword argument{plural}: {args}", - ) - - return mapped + @property + def diagnostics(self) -> list[Diagnostic]: + return self.reporter.diagnostics diff --git a/midas/checker/diagnostic.py b/midas/checker/diagnostic.py index 77f687e..f4b3d12 100644 --- a/midas/checker/diagnostic.py +++ b/midas/checker/diagnostic.py @@ -1,6 +1,5 @@ from dataclasses import dataclass from enum import StrEnum -from pathlib import Path from typing import Optional from midas.ast.location import Location @@ -14,7 +13,7 @@ class DiagnosticType(StrEnum): @dataclass(frozen=True) class Diagnostic: - file_path: Path + file_path: Optional[str] location: Location type: DiagnosticType message: str @@ -28,10 +27,16 @@ class Diagnostic: and self.location.end_col_offset is not None ): end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}" - loc: str = ( - f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}" - ) - return f"{self.type} in {self.file_path} {loc}" + + loc: str = "" + if self.file_path is not None: + loc += f" in {self.file_path}" + if end_loc is None: + loc += f" at {start_loc}" + else: + loc += f" from {start_loc} to {end_loc}" + + return f"{self.type}{loc}" def __str__(self) -> str: return f"{self.location_str}: {self.message}" diff --git a/midas/checker/midas.py b/midas/checker/midas.py new file mode 100644 index 0000000..3764c03 --- /dev/null +++ b/midas/checker/midas.py @@ -0,0 +1,206 @@ +import logging +from pathlib import Path +from typing import Optional + +import midas.ast.midas as m +from midas.checker.builtins import define_builtins +from midas.checker.registry import TypesRegistry +from midas.checker.reporter import FileReporter, Reporter +from midas.checker.types import ( + AliasType, + ComplexType, + ExtensionType, + Function, + GenericType, + Type, + TypeVar, + UnknownType, +) +from midas.lexer.midas import MidasLexer +from midas.lexer.token import Token +from midas.parser.midas import MidasParser + + +class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]): + """A resolver which evaluates Midas type definitions and build a registry""" + + def __init__(self, types: TypesRegistry, reporter: Reporter) -> None: + self.logger: logging.Logger = logging.getLogger("MidasTyper") + self.reporter: FileReporter = reporter.for_file(None) + + self.types: TypesRegistry = types + self._local_variables: dict[str, TypeVar] = {} + + self._current_name: Optional[str] = None + + define_builtins(self.types) + builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve() + self.process(builtins_path.read_text(), str(builtins_path)) + + def process(self, source: str, path: Optional[str]): + self.reporter = self.reporter.for_file(path) + lexer: MidasLexer = MidasLexer(source) + tokens: list[Token] = lexer.process() + parser: MidasParser = MidasParser(tokens) + stmts: list[m.Stmt] = parser.parse() + for error in parser.errors: + self.reporter.error(error.token.get_location(), error.message) + self.resolve(stmts) + + def get_type(self, name: str) -> Type: + """Get a type from its name + + Args: + name (str): the name of the type + + Raises: + NameError: if the type is not defined + + Returns: + Type: the type + """ + if name in self._local_variables: + return self._local_variables[name] + return self.types.get_type(name) + + def resolve(self, stmts: list[m.Stmt]): + """Process a sequence of statements + + Args: + stmts (list[m.Stmt]): the statements + """ + for stmt in stmts: + stmt.accept(self) + + def visit_type_stmt(self, stmt: m.TypeStmt) -> None: + name: str = stmt.name.lexeme + self._current_name = name + params: list[TypeVar] = self._resolve_type_params(stmt.params) + + type: Type = stmt.type.accept(self) + if len(params) != 0: + type = GenericType(name=name, params=params, body=type) + else: + type = AliasType(name=name, type=type) + self.types.define_type(name, type) + self._local_variables.clear() + self._current_name = None + + def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ... + + def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: + self._resolve_type_params(stmt.params) + base_name: str = stmt.name.lexeme + try: + _ = self.get_type(base_name) + except NameError: + self.reporter.error(stmt.name.get_location(), f"Unknown type '{base_name}'") + + for member in stmt.members: + member_type: Type = member.type.accept(self) + self.types.define_member( + base_name, + member.name.lexeme, + member_type, + member.kind == m.MemberKind.METHOD, + ) + + def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: + self.reporter.warning(stmt.location, "PredicateStmt not yet supported") + + def visit_logical_expr(self, expr: m.LogicalExpr) -> None: + self.reporter.warning(expr.location, "LogicalExpr not yet supported") + + def visit_binary_expr(self, expr: m.BinaryExpr) -> None: + self.reporter.warning(expr.location, "BinaryExpr not yet supported") + + def visit_unary_expr(self, expr: m.UnaryExpr) -> None: + self.reporter.warning(expr.location, "UnaryExpr not yet supported") + + def visit_get_expr(self, expr: m.GetExpr) -> None: + self.reporter.warning(expr.location, "GetExpr not yet supported") + + def visit_variable_expr(self, expr: m.VariableExpr) -> None: + self.reporter.warning(expr.location, "VariableExpr not yet supported") + + def visit_grouping_expr(self, expr: m.GroupingExpr) -> None: + return expr.expr.accept(self) + + def visit_literal_expr(self, expr: m.LiteralExpr) -> None: + self.reporter.warning(expr.location, "LiteralExpr not yet supported") + + def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: + self.reporter.warning(expr.location, "WildcardExpr not yet supported") + + def visit_named_type(self, type: m.NamedType) -> Type: + name: str = type.name.lexeme + try: + return self.get_type(name) + except NameError: + msg: str = f"Undefined type {name}" + if self._current_name == name: + msg += ". Recursive types are not supported, use an extend block" + self.reporter.error(type.name.get_location(), msg) + return UnknownType() + + def visit_generic_type(self, type: m.GenericType) -> Type: + type_: Type = type.type.accept(self) + args: list[Type] = [arg.accept(self) for arg in type.args] + try: + return self.types.apply_generic(type_, args) + except Exception as e: + self.reporter.error(type.location, f"Cannot apply generic type: {e}") + return UnknownType() + + def visit_constraint_type(self, type: m.ConstraintType) -> Type: + type_: Type = type.type.accept(self) + type.constraint.accept(self) + # TODO + return UnknownType() + + def visit_complex_type(self, type: m.ComplexType) -> ComplexType: + return ComplexType( + members={ + member.name.lexeme: member.type.accept(self) for member in type.members + } + ) + + def visit_extension_type(self, type: m.ExtensionType) -> Type: + return ExtensionType( + base=type.base.accept(self), + extension=self.visit_complex_type(type.extension), + ) + + def visit_function_type(self, type: m.FunctionType) -> Type: + n_pos_args: int = len(type.pos_args) + n_args: int = len(type.args) + + def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument: + return Function.Argument( + pos=i, + name=arg.name.lexeme if arg.name is not None else str(i), + type=arg.type.accept(self), + required=arg.required, + ) + + return Function( + pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)], + args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)], + kw_args=[ + process_arg(arg, i + n_pos_args + n_args) + for i, arg in enumerate(type.kw_args) + ], + returns=type.returns.accept(self), + ) + + def _resolve_type_params(self, params: list[m.TypeParam]): + vars: list[TypeVar] = [] + for param in params: + name: str = param.name.lexeme + bound: Optional[Type] = None + if param.bound is not None: + bound = param.bound.accept(self) + var = TypeVar(name=name, bound=bound) + self._local_variables[name] = var + vars.append(var) + return vars diff --git a/midas/checker/operators.py b/midas/checker/operators.py index e65ab07..58af88c 100644 --- a/midas/checker/operators.py +++ b/midas/checker/operators.py @@ -29,3 +29,10 @@ COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = { # ast.In: "__in__", # ast.NotIn: "__notin__", } + +UNARY_METHODS: dict[Type[ast.unaryop], str] = { + ast.Invert: "__invert__", + # ast.Not: "", + ast.UAdd: "__pos__", + ast.USub: "__neg__", +} diff --git a/midas/checker/python.py b/midas/checker/python.py new file mode 100644 index 0000000..a0f7a06 --- /dev/null +++ b/midas/checker/python.py @@ -0,0 +1,859 @@ +import ast +import logging +from dataclasses import dataclass +from typing import Optional + +import midas.ast.python as p +from midas.ast.location import Location +from midas.checker.environment import Environment +from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS +from midas.checker.registry import TypesRegistry +from midas.checker.reporter import FileReporter, Reporter +from midas.checker.resolver import Resolver +from midas.checker.types import ( + Function, + OverloadedFunction, + Type, + UnitType, + UnknownType, + unfold_type, +) +from midas.parser.python import PythonParser + +TypedExpr = tuple[p.Expr, Type] + + +class ReturnException(Exception): + pass + + +@dataclass(frozen=True, kw_only=True) +class MappedArgument: + expr: p.Expr + type: Type + argument: Function.Argument + + +@dataclass(frozen=True, kw_only=True) +class OverloadCandidate: + function: Function + mapped: list[MappedArgument] + + +class PythonTyper( + p.Stmt.Visitor[None], + p.Expr.Visitor[Type], + p.MidasType.Visitor[Type], +): + """A type checker which can use custom type definitions""" + + def __init__( + self, + types: TypesRegistry, + reporter: Reporter, + ): + self.logger: logging.Logger = logging.getLogger("PythonTyper") + self.reporter: FileReporter = reporter.for_file(None) + self.types: TypesRegistry = types + self.global_env: Environment = Environment() + self.env: Environment = self.global_env + self.locals: dict[p.Expr, int] = {} + self.judgements: list[tuple[p.Expr, Type]] = [] + + def process(self, source: str, path: Optional[str]): + self.reporter = self.reporter.for_file(path) + + tree: ast.Module = ast.parse(source, filename=path or "") + parser = PythonParser() + stmts: list[p.Stmt] = parser.parse_module(tree) + resolver = Resolver() + resolver.resolve(*stmts) + + self.env = self.global_env + self.locals = resolver.locals + self.judgements = [] + + self.check(stmts) + + def type_of(self, expr: p.Expr) -> Type: + """Evaluate the type of an expression + + Args: + expr (p.Expr): the expression to evaluate + + Returns: + Type: the type of the given expression + """ + type: Type = expr.accept(self) + self.judgements.append((expr, type)) + return type + + def resolve_type_expr(self, expr: p.MidasType) -> Type: + return expr.accept(self) + + def process_stmt(self, stmt: p.Stmt) -> None: + stmt.accept(self) + + def process_block(self, block: list[p.Stmt], env: Environment) -> bool: + """Evaluate a sequence of statements + + Args: + block (list[p.Stmt]): the statements to evaluate + env (Environment): the environment in which to evaluate + + Returns: + bool: whether a return statement is present in the block + """ + previous_env: Environment = self.env + self.env = env + returned: bool = False + for i, stmt in enumerate(block): + try: + self.process_stmt(stmt) + except ReturnException: + returned = True + if i < len(block) - 1: + self.reporter.warning( + block[i + 1].location, "Unreachable statement" + ) + break + self.env = previous_env + return returned + + def check(self, statements: list[p.Stmt]) -> None: + """Type check a sequence of statements and returns diagnostics + + Args: + statements (list[p.Stmt]): the statements to evaluate and check + """ + for stmt in statements: + self.process_stmt(stmt) + + self.logger.debug(f"Final environment: {self.env.flat_dict()}") + + def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]: + """Look up a variable in the environment it was declared + + Args: + name (str): the name of the variable + expr (p.Expr): the variable expression, used to lookup the scope distance + + Returns: + Optional[Type]: the type of the variable, or None if it was not found + """ + distance: Optional[int] = self.locals.get(expr) + if distance is not None: + return self.env.get_at(distance, name) + return self.global_env.get(name) + + def is_subtype(self, type1: Type, type2: Type) -> bool: + return self.types.is_subtype(type1, type2) + + def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: + self.type_of(stmt.expr) + + def visit_function(self, stmt: p.Function) -> None: + env: Environment = Environment(self.env) + pos_args: list[Function.Argument] = [] + args: list[Function.Argument] = [] + kw_args: list[Function.Argument] = [] + + def eval_arg_type(arg: p.Function.Argument) -> Type: + if arg.type is not None: + return self.resolve_type_expr(arg.type) + if arg.default is not None: + return self.type_of(arg.default) + return UnknownType() + + pos: int = 0 + for arg in stmt.posonlyargs: + pos_args.append( + Function.Argument( + pos=pos, + name=arg.name, + type=eval_arg_type(arg), + required=arg.default is None, + ) + ) + pos += 1 + for arg in stmt.args: + args.append( + Function.Argument( + pos=pos, + name=arg.name, + type=eval_arg_type(arg), + required=arg.default is None, + ) + ) + pos += 1 + for arg in stmt.kwonlyargs: + kw_args.append( + Function.Argument( + pos=pos, # not relevant + name=arg.name, + type=eval_arg_type(arg), + required=arg.default is None, + ) + ) + pos += 1 + + for arg in pos_args + args + kw_args: + env.define(arg.name, arg.type) + + returns_hint: Optional[Type] = None + if stmt.returns is not None: + returns_hint = self.resolve_type_expr(stmt.returns) + # Early define to handle simple fully-typed recursion + inside_function: Function = Function( + pos_args=pos_args, + args=args, + kw_args=kw_args, + returns=returns_hint, + ) + self.env.define(stmt.name, inside_function) + + returned: bool = self.process_block(stmt.body, env) + inferred_return: Type = UnknownType() + if not returned: + env.return_types.append(UnitType()) + return_types: list[Type] = self.types.reduce_types(env.return_types) + if len(return_types) == 1: + inferred_return = return_types[0] + elif len(return_types) > 1: + self.reporter.error( + stmt.location, + f"Mixed return types: {return_types}", + ) + + returns: Type = UnknownType() + if returns_hint is not None: + assert stmt.returns is not None + returns = returns_hint + if returns != inferred_return: + self.reporter.error( + stmt.returns.location, + f"Return type mismatch, annotated {returns} but returns {inferred_return}", + ) + else: + returns = inferred_return + + # TODO: handle *args and **kwargs sinks + function: Function = Function( + pos_args=pos_args, + args=args, + kw_args=kw_args, + returns=returns, + ) + self.env.define(stmt.name, function) + + def visit_type_assign(self, stmt: p.TypeAssign) -> None: + # TODO check not yet defined locally + type: Type = self.resolve_type_expr(stmt.type) + self.env.define(stmt.name, type) + + def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: + value_type: Type = self.type_of(stmt.value) + for target in stmt.targets: + self._assign(stmt.location, target, value_type) + + def _assign(self, location: Location, target: p.Expr, value_type: Type): + match target: + case p.VariableExpr(): + self._assign_var(location, target, value_type) + + case p.GetExpr(object=object, name=name): + self._assign_attr(location, object, name, value_type) + + case _: + if not isinstance(target, p.VariableExpr): + self.logger.warning(f"Unsupported assignment to {target}") + self.reporter.warning( + target.location, f"Unsupported assignment to {target}" + ) + + def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type): + name: str = target.name + var_type: Optional[Type] = self.look_up_variable(name, target) + + if var_type is None: + self.env.define(name, value_type) + else: + # S <: T + # Γ, x: T v: S + # x = v + if not self.is_subtype(value_type, var_type): + self.reporter.error( + location, + f"Cannot assign {value_type} to variable '{name}' of type {var_type}", + ) + + def _assign_attr( + self, location: Location, object: p.Expr, name: str, value_type: Type + ): + object_type: Type = self.type_of(object) + member: Optional[Type] = self.types.lookup_member(object_type, name) + if member is None: + self.reporter.error(location, f"Unknown member '{name}' of {object_type}") + return + self.logger.debug(f"Member '{name}' of {object_type} has type {member}") + if not self.is_subtype(value_type, member): + self.reporter.error( + location, + f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}", + ) + + def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: + type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType() + self.env.return_types.append(type) + raise ReturnException() + + def visit_if_stmt(self, stmt: p.IfStmt) -> None: + # Not evaluated in sub-environment because assignments in the test leak out of the if + # For example: + # if (m := 1 + 1) < 2: + # ... + # print(m) # <- m is still defined + test_type: Type = self.type_of(stmt.test) + + # TODO Allow subtypes or any type + if test_type != self.types.get_type("bool"): + self.reporter.error( + stmt.test.location, f"If test must be a boolean, got {test_type}" + ) + + env: Environment = Environment(self.env) + body_returned: bool = self.process_block(stmt.body, env) + else_returned: bool = self.process_block(stmt.orelse, env) + self.env.return_types.extend(env.return_types) + if body_returned and else_returned: + raise ReturnException() + + def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: + method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) + if method is None: + self.logger.warning(f"Unsupported operator {expr.operator}") + self.reporter.warning( + expr.location, f"Unsupported operator {expr.operator}" + ) + return UnknownType() + + return self._visit_binary_expr(expr.location, expr.left, expr.right, method) + + def visit_compare_expr(self, expr: p.CompareExpr) -> Type: + method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__) + if method is None: + self.logger.warning(f"Unsupported operator {expr.operator}") + self.reporter.warning( + expr.location, f"Unsupported operator {expr.operator}" + ) + return UnknownType() + + return self._visit_binary_expr(expr.location, expr.left, expr.right, method) + + def _visit_binary_expr( + self, location: Location, left_expr: p.Expr, right_expr: p.Expr, method: str + ) -> Type: + left: Type = self.type_of(left_expr) + right: Type = self.type_of(right_expr) + + operation: Optional[Type] = self.types.lookup_member(left, method) + if operation is None: + self.reporter.error( + location, + f"Undefined operation {method} between {left} and {right}", + ) + return UnknownType() + + return self._get_call_result(location, operation, [(right_expr, right)], {}) + + def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: + method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__) + if method is None: + self.logger.warning(f"Unsupported operator {expr.operator}") + self.reporter.warning( + expr.location, f"Unsupported operator {expr.operator}" + ) + return UnknownType() + + operand: Type = self.type_of(expr.right) + operation: Optional[Type] = self.types.lookup_member(operand, method) + if operation is None: + self.reporter.error( + expr.location, + f"Undefined operation {method} for {operand}", + ) + return UnknownType() + + return self._get_call_result( + expr.location, operation, [(expr.right, operand)], {} + ) + + def visit_call_expr(self, expr: p.CallExpr) -> Type: + callee: Type = self.type_of(expr.callee) + positional: list[TypedExpr] = [ + (arg, self.type_of(arg)) for arg in expr.arguments + ] + keywords: dict[str, TypedExpr] = { + name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items() + } + return self._get_call_result( + location=expr.location, + callee=callee, + positional=positional, + keywords=keywords, + ) + + def visit_get_expr(self, expr: p.GetExpr) -> Type: + object: Type = self.type_of(expr.object) + member: Optional[Type] = self.types.lookup_member(object, expr.name) + if member is None: + self.reporter.error( + expr.location, f"Unknown member '{expr.name}' of {object}" + ) + return UnknownType() + self.logger.debug(f"Member '{expr.name}' of {object} has type {member}") + return member + + def visit_literal_expr(self, expr: p.LiteralExpr) -> Type: + match expr.value: + case bool(): # Must be before int + return self.types.get_type("bool") + case int(): + return self.types.get_type("int") + case float(): + return self.types.get_type("float") + case str(): + return self.types.get_type("str") + case _: + self.reporter.warning(expr.location, f"Unknown literal {expr}") + return UnknownType() + + def visit_variable_expr(self, expr: p.VariableExpr) -> Type: + type: Optional[Type] = self.look_up_variable(expr.name, expr) + if type is None: + self.logger.debug(f"Unknown variable {expr.name} in {self.env.flat_dict()}") + self.reporter.warning(expr.location, "Unknown variable") + return type or UnknownType() + + def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: + left: Type = self.type_of(expr.left) + right: Type = self.type_of(expr.right) + + if self.is_subtype(left, right): + return right + if self.is_subtype(right, left): + return left + + self.reporter.error( + expr.location, + f"Incompatible operand types, {left=} and {right=}", + ) + return UnknownType() + + def visit_cast_expr(self, expr: p.CastExpr) -> Type: + return self.resolve_type_expr(expr.type) + + def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type: + test_type: Type = self.type_of(expr.test) + + # TODO Allow subtypes or any type + if test_type != self.types.get_type("bool"): + self.reporter.error( + expr.test.location, f"If test must be a boolean, got {test_type}" + ) + + true_type: Type = self.type_of(expr.if_true) + false_type: Type = self.type_of(expr.if_false) + if self.is_subtype(true_type, false_type): + return false_type + if self.is_subtype(false_type, true_type): + return true_type + + self.reporter.error( + expr.location, + f"Incompatible types in ternary if branches: true={true_type} and false={false_type}", + ) + return UnknownType() + + def visit_list_expr(self, expr: p.ListExpr) -> Type: + list_type: Type = self.types.get_type("list") + item_types: list[Type] = [self.type_of(item) for item in expr.items] + item_types = self.types.reduce_types(item_types) + + if len(item_types) == 0: + return list_type + + if len(item_types) == 1: + item_type: Type = item_types[0] + return self.types.apply_generic(list_type, [item_type]) + self.reporter.error( + expr.location, + f"Heterogeneous list items: {item_types}", + ) + return self.types.apply_generic(list_type, [UnknownType()]) + + def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type: + object: Type = self.type_of(expr.object) + operation: Optional[Type] = self.types.lookup_member(object, "__getitem__") + if operation is None: + self.reporter.error( + expr.location, + f"Undefined method __getitem__ on {object}", + ) + return UnknownType() + + index: Type = self.type_of(expr.index) + return self._get_call_result( + expr.location, operation, [(expr.index, index)], {} + ) + + def visit_slice_expr(self, expr: p.SliceExpr) -> Type: + return self.types.get_type("slice") + + def visit_base_type(self, node: p.BaseType) -> Type: + base: Type + try: + base = self.types.get_type(node.base) + except NameError: + self.reporter.warning(node.location, f"Unknown type '{node.base}'") + return UnknownType() + + if node.param is not None: + param: Type = self.resolve_type_expr(node.param) + return self.types.apply_generic(base, [param]) + return base + + def visit_constraint_type(self, node: p.ConstraintType) -> Type: + self.reporter.warning(node.location, "ConstraintType not yet supported") + return UnknownType() + + def visit_frame_column(self, node: p.FrameColumn) -> Type: + self.reporter.warning(node.location, "FrameColumn not yet supported") + return UnknownType() + + def visit_frame_type(self, node: p.FrameType) -> Type: + self.reporter.warning(node.location, "FrameType not yet supported") + return UnknownType() + + def _get_call_result( + self, + location: Location, + callee: Type, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + ) -> Type: + """Get the result type of a function call + + If the function has overloads, the function will try to resolve the + appropriate signature. + Argument types are matched to the defined parameters. + The function doesn't take the raw expression as a parameter to accomodate + for desugared calls such as for operators. + + Args: + location (Location): the call location + callee (Type): the called function + positional (list[TypedExpr]): the list positional arguments + keywords (dict[str, TypedExpr]): the map of keyword arguments + + Returns: + Type: the return type of the call, or `UnknownType` if either + the call is invalid or no overload matched the arguments uniquely + """ + match callee: + case Function() as function: + valid: bool + mapped: list[MappedArgument] + valid, mapped = self.map_call_arguments( + function, location, positional, keywords + ) + valid = valid and self._are_arguments_valid(mapped) + if not valid: + return UnknownType() + return function.returns + + case OverloadedFunction(overloads=overloads): + function = self._match_overload( + overloads, location, positional, keywords + ) + if function is None: + return UnknownType() + return function.returns + case _: + self.reporter.error(location, f"{callee} is not callable") + return UnknownType() + + def _are_arguments_valid( + self, + arguments: list[MappedArgument], + report_errors: bool = True, + ) -> bool: + """Check whether the passed argument types correspond to their matched parameter definitions + + Args: + arguments (list[MappedArgument]): the list of argument/parameter pairs + report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. + + Returns: + bool: True if all arguments fit the matching parameter definitions, False otherwise + """ + valid: bool = True + for arg in arguments: + if not self.is_subtype(arg.type, arg.argument.type): + if report_errors: + self.reporter.error( + arg.expr.location, + f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", + ) + valid = False + return valid + + def _match_overload( + self, + overloads: list[Type], + location: Location, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + ) -> Optional[Function]: + """Try and resolve the appropriate overload for the given arguments + + Args: + overloads (list[Type]): the list of possible overloads + location (Location): the call location + positional (list[TypedExpr]): the list of positional arguments + keywords (dict[str, TypedExpr]): the map of keywords arguments + + Returns: + Optional[Function]: the resolved function signature if it can be + determined unambigously, or `None`. + """ + candidates: list[OverloadCandidate] = [] + for overload in overloads: + function: Type = unfold_type(overload) + if not isinstance(function, Function): + self.logger.error( + f"Overload is not a function: {overload} is {function}" + ) + continue + valid, mapped = self.map_call_arguments( + function=function, + location=location, + positional=positional, + keywords=keywords, + report_errors=False, + ) + if valid and self._are_arguments_valid(mapped, report_errors=False): + candidates.append( + OverloadCandidate( + function=function, + mapped=mapped, + ) + ) + + pos_types: str = ", ".join(str(type) for _, type in positional) + kw_types: str = ", ".join( + f"{name}: {type}" for name, (_, type) in keywords.items() + ) + for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}" + + n_candidates: int = len(candidates) + + # Exactly 1 match -> return it + if n_candidates == 1: + return candidates[0].function + + # No match -> invalid call + if n_candidates == 0: + overloads_str: str = ", ".join(map(str, overloads)) + self.reporter.error( + location, + f"No matching overload in [{overloads_str}] {for_args}", + ) + return None + + # Multiple matches -> see if one <: all others (more specific) + for i1, c1 in enumerate(candidates): + mapped1: list[MappedArgument] = c1.mapped + best_match: bool = True + for i2, c2 in enumerate(candidates): + if i1 == i2: + continue + mapped2: list[MappedArgument] = c2.mapped + if not self._are_mapped_subtypes(mapped1, mapped2): + best_match = False + break + self.logger.debug(f"{c1.function} is a full overload of {c2.function}") + if best_match: + return c1.function + + candidates_str: str = ", ".join( + str(candidate.function) for candidate in candidates + ) + self.reporter.error( + location, + f"Multiple matching overloads {for_args}: {candidates_str}", + ) + return None + + def map_call_arguments( + self, + function: Function, + location: Location, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + report_errors: bool = True, + ) -> tuple[bool, list[MappedArgument]]: + """Map call arguments to a function's parameters as defined in its signature + + This method maps positional-only, keyword-only and mixed parameter definitions + with the arguments passed at the call site + + Any mismatched, missing or unexpected argument is reported as a diagnostic, + unless `report_errors` is set to `False` + + Args: + function (Function): the function definition + location (Location): the call location + positional (list[TypedExpr]): the list of positional arguments + keywords (dict[str, TypedExpr]): the map of keyword arguments + report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True. + + Returns: + tuple[bool, list[MappedArgument]]: a boolean reporting whether + the call is valid and the list of mapped arguments + """ + set_args: set[str] = set() + + required_positional: list[str] = [ + arg.name for arg in function.pos_args + function.args if arg.required + ] + required_keyword: list[str] = [ + arg.name for arg in function.kw_args if arg.required + ] + + mapped: list[MappedArgument] = [] + + pos_params: list[Function.Argument] = list(function.pos_args) + mixed_params: list[Function.Argument] = list(function.args) + kw_params: dict[str, Function.Argument] = { + arg.name: arg for arg in function.kw_args + } + + valid_call: bool = True + + # TODO: handle *args and **kwargs sinks + for arg in positional: + param: Function.Argument + if len(pos_params) != 0: + param = pos_params.pop(0) + elif len(mixed_params) != 0: + param = mixed_params.pop(0) + else: + if report_errors: + self.reporter.error( + arg[0].location, "Too many positional arguments" + ) + valid_call = False + break + name: str = param.name + if name in required_positional: + required_positional.remove(name) + if name in required_keyword: + required_keyword.remove(name) + set_args.add(name) + mapped.append( + MappedArgument( + expr=arg[0], + type=arg[1], + argument=param, + ) + ) + + kw_params.update({arg.name: arg for arg in mixed_params}) + for name, arg in keywords.items(): + param: Function.Argument + if name not in kw_params: + if report_errors: + if name in set_args: + self.reporter.error( + arg[0].location, f"Multiple values for argument '{name}'" + ) + else: + self.reporter.error( + arg[0].location, f"Unknown keyword argument '{name}'" + ) + valid_call = False + continue + param = kw_params.pop(name) + if name in required_positional: + required_positional.remove(name) + if name in required_keyword: + required_keyword.remove(name) + set_args.add(name) + mapped.append( + MappedArgument( + expr=arg[0], + type=arg[1], + argument=param, + ) + ) + + def join_args(args: list[str]) -> str: + args = list(map(lambda a: f"'{a}'", args)) + if len(args) == 0: + return "" + if len(args) == 1: + return args[0] + return ", ".join(args[:-1]) + " and " + args[-1] + + if len(required_positional) != 0: + plural: str = "" if len(required_positional) == 1 else "s" + args: str = join_args(required_positional) + if report_errors: + self.reporter.error( + location, + f"Missing required positional argument{plural}: {args}", + ) + valid_call = False + + if len(required_keyword) != 0: + plural: str = "" if len(required_keyword) == 1 else "s" + args: str = join_args(required_keyword) + if report_errors: + self.reporter.error( + location, + f"Missing required keyword argument{plural}: {args}", + ) + valid_call = False + + return valid_call, mapped + + def _are_mapped_subtypes( + self, mapped1: list[MappedArgument], mapped2: list[MappedArgument] + ) -> bool: + """Check whether the given argument mappings are subtype/supertype of one another + + This function checks whether the argument mappings `mapped1` are subtypes + of `mapped2`. If any of the parameter type in `mapped1` is not a subtype + of the corresponding parameter in `mapped2`, `False` is returned. + + This is used to check whether a given overload is + a more specific function/ a subtype of another. + + Args: + mapped1 (list[MappedArgument]): the first argument mappings (subtype) + mapped2 (list[MappedArgument]): the second argument mappings (supertype) + + Returns: + bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise + """ + by_expr: dict[p.Expr, Type] = {} + for arg in mapped1: + by_expr[arg.expr] = arg.argument.type + + for arg in mapped2: + type2: Type = arg.argument.type + type1: Type = by_expr[arg.expr] + if not self.is_subtype(type1, type2): + return False + return True diff --git a/midas/checker/registry.py b/midas/checker/registry.py new file mode 100644 index 0000000..6591548 --- /dev/null +++ b/midas/checker/registry.py @@ -0,0 +1,347 @@ +import logging +from typing import Optional + +from midas.checker.builtins import BUILTIN_SUBTYPES +from midas.checker.types import ( + AliasType, + AppliedType, + BaseType, + ComplexType, + ExtensionType, + Function, + GenericType, + OverloadedFunction, + TopType, + Type, + TypeVar, + UnknownType, + substitute_typevars, +) + + +class TypesRegistry: + def __init__(self) -> None: + self.logger: logging.Logger = logging.getLogger("TypesRegistry") + self._types: dict[str, Type] = {} + self._members: dict[str, dict[str, Type]] = {} + + def get_type(self, name: str) -> Type: + """Get a type from its name + + Args: + name (str): the name of the type + + Raises: + NameError: if the type is not defined + + Returns: + Type: the type + """ + if name in self._types: + return self._types[name] + raise NameError(f"Undefined type {name}") + + def define_type(self, name: str, type: Type) -> Type: + """Define a type in the registry + + Args: + name (str): the name of the type + type (Type): the type to define + + Raises: + ValueError: if a type is already defined with that name + + Returns: + Type: the defined type + """ + if name in self._types: + raise ValueError(f"Type {name} already defined") + self._types[name] = type + return type + + def define_member( + self, type_name: str, member_name: str, member_type: Type, is_method: bool + ): + members: dict[str, Type] = self._members.setdefault(type_name, {}) + if member_name in members: + if not is_method: + self.logger.error( + f"Member '{member_name}' already defined for type {type_name}" + ) + return + current: Type = members[member_name] + combined: Type + match current: + case OverloadedFunction(overloads=overloads): + combined = OverloadedFunction(overloads=overloads + [member_type]) + case _: + combined = OverloadedFunction(overloads=[current, member_type]) + members[member_name] = combined + + else: + members[member_name] = member_type + + def is_subtype(self, type1: Type, type2: Type) -> bool: + """Check whether `type1` is a subtype of `type2` + + For more details on the rules checked here, see TAPL Chap. 15-16-17 + + Args: + type1 (Type): the potential subtype + type2 (Type): the potential supertype + + Returns: + bool: whether `type1` is a subtype of `type2` + """ + + if type1 == type2: + return True + + match (type1, type2): + case (_, TopType()): + return True + + case (AliasType(type=base1), _): + return self.is_subtype(base1, type2) + + case (BaseType(name=name1), BaseType(name=name2)): + return name1 in BUILTIN_SUBTYPES.get(name2, set()) + + case (ComplexType(properties=props1), ComplexType(properties=props2)): + for k, t in props2.items(): + if k not in props1: + return False + if not self.is_subtype(props1[k], t): + return False + return True + + case (Function(), Function()): + return self.is_func_subtype(type1, type2) + + case (TypeVar(bound=bound), _): + if bound is None: + return False + return self.is_subtype(bound, type2) + + return False + + # TODO: verify the logic in here + def is_func_subtype(self, func1: Function, func2: Function) -> bool: + """Check whether a function is a subtype of another + + Args: + func1 (Function): the potential function subtype + func2 (Function): the potential function supertype + + Returns: + bool: whether `func1` is a subtype of `func2` + """ + if not self.is_subtype(func1.returns, func2.returns): + return False + + pos1: list[Function.Argument] = func1.pos_args + mixed1: list[Function.Argument] = func1.args + kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args} + pos2: list[Function.Argument] = func2.pos_args + mixed2: list[Function.Argument] = func2.args + kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args} + + mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2} + mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2} + + def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool: + if not self.is_subtype(sub.type, sup.type): + return False + if not sup.required and sub.required: + return False + return True + + for arg1 in pos1: + arg2: Function.Argument + if arg1.pos < len(pos2): + arg2 = pos2[arg1.pos] + elif arg1.pos in mixed_by_pos: + arg2 = mixed_by_pos[arg1.pos] + elif not arg1.required: + continue + else: + return False + if not is_arg_subtype(arg2, arg1): + return False + + for name, arg1 in kw1.items(): + arg2: Function.Argument + if name in kw2: + arg2 = kw2[name] + elif name in mixed_by_name: + arg2 = mixed_by_name[name] + elif not arg1.required: + continue + else: + return False + if not is_arg_subtype(arg2, arg1): + return False + + for arg1 in mixed1: + pos_arg2: Optional[Function.Argument] = None + kw_arg2: Optional[Function.Argument] = None + if arg1.name in kw2: + kw_arg2 = kw2[arg1.name] + elif arg1.name in mixed_by_name: + kw_arg2 = mixed_by_name[arg1.name] + if arg1.pos < len(pos2): + pos_arg2 = pos2[arg1.pos] + elif arg1.pos in mixed_by_pos: + pos_arg2 = mixed_by_pos[arg1.pos] + + # No match in func2 and arg is required + if pos_arg2 is None and kw_arg2 is None and arg1.required: + return False + + # Matching keyword argument + if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1): + return False + + # Matching positional argument + if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1): + return False + + mixed_positions: set[int] = {a.pos for a in mixed1} + mixed_names: set[str] = {a.name for a in mixed1} + for arg2 in pos2: + if not arg2.required: + continue + if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions: + return False + + for name, arg2 in kw2.items(): + if not arg2.required: + continue + if name not in kw1 and name not in mixed_names: + return False + + for arg2 in mixed2: + if arg2.required: + continue + pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions + kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names + if not pos_match or not kw_match: + return False + + return True + + def apply_generic(self, type: Type, args: list[Type]) -> Type: + match type: + case AliasType(name=name, type=base): + return AliasType(name=name, type=self.apply_generic(base, args)) + + case GenericType(name=name, params=type_vars, body=body): + n_args: int = len(args) + n_type_vars: int = len(type_vars) + if n_args < n_type_vars: + raise ValueError( + f"Missing type arguments, expected {n_type_vars} but only {n_args} provided" + ) + if n_args > n_type_vars: + raise ValueError( + f"Too many type arguments, expected {n_type_vars} but {n_args} provided" + ) + substitutions: dict[str, Type] = {} + for arg, type_var in zip(args, type_vars): + if type_var.bound is not None and not self.is_subtype( + arg, type_var.bound + ): + raise ValueError( + f"Type argument {arg} is not a subtype of {type_var.bound}" + ) + substitutions[type_var.name] = arg + return AppliedType( + name=name, + args=args, + body=substitute_typevars(body, substitutions), + ) + + case _: + raise ValueError(f"{type} is not a generic type") + + def reduce_types(self, types: list[Type]) -> list[Type]: + """Reduce a list of types to remove subtypes and only keep the highest types + + Args: + types (list[Type]): the types to reduce + + Returns: + list[Type]: the reduced list of types + """ + + reduced: bool = True + keep: list[int] = list(range(len(types))) + while reduced: + reduced = False + for i, i1 in enumerate(keep): + type1: Type = types[i1] + for i2 in keep[i + 1 :]: + type2 = types[i2] + if self.is_subtype(type1, type2): + keep.remove(i1) + elif self.is_subtype(type2, type1): + keep.remove(i2) + else: + continue + reduced = True + break + return [types[i] for i in keep] + + def lookup_member(self, type: Type, member_name: str) -> Optional[Type]: + match type: + case BaseType(name=name): + if name in self._members: + if member_name in self._members[name]: + return self._members[name][member_name] + return None + + case AliasType(name=name, type=base): + if name in self._members: + if member_name in self._members[name]: + return self._members[name][member_name] + return self.lookup_member(base, member_name) + + case AppliedType(name=name, body=body, args=args): + generic: Type = self.get_type(name) + + if not isinstance(generic, GenericType): + raise ValueError("AppliedType not derived from a GenericType") + + substitutions = { + type_var.name: arg for arg, type_var in zip(args, generic.params) + } + if name in self._members: + if member_name in self._members[name]: + member_type: Type = self._members[name][member_name] + return substitute_typevars(member_type, substitutions) + + member_type2: Optional[Type] = self.lookup_member(body, member_name) + if member_type2 is not None: + member_type2 = substitute_typevars(member_type2, substitutions) + return member_type2 + + case ComplexType(members=members): + if member_name in members: + return members[member_name] + self.logger.debug(f"No member '{member_name}' in {type}") + return None + + case ExtensionType(base=base, extension=ComplexType(members=members)): + if member_name in members: + return members[member_name] + self.logger.debug( + f"No member '{member_name}' on {type}, looking up in base" + ) + return self.lookup_member(base, member_name) + + case UnknownType(): + return UnknownType() + + case _: + self.logger.debug(f"Can't get member on {type}") + return None diff --git a/midas/checker/reporter.py b/midas/checker/reporter.py new file mode 100644 index 0000000..b68766a --- /dev/null +++ b/midas/checker/reporter.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import Optional + +from midas.ast.location import Location +from midas.checker.diagnostic import Diagnostic, DiagnosticType + + +class Reporter: + def __init__(self): + self.diagnostics: list[Diagnostic] = [] + + def report( + self, + path: Optional[str], + type: DiagnosticType, + location: Location, + message: str, + ): + self.diagnostics.append( + Diagnostic( + file_path=path, + location=location, + type=type, + message=message, + ) + ) + + def for_file(self, path: Optional[str]) -> FileReporter: + return FileReporter(self, path) + + +class FileReporter: + def __init__(self, base_reporter: Reporter, path: Optional[str]) -> None: + self.base_reporter: Reporter = base_reporter + self.path: Optional[str] = path + + def for_file(self, path: Optional[str]) -> FileReporter: + return FileReporter(self.base_reporter, path) + + def report(self, type: DiagnosticType, location: Location, message: str): + self.base_reporter.report(self.path, type, location, message) + + def error(self, location: Location, message: str): + self.report( + type=DiagnosticType.ERROR, + location=location, + message=message, + ) + + def warning(self, location: Location, message: str): + self.report( + type=DiagnosticType.WARNING, + location=location, + message=message, + ) + + def info(self, location: Location, message: str): + self.report( + type=DiagnosticType.INFO, + location=location, + message=message, + ) diff --git a/midas/resolver/resolver.py b/midas/checker/resolver.py similarity index 85% rename from midas/resolver/resolver.py rename to midas/checker/resolver.py index 18fcba4..12f18cf 100644 --- a/midas/resolver/resolver.py +++ b/midas/checker/resolver.py @@ -13,7 +13,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): def __init__(self): self.locals: dict[p.Expr, int] = {} - self.scopes: list[dict[str, bool]] = [] + self.scopes: list[dict[str, bool]] = [{}] def resolve(self, *objects: p.Stmt | p.Expr) -> None: """Resolve the given statements or expressions""" @@ -77,6 +77,12 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): self.locals[expr] = i return + def is_defined(self, name: str) -> bool: + for scope in self.scopes: + if name in scope: + return True + return False + def resolve_function(self, function: p.Function) -> None: """Resolve a function definition @@ -111,7 +117,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): self.resolve(stmt.value) for target in stmt.targets: match target: - case p.VariableExpr() | p.GetExpr(): + case p.VariableExpr(name=name): + if not self.is_defined(name): + self.declare(name) + self.define(name) + target.accept(self) + + case p.GetExpr(): target.accept(self) case _: raise Exception(f"Unsupported assignment to {target}") @@ -180,3 +192,19 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): self.resolve(expr.test) self.resolve(expr.if_true) self.resolve(expr.if_false) + + def visit_list_expr(self, expr: p.ListExpr) -> None: + for item in expr.items: + self.resolve(item) + + def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: + self.resolve(expr.object) + self.resolve(expr.index) + + def visit_slice_expr(self, expr: p.SliceExpr) -> None: + if expr.lower is not None: + self.resolve(expr.lower) + if expr.upper is not None: + self.resolve(expr.upper) + if expr.step is not None: + self.resolve(expr.step) diff --git a/midas/checker/types.py b/midas/checker/types.py index 83707b6..c6d41d1 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -1,37 +1,68 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Optional + + +@dataclass(frozen=True, kw_only=True) +class TopType: + def __str__(self) -> str: + return "Any" @dataclass(frozen=True, kw_only=True) class BaseType: name: str + def __str__(self) -> str: + return self.name + @dataclass(frozen=True, kw_only=True) class AliasType: name: str type: Type + def __str__(self) -> str: + return self.name + @dataclass(frozen=True, kw_only=True) class UnknownType: - pass + def __str__(self) -> str: + return "" @dataclass(frozen=True, kw_only=True) class UnitType: - pass + def __str__(self) -> str: + return "None" @dataclass(frozen=True, kw_only=True) class Function: - name: str pos_args: list[Argument] args: list[Argument] kw_args: list[Argument] returns: Type + def __str__(self) -> str: + args: list[str] = [] + if len(self.pos_args) != 0: + args += list(map(str, self.pos_args)) + if len(self.args) + len(self.kw_args) != 0: + args.append("/") + + if len(self.args) != 0: + args += list(map(str, self.args)) + + if len(self.kw_args) != 0: + if len(args) != 0: + args.append("*") + args += list(map(str, self.kw_args)) + + return f"({', '.join(args)}) -> {self.returns}" + @dataclass(frozen=True, kw_only=True) class Argument: pos: int @@ -39,22 +70,164 @@ class Function: type: Type required: bool + def __str__(self) -> str: + opt: str = "" if self.required else "?" + return f"{self.name}: {self.type}{opt}" + + +@dataclass(frozen=True, kw_only=True) +class OverloadedFunction: + overloads: list[Type] + + def __str__(self) -> str: + return "" + @dataclass(frozen=True, kw_only=True) class ComplexType: - properties: dict[str, Type] + members: dict[str, Type] + + def __str__(self) -> str: + props: list[str] = [f"{name}: {type}" for name, type in self.members.items()] + return f"{{{', '.join(props)}}}" @dataclass(frozen=True, kw_only=True) -class Operation: - signature: CallSignature - result: Type +class ExtensionType: + base: Type + extension: ComplexType - @dataclass(frozen=True, kw_only=True) - class CallSignature: - left: Type - method: str - right: Type + def __str__(self) -> str: + return f"{self.base} & {self.extension}" -Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType +@dataclass(frozen=True, kw_only=True) +class TypeVar: + name: str + bound: Optional[Type] + + def __str__(self) -> str: + if self.bound is not None: + return f"{self.name} <: {self.bound}" + return self.name + + +@dataclass(frozen=True, kw_only=True) +class GenericType: + name: str + params: list[TypeVar] + body: Type + + def __str__(self) -> str: + return f"{self.name}[{', '.join(map(str, self.params))}]" + + +@dataclass(frozen=True, kw_only=True) +class AppliedType: + name: str + args: list[Type] + body: Type + + def __str__(self) -> str: + return f"{self.name}[{', '.join(map(str, self.args))}]" + + +def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: + def sub_argument(arg: Function.Argument): + return Function.Argument( + pos=arg.pos, + name=arg.name, + type=substitute_typevars(arg.type, substitutions), + required=arg.required, + ) + + match type: + case BaseType(name=name) if name in substitutions: + return substitutions[name] + + case BaseType(): + return type + + case AliasType(name=name, type=type2): + return AliasType(name=name, type=substitute_typevars(type2, substitutions)) + + case Function( + pos_args=pos_args, + args=args, + kw_args=kw_args, + returns=returns, + ): + return Function( + pos_args=list(map(sub_argument, pos_args)), + args=list(map(sub_argument, args)), + kw_args=list(map(sub_argument, kw_args)), + returns=substitute_typevars(returns, substitutions), + ) + + case OverloadedFunction(overloads=overloads): + return OverloadedFunction( + overloads=[ + substitute_typevars(overload, substitutions) + for overload in overloads + ] + ) + + case ComplexType(members=members): + members2: dict[str, Type] = { + name: substitute_typevars(prop, substitutions) + for name, prop in members.items() + } + return ComplexType(members=members2) + + case ExtensionType(base=base, extension=ComplexType(members=members)): + return ExtensionType( + base=substitute_typevars(base, substitutions), + extension=ComplexType( + members={ + name: substitute_typevars(prop, substitutions) + for name, prop in members.items() + } + ), + ) + + case AppliedType(name=name, args=args, body=body): + return AppliedType( + name=name, + args=[substitute_typevars(arg, substitutions) for arg in args], + body=substitute_typevars(body, substitutions), + ) + + case TypeVar(name=name): + if name in substitutions: + return substitutions[name] + raise ValueError(f"Missing TypeVar substitution for {name}") + + case UnknownType() | UnitType(): + return type + + case _: + raise NotImplementedError(f"Unsupported type {type}") + + +def unfold_type(type: Type) -> Type: + match type: + case AliasType(type=ref_type): + return unfold_type(ref_type) + case _: + return type + + +Type = ( + TopType + | BaseType + | AliasType + | UnknownType + | UnitType + | Function + | OverloadedFunction + | ComplexType + | ExtensionType + | TypeVar + | GenericType + | AppliedType +) diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index e4a9556..bc7727c 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -214,6 +214,22 @@ class PythonHighlighter( def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ... + def visit_list_expr(self, expr: p.ListExpr) -> None: + for item in expr.items: + item.accept(self) + + def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: + expr.object.accept(self) + expr.index.accept(self) + + def visit_slice_expr(self, expr: p.SliceExpr) -> None: + if expr.lower is not None: + expr.lower.accept(self) + if expr.upper is not None: + expr.upper.accept(self) + if expr.step is not None: + expr.step.accept(self) + class MidasHighlighter( Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None] @@ -228,21 +244,14 @@ class MidasHighlighter( self.wrap(LocatableToken(stmt.name), "type-name") stmt.type.accept(self) - def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: - self.wrap(stmt, "property") + def visit_member_stmt(self, stmt: m.MemberStmt) -> None: + self.wrap(stmt, "member") stmt.type.accept(self) def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: self.wrap(stmt, "extend") - stmt.type.accept(self) - for op in stmt.operations: - op.accept(self) - - def visit_op_stmt(self, stmt: m.OpStmt) -> None: - self.wrap(stmt, "op") - self.wrap(LocatableToken(stmt.name), "op-name") - stmt.operand.accept(self) - stmt.result.accept(self) + for member in stmt.members: + member.accept(self) def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: self.wrap(stmt, "predicate") @@ -284,8 +293,8 @@ class MidasHighlighter( def visit_generic_type(self, type: m.GenericType) -> None: self.wrap(type, "generic-type") type.type.accept(self) - for param in type.params: - param.accept(self) + for arg in type.args: + arg.accept(self) def visit_constraint_type(self, type: m.ConstraintType) -> None: self.wrap(type, "constraint-type") @@ -294,8 +303,19 @@ class MidasHighlighter( def visit_complex_type(self, type: m.ComplexType) -> None: self.wrap(type, "complex-type") - for prop in type.properties: - prop.accept(self) + for member in type.members: + member.accept(self) + + def visit_function_type(self, type: m.FunctionType) -> None: + self.wrap(type, "function") + for arg in type.pos_args + type.args + type.kw_args: + arg.type.accept(self) + type.returns.accept(self) + + def visit_extension_type(self, type: m.ExtensionType) -> None: + self.wrap(type, "extension") + type.base.accept(self) + type.extension.accept(self) class DiagnosticsHighlighter(Highlighter): diff --git a/midas/cli/main.py b/midas/cli/main.py index ae4295b..af95abd 100644 --- a/midas/cli/main.py +++ b/midas/cli/main.py @@ -10,7 +10,7 @@ import midas.ast.midas as m import midas.ast.python as p from midas.ast.location import Location from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter -from midas.checker.checker import Checker +from midas.checker.checker import TypeChecker from midas.checker.diagnostic import Diagnostic, DiagnosticType from midas.checker.types import Type from midas.cli.ansi import Ansi @@ -25,7 +25,6 @@ from midas.lexer.midas import MidasLexer from midas.lexer.token import Token, TokenType from midas.parser.midas import MidasParser from midas.parser.python import PythonParser -from midas.resolver.resolver import Resolver from midas.utils import UniversalJSONDumper @@ -89,36 +88,57 @@ def print_diagnostic(lines: list[str], diagnostic: Diagnostic, indent: int = 4): @click.option("-l", "--highlight", type=click.File("w")) @click.option("-t", "--types", type=click.File("r"), multiple=True) @click.option("-v", "--verbose", is_flag=True) +@click.option("-j", "--show-judgements", is_flag=True) @click.argument("file", type=click.File("r")) def compile( highlight: Optional[TextIO], types: tuple[TextIO], verbose: bool, + show_judgements: bool, file: TextIO, ): logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN) source: str = file.read() - tree: ast.Module = ast.parse(source, filename=file.name) - parser = PythonParser() - stmts: list[p.Stmt] = parser.parse_module(tree) - resolver = Resolver() - resolver.resolve(*stmts) - types_paths: list[Path] = [Path(t.name).resolve() for t in types] - checker = Checker( - resolver.locals, - source_path=Path(file.name).resolve(), - types_paths=types_paths, - ) - diagnostics: list[Diagnostic] = checker.check(stmts) + source_path: Path = Path(file.name).resolve() + + checker = TypeChecker() + for types_file in types: + checker.import_midas(Path(types_file.name).resolve()) + + checker.type_check_source(source, str(source_path)) + diagnostics: list[Diagnostic] = checker.diagnostics.copy() lines: list[str] = source.split("\n") + files: dict[Optional[str], list[str]] = {None: []} + + if show_judgements: + for expr, type in checker.python_typer.judgements: + print(f"Judged that {expr} at {expr.location} is of type {type}") + diagnostics.append( + Diagnostic( + file_path=str(source_path), + location=expr.location, + type=DiagnosticType.INFO, + message=f"Type: {type}", + ) + ) + for diagnostic in diagnostics: + filename: Optional[str] = diagnostic.file_path + if filename is not None and filename not in files: + path: Path = Path(filename) + if path.exists() and path.is_file(): + files[filename] = path.read_text().split("\n") + else: + files[filename] = [] + + lines: list[str] = files[filename] print_diagnostic(lines, diagnostic) if verbose: print( json.dumps( UniversalJSONDumper.dump( - checker.global_env, + checker.python_typer.global_env, [("Environment", "_children")], lambda obj: isinstance(obj, get_args(Type)), ), diff --git a/midas/lexer/midas.py b/midas/lexer/midas.py index 124ea09..c3246fc 100644 --- a/midas/lexer/midas.py +++ b/midas/lexer/midas.py @@ -50,12 +50,14 @@ class MidasLexer(Lexer): # self.add_token(TokenType.PLUS) case "-": self.add_token(TokenType.MINUS) - # case "*": - # self.add_token(TokenType.STAR) + case "*": + self.add_token(TokenType.STAR) case "/" if self.match("/"): self.scan_comment() case "/" if self.match("*"): self.scan_comment_multiline() + case "/": + self.add_token(TokenType.SLASH) case "\n": self.add_token(TokenType.NEWLINE) case " " | "\r" | "\t": diff --git a/midas/lexer/token.py b/midas/lexer/token.py index f08964a..f0c08a1 100644 --- a/midas/lexer/token.py +++ b/midas/lexer/token.py @@ -27,8 +27,8 @@ class TokenType(Enum): # Operators # PLUS = auto() MINUS = auto() - # STAR = auto() - # SLASH = auto() + STAR = auto() + SLASH = auto() GREATER = auto() GREATER_EQUAL = auto() LESS = auto() @@ -46,10 +46,12 @@ class TokenType(Enum): # Keywords TYPE = auto() - OP = auto() PREDICATE = auto() EXTEND = auto() WHERE = auto() + PROP = auto() + DEF = auto() + FUNC = auto() # Misc COMMENT = auto() @@ -60,13 +62,15 @@ class TokenType(Enum): KEYWORDS: dict[str, TokenType] = { "type": TokenType.TYPE, - "op": TokenType.OP, "predicate": TokenType.PREDICATE, "extend": TokenType.EXTEND, "where": TokenType.WHERE, "true": TokenType.TRUE, "false": TokenType.FALSE, "none": TokenType.NONE, + "prop": TokenType.PROP, + "def": TokenType.DEF, + "fn": TokenType.FUNC, } diff --git a/midas/parser/midas.py b/midas/parser/midas.py index 5d09b83..33069f3 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -7,23 +7,26 @@ from midas.ast.midas import ( ConstraintType, Expr, ExtendStmt, + ExtensionType, + FunctionType, GenericType, GetExpr, GroupingExpr, LiteralExpr, LogicalExpr, + MemberKind, + MemberStmt, NamedType, - OpStmt, PredicateStmt, - PropertyStmt, Stmt, Type, + TypeParam, TypeStmt, UnaryExpr, VariableExpr, WildcardExpr, ) -from midas.lexer.token import Token, TokenType +from midas.lexer.token import KEYWORDS, Token, TokenType from midas.parser.base import Parser from midas.parser.errors import ParsingError @@ -33,9 +36,10 @@ class MidasParser(Parser): SYNC_BOUNDARY: set[TokenType] = { TokenType.TYPE, - TokenType.OP, TokenType.EXTEND, TokenType.PREDICATE, + TokenType.PROP, + TokenType.FUNC, } def parse(self) -> list[Stmt]: @@ -107,10 +111,8 @@ class MidasParser(Parser): TypeStmt: the parsed type declaration statement """ keyword: Token = self.previous() - name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") - params: list[TypeStmt.Param] = [] - if self.check(TokenType.LEFT_BRACKET): - params = self.type_stmt_params() + name: Token = self.consume_identifier("Expected type name") + params: list[TypeParam] = self.type_params() self.consume(TokenType.EQUAL, "Expected '=' before type definition") @@ -123,24 +125,27 @@ class MidasParser(Parser): type=type, ) - def type_stmt_params(self) -> list[TypeStmt.Param]: - """Parse a generic template expression + def type_params(self) -> list[TypeParam]: + """Parse a list of type parameters - A template is written `[TypeExpr]` + Type parameters are a comma-separated list of type variables wrapped in brackets. + Each type variable is either a simple variable, or a bounded variable written `S <: T` Returns: - TemplateExpr: the parsed template expression + list[TypeParam]: the list of type parameters, if any, or an empty list """ - self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression") - params: list[TypeStmt.Param] = [] + if not self.match(TokenType.LEFT_BRACKET): + return [] + + params: list[TypeParam] = [] while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET): - name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable") + name: Token = self.consume_identifier("Expected type variable") bound: Optional[Type] = None if self.match(TokenType.LESS): self.consume(TokenType.COLON, "Expected ':' after '<'") bound = self.type_expr() params.append( - TypeStmt.Param( + TypeParam( location=name.location_to(self.previous()), name=name, bound=bound, @@ -148,7 +153,7 @@ class MidasParser(Parser): ) if not self.match(TokenType.COMMA): break - self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression") + self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters") return params def type_expr(self) -> Type: @@ -160,7 +165,19 @@ class MidasParser(Parser): Returns: TypeExpr: the parsed type expression """ - return self.constraint_type() + base: Type + if self.match(TokenType.FUNC): + base = self.function() + else: + base = self.constraint_type() + if self.match(TokenType.AND): + extension: ComplexType = self.complex_type() + return ExtensionType( + location=Location.span(base.location, extension.location), + base=base, + extension=extension, + ) + return base def constraint_type(self) -> Type: type: Type = self.base_type() @@ -187,55 +204,57 @@ class MidasParser(Parser): def generic_type(self) -> Type: type: Type = self.named_type() if self.check(TokenType.LEFT_BRACKET): - params: list[Type] = self.type_params() + args: list[Type] = self.type_args() return GenericType( location=Location.span(type.location, self.previous().get_location()), type=type, - params=params, + args=args, ) return type - def type_params(self) -> list[Type]: - params: list[Type] = [] - self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters") + def type_args(self) -> list[Type]: + args: list[Type] = [] + self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments") while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET): - params.append(self.type_expr()) + args.append(self.type_expr()) if not self.match(TokenType.COMMA): break - self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters") - return params + self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments") + return args def named_type(self) -> Type: - name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") + name: Token = self.consume_identifier("Expected type name") return NamedType( location=name.get_location(), name=name, ) - def complex_type(self) -> Type: + def complex_type(self) -> ComplexType: """Parse a type definition body A type definition body is a set of whitespace-separated property statements enclosed in curly braces Returns: - list[PropertyStmt]: the parsed type properties + ComplexType: the parsed complex type """ left: Token = self.consume( TokenType.LEFT_BRACE, "Expected '{' to start type body" ) - properties: list[PropertyStmt] = [] + members: list[MemberStmt] = [] + # TODO: add keyword to differentiate properties and methods, + # and allow multiple methods with the same name but not properties names: set[str] = set() while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end(): - prop: PropertyStmt = self.property_stmt() - if prop.name.lexeme in names: - raise self.error(prop.name, "Duplicate property") - names.add(prop.name.lexeme) - properties.append(prop) + member: MemberStmt = self.member_stmt() + # if member.name.lexeme in names: + # raise self.error(member.name, "Duplicate property") + # names.add(member.name.lexeme) + members.append(member) right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body") return ComplexType( location=left.location_to(right), - properties=properties, + members=members, ) def constraint(self) -> Expr: @@ -322,9 +341,7 @@ class MidasParser(Parser): """ expr: Expr = self.primary() while self.match(TokenType.DOT): - name: Token = self.consume( - TokenType.IDENTIFIER, "Expected property name after '.'" - ) + name: Token = self.consume_identifier("Expected property name after '.'") location: Location = Location.span(expr.location, name.get_location()) expr = GetExpr(location=location, expr=expr, name=name) return expr @@ -348,7 +365,7 @@ class MidasParser(Parser): if self.match(TokenType.NUMBER): return LiteralExpr(location=token.get_location(), value=token.value) - if self.match(TokenType.IDENTIFIER): + if self.match_identifier(): return VariableExpr(location=token.get_location(), name=token) if self.match(TokenType.UNDERSCORE): @@ -361,64 +378,70 @@ class MidasParser(Parser): raise self.error(self.peek(), "Expected expression") - def property_stmt(self) -> PropertyStmt: - """Parse a property statement + def consume_identifier(self, message: str = "Expected identifier") -> Token: + if not self.match_identifier(): + raise self.error(self.peek(), message) + return self.previous() - A type property statement is written `name: Type` or `name: Type where Condition` + def match_identifier(self) -> bool: + return self.match(TokenType.IDENTIFIER, *KEYWORDS.values()) + + def check_identifier(self) -> bool: + for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]: + if self.check(tt): + return True + return False + + def member_stmt(self) -> MemberStmt: + """Parse a member statement + + A type member statement is written `prop name: Type` or `def name: Type` Returns: - PropertyStmt: the parsed property statement + MemberStmt: the parsed member statement """ - name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name") - self.consume(TokenType.COLON, "Expected ':' after property name") + kind: MemberKind + if self.match(TokenType.PROP): + kind = MemberKind.PROPERTY + elif self.match(TokenType.DEF): + kind = MemberKind.METHOD + else: + raise self.error(self.peek(), "Expected 'prop' or 'def'") + + name: Token = self.consume_identifier("Expected member name") + self.consume(TokenType.COLON, "Expected ':' after member name") + type: Type = self.type_expr() - return PropertyStmt( + return MemberStmt( location=name.location_to(self.previous()), name=name, type=type, + kind=kind, ) def extend_declaration(self) -> ExtendStmt: """Parse an extension definition - An extension is written `extend Type { operations }` + An extension is written `extend Type { operations }` or `extend[S <: T, U] Type { operations }` Returns: ExtendStmt: the parsed extension statement """ keyword: Token = self.previous() - type: Type = self.type_expr() + name: Token = self.consume_identifier("Expected type name") + params: list[TypeParam] = self.type_params() + self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body") - operations: list[OpStmt] = [] + members: list[MemberStmt] = [] while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE): - operations.append(self.op_declaration()) + members.append(self.member_stmt()) self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body") location: Location = keyword.location_to(self.previous()) - return ExtendStmt(location=location, type=type, operations=operations) - - def op_declaration(self) -> OpStmt: - """Parse an operation definition - - An operation is written `op name(Type) -> Type` - - Returns: - OpStmt: the parsed operation statement - """ - keyword: Token = self.consume(TokenType.OP, "Expected 'op' keyword") - - name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name") - self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type") - operand: Type = self.type_expr() - self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type") - - self.consume(TokenType.ARROW, "Expected '->' before result type") - result: Type = self.type_expr() - - return OpStmt( - location=keyword.location_to(self.previous()), + return ExtendStmt( + location=location, name=name, - operand=operand, - result=result, + params=params, + members=members, ) def predicate_declaration(self) -> PredicateStmt: @@ -430,9 +453,9 @@ class MidasParser(Parser): PredicateStmt: the parsed predicate declaration statement """ keyword: Token = self.previous() - name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name") + name: Token = self.consume_identifier("Expected predicate name") self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject") - subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name") + subject: Token = self.consume_identifier("Expected subject name") self.consume(TokenType.COLON, "Expected ':' after subject name") type: Type = self.type_expr() self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject") @@ -445,3 +468,72 @@ class MidasParser(Parser): type=type, condition=condition, ) + + def function(self) -> FunctionType: + l_paren: Token = self.consume( + TokenType.LEFT_PAREN, "Expected '(' before function parameters" + ) + pos_args: list[FunctionType.Argument] = [] + args: list[FunctionType.Argument] = [] + kw_args: list[FunctionType.Argument] = [] + + args_first_tokens: list[Token] = [] + + section: int = 0 + while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN): + match section: + case 0 if self.match(TokenType.SLASH): + pos_args = args + args = [] + args_first_tokens = [] + section = 1 + case 0 | 1 if self.match(TokenType.STAR): + section = 2 + case _: + # Record first token of mixed argument for errors if unnamed + if section != 2: + args_first_tokens.append(self.peek()) + + name: Optional[Token] = None + if section == 2: + name = self.consume_identifier("Expected keyword argument name") + self.consume( + TokenType.COLON, "Expected ':' after argument name" + ) + elif self.check_identifier() and self.check_next(TokenType.COLON): + name = self.advance() + self.advance() + + type: Type = self.type_expr() + optional: bool = self.match(TokenType.QMARK) + arg = FunctionType.Argument( + location=None, + name=name, + type=type, + required=not optional, + ) + if section == 2: + kw_args.append(arg) + else: + args.append(arg) + + if not self.match(TokenType.COMMA): + break + + for arg, token in zip(args, args_first_tokens): + if arg.name is None: + # Not raised because we can keep parsing + self.error(token, "Unnamed mixed argument") + + self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters") + + self.consume(TokenType.ARROW, "Expected '->' before result type") + result: Type = self.type_expr() + + return FunctionType( + location=l_paren.location_to(self.previous()), + pos_args=pos_args, + args=args, + kw_args=kw_args, + returns=result, + ) diff --git a/midas/parser/python.py b/midas/parser/python.py index 79011bc..a0726da 100644 --- a/midas/parser/python.py +++ b/midas/parser/python.py @@ -17,11 +17,14 @@ from midas.ast.python import ( Function, GetExpr, IfStmt, + ListExpr, LiteralExpr, LogicalExpr, MidasType, ReturnStmt, + SliceExpr, Stmt, + SubscriptExpr, TernaryExpr, TypeAssign, UnaryExpr, @@ -416,6 +419,27 @@ class PythonParser: case ast.Name(id=name): return VariableExpr(location=location, name=name) + case ast.List(elts=items): + return ListExpr( + location=location, + items=[self.parse_expr(item) for item in items], + ) + + case ast.Subscript(value=value, slice=index): + return SubscriptExpr( + location=location, + object=self.parse_expr(value), + index=self.parse_expr(index), + ) + + case ast.Slice(lower=lower, upper=upper, step=step): + return SliceExpr( + location=location, + lower=self.parse_expr(lower) if lower is not None else None, + upper=self.parse_expr(upper) if upper is not None else None, + step=self.parse_expr(step) if step is not None else None, + ) + case _: raise UnsupportedSyntaxError(node) diff --git a/midas/resolver/builtin.py b/midas/resolver/builtin.py deleted file mode 100644 index 04bc6e3..0000000 --- a/midas/resolver/builtin.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from midas.checker.types import BaseType, Type, UnitType - -if TYPE_CHECKING: - from midas.resolver.midas import MidasResolver - - -def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type): - ctx.define_operation( - left=t1, - operator=operator, - right=t2, - result=t3, - ) - - -def basic_op(ctx: MidasResolver, type: Type, op: str): - ctx.define_operation( - left=type, - operator=op, - right=type, - result=type, - ) - - -def define_builtins(ctx: MidasResolver): - """Define builtin types and operations""" - unit = ctx.define_type("None", UnitType()) - bool = ctx.define_type("bool", BaseType(name="bool")) - int = ctx.define_type("int", BaseType(name="int")) - float = ctx.define_type("float", BaseType(name="float")) - str = ctx.define_type("str", BaseType(name="str")) - - basic_op(ctx, int, "__add__") # int + int = int - basic_op(ctx, int, "__sub__") # int - int = int - basic_op(ctx, int, "__mul__") # int * int = int - basic_op(ctx, int, "__pow__") # int ** int = int - basic_op(ctx, int, "__mod__") # int % int = int - basic_op(ctx, int, "__and__") # int & int = int - basic_op(ctx, int, "__or__") # int | int = int - basic_op(ctx, int, "__xor__") # int ^ int = int - op(ctx, int, "__lt__", int, bool) # int < int = bool - op(ctx, int, "__gt__", int, bool) # int > int = bool - op(ctx, int, "__le__", int, bool) # int <= int = bool - op(ctx, int, "__ge__", int, bool) # int >= int = bool - op(ctx, int, "__eq__", int, bool) # int == int = bool - basic_op(ctx, float, "__add__") # float + float = float - basic_op(ctx, float, "__sub__") # float - float = float - basic_op(ctx, float, "__mul__") # float * float = float - basic_op(ctx, float, "__truediv__") # float / float = float - op(ctx, float, "__lt__", float, bool) # float < float = bool - op(ctx, float, "__gt__", float, bool) # float > float = bool - op(ctx, float, "__le__", float, bool) # float <= float = bool - op(ctx, float, "__ge__", float, bool) # float >= float = bool - op(ctx, float, "__eq__", float, bool) # float == float = bool - basic_op(ctx, str, "__add__") # str + str = str - op(ctx, str, "__eq__", str, bool) # str == str = bool - - op(ctx, int, "__lt__", float, bool) # int < float = bool - op(ctx, int, "__gt__", float, bool) # int > float = bool - op(ctx, int, "__le__", float, bool) # int <= float = bool - op(ctx, int, "__ge__", float, bool) # int >= float = bool - op(ctx, int, "__eq__", float, bool) # int == float = bool - - op(ctx, float, "__lt__", int, bool) # float < int = bool - op(ctx, float, "__gt__", int, bool) # float > int = bool - op(ctx, float, "__le__", int, bool) # float <= int = bool - op(ctx, float, "__ge__", int, bool) # float >= int = bool - op(ctx, float, "__eq__", int, bool) # float == int = bool diff --git a/midas/resolver/midas.py b/midas/resolver/midas.py deleted file mode 100644 index 468f59a..0000000 --- a/midas/resolver/midas.py +++ /dev/null @@ -1,186 +0,0 @@ -from typing import Optional - -import midas.ast.midas as m -from midas.checker.types import ( - AliasType, - ComplexType, - Operation, - Type, - UnknownType, -) -from midas.resolver.builtin import define_builtins - - -class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]): - """A resolver which evaluates Midas type definitions and build a registry""" - - def __init__(self) -> None: - self._types: dict[str, Type] = {} - self._operations: dict[Operation.CallSignature, Type] = {} - - define_builtins(self) - - def get_type(self, name: str) -> Type: - """Get a type from its name - - Args: - name (str): the name of the type - - Raises: - NameError: if the type is not defined - - Returns: - Type: the type - """ - type: Optional[Type] = self._types.get(name) - if type is None: - raise NameError(f"Undefined type {name}") - return type - - def get_operation_result( - self, left: Type, operator: str, right: Type - ) -> Optional[Type]: - """Get the resulting type of an operation - - Args: - left (Type): the type of the left operand - operator (str): the operation name - right (Type): the type of the right operand - - Returns: - Optional[Type]: the result type, or None if no matching operation was found - """ - signature: Operation.CallSignature = Operation.CallSignature( - left=left, - method=operator, - right=right, - ) - result: Optional[Type] = self._operations.get(signature) - return result - - def get_operations_by_name(self, name: str) -> list[Operation]: - operations: list[Operation] = [] - for signature, result in self._operations.items(): - if signature.method == name: - operations.append( - Operation( - signature=signature, - result=result, - ) - ) - return operations - - def define_type(self, name: str, type: Type) -> Type: - """Define a type in the registry - - Args: - name (str): the name of the type - type (Type): the type to define - - Raises: - ValueError: if a type is already defined with that name - - Returns: - Type: the defined type - """ - if name in self._types: - raise ValueError(f"Type {name} already defined") - self._types[name] = type - return type - - def define_operation(self, left: Type, operator: str, right: Type, result: Type): - """Define an operation in the registry - - Args: - left (Type): the type of the left operand - operator (str): the operation name - right (Type): the type of the right operand - result (Type): the result type - - Raises: - ValueError: if an operation is already defined with these operands and name - """ - signature: Operation.CallSignature = Operation.CallSignature( - left=left, - method=operator, - right=right, - ) - if signature in self._operations: - raise ValueError( - f"Operation {operator} already defined between {left} and {right}" - ) - self._operations[signature] = result - - def resolve(self, stmts: list[m.Stmt]): - """Process a sequence of statements - - Args: - stmts (list[m.Stmt]): the statements - """ - for stmt in stmts: - stmt.accept(self) - - def visit_type_stmt(self, stmt: m.TypeStmt) -> None: - type: Type = stmt.type.accept(self) - for param in stmt.params: - if param.bound is not None: - param.bound.accept(self) - name: str = stmt.name.lexeme - self.define_type(name, AliasType(name=name, type=type)) - - def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... - - def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: - base: Type = stmt.type.accept(self) - for op in stmt.operations: - right: Type = op.operand.accept(self) - result: Type = op.result.accept(self) - self.define_operation( - left=base, - operator=op.name.lexeme, - right=right, - result=result, - ) - - def visit_op_stmt(self, stmt: m.OpStmt) -> None: ... - - def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ... - - def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ... - - def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ... - - def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ... - - def visit_get_expr(self, expr: m.GetExpr) -> None: ... - - def visit_variable_expr(self, expr: m.VariableExpr) -> None: ... - - def visit_grouping_expr(self, expr: m.GroupingExpr) -> None: - return expr.expr.accept(self) - - def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ... - - def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ... - - def visit_named_type(self, type: m.NamedType) -> Type: - return self.get_type(type.name.lexeme) - - def visit_generic_type(self, type: m.GenericType) -> Type: - type_: Type = type.type.accept(self) - params: list[Type] = [param.accept(self) for param in type.params] - # TODO - return UnknownType() - - def visit_constraint_type(self, type: m.ConstraintType) -> Type: - type_: Type = type.type.accept(self) - type.constraint.accept(self) - # TODO - return UnknownType() - - def visit_complex_type(self, type: m.ComplexType) -> Type: - return ComplexType( - properties={ - prop.name.lexeme: prop.type.accept(self) for prop in type.properties - } - ) diff --git a/tests/cases/checker/01_simple_types.py.ref.json b/tests/cases/checker/01_simple_types.py.ref.json index 3c4d0b9..ac24fcd 100644 --- a/tests/cases/checker/01_simple_types.py.ref.json +++ b/tests/cases/checker/01_simple_types.py.ref.json @@ -1,4 +1,19 @@ { - "diagnostics": [], + "diagnostics": [ + { + "type": "Warning", + "location": { + "start": [ + 6, + 4 + ], + "end": [ + 13, + 5 + ] + }, + "message": "FrameType not yet supported" + } + ], "judgments": [] } \ No newline at end of file diff --git a/tests/cases/checker/02_simple_operations.py.ref.json b/tests/cases/checker/02_simple_operations.py.ref.json index 654af17..a2c5569 100644 --- a/tests/cases/checker/02_simple_operations.py.ref.json +++ b/tests/cases/checker/02_simple_operations.py.ref.json @@ -12,7 +12,21 @@ 13 ] }, - "message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')" + "message": "Cannot assign str to variable 'c' of type int" + }, + { + "type": "Error", + "location": { + "start": [ + 9, + 4 + ], + "end": [ + 9, + 9 + ] + }, + "message": "Undefined operation __add__ between bool and bool" } ], "judgments": [ @@ -158,9 +172,7 @@ "name": "d" } }, - "type": { - "name": "int" - } + "type": {} }, { "location": { diff --git a/tests/cases/checker/03_functions.py.ref.json b/tests/cases/checker/03_functions.py.ref.json index cd0ce42..fa06642 100644 --- a/tests/cases/checker/03_functions.py.ref.json +++ b/tests/cases/checker/03_functions.py.ref.json @@ -236,7 +236,7 @@ 13 ] }, - "message": "Wrong type for argument 'a', expected BaseType(name='int'), got BaseType(name='str')" + "message": "Wrong type for argument 'a', expected int, got str" }, { "type": "Error", @@ -250,10 +250,23 @@ 25 ] }, - "message": "Wrong type for argument 'c', expected BaseType(name='str'), got BaseType(name='bool')" + "message": "Wrong type for argument 'c', expected str, got bool" } ], "judgments": [ + { + "location": { + "from": "L2:11", + "to": "L2:15" + }, + "expr": { + "_type": "LiteralExpr", + "value": true + }, + "type": { + "name": "bool" + } + }, { "location": { "from": "L5:5", @@ -264,7 +277,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -314,9 +326,7 @@ "arguments": [], "keywords": {} }, - "type": { - "name": "bool" - } + "type": {} }, { "location": { @@ -328,7 +338,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -396,9 +405,7 @@ ], "keywords": {} }, - "type": { - "name": "bool" - } + "type": {} }, { "location": { @@ -410,7 +417,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -495,9 +501,7 @@ ], "keywords": {} }, - "type": { - "name": "bool" - } + "type": {} }, { "location": { @@ -509,7 +513,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -595,9 +598,7 @@ } } }, - "type": { - "name": "bool" - } + "type": {} }, { "location": { @@ -609,7 +610,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -711,9 +711,7 @@ ], "keywords": {} }, - "type": { - "name": "bool" - } + "type": {} }, { "location": { @@ -725,7 +723,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -828,9 +825,7 @@ } } }, - "type": { - "name": "bool" - } + "type": {} }, { "location": { @@ -842,7 +837,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -910,9 +904,7 @@ } } }, - "type": { - "name": "bool" - } + "type": {} }, { "location": { @@ -924,7 +916,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -992,9 +983,7 @@ } } }, - "type": { - "name": "bool" - } + "type": {} }, { "location": { @@ -1006,7 +995,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -1123,7 +1111,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -1240,7 +1227,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -1357,7 +1343,6 @@ "name": "foo" }, "type": { - "name": "foo", "pos_args": [ { "pos": 0, @@ -1460,9 +1445,7 @@ } } }, - "type": { - "name": "bool" - } + "type": {} } ] } \ No newline at end of file diff --git a/tests/cases/checker/04_custom_types.midas b/tests/cases/checker/04_custom_types.midas index 6a1a6a2..ff4edb1 100644 --- a/tests/cases/checker/04_custom_types.midas +++ b/tests/cases/checker/04_custom_types.midas @@ -3,12 +3,12 @@ type Second = float type MeterPerSecond = float extend Meter { - op __add__(Meter) -> Meter - op __sub__(Meter) -> Meter - op __truediv__(Second) -> MeterPerSecond + def __add__: fn(Meter, /) -> Meter + def __sub__: fn(Meter, /) -> Meter + def __truediv__: fn(Second, /) -> MeterPerSecond } extend Second { - op __add__(Second) -> Second - op __sub__(Second) -> Second + def __add__: fn(Second, /) -> Second + def __sub__: fn(Second, /) -> Second } diff --git a/tests/cases/checker/05_control_flow.py.ref.json b/tests/cases/checker/05_control_flow.py.ref.json index 8f031f2..be86030 100644 --- a/tests/cases/checker/05_control_flow.py.ref.json +++ b/tests/cases/checker/05_control_flow.py.ref.json @@ -70,6 +70,27 @@ "name": "int" } }, + { + "location": { + "from": "L2:11", + "to": "L2:16" + }, + "expr": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "a" + }, + "operator": "+", + "right": { + "_type": "VariableExpr", + "name": "b" + } + }, + "type": { + "name": "int" + } + }, { "location": { "from": "L5:7", @@ -96,6 +117,27 @@ "name": "int" } }, + { + "location": { + "from": "L5:7", + "to": "L5:12" + }, + "expr": { + "_type": "CompareExpr", + "left": { + "_type": "VariableExpr", + "name": "a" + }, + "operator": "<", + "right": { + "_type": "VariableExpr", + "name": "b" + } + }, + "type": { + "name": "bool" + } + }, { "location": { "from": "L6:15", @@ -122,6 +164,27 @@ "name": "int" } }, + { + "location": { + "from": "L6:15", + "to": "L6:20" + }, + "expr": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "b" + }, + "operator": "-", + "right": { + "_type": "VariableExpr", + "name": "a" + } + }, + "type": { + "name": "int" + } + }, { "location": { "from": "L8:15", @@ -148,6 +211,27 @@ "name": "int" } }, + { + "location": { + "from": "L8:15", + "to": "L8:20" + }, + "expr": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "a" + }, + "operator": "-", + "right": { + "_type": "VariableExpr", + "name": "b" + } + }, + "type": { + "name": "int" + } + }, { "location": { "from": "L15:7", @@ -174,6 +258,27 @@ "name": "int" } }, + { + "location": { + "from": "L15:7", + "to": "L15:13" + }, + "expr": { + "_type": "CompareExpr", + "left": { + "_type": "VariableExpr", + "name": "a" + }, + "operator": ">", + "right": { + "_type": "LiteralExpr", + "value": 10 + } + }, + "type": { + "name": "bool" + } + }, { "location": { "from": "L16:15", @@ -200,6 +305,40 @@ "name": "int" } }, + { + "location": { + "from": "L16:15", + "to": "L16:21" + }, + "expr": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "a" + }, + "operator": "-", + "right": { + "_type": "LiteralExpr", + "value": 10 + } + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L18:15", + "to": "L18:16" + }, + "expr": { + "_type": "VariableExpr", + "name": "a" + }, + "type": { + "name": "int" + } + }, { "location": { "from": "L22:7", @@ -226,6 +365,27 @@ "name": "int" } }, + { + "location": { + "from": "L22:7", + "to": "L22:12" + }, + "expr": { + "_type": "CompareExpr", + "left": { + "_type": "VariableExpr", + "name": "a" + }, + "operator": "<", + "right": { + "_type": "VariableExpr", + "name": "b" + } + }, + "type": { + "name": "bool" + } + }, { "location": { "from": "L23:15", @@ -251,6 +411,40 @@ "type": { "name": "int" } + }, + { + "location": { + "from": "L23:15", + "to": "L23:20" + }, + "expr": { + "_type": "BinaryExpr", + "left": { + "_type": "VariableExpr", + "name": "b" + }, + "operator": "-", + "right": { + "_type": "VariableExpr", + "name": "a" + } + }, + "type": { + "name": "int" + } + }, + { + "location": { + "from": "L25:15", + "to": "L25:21" + }, + "expr": { + "_type": "LiteralExpr", + "value": "oops" + }, + "type": { + "name": "str" + } } ] } \ No newline at end of file diff --git a/tests/cases/checker/06_subtyping.py b/tests/cases/checker/06_subtyping.py index c334ab8..7ab9dd7 100644 --- a/tests/cases/checker/06_subtyping.py +++ b/tests/cases/checker/06_subtyping.py @@ -9,4 +9,4 @@ def maximum(a: float, b: float): v3 = maximum(v1, v2) -v3 = v1 + v2 +v3 = v2 + v1 diff --git a/tests/cases/checker/06_subtyping.py.ref.json b/tests/cases/checker/06_subtyping.py.ref.json index 689402e..3435f45 100644 --- a/tests/cases/checker/06_subtyping.py.ref.json +++ b/tests/cases/checker/06_subtyping.py.ref.json @@ -53,6 +53,53 @@ "name": "float" } }, + { + "location": { + "from": "L6:7", + "to": "L6:12" + }, + "expr": { + "_type": "CompareExpr", + "left": { + "_type": "VariableExpr", + "name": "b" + }, + "operator": ">", + "right": { + "_type": "VariableExpr", + "name": "a" + } + }, + "type": { + "name": "bool" + } + }, + { + "location": { + "from": "L7:15", + "to": "L7:16" + }, + "expr": { + "_type": "VariableExpr", + "name": "b" + }, + "type": { + "name": "float" + } + }, + { + "location": { + "from": "L8:11", + "to": "L8:12" + }, + "expr": { + "_type": "VariableExpr", + "name": "a" + }, + "type": { + "name": "float" + } + }, { "location": { "from": "L11:5", @@ -63,7 +110,6 @@ "name": "maximum" }, "type": { - "name": "maximum", "pos_args": [], "args": [ { @@ -149,10 +195,10 @@ }, "expr": { "_type": "VariableExpr", - "name": "v1" + "name": "v2" }, "type": { - "name": "int" + "name": "float" } }, { @@ -162,10 +208,10 @@ }, "expr": { "_type": "VariableExpr", - "name": "v2" + "name": "v1" }, "type": { - "name": "float" + "name": "int" } }, { @@ -177,12 +223,12 @@ "_type": "BinaryExpr", "left": { "_type": "VariableExpr", - "name": "v1" + "name": "v2" }, "operator": "+", "right": { "_type": "VariableExpr", - "name": "v2" + "name": "v1" } }, "type": { diff --git a/tests/cases/midas-parser/01_simple_types.midas b/tests/cases/midas-parser/01_simple_types.midas index 6446790..f0df3e2 100644 --- a/tests/cases/midas-parser/01_simple_types.midas +++ b/tests/cases/midas-parser/01_simple_types.midas @@ -10,8 +10,8 @@ type Difference[T] = T // Complex custom type, containing two values accessible through properties type GeoLocation = { - lat: Latitude - lon: Longitude + prop lat: Latitude + prop lon: Longitude } // Define operations on our custom type @@ -19,23 +19,23 @@ extend GeoLocation { // This type is compatible with the `-` operation with another GeoLocation // i.e. you can subtract a GeoLocation from another GeoLocation, resulting // in a Difference of GeoLocations - op __sub__(GeoLocation) -> Difference[GeoLocation] + def __sub__: fn(GeoLocation, /) -> Difference[GeoLocation] } // For complex generics, you need to specify how the genericity the properties // are handled type Difference[GeoLocation] = { - lat: Difference[Latitude] - lon: Difference[Longitude] + prop lat: Difference[Latitude] + prop lon: Difference[Longitude] } // Simple operation defined on our custom types extend Latitude { - op __sub__(Latitude) -> Difference[Latitude] + def __sub__: fn(Latitude, /) -> Difference[Latitude] } extend Longitude { - op __sub__(Longitude) -> Difference[Longitude] + def __sub__: fn(Longitude, /) -> Difference[Longitude] } // Predefined custom predicates that can be referenced in other definitions @@ -45,13 +45,13 @@ predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10) predicate Arctic(loc: GeoLocation) = (loc.lat >= 66) type Person = { - name: str + prop name: str // Property with an inline constraint - age: Optional[int where (0 <= _ < 150)] + prop age: Optional[int where (0 <= _ < 150)] // Property referencing a predicate - height: float where StrictlyPositive + prop height: float where StrictlyPositive - home: GeoLocation + prop home: GeoLocation } diff --git a/tests/cases/midas-parser/01_simple_types.midas.ref.json b/tests/cases/midas-parser/01_simple_types.midas.ref.json index 55b4813..be45687 100644 --- a/tests/cases/midas-parser/01_simple_types.midas.ref.json +++ b/tests/cases/midas-parser/01_simple_types.midas.ref.json @@ -511,17 +511,11 @@ "column": 1 }, { - "type": "IDENTIFIER", - "lexeme": "lat", + "type": "PROP", + "lexeme": "prop", "line": 13, "column": 5 }, - { - "type": "COLON", - "lexeme": ":", - "line": 13, - "column": 8 - }, { "type": "WHITESPACE", "lexeme": " ", @@ -530,15 +524,33 @@ }, { "type": "IDENTIFIER", - "lexeme": "Latitude", + "lexeme": "lat", "line": 13, "column": 10 }, + { + "type": "COLON", + "lexeme": ":", + "line": 13, + "column": 13 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 13, + "column": 14 + }, + { + "type": "IDENTIFIER", + "lexeme": "Latitude", + "line": 13, + "column": 15 + }, { "type": "NEWLINE", "lexeme": "\n", "line": 13, - "column": 18 + "column": 23 }, { "type": "WHITESPACE", @@ -547,17 +559,11 @@ "column": 1 }, { - "type": "IDENTIFIER", - "lexeme": "lon", + "type": "PROP", + "lexeme": "prop", "line": 14, "column": 5 }, - { - "type": "COLON", - "lexeme": ":", - "line": 14, - "column": 8 - }, { "type": "WHITESPACE", "lexeme": " ", @@ -566,15 +572,33 @@ }, { "type": "IDENTIFIER", - "lexeme": "Longitude", + "lexeme": "lon", "line": 14, "column": 10 }, + { + "type": "COLON", + "lexeme": ":", + "line": 14, + "column": 13 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 14, + "column": 14 + }, + { + "type": "IDENTIFIER", + "lexeme": "Longitude", + "line": 14, + "column": 15 + }, { "type": "NEWLINE", "lexeme": "\n", "line": 14, - "column": 19 + "column": 24 }, { "type": "RIGHT_BRACE", @@ -703,8 +727,8 @@ "column": 1 }, { - "type": "OP", - "lexeme": "op", + "type": "DEF", + "lexeme": "def", "line": 22, "column": 5 }, @@ -712,79 +736,115 @@ "type": "WHITESPACE", "lexeme": " ", "line": 22, - "column": 7 + "column": 8 }, { "type": "IDENTIFIER", "lexeme": "__sub__", "line": 22, - "column": 8 + "column": 9 + }, + { + "type": "COLON", + "lexeme": ":", + "line": 22, + "column": 16 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 22, + "column": 17 + }, + { + "type": "FUNC", + "lexeme": "fn", + "line": 22, + "column": 18 }, { "type": "LEFT_PAREN", "lexeme": "(", "line": 22, - "column": 15 + "column": 20 }, { "type": "IDENTIFIER", "lexeme": "GeoLocation", "line": 22, - "column": 16 + "column": 21 + }, + { + "type": "COMMA", + "lexeme": ",", + "line": 22, + "column": 32 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 22, + "column": 33 + }, + { + "type": "SLASH", + "lexeme": "/", + "line": 22, + "column": 34 }, { "type": "RIGHT_PAREN", "lexeme": ")", "line": 22, - "column": 27 + "column": 35 }, { "type": "WHITESPACE", "lexeme": " ", "line": 22, - "column": 28 + "column": 36 }, { "type": "ARROW", "lexeme": "->", "line": 22, - "column": 29 + "column": 37 }, { "type": "WHITESPACE", "lexeme": " ", "line": 22, - "column": 31 + "column": 39 }, { "type": "IDENTIFIER", "lexeme": "Difference", "line": 22, - "column": 32 + "column": 40 }, { "type": "LEFT_BRACKET", "lexeme": "[", "line": 22, - "column": 42 + "column": 50 }, { "type": "IDENTIFIER", "lexeme": "GeoLocation", "line": 22, - "column": 43 + "column": 51 }, { "type": "RIGHT_BRACKET", "lexeme": "]", "line": 22, - "column": 54 + "column": 62 }, { "type": "NEWLINE", "lexeme": "\n", "line": 22, - "column": 55 + "column": 63 }, { "type": "RIGHT_BRACE", @@ -901,17 +961,11 @@ "column": 1 }, { - "type": "IDENTIFIER", - "lexeme": "lat", + "type": "PROP", + "lexeme": "prop", "line": 28, "column": 5 }, - { - "type": "COLON", - "lexeme": ":", - "line": 28, - "column": 8 - }, { "type": "WHITESPACE", "lexeme": " ", @@ -920,33 +974,51 @@ }, { "type": "IDENTIFIER", - "lexeme": "Difference", + "lexeme": "lat", "line": 28, "column": 10 }, + { + "type": "COLON", + "lexeme": ":", + "line": 28, + "column": 13 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 28, + "column": 14 + }, + { + "type": "IDENTIFIER", + "lexeme": "Difference", + "line": 28, + "column": 15 + }, { "type": "LEFT_BRACKET", "lexeme": "[", "line": 28, - "column": 20 + "column": 25 }, { "type": "IDENTIFIER", "lexeme": "Latitude", "line": 28, - "column": 21 + "column": 26 }, { "type": "RIGHT_BRACKET", "lexeme": "]", "line": 28, - "column": 29 + "column": 34 }, { "type": "NEWLINE", "lexeme": "\n", "line": 28, - "column": 30 + "column": 35 }, { "type": "WHITESPACE", @@ -955,17 +1027,11 @@ "column": 1 }, { - "type": "IDENTIFIER", - "lexeme": "lon", + "type": "PROP", + "lexeme": "prop", "line": 29, "column": 5 }, - { - "type": "COLON", - "lexeme": ":", - "line": 29, - "column": 8 - }, { "type": "WHITESPACE", "lexeme": " ", @@ -974,33 +1040,51 @@ }, { "type": "IDENTIFIER", - "lexeme": "Difference", + "lexeme": "lon", "line": 29, "column": 10 }, + { + "type": "COLON", + "lexeme": ":", + "line": 29, + "column": 13 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 29, + "column": 14 + }, + { + "type": "IDENTIFIER", + "lexeme": "Difference", + "line": 29, + "column": 15 + }, { "type": "LEFT_BRACKET", "lexeme": "[", "line": 29, - "column": 20 + "column": 25 }, { "type": "IDENTIFIER", "lexeme": "Longitude", "line": 29, - "column": 21 + "column": 26 }, { "type": "RIGHT_BRACKET", "lexeme": "]", "line": 29, - "column": 30 + "column": 35 }, { "type": "NEWLINE", "lexeme": "\n", "line": 29, - "column": 31 + "column": 36 }, { "type": "RIGHT_BRACE", @@ -1075,8 +1159,8 @@ "column": 1 }, { - "type": "OP", - "lexeme": "op", + "type": "DEF", + "lexeme": "def", "line": 34, "column": 5 }, @@ -1084,79 +1168,115 @@ "type": "WHITESPACE", "lexeme": " ", "line": 34, - "column": 7 + "column": 8 }, { "type": "IDENTIFIER", "lexeme": "__sub__", "line": 34, - "column": 8 + "column": 9 + }, + { + "type": "COLON", + "lexeme": ":", + "line": 34, + "column": 16 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 34, + "column": 17 + }, + { + "type": "FUNC", + "lexeme": "fn", + "line": 34, + "column": 18 }, { "type": "LEFT_PAREN", "lexeme": "(", "line": 34, - "column": 15 + "column": 20 }, { "type": "IDENTIFIER", "lexeme": "Latitude", "line": 34, - "column": 16 + "column": 21 + }, + { + "type": "COMMA", + "lexeme": ",", + "line": 34, + "column": 29 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 34, + "column": 30 + }, + { + "type": "SLASH", + "lexeme": "/", + "line": 34, + "column": 31 }, { "type": "RIGHT_PAREN", "lexeme": ")", "line": 34, - "column": 24 + "column": 32 }, { "type": "WHITESPACE", "lexeme": " ", "line": 34, - "column": 25 + "column": 33 }, { "type": "ARROW", "lexeme": "->", "line": 34, - "column": 26 + "column": 34 }, { "type": "WHITESPACE", "lexeme": " ", "line": 34, - "column": 28 + "column": 36 }, { "type": "IDENTIFIER", "lexeme": "Difference", "line": 34, - "column": 29 + "column": 37 }, { "type": "LEFT_BRACKET", "lexeme": "[", "line": 34, - "column": 39 + "column": 47 }, { "type": "IDENTIFIER", "lexeme": "Latitude", "line": 34, - "column": 40 + "column": 48 }, { "type": "RIGHT_BRACKET", "lexeme": "]", "line": 34, - "column": 48 + "column": 56 }, { "type": "NEWLINE", "lexeme": "\n", "line": 34, - "column": 49 + "column": 57 }, { "type": "RIGHT_BRACE", @@ -1219,8 +1339,8 @@ "column": 1 }, { - "type": "OP", - "lexeme": "op", + "type": "DEF", + "lexeme": "def", "line": 38, "column": 5 }, @@ -1228,79 +1348,115 @@ "type": "WHITESPACE", "lexeme": " ", "line": 38, - "column": 7 + "column": 8 }, { "type": "IDENTIFIER", "lexeme": "__sub__", "line": 38, - "column": 8 + "column": 9 + }, + { + "type": "COLON", + "lexeme": ":", + "line": 38, + "column": 16 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 38, + "column": 17 + }, + { + "type": "FUNC", + "lexeme": "fn", + "line": 38, + "column": 18 }, { "type": "LEFT_PAREN", "lexeme": "(", "line": 38, - "column": 15 + "column": 20 }, { "type": "IDENTIFIER", "lexeme": "Longitude", "line": 38, - "column": 16 + "column": 21 + }, + { + "type": "COMMA", + "lexeme": ",", + "line": 38, + "column": 30 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 38, + "column": 31 + }, + { + "type": "SLASH", + "lexeme": "/", + "line": 38, + "column": 32 }, { "type": "RIGHT_PAREN", "lexeme": ")", "line": 38, - "column": 25 + "column": 33 }, { "type": "WHITESPACE", "lexeme": " ", "line": 38, - "column": 26 + "column": 34 }, { "type": "ARROW", "lexeme": "->", "line": 38, - "column": 27 + "column": 35 }, { "type": "WHITESPACE", "lexeme": " ", "line": 38, - "column": 29 + "column": 37 }, { "type": "IDENTIFIER", "lexeme": "Difference", "line": 38, - "column": 30 + "column": 38 }, { "type": "LEFT_BRACKET", "lexeme": "[", "line": 38, - "column": 40 + "column": 48 }, { "type": "IDENTIFIER", "lexeme": "Longitude", "line": 38, - "column": 41 + "column": 49 }, { "type": "RIGHT_BRACKET", "lexeme": "]", "line": 38, - "column": 50 + "column": 58 }, { "type": "NEWLINE", "lexeme": "\n", "line": 38, - "column": 51 + "column": 59 }, { "type": "RIGHT_BRACE", @@ -1903,34 +2059,46 @@ "column": 1 }, { - "type": "IDENTIFIER", - "lexeme": "name", + "type": "PROP", + "lexeme": "prop", "line": 48, "column": 5 }, - { - "type": "COLON", - "lexeme": ":", - "line": 48, - "column": 9 - }, { "type": "WHITESPACE", "lexeme": " ", "line": 48, + "column": 9 + }, + { + "type": "IDENTIFIER", + "lexeme": "name", + "line": 48, "column": 10 }, + { + "type": "COLON", + "lexeme": ":", + "line": 48, + "column": 14 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 48, + "column": 15 + }, { "type": "IDENTIFIER", "lexeme": "str", "line": 48, - "column": 11 + "column": 16 }, { "type": "NEWLINE", "lexeme": "\n", "line": 48, - "column": 14 + "column": 19 }, { "type": "NEWLINE", @@ -1963,17 +2131,11 @@ "column": 1 }, { - "type": "IDENTIFIER", - "lexeme": "age", + "type": "PROP", + "lexeme": "prop", "line": 51, "column": 5 }, - { - "type": "COLON", - "lexeme": ":", - "line": 51, - "column": 8 - }, { "type": "WHITESPACE", "lexeme": " ", @@ -1982,74 +2144,68 @@ }, { "type": "IDENTIFIER", - "lexeme": "Optional", + "lexeme": "age", "line": 51, "column": 10 }, + { + "type": "COLON", + "lexeme": ":", + "line": 51, + "column": 13 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 51, + "column": 14 + }, + { + "type": "IDENTIFIER", + "lexeme": "Optional", + "line": 51, + "column": 15 + }, { "type": "LEFT_BRACKET", "lexeme": "[", "line": 51, - "column": 18 + "column": 23 }, { "type": "IDENTIFIER", "lexeme": "int", "line": 51, - "column": 19 + "column": 24 }, { "type": "WHITESPACE", "lexeme": " ", "line": 51, - "column": 22 + "column": 27 }, { "type": "WHERE", "lexeme": "where", "line": 51, - "column": 23 + "column": 28 }, { "type": "WHITESPACE", "lexeme": " ", "line": 51, - "column": 28 + "column": 33 }, { "type": "LEFT_PAREN", "lexeme": "(", "line": 51, - "column": 29 + "column": 34 }, { "type": "NUMBER", "lexeme": "0", "line": 51, - "column": 30 - }, - { - "type": "WHITESPACE", - "lexeme": " ", - "line": 51, - "column": 31 - }, - { - "type": "LESS_EQUAL", - "lexeme": "<=", - "line": 51, - "column": 32 - }, - { - "type": "WHITESPACE", - "lexeme": " ", - "line": 51, - "column": 34 - }, - { - "type": "UNDERSCORE", - "lexeme": "_", - "line": 51, "column": 35 }, { @@ -2059,8 +2215,8 @@ "column": 36 }, { - "type": "LESS", - "lexeme": "<", + "type": "LESS_EQUAL", + "lexeme": "<=", "line": 51, "column": 37 }, @@ -2068,31 +2224,55 @@ "type": "WHITESPACE", "lexeme": " ", "line": 51, - "column": 38 + "column": 39 + }, + { + "type": "UNDERSCORE", + "lexeme": "_", + "line": 51, + "column": 40 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 51, + "column": 41 + }, + { + "type": "LESS", + "lexeme": "<", + "line": 51, + "column": 42 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 51, + "column": 43 }, { "type": "NUMBER", "lexeme": "150", "line": 51, - "column": 39 + "column": 44 }, { "type": "RIGHT_PAREN", "lexeme": ")", "line": 51, - "column": 42 + "column": 47 }, { "type": "RIGHT_BRACKET", "lexeme": "]", "line": 51, - "column": 43 + "column": 48 }, { "type": "NEWLINE", "lexeme": "\n", "line": 51, - "column": 44 + "column": 49 }, { "type": "NEWLINE", @@ -2124,59 +2304,71 @@ "line": 54, "column": 1 }, + { + "type": "PROP", + "lexeme": "prop", + "line": 54, + "column": 5 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 54, + "column": 9 + }, { "type": "IDENTIFIER", "lexeme": "height", "line": 54, - "column": 5 + "column": 10 }, { "type": "COLON", "lexeme": ":", "line": 54, - "column": 11 + "column": 16 }, { "type": "WHITESPACE", "lexeme": " ", "line": 54, - "column": 12 + "column": 17 }, { "type": "IDENTIFIER", "lexeme": "float", "line": 54, - "column": 13 + "column": 18 }, { "type": "WHITESPACE", "lexeme": " ", "line": 54, - "column": 18 + "column": 23 }, { "type": "WHERE", "lexeme": "where", "line": 54, - "column": 19 + "column": 24 }, { "type": "WHITESPACE", "lexeme": " ", "line": 54, - "column": 24 + "column": 29 }, { "type": "IDENTIFIER", "lexeme": "StrictlyPositive", "line": 54, - "column": 25 + "column": 30 }, { "type": "NEWLINE", "lexeme": "\n", "line": 54, - "column": 41 + "column": 46 }, { "type": "NEWLINE", @@ -2191,34 +2383,46 @@ "column": 1 }, { - "type": "IDENTIFIER", - "lexeme": "home", + "type": "PROP", + "lexeme": "prop", "line": 56, "column": 5 }, - { - "type": "COLON", - "lexeme": ":", - "line": 56, - "column": 9 - }, { "type": "WHITESPACE", "lexeme": " ", "line": 56, + "column": 9 + }, + { + "type": "IDENTIFIER", + "lexeme": "home", + "line": 56, "column": 10 }, + { + "type": "COLON", + "lexeme": ":", + "line": 56, + "column": 14 + }, + { + "type": "WHITESPACE", + "lexeme": " ", + "line": 56, + "column": 15 + }, { "type": "IDENTIFIER", "lexeme": "GeoLocation", "line": 56, - "column": 11 + "column": 16 }, { "type": "NEWLINE", "lexeme": "\n", "line": 56, - "column": 22 + "column": 27 }, { "type": "RIGHT_BRACE", @@ -2345,9 +2549,10 @@ "params": [], "type": { "_type": "ComplexType", - "properties": [ + "members": [ { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": "PROPERTY", "name": "lat", "type": { "_type": "NamedType", @@ -2355,7 +2560,8 @@ } }, { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": "PROPERTY", "name": "lon", "type": { "_type": "NamedType", @@ -2367,30 +2573,40 @@ }, { "_type": "ExtendStmt", - "type": { - "_type": "NamedType", - "name": "GeoLocation" - }, - "operations": [ + "name": "GeoLocation", + "params": [], + "members": [ { - "_type": "OpStmt", + "_type": "MemberStmt", + "kind": "METHOD", "name": "__sub__", - "operand": { - "_type": "NamedType", - "name": "GeoLocation" - }, - "result": { - "_type": "GenericType", - "type": { - "_type": "NamedType", - "name": "Difference" - }, - "params": [ + "type": { + "_type": "FunctionType", + "pos_args": [ { - "_type": "NamedType", - "name": "GeoLocation" + "name": null, + "type": { + "_type": "NamedType", + "name": "GeoLocation" + }, + "required": true } - ] + ], + "args": [], + "kw_args": [], + "returns": { + "_type": "GenericType", + "type": { + "_type": "NamedType", + "name": "Difference" + }, + "args": [ + { + "_type": "NamedType", + "name": "GeoLocation" + } + ] + } } } ] @@ -2406,9 +2622,10 @@ ], "type": { "_type": "ComplexType", - "properties": [ + "members": [ { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": "PROPERTY", "name": "lat", "type": { "_type": "GenericType", @@ -2416,7 +2633,7 @@ "_type": "NamedType", "name": "Difference" }, - "params": [ + "args": [ { "_type": "NamedType", "name": "Latitude" @@ -2425,7 +2642,8 @@ } }, { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": "PROPERTY", "name": "lon", "type": { "_type": "GenericType", @@ -2433,7 +2651,7 @@ "_type": "NamedType", "name": "Difference" }, - "params": [ + "args": [ { "_type": "NamedType", "name": "Longitude" @@ -2446,60 +2664,80 @@ }, { "_type": "ExtendStmt", - "type": { - "_type": "NamedType", - "name": "Latitude" - }, - "operations": [ + "name": "Latitude", + "params": [], + "members": [ { - "_type": "OpStmt", + "_type": "MemberStmt", + "kind": "METHOD", "name": "__sub__", - "operand": { - "_type": "NamedType", - "name": "Latitude" - }, - "result": { - "_type": "GenericType", - "type": { - "_type": "NamedType", - "name": "Difference" - }, - "params": [ + "type": { + "_type": "FunctionType", + "pos_args": [ { - "_type": "NamedType", - "name": "Latitude" + "name": null, + "type": { + "_type": "NamedType", + "name": "Latitude" + }, + "required": true } - ] + ], + "args": [], + "kw_args": [], + "returns": { + "_type": "GenericType", + "type": { + "_type": "NamedType", + "name": "Difference" + }, + "args": [ + { + "_type": "NamedType", + "name": "Latitude" + } + ] + } } } ] }, { "_type": "ExtendStmt", - "type": { - "_type": "NamedType", - "name": "Longitude" - }, - "operations": [ + "name": "Longitude", + "params": [], + "members": [ { - "_type": "OpStmt", + "_type": "MemberStmt", + "kind": "METHOD", "name": "__sub__", - "operand": { - "_type": "NamedType", - "name": "Longitude" - }, - "result": { - "_type": "GenericType", - "type": { - "_type": "NamedType", - "name": "Difference" - }, - "params": [ + "type": { + "_type": "FunctionType", + "pos_args": [ { - "_type": "NamedType", - "name": "Longitude" + "name": null, + "type": { + "_type": "NamedType", + "name": "Longitude" + }, + "required": true } - ] + ], + "args": [], + "kw_args": [], + "returns": { + "_type": "GenericType", + "type": { + "_type": "NamedType", + "name": "Difference" + }, + "args": [ + { + "_type": "NamedType", + "name": "Longitude" + } + ] + } } } ] @@ -2620,9 +2858,10 @@ "params": [], "type": { "_type": "ComplexType", - "properties": [ + "members": [ { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": "PROPERTY", "name": "name", "type": { "_type": "NamedType", @@ -2630,7 +2869,8 @@ } }, { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": "PROPERTY", "name": "age", "type": { "_type": "GenericType", @@ -2638,7 +2878,7 @@ "_type": "NamedType", "name": "Optional" }, - "params": [ + "args": [ { "_type": "ConstraintType", "type": { @@ -2672,7 +2912,8 @@ } }, { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": "PROPERTY", "name": "height", "type": { "_type": "ConstraintType", @@ -2687,7 +2928,8 @@ } }, { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": "PROPERTY", "name": "home", "type": { "_type": "NamedType", diff --git a/tests/cases/python-parser/02_custom_types.py.ref.json b/tests/cases/python-parser/02_custom_types.py.ref.json index 639610d..82c726c 100644 --- a/tests/cases/python-parser/02_custom_types.py.ref.json +++ b/tests/cases/python-parser/02_custom_types.py.ref.json @@ -18,6 +18,80 @@ ] } }, + { + "_type": "TypeAssign", + "name": "lat", + "type": { + "_type": "BaseType", + "base": "Column", + "param": { + "_type": "BaseType", + "base": "GeoLocation", + "param": null + } + } + }, + { + "_type": "AssignStmt", + "targets": [ + { + "_type": "VariableExpr", + "name": "lat" + } + ], + "value": { + "_type": "GetExpr", + "object": { + "_type": "SubscriptExpr", + "object": { + "_type": "VariableExpr", + "name": "df" + }, + "index": { + "_type": "LiteralExpr", + "value": "location" + } + }, + "name": "lat" + } + }, + { + "_type": "TypeAssign", + "name": "lon", + "type": { + "_type": "BaseType", + "base": "Column", + "param": { + "_type": "BaseType", + "base": "GeoLocation", + "param": null + } + } + }, + { + "_type": "AssignStmt", + "targets": [ + { + "_type": "VariableExpr", + "name": "lon" + } + ], + "value": { + "_type": "GetExpr", + "object": { + "_type": "SubscriptExpr", + "object": { + "_type": "VariableExpr", + "name": "df" + }, + "index": { + "_type": "LiteralExpr", + "value": "location" + } + }, + "name": "lon" + } + }, { "_type": "ExpressionStmt", "expr": { @@ -33,6 +107,64 @@ } } }, + { + "_type": "TypeAssign", + "name": "lat1", + "type": { + "_type": "BaseType", + "base": "Latitude", + "param": null + } + }, + { + "_type": "AssignStmt", + "targets": [ + { + "_type": "VariableExpr", + "name": "lat1" + } + ], + "value": { + "_type": "SubscriptExpr", + "object": { + "_type": "VariableExpr", + "name": "lat" + }, + "index": { + "_type": "LiteralExpr", + "value": 0 + } + } + }, + { + "_type": "TypeAssign", + "name": "lat2", + "type": { + "_type": "BaseType", + "base": "Latitude", + "param": null + } + }, + { + "_type": "AssignStmt", + "targets": [ + { + "_type": "VariableExpr", + "name": "lat2" + } + ], + "value": { + "_type": "SubscriptExpr", + "object": { + "_type": "VariableExpr", + "name": "lat" + }, + "index": { + "_type": "LiteralExpr", + "value": 1 + } + } + }, { "_type": "TypeAssign", "name": "lat_diff", diff --git a/tests/checker.py b/tests/checker.py index 27a94cb..3ceb34e 100644 --- a/tests/checker.py +++ b/tests/checker.py @@ -1,14 +1,11 @@ -import ast import json from dataclasses import asdict, dataclass, field from pathlib import Path import midas.ast.python as p -from midas.checker.checker import Checker +from midas.checker.checker import TypeChecker from midas.checker.diagnostic import Diagnostic from midas.checker.types import Type -from midas.parser.python import PythonParser -from midas.resolver.resolver import Resolver from tests.base import Tester from tests.serializer.python import PythonAstJsonSerializer @@ -36,24 +33,16 @@ class CheckerTester(Tester): if not path.is_file(): raise TypeError(f"Test '{path}' is not a file") - types_paths: list[Path] = [] + result: CaseResult = CaseResult() + + checker = TypeChecker() types_path: Path = path.with_suffix(".midas") if types_path.exists(): - types_paths.append(types_path) - source: str = path.read_text() - tree: ast.Module = ast.parse(source, filename=path) - parser = PythonParser() - stmts: list[p.Stmt] = parser.parse_module(tree) - resolver = Resolver() - resolver.resolve(*stmts) - result: CaseResult = CaseResult() - checker = Checker( - resolver.locals, - source_path=path, - types_paths=types_paths, - ) + checker.import_midas(types_path) - diagnostics: list[Diagnostic] = checker.check(stmts) + checker.type_check(path) + + diagnostics: list[Diagnostic] = checker.diagnostics for diagnostic in diagnostics: result.diagnostics.append( { @@ -72,7 +61,7 @@ class CheckerTester(Tester): } ) - judgements: list[tuple[p.Expr, Type]] = checker.judgements + judgements: list[tuple[p.Expr, Type]] = checker.python_typer.judgements serializer = PythonAstJsonSerializer() for expr, type in judgements: loc = expr.location diff --git a/tests/serializer/midas.py b/tests/serializer/midas.py index 919dc66..8bffdb3 100644 --- a/tests/serializer/midas.py +++ b/tests/serializer/midas.py @@ -6,17 +6,19 @@ from midas.ast.midas import ( ConstraintType, Expr, ExtendStmt, + ExtensionType, + FunctionType, GenericType, GetExpr, GroupingExpr, LiteralExpr, LogicalExpr, + MemberStmt, NamedType, - OpStmt, PredicateStmt, - PropertyStmt, Stmt, Type, + TypeParam, TypeStmt, UnaryExpr, VariableExpr, @@ -46,21 +48,20 @@ class MidasAstJsonSerializer( return { "_type": "TypeStmt", "name": stmt.name.lexeme, - "params": [ - self._serialize_type_stmt_template_param(param) for param in stmt.params - ], + "params": [self._serialize_type_param(param) for param in stmt.params], "type": stmt.type.accept(self), } - def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict: + def _serialize_type_param(self, param: TypeParam) -> dict: return { "name": param.name.lexeme, "bound": self._serialize_optional(param.bound), } - def visit_property_stmt(self, stmt: PropertyStmt) -> dict: + def visit_member_stmt(self, stmt: MemberStmt) -> dict: return { - "_type": "PropertyStmt", + "_type": "MemberStmt", + "kind": stmt.kind.name, "name": stmt.name.lexeme, "type": stmt.type.accept(self), } @@ -68,16 +69,9 @@ class MidasAstJsonSerializer( def visit_extend_stmt(self, stmt: ExtendStmt) -> dict: return { "_type": "ExtendStmt", - "type": stmt.type.accept(self), - "operations": self._serialize_list(stmt.operations), - } - - def visit_op_stmt(self, stmt: OpStmt) -> dict: - return { - "_type": "OpStmt", "name": stmt.name.lexeme, - "operand": stmt.operand.accept(self), - "result": stmt.result.accept(self), + "params": [self._serialize_type_param(param) for param in stmt.params], + "members": self._serialize_list(stmt.members), } def visit_predicate_stmt(self, stmt: PredicateStmt) -> dict: @@ -150,7 +144,7 @@ class MidasAstJsonSerializer( return { "_type": "GenericType", "type": type.type.accept(self), - "params": self._serialize_list(type.params), + "args": self._serialize_list(type.args), } def visit_constraint_type(self, type: ConstraintType) -> dict: @@ -163,5 +157,28 @@ class MidasAstJsonSerializer( def visit_complex_type(self, type: ComplexType) -> dict: return { "_type": "ComplexType", - "properties": self._serialize_list(type.properties), + "members": self._serialize_list(type.members), + } + + def visit_function_type(self, type: FunctionType) -> dict: + return { + "_type": "FunctionType", + "pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args], + "args": [self._serialize_func_arg(arg) for arg in type.args], + "kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args], + "returns": type.returns.accept(self), + } + + def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict: + return { + "name": arg.name, + "type": arg.type.accept(self), + "required": arg.required, + } + + def visit_extension_type(self, type: ExtensionType) -> dict: + return { + "_type": "ExtensionType", + "base": type.base.accept(self), + "extension": type.extension.accept(self), } diff --git a/tests/serializer/python.py b/tests/serializer/python.py index bab3f8c..b090eea 100644 --- a/tests/serializer/python.py +++ b/tests/serializer/python.py @@ -16,11 +16,14 @@ from midas.ast.python import ( Function, GetExpr, IfStmt, + ListExpr, LiteralExpr, LogicalExpr, MidasType, ReturnStmt, + SliceExpr, Stmt, + SubscriptExpr, TernaryExpr, TypeAssign, UnaryExpr, @@ -245,3 +248,24 @@ class PythonAstJsonSerializer( "if_true": expr.if_true.accept(self), "if_false": expr.if_false.accept(self), } + + def visit_list_expr(self, expr: ListExpr) -> dict: + return { + "_type": "ListExpr", + "items": [item.accept(self) for item in expr.items], + } + + def visit_subscript_expr(self, expr: SubscriptExpr) -> dict: + return { + "_type": "SubscriptExpr", + "object": expr.object.accept(self), + "index": expr.index.accept(self), + } + + def visit_slice_expr(self, expr: SliceExpr) -> dict: + return { + "_type": "SliceExpr", + "lower": self._serialize_optional(expr.lower), + "upper": self._serialize_optional(expr.upper), + "step": self._serialize_optional(expr.step), + }