diff --git a/examples/01_simple_type_checking/01_simple_operations.py b/examples/01_simple_type_checking/01_simple_operations.py
index a3ac707..4e767f2 100644
--- a/examples/01_simple_type_checking/01_simple_operations.py
+++ b/examples/01_simple_type_checking/01_simple_operations.py
@@ -9,3 +9,5 @@ d = True
e = d + d
f: float = a
+
+f = -f
diff --git a/examples/01_simple_type_checking/02_simple_types.midas b/examples/01_simple_type_checking/02_simple_types.midas
index 6a1a6a2..ff4edb1 100644
--- a/examples/01_simple_type_checking/02_simple_types.midas
+++ b/examples/01_simple_type_checking/02_simple_types.midas
@@ -3,12 +3,12 @@ type Second = float
type MeterPerSecond = float
extend Meter {
- op __add__(Meter) -> Meter
- op __sub__(Meter) -> Meter
- op __truediv__(Second) -> MeterPerSecond
+ def __add__: fn(Meter, /) -> Meter
+ def __sub__: fn(Meter, /) -> Meter
+ def __truediv__: fn(Second, /) -> MeterPerSecond
}
extend Second {
- op __add__(Second) -> Second
- op __sub__(Second) -> Second
+ def __add__: fn(Second, /) -> Second
+ def __sub__: fn(Second, /) -> Second
}
diff --git a/examples/01_simple_type_checking/04_complex_types.midas b/examples/01_simple_type_checking/04_complex_types.midas
index b920c37..adc76b3 100644
--- a/examples/01_simple_type_checking/04_complex_types.midas
+++ b/examples/01_simple_type_checking/04_complex_types.midas
@@ -1,11 +1,21 @@
type Meter = float
extend Meter {
- op __add__(Meter) -> Meter
- op __sub__(Meter) -> Meter
+ def __add__: fn(Meter, /) -> Meter
+ def __sub__: fn(Meter, /) -> Meter
}
-type Coordinate = {
- x: Meter
- y: Meter
+type Coordinate = object
+
+extend Coordinate {
+ prop x: Meter
+ prop y: Meter
+}
+
+type Difference[T <: float] = T
+type MeterDifference = Difference[Meter]
+
+type CompDiff[T <: float] = {
+ prop d1: Difference[T]
+ prop d2: Difference[T]
}
\ No newline at end of file
diff --git a/examples/01_simple_type_checking/04_complex_types.py b/examples/01_simple_type_checking/04_complex_types.py
index f36ef52..f1d1215 100644
--- a/examples/01_simple_type_checking/04_complex_types.py
+++ b/examples/01_simple_type_checking/04_complex_types.py
@@ -1,5 +1,6 @@
# type: ignore
# ruff: disable [F821]
+
p1: Coordinate
p2: Coordinate
@@ -9,3 +10,28 @@ diff_y = p2.y - p1.y
dist = diff_x + diff_y
p2.x += cast(Meter, 1)
+p2.y = True # invalid, wrong type
+p2.z = 3 # invalid, no property 'z' on Coordinate
+p2.x.a = 3 # invalid, no properties on Meter
+
+foo: list[float] = []
+
+append = foo.append
+
+foo.append("") # invalid, must be float
+foo.append(2)
+append(True) # invalid, must be float
+append(2)
+
+bar: list[list[Meter]]
+
+bar.append([p2.x])
+
+foo2 = foo + foo
+
+a = foo[0]
+b = bar[0][1]
+c = bar[0][1][2] # invalid, not method __getitem__ on Meter
+c = bar[""] # invalid, wrong index type
+
+d = foo[1:2]
diff --git a/examples/01_simple_type_checking/05_functions.py b/examples/01_simple_type_checking/05_functions.py
new file mode 100644
index 0000000..9c04813
--- /dev/null
+++ b/examples/01_simple_type_checking/05_functions.py
@@ -0,0 +1,28 @@
+def incr(value: int):
+ return value + 1
+
+
+def decr(value: int):
+ return value - 1
+
+
+def foo(a: int, /, b: float, *, c: str):
+ return True
+
+
+r1 = foo() # foo() missing 2 required positional arguments: 'a' and 'b'
+r2 = foo(1) # foo() missing 1 required positional argument: 'b'
+r3 = foo(1, 2.0) # foo() missing 1 required keyword-only argument: 'c'
+r4 = foo(1, b=2.0) # foo() missing 1 required keyword-only argument: 'c'
+r5 = foo(1, 2.0, "test") # foo() takes 2 positional arguments but 3 were given
+r6 = foo(1, 2.0, b=3.0) # foo() got multiple values for argument 'b'
+r7 = foo(
+ a=1
+) # foo() got some positional-only arguments passed as keyword arguments: 'a'
+r8 = foo(g="test") # foo() got an unexpected keyword argument 'g'
+
+r9a = foo(1, 2.0, c="test")
+r9b = foo(1, b=2.0, c="test")
+r9c = foo(1, c="test", b=2.0)
+
+r10 = foo("a", 3, c=False) # wrong argument types
diff --git a/examples/01_simple_type_checking/06_overloads.midas b/examples/01_simple_type_checking/06_overloads.midas
new file mode 100644
index 0000000..777c410
--- /dev/null
+++ b/examples/01_simple_type_checking/06_overloads.midas
@@ -0,0 +1,10 @@
+type T1 = object
+type T2 = object
+type Foo = object
+type T2b = T2
+
+extend Foo {
+ def bar: fn(T1, /) -> int
+ def bar: fn(T2, /) -> float
+ def bar: fn(T2b, /) -> int
+}
diff --git a/examples/01_simple_type_checking/06_overloads.py b/examples/01_simple_type_checking/06_overloads.py
new file mode 100644
index 0000000..86406e0
--- /dev/null
+++ b/examples/01_simple_type_checking/06_overloads.py
@@ -0,0 +1,18 @@
+# type: ignore
+# ruff: disable [F821]
+
+foo: Foo
+t1: T1
+t2: T2
+
+a = foo.bar(t1)
+b = foo.bar(t2)
+
+func = foo.bar
+
+c = func(t1)
+d = func(t2)
+
+t2b: T2b
+
+e = foo.bar(t2b)
diff --git a/gen/gen.py b/gen/gen.py
index e78c872..50c9c9d 100644
--- a/gen/gen.py
+++ b/gen/gen.py
@@ -30,6 +30,7 @@ from __future__ import annotations
T = TypeVar("T")
+{preamble}
{sections}
"""
@@ -57,6 +58,11 @@ IMPORTS_REGEX = re.compile(
re.MULTILINE | re.DOTALL,
)
+PREAMBLE_REGEX = re.compile(
+ r"^###>\s*Preamble\s*?\n(?P
.*?)\n###<$",
+ re.MULTILINE | re.DOTALL,
+)
+
def snake_case(text: str) -> str:
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
@@ -88,13 +94,14 @@ def make_banner(text: str) -> str:
def make_section(full_name: str, base: str, param: str, body: str) -> str:
+ print(f" Generating {full_name}")
visitor_methods: list[str] = []
classes: list[str] = []
definitions: list[str] = body.strip("\n").split("\n\n\n")
for cls in definitions:
cls = cls.strip("\n")
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
- print(f"Processing {name}")
+ print(f" Processing {name}")
visitor_methods.append(make_visitor_method(name, param))
classes.append(make_class(name, cls, base))
@@ -107,6 +114,7 @@ def make_section(full_name: str, base: str, param: str, body: str) -> str:
def generate(definitions_path: Path, out_path: Path):
+ print(f"Processing generating {out_path} from {definitions_path}")
root_dir: Path = Path(__file__).parent.parent
rel_path: Path = definitions_path.relative_to(root_dir)
src: str = definitions_path.read_text()
@@ -116,6 +124,10 @@ def generate(definitions_path: Path, out_path: Path):
if m := IMPORTS_REGEX.search(src):
imports = m.group("body").strip("\n")
+ preamble: str = ""
+ if m := PREAMBLE_REGEX.search(src):
+ preamble = m.group("body")
+
for section_m in SECTION_REGEX.finditer(src):
full_name: str = section_m.group("name")
base: str = section_m.group("base")
@@ -129,6 +141,7 @@ def generate(definitions_path: Path, out_path: Path):
gen_path=Path(__file__).relative_to(root_dir),
),
imports=imports,
+ preamble=preamble,
sections="\n\n\n".join(sections),
)
out_path.write_text(result)
diff --git a/gen/midas.py b/gen/midas.py
index e1c304d..42caf4f 100644
--- a/gen/midas.py
+++ b/gen/midas.py
@@ -4,6 +4,7 @@
###> Imports
from abc import ABC, abstractmethod
from dataclasses import dataclass
+from enum import Enum, auto
from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location
@@ -12,33 +13,39 @@ from midas.lexer.token import Token
###<
+###> Preamble
+@dataclass(frozen=True, kw_only=True)
+class TypeParam:
+ location: Location
+ name: Token
+ bound: Optional[Type]
+
+
+class MemberKind(Enum):
+ PROPERTY = auto()
+ METHOD = auto()
+
+
+###<
+
+
###> Stmt | Statements
class TypeStmt:
name: Token
- params: list[Param]
+ params: list[TypeParam]
type: Type
- @dataclass(frozen=True, kw_only=True)
- class Param:
- location: Location
- name: Token
- bound: Optional[Type]
-
-class PropertyStmt:
+class MemberStmt:
name: Token
type: Type
+ kind: MemberKind
class ExtendStmt:
- type: Type
- operations: list[OpStmt]
-
-
-class OpStmt:
name: Token
- operand: Type
- result: Type
+ params: list[TypeParam]
+ members: list[MemberStmt]
class PredicateStmt:
@@ -103,7 +110,7 @@ class NamedType:
class GenericType:
type: Type
- params: list[Type]
+ args: list[Type]
class ConstraintType:
@@ -112,7 +119,26 @@ class ConstraintType:
class ComplexType:
- properties: list[PropertyStmt]
+ members: list[MemberStmt]
+
+
+class ExtensionType:
+ base: Type
+ extension: ComplexType
+
+
+class FunctionType:
+ pos_args: list[Argument]
+ args: list[Argument]
+ kw_args: list[Argument]
+ returns: Type
+
+ @dataclass(frozen=True, kw_only=True)
+ class Argument:
+ location: Optional[Location] = None
+ name: Optional[Token]
+ type: Type
+ required: bool
###<
diff --git a/gen/python.py b/gen/python.py
index e6d08c9..35908f7 100644
--- a/gen/python.py
+++ b/gen/python.py
@@ -139,4 +139,19 @@ class TernaryExpr:
if_false: Expr
+class ListExpr:
+ items: list[Expr]
+
+
+class SubscriptExpr:
+ object: Expr
+ index: Expr
+
+
+class SliceExpr:
+ lower: Optional[Expr]
+ upper: Optional[Expr]
+ step: Optional[Expr]
+
+
###<
diff --git a/midas/ast/midas.py b/midas/ast/midas.py
index 335e5cf..e71aff9 100644
--- a/midas/ast/midas.py
+++ b/midas/ast/midas.py
@@ -7,6 +7,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
+from enum import Enum, auto
from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location
@@ -14,6 +15,18 @@ from midas.lexer.token import Token
T = TypeVar("T")
+@dataclass(frozen=True, kw_only=True)
+class TypeParam:
+ location: Location
+ name: Token
+ bound: Optional[Type]
+
+
+class MemberKind(Enum):
+ PROPERTY = auto()
+ METHOD = auto()
+
+
##############
# Statements #
##############
@@ -31,14 +44,11 @@ class Stmt(ABC):
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
@abstractmethod
- def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
+ def visit_member_stmt(self, stmt: MemberStmt) -> T: ...
@abstractmethod
def visit_extend_stmt(self, stmt: ExtendStmt) -> T: ...
- @abstractmethod
- def visit_op_stmt(self, stmt: OpStmt) -> T: ...
-
@abstractmethod
def visit_predicate_stmt(self, stmt: PredicateStmt) -> T: ...
@@ -46,47 +56,33 @@ class Stmt(ABC):
@dataclass(frozen=True)
class TypeStmt(Stmt):
name: Token
- params: list[Param]
+ params: list[TypeParam]
type: Type
- @dataclass(frozen=True, kw_only=True)
- class Param:
- location: Location
- name: Token
- bound: Optional[Type]
-
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_type_stmt(self)
@dataclass(frozen=True)
-class PropertyStmt(Stmt):
+class MemberStmt(Stmt):
name: Token
type: Type
+ kind: MemberKind
def accept(self, visitor: Stmt.Visitor[T]) -> T:
- return visitor.visit_property_stmt(self)
+ return visitor.visit_member_stmt(self)
@dataclass(frozen=True)
class ExtendStmt(Stmt):
- type: Type
- operations: list[OpStmt]
+ name: Token
+ params: list[TypeParam]
+ members: list[MemberStmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_extend_stmt(self)
-@dataclass(frozen=True)
-class OpStmt(Stmt):
- name: Token
- operand: Type
- result: Type
-
- def accept(self, visitor: Stmt.Visitor[T]) -> T:
- return visitor.visit_op_stmt(self)
-
-
@dataclass(frozen=True)
class PredicateStmt(Stmt):
name: Token
@@ -231,6 +227,12 @@ class Type(ABC):
@abstractmethod
def visit_complex_type(self, type: ComplexType) -> T: ...
+ @abstractmethod
+ def visit_extension_type(self, type: ExtensionType) -> T: ...
+
+ @abstractmethod
+ def visit_function_type(self, type: FunctionType) -> T: ...
+
@dataclass(frozen=True)
class NamedType(Type):
@@ -243,7 +245,7 @@ class NamedType(Type):
@dataclass(frozen=True)
class GenericType(Type):
type: Type
- params: list[Type]
+ args: list[Type]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_generic_type(self)
@@ -260,7 +262,34 @@ class ConstraintType(Type):
@dataclass(frozen=True)
class ComplexType(Type):
- properties: list[PropertyStmt]
+ members: list[MemberStmt]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_complex_type(self)
+
+
+@dataclass(frozen=True)
+class ExtensionType(Type):
+ base: Type
+ extension: ComplexType
+
+ def accept(self, visitor: Type.Visitor[T]) -> T:
+ return visitor.visit_extension_type(self)
+
+
+@dataclass(frozen=True)
+class FunctionType(Type):
+ pos_args: list[Argument]
+ args: list[Argument]
+ kw_args: list[Argument]
+ returns: Type
+
+ @dataclass(frozen=True, kw_only=True)
+ class Argument:
+ location: Optional[Location] = None
+ name: Optional[Token]
+ type: Type
+ required: bool
+
+ def accept(self, visitor: Type.Visitor[T]) -> T:
+ return visitor.visit_function_type(self)
diff --git a/midas/ast/printer.py b/midas/ast/printer.py
index f8fb411..e52472c 100644
--- a/midas/ast/printer.py
+++ b/midas/ast/printer.py
@@ -100,20 +100,21 @@ class MidasAstPrinter(
self._idx = i
if i == len(stmt.params) - 1:
self._mark_last()
- self._print_type_stmt_param(param)
+ self._print_type_param(param)
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
- def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None:
+ def _print_type_param(self, param: m.TypeParam) -> None:
self._write_line("Param")
with self._child_level():
self._write_line(f'name: "{param.name.lexeme}"')
self._write_optional_child("bound", param.bound, last=True)
- def visit_property_stmt(self, stmt: m.PropertyStmt):
- self._write_line("PropertyStmt")
+ def visit_member_stmt(self, stmt: m.MemberStmt):
+ self._write_line("MemberStmt")
with self._child_level():
+ self._write_line(f"kind: {stmt.kind.name}")
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type", last=True)
with self._child_level(single=True):
@@ -122,29 +123,28 @@ class MidasAstPrinter(
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self._write_line("ExtendStmt")
with self._child_level():
- self._write_line("type")
- with self._child_level(single=True):
- stmt.type.accept(self)
- self._write_line("operations", last=True)
+ self._write_line("params")
with self._child_level():
- for i, op in enumerate(stmt.operations):
+ for i, param in enumerate(stmt.params):
self._idx = i
- if i == len(stmt.operations) - 1:
+ if i == len(stmt.params) - 1:
self._mark_last()
- op.accept(self)
-
- def visit_op_stmt(self, stmt: m.OpStmt) -> None:
- self._write_line("OpStmt")
- with self._child_level():
+ self._print_type_param(param)
self._write_line(f'name: "{stmt.name.lexeme}"')
-
- self._write_line("operand")
- with self._child_level(single=True):
- stmt.operand.accept(self)
-
- self._write_line("result", last=True)
- with self._child_level(single=True):
- stmt.result.accept(self)
+ self._write_line("params")
+ with self._child_level():
+ for i, param in enumerate(stmt.params):
+ self._idx = i
+ if i == len(stmt.params) - 1:
+ self._mark_last()
+ self._print_type_param(param)
+ self._write_line("members", last=True)
+ with self._child_level():
+ for i, member in enumerate(stmt.members):
+ self._idx = i
+ if i == len(stmt.members) - 1:
+ self._mark_last()
+ member.accept(self)
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
self._write_line("PredicateStmt")
@@ -234,11 +234,11 @@ class MidasAstPrinter(
self._write_line("type")
with self._child_level():
type.type.accept(self)
- self._write_line("params", last=True)
+ self._write_line("args", last=True)
with self._child_level():
- for i, param in enumerate(type.params):
+ for i, param in enumerate(type.args):
self._idx = i
- if i == len(type.params) - 1:
+ if i == len(type.args) - 1:
self._mark_last()
param.accept(self)
@@ -255,13 +255,66 @@ class MidasAstPrinter(
def visit_complex_type(self, type: m.ComplexType) -> None:
self._write_line("ComplexType")
with self._child_level():
- self._write_line("properties", last=True)
+ self._write_line("members", last=True)
with self._child_level():
- for i, prop in enumerate(type.properties):
+ for i, member in enumerate(type.members):
self._idx = i
- if i == len(type.properties) - 1:
+ if i == len(type.members) - 1:
self._mark_last()
- prop.accept(self)
+ member.accept(self)
+
+ def visit_extension_type(self, type: m.ExtensionType) -> None:
+ self._write_line("ExtensionType")
+ with self._child_level():
+ self._write_line("base")
+ with self._child_level(single=True):
+ type.base.accept(self)
+ self._write_line("extension", last=True)
+ with self._child_level(single=True):
+ type.extension.accept(self)
+
+ def visit_function_type(self, type: m.FunctionType) -> None:
+ self._write_line("FunctionType")
+ with self._child_level():
+ self._write_line("pos_args")
+ with self._child_level():
+ for i, arg in enumerate(type.pos_args):
+ self._idx = i
+ if i == len(type.pos_args) - 1:
+ self._mark_last()
+ self._print_function_arg(arg)
+
+ self._write_line("args")
+ with self._child_level():
+ for i, arg in enumerate(type.args):
+ self._idx = i
+ if i == len(type.args) - 1:
+ self._mark_last()
+ self._print_function_arg(arg)
+
+ self._write_line("kw_args")
+ with self._child_level():
+ for i, arg in enumerate(type.kw_args):
+ self._idx = i
+ if i == len(type.kw_args) - 1:
+ self._mark_last()
+ self._print_function_arg(arg)
+
+ self._write_line("returns", last=True)
+ with self._child_level(single=True):
+ type.returns.accept(self)
+
+ def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
+ self._write_line("Argument")
+ with self._child_level():
+ name: str = "None"
+ if arg.name is not None:
+ name = f'"{arg.name.lexeme}"'
+ self._write_line(f"name: {name}")
+ self._write_line("type")
+ with self._child_level(single=True):
+ arg.type.accept(self)
+ self._write_line(f"required: {arg.required}", last=True)
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
@@ -279,38 +332,39 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
template: str = ""
if len(stmt.params) != 0:
- params: list[str] = [
- self._print_type_template_param(param) for param in stmt.params
- ]
+ params: list[str] = [self._print_type_param(param) for param in stmt.params]
template = f"[{', '.join(params)}]"
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
return self.indented(res)
- def _print_type_template_param(self, param: m.TypeStmt.Param) -> str:
+ def _print_type_param(self, param: m.TypeParam) -> str:
res: str = param.name.lexeme
if param.bound is not None:
res += "<:" + param.bound.accept(self)
return res
- def visit_property_stmt(self, stmt: m.PropertyStmt):
- res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
+ def visit_member_stmt(self, stmt: m.MemberStmt):
+ keyword: str = {
+ m.MemberKind.PROPERTY: "prop",
+ m.MemberKind.METHOD: "def",
+ }.get(stmt.kind, "")
+ res: str = f"{keyword} {stmt.name.lexeme}: {stmt.type.accept(self)}"
return self.indented(res)
def visit_extend_stmt(self, stmt: m.ExtendStmt):
- res: str = self.indented(f"extend {stmt.type.accept(self)}")
+ template: str = ""
+ if len(stmt.params) != 0:
+ params: list[str] = [self._print_type_param(param) for param in stmt.params]
+ template = f"[{', '.join(params)}]"
+ res: str = self.indented(f"extend {stmt.name.lexeme}{template}")
res += " {\n"
self.level += 1
- for op in stmt.operations:
- res += op.accept(self)
+ for member in stmt.members:
+ res += member.accept(self) + "\n"
self.level -= 1
res += self.indented("}")
return res
- def visit_op_stmt(self, stmt: m.OpStmt):
- operand: str = stmt.operand.accept(self)
- result: str = stmt.result.accept(self)
- return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}\n")
-
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme
subject: str = stmt.subject.lexeme
@@ -358,9 +412,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
def visit_generic_type(self, type: m.GenericType) -> str:
res: str = type.type.accept(self)
- if len(type.params) != 0:
- params: list[str] = [param.accept(self) for param in type.params]
- res += f"[{', '.join(params)}]"
+ if len(type.args) != 0:
+ args: list[str] = [param.accept(self) for param in type.args]
+ res += f"[{', '.join(args)}]"
return res
def visit_constraint_type(self, type: m.ConstraintType) -> str:
@@ -371,13 +425,41 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
def visit_complex_type(self, type: m.ComplexType) -> str:
res: str = "{\n"
self.level += 1
- for prop in type.properties:
- res += prop.accept(self)
+ for member in type.members:
+ res += member.accept(self)
res += "\n"
self.level -= 1
res += self.indented("}")
return res
+ def visit_extension_type(self, type: m.ExtensionType) -> str:
+ return f"{type.base.accept(self)} & {type.extension.accept(self)}"
+
+ def visit_function_type(self, type: m.FunctionType) -> str:
+ pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
+ mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
+ kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
+ args: list[str] = pos_args
+
+ if len(pos_args) != 0:
+ args.append("/")
+ args += mixed_args
+ if len(kw_args) != 0:
+ args.append("*")
+ args += kw_args
+
+ return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
+
+ def _print_arg(self, arg: m.FunctionType.Argument) -> str:
+ res: str = ""
+ if arg.name is not None:
+ res += arg.name.lexeme
+ res += ": "
+ res += arg.type.accept(self)
+ if not arg.required:
+ res += "?"
+ return res
+
class PythonAstPrinter(
AstPrinter,
@@ -582,7 +664,7 @@ class PythonAstPrinter(
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
self._write_line("LiteralExpr")
with self._child_level(single=True):
- self._write_line(f"value: {expr.value}")
+ self._write_line(f"value: {expr.value!r}")
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
self._write_line("VariableExpr")
@@ -626,3 +708,31 @@ class PythonAstPrinter(
self._write_line("if_false", last=True)
with self._child_level(single=True):
expr.if_false.accept(self)
+
+ def visit_list_expr(self, expr: p.ListExpr) -> None:
+ self._write_line("ListExpr")
+ with self._child_level():
+ self._write_line("items", last=True)
+ with self._child_level():
+ for i, item in enumerate(expr.items):
+ self._idx = i
+ if i == len(expr.items) - 1:
+ self._mark_last()
+ item.accept(self)
+
+ def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
+ self._write_line("SubscriptExpr")
+ with self._child_level():
+ self._write_line("object")
+ with self._child_level(single=True):
+ expr.object.accept(self)
+ self._write_line("index", last=True)
+ with self._child_level(single=True):
+ expr.index.accept(self)
+
+ def visit_slice_expr(self, expr: p.SliceExpr) -> None:
+ self._write_line("SliceExpr")
+ with self._child_level():
+ self._write_optional_child("lower", expr.lower)
+ self._write_optional_child("upper", expr.upper)
+ self._write_optional_child("step", expr.step, last=True)
diff --git a/midas/ast/python.py b/midas/ast/python.py
index dd5d905..f025e2f 100644
--- a/midas/ast/python.py
+++ b/midas/ast/python.py
@@ -14,6 +14,7 @@ from midas.ast.location import Location
T = TypeVar("T")
+
####################
# Type annotations #
####################
@@ -220,6 +221,15 @@ class Expr(ABC):
@abstractmethod
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
+ @abstractmethod
+ def visit_list_expr(self, expr: ListExpr) -> T: ...
+
+ @abstractmethod
+ def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ...
+
+ @abstractmethod
+ def visit_slice_expr(self, expr: SliceExpr) -> T: ...
+
@dataclass(frozen=True)
class BinaryExpr(Expr):
@@ -312,3 +322,30 @@ class TernaryExpr(Expr):
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_ternary_expr(self)
+
+
+@dataclass(frozen=True)
+class ListExpr(Expr):
+ items: list[Expr]
+
+ def accept(self, visitor: Expr.Visitor[T]) -> T:
+ return visitor.visit_list_expr(self)
+
+
+@dataclass(frozen=True)
+class SubscriptExpr(Expr):
+ object: Expr
+ index: Expr
+
+ def accept(self, visitor: Expr.Visitor[T]) -> T:
+ return visitor.visit_subscript_expr(self)
+
+
+@dataclass(frozen=True)
+class SliceExpr(Expr):
+ lower: Optional[Expr]
+ upper: Optional[Expr]
+ step: Optional[Expr]
+
+ def accept(self, visitor: Expr.Visitor[T]) -> T:
+ return visitor.visit_slice_expr(self)
diff --git a/midas/checker/builtins.midas b/midas/checker/builtins.midas
new file mode 100644
index 0000000..6e89172
--- /dev/null
+++ b/midas/checker/builtins.midas
@@ -0,0 +1,152 @@
+extend float {
+ def hex: fn() -> str
+ def is_integer: fn() -> bool
+ prop real: float
+ prop imag: float
+ def conjugate: fn() -> float
+ def __add__: fn(value: float, /) -> float
+ def __sub__: fn(value: float, /) -> float
+ def __mul__: fn(value: float, /) -> float
+ def __floordiv__: fn(value: float, /) -> float
+ def __truediv__: fn(value: float, /) -> float
+ def __mod__: fn(value: float, /) -> float
+ // def __divmod__: fn(value: float, /) -> tuple[float, float]
+
+ def __pow__: fn(value: int, /) -> float
+ // positive __value -> float; negative __value -> complex
+ // return type must be Any as `float | complex` causes too many false-positive errors
+ def __pow__: fn(value: float, /) -> Any
+ def __radd__: fn(value: float, /) -> float
+ def __rsub__: fn(value: float, /) -> float
+ def __rmul__: fn(value: float, /) -> float
+ def __rfloordiv__: fn(value: float, /) -> float
+ def __rtruediv__: fn(value: float, /) -> float
+ def __rmod__: fn(value: float, /) -> float
+ // def __rdivmod__: fn(value: float, /) -> tuple[float, float]
+ // def __rpow__: fn(value: _PositiveInteger, mod: None = None, /) -> float
+ // def __rpow__: fn(value: _NegativeInteger, mod: None = None, /) -> complex
+ // Returning `complex` for the general case gives too many false-positive errors.
+ // def __rpow__: fn(value: float, mod: None = None, /) -> Any
+ // def __getnewargs__: fn() -> tuple[float]
+ def __trunc__: fn() -> int
+ def __ceil__: fn() -> int
+ def __floor__: fn() -> int
+ def __round__: fn(ndigits: None?, /) -> int
+ def __round__: fn(ndigits: int, /) -> float
+ def __eq__: fn(value: object, /) -> bool
+ def __ne__: fn(value: object, /) -> bool
+ def __lt__: fn(value: float, /) -> bool
+ def __le__: fn(value: float, /) -> bool
+ def __gt__: fn(value: float, /) -> bool
+ def __ge__: fn(value: float, /) -> bool
+ def __neg__: fn() -> float
+ def __pos__: fn() -> float
+ def __int__: fn() -> int
+ def __float__: fn() -> float
+ def __abs__: fn() -> float
+ def __hash__: fn() -> int
+ def __bool__: fn() -> bool
+ def __format__: fn(format_spec: str, /) -> str
+}
+
+extend int {
+ prop real: int
+ prop imag: int
+ prop numerator: int
+ prop denominator: int
+ def conjugate: fn() -> int
+ def bit_length: fn() -> int
+ def bit_count: fn() -> int
+ // def to_bytes: fn(length: int?, byteorder: str?, *, signed: bool?) -> bytes
+
+ def __add__: fn(value: int, /) -> int
+ def __sub__: fn(value: int, /) -> int
+ def __mul__: fn(value: int, /) -> int
+ def __floordiv__: fn(value: int, /) -> int
+ def __truediv__: fn(value: int, /) -> float
+ def __mod__: fn(value: int, /) -> int
+ // def __divmod__: fn(value: int, /) -> tuple[int, int]
+ def __radd__: fn(value: int, /) -> int
+ def __rsub__: fn(value: int, /) -> int
+ def __rmul__: fn(value: int, /) -> int
+ def __rfloordiv__: fn(value: int, /) -> int
+ def __rtruediv__: fn(value: int, /) -> float
+ def __rmod__: fn(value: int, /) -> int
+ // def __rdivmod__: fn(value: int, /) -> tuple[int, int]
+ def __pow__: fn(value: int, /) -> int
+ // def __pow__: fn(value: _PositiveInteger, mod: None = None, /) -> int
+ // def __pow__: fn(value: _NegativeInteger, mod: None = None, /) -> float
+ // positive __value -> int; negative __value -> float
+ // return type must be Any as `int | float` causes too many false-positive errors
+ // def __pow__: fn(value: int, mod: None = None, /) -> Any
+ // def __pow__: fn(value: int, mod: int, /) -> int
+ def __rpow__: fn(value: int, /) -> Any
+ def __and__: fn(value: int, /) -> int
+ def __or__: fn(value: int, /) -> int
+ def __xor__: fn(value: int, /) -> int
+ def __lshift__: fn(value: int, /) -> int
+ def __rshift__: fn(value: int, /) -> int
+ def __rand__: fn(value: int, /) -> int
+ def __ror__: fn(value: int, /) -> int
+ def __rxor__: fn(value: int, /) -> int
+ def __rlshift__: fn(value: int, /) -> int
+ def __rrshift__: fn(value: int, /) -> int
+ def __neg__: fn() -> int
+ def __pos__: fn() -> int
+ def __invert__: fn() -> int
+ def __trunc__: fn() -> int
+ def __ceil__: fn() -> int
+ def __floor__: fn() -> int
+ def __round__: fn(ndigits: None?, /) -> int
+ def __round__: fn(ndigits: int, /) -> int
+
+ // def __getnewargs__: fn() -> tuple[int]
+ def __eq__: fn(value: object, /) -> bool
+ def __ne__: fn(value: object, /) -> bool
+ def __lt__: fn(value: int, /) -> bool
+ def __le__: fn(value: int, /) -> bool
+ def __gt__: fn(value: int, /) -> bool
+ def __ge__: fn(value: int, /) -> bool
+ def __float__: fn() -> float
+ def __int__: fn() -> int
+ def __abs__: fn() -> int
+ def __hash__: fn() -> int
+ def __bool__: fn() -> bool
+ def __index__: fn() -> int
+ def __format__: fn(format_spec: str, /) -> str
+}
+
+extend list[T] {
+ def copy: fn () -> list[T]
+ def append: fn (object: T, /) -> None
+ def extend: fn (iterable: list[T], /) -> None
+ def pop: fn (index: int?, /) -> T
+ def index: fn (value: T, start: int?, stop: int?, /) -> int
+ def count: fn (value: T, /) -> int
+ def insert: fn (index: int, object: T, /) -> None
+ def remove: fn (value: T, /) -> None
+ def sort: fn (*, reverse: bool?) -> None
+ def __len__: fn () -> int
+ // def __iter__: fn () -> Iterator[T]
+ def __getitem__: fn (i: int, /) -> T
+ def __getitem__: fn (s: slice, /) -> list[T]
+ def __setitem__: fn (key: int, value: T, /) -> None
+ def __setitem__: fn (key: slice, value: list[T], /) -> None
+ def __delitem__: fn (key: int, /) -> None
+ def __delitem__: fn (key: slice, /) -> None
+ // def __add__: fn[S <: T] (value: list[S], /) -> list[T]
+ def __add__: fn (value: list[T], /) -> list[T]
+ def __iadd__: fn (value: list[T], /) -> list[T]
+ def __mul__: fn (value: int, /) -> list[T]
+ def __rmul__: fn (value: int, /) -> list[T]
+ def __imul__: fn (value: int, /) -> list[T]
+ def __contains__: fn (key: object, /) -> bool
+ // def __reversed__: fn (self) -> Iterator[_T]
+ def __gt__: fn (value: list[T], /) -> bool
+ def __ge__: fn (value: list[T], /) -> bool
+ def __lt__: fn (value: list[T], /) -> bool
+ def __le__: fn (value: list[T], /) -> bool
+ def __eq__: fn (value: object, /) -> bool
+
+ prop __doc__: str
+}
diff --git a/midas/checker/builtins.py b/midas/checker/builtins.py
index bc80084..b1adf6d 100644
--- a/midas/checker/builtins.py
+++ b/midas/checker/builtins.py
@@ -1,4 +1,41 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from midas.checker.types import (
+ BaseType,
+ GenericType,
+ TopType,
+ TypeVar,
+ UnitType,
+)
+
+if TYPE_CHECKING:
+ from midas.checker.registry import TypesRegistry
+
+
BUILTIN_SUBTYPES: dict[str, set[str]] = {
"float": {"int"},
"int": {"bool"},
}
+
+
+def define_builtins(reg: TypesRegistry):
+ """Define builtin types and operations"""
+ any = reg.define_type("Any", TopType())
+ unit = reg.define_type("None", UnitType())
+ object = reg.define_type("object", BaseType(name="object"))
+ bool = reg.define_type("bool", BaseType(name="bool"))
+ int = reg.define_type("int", BaseType(name="int"))
+ float = reg.define_type("float", BaseType(name="float"))
+ str = reg.define_type("str", BaseType(name="str"))
+ slice = reg.define_type("slice", BaseType(name="slice"))
+
+ list = reg.define_type(
+ "list",
+ GenericType(
+ name="list",
+ params=[TypeVar(name="T", bound=None)],
+ body=BaseType(name="list"),
+ ),
+ )
diff --git a/midas/checker/checker.py b/midas/checker/checker.py
index ab7261c..c26f0aa 100644
--- a/midas/checker/checker.py
+++ b/midas/checker/checker.py
@@ -1,812 +1,35 @@
-import logging
-from dataclasses import dataclass
from pathlib import Path
from typing import Optional
-import midas.ast.midas as m
-import midas.ast.python as p
-from midas.ast.location import Location
-from midas.checker.builtins import BUILTIN_SUBTYPES
-from midas.checker.diagnostic import Diagnostic, DiagnosticType
-from midas.checker.environment import Environment
-from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
-from midas.checker.types import (
- AliasType,
- BaseType,
- ComplexType,
- Function,
- Operation,
- Type,
- UnitType,
- UnknownType,
-)
-from midas.lexer.midas import MidasLexer
-from midas.lexer.token import Token
-from midas.parser.midas import MidasParser
-from midas.resolver.midas import MidasResolver
+from midas.checker.diagnostic import Diagnostic
+from midas.checker.midas import MidasTyper
+from midas.checker.python import PythonTyper
+from midas.checker.registry import TypesRegistry
+from midas.checker.reporter import Reporter
-class ReturnException(Exception):
- pass
+class TypeChecker:
+ def __init__(self):
+ self.types: TypesRegistry = TypesRegistry()
+ self.reporter: Reporter = Reporter()
+ self.midas_typer = MidasTyper(self.types, self.reporter)
+ self.python_typer = PythonTyper(self.types, self.reporter)
-@dataclass(frozen=True, kw_only=True)
-class MappedArgument:
- expr: p.Expr
- type: Type
- argument: Function.Argument
+ def import_midas(self, path: Path):
+ source: str = path.read_text()
+ return self.import_midas_source(source, path=str(path))
+ def import_midas_source(self, source: str, path: Optional[str] = None):
+ self.midas_typer.process(source, path)
-class Checker(
- p.Stmt.Visitor[None],
- p.Expr.Visitor[Type],
- p.MidasType.Visitor[Type],
-):
- """A type checker which can use custom type definitions"""
+ def type_check(self, path: Path):
+ source: str = path.read_text()
+ return self.type_check_source(source, path=str(path))
- def __init__(
- self,
- locals: dict[p.Expr, int],
- source_path: Path,
- types_paths: list[Path],
- ):
- self.logger: logging.Logger = logging.getLogger("Checker")
- self.source_path: Path = source_path
- self.types_paths: list[Path] = types_paths
- self.ctx: MidasResolver = MidasResolver()
- self.global_env: Environment = Environment()
- self.env: Environment = self.global_env
- self.locals: dict[p.Expr, int] = locals
- self.diagnostics: list[Diagnostic] = []
- self.judgements: list[tuple[p.Expr, Type]] = []
+ def type_check_source(self, source: str, path: Optional[str] = None):
+ self.python_typer.process(source, path)
- def diagnostic(self, type: DiagnosticType, location: Location, message: str):
- self.diagnostics.append(
- Diagnostic(
- file_path=self.source_path,
- location=location,
- type=type,
- message=message,
- )
- )
-
- def error(self, location: Location, message: str):
- self.diagnostic(
- type=DiagnosticType.ERROR,
- location=location,
- message=message,
- )
-
- def warning(self, location: Location, message: str):
- self.diagnostic(
- type=DiagnosticType.WARNING,
- location=location,
- message=message,
- )
-
- def info(self, location: Location, message: str):
- self.diagnostic(
- type=DiagnosticType.INFO,
- location=location,
- message=message,
- )
-
- def type_of(self, expr: p.Expr) -> Type:
- """Evaluate the type of an expression
-
- Args:
- expr (p.Expr): the expression to evaluate
-
- Returns:
- Type: the type of the given expression
- """
- type: Type = expr.accept(self)
- self.judgements.append((expr, type))
- return type
-
- def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
- """Evaluate a sequence of statements
-
- Args:
- block (list[p.Stmt]): the statements to evaluate
- env (Environment): the environment in which to evaluate
-
- Returns:
- bool: whether a return statement is present in the block
- """
- previous_env: Environment = self.env
- self.env = env
- returned: bool = False
- for i, stmt in enumerate(block):
- try:
- stmt.accept(self)
- except ReturnException:
- returned = True
- if i < len(block) - 1:
- self.warning(block[i + 1].location, "Unreachable statement")
- break
- self.env = previous_env
- return returned
-
- def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
- """Type check a sequence of statements and returns diagnostics
-
- Args:
- statements (list[p.Stmt]): the statements to evaluate and check
-
- Returns:
- list[Diagnostic]: the list of diagnostics (errors, warning, etc.)
- """
- self.diagnostics = []
-
- for path in self.types_paths:
- self.import_midas(path)
- self.logger.debug(f"Midas types: {self.ctx._types}")
- self.logger.debug(f"Midas operations: {self.ctx._operations}")
-
- for stmt in statements:
- stmt.accept(self)
-
- self.logger.debug(f"Final environment: {self.env.flat_dict()}")
- return self.diagnostics
-
- def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
- """Look up a variable in the environment it was declared
-
- Args:
- name (str): the name of the variable
- expr (p.Expr): the variable expression, used to lookup the scope distance
-
- Returns:
- Optional[Type]: the type of the variable, or None if it was not found
- """
- distance: Optional[int] = self.locals.get(expr)
- if distance is not None:
- return self.env.get_at(distance, name)
- return self.global_env.get(name)
-
- def import_midas(self, path: Path) -> None:
- """Import Midas definitions from a path
-
- Args:
- path (Path): the import path
- """
- self.logger.debug(f"Importing type definitions from {path}")
- lexer: MidasLexer = MidasLexer(path.read_text())
- tokens: list[Token] = lexer.process()
- parser: MidasParser = MidasParser(tokens)
- stmts: list[m.Stmt] = parser.parse()
- self.ctx.resolve(stmts)
-
- def unfold_type(self, type: Type) -> Type:
- match type:
- case AliasType(type=ref_type):
- return self.unfold_type(ref_type)
- case _:
- return type
-
- def is_subtype(self, type1: Type, type2: Type) -> bool:
- """Check whether `type1` is a subtype of `type2`
-
- For more details on the rules checked here, see TAPL Chap. 15-16-17
-
- Args:
- type1 (Type): the potential subtype
- type2 (Type): the potential supertype
-
- Returns:
- bool: whether `type1` is a subtype of `type2`
- """
-
- if type1 == type2:
- return True
-
- match (type1, type2):
- case (AliasType(type=base1), _):
- return self.is_subtype(base1, type2)
-
- case (BaseType(name=name1), BaseType(name=name2)):
- return name1 in BUILTIN_SUBTYPES.get(name2, set())
-
- case (ComplexType(properties=props1), ComplexType(properties=props2)):
- for k, t in props2.items():
- if k not in props1:
- return False
- if not self.is_subtype(props1[k], t):
- return False
- return True
-
- case (Function(returns=return1), Function(returns=return2)):
- if not self.is_func_subtype(type1, type2):
- return False
- if not self.is_subtype(return1, return2):
- return False
- return True
-
- return False
-
- # TODO: verify the logic in here
- def is_func_subtype(self, func1: Function, func2: Function) -> bool:
- """Check whether a function is a subtype of another
-
- Args:
- func1 (Function): the potential function subtype
- func2 (Function): the potential function supertype
-
- Returns:
- bool: whether `func1` is a subtype of `func2`
- """
- if not self.is_subtype(func1.returns, func2.returns):
- return False
-
- pos1: list[Function.Argument] = func1.pos_args
- mixed1: list[Function.Argument] = func1.args
- kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args}
- pos2: list[Function.Argument] = func2.pos_args
- mixed2: list[Function.Argument] = func2.args
- kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args}
-
- mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2}
- mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2}
-
- def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool:
- if not self.is_subtype(sub.type, sup.type):
- return False
- if not sup.required and sub.required:
- return False
- return True
-
- for arg1 in pos1:
- arg2: Function.Argument
- if arg1.pos < len(pos2):
- arg2 = pos2[arg1.pos]
- elif arg1.pos in mixed_by_pos:
- arg2 = mixed_by_pos[arg1.pos]
- elif not arg1.required:
- continue
- else:
- return False
- if not is_arg_subtype(arg2, arg1):
- return False
-
- for name, arg1 in kw1.items():
- arg2: Function.Argument
- if name in kw2:
- arg2 = kw2[name]
- elif name in mixed_by_name:
- arg2 = mixed_by_name[name]
- elif not arg1.required:
- continue
- else:
- return False
- if not is_arg_subtype(arg2, arg1):
- return False
-
- for arg1 in mixed1:
- pos_arg2: Optional[Function.Argument] = None
- kw_arg2: Optional[Function.Argument] = None
- if arg1.name in kw2:
- kw_arg2 = kw2[arg1.name]
- elif arg1.name in mixed_by_name:
- kw_arg2 = mixed_by_name[arg1.name]
- if arg1.pos < len(pos2):
- pos_arg2 = pos2[arg1.pos]
- elif arg1.pos in mixed_by_pos:
- pos_arg2 = mixed_by_pos[arg1.pos]
-
- # No match in func2 and arg is required
- if pos_arg2 is None and kw_arg2 is None and arg1.required:
- return False
-
- # Matching keyword argument
- if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1):
- return False
-
- # Matching positional argument
- if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1):
- return False
-
- mixed_positions: set[int] = {a.pos for a in mixed1}
- mixed_names: set[str] = {a.name for a in mixed1}
- for arg2 in pos2:
- if not arg2.required:
- continue
- if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions:
- return False
-
- for name, arg2 in kw2.items():
- if not arg2.required:
- continue
- if name not in kw1 and name not in mixed_names:
- return False
-
- for arg2 in mixed2:
- if arg2.required:
- continue
- pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions
- kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names
- if not pos_match or not kw_match:
- return False
-
- return True
-
- def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
- self.type_of(stmt.expr)
-
- def visit_function(self, stmt: p.Function) -> None:
- env: Environment = Environment(self.env)
- pos_args: list[Function.Argument] = []
- args: list[Function.Argument] = []
- kw_args: list[Function.Argument] = []
-
- def eval_arg_type(arg: p.Function.Argument) -> Type:
- if arg.type is not None:
- return arg.type.accept(self)
- if arg.default is not None:
- return arg.default.accept(self)
- return UnknownType()
-
- pos: int = 0
- for arg in stmt.posonlyargs:
- pos_args.append(
- Function.Argument(
- pos=pos,
- name=arg.name,
- type=eval_arg_type(arg),
- required=arg.default is None,
- )
- )
- pos += 1
- for arg in stmt.args:
- args.append(
- Function.Argument(
- pos=pos,
- name=arg.name,
- type=eval_arg_type(arg),
- required=arg.default is None,
- )
- )
- pos += 1
- for arg in stmt.kwonlyargs:
- kw_args.append(
- Function.Argument(
- pos=pos, # not relevant
- name=arg.name,
- type=eval_arg_type(arg),
- required=arg.default is None,
- )
- )
- pos += 1
-
- for arg in pos_args + args + kw_args:
- env.define(arg.name, arg.type)
-
- returns_hint: Optional[Type] = None
- if stmt.returns is not None:
- returns_hint = stmt.returns.accept(self)
- # Early define to handle simple fully-typed recursion
- inside_function: Function = Function(
- name=stmt.name,
- pos_args=pos_args,
- args=args,
- kw_args=kw_args,
- returns=returns_hint,
- )
- self.env.define(stmt.name, inside_function)
-
- returned: bool = self.process_block(stmt.body, env)
- inferred_return: Type = UnknownType()
- if not returned:
- env.return_types.append(UnitType())
- return_types: set[Type] = set(env.return_types)
- if len(return_types) == 1:
- inferred_return = list(return_types)[0]
- elif len(return_types) > 1:
- self.error(
- stmt.location,
- f"Mixed return types: {env.return_types}",
- )
-
- returns: Type = UnknownType()
- if returns_hint is not None:
- assert stmt.returns is not None
- returns = returns_hint
- if returns != inferred_return:
- self.error(
- stmt.returns.location,
- f"Return type mismatch, annotated {returns} but returns {inferred_return}",
- )
- else:
- returns = inferred_return
-
- # TODO: handle *args and **kwargs sinks
- function: Function = Function(
- name=stmt.name,
- pos_args=pos_args,
- args=args,
- kw_args=kw_args,
- returns=returns,
- )
- self.env.define(stmt.name, function)
-
- def visit_type_assign(self, stmt: p.TypeAssign) -> None:
- # TODO check not yet defined locally
- type: Type = stmt.type.accept(self)
- self.env.define(stmt.name, type)
-
- def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
- value_type: Type = self.type_of(stmt.value)
- for target in stmt.targets:
- self._assign(stmt.location, target, value_type)
-
- def _assign(self, location: Location, target: p.Expr, value_type: Type):
- match target:
- case p.VariableExpr():
- self._assign_var(location, target, value_type)
-
- case p.GetExpr():
- self._assign_attr(location, target, value_type)
-
- case _:
- if not isinstance(target, p.VariableExpr):
- self.logger.warning(f"Unsupported assignment to {target}")
- self.warning(target.location, f"Unsupported assignment to {target}")
-
- def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type):
- name: str = target.name
- var_type: Optional[Type] = self.look_up_variable(name, target)
-
- if var_type is None:
- self.env.define(name, value_type)
- else:
- # S <: T
- # Γ, x: T v: S
- # x = v
- if not self.is_subtype(value_type, var_type):
- self.error(
- location,
- f"Cannot assign {value_type} to {name} of type {var_type}",
- )
-
- def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type):
- object: Type = self.type_of(target.object)
- base_object: Type = self.unfold_type(object)
- match base_object:
- case ComplexType(properties=properties):
- if target.name not in properties:
- self.error(
- target.location, f"Unknown property '{target.name} on {object}"
- )
- return
-
- prop_type: Type = properties[target.name]
- if not self.is_subtype(value_type, prop_type):
- self.error(
- location,
- f"Cannot assign {value_type} to property '{target.name}' of type {prop_type} on {object}",
- )
- return
-
- case UnknownType():
- pass
-
- case _:
- self.error(
- target.location,
- f"Cannot assign {value_type} to unknown property '{target.name}' on {object}",
- )
-
- def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
- type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
- self.env.return_types.append(type)
- raise ReturnException()
-
- def visit_if_stmt(self, stmt: p.IfStmt) -> None:
- # Not evaluated in sub-environment because assignments in the test leak out of the if
- # For example:
- # if (m := 1 + 1) < 2:
- # ...
- # print(m) # <- m is still defined
- test_type: Type = stmt.test.accept(self)
-
- # TODO Allow subtypes or any type
- if test_type != self.ctx.get_type("bool"):
- self.error(
- stmt.test.location, f"If test must be a boolean, got {test_type}"
- )
-
- env: Environment = Environment(self.env)
- body_returned: bool = self.process_block(stmt.body, env)
- else_returned: bool = self.process_block(stmt.orelse, env)
- self.env.return_types.extend(env.return_types)
- if body_returned and else_returned:
- raise ReturnException()
-
- def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
- method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
- if method is None:
- self.logger.warning(f"Unsupported operator {expr.operator}")
- self.warning(expr.location, f"Unsupported operator {expr.operator}")
- return UnknownType()
- left: Type = self.type_of(expr.left)
- right: Type = self.type_of(expr.right)
-
- operations: list[Operation] = self.ctx.get_operations_by_name(method)
- valid_operations: list[Operation] = []
- for op in operations:
- sig: Operation.CallSignature = op.signature
- if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right):
- valid_operations.append(op)
-
- if len(valid_operations) == 0:
- self.error(
- expr.location,
- f"Undefined operation {method} between {left} and {right}",
- )
- return UnknownType()
- elif len(valid_operations) == 1:
- self.logger.debug(f"Unique operation {method} between {left} and {right}")
- return valid_operations[0].result
-
- for i, op1 in enumerate(valid_operations):
- sig1: Operation.CallSignature = op1.signature
- best_match: bool = True
- for j, op2 in enumerate(valid_operations):
- if i == j:
- continue
- sig2: Operation.CallSignature = op2.signature
- if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype(
- sig1.right, sig2.right
- ):
- best_match = False
- break
- self.logger.debug(f"{op1} is a full overload of {op2}")
- if best_match:
- return op1.result
-
- overloads: list[str] = [
- f"({op.signature.left} {op.signature.method} {op.signature.right}) -> {op.result}"
- for op in valid_operations
- ]
- self.error(
- expr.location,
- f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(overloads)}",
- )
- return UnknownType()
-
- def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
- method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
- if method is None:
- self.logger.warning(f"Unsupported operator {expr.operator}")
- self.warning(expr.location, f"Unsupported operator {expr.operator}")
- return UnknownType()
- left: Type = self.type_of(expr.left)
- right: Type = self.type_of(expr.right)
-
- result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
- if result is None:
- self.error(
- expr.location,
- f"Undefined operation {method} between {left} and {right}",
- )
- return UnknownType()
- return result
-
- def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
-
- def visit_call_expr(self, expr: p.CallExpr) -> Type:
- callee: Type = self.type_of(expr.callee)
- if not isinstance(callee, Function):
- self.error(expr.callee.location, "Callee is not a function")
- return UnknownType()
- function: Function = callee
- mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
- for arg in mapped:
- if not self.is_subtype(arg.type, arg.argument.type):
- self.error(
- arg.expr.location,
- f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
- )
- return function.returns
-
- def visit_get_expr(self, expr: p.GetExpr) -> Type:
- object: Type = self.type_of(expr.object)
- base_object: Type = self.unfold_type(object)
- match base_object:
- case ComplexType(properties=properties):
- if expr.name not in properties:
- self.error(
- expr.location, f"Unknown property '{expr.name} on {object}"
- )
- return UnknownType()
- return properties[expr.name]
-
- case UnknownType():
- return UnknownType()
-
- case _:
- self.error(
- expr.location, f"Cannot get property '{expr.name}' on {object}"
- )
- return UnknownType()
-
- def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
- match expr.value:
- case bool(): # Must be before int
- return self.ctx.get_type("bool")
- case int():
- return self.ctx.get_type("int")
- case float():
- return self.ctx.get_type("float")
- case str():
- return self.ctx.get_type("str")
- case _:
- self.warning(expr.location, f"Unknown literal {expr}")
- return UnknownType()
-
- def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
- return self.look_up_variable(expr.name, expr) or UnknownType()
-
- def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
- left: Type = expr.left.accept(self)
- right: Type = expr.right.accept(self)
-
- if self.is_subtype(left, right):
- return right
- if self.is_subtype(right, left):
- return left
-
- self.error(
- expr.location,
- f"Incompatible operand types, {left=} and {right=}",
- )
- return UnknownType()
-
- def visit_cast_expr(self, expr: p.CastExpr) -> Type:
- return expr.type.accept(self)
-
- def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
- test_type: Type = expr.test.accept(self)
-
- # TODO Allow subtypes or any type
- if test_type != self.ctx.get_type("bool"):
- self.error(
- expr.test.location, f"If test must be a boolean, got {test_type}"
- )
-
- true_type: Type = expr.if_true.accept(self)
- false_type: Type = expr.if_false.accept(self)
- if self.is_subtype(true_type, false_type):
- return false_type
- if self.is_subtype(false_type, true_type):
- return true_type
-
- self.error(
- expr.location,
- f"Incompatible types in ternary if branches: true={true_type} and false={false_type}",
- )
- return UnknownType()
-
- def visit_base_type(self, node: p.BaseType) -> Type:
- return self.ctx.get_type(node.base)
-
- def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
-
- def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
-
- def visit_frame_type(self, node: p.FrameType) -> Type: ...
-
- def map_call_arguments(
- self, function: Function, call: p.CallExpr
- ) -> list[MappedArgument]:
- """Map call arguments to function parameters as defined in its signature
-
- This method maps positional-only, keyword-only and mixed parameter definitions
- with the arguments passed at the call site
-
- Any mismatched, missing or unexpected argument is reported as a diagnostic
-
- Args:
- function (Function): the function definition
- call (p.CallExpr): the call expression
-
- Returns:
- list[MappedArgument]: the list of mapped arguments
- """
- positional: list[tuple[p.Expr, Type]] = [
- (arg, self.type_of(arg)) for arg in call.arguments
- ]
- keywords: dict[str, tuple[p.Expr, Type]] = {
- name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
- }
- set_args: set[str] = set()
-
- required_positional: list[str] = [
- arg.name for arg in function.pos_args + function.args if arg.required
- ]
- required_keyword: list[str] = [
- arg.name for arg in function.kw_args if arg.required
- ]
-
- mapped: list[MappedArgument] = []
-
- pos_params: list[Function.Argument] = list(function.pos_args)
- mixed_params: list[Function.Argument] = list(function.args)
- kw_params: dict[str, Function.Argument] = {
- arg.name: arg for arg in function.kw_args
- }
-
- # TODO: handle *args and **kwargs sinks
- for arg in positional:
- param: Function.Argument
- if len(pos_params) != 0:
- param = pos_params.pop(0)
- elif len(mixed_params) != 0:
- param = mixed_params.pop(0)
- else:
- self.error(arg[0].location, "Too many positional arguments")
- break
- name: str = param.name
- if name in required_positional:
- required_positional.remove(name)
- if name in required_keyword:
- required_keyword.remove(name)
- set_args.add(name)
- mapped.append(
- MappedArgument(
- expr=arg[0],
- type=arg[1],
- argument=param,
- )
- )
-
- kw_params.update({arg.name: arg for arg in mixed_params})
- for name, arg in keywords.items():
- param: Function.Argument
- if name not in kw_params:
- if name in set_args:
- self.error(
- arg[0].location, f"Multiple values for argument '{name}'"
- )
- else:
- self.error(arg[0].location, f"Unknown keyword argument '{name}'")
- continue
- param = kw_params.pop(name)
- if name in required_positional:
- required_positional.remove(name)
- if name in required_keyword:
- required_keyword.remove(name)
- set_args.add(name)
- mapped.append(
- MappedArgument(
- expr=arg[0],
- type=arg[1],
- argument=param,
- )
- )
-
- def join_args(args: list[str]) -> str:
- args = list(map(lambda a: f"'{a}'", args))
- if len(args) == 0:
- return ""
- if len(args) == 1:
- return args[0]
- return ", ".join(args[:-1]) + " and " + args[-1]
-
- if len(required_positional) != 0:
- plural: str = "" if len(required_positional) == 1 else "s"
- args: str = join_args(required_positional)
- self.error(
- call.location,
- f"Missing required positional argument{plural}: {args}",
- )
-
- if len(required_keyword) != 0:
- plural: str = "" if len(required_keyword) == 1 else "s"
- args: str = join_args(required_keyword)
- self.error(
- call.location,
- f"Missing required keyword argument{plural}: {args}",
- )
-
- return mapped
+ @property
+ def diagnostics(self) -> list[Diagnostic]:
+ return self.reporter.diagnostics
diff --git a/midas/checker/diagnostic.py b/midas/checker/diagnostic.py
index 77f687e..f4b3d12 100644
--- a/midas/checker/diagnostic.py
+++ b/midas/checker/diagnostic.py
@@ -1,6 +1,5 @@
from dataclasses import dataclass
from enum import StrEnum
-from pathlib import Path
from typing import Optional
from midas.ast.location import Location
@@ -14,7 +13,7 @@ class DiagnosticType(StrEnum):
@dataclass(frozen=True)
class Diagnostic:
- file_path: Path
+ file_path: Optional[str]
location: Location
type: DiagnosticType
message: str
@@ -28,10 +27,16 @@ class Diagnostic:
and self.location.end_col_offset is not None
):
end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}"
- loc: str = (
- f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}"
- )
- return f"{self.type} in {self.file_path} {loc}"
+
+ loc: str = ""
+ if self.file_path is not None:
+ loc += f" in {self.file_path}"
+ if end_loc is None:
+ loc += f" at {start_loc}"
+ else:
+ loc += f" from {start_loc} to {end_loc}"
+
+ return f"{self.type}{loc}"
def __str__(self) -> str:
return f"{self.location_str}: {self.message}"
diff --git a/midas/checker/midas.py b/midas/checker/midas.py
new file mode 100644
index 0000000..3764c03
--- /dev/null
+++ b/midas/checker/midas.py
@@ -0,0 +1,206 @@
+import logging
+from pathlib import Path
+from typing import Optional
+
+import midas.ast.midas as m
+from midas.checker.builtins import define_builtins
+from midas.checker.registry import TypesRegistry
+from midas.checker.reporter import FileReporter, Reporter
+from midas.checker.types import (
+ AliasType,
+ ComplexType,
+ ExtensionType,
+ Function,
+ GenericType,
+ Type,
+ TypeVar,
+ UnknownType,
+)
+from midas.lexer.midas import MidasLexer
+from midas.lexer.token import Token
+from midas.parser.midas import MidasParser
+
+
+class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
+ """A resolver which evaluates Midas type definitions and build a registry"""
+
+ def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
+ self.logger: logging.Logger = logging.getLogger("MidasTyper")
+ self.reporter: FileReporter = reporter.for_file(None)
+
+ self.types: TypesRegistry = types
+ self._local_variables: dict[str, TypeVar] = {}
+
+ self._current_name: Optional[str] = None
+
+ define_builtins(self.types)
+ builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
+ self.process(builtins_path.read_text(), str(builtins_path))
+
+ def process(self, source: str, path: Optional[str]):
+ self.reporter = self.reporter.for_file(path)
+ lexer: MidasLexer = MidasLexer(source)
+ tokens: list[Token] = lexer.process()
+ parser: MidasParser = MidasParser(tokens)
+ stmts: list[m.Stmt] = parser.parse()
+ for error in parser.errors:
+ self.reporter.error(error.token.get_location(), error.message)
+ self.resolve(stmts)
+
+ def get_type(self, name: str) -> Type:
+ """Get a type from its name
+
+ Args:
+ name (str): the name of the type
+
+ Raises:
+ NameError: if the type is not defined
+
+ Returns:
+ Type: the type
+ """
+ if name in self._local_variables:
+ return self._local_variables[name]
+ return self.types.get_type(name)
+
+ def resolve(self, stmts: list[m.Stmt]):
+ """Process a sequence of statements
+
+ Args:
+ stmts (list[m.Stmt]): the statements
+ """
+ for stmt in stmts:
+ stmt.accept(self)
+
+ def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
+ name: str = stmt.name.lexeme
+ self._current_name = name
+ params: list[TypeVar] = self._resolve_type_params(stmt.params)
+
+ type: Type = stmt.type.accept(self)
+ if len(params) != 0:
+ type = GenericType(name=name, params=params, body=type)
+ else:
+ type = AliasType(name=name, type=type)
+ self.types.define_type(name, type)
+ self._local_variables.clear()
+ self._current_name = None
+
+ def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ...
+
+ def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
+ self._resolve_type_params(stmt.params)
+ base_name: str = stmt.name.lexeme
+ try:
+ _ = self.get_type(base_name)
+ except NameError:
+ self.reporter.error(stmt.name.get_location(), f"Unknown type '{base_name}'")
+
+ for member in stmt.members:
+ member_type: Type = member.type.accept(self)
+ self.types.define_member(
+ base_name,
+ member.name.lexeme,
+ member_type,
+ member.kind == m.MemberKind.METHOD,
+ )
+
+ def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
+ self.reporter.warning(stmt.location, "PredicateStmt not yet supported")
+
+ def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
+ self.reporter.warning(expr.location, "LogicalExpr not yet supported")
+
+ def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
+ self.reporter.warning(expr.location, "BinaryExpr not yet supported")
+
+ def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
+ self.reporter.warning(expr.location, "UnaryExpr not yet supported")
+
+ def visit_get_expr(self, expr: m.GetExpr) -> None:
+ self.reporter.warning(expr.location, "GetExpr not yet supported")
+
+ def visit_variable_expr(self, expr: m.VariableExpr) -> None:
+ self.reporter.warning(expr.location, "VariableExpr not yet supported")
+
+ def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
+ return expr.expr.accept(self)
+
+ def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
+ self.reporter.warning(expr.location, "LiteralExpr not yet supported")
+
+ def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
+ self.reporter.warning(expr.location, "WildcardExpr not yet supported")
+
+ def visit_named_type(self, type: m.NamedType) -> Type:
+ name: str = type.name.lexeme
+ try:
+ return self.get_type(name)
+ except NameError:
+ msg: str = f"Undefined type {name}"
+ if self._current_name == name:
+ msg += ". Recursive types are not supported, use an extend block"
+ self.reporter.error(type.name.get_location(), msg)
+ return UnknownType()
+
+ def visit_generic_type(self, type: m.GenericType) -> Type:
+ type_: Type = type.type.accept(self)
+ args: list[Type] = [arg.accept(self) for arg in type.args]
+ try:
+ return self.types.apply_generic(type_, args)
+ except Exception as e:
+ self.reporter.error(type.location, f"Cannot apply generic type: {e}")
+ return UnknownType()
+
+ def visit_constraint_type(self, type: m.ConstraintType) -> Type:
+ type_: Type = type.type.accept(self)
+ type.constraint.accept(self)
+ # TODO
+ return UnknownType()
+
+ def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
+ return ComplexType(
+ members={
+ member.name.lexeme: member.type.accept(self) for member in type.members
+ }
+ )
+
+ def visit_extension_type(self, type: m.ExtensionType) -> Type:
+ return ExtensionType(
+ base=type.base.accept(self),
+ extension=self.visit_complex_type(type.extension),
+ )
+
+ def visit_function_type(self, type: m.FunctionType) -> Type:
+ n_pos_args: int = len(type.pos_args)
+ n_args: int = len(type.args)
+
+ def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
+ return Function.Argument(
+ pos=i,
+ name=arg.name.lexeme if arg.name is not None else str(i),
+ type=arg.type.accept(self),
+ required=arg.required,
+ )
+
+ return Function(
+ pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
+ args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
+ kw_args=[
+ process_arg(arg, i + n_pos_args + n_args)
+ for i, arg in enumerate(type.kw_args)
+ ],
+ returns=type.returns.accept(self),
+ )
+
+ def _resolve_type_params(self, params: list[m.TypeParam]):
+ vars: list[TypeVar] = []
+ for param in params:
+ name: str = param.name.lexeme
+ bound: Optional[Type] = None
+ if param.bound is not None:
+ bound = param.bound.accept(self)
+ var = TypeVar(name=name, bound=bound)
+ self._local_variables[name] = var
+ vars.append(var)
+ return vars
diff --git a/midas/checker/operators.py b/midas/checker/operators.py
index e65ab07..58af88c 100644
--- a/midas/checker/operators.py
+++ b/midas/checker/operators.py
@@ -29,3 +29,10 @@ COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
# ast.In: "__in__",
# ast.NotIn: "__notin__",
}
+
+UNARY_METHODS: dict[Type[ast.unaryop], str] = {
+ ast.Invert: "__invert__",
+ # ast.Not: "",
+ ast.UAdd: "__pos__",
+ ast.USub: "__neg__",
+}
diff --git a/midas/checker/python.py b/midas/checker/python.py
new file mode 100644
index 0000000..a0f7a06
--- /dev/null
+++ b/midas/checker/python.py
@@ -0,0 +1,859 @@
+import ast
+import logging
+from dataclasses import dataclass
+from typing import Optional
+
+import midas.ast.python as p
+from midas.ast.location import Location
+from midas.checker.environment import Environment
+from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
+from midas.checker.registry import TypesRegistry
+from midas.checker.reporter import FileReporter, Reporter
+from midas.checker.resolver import Resolver
+from midas.checker.types import (
+ Function,
+ OverloadedFunction,
+ Type,
+ UnitType,
+ UnknownType,
+ unfold_type,
+)
+from midas.parser.python import PythonParser
+
+TypedExpr = tuple[p.Expr, Type]
+
+
+class ReturnException(Exception):
+ pass
+
+
+@dataclass(frozen=True, kw_only=True)
+class MappedArgument:
+ expr: p.Expr
+ type: Type
+ argument: Function.Argument
+
+
+@dataclass(frozen=True, kw_only=True)
+class OverloadCandidate:
+ function: Function
+ mapped: list[MappedArgument]
+
+
+class PythonTyper(
+ p.Stmt.Visitor[None],
+ p.Expr.Visitor[Type],
+ p.MidasType.Visitor[Type],
+):
+ """A type checker which can use custom type definitions"""
+
+ def __init__(
+ self,
+ types: TypesRegistry,
+ reporter: Reporter,
+ ):
+ self.logger: logging.Logger = logging.getLogger("PythonTyper")
+ self.reporter: FileReporter = reporter.for_file(None)
+ self.types: TypesRegistry = types
+ self.global_env: Environment = Environment()
+ self.env: Environment = self.global_env
+ self.locals: dict[p.Expr, int] = {}
+ self.judgements: list[tuple[p.Expr, Type]] = []
+
+ def process(self, source: str, path: Optional[str]):
+ self.reporter = self.reporter.for_file(path)
+
+ tree: ast.Module = ast.parse(source, filename=path or "")
+ parser = PythonParser()
+ stmts: list[p.Stmt] = parser.parse_module(tree)
+ resolver = Resolver()
+ resolver.resolve(*stmts)
+
+ self.env = self.global_env
+ self.locals = resolver.locals
+ self.judgements = []
+
+ self.check(stmts)
+
+ def type_of(self, expr: p.Expr) -> Type:
+ """Evaluate the type of an expression
+
+ Args:
+ expr (p.Expr): the expression to evaluate
+
+ Returns:
+ Type: the type of the given expression
+ """
+ type: Type = expr.accept(self)
+ self.judgements.append((expr, type))
+ return type
+
+ def resolve_type_expr(self, expr: p.MidasType) -> Type:
+ return expr.accept(self)
+
+ def process_stmt(self, stmt: p.Stmt) -> None:
+ stmt.accept(self)
+
+ def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
+ """Evaluate a sequence of statements
+
+ Args:
+ block (list[p.Stmt]): the statements to evaluate
+ env (Environment): the environment in which to evaluate
+
+ Returns:
+ bool: whether a return statement is present in the block
+ """
+ previous_env: Environment = self.env
+ self.env = env
+ returned: bool = False
+ for i, stmt in enumerate(block):
+ try:
+ self.process_stmt(stmt)
+ except ReturnException:
+ returned = True
+ if i < len(block) - 1:
+ self.reporter.warning(
+ block[i + 1].location, "Unreachable statement"
+ )
+ break
+ self.env = previous_env
+ return returned
+
+ def check(self, statements: list[p.Stmt]) -> None:
+ """Type check a sequence of statements and returns diagnostics
+
+ Args:
+ statements (list[p.Stmt]): the statements to evaluate and check
+ """
+ for stmt in statements:
+ self.process_stmt(stmt)
+
+ self.logger.debug(f"Final environment: {self.env.flat_dict()}")
+
+ def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
+ """Look up a variable in the environment it was declared
+
+ Args:
+ name (str): the name of the variable
+ expr (p.Expr): the variable expression, used to lookup the scope distance
+
+ Returns:
+ Optional[Type]: the type of the variable, or None if it was not found
+ """
+ distance: Optional[int] = self.locals.get(expr)
+ if distance is not None:
+ return self.env.get_at(distance, name)
+ return self.global_env.get(name)
+
+ def is_subtype(self, type1: Type, type2: Type) -> bool:
+ return self.types.is_subtype(type1, type2)
+
+ def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
+ self.type_of(stmt.expr)
+
+ def visit_function(self, stmt: p.Function) -> None:
+ env: Environment = Environment(self.env)
+ pos_args: list[Function.Argument] = []
+ args: list[Function.Argument] = []
+ kw_args: list[Function.Argument] = []
+
+ def eval_arg_type(arg: p.Function.Argument) -> Type:
+ if arg.type is not None:
+ return self.resolve_type_expr(arg.type)
+ if arg.default is not None:
+ return self.type_of(arg.default)
+ return UnknownType()
+
+ pos: int = 0
+ for arg in stmt.posonlyargs:
+ pos_args.append(
+ Function.Argument(
+ pos=pos,
+ name=arg.name,
+ type=eval_arg_type(arg),
+ required=arg.default is None,
+ )
+ )
+ pos += 1
+ for arg in stmt.args:
+ args.append(
+ Function.Argument(
+ pos=pos,
+ name=arg.name,
+ type=eval_arg_type(arg),
+ required=arg.default is None,
+ )
+ )
+ pos += 1
+ for arg in stmt.kwonlyargs:
+ kw_args.append(
+ Function.Argument(
+ pos=pos, # not relevant
+ name=arg.name,
+ type=eval_arg_type(arg),
+ required=arg.default is None,
+ )
+ )
+ pos += 1
+
+ for arg in pos_args + args + kw_args:
+ env.define(arg.name, arg.type)
+
+ returns_hint: Optional[Type] = None
+ if stmt.returns is not None:
+ returns_hint = self.resolve_type_expr(stmt.returns)
+ # Early define to handle simple fully-typed recursion
+ inside_function: Function = Function(
+ pos_args=pos_args,
+ args=args,
+ kw_args=kw_args,
+ returns=returns_hint,
+ )
+ self.env.define(stmt.name, inside_function)
+
+ returned: bool = self.process_block(stmt.body, env)
+ inferred_return: Type = UnknownType()
+ if not returned:
+ env.return_types.append(UnitType())
+ return_types: list[Type] = self.types.reduce_types(env.return_types)
+ if len(return_types) == 1:
+ inferred_return = return_types[0]
+ elif len(return_types) > 1:
+ self.reporter.error(
+ stmt.location,
+ f"Mixed return types: {return_types}",
+ )
+
+ returns: Type = UnknownType()
+ if returns_hint is not None:
+ assert stmt.returns is not None
+ returns = returns_hint
+ if returns != inferred_return:
+ self.reporter.error(
+ stmt.returns.location,
+ f"Return type mismatch, annotated {returns} but returns {inferred_return}",
+ )
+ else:
+ returns = inferred_return
+
+ # TODO: handle *args and **kwargs sinks
+ function: Function = Function(
+ pos_args=pos_args,
+ args=args,
+ kw_args=kw_args,
+ returns=returns,
+ )
+ self.env.define(stmt.name, function)
+
+ def visit_type_assign(self, stmt: p.TypeAssign) -> None:
+ # TODO check not yet defined locally
+ type: Type = self.resolve_type_expr(stmt.type)
+ self.env.define(stmt.name, type)
+
+ def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
+ value_type: Type = self.type_of(stmt.value)
+ for target in stmt.targets:
+ self._assign(stmt.location, target, value_type)
+
+ def _assign(self, location: Location, target: p.Expr, value_type: Type):
+ match target:
+ case p.VariableExpr():
+ self._assign_var(location, target, value_type)
+
+ case p.GetExpr(object=object, name=name):
+ self._assign_attr(location, object, name, value_type)
+
+ case _:
+ if not isinstance(target, p.VariableExpr):
+ self.logger.warning(f"Unsupported assignment to {target}")
+ self.reporter.warning(
+ target.location, f"Unsupported assignment to {target}"
+ )
+
+ def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type):
+ name: str = target.name
+ var_type: Optional[Type] = self.look_up_variable(name, target)
+
+ if var_type is None:
+ self.env.define(name, value_type)
+ else:
+ # S <: T
+ # Γ, x: T v: S
+ # x = v
+ if not self.is_subtype(value_type, var_type):
+ self.reporter.error(
+ location,
+ f"Cannot assign {value_type} to variable '{name}' of type {var_type}",
+ )
+
+ def _assign_attr(
+ self, location: Location, object: p.Expr, name: str, value_type: Type
+ ):
+ object_type: Type = self.type_of(object)
+ member: Optional[Type] = self.types.lookup_member(object_type, name)
+ if member is None:
+ self.reporter.error(location, f"Unknown member '{name}' of {object_type}")
+ return
+ self.logger.debug(f"Member '{name}' of {object_type} has type {member}")
+ if not self.is_subtype(value_type, member):
+ self.reporter.error(
+ location,
+ f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}",
+ )
+
+ def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
+ type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType()
+ self.env.return_types.append(type)
+ raise ReturnException()
+
+ def visit_if_stmt(self, stmt: p.IfStmt) -> None:
+ # Not evaluated in sub-environment because assignments in the test leak out of the if
+ # For example:
+ # if (m := 1 + 1) < 2:
+ # ...
+ # print(m) # <- m is still defined
+ test_type: Type = self.type_of(stmt.test)
+
+ # TODO Allow subtypes or any type
+ if test_type != self.types.get_type("bool"):
+ self.reporter.error(
+ stmt.test.location, f"If test must be a boolean, got {test_type}"
+ )
+
+ env: Environment = Environment(self.env)
+ body_returned: bool = self.process_block(stmt.body, env)
+ else_returned: bool = self.process_block(stmt.orelse, env)
+ self.env.return_types.extend(env.return_types)
+ if body_returned and else_returned:
+ raise ReturnException()
+
+ def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
+ method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
+ if method is None:
+ self.logger.warning(f"Unsupported operator {expr.operator}")
+ self.reporter.warning(
+ expr.location, f"Unsupported operator {expr.operator}"
+ )
+ return UnknownType()
+
+ return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
+
+ def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
+ method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
+ if method is None:
+ self.logger.warning(f"Unsupported operator {expr.operator}")
+ self.reporter.warning(
+ expr.location, f"Unsupported operator {expr.operator}"
+ )
+ return UnknownType()
+
+ return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
+
+ def _visit_binary_expr(
+ self, location: Location, left_expr: p.Expr, right_expr: p.Expr, method: str
+ ) -> Type:
+ left: Type = self.type_of(left_expr)
+ right: Type = self.type_of(right_expr)
+
+ operation: Optional[Type] = self.types.lookup_member(left, method)
+ if operation is None:
+ self.reporter.error(
+ location,
+ f"Undefined operation {method} between {left} and {right}",
+ )
+ return UnknownType()
+
+ return self._get_call_result(location, operation, [(right_expr, right)], {})
+
+ def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
+ method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
+ if method is None:
+ self.logger.warning(f"Unsupported operator {expr.operator}")
+ self.reporter.warning(
+ expr.location, f"Unsupported operator {expr.operator}"
+ )
+ return UnknownType()
+
+ operand: Type = self.type_of(expr.right)
+ operation: Optional[Type] = self.types.lookup_member(operand, method)
+ if operation is None:
+ self.reporter.error(
+ expr.location,
+ f"Undefined operation {method} for {operand}",
+ )
+ return UnknownType()
+
+ return self._get_call_result(
+ expr.location, operation, [(expr.right, operand)], {}
+ )
+
+ def visit_call_expr(self, expr: p.CallExpr) -> Type:
+ callee: Type = self.type_of(expr.callee)
+ positional: list[TypedExpr] = [
+ (arg, self.type_of(arg)) for arg in expr.arguments
+ ]
+ keywords: dict[str, TypedExpr] = {
+ name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
+ }
+ return self._get_call_result(
+ location=expr.location,
+ callee=callee,
+ positional=positional,
+ keywords=keywords,
+ )
+
+ def visit_get_expr(self, expr: p.GetExpr) -> Type:
+ object: Type = self.type_of(expr.object)
+ member: Optional[Type] = self.types.lookup_member(object, expr.name)
+ if member is None:
+ self.reporter.error(
+ expr.location, f"Unknown member '{expr.name}' of {object}"
+ )
+ return UnknownType()
+ self.logger.debug(f"Member '{expr.name}' of {object} has type {member}")
+ return member
+
+ def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
+ match expr.value:
+ case bool(): # Must be before int
+ return self.types.get_type("bool")
+ case int():
+ return self.types.get_type("int")
+ case float():
+ return self.types.get_type("float")
+ case str():
+ return self.types.get_type("str")
+ case _:
+ self.reporter.warning(expr.location, f"Unknown literal {expr}")
+ return UnknownType()
+
+ def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
+ type: Optional[Type] = self.look_up_variable(expr.name, expr)
+ if type is None:
+ self.logger.debug(f"Unknown variable {expr.name} in {self.env.flat_dict()}")
+ self.reporter.warning(expr.location, "Unknown variable")
+ return type or UnknownType()
+
+ def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
+ left: Type = self.type_of(expr.left)
+ right: Type = self.type_of(expr.right)
+
+ if self.is_subtype(left, right):
+ return right
+ if self.is_subtype(right, left):
+ return left
+
+ self.reporter.error(
+ expr.location,
+ f"Incompatible operand types, {left=} and {right=}",
+ )
+ return UnknownType()
+
+ def visit_cast_expr(self, expr: p.CastExpr) -> Type:
+ return self.resolve_type_expr(expr.type)
+
+ def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
+ test_type: Type = self.type_of(expr.test)
+
+ # TODO Allow subtypes or any type
+ if test_type != self.types.get_type("bool"):
+ self.reporter.error(
+ expr.test.location, f"If test must be a boolean, got {test_type}"
+ )
+
+ true_type: Type = self.type_of(expr.if_true)
+ false_type: Type = self.type_of(expr.if_false)
+ if self.is_subtype(true_type, false_type):
+ return false_type
+ if self.is_subtype(false_type, true_type):
+ return true_type
+
+ self.reporter.error(
+ expr.location,
+ f"Incompatible types in ternary if branches: true={true_type} and false={false_type}",
+ )
+ return UnknownType()
+
+ def visit_list_expr(self, expr: p.ListExpr) -> Type:
+ list_type: Type = self.types.get_type("list")
+ item_types: list[Type] = [self.type_of(item) for item in expr.items]
+ item_types = self.types.reduce_types(item_types)
+
+ if len(item_types) == 0:
+ return list_type
+
+ if len(item_types) == 1:
+ item_type: Type = item_types[0]
+ return self.types.apply_generic(list_type, [item_type])
+ self.reporter.error(
+ expr.location,
+ f"Heterogeneous list items: {item_types}",
+ )
+ return self.types.apply_generic(list_type, [UnknownType()])
+
+ def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type:
+ object: Type = self.type_of(expr.object)
+ operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
+ if operation is None:
+ self.reporter.error(
+ expr.location,
+ f"Undefined method __getitem__ on {object}",
+ )
+ return UnknownType()
+
+ index: Type = self.type_of(expr.index)
+ return self._get_call_result(
+ expr.location, operation, [(expr.index, index)], {}
+ )
+
+ def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
+ return self.types.get_type("slice")
+
+ def visit_base_type(self, node: p.BaseType) -> Type:
+ base: Type
+ try:
+ base = self.types.get_type(node.base)
+ except NameError:
+ self.reporter.warning(node.location, f"Unknown type '{node.base}'")
+ return UnknownType()
+
+ if node.param is not None:
+ param: Type = self.resolve_type_expr(node.param)
+ return self.types.apply_generic(base, [param])
+ return base
+
+ def visit_constraint_type(self, node: p.ConstraintType) -> Type:
+ self.reporter.warning(node.location, "ConstraintType not yet supported")
+ return UnknownType()
+
+ def visit_frame_column(self, node: p.FrameColumn) -> Type:
+ self.reporter.warning(node.location, "FrameColumn not yet supported")
+ return UnknownType()
+
+ def visit_frame_type(self, node: p.FrameType) -> Type:
+ self.reporter.warning(node.location, "FrameType not yet supported")
+ return UnknownType()
+
+ def _get_call_result(
+ self,
+ location: Location,
+ callee: Type,
+ positional: list[TypedExpr],
+ keywords: dict[str, TypedExpr],
+ ) -> Type:
+ """Get the result type of a function call
+
+ If the function has overloads, the function will try to resolve the
+ appropriate signature.
+ Argument types are matched to the defined parameters.
+ The function doesn't take the raw expression as a parameter to accomodate
+ for desugared calls such as for operators.
+
+ Args:
+ location (Location): the call location
+ callee (Type): the called function
+ positional (list[TypedExpr]): the list positional arguments
+ keywords (dict[str, TypedExpr]): the map of keyword arguments
+
+ Returns:
+ Type: the return type of the call, or `UnknownType` if either
+ the call is invalid or no overload matched the arguments uniquely
+ """
+ match callee:
+ case Function() as function:
+ valid: bool
+ mapped: list[MappedArgument]
+ valid, mapped = self.map_call_arguments(
+ function, location, positional, keywords
+ )
+ valid = valid and self._are_arguments_valid(mapped)
+ if not valid:
+ return UnknownType()
+ return function.returns
+
+ case OverloadedFunction(overloads=overloads):
+ function = self._match_overload(
+ overloads, location, positional, keywords
+ )
+ if function is None:
+ return UnknownType()
+ return function.returns
+ case _:
+ self.reporter.error(location, f"{callee} is not callable")
+ return UnknownType()
+
+ def _are_arguments_valid(
+ self,
+ arguments: list[MappedArgument],
+ report_errors: bool = True,
+ ) -> bool:
+ """Check whether the passed argument types correspond to their matched parameter definitions
+
+ Args:
+ arguments (list[MappedArgument]): the list of argument/parameter pairs
+ report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
+
+ Returns:
+ bool: True if all arguments fit the matching parameter definitions, False otherwise
+ """
+ valid: bool = True
+ for arg in arguments:
+ if not self.is_subtype(arg.type, arg.argument.type):
+ if report_errors:
+ self.reporter.error(
+ arg.expr.location,
+ f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
+ )
+ valid = False
+ return valid
+
+ def _match_overload(
+ self,
+ overloads: list[Type],
+ location: Location,
+ positional: list[TypedExpr],
+ keywords: dict[str, TypedExpr],
+ ) -> Optional[Function]:
+ """Try and resolve the appropriate overload for the given arguments
+
+ Args:
+ overloads (list[Type]): the list of possible overloads
+ location (Location): the call location
+ positional (list[TypedExpr]): the list of positional arguments
+ keywords (dict[str, TypedExpr]): the map of keywords arguments
+
+ Returns:
+ Optional[Function]: the resolved function signature if it can be
+ determined unambigously, or `None`.
+ """
+ candidates: list[OverloadCandidate] = []
+ for overload in overloads:
+ function: Type = unfold_type(overload)
+ if not isinstance(function, Function):
+ self.logger.error(
+ f"Overload is not a function: {overload} is {function}"
+ )
+ continue
+ valid, mapped = self.map_call_arguments(
+ function=function,
+ location=location,
+ positional=positional,
+ keywords=keywords,
+ report_errors=False,
+ )
+ if valid and self._are_arguments_valid(mapped, report_errors=False):
+ candidates.append(
+ OverloadCandidate(
+ function=function,
+ mapped=mapped,
+ )
+ )
+
+ pos_types: str = ", ".join(str(type) for _, type in positional)
+ kw_types: str = ", ".join(
+ f"{name}: {type}" for name, (_, type) in keywords.items()
+ )
+ for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
+
+ n_candidates: int = len(candidates)
+
+ # Exactly 1 match -> return it
+ if n_candidates == 1:
+ return candidates[0].function
+
+ # No match -> invalid call
+ if n_candidates == 0:
+ overloads_str: str = ", ".join(map(str, overloads))
+ self.reporter.error(
+ location,
+ f"No matching overload in [{overloads_str}] {for_args}",
+ )
+ return None
+
+ # Multiple matches -> see if one <: all others (more specific)
+ for i1, c1 in enumerate(candidates):
+ mapped1: list[MappedArgument] = c1.mapped
+ best_match: bool = True
+ for i2, c2 in enumerate(candidates):
+ if i1 == i2:
+ continue
+ mapped2: list[MappedArgument] = c2.mapped
+ if not self._are_mapped_subtypes(mapped1, mapped2):
+ best_match = False
+ break
+ self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
+ if best_match:
+ return c1.function
+
+ candidates_str: str = ", ".join(
+ str(candidate.function) for candidate in candidates
+ )
+ self.reporter.error(
+ location,
+ f"Multiple matching overloads {for_args}: {candidates_str}",
+ )
+ return None
+
+ def map_call_arguments(
+ self,
+ function: Function,
+ location: Location,
+ positional: list[TypedExpr],
+ keywords: dict[str, TypedExpr],
+ report_errors: bool = True,
+ ) -> tuple[bool, list[MappedArgument]]:
+ """Map call arguments to a function's parameters as defined in its signature
+
+ This method maps positional-only, keyword-only and mixed parameter definitions
+ with the arguments passed at the call site
+
+ Any mismatched, missing or unexpected argument is reported as a diagnostic,
+ unless `report_errors` is set to `False`
+
+ Args:
+ function (Function): the function definition
+ location (Location): the call location
+ positional (list[TypedExpr]): the list of positional arguments
+ keywords (dict[str, TypedExpr]): the map of keyword arguments
+ report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
+
+ Returns:
+ tuple[bool, list[MappedArgument]]: a boolean reporting whether
+ the call is valid and the list of mapped arguments
+ """
+ set_args: set[str] = set()
+
+ required_positional: list[str] = [
+ arg.name for arg in function.pos_args + function.args if arg.required
+ ]
+ required_keyword: list[str] = [
+ arg.name for arg in function.kw_args if arg.required
+ ]
+
+ mapped: list[MappedArgument] = []
+
+ pos_params: list[Function.Argument] = list(function.pos_args)
+ mixed_params: list[Function.Argument] = list(function.args)
+ kw_params: dict[str, Function.Argument] = {
+ arg.name: arg for arg in function.kw_args
+ }
+
+ valid_call: bool = True
+
+ # TODO: handle *args and **kwargs sinks
+ for arg in positional:
+ param: Function.Argument
+ if len(pos_params) != 0:
+ param = pos_params.pop(0)
+ elif len(mixed_params) != 0:
+ param = mixed_params.pop(0)
+ else:
+ if report_errors:
+ self.reporter.error(
+ arg[0].location, "Too many positional arguments"
+ )
+ valid_call = False
+ break
+ name: str = param.name
+ if name in required_positional:
+ required_positional.remove(name)
+ if name in required_keyword:
+ required_keyword.remove(name)
+ set_args.add(name)
+ mapped.append(
+ MappedArgument(
+ expr=arg[0],
+ type=arg[1],
+ argument=param,
+ )
+ )
+
+ kw_params.update({arg.name: arg for arg in mixed_params})
+ for name, arg in keywords.items():
+ param: Function.Argument
+ if name not in kw_params:
+ if report_errors:
+ if name in set_args:
+ self.reporter.error(
+ arg[0].location, f"Multiple values for argument '{name}'"
+ )
+ else:
+ self.reporter.error(
+ arg[0].location, f"Unknown keyword argument '{name}'"
+ )
+ valid_call = False
+ continue
+ param = kw_params.pop(name)
+ if name in required_positional:
+ required_positional.remove(name)
+ if name in required_keyword:
+ required_keyword.remove(name)
+ set_args.add(name)
+ mapped.append(
+ MappedArgument(
+ expr=arg[0],
+ type=arg[1],
+ argument=param,
+ )
+ )
+
+ def join_args(args: list[str]) -> str:
+ args = list(map(lambda a: f"'{a}'", args))
+ if len(args) == 0:
+ return ""
+ if len(args) == 1:
+ return args[0]
+ return ", ".join(args[:-1]) + " and " + args[-1]
+
+ if len(required_positional) != 0:
+ plural: str = "" if len(required_positional) == 1 else "s"
+ args: str = join_args(required_positional)
+ if report_errors:
+ self.reporter.error(
+ location,
+ f"Missing required positional argument{plural}: {args}",
+ )
+ valid_call = False
+
+ if len(required_keyword) != 0:
+ plural: str = "" if len(required_keyword) == 1 else "s"
+ args: str = join_args(required_keyword)
+ if report_errors:
+ self.reporter.error(
+ location,
+ f"Missing required keyword argument{plural}: {args}",
+ )
+ valid_call = False
+
+ return valid_call, mapped
+
+ def _are_mapped_subtypes(
+ self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
+ ) -> bool:
+ """Check whether the given argument mappings are subtype/supertype of one another
+
+ This function checks whether the argument mappings `mapped1` are subtypes
+ of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
+ of the corresponding parameter in `mapped2`, `False` is returned.
+
+ This is used to check whether a given overload is
+ a more specific function/ a subtype of another.
+
+ Args:
+ mapped1 (list[MappedArgument]): the first argument mappings (subtype)
+ mapped2 (list[MappedArgument]): the second argument mappings (supertype)
+
+ Returns:
+ bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
+ """
+ by_expr: dict[p.Expr, Type] = {}
+ for arg in mapped1:
+ by_expr[arg.expr] = arg.argument.type
+
+ for arg in mapped2:
+ type2: Type = arg.argument.type
+ type1: Type = by_expr[arg.expr]
+ if not self.is_subtype(type1, type2):
+ return False
+ return True
diff --git a/midas/checker/registry.py b/midas/checker/registry.py
new file mode 100644
index 0000000..6591548
--- /dev/null
+++ b/midas/checker/registry.py
@@ -0,0 +1,347 @@
+import logging
+from typing import Optional
+
+from midas.checker.builtins import BUILTIN_SUBTYPES
+from midas.checker.types import (
+ AliasType,
+ AppliedType,
+ BaseType,
+ ComplexType,
+ ExtensionType,
+ Function,
+ GenericType,
+ OverloadedFunction,
+ TopType,
+ Type,
+ TypeVar,
+ UnknownType,
+ substitute_typevars,
+)
+
+
+class TypesRegistry:
+ def __init__(self) -> None:
+ self.logger: logging.Logger = logging.getLogger("TypesRegistry")
+ self._types: dict[str, Type] = {}
+ self._members: dict[str, dict[str, Type]] = {}
+
+ def get_type(self, name: str) -> Type:
+ """Get a type from its name
+
+ Args:
+ name (str): the name of the type
+
+ Raises:
+ NameError: if the type is not defined
+
+ Returns:
+ Type: the type
+ """
+ if name in self._types:
+ return self._types[name]
+ raise NameError(f"Undefined type {name}")
+
+ def define_type(self, name: str, type: Type) -> Type:
+ """Define a type in the registry
+
+ Args:
+ name (str): the name of the type
+ type (Type): the type to define
+
+ Raises:
+ ValueError: if a type is already defined with that name
+
+ Returns:
+ Type: the defined type
+ """
+ if name in self._types:
+ raise ValueError(f"Type {name} already defined")
+ self._types[name] = type
+ return type
+
+ def define_member(
+ self, type_name: str, member_name: str, member_type: Type, is_method: bool
+ ):
+ members: dict[str, Type] = self._members.setdefault(type_name, {})
+ if member_name in members:
+ if not is_method:
+ self.logger.error(
+ f"Member '{member_name}' already defined for type {type_name}"
+ )
+ return
+ current: Type = members[member_name]
+ combined: Type
+ match current:
+ case OverloadedFunction(overloads=overloads):
+ combined = OverloadedFunction(overloads=overloads + [member_type])
+ case _:
+ combined = OverloadedFunction(overloads=[current, member_type])
+ members[member_name] = combined
+
+ else:
+ members[member_name] = member_type
+
+ def is_subtype(self, type1: Type, type2: Type) -> bool:
+ """Check whether `type1` is a subtype of `type2`
+
+ For more details on the rules checked here, see TAPL Chap. 15-16-17
+
+ Args:
+ type1 (Type): the potential subtype
+ type2 (Type): the potential supertype
+
+ Returns:
+ bool: whether `type1` is a subtype of `type2`
+ """
+
+ if type1 == type2:
+ return True
+
+ match (type1, type2):
+ case (_, TopType()):
+ return True
+
+ case (AliasType(type=base1), _):
+ return self.is_subtype(base1, type2)
+
+ case (BaseType(name=name1), BaseType(name=name2)):
+ return name1 in BUILTIN_SUBTYPES.get(name2, set())
+
+ case (ComplexType(properties=props1), ComplexType(properties=props2)):
+ for k, t in props2.items():
+ if k not in props1:
+ return False
+ if not self.is_subtype(props1[k], t):
+ return False
+ return True
+
+ case (Function(), Function()):
+ return self.is_func_subtype(type1, type2)
+
+ case (TypeVar(bound=bound), _):
+ if bound is None:
+ return False
+ return self.is_subtype(bound, type2)
+
+ return False
+
+ # TODO: verify the logic in here
+ def is_func_subtype(self, func1: Function, func2: Function) -> bool:
+ """Check whether a function is a subtype of another
+
+ Args:
+ func1 (Function): the potential function subtype
+ func2 (Function): the potential function supertype
+
+ Returns:
+ bool: whether `func1` is a subtype of `func2`
+ """
+ if not self.is_subtype(func1.returns, func2.returns):
+ return False
+
+ pos1: list[Function.Argument] = func1.pos_args
+ mixed1: list[Function.Argument] = func1.args
+ kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args}
+ pos2: list[Function.Argument] = func2.pos_args
+ mixed2: list[Function.Argument] = func2.args
+ kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args}
+
+ mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2}
+ mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2}
+
+ def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool:
+ if not self.is_subtype(sub.type, sup.type):
+ return False
+ if not sup.required and sub.required:
+ return False
+ return True
+
+ for arg1 in pos1:
+ arg2: Function.Argument
+ if arg1.pos < len(pos2):
+ arg2 = pos2[arg1.pos]
+ elif arg1.pos in mixed_by_pos:
+ arg2 = mixed_by_pos[arg1.pos]
+ elif not arg1.required:
+ continue
+ else:
+ return False
+ if not is_arg_subtype(arg2, arg1):
+ return False
+
+ for name, arg1 in kw1.items():
+ arg2: Function.Argument
+ if name in kw2:
+ arg2 = kw2[name]
+ elif name in mixed_by_name:
+ arg2 = mixed_by_name[name]
+ elif not arg1.required:
+ continue
+ else:
+ return False
+ if not is_arg_subtype(arg2, arg1):
+ return False
+
+ for arg1 in mixed1:
+ pos_arg2: Optional[Function.Argument] = None
+ kw_arg2: Optional[Function.Argument] = None
+ if arg1.name in kw2:
+ kw_arg2 = kw2[arg1.name]
+ elif arg1.name in mixed_by_name:
+ kw_arg2 = mixed_by_name[arg1.name]
+ if arg1.pos < len(pos2):
+ pos_arg2 = pos2[arg1.pos]
+ elif arg1.pos in mixed_by_pos:
+ pos_arg2 = mixed_by_pos[arg1.pos]
+
+ # No match in func2 and arg is required
+ if pos_arg2 is None and kw_arg2 is None and arg1.required:
+ return False
+
+ # Matching keyword argument
+ if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1):
+ return False
+
+ # Matching positional argument
+ if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1):
+ return False
+
+ mixed_positions: set[int] = {a.pos for a in mixed1}
+ mixed_names: set[str] = {a.name for a in mixed1}
+ for arg2 in pos2:
+ if not arg2.required:
+ continue
+ if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions:
+ return False
+
+ for name, arg2 in kw2.items():
+ if not arg2.required:
+ continue
+ if name not in kw1 and name not in mixed_names:
+ return False
+
+ for arg2 in mixed2:
+ if arg2.required:
+ continue
+ pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions
+ kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names
+ if not pos_match or not kw_match:
+ return False
+
+ return True
+
+ def apply_generic(self, type: Type, args: list[Type]) -> Type:
+ match type:
+ case AliasType(name=name, type=base):
+ return AliasType(name=name, type=self.apply_generic(base, args))
+
+ case GenericType(name=name, params=type_vars, body=body):
+ n_args: int = len(args)
+ n_type_vars: int = len(type_vars)
+ if n_args < n_type_vars:
+ raise ValueError(
+ f"Missing type arguments, expected {n_type_vars} but only {n_args} provided"
+ )
+ if n_args > n_type_vars:
+ raise ValueError(
+ f"Too many type arguments, expected {n_type_vars} but {n_args} provided"
+ )
+ substitutions: dict[str, Type] = {}
+ for arg, type_var in zip(args, type_vars):
+ if type_var.bound is not None and not self.is_subtype(
+ arg, type_var.bound
+ ):
+ raise ValueError(
+ f"Type argument {arg} is not a subtype of {type_var.bound}"
+ )
+ substitutions[type_var.name] = arg
+ return AppliedType(
+ name=name,
+ args=args,
+ body=substitute_typevars(body, substitutions),
+ )
+
+ case _:
+ raise ValueError(f"{type} is not a generic type")
+
+ def reduce_types(self, types: list[Type]) -> list[Type]:
+ """Reduce a list of types to remove subtypes and only keep the highest types
+
+ Args:
+ types (list[Type]): the types to reduce
+
+ Returns:
+ list[Type]: the reduced list of types
+ """
+
+ reduced: bool = True
+ keep: list[int] = list(range(len(types)))
+ while reduced:
+ reduced = False
+ for i, i1 in enumerate(keep):
+ type1: Type = types[i1]
+ for i2 in keep[i + 1 :]:
+ type2 = types[i2]
+ if self.is_subtype(type1, type2):
+ keep.remove(i1)
+ elif self.is_subtype(type2, type1):
+ keep.remove(i2)
+ else:
+ continue
+ reduced = True
+ break
+ return [types[i] for i in keep]
+
+ def lookup_member(self, type: Type, member_name: str) -> Optional[Type]:
+ match type:
+ case BaseType(name=name):
+ if name in self._members:
+ if member_name in self._members[name]:
+ return self._members[name][member_name]
+ return None
+
+ case AliasType(name=name, type=base):
+ if name in self._members:
+ if member_name in self._members[name]:
+ return self._members[name][member_name]
+ return self.lookup_member(base, member_name)
+
+ case AppliedType(name=name, body=body, args=args):
+ generic: Type = self.get_type(name)
+
+ if not isinstance(generic, GenericType):
+ raise ValueError("AppliedType not derived from a GenericType")
+
+ substitutions = {
+ type_var.name: arg for arg, type_var in zip(args, generic.params)
+ }
+ if name in self._members:
+ if member_name in self._members[name]:
+ member_type: Type = self._members[name][member_name]
+ return substitute_typevars(member_type, substitutions)
+
+ member_type2: Optional[Type] = self.lookup_member(body, member_name)
+ if member_type2 is not None:
+ member_type2 = substitute_typevars(member_type2, substitutions)
+ return member_type2
+
+ case ComplexType(members=members):
+ if member_name in members:
+ return members[member_name]
+ self.logger.debug(f"No member '{member_name}' in {type}")
+ return None
+
+ case ExtensionType(base=base, extension=ComplexType(members=members)):
+ if member_name in members:
+ return members[member_name]
+ self.logger.debug(
+ f"No member '{member_name}' on {type}, looking up in base"
+ )
+ return self.lookup_member(base, member_name)
+
+ case UnknownType():
+ return UnknownType()
+
+ case _:
+ self.logger.debug(f"Can't get member on {type}")
+ return None
diff --git a/midas/checker/reporter.py b/midas/checker/reporter.py
new file mode 100644
index 0000000..b68766a
--- /dev/null
+++ b/midas/checker/reporter.py
@@ -0,0 +1,63 @@
+from __future__ import annotations
+
+from typing import Optional
+
+from midas.ast.location import Location
+from midas.checker.diagnostic import Diagnostic, DiagnosticType
+
+
+class Reporter:
+ def __init__(self):
+ self.diagnostics: list[Diagnostic] = []
+
+ def report(
+ self,
+ path: Optional[str],
+ type: DiagnosticType,
+ location: Location,
+ message: str,
+ ):
+ self.diagnostics.append(
+ Diagnostic(
+ file_path=path,
+ location=location,
+ type=type,
+ message=message,
+ )
+ )
+
+ def for_file(self, path: Optional[str]) -> FileReporter:
+ return FileReporter(self, path)
+
+
+class FileReporter:
+ def __init__(self, base_reporter: Reporter, path: Optional[str]) -> None:
+ self.base_reporter: Reporter = base_reporter
+ self.path: Optional[str] = path
+
+ def for_file(self, path: Optional[str]) -> FileReporter:
+ return FileReporter(self.base_reporter, path)
+
+ def report(self, type: DiagnosticType, location: Location, message: str):
+ self.base_reporter.report(self.path, type, location, message)
+
+ def error(self, location: Location, message: str):
+ self.report(
+ type=DiagnosticType.ERROR,
+ location=location,
+ message=message,
+ )
+
+ def warning(self, location: Location, message: str):
+ self.report(
+ type=DiagnosticType.WARNING,
+ location=location,
+ message=message,
+ )
+
+ def info(self, location: Location, message: str):
+ self.report(
+ type=DiagnosticType.INFO,
+ location=location,
+ message=message,
+ )
diff --git a/midas/resolver/resolver.py b/midas/checker/resolver.py
similarity index 85%
rename from midas/resolver/resolver.py
rename to midas/checker/resolver.py
index 18fcba4..12f18cf 100644
--- a/midas/resolver/resolver.py
+++ b/midas/checker/resolver.py
@@ -13,7 +13,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
def __init__(self):
self.locals: dict[p.Expr, int] = {}
- self.scopes: list[dict[str, bool]] = []
+ self.scopes: list[dict[str, bool]] = [{}]
def resolve(self, *objects: p.Stmt | p.Expr) -> None:
"""Resolve the given statements or expressions"""
@@ -77,6 +77,12 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
self.locals[expr] = i
return
+ def is_defined(self, name: str) -> bool:
+ for scope in self.scopes:
+ if name in scope:
+ return True
+ return False
+
def resolve_function(self, function: p.Function) -> None:
"""Resolve a function definition
@@ -111,7 +117,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
self.resolve(stmt.value)
for target in stmt.targets:
match target:
- case p.VariableExpr() | p.GetExpr():
+ case p.VariableExpr(name=name):
+ if not self.is_defined(name):
+ self.declare(name)
+ self.define(name)
+ target.accept(self)
+
+ case p.GetExpr():
target.accept(self)
case _:
raise Exception(f"Unsupported assignment to {target}")
@@ -180,3 +192,19 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
self.resolve(expr.test)
self.resolve(expr.if_true)
self.resolve(expr.if_false)
+
+ def visit_list_expr(self, expr: p.ListExpr) -> None:
+ for item in expr.items:
+ self.resolve(item)
+
+ def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
+ self.resolve(expr.object)
+ self.resolve(expr.index)
+
+ def visit_slice_expr(self, expr: p.SliceExpr) -> None:
+ if expr.lower is not None:
+ self.resolve(expr.lower)
+ if expr.upper is not None:
+ self.resolve(expr.upper)
+ if expr.step is not None:
+ self.resolve(expr.step)
diff --git a/midas/checker/types.py b/midas/checker/types.py
index 83707b6..c6d41d1 100644
--- a/midas/checker/types.py
+++ b/midas/checker/types.py
@@ -1,37 +1,68 @@
from __future__ import annotations
from dataclasses import dataclass
+from typing import Optional
+
+
+@dataclass(frozen=True, kw_only=True)
+class TopType:
+ def __str__(self) -> str:
+ return "Any"
@dataclass(frozen=True, kw_only=True)
class BaseType:
name: str
+ def __str__(self) -> str:
+ return self.name
+
@dataclass(frozen=True, kw_only=True)
class AliasType:
name: str
type: Type
+ def __str__(self) -> str:
+ return self.name
+
@dataclass(frozen=True, kw_only=True)
class UnknownType:
- pass
+ def __str__(self) -> str:
+ return ""
@dataclass(frozen=True, kw_only=True)
class UnitType:
- pass
+ def __str__(self) -> str:
+ return "None"
@dataclass(frozen=True, kw_only=True)
class Function:
- name: str
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
returns: Type
+ def __str__(self) -> str:
+ args: list[str] = []
+ if len(self.pos_args) != 0:
+ args += list(map(str, self.pos_args))
+ if len(self.args) + len(self.kw_args) != 0:
+ args.append("/")
+
+ if len(self.args) != 0:
+ args += list(map(str, self.args))
+
+ if len(self.kw_args) != 0:
+ if len(args) != 0:
+ args.append("*")
+ args += list(map(str, self.kw_args))
+
+ return f"({', '.join(args)}) -> {self.returns}"
+
@dataclass(frozen=True, kw_only=True)
class Argument:
pos: int
@@ -39,22 +70,164 @@ class Function:
type: Type
required: bool
+ def __str__(self) -> str:
+ opt: str = "" if self.required else "?"
+ return f"{self.name}: {self.type}{opt}"
+
+
+@dataclass(frozen=True, kw_only=True)
+class OverloadedFunction:
+ overloads: list[Type]
+
+ def __str__(self) -> str:
+ return ""
+
@dataclass(frozen=True, kw_only=True)
class ComplexType:
- properties: dict[str, Type]
+ members: dict[str, Type]
+
+ def __str__(self) -> str:
+ props: list[str] = [f"{name}: {type}" for name, type in self.members.items()]
+ return f"{{{', '.join(props)}}}"
@dataclass(frozen=True, kw_only=True)
-class Operation:
- signature: CallSignature
- result: Type
+class ExtensionType:
+ base: Type
+ extension: ComplexType
- @dataclass(frozen=True, kw_only=True)
- class CallSignature:
- left: Type
- method: str
- right: Type
+ def __str__(self) -> str:
+ return f"{self.base} & {self.extension}"
-Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType
+@dataclass(frozen=True, kw_only=True)
+class TypeVar:
+ name: str
+ bound: Optional[Type]
+
+ def __str__(self) -> str:
+ if self.bound is not None:
+ return f"{self.name} <: {self.bound}"
+ return self.name
+
+
+@dataclass(frozen=True, kw_only=True)
+class GenericType:
+ name: str
+ params: list[TypeVar]
+ body: Type
+
+ def __str__(self) -> str:
+ return f"{self.name}[{', '.join(map(str, self.params))}]"
+
+
+@dataclass(frozen=True, kw_only=True)
+class AppliedType:
+ name: str
+ args: list[Type]
+ body: Type
+
+ def __str__(self) -> str:
+ return f"{self.name}[{', '.join(map(str, self.args))}]"
+
+
+def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
+ def sub_argument(arg: Function.Argument):
+ return Function.Argument(
+ pos=arg.pos,
+ name=arg.name,
+ type=substitute_typevars(arg.type, substitutions),
+ required=arg.required,
+ )
+
+ match type:
+ case BaseType(name=name) if name in substitutions:
+ return substitutions[name]
+
+ case BaseType():
+ return type
+
+ case AliasType(name=name, type=type2):
+ return AliasType(name=name, type=substitute_typevars(type2, substitutions))
+
+ case Function(
+ pos_args=pos_args,
+ args=args,
+ kw_args=kw_args,
+ returns=returns,
+ ):
+ return Function(
+ pos_args=list(map(sub_argument, pos_args)),
+ args=list(map(sub_argument, args)),
+ kw_args=list(map(sub_argument, kw_args)),
+ returns=substitute_typevars(returns, substitutions),
+ )
+
+ case OverloadedFunction(overloads=overloads):
+ return OverloadedFunction(
+ overloads=[
+ substitute_typevars(overload, substitutions)
+ for overload in overloads
+ ]
+ )
+
+ case ComplexType(members=members):
+ members2: dict[str, Type] = {
+ name: substitute_typevars(prop, substitutions)
+ for name, prop in members.items()
+ }
+ return ComplexType(members=members2)
+
+ case ExtensionType(base=base, extension=ComplexType(members=members)):
+ return ExtensionType(
+ base=substitute_typevars(base, substitutions),
+ extension=ComplexType(
+ members={
+ name: substitute_typevars(prop, substitutions)
+ for name, prop in members.items()
+ }
+ ),
+ )
+
+ case AppliedType(name=name, args=args, body=body):
+ return AppliedType(
+ name=name,
+ args=[substitute_typevars(arg, substitutions) for arg in args],
+ body=substitute_typevars(body, substitutions),
+ )
+
+ case TypeVar(name=name):
+ if name in substitutions:
+ return substitutions[name]
+ raise ValueError(f"Missing TypeVar substitution for {name}")
+
+ case UnknownType() | UnitType():
+ return type
+
+ case _:
+ raise NotImplementedError(f"Unsupported type {type}")
+
+
+def unfold_type(type: Type) -> Type:
+ match type:
+ case AliasType(type=ref_type):
+ return unfold_type(ref_type)
+ case _:
+ return type
+
+
+Type = (
+ TopType
+ | BaseType
+ | AliasType
+ | UnknownType
+ | UnitType
+ | Function
+ | OverloadedFunction
+ | ComplexType
+ | ExtensionType
+ | TypeVar
+ | GenericType
+ | AppliedType
+)
diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py
index e4a9556..bc7727c 100644
--- a/midas/cli/highlighter.py
+++ b/midas/cli/highlighter.py
@@ -214,6 +214,22 @@ class PythonHighlighter(
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
+ def visit_list_expr(self, expr: p.ListExpr) -> None:
+ for item in expr.items:
+ item.accept(self)
+
+ def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
+ expr.object.accept(self)
+ expr.index.accept(self)
+
+ def visit_slice_expr(self, expr: p.SliceExpr) -> None:
+ if expr.lower is not None:
+ expr.lower.accept(self)
+ if expr.upper is not None:
+ expr.upper.accept(self)
+ if expr.step is not None:
+ expr.step.accept(self)
+
class MidasHighlighter(
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
@@ -228,21 +244,14 @@ class MidasHighlighter(
self.wrap(LocatableToken(stmt.name), "type-name")
stmt.type.accept(self)
- def visit_property_stmt(self, stmt: m.PropertyStmt) -> None:
- self.wrap(stmt, "property")
+ def visit_member_stmt(self, stmt: m.MemberStmt) -> None:
+ self.wrap(stmt, "member")
stmt.type.accept(self)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self.wrap(stmt, "extend")
- stmt.type.accept(self)
- for op in stmt.operations:
- op.accept(self)
-
- def visit_op_stmt(self, stmt: m.OpStmt) -> None:
- self.wrap(stmt, "op")
- self.wrap(LocatableToken(stmt.name), "op-name")
- stmt.operand.accept(self)
- stmt.result.accept(self)
+ for member in stmt.members:
+ member.accept(self)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.wrap(stmt, "predicate")
@@ -284,8 +293,8 @@ class MidasHighlighter(
def visit_generic_type(self, type: m.GenericType) -> None:
self.wrap(type, "generic-type")
type.type.accept(self)
- for param in type.params:
- param.accept(self)
+ for arg in type.args:
+ arg.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self.wrap(type, "constraint-type")
@@ -294,8 +303,19 @@ class MidasHighlighter(
def visit_complex_type(self, type: m.ComplexType) -> None:
self.wrap(type, "complex-type")
- for prop in type.properties:
- prop.accept(self)
+ for member in type.members:
+ member.accept(self)
+
+ def visit_function_type(self, type: m.FunctionType) -> None:
+ self.wrap(type, "function")
+ for arg in type.pos_args + type.args + type.kw_args:
+ arg.type.accept(self)
+ type.returns.accept(self)
+
+ def visit_extension_type(self, type: m.ExtensionType) -> None:
+ self.wrap(type, "extension")
+ type.base.accept(self)
+ type.extension.accept(self)
class DiagnosticsHighlighter(Highlighter):
diff --git a/midas/cli/main.py b/midas/cli/main.py
index ae4295b..af95abd 100644
--- a/midas/cli/main.py
+++ b/midas/cli/main.py
@@ -10,7 +10,7 @@ import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
-from midas.checker.checker import Checker
+from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.types import Type
from midas.cli.ansi import Ansi
@@ -25,7 +25,6 @@ from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token, TokenType
from midas.parser.midas import MidasParser
from midas.parser.python import PythonParser
-from midas.resolver.resolver import Resolver
from midas.utils import UniversalJSONDumper
@@ -89,36 +88,57 @@ def print_diagnostic(lines: list[str], diagnostic: Diagnostic, indent: int = 4):
@click.option("-l", "--highlight", type=click.File("w"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-v", "--verbose", is_flag=True)
+@click.option("-j", "--show-judgements", is_flag=True)
@click.argument("file", type=click.File("r"))
def compile(
highlight: Optional[TextIO],
types: tuple[TextIO],
verbose: bool,
+ show_judgements: bool,
file: TextIO,
):
logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN)
source: str = file.read()
- tree: ast.Module = ast.parse(source, filename=file.name)
- parser = PythonParser()
- stmts: list[p.Stmt] = parser.parse_module(tree)
- resolver = Resolver()
- resolver.resolve(*stmts)
- types_paths: list[Path] = [Path(t.name).resolve() for t in types]
- checker = Checker(
- resolver.locals,
- source_path=Path(file.name).resolve(),
- types_paths=types_paths,
- )
- diagnostics: list[Diagnostic] = checker.check(stmts)
+ source_path: Path = Path(file.name).resolve()
+
+ checker = TypeChecker()
+ for types_file in types:
+ checker.import_midas(Path(types_file.name).resolve())
+
+ checker.type_check_source(source, str(source_path))
+ diagnostics: list[Diagnostic] = checker.diagnostics.copy()
lines: list[str] = source.split("\n")
+ files: dict[Optional[str], list[str]] = {None: []}
+
+ if show_judgements:
+ for expr, type in checker.python_typer.judgements:
+ print(f"Judged that {expr} at {expr.location} is of type {type}")
+ diagnostics.append(
+ Diagnostic(
+ file_path=str(source_path),
+ location=expr.location,
+ type=DiagnosticType.INFO,
+ message=f"Type: {type}",
+ )
+ )
+
for diagnostic in diagnostics:
+ filename: Optional[str] = diagnostic.file_path
+ if filename is not None and filename not in files:
+ path: Path = Path(filename)
+ if path.exists() and path.is_file():
+ files[filename] = path.read_text().split("\n")
+ else:
+ files[filename] = []
+
+ lines: list[str] = files[filename]
print_diagnostic(lines, diagnostic)
if verbose:
print(
json.dumps(
UniversalJSONDumper.dump(
- checker.global_env,
+ checker.python_typer.global_env,
[("Environment", "_children")],
lambda obj: isinstance(obj, get_args(Type)),
),
diff --git a/midas/lexer/midas.py b/midas/lexer/midas.py
index 124ea09..c3246fc 100644
--- a/midas/lexer/midas.py
+++ b/midas/lexer/midas.py
@@ -50,12 +50,14 @@ class MidasLexer(Lexer):
# self.add_token(TokenType.PLUS)
case "-":
self.add_token(TokenType.MINUS)
- # case "*":
- # self.add_token(TokenType.STAR)
+ case "*":
+ self.add_token(TokenType.STAR)
case "/" if self.match("/"):
self.scan_comment()
case "/" if self.match("*"):
self.scan_comment_multiline()
+ case "/":
+ self.add_token(TokenType.SLASH)
case "\n":
self.add_token(TokenType.NEWLINE)
case " " | "\r" | "\t":
diff --git a/midas/lexer/token.py b/midas/lexer/token.py
index f08964a..f0c08a1 100644
--- a/midas/lexer/token.py
+++ b/midas/lexer/token.py
@@ -27,8 +27,8 @@ class TokenType(Enum):
# Operators
# PLUS = auto()
MINUS = auto()
- # STAR = auto()
- # SLASH = auto()
+ STAR = auto()
+ SLASH = auto()
GREATER = auto()
GREATER_EQUAL = auto()
LESS = auto()
@@ -46,10 +46,12 @@ class TokenType(Enum):
# Keywords
TYPE = auto()
- OP = auto()
PREDICATE = auto()
EXTEND = auto()
WHERE = auto()
+ PROP = auto()
+ DEF = auto()
+ FUNC = auto()
# Misc
COMMENT = auto()
@@ -60,13 +62,15 @@ class TokenType(Enum):
KEYWORDS: dict[str, TokenType] = {
"type": TokenType.TYPE,
- "op": TokenType.OP,
"predicate": TokenType.PREDICATE,
"extend": TokenType.EXTEND,
"where": TokenType.WHERE,
"true": TokenType.TRUE,
"false": TokenType.FALSE,
"none": TokenType.NONE,
+ "prop": TokenType.PROP,
+ "def": TokenType.DEF,
+ "fn": TokenType.FUNC,
}
diff --git a/midas/parser/midas.py b/midas/parser/midas.py
index 5d09b83..33069f3 100644
--- a/midas/parser/midas.py
+++ b/midas/parser/midas.py
@@ -7,23 +7,26 @@ from midas.ast.midas import (
ConstraintType,
Expr,
ExtendStmt,
+ ExtensionType,
+ FunctionType,
GenericType,
GetExpr,
GroupingExpr,
LiteralExpr,
LogicalExpr,
+ MemberKind,
+ MemberStmt,
NamedType,
- OpStmt,
PredicateStmt,
- PropertyStmt,
Stmt,
Type,
+ TypeParam,
TypeStmt,
UnaryExpr,
VariableExpr,
WildcardExpr,
)
-from midas.lexer.token import Token, TokenType
+from midas.lexer.token import KEYWORDS, Token, TokenType
from midas.parser.base import Parser
from midas.parser.errors import ParsingError
@@ -33,9 +36,10 @@ class MidasParser(Parser):
SYNC_BOUNDARY: set[TokenType] = {
TokenType.TYPE,
- TokenType.OP,
TokenType.EXTEND,
TokenType.PREDICATE,
+ TokenType.PROP,
+ TokenType.FUNC,
}
def parse(self) -> list[Stmt]:
@@ -107,10 +111,8 @@ class MidasParser(Parser):
TypeStmt: the parsed type declaration statement
"""
keyword: Token = self.previous()
- name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
- params: list[TypeStmt.Param] = []
- if self.check(TokenType.LEFT_BRACKET):
- params = self.type_stmt_params()
+ name: Token = self.consume_identifier("Expected type name")
+ params: list[TypeParam] = self.type_params()
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
@@ -123,24 +125,27 @@ class MidasParser(Parser):
type=type,
)
- def type_stmt_params(self) -> list[TypeStmt.Param]:
- """Parse a generic template expression
+ def type_params(self) -> list[TypeParam]:
+ """Parse a list of type parameters
- A template is written `[TypeExpr]`
+ Type parameters are a comma-separated list of type variables wrapped in brackets.
+ Each type variable is either a simple variable, or a bounded variable written `S <: T`
Returns:
- TemplateExpr: the parsed template expression
+ list[TypeParam]: the list of type parameters, if any, or an empty list
"""
- self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
- params: list[TypeStmt.Param] = []
+ if not self.match(TokenType.LEFT_BRACKET):
+ return []
+
+ params: list[TypeParam] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
- name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
+ name: Token = self.consume_identifier("Expected type variable")
bound: Optional[Type] = None
if self.match(TokenType.LESS):
self.consume(TokenType.COLON, "Expected ':' after '<'")
bound = self.type_expr()
params.append(
- TypeStmt.Param(
+ TypeParam(
location=name.location_to(self.previous()),
name=name,
bound=bound,
@@ -148,7 +153,7 @@ class MidasParser(Parser):
)
if not self.match(TokenType.COMMA):
break
- self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
+ self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
return params
def type_expr(self) -> Type:
@@ -160,7 +165,19 @@ class MidasParser(Parser):
Returns:
TypeExpr: the parsed type expression
"""
- return self.constraint_type()
+ base: Type
+ if self.match(TokenType.FUNC):
+ base = self.function()
+ else:
+ base = self.constraint_type()
+ if self.match(TokenType.AND):
+ extension: ComplexType = self.complex_type()
+ return ExtensionType(
+ location=Location.span(base.location, extension.location),
+ base=base,
+ extension=extension,
+ )
+ return base
def constraint_type(self) -> Type:
type: Type = self.base_type()
@@ -187,55 +204,57 @@ class MidasParser(Parser):
def generic_type(self) -> Type:
type: Type = self.named_type()
if self.check(TokenType.LEFT_BRACKET):
- params: list[Type] = self.type_params()
+ args: list[Type] = self.type_args()
return GenericType(
location=Location.span(type.location, self.previous().get_location()),
type=type,
- params=params,
+ args=args,
)
return type
- def type_params(self) -> list[Type]:
- params: list[Type] = []
- self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters")
+ def type_args(self) -> list[Type]:
+ args: list[Type] = []
+ self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments")
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
- params.append(self.type_expr())
+ args.append(self.type_expr())
if not self.match(TokenType.COMMA):
break
- self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters")
- return params
+ self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
+ return args
def named_type(self) -> Type:
- name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
+ name: Token = self.consume_identifier("Expected type name")
return NamedType(
location=name.get_location(),
name=name,
)
- def complex_type(self) -> Type:
+ def complex_type(self) -> ComplexType:
"""Parse a type definition body
A type definition body is a set of whitespace-separated
property statements enclosed in curly braces
Returns:
- list[PropertyStmt]: the parsed type properties
+ ComplexType: the parsed complex type
"""
left: Token = self.consume(
TokenType.LEFT_BRACE, "Expected '{' to start type body"
)
- properties: list[PropertyStmt] = []
+ members: list[MemberStmt] = []
+ # TODO: add keyword to differentiate properties and methods,
+ # and allow multiple methods with the same name but not properties
names: set[str] = set()
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
- prop: PropertyStmt = self.property_stmt()
- if prop.name.lexeme in names:
- raise self.error(prop.name, "Duplicate property")
- names.add(prop.name.lexeme)
- properties.append(prop)
+ member: MemberStmt = self.member_stmt()
+ # if member.name.lexeme in names:
+ # raise self.error(member.name, "Duplicate property")
+ # names.add(member.name.lexeme)
+ members.append(member)
right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
return ComplexType(
location=left.location_to(right),
- properties=properties,
+ members=members,
)
def constraint(self) -> Expr:
@@ -322,9 +341,7 @@ class MidasParser(Parser):
"""
expr: Expr = self.primary()
while self.match(TokenType.DOT):
- name: Token = self.consume(
- TokenType.IDENTIFIER, "Expected property name after '.'"
- )
+ name: Token = self.consume_identifier("Expected property name after '.'")
location: Location = Location.span(expr.location, name.get_location())
expr = GetExpr(location=location, expr=expr, name=name)
return expr
@@ -348,7 +365,7 @@ class MidasParser(Parser):
if self.match(TokenType.NUMBER):
return LiteralExpr(location=token.get_location(), value=token.value)
- if self.match(TokenType.IDENTIFIER):
+ if self.match_identifier():
return VariableExpr(location=token.get_location(), name=token)
if self.match(TokenType.UNDERSCORE):
@@ -361,64 +378,70 @@ class MidasParser(Parser):
raise self.error(self.peek(), "Expected expression")
- def property_stmt(self) -> PropertyStmt:
- """Parse a property statement
+ def consume_identifier(self, message: str = "Expected identifier") -> Token:
+ if not self.match_identifier():
+ raise self.error(self.peek(), message)
+ return self.previous()
- A type property statement is written `name: Type` or `name: Type where Condition`
+ def match_identifier(self) -> bool:
+ return self.match(TokenType.IDENTIFIER, *KEYWORDS.values())
+
+ def check_identifier(self) -> bool:
+ for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]:
+ if self.check(tt):
+ return True
+ return False
+
+ def member_stmt(self) -> MemberStmt:
+ """Parse a member statement
+
+ A type member statement is written `prop name: Type` or `def name: Type`
Returns:
- PropertyStmt: the parsed property statement
+ MemberStmt: the parsed member statement
"""
- name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
- self.consume(TokenType.COLON, "Expected ':' after property name")
+ kind: MemberKind
+ if self.match(TokenType.PROP):
+ kind = MemberKind.PROPERTY
+ elif self.match(TokenType.DEF):
+ kind = MemberKind.METHOD
+ else:
+ raise self.error(self.peek(), "Expected 'prop' or 'def'")
+
+ name: Token = self.consume_identifier("Expected member name")
+ self.consume(TokenType.COLON, "Expected ':' after member name")
+
type: Type = self.type_expr()
- return PropertyStmt(
+ return MemberStmt(
location=name.location_to(self.previous()),
name=name,
type=type,
+ kind=kind,
)
def extend_declaration(self) -> ExtendStmt:
"""Parse an extension definition
- An extension is written `extend Type { operations }`
+ An extension is written `extend Type { operations }` or `extend[S <: T, U] Type { operations }`
Returns:
ExtendStmt: the parsed extension statement
"""
keyword: Token = self.previous()
- type: Type = self.type_expr()
+ name: Token = self.consume_identifier("Expected type name")
+ params: list[TypeParam] = self.type_params()
+
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
- operations: list[OpStmt] = []
+ members: list[MemberStmt] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
- operations.append(self.op_declaration())
+ members.append(self.member_stmt())
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
location: Location = keyword.location_to(self.previous())
- return ExtendStmt(location=location, type=type, operations=operations)
-
- def op_declaration(self) -> OpStmt:
- """Parse an operation definition
-
- An operation is written `op name(Type) -> Type`
-
- Returns:
- OpStmt: the parsed operation statement
- """
- keyword: Token = self.consume(TokenType.OP, "Expected 'op' keyword")
-
- name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
- self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
- operand: Type = self.type_expr()
- self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type")
-
- self.consume(TokenType.ARROW, "Expected '->' before result type")
- result: Type = self.type_expr()
-
- return OpStmt(
- location=keyword.location_to(self.previous()),
+ return ExtendStmt(
+ location=location,
name=name,
- operand=operand,
- result=result,
+ params=params,
+ members=members,
)
def predicate_declaration(self) -> PredicateStmt:
@@ -430,9 +453,9 @@ class MidasParser(Parser):
PredicateStmt: the parsed predicate declaration statement
"""
keyword: Token = self.previous()
- name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
+ name: Token = self.consume_identifier("Expected predicate name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
- subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
+ subject: Token = self.consume_identifier("Expected subject name")
self.consume(TokenType.COLON, "Expected ':' after subject name")
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
@@ -445,3 +468,72 @@ class MidasParser(Parser):
type=type,
condition=condition,
)
+
+ def function(self) -> FunctionType:
+ l_paren: Token = self.consume(
+ TokenType.LEFT_PAREN, "Expected '(' before function parameters"
+ )
+ pos_args: list[FunctionType.Argument] = []
+ args: list[FunctionType.Argument] = []
+ kw_args: list[FunctionType.Argument] = []
+
+ args_first_tokens: list[Token] = []
+
+ section: int = 0
+ while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
+ match section:
+ case 0 if self.match(TokenType.SLASH):
+ pos_args = args
+ args = []
+ args_first_tokens = []
+ section = 1
+ case 0 | 1 if self.match(TokenType.STAR):
+ section = 2
+ case _:
+ # Record first token of mixed argument for errors if unnamed
+ if section != 2:
+ args_first_tokens.append(self.peek())
+
+ name: Optional[Token] = None
+ if section == 2:
+ name = self.consume_identifier("Expected keyword argument name")
+ self.consume(
+ TokenType.COLON, "Expected ':' after argument name"
+ )
+ elif self.check_identifier() and self.check_next(TokenType.COLON):
+ name = self.advance()
+ self.advance()
+
+ type: Type = self.type_expr()
+ optional: bool = self.match(TokenType.QMARK)
+ arg = FunctionType.Argument(
+ location=None,
+ name=name,
+ type=type,
+ required=not optional,
+ )
+ if section == 2:
+ kw_args.append(arg)
+ else:
+ args.append(arg)
+
+ if not self.match(TokenType.COMMA):
+ break
+
+ for arg, token in zip(args, args_first_tokens):
+ if arg.name is None:
+ # Not raised because we can keep parsing
+ self.error(token, "Unnamed mixed argument")
+
+ self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
+
+ self.consume(TokenType.ARROW, "Expected '->' before result type")
+ result: Type = self.type_expr()
+
+ return FunctionType(
+ location=l_paren.location_to(self.previous()),
+ pos_args=pos_args,
+ args=args,
+ kw_args=kw_args,
+ returns=result,
+ )
diff --git a/midas/parser/python.py b/midas/parser/python.py
index 79011bc..a0726da 100644
--- a/midas/parser/python.py
+++ b/midas/parser/python.py
@@ -17,11 +17,14 @@ from midas.ast.python import (
Function,
GetExpr,
IfStmt,
+ ListExpr,
LiteralExpr,
LogicalExpr,
MidasType,
ReturnStmt,
+ SliceExpr,
Stmt,
+ SubscriptExpr,
TernaryExpr,
TypeAssign,
UnaryExpr,
@@ -416,6 +419,27 @@ class PythonParser:
case ast.Name(id=name):
return VariableExpr(location=location, name=name)
+ case ast.List(elts=items):
+ return ListExpr(
+ location=location,
+ items=[self.parse_expr(item) for item in items],
+ )
+
+ case ast.Subscript(value=value, slice=index):
+ return SubscriptExpr(
+ location=location,
+ object=self.parse_expr(value),
+ index=self.parse_expr(index),
+ )
+
+ case ast.Slice(lower=lower, upper=upper, step=step):
+ return SliceExpr(
+ location=location,
+ lower=self.parse_expr(lower) if lower is not None else None,
+ upper=self.parse_expr(upper) if upper is not None else None,
+ step=self.parse_expr(step) if step is not None else None,
+ )
+
case _:
raise UnsupportedSyntaxError(node)
diff --git a/midas/resolver/builtin.py b/midas/resolver/builtin.py
deleted file mode 100644
index 04bc6e3..0000000
--- a/midas/resolver/builtin.py
+++ /dev/null
@@ -1,72 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING
-
-from midas.checker.types import BaseType, Type, UnitType
-
-if TYPE_CHECKING:
- from midas.resolver.midas import MidasResolver
-
-
-def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
- ctx.define_operation(
- left=t1,
- operator=operator,
- right=t2,
- result=t3,
- )
-
-
-def basic_op(ctx: MidasResolver, type: Type, op: str):
- ctx.define_operation(
- left=type,
- operator=op,
- right=type,
- result=type,
- )
-
-
-def define_builtins(ctx: MidasResolver):
- """Define builtin types and operations"""
- unit = ctx.define_type("None", UnitType())
- bool = ctx.define_type("bool", BaseType(name="bool"))
- int = ctx.define_type("int", BaseType(name="int"))
- float = ctx.define_type("float", BaseType(name="float"))
- str = ctx.define_type("str", BaseType(name="str"))
-
- basic_op(ctx, int, "__add__") # int + int = int
- basic_op(ctx, int, "__sub__") # int - int = int
- basic_op(ctx, int, "__mul__") # int * int = int
- basic_op(ctx, int, "__pow__") # int ** int = int
- basic_op(ctx, int, "__mod__") # int % int = int
- basic_op(ctx, int, "__and__") # int & int = int
- basic_op(ctx, int, "__or__") # int | int = int
- basic_op(ctx, int, "__xor__") # int ^ int = int
- op(ctx, int, "__lt__", int, bool) # int < int = bool
- op(ctx, int, "__gt__", int, bool) # int > int = bool
- op(ctx, int, "__le__", int, bool) # int <= int = bool
- op(ctx, int, "__ge__", int, bool) # int >= int = bool
- op(ctx, int, "__eq__", int, bool) # int == int = bool
- basic_op(ctx, float, "__add__") # float + float = float
- basic_op(ctx, float, "__sub__") # float - float = float
- basic_op(ctx, float, "__mul__") # float * float = float
- basic_op(ctx, float, "__truediv__") # float / float = float
- op(ctx, float, "__lt__", float, bool) # float < float = bool
- op(ctx, float, "__gt__", float, bool) # float > float = bool
- op(ctx, float, "__le__", float, bool) # float <= float = bool
- op(ctx, float, "__ge__", float, bool) # float >= float = bool
- op(ctx, float, "__eq__", float, bool) # float == float = bool
- basic_op(ctx, str, "__add__") # str + str = str
- op(ctx, str, "__eq__", str, bool) # str == str = bool
-
- op(ctx, int, "__lt__", float, bool) # int < float = bool
- op(ctx, int, "__gt__", float, bool) # int > float = bool
- op(ctx, int, "__le__", float, bool) # int <= float = bool
- op(ctx, int, "__ge__", float, bool) # int >= float = bool
- op(ctx, int, "__eq__", float, bool) # int == float = bool
-
- op(ctx, float, "__lt__", int, bool) # float < int = bool
- op(ctx, float, "__gt__", int, bool) # float > int = bool
- op(ctx, float, "__le__", int, bool) # float <= int = bool
- op(ctx, float, "__ge__", int, bool) # float >= int = bool
- op(ctx, float, "__eq__", int, bool) # float == int = bool
diff --git a/midas/resolver/midas.py b/midas/resolver/midas.py
deleted file mode 100644
index 468f59a..0000000
--- a/midas/resolver/midas.py
+++ /dev/null
@@ -1,186 +0,0 @@
-from typing import Optional
-
-import midas.ast.midas as m
-from midas.checker.types import (
- AliasType,
- ComplexType,
- Operation,
- Type,
- UnknownType,
-)
-from midas.resolver.builtin import define_builtins
-
-
-class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
- """A resolver which evaluates Midas type definitions and build a registry"""
-
- def __init__(self) -> None:
- self._types: dict[str, Type] = {}
- self._operations: dict[Operation.CallSignature, Type] = {}
-
- define_builtins(self)
-
- def get_type(self, name: str) -> Type:
- """Get a type from its name
-
- Args:
- name (str): the name of the type
-
- Raises:
- NameError: if the type is not defined
-
- Returns:
- Type: the type
- """
- type: Optional[Type] = self._types.get(name)
- if type is None:
- raise NameError(f"Undefined type {name}")
- return type
-
- def get_operation_result(
- self, left: Type, operator: str, right: Type
- ) -> Optional[Type]:
- """Get the resulting type of an operation
-
- Args:
- left (Type): the type of the left operand
- operator (str): the operation name
- right (Type): the type of the right operand
-
- Returns:
- Optional[Type]: the result type, or None if no matching operation was found
- """
- signature: Operation.CallSignature = Operation.CallSignature(
- left=left,
- method=operator,
- right=right,
- )
- result: Optional[Type] = self._operations.get(signature)
- return result
-
- def get_operations_by_name(self, name: str) -> list[Operation]:
- operations: list[Operation] = []
- for signature, result in self._operations.items():
- if signature.method == name:
- operations.append(
- Operation(
- signature=signature,
- result=result,
- )
- )
- return operations
-
- def define_type(self, name: str, type: Type) -> Type:
- """Define a type in the registry
-
- Args:
- name (str): the name of the type
- type (Type): the type to define
-
- Raises:
- ValueError: if a type is already defined with that name
-
- Returns:
- Type: the defined type
- """
- if name in self._types:
- raise ValueError(f"Type {name} already defined")
- self._types[name] = type
- return type
-
- def define_operation(self, left: Type, operator: str, right: Type, result: Type):
- """Define an operation in the registry
-
- Args:
- left (Type): the type of the left operand
- operator (str): the operation name
- right (Type): the type of the right operand
- result (Type): the result type
-
- Raises:
- ValueError: if an operation is already defined with these operands and name
- """
- signature: Operation.CallSignature = Operation.CallSignature(
- left=left,
- method=operator,
- right=right,
- )
- if signature in self._operations:
- raise ValueError(
- f"Operation {operator} already defined between {left} and {right}"
- )
- self._operations[signature] = result
-
- def resolve(self, stmts: list[m.Stmt]):
- """Process a sequence of statements
-
- Args:
- stmts (list[m.Stmt]): the statements
- """
- for stmt in stmts:
- stmt.accept(self)
-
- def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
- type: Type = stmt.type.accept(self)
- for param in stmt.params:
- if param.bound is not None:
- param.bound.accept(self)
- name: str = stmt.name.lexeme
- self.define_type(name, AliasType(name=name, type=type))
-
- def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
-
- def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
- base: Type = stmt.type.accept(self)
- for op in stmt.operations:
- right: Type = op.operand.accept(self)
- result: Type = op.result.accept(self)
- self.define_operation(
- left=base,
- operator=op.name.lexeme,
- right=right,
- result=result,
- )
-
- def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
-
- def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
-
- def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
-
- def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
-
- def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
-
- def visit_get_expr(self, expr: m.GetExpr) -> None: ...
-
- def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
-
- def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
- return expr.expr.accept(self)
-
- def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
-
- def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
-
- def visit_named_type(self, type: m.NamedType) -> Type:
- return self.get_type(type.name.lexeme)
-
- def visit_generic_type(self, type: m.GenericType) -> Type:
- type_: Type = type.type.accept(self)
- params: list[Type] = [param.accept(self) for param in type.params]
- # TODO
- return UnknownType()
-
- def visit_constraint_type(self, type: m.ConstraintType) -> Type:
- type_: Type = type.type.accept(self)
- type.constraint.accept(self)
- # TODO
- return UnknownType()
-
- def visit_complex_type(self, type: m.ComplexType) -> Type:
- return ComplexType(
- properties={
- prop.name.lexeme: prop.type.accept(self) for prop in type.properties
- }
- )
diff --git a/tests/cases/checker/01_simple_types.py.ref.json b/tests/cases/checker/01_simple_types.py.ref.json
index 3c4d0b9..ac24fcd 100644
--- a/tests/cases/checker/01_simple_types.py.ref.json
+++ b/tests/cases/checker/01_simple_types.py.ref.json
@@ -1,4 +1,19 @@
{
- "diagnostics": [],
+ "diagnostics": [
+ {
+ "type": "Warning",
+ "location": {
+ "start": [
+ 6,
+ 4
+ ],
+ "end": [
+ 13,
+ 5
+ ]
+ },
+ "message": "FrameType not yet supported"
+ }
+ ],
"judgments": []
}
\ No newline at end of file
diff --git a/tests/cases/checker/02_simple_operations.py.ref.json b/tests/cases/checker/02_simple_operations.py.ref.json
index 654af17..a2c5569 100644
--- a/tests/cases/checker/02_simple_operations.py.ref.json
+++ b/tests/cases/checker/02_simple_operations.py.ref.json
@@ -12,7 +12,21 @@
13
]
},
- "message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')"
+ "message": "Cannot assign str to variable 'c' of type int"
+ },
+ {
+ "type": "Error",
+ "location": {
+ "start": [
+ 9,
+ 4
+ ],
+ "end": [
+ 9,
+ 9
+ ]
+ },
+ "message": "Undefined operation __add__ between bool and bool"
}
],
"judgments": [
@@ -158,9 +172,7 @@
"name": "d"
}
},
- "type": {
- "name": "int"
- }
+ "type": {}
},
{
"location": {
diff --git a/tests/cases/checker/03_functions.py.ref.json b/tests/cases/checker/03_functions.py.ref.json
index cd0ce42..fa06642 100644
--- a/tests/cases/checker/03_functions.py.ref.json
+++ b/tests/cases/checker/03_functions.py.ref.json
@@ -236,7 +236,7 @@
13
]
},
- "message": "Wrong type for argument 'a', expected BaseType(name='int'), got BaseType(name='str')"
+ "message": "Wrong type for argument 'a', expected int, got str"
},
{
"type": "Error",
@@ -250,10 +250,23 @@
25
]
},
- "message": "Wrong type for argument 'c', expected BaseType(name='str'), got BaseType(name='bool')"
+ "message": "Wrong type for argument 'c', expected str, got bool"
}
],
"judgments": [
+ {
+ "location": {
+ "from": "L2:11",
+ "to": "L2:15"
+ },
+ "expr": {
+ "_type": "LiteralExpr",
+ "value": true
+ },
+ "type": {
+ "name": "bool"
+ }
+ },
{
"location": {
"from": "L5:5",
@@ -264,7 +277,6 @@
"name": "foo"
},
"type": {
- "name": "foo",
"pos_args": [
{
"pos": 0,
@@ -314,9 +326,7 @@
"arguments": [],
"keywords": {}
},
- "type": {
- "name": "bool"
- }
+ "type": {}
},
{
"location": {
@@ -328,7 +338,6 @@
"name": "foo"
},
"type": {
- "name": "foo",
"pos_args": [
{
"pos": 0,
@@ -396,9 +405,7 @@
],
"keywords": {}
},
- "type": {
- "name": "bool"
- }
+ "type": {}
},
{
"location": {
@@ -410,7 +417,6 @@
"name": "foo"
},
"type": {
- "name": "foo",
"pos_args": [
{
"pos": 0,
@@ -495,9 +501,7 @@
],
"keywords": {}
},
- "type": {
- "name": "bool"
- }
+ "type": {}
},
{
"location": {
@@ -509,7 +513,6 @@
"name": "foo"
},
"type": {
- "name": "foo",
"pos_args": [
{
"pos": 0,
@@ -595,9 +598,7 @@
}
}
},
- "type": {
- "name": "bool"
- }
+ "type": {}
},
{
"location": {
@@ -609,7 +610,6 @@
"name": "foo"
},
"type": {
- "name": "foo",
"pos_args": [
{
"pos": 0,
@@ -711,9 +711,7 @@
],
"keywords": {}
},
- "type": {
- "name": "bool"
- }
+ "type": {}
},
{
"location": {
@@ -725,7 +723,6 @@
"name": "foo"
},
"type": {
- "name": "foo",
"pos_args": [
{
"pos": 0,
@@ -828,9 +825,7 @@
}
}
},
- "type": {
- "name": "bool"
- }
+ "type": {}
},
{
"location": {
@@ -842,7 +837,6 @@
"name": "foo"
},
"type": {
- "name": "foo",
"pos_args": [
{
"pos": 0,
@@ -910,9 +904,7 @@
}
}
},
- "type": {
- "name": "bool"
- }
+ "type": {}
},
{
"location": {
@@ -924,7 +916,6 @@
"name": "foo"
},
"type": {
- "name": "foo",
"pos_args": [
{
"pos": 0,
@@ -992,9 +983,7 @@
}
}
},
- "type": {
- "name": "bool"
- }
+ "type": {}
},
{
"location": {
@@ -1006,7 +995,6 @@
"name": "foo"
},
"type": {
- "name": "foo",
"pos_args": [
{
"pos": 0,
@@ -1123,7 +1111,6 @@
"name": "foo"
},
"type": {
- "name": "foo",
"pos_args": [
{
"pos": 0,
@@ -1240,7 +1227,6 @@
"name": "foo"
},
"type": {
- "name": "foo",
"pos_args": [
{
"pos": 0,
@@ -1357,7 +1343,6 @@
"name": "foo"
},
"type": {
- "name": "foo",
"pos_args": [
{
"pos": 0,
@@ -1460,9 +1445,7 @@
}
}
},
- "type": {
- "name": "bool"
- }
+ "type": {}
}
]
}
\ No newline at end of file
diff --git a/tests/cases/checker/04_custom_types.midas b/tests/cases/checker/04_custom_types.midas
index 6a1a6a2..ff4edb1 100644
--- a/tests/cases/checker/04_custom_types.midas
+++ b/tests/cases/checker/04_custom_types.midas
@@ -3,12 +3,12 @@ type Second = float
type MeterPerSecond = float
extend Meter {
- op __add__(Meter) -> Meter
- op __sub__(Meter) -> Meter
- op __truediv__(Second) -> MeterPerSecond
+ def __add__: fn(Meter, /) -> Meter
+ def __sub__: fn(Meter, /) -> Meter
+ def __truediv__: fn(Second, /) -> MeterPerSecond
}
extend Second {
- op __add__(Second) -> Second
- op __sub__(Second) -> Second
+ def __add__: fn(Second, /) -> Second
+ def __sub__: fn(Second, /) -> Second
}
diff --git a/tests/cases/checker/05_control_flow.py.ref.json b/tests/cases/checker/05_control_flow.py.ref.json
index 8f031f2..be86030 100644
--- a/tests/cases/checker/05_control_flow.py.ref.json
+++ b/tests/cases/checker/05_control_flow.py.ref.json
@@ -70,6 +70,27 @@
"name": "int"
}
},
+ {
+ "location": {
+ "from": "L2:11",
+ "to": "L2:16"
+ },
+ "expr": {
+ "_type": "BinaryExpr",
+ "left": {
+ "_type": "VariableExpr",
+ "name": "a"
+ },
+ "operator": "+",
+ "right": {
+ "_type": "VariableExpr",
+ "name": "b"
+ }
+ },
+ "type": {
+ "name": "int"
+ }
+ },
{
"location": {
"from": "L5:7",
@@ -96,6 +117,27 @@
"name": "int"
}
},
+ {
+ "location": {
+ "from": "L5:7",
+ "to": "L5:12"
+ },
+ "expr": {
+ "_type": "CompareExpr",
+ "left": {
+ "_type": "VariableExpr",
+ "name": "a"
+ },
+ "operator": "<",
+ "right": {
+ "_type": "VariableExpr",
+ "name": "b"
+ }
+ },
+ "type": {
+ "name": "bool"
+ }
+ },
{
"location": {
"from": "L6:15",
@@ -122,6 +164,27 @@
"name": "int"
}
},
+ {
+ "location": {
+ "from": "L6:15",
+ "to": "L6:20"
+ },
+ "expr": {
+ "_type": "BinaryExpr",
+ "left": {
+ "_type": "VariableExpr",
+ "name": "b"
+ },
+ "operator": "-",
+ "right": {
+ "_type": "VariableExpr",
+ "name": "a"
+ }
+ },
+ "type": {
+ "name": "int"
+ }
+ },
{
"location": {
"from": "L8:15",
@@ -148,6 +211,27 @@
"name": "int"
}
},
+ {
+ "location": {
+ "from": "L8:15",
+ "to": "L8:20"
+ },
+ "expr": {
+ "_type": "BinaryExpr",
+ "left": {
+ "_type": "VariableExpr",
+ "name": "a"
+ },
+ "operator": "-",
+ "right": {
+ "_type": "VariableExpr",
+ "name": "b"
+ }
+ },
+ "type": {
+ "name": "int"
+ }
+ },
{
"location": {
"from": "L15:7",
@@ -174,6 +258,27 @@
"name": "int"
}
},
+ {
+ "location": {
+ "from": "L15:7",
+ "to": "L15:13"
+ },
+ "expr": {
+ "_type": "CompareExpr",
+ "left": {
+ "_type": "VariableExpr",
+ "name": "a"
+ },
+ "operator": ">",
+ "right": {
+ "_type": "LiteralExpr",
+ "value": 10
+ }
+ },
+ "type": {
+ "name": "bool"
+ }
+ },
{
"location": {
"from": "L16:15",
@@ -200,6 +305,40 @@
"name": "int"
}
},
+ {
+ "location": {
+ "from": "L16:15",
+ "to": "L16:21"
+ },
+ "expr": {
+ "_type": "BinaryExpr",
+ "left": {
+ "_type": "VariableExpr",
+ "name": "a"
+ },
+ "operator": "-",
+ "right": {
+ "_type": "LiteralExpr",
+ "value": 10
+ }
+ },
+ "type": {
+ "name": "int"
+ }
+ },
+ {
+ "location": {
+ "from": "L18:15",
+ "to": "L18:16"
+ },
+ "expr": {
+ "_type": "VariableExpr",
+ "name": "a"
+ },
+ "type": {
+ "name": "int"
+ }
+ },
{
"location": {
"from": "L22:7",
@@ -226,6 +365,27 @@
"name": "int"
}
},
+ {
+ "location": {
+ "from": "L22:7",
+ "to": "L22:12"
+ },
+ "expr": {
+ "_type": "CompareExpr",
+ "left": {
+ "_type": "VariableExpr",
+ "name": "a"
+ },
+ "operator": "<",
+ "right": {
+ "_type": "VariableExpr",
+ "name": "b"
+ }
+ },
+ "type": {
+ "name": "bool"
+ }
+ },
{
"location": {
"from": "L23:15",
@@ -251,6 +411,40 @@
"type": {
"name": "int"
}
+ },
+ {
+ "location": {
+ "from": "L23:15",
+ "to": "L23:20"
+ },
+ "expr": {
+ "_type": "BinaryExpr",
+ "left": {
+ "_type": "VariableExpr",
+ "name": "b"
+ },
+ "operator": "-",
+ "right": {
+ "_type": "VariableExpr",
+ "name": "a"
+ }
+ },
+ "type": {
+ "name": "int"
+ }
+ },
+ {
+ "location": {
+ "from": "L25:15",
+ "to": "L25:21"
+ },
+ "expr": {
+ "_type": "LiteralExpr",
+ "value": "oops"
+ },
+ "type": {
+ "name": "str"
+ }
}
]
}
\ No newline at end of file
diff --git a/tests/cases/checker/06_subtyping.py b/tests/cases/checker/06_subtyping.py
index c334ab8..7ab9dd7 100644
--- a/tests/cases/checker/06_subtyping.py
+++ b/tests/cases/checker/06_subtyping.py
@@ -9,4 +9,4 @@ def maximum(a: float, b: float):
v3 = maximum(v1, v2)
-v3 = v1 + v2
+v3 = v2 + v1
diff --git a/tests/cases/checker/06_subtyping.py.ref.json b/tests/cases/checker/06_subtyping.py.ref.json
index 689402e..3435f45 100644
--- a/tests/cases/checker/06_subtyping.py.ref.json
+++ b/tests/cases/checker/06_subtyping.py.ref.json
@@ -53,6 +53,53 @@
"name": "float"
}
},
+ {
+ "location": {
+ "from": "L6:7",
+ "to": "L6:12"
+ },
+ "expr": {
+ "_type": "CompareExpr",
+ "left": {
+ "_type": "VariableExpr",
+ "name": "b"
+ },
+ "operator": ">",
+ "right": {
+ "_type": "VariableExpr",
+ "name": "a"
+ }
+ },
+ "type": {
+ "name": "bool"
+ }
+ },
+ {
+ "location": {
+ "from": "L7:15",
+ "to": "L7:16"
+ },
+ "expr": {
+ "_type": "VariableExpr",
+ "name": "b"
+ },
+ "type": {
+ "name": "float"
+ }
+ },
+ {
+ "location": {
+ "from": "L8:11",
+ "to": "L8:12"
+ },
+ "expr": {
+ "_type": "VariableExpr",
+ "name": "a"
+ },
+ "type": {
+ "name": "float"
+ }
+ },
{
"location": {
"from": "L11:5",
@@ -63,7 +110,6 @@
"name": "maximum"
},
"type": {
- "name": "maximum",
"pos_args": [],
"args": [
{
@@ -149,10 +195,10 @@
},
"expr": {
"_type": "VariableExpr",
- "name": "v1"
+ "name": "v2"
},
"type": {
- "name": "int"
+ "name": "float"
}
},
{
@@ -162,10 +208,10 @@
},
"expr": {
"_type": "VariableExpr",
- "name": "v2"
+ "name": "v1"
},
"type": {
- "name": "float"
+ "name": "int"
}
},
{
@@ -177,12 +223,12 @@
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
- "name": "v1"
+ "name": "v2"
},
"operator": "+",
"right": {
"_type": "VariableExpr",
- "name": "v2"
+ "name": "v1"
}
},
"type": {
diff --git a/tests/cases/midas-parser/01_simple_types.midas b/tests/cases/midas-parser/01_simple_types.midas
index 6446790..f0df3e2 100644
--- a/tests/cases/midas-parser/01_simple_types.midas
+++ b/tests/cases/midas-parser/01_simple_types.midas
@@ -10,8 +10,8 @@ type Difference[T] = T
// Complex custom type, containing two values accessible through properties
type GeoLocation = {
- lat: Latitude
- lon: Longitude
+ prop lat: Latitude
+ prop lon: Longitude
}
// Define operations on our custom type
@@ -19,23 +19,23 @@ extend GeoLocation {
// This type is compatible with the `-` operation with another GeoLocation
// i.e. you can subtract a GeoLocation from another GeoLocation, resulting
// in a Difference of GeoLocations
- op __sub__(GeoLocation) -> Difference[GeoLocation]
+ def __sub__: fn(GeoLocation, /) -> Difference[GeoLocation]
}
// For complex generics, you need to specify how the genericity the properties
// are handled
type Difference[GeoLocation] = {
- lat: Difference[Latitude]
- lon: Difference[Longitude]
+ prop lat: Difference[Latitude]
+ prop lon: Difference[Longitude]
}
// Simple operation defined on our custom types
extend Latitude {
- op __sub__(Latitude) -> Difference[Latitude]
+ def __sub__: fn(Latitude, /) -> Difference[Latitude]
}
extend Longitude {
- op __sub__(Longitude) -> Difference[Longitude]
+ def __sub__: fn(Longitude, /) -> Difference[Longitude]
}
// Predefined custom predicates that can be referenced in other definitions
@@ -45,13 +45,13 @@ predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
type Person = {
- name: str
+ prop name: str
// Property with an inline constraint
- age: Optional[int where (0 <= _ < 150)]
+ prop age: Optional[int where (0 <= _ < 150)]
// Property referencing a predicate
- height: float where StrictlyPositive
+ prop height: float where StrictlyPositive
- home: GeoLocation
+ prop home: GeoLocation
}
diff --git a/tests/cases/midas-parser/01_simple_types.midas.ref.json b/tests/cases/midas-parser/01_simple_types.midas.ref.json
index 55b4813..be45687 100644
--- a/tests/cases/midas-parser/01_simple_types.midas.ref.json
+++ b/tests/cases/midas-parser/01_simple_types.midas.ref.json
@@ -511,17 +511,11 @@
"column": 1
},
{
- "type": "IDENTIFIER",
- "lexeme": "lat",
+ "type": "PROP",
+ "lexeme": "prop",
"line": 13,
"column": 5
},
- {
- "type": "COLON",
- "lexeme": ":",
- "line": 13,
- "column": 8
- },
{
"type": "WHITESPACE",
"lexeme": " ",
@@ -530,15 +524,33 @@
},
{
"type": "IDENTIFIER",
- "lexeme": "Latitude",
+ "lexeme": "lat",
"line": 13,
"column": 10
},
+ {
+ "type": "COLON",
+ "lexeme": ":",
+ "line": 13,
+ "column": 13
+ },
+ {
+ "type": "WHITESPACE",
+ "lexeme": " ",
+ "line": 13,
+ "column": 14
+ },
+ {
+ "type": "IDENTIFIER",
+ "lexeme": "Latitude",
+ "line": 13,
+ "column": 15
+ },
{
"type": "NEWLINE",
"lexeme": "\n",
"line": 13,
- "column": 18
+ "column": 23
},
{
"type": "WHITESPACE",
@@ -547,17 +559,11 @@
"column": 1
},
{
- "type": "IDENTIFIER",
- "lexeme": "lon",
+ "type": "PROP",
+ "lexeme": "prop",
"line": 14,
"column": 5
},
- {
- "type": "COLON",
- "lexeme": ":",
- "line": 14,
- "column": 8
- },
{
"type": "WHITESPACE",
"lexeme": " ",
@@ -566,15 +572,33 @@
},
{
"type": "IDENTIFIER",
- "lexeme": "Longitude",
+ "lexeme": "lon",
"line": 14,
"column": 10
},
+ {
+ "type": "COLON",
+ "lexeme": ":",
+ "line": 14,
+ "column": 13
+ },
+ {
+ "type": "WHITESPACE",
+ "lexeme": " ",
+ "line": 14,
+ "column": 14
+ },
+ {
+ "type": "IDENTIFIER",
+ "lexeme": "Longitude",
+ "line": 14,
+ "column": 15
+ },
{
"type": "NEWLINE",
"lexeme": "\n",
"line": 14,
- "column": 19
+ "column": 24
},
{
"type": "RIGHT_BRACE",
@@ -703,8 +727,8 @@
"column": 1
},
{
- "type": "OP",
- "lexeme": "op",
+ "type": "DEF",
+ "lexeme": "def",
"line": 22,
"column": 5
},
@@ -712,79 +736,115 @@
"type": "WHITESPACE",
"lexeme": " ",
"line": 22,
- "column": 7
+ "column": 8
},
{
"type": "IDENTIFIER",
"lexeme": "__sub__",
"line": 22,
- "column": 8
+ "column": 9
+ },
+ {
+ "type": "COLON",
+ "lexeme": ":",
+ "line": 22,
+ "column": 16
+ },
+ {
+ "type": "WHITESPACE",
+ "lexeme": " ",
+ "line": 22,
+ "column": 17
+ },
+ {
+ "type": "FUNC",
+ "lexeme": "fn",
+ "line": 22,
+ "column": 18
},
{
"type": "LEFT_PAREN",
"lexeme": "(",
"line": 22,
- "column": 15
+ "column": 20
},
{
"type": "IDENTIFIER",
"lexeme": "GeoLocation",
"line": 22,
- "column": 16
+ "column": 21
+ },
+ {
+ "type": "COMMA",
+ "lexeme": ",",
+ "line": 22,
+ "column": 32
+ },
+ {
+ "type": "WHITESPACE",
+ "lexeme": " ",
+ "line": 22,
+ "column": 33
+ },
+ {
+ "type": "SLASH",
+ "lexeme": "/",
+ "line": 22,
+ "column": 34
},
{
"type": "RIGHT_PAREN",
"lexeme": ")",
"line": 22,
- "column": 27
+ "column": 35
},
{
"type": "WHITESPACE",
"lexeme": " ",
"line": 22,
- "column": 28
+ "column": 36
},
{
"type": "ARROW",
"lexeme": "->",
"line": 22,
- "column": 29
+ "column": 37
},
{
"type": "WHITESPACE",
"lexeme": " ",
"line": 22,
- "column": 31
+ "column": 39
},
{
"type": "IDENTIFIER",
"lexeme": "Difference",
"line": 22,
- "column": 32
+ "column": 40
},
{
"type": "LEFT_BRACKET",
"lexeme": "[",
"line": 22,
- "column": 42
+ "column": 50
},
{
"type": "IDENTIFIER",
"lexeme": "GeoLocation",
"line": 22,
- "column": 43
+ "column": 51
},
{
"type": "RIGHT_BRACKET",
"lexeme": "]",
"line": 22,
- "column": 54
+ "column": 62
},
{
"type": "NEWLINE",
"lexeme": "\n",
"line": 22,
- "column": 55
+ "column": 63
},
{
"type": "RIGHT_BRACE",
@@ -901,17 +961,11 @@
"column": 1
},
{
- "type": "IDENTIFIER",
- "lexeme": "lat",
+ "type": "PROP",
+ "lexeme": "prop",
"line": 28,
"column": 5
},
- {
- "type": "COLON",
- "lexeme": ":",
- "line": 28,
- "column": 8
- },
{
"type": "WHITESPACE",
"lexeme": " ",
@@ -920,33 +974,51 @@
},
{
"type": "IDENTIFIER",
- "lexeme": "Difference",
+ "lexeme": "lat",
"line": 28,
"column": 10
},
+ {
+ "type": "COLON",
+ "lexeme": ":",
+ "line": 28,
+ "column": 13
+ },
+ {
+ "type": "WHITESPACE",
+ "lexeme": " ",
+ "line": 28,
+ "column": 14
+ },
+ {
+ "type": "IDENTIFIER",
+ "lexeme": "Difference",
+ "line": 28,
+ "column": 15
+ },
{
"type": "LEFT_BRACKET",
"lexeme": "[",
"line": 28,
- "column": 20
+ "column": 25
},
{
"type": "IDENTIFIER",
"lexeme": "Latitude",
"line": 28,
- "column": 21
+ "column": 26
},
{
"type": "RIGHT_BRACKET",
"lexeme": "]",
"line": 28,
- "column": 29
+ "column": 34
},
{
"type": "NEWLINE",
"lexeme": "\n",
"line": 28,
- "column": 30
+ "column": 35
},
{
"type": "WHITESPACE",
@@ -955,17 +1027,11 @@
"column": 1
},
{
- "type": "IDENTIFIER",
- "lexeme": "lon",
+ "type": "PROP",
+ "lexeme": "prop",
"line": 29,
"column": 5
},
- {
- "type": "COLON",
- "lexeme": ":",
- "line": 29,
- "column": 8
- },
{
"type": "WHITESPACE",
"lexeme": " ",
@@ -974,33 +1040,51 @@
},
{
"type": "IDENTIFIER",
- "lexeme": "Difference",
+ "lexeme": "lon",
"line": 29,
"column": 10
},
+ {
+ "type": "COLON",
+ "lexeme": ":",
+ "line": 29,
+ "column": 13
+ },
+ {
+ "type": "WHITESPACE",
+ "lexeme": " ",
+ "line": 29,
+ "column": 14
+ },
+ {
+ "type": "IDENTIFIER",
+ "lexeme": "Difference",
+ "line": 29,
+ "column": 15
+ },
{
"type": "LEFT_BRACKET",
"lexeme": "[",
"line": 29,
- "column": 20
+ "column": 25
},
{
"type": "IDENTIFIER",
"lexeme": "Longitude",
"line": 29,
- "column": 21
+ "column": 26
},
{
"type": "RIGHT_BRACKET",
"lexeme": "]",
"line": 29,
- "column": 30
+ "column": 35
},
{
"type": "NEWLINE",
"lexeme": "\n",
"line": 29,
- "column": 31
+ "column": 36
},
{
"type": "RIGHT_BRACE",
@@ -1075,8 +1159,8 @@
"column": 1
},
{
- "type": "OP",
- "lexeme": "op",
+ "type": "DEF",
+ "lexeme": "def",
"line": 34,
"column": 5
},
@@ -1084,79 +1168,115 @@
"type": "WHITESPACE",
"lexeme": " ",
"line": 34,
- "column": 7
+ "column": 8
},
{
"type": "IDENTIFIER",
"lexeme": "__sub__",
"line": 34,
- "column": 8
+ "column": 9
+ },
+ {
+ "type": "COLON",
+ "lexeme": ":",
+ "line": 34,
+ "column": 16
+ },
+ {
+ "type": "WHITESPACE",
+ "lexeme": " ",
+ "line": 34,
+ "column": 17
+ },
+ {
+ "type": "FUNC",
+ "lexeme": "fn",
+ "line": 34,
+ "column": 18
},
{
"type": "LEFT_PAREN",
"lexeme": "(",
"line": 34,
- "column": 15
+ "column": 20
},
{
"type": "IDENTIFIER",
"lexeme": "Latitude",
"line": 34,
- "column": 16
+ "column": 21
+ },
+ {
+ "type": "COMMA",
+ "lexeme": ",",
+ "line": 34,
+ "column": 29
+ },
+ {
+ "type": "WHITESPACE",
+ "lexeme": " ",
+ "line": 34,
+ "column": 30
+ },
+ {
+ "type": "SLASH",
+ "lexeme": "/",
+ "line": 34,
+ "column": 31
},
{
"type": "RIGHT_PAREN",
"lexeme": ")",
"line": 34,
- "column": 24
+ "column": 32
},
{
"type": "WHITESPACE",
"lexeme": " ",
"line": 34,
- "column": 25
+ "column": 33
},
{
"type": "ARROW",
"lexeme": "->",
"line": 34,
- "column": 26
+ "column": 34
},
{
"type": "WHITESPACE",
"lexeme": " ",
"line": 34,
- "column": 28
+ "column": 36
},
{
"type": "IDENTIFIER",
"lexeme": "Difference",
"line": 34,
- "column": 29
+ "column": 37
},
{
"type": "LEFT_BRACKET",
"lexeme": "[",
"line": 34,
- "column": 39
+ "column": 47
},
{
"type": "IDENTIFIER",
"lexeme": "Latitude",
"line": 34,
- "column": 40
+ "column": 48
},
{
"type": "RIGHT_BRACKET",
"lexeme": "]",
"line": 34,
- "column": 48
+ "column": 56
},
{
"type": "NEWLINE",
"lexeme": "\n",
"line": 34,
- "column": 49
+ "column": 57
},
{
"type": "RIGHT_BRACE",
@@ -1219,8 +1339,8 @@
"column": 1
},
{
- "type": "OP",
- "lexeme": "op",
+ "type": "DEF",
+ "lexeme": "def",
"line": 38,
"column": 5
},
@@ -1228,79 +1348,115 @@
"type": "WHITESPACE",
"lexeme": " ",
"line": 38,
- "column": 7
+ "column": 8
},
{
"type": "IDENTIFIER",
"lexeme": "__sub__",
"line": 38,
- "column": 8
+ "column": 9
+ },
+ {
+ "type": "COLON",
+ "lexeme": ":",
+ "line": 38,
+ "column": 16
+ },
+ {
+ "type": "WHITESPACE",
+ "lexeme": " ",
+ "line": 38,
+ "column": 17
+ },
+ {
+ "type": "FUNC",
+ "lexeme": "fn",
+ "line": 38,
+ "column": 18
},
{
"type": "LEFT_PAREN",
"lexeme": "(",
"line": 38,
- "column": 15
+ "column": 20
},
{
"type": "IDENTIFIER",
"lexeme": "Longitude",
"line": 38,
- "column": 16
+ "column": 21
+ },
+ {
+ "type": "COMMA",
+ "lexeme": ",",
+ "line": 38,
+ "column": 30
+ },
+ {
+ "type": "WHITESPACE",
+ "lexeme": " ",
+ "line": 38,
+ "column": 31
+ },
+ {
+ "type": "SLASH",
+ "lexeme": "/",
+ "line": 38,
+ "column": 32
},
{
"type": "RIGHT_PAREN",
"lexeme": ")",
"line": 38,
- "column": 25
+ "column": 33
},
{
"type": "WHITESPACE",
"lexeme": " ",
"line": 38,
- "column": 26
+ "column": 34
},
{
"type": "ARROW",
"lexeme": "->",
"line": 38,
- "column": 27
+ "column": 35
},
{
"type": "WHITESPACE",
"lexeme": " ",
"line": 38,
- "column": 29
+ "column": 37
},
{
"type": "IDENTIFIER",
"lexeme": "Difference",
"line": 38,
- "column": 30
+ "column": 38
},
{
"type": "LEFT_BRACKET",
"lexeme": "[",
"line": 38,
- "column": 40
+ "column": 48
},
{
"type": "IDENTIFIER",
"lexeme": "Longitude",
"line": 38,
- "column": 41
+ "column": 49
},
{
"type": "RIGHT_BRACKET",
"lexeme": "]",
"line": 38,
- "column": 50
+ "column": 58
},
{
"type": "NEWLINE",
"lexeme": "\n",
"line": 38,
- "column": 51
+ "column": 59
},
{
"type": "RIGHT_BRACE",
@@ -1903,34 +2059,46 @@
"column": 1
},
{
- "type": "IDENTIFIER",
- "lexeme": "name",
+ "type": "PROP",
+ "lexeme": "prop",
"line": 48,
"column": 5
},
- {
- "type": "COLON",
- "lexeme": ":",
- "line": 48,
- "column": 9
- },
{
"type": "WHITESPACE",
"lexeme": " ",
"line": 48,
+ "column": 9
+ },
+ {
+ "type": "IDENTIFIER",
+ "lexeme": "name",
+ "line": 48,
"column": 10
},
+ {
+ "type": "COLON",
+ "lexeme": ":",
+ "line": 48,
+ "column": 14
+ },
+ {
+ "type": "WHITESPACE",
+ "lexeme": " ",
+ "line": 48,
+ "column": 15
+ },
{
"type": "IDENTIFIER",
"lexeme": "str",
"line": 48,
- "column": 11
+ "column": 16
},
{
"type": "NEWLINE",
"lexeme": "\n",
"line": 48,
- "column": 14
+ "column": 19
},
{
"type": "NEWLINE",
@@ -1963,17 +2131,11 @@
"column": 1
},
{
- "type": "IDENTIFIER",
- "lexeme": "age",
+ "type": "PROP",
+ "lexeme": "prop",
"line": 51,
"column": 5
},
- {
- "type": "COLON",
- "lexeme": ":",
- "line": 51,
- "column": 8
- },
{
"type": "WHITESPACE",
"lexeme": " ",
@@ -1982,74 +2144,68 @@
},
{
"type": "IDENTIFIER",
- "lexeme": "Optional",
+ "lexeme": "age",
"line": 51,
"column": 10
},
+ {
+ "type": "COLON",
+ "lexeme": ":",
+ "line": 51,
+ "column": 13
+ },
+ {
+ "type": "WHITESPACE",
+ "lexeme": " ",
+ "line": 51,
+ "column": 14
+ },
+ {
+ "type": "IDENTIFIER",
+ "lexeme": "Optional",
+ "line": 51,
+ "column": 15
+ },
{
"type": "LEFT_BRACKET",
"lexeme": "[",
"line": 51,
- "column": 18
+ "column": 23
},
{
"type": "IDENTIFIER",
"lexeme": "int",
"line": 51,
- "column": 19
+ "column": 24
},
{
"type": "WHITESPACE",
"lexeme": " ",
"line": 51,
- "column": 22
+ "column": 27
},
{
"type": "WHERE",
"lexeme": "where",
"line": 51,
- "column": 23
+ "column": 28
},
{
"type": "WHITESPACE",
"lexeme": " ",
"line": 51,
- "column": 28
+ "column": 33
},
{
"type": "LEFT_PAREN",
"lexeme": "(",
"line": 51,
- "column": 29
+ "column": 34
},
{
"type": "NUMBER",
"lexeme": "0",
"line": 51,
- "column": 30
- },
- {
- "type": "WHITESPACE",
- "lexeme": " ",
- "line": 51,
- "column": 31
- },
- {
- "type": "LESS_EQUAL",
- "lexeme": "<=",
- "line": 51,
- "column": 32
- },
- {
- "type": "WHITESPACE",
- "lexeme": " ",
- "line": 51,
- "column": 34
- },
- {
- "type": "UNDERSCORE",
- "lexeme": "_",
- "line": 51,
"column": 35
},
{
@@ -2059,8 +2215,8 @@
"column": 36
},
{
- "type": "LESS",
- "lexeme": "<",
+ "type": "LESS_EQUAL",
+ "lexeme": "<=",
"line": 51,
"column": 37
},
@@ -2068,31 +2224,55 @@
"type": "WHITESPACE",
"lexeme": " ",
"line": 51,
- "column": 38
+ "column": 39
+ },
+ {
+ "type": "UNDERSCORE",
+ "lexeme": "_",
+ "line": 51,
+ "column": 40
+ },
+ {
+ "type": "WHITESPACE",
+ "lexeme": " ",
+ "line": 51,
+ "column": 41
+ },
+ {
+ "type": "LESS",
+ "lexeme": "<",
+ "line": 51,
+ "column": 42
+ },
+ {
+ "type": "WHITESPACE",
+ "lexeme": " ",
+ "line": 51,
+ "column": 43
},
{
"type": "NUMBER",
"lexeme": "150",
"line": 51,
- "column": 39
+ "column": 44
},
{
"type": "RIGHT_PAREN",
"lexeme": ")",
"line": 51,
- "column": 42
+ "column": 47
},
{
"type": "RIGHT_BRACKET",
"lexeme": "]",
"line": 51,
- "column": 43
+ "column": 48
},
{
"type": "NEWLINE",
"lexeme": "\n",
"line": 51,
- "column": 44
+ "column": 49
},
{
"type": "NEWLINE",
@@ -2124,59 +2304,71 @@
"line": 54,
"column": 1
},
+ {
+ "type": "PROP",
+ "lexeme": "prop",
+ "line": 54,
+ "column": 5
+ },
+ {
+ "type": "WHITESPACE",
+ "lexeme": " ",
+ "line": 54,
+ "column": 9
+ },
{
"type": "IDENTIFIER",
"lexeme": "height",
"line": 54,
- "column": 5
+ "column": 10
},
{
"type": "COLON",
"lexeme": ":",
"line": 54,
- "column": 11
+ "column": 16
},
{
"type": "WHITESPACE",
"lexeme": " ",
"line": 54,
- "column": 12
+ "column": 17
},
{
"type": "IDENTIFIER",
"lexeme": "float",
"line": 54,
- "column": 13
+ "column": 18
},
{
"type": "WHITESPACE",
"lexeme": " ",
"line": 54,
- "column": 18
+ "column": 23
},
{
"type": "WHERE",
"lexeme": "where",
"line": 54,
- "column": 19
+ "column": 24
},
{
"type": "WHITESPACE",
"lexeme": " ",
"line": 54,
- "column": 24
+ "column": 29
},
{
"type": "IDENTIFIER",
"lexeme": "StrictlyPositive",
"line": 54,
- "column": 25
+ "column": 30
},
{
"type": "NEWLINE",
"lexeme": "\n",
"line": 54,
- "column": 41
+ "column": 46
},
{
"type": "NEWLINE",
@@ -2191,34 +2383,46 @@
"column": 1
},
{
- "type": "IDENTIFIER",
- "lexeme": "home",
+ "type": "PROP",
+ "lexeme": "prop",
"line": 56,
"column": 5
},
- {
- "type": "COLON",
- "lexeme": ":",
- "line": 56,
- "column": 9
- },
{
"type": "WHITESPACE",
"lexeme": " ",
"line": 56,
+ "column": 9
+ },
+ {
+ "type": "IDENTIFIER",
+ "lexeme": "home",
+ "line": 56,
"column": 10
},
+ {
+ "type": "COLON",
+ "lexeme": ":",
+ "line": 56,
+ "column": 14
+ },
+ {
+ "type": "WHITESPACE",
+ "lexeme": " ",
+ "line": 56,
+ "column": 15
+ },
{
"type": "IDENTIFIER",
"lexeme": "GeoLocation",
"line": 56,
- "column": 11
+ "column": 16
},
{
"type": "NEWLINE",
"lexeme": "\n",
"line": 56,
- "column": 22
+ "column": 27
},
{
"type": "RIGHT_BRACE",
@@ -2345,9 +2549,10 @@
"params": [],
"type": {
"_type": "ComplexType",
- "properties": [
+ "members": [
{
- "_type": "PropertyStmt",
+ "_type": "MemberStmt",
+ "kind": "PROPERTY",
"name": "lat",
"type": {
"_type": "NamedType",
@@ -2355,7 +2560,8 @@
}
},
{
- "_type": "PropertyStmt",
+ "_type": "MemberStmt",
+ "kind": "PROPERTY",
"name": "lon",
"type": {
"_type": "NamedType",
@@ -2367,30 +2573,40 @@
},
{
"_type": "ExtendStmt",
- "type": {
- "_type": "NamedType",
- "name": "GeoLocation"
- },
- "operations": [
+ "name": "GeoLocation",
+ "params": [],
+ "members": [
{
- "_type": "OpStmt",
+ "_type": "MemberStmt",
+ "kind": "METHOD",
"name": "__sub__",
- "operand": {
- "_type": "NamedType",
- "name": "GeoLocation"
- },
- "result": {
- "_type": "GenericType",
- "type": {
- "_type": "NamedType",
- "name": "Difference"
- },
- "params": [
+ "type": {
+ "_type": "FunctionType",
+ "pos_args": [
{
- "_type": "NamedType",
- "name": "GeoLocation"
+ "name": null,
+ "type": {
+ "_type": "NamedType",
+ "name": "GeoLocation"
+ },
+ "required": true
}
- ]
+ ],
+ "args": [],
+ "kw_args": [],
+ "returns": {
+ "_type": "GenericType",
+ "type": {
+ "_type": "NamedType",
+ "name": "Difference"
+ },
+ "args": [
+ {
+ "_type": "NamedType",
+ "name": "GeoLocation"
+ }
+ ]
+ }
}
}
]
@@ -2406,9 +2622,10 @@
],
"type": {
"_type": "ComplexType",
- "properties": [
+ "members": [
{
- "_type": "PropertyStmt",
+ "_type": "MemberStmt",
+ "kind": "PROPERTY",
"name": "lat",
"type": {
"_type": "GenericType",
@@ -2416,7 +2633,7 @@
"_type": "NamedType",
"name": "Difference"
},
- "params": [
+ "args": [
{
"_type": "NamedType",
"name": "Latitude"
@@ -2425,7 +2642,8 @@
}
},
{
- "_type": "PropertyStmt",
+ "_type": "MemberStmt",
+ "kind": "PROPERTY",
"name": "lon",
"type": {
"_type": "GenericType",
@@ -2433,7 +2651,7 @@
"_type": "NamedType",
"name": "Difference"
},
- "params": [
+ "args": [
{
"_type": "NamedType",
"name": "Longitude"
@@ -2446,60 +2664,80 @@
},
{
"_type": "ExtendStmt",
- "type": {
- "_type": "NamedType",
- "name": "Latitude"
- },
- "operations": [
+ "name": "Latitude",
+ "params": [],
+ "members": [
{
- "_type": "OpStmt",
+ "_type": "MemberStmt",
+ "kind": "METHOD",
"name": "__sub__",
- "operand": {
- "_type": "NamedType",
- "name": "Latitude"
- },
- "result": {
- "_type": "GenericType",
- "type": {
- "_type": "NamedType",
- "name": "Difference"
- },
- "params": [
+ "type": {
+ "_type": "FunctionType",
+ "pos_args": [
{
- "_type": "NamedType",
- "name": "Latitude"
+ "name": null,
+ "type": {
+ "_type": "NamedType",
+ "name": "Latitude"
+ },
+ "required": true
}
- ]
+ ],
+ "args": [],
+ "kw_args": [],
+ "returns": {
+ "_type": "GenericType",
+ "type": {
+ "_type": "NamedType",
+ "name": "Difference"
+ },
+ "args": [
+ {
+ "_type": "NamedType",
+ "name": "Latitude"
+ }
+ ]
+ }
}
}
]
},
{
"_type": "ExtendStmt",
- "type": {
- "_type": "NamedType",
- "name": "Longitude"
- },
- "operations": [
+ "name": "Longitude",
+ "params": [],
+ "members": [
{
- "_type": "OpStmt",
+ "_type": "MemberStmt",
+ "kind": "METHOD",
"name": "__sub__",
- "operand": {
- "_type": "NamedType",
- "name": "Longitude"
- },
- "result": {
- "_type": "GenericType",
- "type": {
- "_type": "NamedType",
- "name": "Difference"
- },
- "params": [
+ "type": {
+ "_type": "FunctionType",
+ "pos_args": [
{
- "_type": "NamedType",
- "name": "Longitude"
+ "name": null,
+ "type": {
+ "_type": "NamedType",
+ "name": "Longitude"
+ },
+ "required": true
}
- ]
+ ],
+ "args": [],
+ "kw_args": [],
+ "returns": {
+ "_type": "GenericType",
+ "type": {
+ "_type": "NamedType",
+ "name": "Difference"
+ },
+ "args": [
+ {
+ "_type": "NamedType",
+ "name": "Longitude"
+ }
+ ]
+ }
}
}
]
@@ -2620,9 +2858,10 @@
"params": [],
"type": {
"_type": "ComplexType",
- "properties": [
+ "members": [
{
- "_type": "PropertyStmt",
+ "_type": "MemberStmt",
+ "kind": "PROPERTY",
"name": "name",
"type": {
"_type": "NamedType",
@@ -2630,7 +2869,8 @@
}
},
{
- "_type": "PropertyStmt",
+ "_type": "MemberStmt",
+ "kind": "PROPERTY",
"name": "age",
"type": {
"_type": "GenericType",
@@ -2638,7 +2878,7 @@
"_type": "NamedType",
"name": "Optional"
},
- "params": [
+ "args": [
{
"_type": "ConstraintType",
"type": {
@@ -2672,7 +2912,8 @@
}
},
{
- "_type": "PropertyStmt",
+ "_type": "MemberStmt",
+ "kind": "PROPERTY",
"name": "height",
"type": {
"_type": "ConstraintType",
@@ -2687,7 +2928,8 @@
}
},
{
- "_type": "PropertyStmt",
+ "_type": "MemberStmt",
+ "kind": "PROPERTY",
"name": "home",
"type": {
"_type": "NamedType",
diff --git a/tests/cases/python-parser/02_custom_types.py.ref.json b/tests/cases/python-parser/02_custom_types.py.ref.json
index 639610d..82c726c 100644
--- a/tests/cases/python-parser/02_custom_types.py.ref.json
+++ b/tests/cases/python-parser/02_custom_types.py.ref.json
@@ -18,6 +18,80 @@
]
}
},
+ {
+ "_type": "TypeAssign",
+ "name": "lat",
+ "type": {
+ "_type": "BaseType",
+ "base": "Column",
+ "param": {
+ "_type": "BaseType",
+ "base": "GeoLocation",
+ "param": null
+ }
+ }
+ },
+ {
+ "_type": "AssignStmt",
+ "targets": [
+ {
+ "_type": "VariableExpr",
+ "name": "lat"
+ }
+ ],
+ "value": {
+ "_type": "GetExpr",
+ "object": {
+ "_type": "SubscriptExpr",
+ "object": {
+ "_type": "VariableExpr",
+ "name": "df"
+ },
+ "index": {
+ "_type": "LiteralExpr",
+ "value": "location"
+ }
+ },
+ "name": "lat"
+ }
+ },
+ {
+ "_type": "TypeAssign",
+ "name": "lon",
+ "type": {
+ "_type": "BaseType",
+ "base": "Column",
+ "param": {
+ "_type": "BaseType",
+ "base": "GeoLocation",
+ "param": null
+ }
+ }
+ },
+ {
+ "_type": "AssignStmt",
+ "targets": [
+ {
+ "_type": "VariableExpr",
+ "name": "lon"
+ }
+ ],
+ "value": {
+ "_type": "GetExpr",
+ "object": {
+ "_type": "SubscriptExpr",
+ "object": {
+ "_type": "VariableExpr",
+ "name": "df"
+ },
+ "index": {
+ "_type": "LiteralExpr",
+ "value": "location"
+ }
+ },
+ "name": "lon"
+ }
+ },
{
"_type": "ExpressionStmt",
"expr": {
@@ -33,6 +107,64 @@
}
}
},
+ {
+ "_type": "TypeAssign",
+ "name": "lat1",
+ "type": {
+ "_type": "BaseType",
+ "base": "Latitude",
+ "param": null
+ }
+ },
+ {
+ "_type": "AssignStmt",
+ "targets": [
+ {
+ "_type": "VariableExpr",
+ "name": "lat1"
+ }
+ ],
+ "value": {
+ "_type": "SubscriptExpr",
+ "object": {
+ "_type": "VariableExpr",
+ "name": "lat"
+ },
+ "index": {
+ "_type": "LiteralExpr",
+ "value": 0
+ }
+ }
+ },
+ {
+ "_type": "TypeAssign",
+ "name": "lat2",
+ "type": {
+ "_type": "BaseType",
+ "base": "Latitude",
+ "param": null
+ }
+ },
+ {
+ "_type": "AssignStmt",
+ "targets": [
+ {
+ "_type": "VariableExpr",
+ "name": "lat2"
+ }
+ ],
+ "value": {
+ "_type": "SubscriptExpr",
+ "object": {
+ "_type": "VariableExpr",
+ "name": "lat"
+ },
+ "index": {
+ "_type": "LiteralExpr",
+ "value": 1
+ }
+ }
+ },
{
"_type": "TypeAssign",
"name": "lat_diff",
diff --git a/tests/checker.py b/tests/checker.py
index 27a94cb..3ceb34e 100644
--- a/tests/checker.py
+++ b/tests/checker.py
@@ -1,14 +1,11 @@
-import ast
import json
from dataclasses import asdict, dataclass, field
from pathlib import Path
import midas.ast.python as p
-from midas.checker.checker import Checker
+from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic
from midas.checker.types import Type
-from midas.parser.python import PythonParser
-from midas.resolver.resolver import Resolver
from tests.base import Tester
from tests.serializer.python import PythonAstJsonSerializer
@@ -36,24 +33,16 @@ class CheckerTester(Tester):
if not path.is_file():
raise TypeError(f"Test '{path}' is not a file")
- types_paths: list[Path] = []
+ result: CaseResult = CaseResult()
+
+ checker = TypeChecker()
types_path: Path = path.with_suffix(".midas")
if types_path.exists():
- types_paths.append(types_path)
- source: str = path.read_text()
- tree: ast.Module = ast.parse(source, filename=path)
- parser = PythonParser()
- stmts: list[p.Stmt] = parser.parse_module(tree)
- resolver = Resolver()
- resolver.resolve(*stmts)
- result: CaseResult = CaseResult()
- checker = Checker(
- resolver.locals,
- source_path=path,
- types_paths=types_paths,
- )
+ checker.import_midas(types_path)
- diagnostics: list[Diagnostic] = checker.check(stmts)
+ checker.type_check(path)
+
+ diagnostics: list[Diagnostic] = checker.diagnostics
for diagnostic in diagnostics:
result.diagnostics.append(
{
@@ -72,7 +61,7 @@ class CheckerTester(Tester):
}
)
- judgements: list[tuple[p.Expr, Type]] = checker.judgements
+ judgements: list[tuple[p.Expr, Type]] = checker.python_typer.judgements
serializer = PythonAstJsonSerializer()
for expr, type in judgements:
loc = expr.location
diff --git a/tests/serializer/midas.py b/tests/serializer/midas.py
index 919dc66..8bffdb3 100644
--- a/tests/serializer/midas.py
+++ b/tests/serializer/midas.py
@@ -6,17 +6,19 @@ from midas.ast.midas import (
ConstraintType,
Expr,
ExtendStmt,
+ ExtensionType,
+ FunctionType,
GenericType,
GetExpr,
GroupingExpr,
LiteralExpr,
LogicalExpr,
+ MemberStmt,
NamedType,
- OpStmt,
PredicateStmt,
- PropertyStmt,
Stmt,
Type,
+ TypeParam,
TypeStmt,
UnaryExpr,
VariableExpr,
@@ -46,21 +48,20 @@ class MidasAstJsonSerializer(
return {
"_type": "TypeStmt",
"name": stmt.name.lexeme,
- "params": [
- self._serialize_type_stmt_template_param(param) for param in stmt.params
- ],
+ "params": [self._serialize_type_param(param) for param in stmt.params],
"type": stmt.type.accept(self),
}
- def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict:
+ def _serialize_type_param(self, param: TypeParam) -> dict:
return {
"name": param.name.lexeme,
"bound": self._serialize_optional(param.bound),
}
- def visit_property_stmt(self, stmt: PropertyStmt) -> dict:
+ def visit_member_stmt(self, stmt: MemberStmt) -> dict:
return {
- "_type": "PropertyStmt",
+ "_type": "MemberStmt",
+ "kind": stmt.kind.name,
"name": stmt.name.lexeme,
"type": stmt.type.accept(self),
}
@@ -68,16 +69,9 @@ class MidasAstJsonSerializer(
def visit_extend_stmt(self, stmt: ExtendStmt) -> dict:
return {
"_type": "ExtendStmt",
- "type": stmt.type.accept(self),
- "operations": self._serialize_list(stmt.operations),
- }
-
- def visit_op_stmt(self, stmt: OpStmt) -> dict:
- return {
- "_type": "OpStmt",
"name": stmt.name.lexeme,
- "operand": stmt.operand.accept(self),
- "result": stmt.result.accept(self),
+ "params": [self._serialize_type_param(param) for param in stmt.params],
+ "members": self._serialize_list(stmt.members),
}
def visit_predicate_stmt(self, stmt: PredicateStmt) -> dict:
@@ -150,7 +144,7 @@ class MidasAstJsonSerializer(
return {
"_type": "GenericType",
"type": type.type.accept(self),
- "params": self._serialize_list(type.params),
+ "args": self._serialize_list(type.args),
}
def visit_constraint_type(self, type: ConstraintType) -> dict:
@@ -163,5 +157,28 @@ class MidasAstJsonSerializer(
def visit_complex_type(self, type: ComplexType) -> dict:
return {
"_type": "ComplexType",
- "properties": self._serialize_list(type.properties),
+ "members": self._serialize_list(type.members),
+ }
+
+ def visit_function_type(self, type: FunctionType) -> dict:
+ return {
+ "_type": "FunctionType",
+ "pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
+ "args": [self._serialize_func_arg(arg) for arg in type.args],
+ "kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
+ "returns": type.returns.accept(self),
+ }
+
+ def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
+ return {
+ "name": arg.name,
+ "type": arg.type.accept(self),
+ "required": arg.required,
+ }
+
+ def visit_extension_type(self, type: ExtensionType) -> dict:
+ return {
+ "_type": "ExtensionType",
+ "base": type.base.accept(self),
+ "extension": type.extension.accept(self),
}
diff --git a/tests/serializer/python.py b/tests/serializer/python.py
index bab3f8c..b090eea 100644
--- a/tests/serializer/python.py
+++ b/tests/serializer/python.py
@@ -16,11 +16,14 @@ from midas.ast.python import (
Function,
GetExpr,
IfStmt,
+ ListExpr,
LiteralExpr,
LogicalExpr,
MidasType,
ReturnStmt,
+ SliceExpr,
Stmt,
+ SubscriptExpr,
TernaryExpr,
TypeAssign,
UnaryExpr,
@@ -245,3 +248,24 @@ class PythonAstJsonSerializer(
"if_true": expr.if_true.accept(self),
"if_false": expr.if_false.accept(self),
}
+
+ def visit_list_expr(self, expr: ListExpr) -> dict:
+ return {
+ "_type": "ListExpr",
+ "items": [item.accept(self) for item in expr.items],
+ }
+
+ def visit_subscript_expr(self, expr: SubscriptExpr) -> dict:
+ return {
+ "_type": "SubscriptExpr",
+ "object": expr.object.accept(self),
+ "index": expr.index.accept(self),
+ }
+
+ def visit_slice_expr(self, expr: SliceExpr) -> dict:
+ return {
+ "_type": "SliceExpr",
+ "lower": self._serialize_optional(expr.lower),
+ "upper": self._serialize_optional(expr.upper),
+ "step": self._serialize_optional(expr.step),
+ }