Compare commits
35 Commits
main
...
feat/gener
| Author | SHA1 | Date | |
|---|---|---|---|
|
c6ead886ec
|
|||
|
9de03bf2b5
|
|||
|
a26b9293be
|
|||
|
efa5454776
|
|||
|
b8bb8190c4
|
|||
|
a4f5db7ece
|
|||
|
fc67f01f34
|
|||
|
0a748a36a3
|
|||
|
89fdd1b47e
|
|||
|
0cde53ac6e
|
|||
|
f3ec3606c2
|
|||
|
67ec029529
|
|||
|
e2aef7a811
|
|||
|
86ba4e658a
|
|||
|
7eccf59558
|
|||
|
9dd7801d2d
|
|||
|
154cb8b314
|
|||
|
c64ab434b5
|
|||
|
25e6410546
|
|||
|
8a22acc17c
|
|||
|
e0179bc442
|
|||
|
e665d03533
|
|||
|
b8cb2b4273
|
|||
|
d278dc5f5b
|
|||
|
59e73f0fd9
|
|||
|
3e0dc60283
|
|||
|
c24eb5125e
|
|||
|
25bd895dde
|
|||
|
bccd75317e
|
|||
|
f0e3f7574f
|
|||
|
5d44081847
|
|||
|
2a2bb0aec7
|
|||
|
67c40a3909
|
|||
|
1c30188122
|
|||
|
82a0f13242
|
11
examples/01_simple_type_checking/04_complex_types.midas
Normal file
11
examples/01_simple_type_checking/04_complex_types.midas
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
type Meter = float
|
||||||
|
|
||||||
|
extend Meter {
|
||||||
|
op __add__(Meter) -> Meter
|
||||||
|
op __sub__(Meter) -> Meter
|
||||||
|
}
|
||||||
|
|
||||||
|
type Coordinate = {
|
||||||
|
x: Meter
|
||||||
|
y: Meter
|
||||||
|
}
|
||||||
14
examples/01_simple_type_checking/04_complex_types.py
Normal file
14
examples/01_simple_type_checking/04_complex_types.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# type: ignore
|
||||||
|
# ruff: disable [F821]
|
||||||
|
p1: Coordinate
|
||||||
|
p2: Coordinate
|
||||||
|
|
||||||
|
diff_x = p2.x - p1.x
|
||||||
|
diff_y = p2.y - p1.y
|
||||||
|
|
||||||
|
dist = diff_x + diff_y
|
||||||
|
|
||||||
|
p2.x += cast(Meter, 1)
|
||||||
|
p2.y = True
|
||||||
|
p2.z = 3
|
||||||
|
p2.x.a = 3
|
||||||
15
gen/gen.py
15
gen/gen.py
@@ -30,6 +30,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
{preamble}
|
||||||
{sections}
|
{sections}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -57,6 +58,11 @@ IMPORTS_REGEX = re.compile(
|
|||||||
re.MULTILINE | re.DOTALL,
|
re.MULTILINE | re.DOTALL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
PREAMBLE_REGEX = re.compile(
|
||||||
|
r"^###>\s*Preamble\s*?\n(?P<body>.*?)\n###<$",
|
||||||
|
re.MULTILINE | re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def snake_case(text: str) -> str:
|
def snake_case(text: str) -> str:
|
||||||
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
||||||
@@ -88,13 +94,14 @@ def make_banner(text: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
||||||
|
print(f" Generating {full_name}")
|
||||||
visitor_methods: list[str] = []
|
visitor_methods: list[str] = []
|
||||||
classes: list[str] = []
|
classes: list[str] = []
|
||||||
definitions: list[str] = body.strip("\n").split("\n\n\n")
|
definitions: list[str] = body.strip("\n").split("\n\n\n")
|
||||||
for cls in definitions:
|
for cls in definitions:
|
||||||
cls = cls.strip("\n")
|
cls = cls.strip("\n")
|
||||||
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
|
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
|
||||||
print(f"Processing {name}")
|
print(f" Processing {name}")
|
||||||
visitor_methods.append(make_visitor_method(name, param))
|
visitor_methods.append(make_visitor_method(name, param))
|
||||||
classes.append(make_class(name, cls, base))
|
classes.append(make_class(name, cls, base))
|
||||||
|
|
||||||
@@ -107,6 +114,7 @@ def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def generate(definitions_path: Path, out_path: Path):
|
def generate(definitions_path: Path, out_path: Path):
|
||||||
|
print(f"Processing generating {out_path} from {definitions_path}")
|
||||||
root_dir: Path = Path(__file__).parent.parent
|
root_dir: Path = Path(__file__).parent.parent
|
||||||
rel_path: Path = definitions_path.relative_to(root_dir)
|
rel_path: Path = definitions_path.relative_to(root_dir)
|
||||||
src: str = definitions_path.read_text()
|
src: str = definitions_path.read_text()
|
||||||
@@ -116,6 +124,10 @@ def generate(definitions_path: Path, out_path: Path):
|
|||||||
if m := IMPORTS_REGEX.search(src):
|
if m := IMPORTS_REGEX.search(src):
|
||||||
imports = m.group("body").strip("\n")
|
imports = m.group("body").strip("\n")
|
||||||
|
|
||||||
|
preamble: str = ""
|
||||||
|
if m := PREAMBLE_REGEX.search(src):
|
||||||
|
preamble = m.group("body")
|
||||||
|
|
||||||
for section_m in SECTION_REGEX.finditer(src):
|
for section_m in SECTION_REGEX.finditer(src):
|
||||||
full_name: str = section_m.group("name")
|
full_name: str = section_m.group("name")
|
||||||
base: str = section_m.group("base")
|
base: str = section_m.group("base")
|
||||||
@@ -129,6 +141,7 @@ def generate(definitions_path: Path, out_path: Path):
|
|||||||
gen_path=Path(__file__).relative_to(root_dir),
|
gen_path=Path(__file__).relative_to(root_dir),
|
||||||
),
|
),
|
||||||
imports=imports,
|
imports=imports,
|
||||||
|
preamble=preamble,
|
||||||
sections="\n\n\n".join(sections),
|
sections="\n\n\n".join(sections),
|
||||||
)
|
)
|
||||||
out_path.write_text(result)
|
out_path.write_text(result)
|
||||||
|
|||||||
35
gen/midas.py
35
gen/midas.py
@@ -12,18 +12,23 @@ from midas.lexer.token import Token
|
|||||||
###<
|
###<
|
||||||
|
|
||||||
|
|
||||||
|
###> Preamble
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class TypeParam:
|
||||||
|
location: Location
|
||||||
|
name: Token
|
||||||
|
bound: Optional[Type]
|
||||||
|
|
||||||
|
|
||||||
|
###<
|
||||||
|
|
||||||
|
|
||||||
###> Stmt | Statements
|
###> Stmt | Statements
|
||||||
class TypeStmt:
|
class TypeStmt:
|
||||||
name: Token
|
name: Token
|
||||||
params: list[Param]
|
params: list[TypeParam]
|
||||||
type: Type
|
type: Type
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
|
||||||
class Param:
|
|
||||||
location: Location
|
|
||||||
name: Token
|
|
||||||
bound: Optional[Type]
|
|
||||||
|
|
||||||
|
|
||||||
class PropertyStmt:
|
class PropertyStmt:
|
||||||
name: Token
|
name: Token
|
||||||
@@ -31,6 +36,7 @@ class PropertyStmt:
|
|||||||
|
|
||||||
|
|
||||||
class ExtendStmt:
|
class ExtendStmt:
|
||||||
|
params: list[TypeParam]
|
||||||
type: Type
|
type: Type
|
||||||
operations: list[OpStmt]
|
operations: list[OpStmt]
|
||||||
|
|
||||||
@@ -103,7 +109,7 @@ class NamedType:
|
|||||||
|
|
||||||
class GenericType:
|
class GenericType:
|
||||||
type: Type
|
type: Type
|
||||||
params: list[Type]
|
args: list[Type]
|
||||||
|
|
||||||
|
|
||||||
class ConstraintType:
|
class ConstraintType:
|
||||||
@@ -115,4 +121,17 @@ class ComplexType:
|
|||||||
properties: list[PropertyStmt]
|
properties: list[PropertyStmt]
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionType:
|
||||||
|
pos_args: list[Argument]
|
||||||
|
kw_args: list[Argument]
|
||||||
|
returns: Type
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Argument:
|
||||||
|
location: Optional[Location] = None
|
||||||
|
name: Optional[Token]
|
||||||
|
type: Type
|
||||||
|
required: bool
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
|
|||||||
@@ -128,12 +128,6 @@ class LogicalExpr:
|
|||||||
right: Expr
|
right: Expr
|
||||||
|
|
||||||
|
|
||||||
class SetExpr:
|
|
||||||
object: Expr
|
|
||||||
name: str
|
|
||||||
value: Expr
|
|
||||||
|
|
||||||
|
|
||||||
class CastExpr:
|
class CastExpr:
|
||||||
type: MidasType
|
type: MidasType
|
||||||
expr: Expr
|
expr: Expr
|
||||||
@@ -145,4 +139,8 @@ class TernaryExpr:
|
|||||||
if_false: Expr
|
if_false: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class ListExpr:
|
||||||
|
items: list[Expr]
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
|
|||||||
@@ -14,6 +14,13 @@ from midas.lexer.token import Token
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class TypeParam:
|
||||||
|
location: Location
|
||||||
|
name: Token
|
||||||
|
bound: Optional[Type]
|
||||||
|
|
||||||
|
|
||||||
##############
|
##############
|
||||||
# Statements #
|
# Statements #
|
||||||
##############
|
##############
|
||||||
@@ -46,15 +53,9 @@ class Stmt(ABC):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TypeStmt(Stmt):
|
class TypeStmt(Stmt):
|
||||||
name: Token
|
name: Token
|
||||||
params: list[Param]
|
params: list[TypeParam]
|
||||||
type: Type
|
type: Type
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
|
||||||
class Param:
|
|
||||||
location: Location
|
|
||||||
name: Token
|
|
||||||
bound: Optional[Type]
|
|
||||||
|
|
||||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
return visitor.visit_type_stmt(self)
|
return visitor.visit_type_stmt(self)
|
||||||
|
|
||||||
@@ -70,6 +71,7 @@ class PropertyStmt(Stmt):
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ExtendStmt(Stmt):
|
class ExtendStmt(Stmt):
|
||||||
|
params: list[TypeParam]
|
||||||
type: Type
|
type: Type
|
||||||
operations: list[OpStmt]
|
operations: list[OpStmt]
|
||||||
|
|
||||||
@@ -231,6 +233,9 @@ class Type(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_complex_type(self, type: ComplexType) -> T: ...
|
def visit_complex_type(self, type: ComplexType) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_function_type(self, type: FunctionType) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class NamedType(Type):
|
class NamedType(Type):
|
||||||
@@ -243,7 +248,7 @@ class NamedType(Type):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class GenericType(Type):
|
class GenericType(Type):
|
||||||
type: Type
|
type: Type
|
||||||
params: list[Type]
|
args: list[Type]
|
||||||
|
|
||||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||||
return visitor.visit_generic_type(self)
|
return visitor.visit_generic_type(self)
|
||||||
@@ -264,3 +269,20 @@ class ComplexType(Type):
|
|||||||
|
|
||||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||||
return visitor.visit_complex_type(self)
|
return visitor.visit_complex_type(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FunctionType(Type):
|
||||||
|
pos_args: list[Argument]
|
||||||
|
kw_args: list[Argument]
|
||||||
|
returns: Type
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Argument:
|
||||||
|
location: Optional[Location] = None
|
||||||
|
name: Optional[Token]
|
||||||
|
type: Type
|
||||||
|
required: bool
|
||||||
|
|
||||||
|
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_function_type(self)
|
||||||
|
|||||||
@@ -100,12 +100,12 @@ class MidasAstPrinter(
|
|||||||
self._idx = i
|
self._idx = i
|
||||||
if i == len(stmt.params) - 1:
|
if i == len(stmt.params) - 1:
|
||||||
self._mark_last()
|
self._mark_last()
|
||||||
self._print_type_stmt_param(param)
|
self._print_type_param(param)
|
||||||
self._write_line("type", last=True)
|
self._write_line("type", last=True)
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
stmt.type.accept(self)
|
stmt.type.accept(self)
|
||||||
|
|
||||||
def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None:
|
def _print_type_param(self, param: m.TypeParam) -> None:
|
||||||
self._write_line("Param")
|
self._write_line("Param")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
self._write_line(f'name: "{param.name.lexeme}"')
|
self._write_line(f'name: "{param.name.lexeme}"')
|
||||||
@@ -122,6 +122,13 @@ class MidasAstPrinter(
|
|||||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||||
self._write_line("ExtendStmt")
|
self._write_line("ExtendStmt")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
|
self._write_line("params")
|
||||||
|
with self._child_level():
|
||||||
|
for i, param in enumerate(stmt.params):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.params) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_type_param(param)
|
||||||
self._write_line("type")
|
self._write_line("type")
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
stmt.type.accept(self)
|
stmt.type.accept(self)
|
||||||
@@ -234,11 +241,11 @@ class MidasAstPrinter(
|
|||||||
self._write_line("type")
|
self._write_line("type")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
type.type.accept(self)
|
type.type.accept(self)
|
||||||
self._write_line("params", last=True)
|
self._write_line("args", last=True)
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
for i, param in enumerate(type.params):
|
for i, param in enumerate(type.args):
|
||||||
self._idx = i
|
self._idx = i
|
||||||
if i == len(type.params) - 1:
|
if i == len(type.args) - 1:
|
||||||
self._mark_last()
|
self._mark_last()
|
||||||
param.accept(self)
|
param.accept(self)
|
||||||
|
|
||||||
@@ -263,6 +270,41 @@ class MidasAstPrinter(
|
|||||||
self._mark_last()
|
self._mark_last()
|
||||||
prop.accept(self)
|
prop.accept(self)
|
||||||
|
|
||||||
|
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||||
|
self._write_line("FunctionType")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("pos_args")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(type.pos_args):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(type.pos_args) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_function_arg(arg)
|
||||||
|
|
||||||
|
self._write_line("kw_args")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(type.kw_args):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(type.kw_args) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_function_arg(arg)
|
||||||
|
|
||||||
|
self._write_line("returns", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
type.returns.accept(self)
|
||||||
|
|
||||||
|
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
||||||
|
self._write_line("Argument")
|
||||||
|
with self._child_level():
|
||||||
|
name: str = "None"
|
||||||
|
if arg.name is not None:
|
||||||
|
name = f'"{arg.name.lexeme}"'
|
||||||
|
self._write_line(f"name: {name}")
|
||||||
|
self._write_line("type")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
arg.type.accept(self)
|
||||||
|
self._write_line(f"required: {arg.required}", last=True)
|
||||||
|
|
||||||
|
|
||||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
||||||
def __init__(self, indent: int = 4):
|
def __init__(self, indent: int = 4):
|
||||||
@@ -279,14 +321,12 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
|
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
|
||||||
template: str = ""
|
template: str = ""
|
||||||
if len(stmt.params) != 0:
|
if len(stmt.params) != 0:
|
||||||
params: list[str] = [
|
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||||
self._print_type_template_param(param) for param in stmt.params
|
|
||||||
]
|
|
||||||
template = f"[{', '.join(params)}]"
|
template = f"[{', '.join(params)}]"
|
||||||
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
||||||
return self.indented(res)
|
return self.indented(res)
|
||||||
|
|
||||||
def _print_type_template_param(self, param: m.TypeStmt.Param) -> str:
|
def _print_type_param(self, param: m.TypeParam) -> str:
|
||||||
res: str = param.name.lexeme
|
res: str = param.name.lexeme
|
||||||
if param.bound is not None:
|
if param.bound is not None:
|
||||||
res += "<:" + param.bound.accept(self)
|
res += "<:" + param.bound.accept(self)
|
||||||
@@ -358,9 +398,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
|
|
||||||
def visit_generic_type(self, type: m.GenericType) -> str:
|
def visit_generic_type(self, type: m.GenericType) -> str:
|
||||||
res: str = type.type.accept(self)
|
res: str = type.type.accept(self)
|
||||||
if len(type.params) != 0:
|
if len(type.args) != 0:
|
||||||
params: list[str] = [param.accept(self) for param in type.params]
|
args: list[str] = [param.accept(self) for param in type.args]
|
||||||
res += f"[{', '.join(params)}]"
|
res += f"[{', '.join(args)}]"
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def visit_constraint_type(self, type: m.ConstraintType) -> str:
|
def visit_constraint_type(self, type: m.ConstraintType) -> str:
|
||||||
@@ -378,6 +418,29 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
res += self.indented("}")
|
res += self.indented("}")
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||||
|
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
||||||
|
kw_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
||||||
|
args: list[str] = pos_args
|
||||||
|
|
||||||
|
if len(pos_args) != 0:
|
||||||
|
args.append("/")
|
||||||
|
if len(kw_args) != 0:
|
||||||
|
args.append("*")
|
||||||
|
args += kw_args
|
||||||
|
|
||||||
|
return f"({', '.join(args)}) -> {type.returns.accept(self)}"
|
||||||
|
|
||||||
|
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
||||||
|
res: str = ""
|
||||||
|
if arg.name is not None:
|
||||||
|
res += arg.name.lexeme
|
||||||
|
res += ": "
|
||||||
|
res += arg.type.accept(self)
|
||||||
|
if not arg.required:
|
||||||
|
res += "?"
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class PythonAstPrinter(
|
class PythonAstPrinter(
|
||||||
AstPrinter,
|
AstPrinter,
|
||||||
@@ -602,17 +665,6 @@ class PythonAstPrinter(
|
|||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
expr.right.accept(self)
|
expr.right.accept(self)
|
||||||
|
|
||||||
def visit_set_expr(self, expr: p.SetExpr) -> None:
|
|
||||||
self._write_line("SetExpr")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line("object")
|
|
||||||
with self._child_level(single=True):
|
|
||||||
expr.object.accept(self)
|
|
||||||
self._write_line(f"name: {expr.name}")
|
|
||||||
self._write_line("value", last=True)
|
|
||||||
with self._child_level(single=True):
|
|
||||||
expr.value.accept(self)
|
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||||
self._write_line("CastExpr")
|
self._write_line("CastExpr")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
@@ -637,3 +689,14 @@ class PythonAstPrinter(
|
|||||||
self._write_line("if_false", last=True)
|
self._write_line("if_false", last=True)
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
expr.if_false.accept(self)
|
expr.if_false.accept(self)
|
||||||
|
|
||||||
|
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||||
|
self._write_line("ListExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("items", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, item in enumerate(expr.items):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(expr.items) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
item.accept(self)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from midas.ast.location import Location
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Type annotations #
|
# Type annotations #
|
||||||
####################
|
####################
|
||||||
@@ -214,15 +215,15 @@ class Expr(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
|
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_set_expr(self, expr: SetExpr) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_cast_expr(self, expr: CastExpr) -> T: ...
|
def visit_cast_expr(self, expr: CastExpr) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
|
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_list_expr(self, expr: ListExpr) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class BinaryExpr(Expr):
|
class BinaryExpr(Expr):
|
||||||
@@ -298,16 +299,6 @@ class LogicalExpr(Expr):
|
|||||||
return visitor.visit_logical_expr(self)
|
return visitor.visit_logical_expr(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class SetExpr(Expr):
|
|
||||||
object: Expr
|
|
||||||
name: str
|
|
||||||
value: Expr
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_set_expr(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class CastExpr(Expr):
|
class CastExpr(Expr):
|
||||||
type: MidasType
|
type: MidasType
|
||||||
@@ -325,3 +316,11 @@ class TernaryExpr(Expr):
|
|||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
return visitor.visit_ternary_expr(self)
|
return visitor.visit_ternary_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ListExpr(Expr):
|
||||||
|
items: list[Expr]
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_list_expr(self)
|
||||||
|
|||||||
112
midas/checker/builtins.py
Normal file
112
midas/checker/builtins.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from midas.checker.types import (
|
||||||
|
BaseType,
|
||||||
|
ComplexType,
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
UnitType,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
|
||||||
|
|
||||||
|
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||||
|
"float": {"int"},
|
||||||
|
"int": {"bool"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def op(reg: TypesRegistry, t1: Type, operator: str, t2: Type, t3: Type):
|
||||||
|
reg.define_operation(
|
||||||
|
left=t1,
|
||||||
|
operator=operator,
|
||||||
|
right=t2,
|
||||||
|
result=t3,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def basic_op(reg: TypesRegistry, type: Type, op: str):
|
||||||
|
reg.define_operation(
|
||||||
|
left=type,
|
||||||
|
operator=op,
|
||||||
|
right=type,
|
||||||
|
result=type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def define_builtins(reg: TypesRegistry):
|
||||||
|
"""Define builtin types and operations"""
|
||||||
|
unit = reg.define_type("None", UnitType())
|
||||||
|
bool = reg.define_type("bool", BaseType(name="bool"))
|
||||||
|
int = reg.define_type("int", BaseType(name="int"))
|
||||||
|
float = reg.define_type("float", BaseType(name="float"))
|
||||||
|
str = reg.define_type("str", BaseType(name="str"))
|
||||||
|
|
||||||
|
basic_op(reg, int, "__add__") # int + int = int
|
||||||
|
basic_op(reg, int, "__sub__") # int - int = int
|
||||||
|
basic_op(reg, int, "__mul__") # int * int = int
|
||||||
|
basic_op(reg, int, "__pow__") # int ** int = int
|
||||||
|
basic_op(reg, int, "__mod__") # int % int = int
|
||||||
|
basic_op(reg, int, "__and__") # int & int = int
|
||||||
|
basic_op(reg, int, "__or__") # int | int = int
|
||||||
|
basic_op(reg, int, "__xor__") # int ^ int = int
|
||||||
|
op(reg, int, "__lt__", int, bool) # int < int = bool
|
||||||
|
op(reg, int, "__gt__", int, bool) # int > int = bool
|
||||||
|
op(reg, int, "__le__", int, bool) # int <= int = bool
|
||||||
|
op(reg, int, "__ge__", int, bool) # int >= int = bool
|
||||||
|
op(reg, int, "__eq__", int, bool) # int == int = bool
|
||||||
|
basic_op(reg, float, "__add__") # float + float = float
|
||||||
|
basic_op(reg, float, "__sub__") # float - float = float
|
||||||
|
basic_op(reg, float, "__mul__") # float * float = float
|
||||||
|
basic_op(reg, float, "__truediv__") # float / float = float
|
||||||
|
op(reg, float, "__lt__", float, bool) # float < float = bool
|
||||||
|
op(reg, float, "__gt__", float, bool) # float > float = bool
|
||||||
|
op(reg, float, "__le__", float, bool) # float <= float = bool
|
||||||
|
op(reg, float, "__ge__", float, bool) # float >= float = bool
|
||||||
|
op(reg, float, "__eq__", float, bool) # float == float = bool
|
||||||
|
basic_op(reg, str, "__add__") # str + str = str
|
||||||
|
op(reg, str, "__eq__", str, bool) # str == str = bool
|
||||||
|
|
||||||
|
op(reg, int, "__lt__", float, bool) # int < float = bool
|
||||||
|
op(reg, int, "__gt__", float, bool) # int > float = bool
|
||||||
|
op(reg, int, "__le__", float, bool) # int <= float = bool
|
||||||
|
op(reg, int, "__ge__", float, bool) # int >= float = bool
|
||||||
|
op(reg, int, "__eq__", float, bool) # int == float = bool
|
||||||
|
|
||||||
|
op(reg, float, "__lt__", int, bool) # float < int = bool
|
||||||
|
op(reg, float, "__gt__", int, bool) # float > int = bool
|
||||||
|
op(reg, float, "__le__", int, bool) # float <= int = bool
|
||||||
|
op(reg, float, "__ge__", int, bool) # float >= int = bool
|
||||||
|
op(reg, float, "__eq__", int, bool) # float == int = bool
|
||||||
|
|
||||||
|
list = reg.define_type(
|
||||||
|
"list",
|
||||||
|
GenericType(
|
||||||
|
name="list",
|
||||||
|
params=[TypeVar(name="T", bound=None)],
|
||||||
|
body=ComplexType(
|
||||||
|
properties={
|
||||||
|
"append": Function(
|
||||||
|
name="append",
|
||||||
|
pos_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="object",
|
||||||
|
type=TypeVar(name="T", bound=None),
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
args=[],
|
||||||
|
kw_args=[],
|
||||||
|
returns=UnitType(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
@@ -1,540 +1,35 @@
|
|||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import midas.ast.midas as m
|
from midas.checker.diagnostic import Diagnostic
|
||||||
import midas.ast.python as p
|
from midas.checker.midas import MidasTyper
|
||||||
from midas.ast.location import Location
|
from midas.checker.python import PythonTyper
|
||||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.reporter import Reporter
|
||||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
|
|
||||||
from midas.checker.types import Function, Type, UnitType, UnknownType
|
|
||||||
from midas.lexer.midas import MidasLexer
|
|
||||||
from midas.lexer.token import Token
|
|
||||||
from midas.parser.midas import MidasParser
|
|
||||||
from midas.resolver.midas import MidasResolver
|
|
||||||
|
|
||||||
|
|
||||||
class ReturnException(Exception):
|
class TypeChecker:
|
||||||
pass
|
def __init__(self):
|
||||||
|
self.types: TypesRegistry = TypesRegistry()
|
||||||
|
self.reporter: Reporter = Reporter()
|
||||||
|
|
||||||
|
self.midas_typer = MidasTyper(self.types, self.reporter)
|
||||||
|
self.python_typer = PythonTyper(self.types, self.reporter)
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
def import_midas(self, path: Path):
|
||||||
class MappedArgument:
|
source: str = path.read_text()
|
||||||
expr: p.Expr
|
return self.import_midas_source(source, path=str(path))
|
||||||
type: Type
|
|
||||||
argument: Function.Argument
|
|
||||||
|
|
||||||
|
def import_midas_source(self, source: str, path: Optional[str] = None):
|
||||||
|
self.midas_typer.process(source, path)
|
||||||
|
|
||||||
class Checker(
|
def type_check(self, path: Path):
|
||||||
p.Stmt.Visitor[None],
|
source: str = path.read_text()
|
||||||
p.Expr.Visitor[Type],
|
return self.type_check_source(source, path=str(path))
|
||||||
p.MidasType.Visitor[Type],
|
|
||||||
):
|
|
||||||
"""A type checker which can use custom type definitions"""
|
|
||||||
|
|
||||||
def __init__(
|
def type_check_source(self, source: str, path: Optional[str] = None):
|
||||||
self,
|
self.python_typer.process(source, path)
|
||||||
locals: dict[p.Expr, int],
|
|
||||||
source_path: Path,
|
|
||||||
types_paths: list[Path],
|
|
||||||
):
|
|
||||||
self.logger: logging.Logger = logging.getLogger("Checker")
|
|
||||||
self.source_path: Path = source_path
|
|
||||||
self.types_paths: list[Path] = types_paths
|
|
||||||
self.ctx: MidasResolver = MidasResolver()
|
|
||||||
self.global_env: Environment = Environment()
|
|
||||||
self.env: Environment = self.global_env
|
|
||||||
self.locals: dict[p.Expr, int] = locals
|
|
||||||
self.diagnostics: list[Diagnostic] = []
|
|
||||||
|
|
||||||
def diagnostic(self, type: DiagnosticType, location: Location, message: str):
|
@property
|
||||||
self.diagnostics.append(
|
def diagnostics(self) -> list[Diagnostic]:
|
||||||
Diagnostic(
|
return self.reporter.diagnostics
|
||||||
file_path=self.source_path,
|
|
||||||
location=location,
|
|
||||||
type=type,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def error(self, location: Location, message: str):
|
|
||||||
self.diagnostic(
|
|
||||||
type=DiagnosticType.ERROR,
|
|
||||||
location=location,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
|
|
||||||
def warning(self, location: Location, message: str):
|
|
||||||
self.diagnostic(
|
|
||||||
type=DiagnosticType.WARNING,
|
|
||||||
location=location,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
|
|
||||||
def info(self, location: Location, message: str):
|
|
||||||
self.diagnostic(
|
|
||||||
type=DiagnosticType.INFO,
|
|
||||||
location=location,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
|
|
||||||
def type_of(self, expr: p.Expr) -> Type:
|
|
||||||
"""Evaluate the type of an expression
|
|
||||||
|
|
||||||
Args:
|
|
||||||
expr (p.Expr): the expression to evaluate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Type: the type of the given expression
|
|
||||||
"""
|
|
||||||
return expr.accept(self)
|
|
||||||
|
|
||||||
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
|
||||||
"""Evaluate a sequence of statements
|
|
||||||
|
|
||||||
Args:
|
|
||||||
block (list[p.Stmt]): the statements to evaluate
|
|
||||||
env (Environment): the environment in which to evaluate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: whether a return statement is present in the block
|
|
||||||
"""
|
|
||||||
previous_env: Environment = self.env
|
|
||||||
self.env = env
|
|
||||||
returned: bool = False
|
|
||||||
for i, stmt in enumerate(block):
|
|
||||||
try:
|
|
||||||
stmt.accept(self)
|
|
||||||
except ReturnException:
|
|
||||||
returned = True
|
|
||||||
if i < len(block) - 1:
|
|
||||||
self.warning(block[i + 1].location, "Unreachable statement")
|
|
||||||
break
|
|
||||||
self.env = previous_env
|
|
||||||
return returned
|
|
||||||
|
|
||||||
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
|
|
||||||
"""Type check a sequence of statements and returns diagnostics
|
|
||||||
|
|
||||||
Args:
|
|
||||||
statements (list[p.Stmt]): the statements to evaluate and check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[Diagnostic]: the list of diagnostics (errors, warning, etc.)
|
|
||||||
"""
|
|
||||||
self.diagnostics = []
|
|
||||||
|
|
||||||
for path in self.types_paths:
|
|
||||||
self.import_midas(path)
|
|
||||||
self.logger.debug(f"Midas types: {self.ctx._types}")
|
|
||||||
self.logger.debug(f"Midas operations: {self.ctx._operations}")
|
|
||||||
|
|
||||||
for stmt in statements:
|
|
||||||
stmt.accept(self)
|
|
||||||
|
|
||||||
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
|
|
||||||
return self.diagnostics
|
|
||||||
|
|
||||||
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
|
|
||||||
"""Look up a variable in the environment it was declared
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): the name of the variable
|
|
||||||
expr (p.Expr): the variable expression, used to lookup the scope distance
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Type]: the type of the variable, or None if it was not found
|
|
||||||
"""
|
|
||||||
distance: Optional[int] = self.locals.get(expr)
|
|
||||||
if distance is not None:
|
|
||||||
return self.env.get_at(distance, name)
|
|
||||||
return self.global_env.get(name)
|
|
||||||
|
|
||||||
def import_midas(self, path: Path) -> None:
|
|
||||||
"""Import Midas definitions from a path
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (Path): the import path
|
|
||||||
"""
|
|
||||||
self.logger.debug(f"Importing type definitions from {path}")
|
|
||||||
lexer: MidasLexer = MidasLexer(path.read_text())
|
|
||||||
tokens: list[Token] = lexer.process()
|
|
||||||
parser: MidasParser = MidasParser(tokens)
|
|
||||||
stmts: list[m.Stmt] = parser.parse()
|
|
||||||
self.ctx.resolve(stmts)
|
|
||||||
|
|
||||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
|
||||||
self.type_of(stmt.expr)
|
|
||||||
|
|
||||||
def visit_function(self, stmt: p.Function) -> None:
|
|
||||||
env: Environment = Environment(self.env)
|
|
||||||
pos_args: list[Function.Argument] = []
|
|
||||||
args: list[Function.Argument] = []
|
|
||||||
kw_args: list[Function.Argument] = []
|
|
||||||
|
|
||||||
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
|
||||||
if arg.type is not None:
|
|
||||||
return arg.type.accept(self)
|
|
||||||
if arg.default is not None:
|
|
||||||
return arg.default.accept(self)
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
for arg in stmt.posonlyargs:
|
|
||||||
pos_args.append(
|
|
||||||
Function.Argument(
|
|
||||||
name=arg.name,
|
|
||||||
type=eval_arg_type(arg),
|
|
||||||
required=arg.default is None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for arg in stmt.args:
|
|
||||||
args.append(
|
|
||||||
Function.Argument(
|
|
||||||
name=arg.name,
|
|
||||||
type=eval_arg_type(arg),
|
|
||||||
required=arg.default is None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for arg in stmt.kwonlyargs:
|
|
||||||
kw_args.append(
|
|
||||||
Function.Argument(
|
|
||||||
name=arg.name,
|
|
||||||
type=eval_arg_type(arg),
|
|
||||||
required=arg.default is None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for arg in pos_args + args + kw_args:
|
|
||||||
env.define(arg.name, arg.type)
|
|
||||||
|
|
||||||
returns_hint: Optional[Type] = None
|
|
||||||
if stmt.returns is not None:
|
|
||||||
returns_hint = stmt.returns.accept(self)
|
|
||||||
# Early define to handle simple fully-typed recursion
|
|
||||||
inside_function: Function = Function(
|
|
||||||
name=stmt.name,
|
|
||||||
pos_args=pos_args,
|
|
||||||
args=args,
|
|
||||||
kw_args=kw_args,
|
|
||||||
returns=returns_hint,
|
|
||||||
)
|
|
||||||
self.env.define(stmt.name, inside_function)
|
|
||||||
|
|
||||||
returned: bool = self.process_block(stmt.body, env)
|
|
||||||
inferred_return: Type = UnknownType()
|
|
||||||
if not returned:
|
|
||||||
env.return_types.append(UnitType())
|
|
||||||
return_types: set[Type] = set(env.return_types)
|
|
||||||
if len(return_types) == 1:
|
|
||||||
inferred_return = list(return_types)[0]
|
|
||||||
elif len(return_types) > 1:
|
|
||||||
self.error(
|
|
||||||
stmt.location,
|
|
||||||
f"Mixed return types: {env.return_types}",
|
|
||||||
)
|
|
||||||
|
|
||||||
returns: Type = UnknownType()
|
|
||||||
if returns_hint is not None:
|
|
||||||
assert stmt.returns is not None
|
|
||||||
returns = returns_hint
|
|
||||||
if returns != inferred_return:
|
|
||||||
self.error(
|
|
||||||
stmt.returns.location,
|
|
||||||
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
returns = inferred_return
|
|
||||||
|
|
||||||
# TODO: handle *args and **kwargs sinks
|
|
||||||
function: Function = Function(
|
|
||||||
name=stmt.name,
|
|
||||||
pos_args=pos_args,
|
|
||||||
args=args,
|
|
||||||
kw_args=kw_args,
|
|
||||||
returns=returns,
|
|
||||||
)
|
|
||||||
self.env.define(stmt.name, function)
|
|
||||||
|
|
||||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
|
||||||
# TODO check not yet defined locally
|
|
||||||
type: Type = stmt.type.accept(self)
|
|
||||||
self.env.define(stmt.name, type)
|
|
||||||
|
|
||||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
|
||||||
value: Type = self.type_of(stmt.value)
|
|
||||||
for target in stmt.targets:
|
|
||||||
if not isinstance(target, p.VariableExpr):
|
|
||||||
self.logger.warning(f"Unsupported assignment to {target}")
|
|
||||||
self.warning(target.location, f"Unsupported assignment to {target}")
|
|
||||||
continue
|
|
||||||
name: str = target.name
|
|
||||||
var_type: Optional[Type] = self.look_up_variable(name, target)
|
|
||||||
|
|
||||||
if var_type is None:
|
|
||||||
self.env.define(name, value)
|
|
||||||
else:
|
|
||||||
# TODO: implement real comparison method
|
|
||||||
if var_type != value:
|
|
||||||
self.error(
|
|
||||||
stmt.location,
|
|
||||||
f"Cannot assign {value} to {name} of type {var_type}",
|
|
||||||
)
|
|
||||||
|
|
||||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
|
||||||
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
|
|
||||||
self.env.return_types.append(type)
|
|
||||||
raise ReturnException()
|
|
||||||
|
|
||||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
|
||||||
# Not evaluated in sub-environment because assignments in the test leak out of the if
|
|
||||||
# For example:
|
|
||||||
# if (m := 1 + 1) < 2:
|
|
||||||
# ...
|
|
||||||
# print(m) # <- m is still defined
|
|
||||||
test_type: Type = stmt.test.accept(self)
|
|
||||||
|
|
||||||
# TODO Allow subtypes or any type
|
|
||||||
if test_type != self.ctx.get_type("bool"):
|
|
||||||
self.error(
|
|
||||||
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
env: Environment = Environment(self.env)
|
|
||||||
body_returned: bool = self.process_block(stmt.body, env)
|
|
||||||
else_returned: bool = self.process_block(stmt.orelse, env)
|
|
||||||
self.env.return_types.extend(env.return_types)
|
|
||||||
if body_returned and else_returned:
|
|
||||||
raise ReturnException()
|
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
|
||||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
|
||||||
if method is None:
|
|
||||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
|
||||||
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
|
||||||
return UnknownType()
|
|
||||||
left: Type = self.type_of(expr.left)
|
|
||||||
right: Type = self.type_of(expr.right)
|
|
||||||
|
|
||||||
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
|
||||||
if result is None:
|
|
||||||
self.error(
|
|
||||||
expr.location,
|
|
||||||
f"Undefined operation {method} between {left} and {right}",
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
|
||||||
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
|
||||||
if method is None:
|
|
||||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
|
||||||
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
|
||||||
return UnknownType()
|
|
||||||
left: Type = self.type_of(expr.left)
|
|
||||||
right: Type = self.type_of(expr.right)
|
|
||||||
|
|
||||||
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
|
||||||
if result is None:
|
|
||||||
self.error(
|
|
||||||
expr.location,
|
|
||||||
f"Undefined operation {method} between {left} and {right}",
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
|
|
||||||
|
|
||||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
|
||||||
callee: Type = self.type_of(expr.callee)
|
|
||||||
if not isinstance(callee, Function):
|
|
||||||
self.error(expr.callee.location, "Callee is not a function")
|
|
||||||
return UnknownType()
|
|
||||||
function: Function = callee
|
|
||||||
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
|
||||||
for arg in mapped:
|
|
||||||
if arg.type != arg.argument.type:
|
|
||||||
self.error(
|
|
||||||
arg.expr.location,
|
|
||||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
|
||||||
)
|
|
||||||
return function.returns
|
|
||||||
|
|
||||||
def visit_get_expr(self, expr: p.GetExpr) -> Type: ...
|
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
|
|
||||||
match expr.value:
|
|
||||||
case bool(): # Must be before int
|
|
||||||
return self.ctx.get_type("bool")
|
|
||||||
case int():
|
|
||||||
return self.ctx.get_type("int")
|
|
||||||
case float():
|
|
||||||
return self.ctx.get_type("float")
|
|
||||||
case str():
|
|
||||||
return self.ctx.get_type("str")
|
|
||||||
case _:
|
|
||||||
self.warning(expr.location, f"Unknown literal {expr}")
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
|
|
||||||
return self.look_up_variable(expr.name, expr) or UnknownType()
|
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
|
|
||||||
left: Type = expr.left.accept(self)
|
|
||||||
right: Type = expr.right.accept(self)
|
|
||||||
# TODO: union type
|
|
||||||
if left != right:
|
|
||||||
self.error(
|
|
||||||
expr.location,
|
|
||||||
f"Operands must be of the same type, left={left} != right={right}",
|
|
||||||
)
|
|
||||||
return left
|
|
||||||
|
|
||||||
def visit_set_expr(self, expr: p.SetExpr) -> Type: ...
|
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
|
||||||
return expr.type.accept(self)
|
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
|
||||||
test_type: Type = expr.test.accept(self)
|
|
||||||
|
|
||||||
# TODO Allow subtypes or any type
|
|
||||||
if test_type != self.ctx.get_type("bool"):
|
|
||||||
self.error(
|
|
||||||
expr.test.location, f"If test must be a boolean, got {test_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
true_type: Type = expr.if_true.accept(self)
|
|
||||||
false_type: Type = expr.if_false.accept(self)
|
|
||||||
if true_type != false_type:
|
|
||||||
self.error(
|
|
||||||
expr.location,
|
|
||||||
f"Type mismatch in ternary if branches: true={true_type} != false={false_type}",
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
return true_type
|
|
||||||
|
|
||||||
def visit_base_type(self, node: p.BaseType) -> Type:
|
|
||||||
return self.ctx.get_type(node.base)
|
|
||||||
|
|
||||||
def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
|
|
||||||
|
|
||||||
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
|
|
||||||
|
|
||||||
def visit_frame_type(self, node: p.FrameType) -> Type: ...
|
|
||||||
|
|
||||||
def map_call_arguments(
|
|
||||||
self, function: Function, call: p.CallExpr
|
|
||||||
) -> list[MappedArgument]:
|
|
||||||
"""Map call arguments to function parameters as defined in its signature
|
|
||||||
|
|
||||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
|
||||||
with the arguments passed at the call site
|
|
||||||
|
|
||||||
Any mismatched, missing or unexpected argument is reported as a diagnostic
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function (Function): the function definition
|
|
||||||
call (p.CallExpr): the call expression
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[MappedArgument]: the list of mapped arguments
|
|
||||||
"""
|
|
||||||
positional: list[tuple[p.Expr, Type]] = [
|
|
||||||
(arg, self.type_of(arg)) for arg in call.arguments
|
|
||||||
]
|
|
||||||
keywords: dict[str, tuple[p.Expr, Type]] = {
|
|
||||||
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
|
|
||||||
}
|
|
||||||
set_args: set[str] = set()
|
|
||||||
|
|
||||||
required_positional: list[str] = [
|
|
||||||
arg.name for arg in function.pos_args + function.args if arg.required
|
|
||||||
]
|
|
||||||
required_keyword: list[str] = [
|
|
||||||
arg.name for arg in function.kw_args if arg.required
|
|
||||||
]
|
|
||||||
|
|
||||||
mapped: list[MappedArgument] = []
|
|
||||||
|
|
||||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
|
||||||
mixed_params: list[Function.Argument] = list(function.args)
|
|
||||||
kw_params: dict[str, Function.Argument] = {
|
|
||||||
arg.name: arg for arg in function.kw_args
|
|
||||||
}
|
|
||||||
|
|
||||||
# TODO: handle *args and **kwargs sinks
|
|
||||||
for arg in positional:
|
|
||||||
param: Function.Argument
|
|
||||||
if len(pos_params) != 0:
|
|
||||||
param = pos_params.pop(0)
|
|
||||||
elif len(mixed_params) != 0:
|
|
||||||
param = mixed_params.pop(0)
|
|
||||||
else:
|
|
||||||
self.error(arg[0].location, "Too many positional arguments")
|
|
||||||
break
|
|
||||||
name: str = param.name
|
|
||||||
if name in required_positional:
|
|
||||||
required_positional.remove(name)
|
|
||||||
if name in required_keyword:
|
|
||||||
required_keyword.remove(name)
|
|
||||||
set_args.add(name)
|
|
||||||
mapped.append(
|
|
||||||
MappedArgument(
|
|
||||||
expr=arg[0],
|
|
||||||
type=arg[1],
|
|
||||||
argument=param,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
|
||||||
for name, arg in keywords.items():
|
|
||||||
param: Function.Argument
|
|
||||||
if name not in kw_params:
|
|
||||||
if name in set_args:
|
|
||||||
self.error(
|
|
||||||
arg[0].location, f"Multiple values for argument '{name}'"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.error(arg[0].location, f"Unknown keyword argument '{name}'")
|
|
||||||
continue
|
|
||||||
param = kw_params.pop(name)
|
|
||||||
if name in required_positional:
|
|
||||||
required_positional.remove(name)
|
|
||||||
if name in required_keyword:
|
|
||||||
required_keyword.remove(name)
|
|
||||||
set_args.add(name)
|
|
||||||
mapped.append(
|
|
||||||
MappedArgument(
|
|
||||||
expr=arg[0],
|
|
||||||
type=arg[1],
|
|
||||||
argument=param,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def join_args(args: list[str]) -> str:
|
|
||||||
args = list(map(lambda a: f"'{a}'", args))
|
|
||||||
if len(args) == 0:
|
|
||||||
return ""
|
|
||||||
if len(args) == 1:
|
|
||||||
return args[0]
|
|
||||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
|
||||||
|
|
||||||
if len(required_positional) != 0:
|
|
||||||
plural: str = "" if len(required_positional) == 1 else "s"
|
|
||||||
args: str = join_args(required_positional)
|
|
||||||
self.error(
|
|
||||||
call.location,
|
|
||||||
f"Missing required positional argument{plural}: {args}",
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(required_keyword) != 0:
|
|
||||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
|
||||||
args: str = join_args(required_keyword)
|
|
||||||
self.error(
|
|
||||||
call.location,
|
|
||||||
f"Missing required keyword argument{plural}: {args}",
|
|
||||||
)
|
|
||||||
|
|
||||||
return mapped
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
@@ -14,12 +13,13 @@ class DiagnosticType(StrEnum):
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Diagnostic:
|
class Diagnostic:
|
||||||
file_path: Path
|
file_path: Optional[str]
|
||||||
location: Location
|
location: Location
|
||||||
type: DiagnosticType
|
type: DiagnosticType
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
def __str__(self) -> str:
|
@property
|
||||||
|
def location_str(self) -> str:
|
||||||
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
|
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
|
||||||
end_loc: Optional[str] = ""
|
end_loc: Optional[str] = ""
|
||||||
if (
|
if (
|
||||||
@@ -27,7 +27,16 @@ class Diagnostic:
|
|||||||
and self.location.end_col_offset is not None
|
and self.location.end_col_offset is not None
|
||||||
):
|
):
|
||||||
end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}"
|
end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}"
|
||||||
loc: str = (
|
|
||||||
f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}"
|
loc: str = ""
|
||||||
)
|
if self.file_path is not None:
|
||||||
return f"{self.type} in {self.file_path} {loc}: {self.message}"
|
loc += f" in {self.file_path}"
|
||||||
|
if end_loc is None:
|
||||||
|
loc += f" at {start_loc}"
|
||||||
|
else:
|
||||||
|
loc += f" from {start_loc} to {end_loc}"
|
||||||
|
|
||||||
|
return f"{self.type}{loc}"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.location_str}: {self.message}"
|
||||||
|
|||||||
171
midas/checker/midas.py
Normal file
171
midas/checker/midas.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
|
from midas.checker.builtins import define_builtins
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
|
from midas.checker.types import (
|
||||||
|
AliasType,
|
||||||
|
ComplexType,
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
UnknownType,
|
||||||
|
)
|
||||||
|
from midas.lexer.midas import MidasLexer
|
||||||
|
from midas.lexer.token import Token
|
||||||
|
from midas.parser.midas import MidasParser
|
||||||
|
|
||||||
|
|
||||||
|
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
||||||
|
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||||
|
|
||||||
|
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||||
|
self.logger: logging.Logger = logging.getLogger("MidasTyper")
|
||||||
|
self.reporter: FileReporter = reporter.for_file(None)
|
||||||
|
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
self._local_variables: dict[str, TypeVar] = {}
|
||||||
|
|
||||||
|
define_builtins(self.types)
|
||||||
|
|
||||||
|
def process(self, source: str, path: Optional[str]):
|
||||||
|
self.reporter = self.reporter.for_file(path)
|
||||||
|
lexer: MidasLexer = MidasLexer(source)
|
||||||
|
tokens: list[Token] = lexer.process()
|
||||||
|
parser: MidasParser = MidasParser(tokens)
|
||||||
|
stmts: list[m.Stmt] = parser.parse()
|
||||||
|
self.resolve(stmts)
|
||||||
|
|
||||||
|
def get_type(self, name: str) -> Type:
|
||||||
|
"""Get a type from its name
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the type
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NameError: if the type is not defined
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the type
|
||||||
|
"""
|
||||||
|
if name in self._local_variables:
|
||||||
|
return self._local_variables[name]
|
||||||
|
return self.types.get_type(name)
|
||||||
|
|
||||||
|
def resolve(self, stmts: list[m.Stmt]):
|
||||||
|
"""Process a sequence of statements
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stmts (list[m.Stmt]): the statements
|
||||||
|
"""
|
||||||
|
for stmt in stmts:
|
||||||
|
stmt.accept(self)
|
||||||
|
|
||||||
|
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||||
|
params: list[TypeVar] = self._resolve_type_params(stmt.params)
|
||||||
|
|
||||||
|
name: str = stmt.name.lexeme
|
||||||
|
type: Type = stmt.type.accept(self)
|
||||||
|
if len(params) != 0:
|
||||||
|
type = GenericType(name=name, params=params, body=type)
|
||||||
|
else:
|
||||||
|
type = AliasType(name=name, type=type)
|
||||||
|
self.types.define_type(name, type)
|
||||||
|
self._local_variables.clear()
|
||||||
|
|
||||||
|
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
|
||||||
|
|
||||||
|
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||||
|
self._resolve_type_params(stmt.params)
|
||||||
|
base: Type = stmt.type.accept(self)
|
||||||
|
for op in stmt.operations:
|
||||||
|
right: Type = op.operand.accept(self)
|
||||||
|
result: Type = op.result.accept(self)
|
||||||
|
self.types.define_operation(
|
||||||
|
left=base,
|
||||||
|
operator=op.name.lexeme,
|
||||||
|
right=right,
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
|
||||||
|
|
||||||
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: m.GetExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||||
|
return expr.expr.accept(self)
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||||
|
return self.get_type(type.name.lexeme)
|
||||||
|
|
||||||
|
def visit_generic_type(self, type: m.GenericType) -> Type:
|
||||||
|
type_: Type = type.type.accept(self)
|
||||||
|
args: list[Type] = [arg.accept(self) for arg in type.args]
|
||||||
|
return self.types.apply_generic(type_, args)
|
||||||
|
|
||||||
|
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||||
|
type_: Type = type.type.accept(self)
|
||||||
|
type.constraint.accept(self)
|
||||||
|
# TODO
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_complex_type(self, type: m.ComplexType) -> Type:
|
||||||
|
return ComplexType(
|
||||||
|
properties={
|
||||||
|
prop.name.lexeme: prop.type.accept(self) for prop in type.properties
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_function_type(self, type: m.FunctionType) -> Type:
|
||||||
|
return Function(
|
||||||
|
name="<anonymous>",
|
||||||
|
pos_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=i,
|
||||||
|
name=arg.name.lexeme if arg.name is not None else str(i),
|
||||||
|
type=arg.type.accept(self),
|
||||||
|
required=arg.required,
|
||||||
|
)
|
||||||
|
for i, arg in enumerate(type.pos_args)
|
||||||
|
],
|
||||||
|
args=[],
|
||||||
|
kw_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=i,
|
||||||
|
name=arg.name.lexeme if arg.name is not None else str(i),
|
||||||
|
type=arg.type.accept(self),
|
||||||
|
required=arg.required,
|
||||||
|
)
|
||||||
|
for i, arg in enumerate(type.kw_args, start=len(type.pos_args))
|
||||||
|
],
|
||||||
|
returns=type.returns.accept(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _resolve_type_params(self, params: list[m.TypeParam]):
|
||||||
|
vars: list[TypeVar] = []
|
||||||
|
for param in params:
|
||||||
|
name: str = param.name.lexeme
|
||||||
|
bound: Optional[Type] = None
|
||||||
|
if param.bound is not None:
|
||||||
|
bound = param.bound.accept(self)
|
||||||
|
var = TypeVar(name=name, bound=bound)
|
||||||
|
self._local_variables[name] = var
|
||||||
|
vars.append(var)
|
||||||
|
return vars
|
||||||
646
midas/checker/python.py
Normal file
646
midas/checker/python.py
Normal file
@@ -0,0 +1,646 @@
|
|||||||
|
import ast
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.environment import Environment
|
||||||
|
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
|
from midas.checker.resolver import Resolver
|
||||||
|
from midas.checker.types import (
|
||||||
|
ComplexType,
|
||||||
|
Function,
|
||||||
|
Operation,
|
||||||
|
Type,
|
||||||
|
UnitType,
|
||||||
|
UnknownType,
|
||||||
|
unfold_type,
|
||||||
|
)
|
||||||
|
from midas.parser.python import PythonParser
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class MappedArgument:
|
||||||
|
expr: p.Expr
|
||||||
|
type: Type
|
||||||
|
argument: Function.Argument
|
||||||
|
|
||||||
|
|
||||||
|
class PythonTyper(
|
||||||
|
p.Stmt.Visitor[None],
|
||||||
|
p.Expr.Visitor[Type],
|
||||||
|
p.MidasType.Visitor[Type],
|
||||||
|
):
|
||||||
|
"""A type checker which can use custom type definitions"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
types: TypesRegistry,
|
||||||
|
reporter: Reporter,
|
||||||
|
):
|
||||||
|
self.logger: logging.Logger = logging.getLogger("PythonTyper")
|
||||||
|
self.reporter: FileReporter = reporter.for_file(None)
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
self.global_env: Environment = Environment()
|
||||||
|
self.env: Environment = self.global_env
|
||||||
|
self.locals: dict[p.Expr, int] = {}
|
||||||
|
self.judgements: list[tuple[p.Expr, Type]] = []
|
||||||
|
|
||||||
|
def process(self, source: str, path: Optional[str]):
|
||||||
|
self.reporter = self.reporter.for_file(path)
|
||||||
|
|
||||||
|
tree: ast.Module = ast.parse(source, filename=path or "<unknown>")
|
||||||
|
parser = PythonParser()
|
||||||
|
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||||
|
resolver = Resolver()
|
||||||
|
resolver.resolve(*stmts)
|
||||||
|
|
||||||
|
self.env = self.global_env
|
||||||
|
self.locals = resolver.locals
|
||||||
|
self.judgements = []
|
||||||
|
|
||||||
|
self.check(stmts)
|
||||||
|
|
||||||
|
def type_of(self, expr: p.Expr) -> Type:
|
||||||
|
"""Evaluate the type of an expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expr (p.Expr): the expression to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the type of the given expression
|
||||||
|
"""
|
||||||
|
type: Type = expr.accept(self)
|
||||||
|
self.judgements.append((expr, type))
|
||||||
|
return type
|
||||||
|
|
||||||
|
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
||||||
|
"""Evaluate a sequence of statements
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block (list[p.Stmt]): the statements to evaluate
|
||||||
|
env (Environment): the environment in which to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether a return statement is present in the block
|
||||||
|
"""
|
||||||
|
previous_env: Environment = self.env
|
||||||
|
self.env = env
|
||||||
|
returned: bool = False
|
||||||
|
for i, stmt in enumerate(block):
|
||||||
|
try:
|
||||||
|
stmt.accept(self)
|
||||||
|
except ReturnException:
|
||||||
|
returned = True
|
||||||
|
if i < len(block) - 1:
|
||||||
|
self.reporter.warning(
|
||||||
|
block[i + 1].location, "Unreachable statement"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
self.env = previous_env
|
||||||
|
return returned
|
||||||
|
|
||||||
|
def check(self, statements: list[p.Stmt]) -> None:
|
||||||
|
"""Type check a sequence of statements and returns diagnostics
|
||||||
|
|
||||||
|
Args:
|
||||||
|
statements (list[p.Stmt]): the statements to evaluate and check
|
||||||
|
"""
|
||||||
|
for stmt in statements:
|
||||||
|
stmt.accept(self)
|
||||||
|
|
||||||
|
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
|
||||||
|
|
||||||
|
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
|
||||||
|
"""Look up a variable in the environment it was declared
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the variable
|
||||||
|
expr (p.Expr): the variable expression, used to lookup the scope distance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Type]: the type of the variable, or None if it was not found
|
||||||
|
"""
|
||||||
|
distance: Optional[int] = self.locals.get(expr)
|
||||||
|
if distance is not None:
|
||||||
|
return self.env.get_at(distance, name)
|
||||||
|
return self.global_env.get(name)
|
||||||
|
|
||||||
|
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||||
|
return self.types.is_subtype(type1, type2)
|
||||||
|
|
||||||
|
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||||
|
self.type_of(stmt.expr)
|
||||||
|
|
||||||
|
def visit_function(self, stmt: p.Function) -> None:
|
||||||
|
env: Environment = Environment(self.env)
|
||||||
|
pos_args: list[Function.Argument] = []
|
||||||
|
args: list[Function.Argument] = []
|
||||||
|
kw_args: list[Function.Argument] = []
|
||||||
|
|
||||||
|
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||||
|
if arg.type is not None:
|
||||||
|
return arg.type.accept(self)
|
||||||
|
if arg.default is not None:
|
||||||
|
return arg.default.accept(self)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
pos: int = 0
|
||||||
|
for arg in stmt.posonlyargs:
|
||||||
|
pos_args.append(
|
||||||
|
Function.Argument(
|
||||||
|
pos=pos,
|
||||||
|
name=arg.name,
|
||||||
|
type=eval_arg_type(arg),
|
||||||
|
required=arg.default is None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pos += 1
|
||||||
|
for arg in stmt.args:
|
||||||
|
args.append(
|
||||||
|
Function.Argument(
|
||||||
|
pos=pos,
|
||||||
|
name=arg.name,
|
||||||
|
type=eval_arg_type(arg),
|
||||||
|
required=arg.default is None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pos += 1
|
||||||
|
for arg in stmt.kwonlyargs:
|
||||||
|
kw_args.append(
|
||||||
|
Function.Argument(
|
||||||
|
pos=pos, # not relevant
|
||||||
|
name=arg.name,
|
||||||
|
type=eval_arg_type(arg),
|
||||||
|
required=arg.default is None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pos += 1
|
||||||
|
|
||||||
|
for arg in pos_args + args + kw_args:
|
||||||
|
env.define(arg.name, arg.type)
|
||||||
|
|
||||||
|
returns_hint: Optional[Type] = None
|
||||||
|
if stmt.returns is not None:
|
||||||
|
returns_hint = stmt.returns.accept(self)
|
||||||
|
# Early define to handle simple fully-typed recursion
|
||||||
|
inside_function: Function = Function(
|
||||||
|
name=stmt.name,
|
||||||
|
pos_args=pos_args,
|
||||||
|
args=args,
|
||||||
|
kw_args=kw_args,
|
||||||
|
returns=returns_hint,
|
||||||
|
)
|
||||||
|
self.env.define(stmt.name, inside_function)
|
||||||
|
|
||||||
|
returned: bool = self.process_block(stmt.body, env)
|
||||||
|
inferred_return: Type = UnknownType()
|
||||||
|
if not returned:
|
||||||
|
env.return_types.append(UnitType())
|
||||||
|
return_types: list[Type] = self.types.reduce_types(env.return_types)
|
||||||
|
if len(return_types) == 1:
|
||||||
|
inferred_return = return_types[0]
|
||||||
|
elif len(return_types) > 1:
|
||||||
|
self.reporter.error(
|
||||||
|
stmt.location,
|
||||||
|
f"Mixed return types: {return_types}",
|
||||||
|
)
|
||||||
|
|
||||||
|
returns: Type = UnknownType()
|
||||||
|
if returns_hint is not None:
|
||||||
|
assert stmt.returns is not None
|
||||||
|
returns = returns_hint
|
||||||
|
if returns != inferred_return:
|
||||||
|
self.reporter.error(
|
||||||
|
stmt.returns.location,
|
||||||
|
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
returns = inferred_return
|
||||||
|
|
||||||
|
# TODO: handle *args and **kwargs sinks
|
||||||
|
function: Function = Function(
|
||||||
|
name=stmt.name,
|
||||||
|
pos_args=pos_args,
|
||||||
|
args=args,
|
||||||
|
kw_args=kw_args,
|
||||||
|
returns=returns,
|
||||||
|
)
|
||||||
|
self.env.define(stmt.name, function)
|
||||||
|
|
||||||
|
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||||
|
# TODO check not yet defined locally
|
||||||
|
type: Type = stmt.type.accept(self)
|
||||||
|
self.env.define(stmt.name, type)
|
||||||
|
|
||||||
|
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||||
|
value_type: Type = self.type_of(stmt.value)
|
||||||
|
for target in stmt.targets:
|
||||||
|
self._assign(stmt.location, target, value_type)
|
||||||
|
|
||||||
|
def _assign(self, location: Location, target: p.Expr, value_type: Type):
|
||||||
|
match target:
|
||||||
|
case p.VariableExpr():
|
||||||
|
self._assign_var(location, target, value_type)
|
||||||
|
|
||||||
|
case p.GetExpr():
|
||||||
|
self._assign_attr(location, target, value_type)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
if not isinstance(target, p.VariableExpr):
|
||||||
|
self.logger.warning(f"Unsupported assignment to {target}")
|
||||||
|
self.reporter.warning(
|
||||||
|
target.location, f"Unsupported assignment to {target}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type):
|
||||||
|
name: str = target.name
|
||||||
|
var_type: Optional[Type] = self.look_up_variable(name, target)
|
||||||
|
|
||||||
|
if var_type is None:
|
||||||
|
self.env.define(name, value_type)
|
||||||
|
else:
|
||||||
|
# S <: T
|
||||||
|
# Γ, x: T v: S
|
||||||
|
# x = v
|
||||||
|
if not self.is_subtype(value_type, var_type):
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Cannot assign {value_type} to variable '{name}' of type {var_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type):
|
||||||
|
object: Type = self.type_of(target.object)
|
||||||
|
base_object: Type = unfold_type(object)
|
||||||
|
match base_object:
|
||||||
|
case ComplexType(properties=properties):
|
||||||
|
if target.name not in properties:
|
||||||
|
self.reporter.error(
|
||||||
|
target.location, f"Unknown property '{object}.{target.name}'"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
prop_type: Type = properties[target.name]
|
||||||
|
if not self.is_subtype(value_type, prop_type):
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Cannot assign {value_type} to property '{object}.{target.name}' of type {prop_type}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
case UnknownType():
|
||||||
|
pass
|
||||||
|
|
||||||
|
case _:
|
||||||
|
self.reporter.error(
|
||||||
|
target.location,
|
||||||
|
f"Cannot assign {value_type} to unknown property '{object}.{target.name}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||||
|
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
|
||||||
|
self.env.return_types.append(type)
|
||||||
|
raise ReturnException()
|
||||||
|
|
||||||
|
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||||
|
# Not evaluated in sub-environment because assignments in the test leak out of the if
|
||||||
|
# For example:
|
||||||
|
# if (m := 1 + 1) < 2:
|
||||||
|
# ...
|
||||||
|
# print(m) # <- m is still defined
|
||||||
|
test_type: Type = stmt.test.accept(self)
|
||||||
|
|
||||||
|
# TODO Allow subtypes or any type
|
||||||
|
if test_type != self.types.get_type("bool"):
|
||||||
|
self.reporter.error(
|
||||||
|
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
env: Environment = Environment(self.env)
|
||||||
|
body_returned: bool = self.process_block(stmt.body, env)
|
||||||
|
else_returned: bool = self.process_block(stmt.orelse, env)
|
||||||
|
self.env.return_types.extend(env.return_types)
|
||||||
|
if body_returned and else_returned:
|
||||||
|
raise ReturnException()
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||||
|
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||||
|
if method is None:
|
||||||
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
|
self.reporter.warning(
|
||||||
|
expr.location, f"Unsupported operator {expr.operator}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
left: Type = self.type_of(expr.left)
|
||||||
|
right: Type = self.type_of(expr.right)
|
||||||
|
|
||||||
|
operations: list[Operation] = self.types.get_operations_by_name(method)
|
||||||
|
valid_operations: list[Operation] = []
|
||||||
|
for op in operations:
|
||||||
|
sig: Operation.CallSignature = op.signature
|
||||||
|
if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right):
|
||||||
|
valid_operations.append(op)
|
||||||
|
|
||||||
|
if len(valid_operations) == 0:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Undefined operation {method} between {left} and {right}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
elif len(valid_operations) == 1:
|
||||||
|
self.logger.debug(f"Unique operation {method} between {left} and {right}")
|
||||||
|
return valid_operations[0].result
|
||||||
|
|
||||||
|
for i, op1 in enumerate(valid_operations):
|
||||||
|
sig1: Operation.CallSignature = op1.signature
|
||||||
|
best_match: bool = True
|
||||||
|
for j, op2 in enumerate(valid_operations):
|
||||||
|
if i == j:
|
||||||
|
continue
|
||||||
|
sig2: Operation.CallSignature = op2.signature
|
||||||
|
|
||||||
|
# If op1 is not a full overload of op2 (i.e. operands of op1 are subtypes of op2's)
|
||||||
|
# ambiguity -> not best match
|
||||||
|
if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype(
|
||||||
|
sig1.right, sig2.right
|
||||||
|
):
|
||||||
|
best_match = False
|
||||||
|
break
|
||||||
|
self.logger.debug(f"{op1} is a full overload of {op2}")
|
||||||
|
if best_match:
|
||||||
|
return op1.result
|
||||||
|
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(map(str, valid_operations))}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||||
|
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||||
|
if method is None:
|
||||||
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
|
self.reporter.warning(
|
||||||
|
expr.location, f"Unsupported operator {expr.operator}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
left: Type = self.type_of(expr.left)
|
||||||
|
right: Type = self.type_of(expr.right)
|
||||||
|
|
||||||
|
result: Optional[Type] = self.types.get_operation_result(left, method, right)
|
||||||
|
if result is None:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Undefined operation {method} between {left} and {right}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||||
|
callee: Type = self.type_of(expr.callee)
|
||||||
|
if not isinstance(callee, Function):
|
||||||
|
self.reporter.error(expr.callee.location, "Callee is not a function")
|
||||||
|
return UnknownType()
|
||||||
|
function: Function = callee
|
||||||
|
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
||||||
|
for arg in mapped:
|
||||||
|
if not self.is_subtype(arg.type, arg.argument.type):
|
||||||
|
self.reporter.error(
|
||||||
|
arg.expr.location,
|
||||||
|
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||||
|
)
|
||||||
|
return function.returns
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
||||||
|
object: Type = self.type_of(expr.object)
|
||||||
|
base_object: Type = unfold_type(object)
|
||||||
|
match base_object:
|
||||||
|
case ComplexType(properties=properties):
|
||||||
|
if expr.name not in properties:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Unknown property '{expr.name} on {object}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return properties[expr.name]
|
||||||
|
|
||||||
|
case UnknownType():
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
case _:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Cannot get property '{expr.name}' on {object}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
|
||||||
|
match expr.value:
|
||||||
|
case bool(): # Must be before int
|
||||||
|
return self.types.get_type("bool")
|
||||||
|
case int():
|
||||||
|
return self.types.get_type("int")
|
||||||
|
case float():
|
||||||
|
return self.types.get_type("float")
|
||||||
|
case str():
|
||||||
|
return self.types.get_type("str")
|
||||||
|
case _:
|
||||||
|
self.reporter.warning(expr.location, f"Unknown literal {expr}")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
|
||||||
|
return self.look_up_variable(expr.name, expr) or UnknownType()
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
|
||||||
|
left: Type = expr.left.accept(self)
|
||||||
|
right: Type = expr.right.accept(self)
|
||||||
|
|
||||||
|
if self.is_subtype(left, right):
|
||||||
|
return right
|
||||||
|
if self.is_subtype(right, left):
|
||||||
|
return left
|
||||||
|
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Incompatible operand types, {left=} and {right=}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
||||||
|
return expr.type.accept(self)
|
||||||
|
|
||||||
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
||||||
|
test_type: Type = expr.test.accept(self)
|
||||||
|
|
||||||
|
# TODO Allow subtypes or any type
|
||||||
|
if test_type != self.types.get_type("bool"):
|
||||||
|
self.reporter.error(
|
||||||
|
expr.test.location, f"If test must be a boolean, got {test_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
true_type: Type = expr.if_true.accept(self)
|
||||||
|
false_type: Type = expr.if_false.accept(self)
|
||||||
|
if self.is_subtype(true_type, false_type):
|
||||||
|
return false_type
|
||||||
|
if self.is_subtype(false_type, true_type):
|
||||||
|
return true_type
|
||||||
|
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Incompatible types in ternary if branches: true={true_type} and false={false_type}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_list_expr(self, expr: p.ListExpr) -> Type:
|
||||||
|
list_type: Type = self.types.get_type("list")
|
||||||
|
item_types: list[Type] = [self.type_of(item) for item in expr.items]
|
||||||
|
item_types = self.types.reduce_types(item_types)
|
||||||
|
|
||||||
|
if len(item_types) == 0:
|
||||||
|
return list_type
|
||||||
|
|
||||||
|
if len(item_types) == 1:
|
||||||
|
item_type: Type = item_types[0]
|
||||||
|
return self.types.apply_generic(list_type, [item_type])
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Heterogeneous list items: {item_types}",
|
||||||
|
)
|
||||||
|
return self.types.apply_generic(list_type, [UnknownType()])
|
||||||
|
|
||||||
|
def visit_base_type(self, node: p.BaseType) -> Type:
|
||||||
|
base: Type = self.types.get_type(node.base)
|
||||||
|
if node.param is not None:
|
||||||
|
param: Type = node.param.accept(self)
|
||||||
|
return self.types.apply_generic(base, [param])
|
||||||
|
return base
|
||||||
|
|
||||||
|
def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
|
||||||
|
|
||||||
|
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
|
||||||
|
|
||||||
|
def visit_frame_type(self, node: p.FrameType) -> Type: ...
|
||||||
|
|
||||||
|
def map_call_arguments(
|
||||||
|
self, function: Function, call: p.CallExpr
|
||||||
|
) -> list[MappedArgument]:
|
||||||
|
"""Map call arguments to function parameters as defined in its signature
|
||||||
|
|
||||||
|
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||||
|
with the arguments passed at the call site
|
||||||
|
|
||||||
|
Any mismatched, missing or unexpected argument is reported as a diagnostic
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function (Function): the function definition
|
||||||
|
call (p.CallExpr): the call expression
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[MappedArgument]: the list of mapped arguments
|
||||||
|
"""
|
||||||
|
positional: list[tuple[p.Expr, Type]] = [
|
||||||
|
(arg, self.type_of(arg)) for arg in call.arguments
|
||||||
|
]
|
||||||
|
keywords: dict[str, tuple[p.Expr, Type]] = {
|
||||||
|
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
|
||||||
|
}
|
||||||
|
set_args: set[str] = set()
|
||||||
|
|
||||||
|
required_positional: list[str] = [
|
||||||
|
arg.name for arg in function.pos_args + function.args if arg.required
|
||||||
|
]
|
||||||
|
required_keyword: list[str] = [
|
||||||
|
arg.name for arg in function.kw_args if arg.required
|
||||||
|
]
|
||||||
|
|
||||||
|
mapped: list[MappedArgument] = []
|
||||||
|
|
||||||
|
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||||
|
mixed_params: list[Function.Argument] = list(function.args)
|
||||||
|
kw_params: dict[str, Function.Argument] = {
|
||||||
|
arg.name: arg for arg in function.kw_args
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: handle *args and **kwargs sinks
|
||||||
|
for arg in positional:
|
||||||
|
param: Function.Argument
|
||||||
|
if len(pos_params) != 0:
|
||||||
|
param = pos_params.pop(0)
|
||||||
|
elif len(mixed_params) != 0:
|
||||||
|
param = mixed_params.pop(0)
|
||||||
|
else:
|
||||||
|
self.reporter.error(arg[0].location, "Too many positional arguments")
|
||||||
|
break
|
||||||
|
name: str = param.name
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||||
|
for name, arg in keywords.items():
|
||||||
|
param: Function.Argument
|
||||||
|
if name not in kw_params:
|
||||||
|
if name in set_args:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, f"Multiple values for argument '{name}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
param = kw_params.pop(name)
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def join_args(args: list[str]) -> str:
|
||||||
|
args = list(map(lambda a: f"'{a}'", args))
|
||||||
|
if len(args) == 0:
|
||||||
|
return ""
|
||||||
|
if len(args) == 1:
|
||||||
|
return args[0]
|
||||||
|
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||||
|
|
||||||
|
if len(required_positional) != 0:
|
||||||
|
plural: str = "" if len(required_positional) == 1 else "s"
|
||||||
|
args: str = join_args(required_positional)
|
||||||
|
self.reporter.error(
|
||||||
|
call.location,
|
||||||
|
f"Missing required positional argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(required_keyword) != 0:
|
||||||
|
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||||
|
args: str = join_args(required_keyword)
|
||||||
|
self.reporter.error(
|
||||||
|
call.location,
|
||||||
|
f"Missing required keyword argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return mapped
|
||||||
313
midas/checker/registry.py
Normal file
313
midas/checker/registry.py
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||||
|
from midas.checker.types import (
|
||||||
|
AliasType,
|
||||||
|
AppliedType,
|
||||||
|
BaseType,
|
||||||
|
ComplexType,
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
Operation,
|
||||||
|
Type,
|
||||||
|
substitute_typevars,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TypesRegistry:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._types: dict[str, Type] = {}
|
||||||
|
self._operations: dict[Operation.CallSignature, Type] = {}
|
||||||
|
|
||||||
|
def get_type(self, name: str) -> Type:
|
||||||
|
"""Get a type from its name
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the type
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NameError: if the type is not defined
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the type
|
||||||
|
"""
|
||||||
|
if name in self._types:
|
||||||
|
return self._types[name]
|
||||||
|
raise NameError(f"Undefined type {name}")
|
||||||
|
|
||||||
|
def get_operation_result(
|
||||||
|
self, left: Type, operator: str, right: Type
|
||||||
|
) -> Optional[Type]:
|
||||||
|
"""Get the resulting type of an operation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
left (Type): the type of the left operand
|
||||||
|
operator (str): the operation name
|
||||||
|
right (Type): the type of the right operand
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Type]: the result type, or None if no matching operation was found
|
||||||
|
"""
|
||||||
|
signature: Operation.CallSignature = Operation.CallSignature(
|
||||||
|
left=left,
|
||||||
|
method=operator,
|
||||||
|
right=right,
|
||||||
|
)
|
||||||
|
result: Optional[Type] = self._operations.get(signature)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_operations_by_name(self, name: str) -> list[Operation]:
|
||||||
|
operations: list[Operation] = []
|
||||||
|
for signature, result in self._operations.items():
|
||||||
|
if signature.method == name:
|
||||||
|
operations.append(
|
||||||
|
Operation(
|
||||||
|
signature=signature,
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return operations
|
||||||
|
|
||||||
|
def define_type(self, name: str, type: Type) -> Type:
|
||||||
|
"""Define a type in the registry
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the type
|
||||||
|
type (Type): the type to define
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if a type is already defined with that name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the defined type
|
||||||
|
"""
|
||||||
|
if name in self._types:
|
||||||
|
raise ValueError(f"Type {name} already defined")
|
||||||
|
self._types[name] = type
|
||||||
|
return type
|
||||||
|
|
||||||
|
def define_operation(self, left: Type, operator: str, right: Type, result: Type):
|
||||||
|
"""Define an operation in the registry
|
||||||
|
|
||||||
|
Args:
|
||||||
|
left (Type): the type of the left operand
|
||||||
|
operator (str): the operation name
|
||||||
|
right (Type): the type of the right operand
|
||||||
|
result (Type): the result type
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if an operation is already defined with these operands and name
|
||||||
|
"""
|
||||||
|
signature: Operation.CallSignature = Operation.CallSignature(
|
||||||
|
left=left,
|
||||||
|
method=operator,
|
||||||
|
right=right,
|
||||||
|
)
|
||||||
|
if signature in self._operations:
|
||||||
|
raise ValueError(
|
||||||
|
f"Operation {operator} already defined between {left} and {right}"
|
||||||
|
)
|
||||||
|
self._operations[signature] = result
|
||||||
|
|
||||||
|
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||||
|
"""Check whether `type1` is a subtype of `type2`
|
||||||
|
|
||||||
|
For more details on the rules checked here, see TAPL Chap. 15-16-17
|
||||||
|
|
||||||
|
Args:
|
||||||
|
type1 (Type): the potential subtype
|
||||||
|
type2 (Type): the potential supertype
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether `type1` is a subtype of `type2`
|
||||||
|
"""
|
||||||
|
|
||||||
|
if type1 == type2:
|
||||||
|
return True
|
||||||
|
|
||||||
|
match (type1, type2):
|
||||||
|
case (AliasType(type=base1), _):
|
||||||
|
return self.is_subtype(base1, type2)
|
||||||
|
|
||||||
|
case (BaseType(name=name1), BaseType(name=name2)):
|
||||||
|
return name1 in BUILTIN_SUBTYPES.get(name2, set())
|
||||||
|
|
||||||
|
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
||||||
|
for k, t in props2.items():
|
||||||
|
if k not in props1:
|
||||||
|
return False
|
||||||
|
if not self.is_subtype(props1[k], t):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
case (Function(), Function()):
|
||||||
|
return self.is_func_subtype(type1, type2)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
# TODO: verify the logic in here
|
||||||
|
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
|
||||||
|
"""Check whether a function is a subtype of another
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func1 (Function): the potential function subtype
|
||||||
|
func2 (Function): the potential function supertype
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether `func1` is a subtype of `func2`
|
||||||
|
"""
|
||||||
|
if not self.is_subtype(func1.returns, func2.returns):
|
||||||
|
return False
|
||||||
|
|
||||||
|
pos1: list[Function.Argument] = func1.pos_args
|
||||||
|
mixed1: list[Function.Argument] = func1.args
|
||||||
|
kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args}
|
||||||
|
pos2: list[Function.Argument] = func2.pos_args
|
||||||
|
mixed2: list[Function.Argument] = func2.args
|
||||||
|
kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args}
|
||||||
|
|
||||||
|
mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2}
|
||||||
|
mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2}
|
||||||
|
|
||||||
|
def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool:
|
||||||
|
if not self.is_subtype(sub.type, sup.type):
|
||||||
|
return False
|
||||||
|
if not sup.required and sub.required:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
for arg1 in pos1:
|
||||||
|
arg2: Function.Argument
|
||||||
|
if arg1.pos < len(pos2):
|
||||||
|
arg2 = pos2[arg1.pos]
|
||||||
|
elif arg1.pos in mixed_by_pos:
|
||||||
|
arg2 = mixed_by_pos[arg1.pos]
|
||||||
|
elif not arg1.required:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
if not is_arg_subtype(arg2, arg1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for name, arg1 in kw1.items():
|
||||||
|
arg2: Function.Argument
|
||||||
|
if name in kw2:
|
||||||
|
arg2 = kw2[name]
|
||||||
|
elif name in mixed_by_name:
|
||||||
|
arg2 = mixed_by_name[name]
|
||||||
|
elif not arg1.required:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
if not is_arg_subtype(arg2, arg1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for arg1 in mixed1:
|
||||||
|
pos_arg2: Optional[Function.Argument] = None
|
||||||
|
kw_arg2: Optional[Function.Argument] = None
|
||||||
|
if arg1.name in kw2:
|
||||||
|
kw_arg2 = kw2[arg1.name]
|
||||||
|
elif arg1.name in mixed_by_name:
|
||||||
|
kw_arg2 = mixed_by_name[arg1.name]
|
||||||
|
if arg1.pos < len(pos2):
|
||||||
|
pos_arg2 = pos2[arg1.pos]
|
||||||
|
elif arg1.pos in mixed_by_pos:
|
||||||
|
pos_arg2 = mixed_by_pos[arg1.pos]
|
||||||
|
|
||||||
|
# No match in func2 and arg is required
|
||||||
|
if pos_arg2 is None and kw_arg2 is None and arg1.required:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Matching keyword argument
|
||||||
|
if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Matching positional argument
|
||||||
|
if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
mixed_positions: set[int] = {a.pos for a in mixed1}
|
||||||
|
mixed_names: set[str] = {a.name for a in mixed1}
|
||||||
|
for arg2 in pos2:
|
||||||
|
if not arg2.required:
|
||||||
|
continue
|
||||||
|
if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for name, arg2 in kw2.items():
|
||||||
|
if not arg2.required:
|
||||||
|
continue
|
||||||
|
if name not in kw1 and name not in mixed_names:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for arg2 in mixed2:
|
||||||
|
if arg2.required:
|
||||||
|
continue
|
||||||
|
pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions
|
||||||
|
kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names
|
||||||
|
if not pos_match or not kw_match:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def apply_generic(self, type: Type, args: list[Type]) -> Type:
|
||||||
|
match type:
|
||||||
|
case AliasType(name=name, type=base):
|
||||||
|
return AliasType(name=name, type=self.apply_generic(base, args))
|
||||||
|
|
||||||
|
case GenericType(name=name, params=type_vars, body=body):
|
||||||
|
n_args: int = len(args)
|
||||||
|
n_type_vars: int = len(type_vars)
|
||||||
|
if n_args < n_type_vars:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing type arguments, expected {n_type_vars} but only {n_args} provided"
|
||||||
|
)
|
||||||
|
if n_args > n_type_vars:
|
||||||
|
raise ValueError(
|
||||||
|
f"Too many type arguments, expected {n_type_vars} but {n_args} provided"
|
||||||
|
)
|
||||||
|
substitutions: dict[str, Type] = {}
|
||||||
|
for arg, type_var in zip(args, type_vars):
|
||||||
|
if type_var.bound is not None and not self.is_subtype(
|
||||||
|
arg, type_var.bound
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Type argument {arg} is not a subtype of {type_var.bound}"
|
||||||
|
)
|
||||||
|
substitutions[type_var.name] = arg
|
||||||
|
return AppliedType(
|
||||||
|
name=name,
|
||||||
|
args=args,
|
||||||
|
body=substitute_typevars(body, substitutions),
|
||||||
|
)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"{type} is not a generic type")
|
||||||
|
|
||||||
|
def reduce_types(self, types: list[Type]) -> list[Type]:
|
||||||
|
"""Reduce a list of types to remove subtypes and only keep the highest types
|
||||||
|
|
||||||
|
Args:
|
||||||
|
types (list[Type]): the types to reduce
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Type]: the reduced list of types
|
||||||
|
"""
|
||||||
|
|
||||||
|
reduced: bool = True
|
||||||
|
keep: list[int] = list(range(len(types)))
|
||||||
|
while reduced:
|
||||||
|
reduced = False
|
||||||
|
for i, i1 in enumerate(keep):
|
||||||
|
type1: Type = types[i1]
|
||||||
|
for i2 in keep[i + 1 :]:
|
||||||
|
type2 = types[i2]
|
||||||
|
if self.is_subtype(type1, type2):
|
||||||
|
keep.remove(i1)
|
||||||
|
elif self.is_subtype(type2, type1):
|
||||||
|
keep.remove(i2)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
reduced = True
|
||||||
|
break
|
||||||
|
return [types[i] for i in keep]
|
||||||
63
midas/checker/reporter.py
Normal file
63
midas/checker/reporter.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||||
|
|
||||||
|
|
||||||
|
class Reporter:
|
||||||
|
def __init__(self):
|
||||||
|
self.diagnostics: list[Diagnostic] = []
|
||||||
|
|
||||||
|
def report(
|
||||||
|
self,
|
||||||
|
path: Optional[str],
|
||||||
|
type: DiagnosticType,
|
||||||
|
location: Location,
|
||||||
|
message: str,
|
||||||
|
):
|
||||||
|
self.diagnostics.append(
|
||||||
|
Diagnostic(
|
||||||
|
file_path=path,
|
||||||
|
location=location,
|
||||||
|
type=type,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def for_file(self, path: Optional[str]) -> FileReporter:
|
||||||
|
return FileReporter(self, path)
|
||||||
|
|
||||||
|
|
||||||
|
class FileReporter:
|
||||||
|
def __init__(self, base_reporter: Reporter, path: Optional[str]) -> None:
|
||||||
|
self.base_reporter: Reporter = base_reporter
|
||||||
|
self.path: Optional[str] = path
|
||||||
|
|
||||||
|
def for_file(self, path: Optional[str]) -> FileReporter:
|
||||||
|
return FileReporter(self.base_reporter, path)
|
||||||
|
|
||||||
|
def report(self, type: DiagnosticType, location: Location, message: str):
|
||||||
|
self.base_reporter.report(self.path, type, location, message)
|
||||||
|
|
||||||
|
def error(self, location: Location, message: str):
|
||||||
|
self.report(
|
||||||
|
type=DiagnosticType.ERROR,
|
||||||
|
location=location,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
def warning(self, location: Location, message: str):
|
||||||
|
self.report(
|
||||||
|
type=DiagnosticType.WARNING,
|
||||||
|
location=location,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
def info(self, location: Location, message: str):
|
||||||
|
self.report(
|
||||||
|
type=DiagnosticType.INFO,
|
||||||
|
location=location,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
@@ -13,7 +13,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.locals: dict[p.Expr, int] = {}
|
self.locals: dict[p.Expr, int] = {}
|
||||||
self.scopes: list[dict[str, bool]] = []
|
self.scopes: list[dict[str, bool]] = [{}]
|
||||||
|
|
||||||
def resolve(self, *objects: p.Stmt | p.Expr) -> None:
|
def resolve(self, *objects: p.Stmt | p.Expr) -> None:
|
||||||
"""Resolve the given statements or expressions"""
|
"""Resolve the given statements or expressions"""
|
||||||
@@ -77,6 +77,12 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
self.locals[expr] = i
|
self.locals[expr] = i
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def is_defined(self, name: str) -> bool:
|
||||||
|
for scope in self.scopes:
|
||||||
|
if name in scope:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def resolve_function(self, function: p.Function) -> None:
|
def resolve_function(self, function: p.Function) -> None:
|
||||||
"""Resolve a function definition
|
"""Resolve a function definition
|
||||||
|
|
||||||
@@ -112,8 +118,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
for target in stmt.targets:
|
for target in stmt.targets:
|
||||||
match target:
|
match target:
|
||||||
case p.VariableExpr(name=name):
|
case p.VariableExpr(name=name):
|
||||||
self.resolve_local(target, name)
|
if not self.is_defined(name):
|
||||||
# TODO: declare if not found
|
self.declare(name)
|
||||||
|
self.define(name)
|
||||||
|
target.accept(self)
|
||||||
|
|
||||||
|
case p.GetExpr():
|
||||||
|
target.accept(self)
|
||||||
case _:
|
case _:
|
||||||
raise Exception(f"Unsupported assignment to {target}")
|
raise Exception(f"Unsupported assignment to {target}")
|
||||||
|
|
||||||
@@ -174,10 +185,6 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
self.resolve(expr.left)
|
self.resolve(expr.left)
|
||||||
self.resolve(expr.right)
|
self.resolve(expr.right)
|
||||||
|
|
||||||
def visit_set_expr(self, expr: p.SetExpr) -> None:
|
|
||||||
self.resolve(expr.value)
|
|
||||||
self.resolve(expr.object)
|
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||||
self.resolve(expr.expr)
|
self.resolve(expr.expr)
|
||||||
|
|
||||||
@@ -185,3 +192,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
self.resolve(expr.test)
|
self.resolve(expr.test)
|
||||||
self.resolve(expr.if_true)
|
self.resolve(expr.if_true)
|
||||||
self.resolve(expr.if_false)
|
self.resolve(expr.if_false)
|
||||||
|
|
||||||
|
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||||
|
for item in expr.items:
|
||||||
|
self.resolve(item)
|
||||||
@@ -1,27 +1,36 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class BaseType:
|
class BaseType:
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class AliasType:
|
class AliasType:
|
||||||
name: str
|
name: str
|
||||||
type: Type
|
type: Type
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class UnknownType:
|
class UnknownType:
|
||||||
pass
|
def __str__(self) -> str:
|
||||||
|
return "<Unknown>"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class UnitType:
|
class UnitType:
|
||||||
pass
|
def __str__(self) -> str:
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
@@ -32,16 +41,159 @@ class Function:
|
|||||||
kw_args: list[Argument]
|
kw_args: list[Argument]
|
||||||
returns: Type
|
returns: Type
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
args: list[str] = []
|
||||||
|
if len(self.pos_args) != 0:
|
||||||
|
args += list(map(str, self.pos_args))
|
||||||
|
if len(self.args) + len(self.kw_args) != 0:
|
||||||
|
args.append("/")
|
||||||
|
|
||||||
|
if len(self.args) != 0:
|
||||||
|
args += list(map(str, self.args))
|
||||||
|
|
||||||
|
if len(self.kw_args) != 0:
|
||||||
|
if len(args) != 0:
|
||||||
|
args.append("*")
|
||||||
|
args += list(map(str, self.kw_args))
|
||||||
|
|
||||||
|
return f"{self.name}({', '.join(args)}) -> {self.returns}"
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class Argument:
|
class Argument:
|
||||||
|
pos: int
|
||||||
name: str
|
name: str
|
||||||
type: Type
|
type: Type
|
||||||
required: bool
|
required: bool
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
opt: str = "" if self.required else "?"
|
||||||
|
return f"{self.name}: {self.type}{opt}"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class ComplexType:
|
class ComplexType:
|
||||||
properties: dict[str, Type]
|
properties: dict[str, Type]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
props: list[str] = [f"{name}: {type}" for name, type in self.properties.items()]
|
||||||
|
return f"{{{', '.join(props)}}}"
|
||||||
|
|
||||||
Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Operation:
|
||||||
|
signature: CallSignature
|
||||||
|
result: Type
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.signature} -> {self.result}"
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class CallSignature:
|
||||||
|
left: Type
|
||||||
|
method: str
|
||||||
|
right: Type
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.method}({self.left}, {self.right})"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class TypeVar:
|
||||||
|
name: str
|
||||||
|
bound: Optional[Type]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
if self.bound is not None:
|
||||||
|
return f"{self.name} <: {self.bound}"
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class GenericType:
|
||||||
|
name: str
|
||||||
|
params: list[TypeVar]
|
||||||
|
body: Type
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.name}[{', '.join(map(str, self.params))}]"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class AppliedType:
|
||||||
|
name: str
|
||||||
|
args: list[Type]
|
||||||
|
body: Type
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.name}[{', '.join(map(str, self.args))}]"
|
||||||
|
|
||||||
|
|
||||||
|
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||||
|
def sub_argument(arg: Function.Argument):
|
||||||
|
return Function.Argument(
|
||||||
|
pos=arg.pos,
|
||||||
|
name=arg.name,
|
||||||
|
type=substitute_typevars(arg.type, substitutions),
|
||||||
|
required=arg.required,
|
||||||
|
)
|
||||||
|
|
||||||
|
match type:
|
||||||
|
case BaseType(name=name) if name in substitutions:
|
||||||
|
return substitutions[name]
|
||||||
|
|
||||||
|
case AliasType(name=name, type=type2):
|
||||||
|
return AliasType(name=name, type=substitute_typevars(type2, substitutions))
|
||||||
|
|
||||||
|
case Function(
|
||||||
|
name=name,
|
||||||
|
pos_args=pos_args,
|
||||||
|
args=args,
|
||||||
|
kw_args=kw_args,
|
||||||
|
returns=returns,
|
||||||
|
):
|
||||||
|
return Function(
|
||||||
|
name=name,
|
||||||
|
pos_args=list(map(sub_argument, pos_args)),
|
||||||
|
args=list(map(sub_argument, args)),
|
||||||
|
kw_args=list(map(sub_argument, kw_args)),
|
||||||
|
returns=substitute_typevars(returns, substitutions),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ComplexType(properties=properties):
|
||||||
|
properties2: dict[str, Type] = {
|
||||||
|
name: substitute_typevars(prop, substitutions)
|
||||||
|
for name, prop in properties.items()
|
||||||
|
}
|
||||||
|
return ComplexType(properties=properties2)
|
||||||
|
|
||||||
|
case TypeVar(name=name):
|
||||||
|
if name in substitutions:
|
||||||
|
return substitutions[name]
|
||||||
|
raise ValueError(f"Missing TypeVar substitution for {name}")
|
||||||
|
|
||||||
|
case UnknownType() | UnitType():
|
||||||
|
return type
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError(f"Unsupported type {type}")
|
||||||
|
|
||||||
|
|
||||||
|
def unfold_type(type: Type) -> Type:
|
||||||
|
match type:
|
||||||
|
case AliasType(type=ref_type):
|
||||||
|
return unfold_type(ref_type)
|
||||||
|
case _:
|
||||||
|
return type
|
||||||
|
|
||||||
|
|
||||||
|
Type = (
|
||||||
|
BaseType
|
||||||
|
| AliasType
|
||||||
|
| UnknownType
|
||||||
|
| UnitType
|
||||||
|
| Function
|
||||||
|
| ComplexType
|
||||||
|
| TypeVar
|
||||||
|
| GenericType
|
||||||
|
| AppliedType
|
||||||
|
)
|
||||||
|
|||||||
41
midas/cli/ansi.py
Normal file
41
midas/cli/ansi.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
class Ansi:
|
||||||
|
CTRL = "\x1b["
|
||||||
|
RESET = CTRL + "0m"
|
||||||
|
BOLD = CTRL + "1m"
|
||||||
|
DIM = CTRL + "2m"
|
||||||
|
ITALIC = CTRL + "3m"
|
||||||
|
UNDERLINE = CTRL + "4m"
|
||||||
|
|
||||||
|
BLACK = 0
|
||||||
|
RED = 1
|
||||||
|
GREEN = 2
|
||||||
|
YELLOW = 3
|
||||||
|
BLUE = 4
|
||||||
|
MAGENTA = 5
|
||||||
|
CYAN = 6
|
||||||
|
WHITE = 7
|
||||||
|
|
||||||
|
BRIGHT_BLACK = 60
|
||||||
|
BRIGHT_RED = 61
|
||||||
|
BRIGHT_GREEN = 62
|
||||||
|
BRIGHT_YELLOW = 63
|
||||||
|
BRIGHT_BLUE = 64
|
||||||
|
BRIGHT_MAGENTA = 65
|
||||||
|
BRIGHT_CYAN = 66
|
||||||
|
BRIGHT_WHITE = 67
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def FG(cls, col: int) -> str:
|
||||||
|
return f"{cls.CTRL}{30 + col}m"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def BG(cls, col: int) -> str:
|
||||||
|
return f"{cls.CTRL}{40 + col}m"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def FG_RGB(cls, r: int, g: int, b: int) -> str:
|
||||||
|
return f"{cls.CTRL}38;2;{r};{g};{b}m"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def BG_RGB(cls, r: int, g: int, b: int) -> str:
|
||||||
|
return f"{cls.CTRL}48;2;{r};{g};{b}m"
|
||||||
@@ -210,12 +210,14 @@ class PythonHighlighter(
|
|||||||
|
|
||||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ...
|
def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ...
|
||||||
|
|
||||||
def visit_set_expr(self, expr: p.SetExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
|
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||||
|
for item in expr.items:
|
||||||
|
item.accept(self)
|
||||||
|
|
||||||
|
|
||||||
class MidasHighlighter(
|
class MidasHighlighter(
|
||||||
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
|
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
|
||||||
@@ -286,8 +288,8 @@ class MidasHighlighter(
|
|||||||
def visit_generic_type(self, type: m.GenericType) -> None:
|
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||||
self.wrap(type, "generic-type")
|
self.wrap(type, "generic-type")
|
||||||
type.type.accept(self)
|
type.type.accept(self)
|
||||||
for param in type.params:
|
for arg in type.args:
|
||||||
param.accept(self)
|
arg.accept(self)
|
||||||
|
|
||||||
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||||
self.wrap(type, "constraint-type")
|
self.wrap(type, "constraint-type")
|
||||||
@@ -299,6 +301,12 @@ class MidasHighlighter(
|
|||||||
for prop in type.properties:
|
for prop in type.properties:
|
||||||
prop.accept(self)
|
prop.accept(self)
|
||||||
|
|
||||||
|
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||||
|
self.wrap(type, "function")
|
||||||
|
for arg in type.pos_args + type.kw_args:
|
||||||
|
arg.type.accept(self)
|
||||||
|
type.returns.accept(self)
|
||||||
|
|
||||||
|
|
||||||
class DiagnosticsHighlighter(Highlighter):
|
class DiagnosticsHighlighter(Highlighter):
|
||||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
||||||
|
|||||||
@@ -8,10 +8,12 @@ import click
|
|||||||
|
|
||||||
import midas.ast.midas as m
|
import midas.ast.midas as m
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
|
from midas.ast.location import Location
|
||||||
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
|
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
|
||||||
from midas.checker.checker import Checker
|
from midas.checker.checker import TypeChecker
|
||||||
from midas.checker.diagnostic import Diagnostic
|
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||||
from midas.checker.types import Type
|
from midas.checker.types import Type
|
||||||
|
from midas.cli.ansi import Ansi
|
||||||
from midas.cli.highlighter import (
|
from midas.cli.highlighter import (
|
||||||
DiagnosticsHighlighter,
|
DiagnosticsHighlighter,
|
||||||
Highlighter,
|
Highlighter,
|
||||||
@@ -23,7 +25,6 @@ from midas.lexer.midas import MidasLexer
|
|||||||
from midas.lexer.token import Token, TokenType
|
from midas.lexer.token import Token, TokenType
|
||||||
from midas.parser.midas import MidasParser
|
from midas.parser.midas import MidasParser
|
||||||
from midas.parser.python import PythonParser
|
from midas.parser.python import PythonParser
|
||||||
from midas.resolver.resolver import Resolver
|
|
||||||
from midas.utils import UniversalJSONDumper
|
from midas.utils import UniversalJSONDumper
|
||||||
|
|
||||||
|
|
||||||
@@ -32,38 +33,92 @@ def midas():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def print_diagnostic(lines: list[str], diagnostic: Diagnostic, indent: int = 4):
|
||||||
|
"""Pretty-print a diagnostic, showing some context if possible
|
||||||
|
|
||||||
|
If the diagnostic concerns a specific part of one line, the line is shown
|
||||||
|
with the affected part highlighted. The message is clearly printed under the
|
||||||
|
line with an underline further indicating the target expression.
|
||||||
|
|
||||||
|
If multiple lines are concerned, no context is shown, only the
|
||||||
|
diagnostic type, location and message
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lines (list[str]): source code lines
|
||||||
|
diagnostic (Diagnostic): the diagnostic to print
|
||||||
|
indent (int, optional): the number of spaces added before the target line to indent if from the location header. Defaults to 4.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loc: Location = diagnostic.location
|
||||||
|
if loc.lineno != loc.end_lineno:
|
||||||
|
print(diagnostic)
|
||||||
|
return
|
||||||
|
|
||||||
|
start_offset: int = loc.col_offset
|
||||||
|
end_offset: int = loc.end_col_offset or (start_offset + 1)
|
||||||
|
|
||||||
|
line: str = lines[loc.lineno - 1]
|
||||||
|
before: str = line[:start_offset]
|
||||||
|
after: str = line[end_offset:]
|
||||||
|
|
||||||
|
color: int = {
|
||||||
|
DiagnosticType.ERROR: Ansi.RED,
|
||||||
|
DiagnosticType.WARNING: Ansi.YELLOW,
|
||||||
|
DiagnosticType.INFO: Ansi.CYAN,
|
||||||
|
}.get(diagnostic.type, Ansi.WHITE)
|
||||||
|
|
||||||
|
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
|
||||||
|
cursor: str = (
|
||||||
|
" " * start_offset
|
||||||
|
+ Ansi.FG(color)
|
||||||
|
+ "~" * (end_offset - start_offset)
|
||||||
|
+ "> "
|
||||||
|
+ diagnostic.message
|
||||||
|
+ Ansi.RESET
|
||||||
|
)
|
||||||
|
|
||||||
|
indent_str: str = " " * indent
|
||||||
|
print(diagnostic.location_str + ":")
|
||||||
|
print(indent_str + before + subject + after)
|
||||||
|
print(indent_str + cursor)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
@midas.command()
|
@midas.command()
|
||||||
@click.option("-l", "--highlight", type=click.File("w"))
|
@click.option("-l", "--highlight", type=click.File("w"))
|
||||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||||
|
@click.option("-v", "--verbose", is_flag=True)
|
||||||
@click.argument("file", type=click.File("r"))
|
@click.argument("file", type=click.File("r"))
|
||||||
def compile(highlight: Optional[TextIO], file: TextIO, types: tuple[TextIO]):
|
def compile(
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
highlight: Optional[TextIO],
|
||||||
|
types: tuple[TextIO],
|
||||||
|
verbose: bool,
|
||||||
|
file: TextIO,
|
||||||
|
):
|
||||||
|
logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN)
|
||||||
source: str = file.read()
|
source: str = file.read()
|
||||||
tree: ast.Module = ast.parse(source, filename=file.name)
|
|
||||||
parser = PythonParser()
|
|
||||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
|
||||||
resolver = Resolver()
|
|
||||||
resolver.resolve(*stmts)
|
|
||||||
types_paths: list[Path] = [Path(t.name).resolve() for t in types]
|
|
||||||
checker = Checker(
|
|
||||||
resolver.locals,
|
|
||||||
source_path=Path(file.name).resolve(),
|
|
||||||
types_paths=types_paths,
|
|
||||||
)
|
|
||||||
diagnostics: list[Diagnostic] = checker.check(stmts)
|
|
||||||
for diagnostic in diagnostics:
|
|
||||||
print(diagnostic)
|
|
||||||
|
|
||||||
print(
|
checker = TypeChecker()
|
||||||
json.dumps(
|
for path in types:
|
||||||
UniversalJSONDumper.dump(
|
checker.import_midas(Path(path.name).resolve())
|
||||||
checker.global_env,
|
|
||||||
[("Environment", "_children")],
|
checker.type_check_source(source, str(Path(file.name).resolve()))
|
||||||
lambda obj: isinstance(obj, get_args(Type)),
|
diagnostics: list[Diagnostic] = checker.diagnostics
|
||||||
),
|
lines: list[str] = source.split("\n")
|
||||||
indent=4,
|
for diagnostic in diagnostics:
|
||||||
|
print_diagnostic(lines, diagnostic)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
json.dumps(
|
||||||
|
UniversalJSONDumper.dump(
|
||||||
|
checker.python_typer.global_env,
|
||||||
|
[("Environment", "_children")],
|
||||||
|
lambda obj: isinstance(obj, get_args(Type)),
|
||||||
|
),
|
||||||
|
indent=4,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if highlight is not None:
|
if highlight is not None:
|
||||||
highlighter = DiagnosticsHighlighter(source)
|
highlighter = DiagnosticsHighlighter(source)
|
||||||
highlighter.highlight(diagnostics)
|
highlighter.highlight(diagnostics)
|
||||||
|
|||||||
@@ -50,12 +50,14 @@ class MidasLexer(Lexer):
|
|||||||
# self.add_token(TokenType.PLUS)
|
# self.add_token(TokenType.PLUS)
|
||||||
case "-":
|
case "-":
|
||||||
self.add_token(TokenType.MINUS)
|
self.add_token(TokenType.MINUS)
|
||||||
# case "*":
|
case "*":
|
||||||
# self.add_token(TokenType.STAR)
|
self.add_token(TokenType.STAR)
|
||||||
case "/" if self.match("/"):
|
case "/" if self.match("/"):
|
||||||
self.scan_comment()
|
self.scan_comment()
|
||||||
case "/" if self.match("*"):
|
case "/" if self.match("*"):
|
||||||
self.scan_comment_multiline()
|
self.scan_comment_multiline()
|
||||||
|
case "/":
|
||||||
|
self.add_token(TokenType.SLASH)
|
||||||
case "\n":
|
case "\n":
|
||||||
self.add_token(TokenType.NEWLINE)
|
self.add_token(TokenType.NEWLINE)
|
||||||
case " " | "\r" | "\t":
|
case " " | "\r" | "\t":
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ class TokenType(Enum):
|
|||||||
# Operators
|
# Operators
|
||||||
# PLUS = auto()
|
# PLUS = auto()
|
||||||
MINUS = auto()
|
MINUS = auto()
|
||||||
# STAR = auto()
|
STAR = auto()
|
||||||
# SLASH = auto()
|
SLASH = auto()
|
||||||
GREATER = auto()
|
GREATER = auto()
|
||||||
GREATER_EQUAL = auto()
|
GREATER_EQUAL = auto()
|
||||||
LESS = auto()
|
LESS = auto()
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from midas.ast.midas import (
|
|||||||
ConstraintType,
|
ConstraintType,
|
||||||
Expr,
|
Expr,
|
||||||
ExtendStmt,
|
ExtendStmt,
|
||||||
|
FunctionType,
|
||||||
GenericType,
|
GenericType,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
GroupingExpr,
|
GroupingExpr,
|
||||||
@@ -18,12 +19,13 @@ from midas.ast.midas import (
|
|||||||
PropertyStmt,
|
PropertyStmt,
|
||||||
Stmt,
|
Stmt,
|
||||||
Type,
|
Type,
|
||||||
|
TypeParam,
|
||||||
TypeStmt,
|
TypeStmt,
|
||||||
UnaryExpr,
|
UnaryExpr,
|
||||||
VariableExpr,
|
VariableExpr,
|
||||||
WildcardExpr,
|
WildcardExpr,
|
||||||
)
|
)
|
||||||
from midas.lexer.token import Token, TokenType
|
from midas.lexer.token import KEYWORDS, Token, TokenType
|
||||||
from midas.parser.base import Parser
|
from midas.parser.base import Parser
|
||||||
from midas.parser.errors import ParsingError
|
from midas.parser.errors import ParsingError
|
||||||
|
|
||||||
@@ -107,10 +109,8 @@ class MidasParser(Parser):
|
|||||||
TypeStmt: the parsed type declaration statement
|
TypeStmt: the parsed type declaration statement
|
||||||
"""
|
"""
|
||||||
keyword: Token = self.previous()
|
keyword: Token = self.previous()
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
name: Token = self.consume_identifier("Expected type name")
|
||||||
params: list[TypeStmt.Param] = []
|
params: list[TypeParam] = self.type_params()
|
||||||
if self.check(TokenType.LEFT_BRACKET):
|
|
||||||
params = self.type_stmt_params()
|
|
||||||
|
|
||||||
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
|
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
|
||||||
|
|
||||||
@@ -123,24 +123,27 @@ class MidasParser(Parser):
|
|||||||
type=type,
|
type=type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def type_stmt_params(self) -> list[TypeStmt.Param]:
|
def type_params(self) -> list[TypeParam]:
|
||||||
"""Parse a generic template expression
|
"""Parse a list of type parameters
|
||||||
|
|
||||||
A template is written `[TypeExpr]`
|
Type parameters are a comma-separated list of type variables wrapped in brackets.
|
||||||
|
Each type variable is either a simple variable, or a bounded variable written `S <: T`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TemplateExpr: the parsed template expression
|
list[TypeParam]: the list of type parameters, if any, or an empty list
|
||||||
"""
|
"""
|
||||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
|
if not self.match(TokenType.LEFT_BRACKET):
|
||||||
params: list[TypeStmt.Param] = []
|
return []
|
||||||
|
|
||||||
|
params: list[TypeParam] = []
|
||||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
|
name: Token = self.consume_identifier("Expected type variable")
|
||||||
bound: Optional[Type] = None
|
bound: Optional[Type] = None
|
||||||
if self.match(TokenType.LESS):
|
if self.match(TokenType.LESS):
|
||||||
self.consume(TokenType.COLON, "Expected ':' after '<'")
|
self.consume(TokenType.COLON, "Expected ':' after '<'")
|
||||||
bound = self.type_expr()
|
bound = self.type_expr()
|
||||||
params.append(
|
params.append(
|
||||||
TypeStmt.Param(
|
TypeParam(
|
||||||
location=name.location_to(self.previous()),
|
location=name.location_to(self.previous()),
|
||||||
name=name,
|
name=name,
|
||||||
bound=bound,
|
bound=bound,
|
||||||
@@ -148,7 +151,7 @@ class MidasParser(Parser):
|
|||||||
)
|
)
|
||||||
if not self.match(TokenType.COMMA):
|
if not self.match(TokenType.COMMA):
|
||||||
break
|
break
|
||||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
|
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def type_expr(self) -> Type:
|
def type_expr(self) -> Type:
|
||||||
@@ -187,26 +190,26 @@ class MidasParser(Parser):
|
|||||||
def generic_type(self) -> Type:
|
def generic_type(self) -> Type:
|
||||||
type: Type = self.named_type()
|
type: Type = self.named_type()
|
||||||
if self.check(TokenType.LEFT_BRACKET):
|
if self.check(TokenType.LEFT_BRACKET):
|
||||||
params: list[Type] = self.type_params()
|
args: list[Type] = self.type_args()
|
||||||
return GenericType(
|
return GenericType(
|
||||||
location=Location.span(type.location, self.previous().get_location()),
|
location=Location.span(type.location, self.previous().get_location()),
|
||||||
type=type,
|
type=type,
|
||||||
params=params,
|
args=args,
|
||||||
)
|
)
|
||||||
return type
|
return type
|
||||||
|
|
||||||
def type_params(self) -> list[Type]:
|
def type_args(self) -> list[Type]:
|
||||||
params: list[Type] = []
|
args: list[Type] = []
|
||||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters")
|
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments")
|
||||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||||
params.append(self.type_expr())
|
args.append(self.type_expr())
|
||||||
if not self.match(TokenType.COMMA):
|
if not self.match(TokenType.COMMA):
|
||||||
break
|
break
|
||||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters")
|
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
|
||||||
return params
|
return args
|
||||||
|
|
||||||
def named_type(self) -> Type:
|
def named_type(self) -> Type:
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
name: Token = self.consume_identifier("Expected type name")
|
||||||
return NamedType(
|
return NamedType(
|
||||||
location=name.get_location(),
|
location=name.get_location(),
|
||||||
name=name,
|
name=name,
|
||||||
@@ -322,9 +325,7 @@ class MidasParser(Parser):
|
|||||||
"""
|
"""
|
||||||
expr: Expr = self.primary()
|
expr: Expr = self.primary()
|
||||||
while self.match(TokenType.DOT):
|
while self.match(TokenType.DOT):
|
||||||
name: Token = self.consume(
|
name: Token = self.consume_identifier("Expected property name after '.'")
|
||||||
TokenType.IDENTIFIER, "Expected property name after '.'"
|
|
||||||
)
|
|
||||||
location: Location = Location.span(expr.location, name.get_location())
|
location: Location = Location.span(expr.location, name.get_location())
|
||||||
expr = GetExpr(location=location, expr=expr, name=name)
|
expr = GetExpr(location=location, expr=expr, name=name)
|
||||||
return expr
|
return expr
|
||||||
@@ -348,7 +349,7 @@ class MidasParser(Parser):
|
|||||||
if self.match(TokenType.NUMBER):
|
if self.match(TokenType.NUMBER):
|
||||||
return LiteralExpr(location=token.get_location(), value=token.value)
|
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||||
|
|
||||||
if self.match(TokenType.IDENTIFIER):
|
if self.match_identifier():
|
||||||
return VariableExpr(location=token.get_location(), name=token)
|
return VariableExpr(location=token.get_location(), name=token)
|
||||||
|
|
||||||
if self.match(TokenType.UNDERSCORE):
|
if self.match(TokenType.UNDERSCORE):
|
||||||
@@ -361,6 +362,20 @@ class MidasParser(Parser):
|
|||||||
|
|
||||||
raise self.error(self.peek(), "Expected expression")
|
raise self.error(self.peek(), "Expected expression")
|
||||||
|
|
||||||
|
def consume_identifier(self, message: str = "Expected identifier") -> Token:
|
||||||
|
if not self.match_identifier():
|
||||||
|
raise self.error(self.peek(), message)
|
||||||
|
return self.previous()
|
||||||
|
|
||||||
|
def match_identifier(self) -> bool:
|
||||||
|
return self.match(TokenType.IDENTIFIER, *KEYWORDS.values())
|
||||||
|
|
||||||
|
def check_identifier(self) -> bool:
|
||||||
|
for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]:
|
||||||
|
if self.check(tt):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def property_stmt(self) -> PropertyStmt:
|
def property_stmt(self) -> PropertyStmt:
|
||||||
"""Parse a property statement
|
"""Parse a property statement
|
||||||
|
|
||||||
@@ -369,7 +384,7 @@ class MidasParser(Parser):
|
|||||||
Returns:
|
Returns:
|
||||||
PropertyStmt: the parsed property statement
|
PropertyStmt: the parsed property statement
|
||||||
"""
|
"""
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
|
name: Token = self.consume_identifier("Expected property name")
|
||||||
self.consume(TokenType.COLON, "Expected ':' after property name")
|
self.consume(TokenType.COLON, "Expected ':' after property name")
|
||||||
type: Type = self.type_expr()
|
type: Type = self.type_expr()
|
||||||
return PropertyStmt(
|
return PropertyStmt(
|
||||||
@@ -381,12 +396,14 @@ class MidasParser(Parser):
|
|||||||
def extend_declaration(self) -> ExtendStmt:
|
def extend_declaration(self) -> ExtendStmt:
|
||||||
"""Parse an extension definition
|
"""Parse an extension definition
|
||||||
|
|
||||||
An extension is written `extend Type { operations }`
|
An extension is written `extend Type { operations }` or `extend[S <: T, U] Type { operations }`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ExtendStmt: the parsed extension statement
|
ExtendStmt: the parsed extension statement
|
||||||
"""
|
"""
|
||||||
keyword: Token = self.previous()
|
keyword: Token = self.previous()
|
||||||
|
params: list[TypeParam] = self.type_params()
|
||||||
|
|
||||||
type: Type = self.type_expr()
|
type: Type = self.type_expr()
|
||||||
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
|
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
|
||||||
operations: list[OpStmt] = []
|
operations: list[OpStmt] = []
|
||||||
@@ -394,7 +411,12 @@ class MidasParser(Parser):
|
|||||||
operations.append(self.op_declaration())
|
operations.append(self.op_declaration())
|
||||||
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
|
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
|
||||||
location: Location = keyword.location_to(self.previous())
|
location: Location = keyword.location_to(self.previous())
|
||||||
return ExtendStmt(location=location, type=type, operations=operations)
|
return ExtendStmt(
|
||||||
|
location=location,
|
||||||
|
params=params,
|
||||||
|
type=type,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
def op_declaration(self) -> OpStmt:
|
def op_declaration(self) -> OpStmt:
|
||||||
"""Parse an operation definition
|
"""Parse an operation definition
|
||||||
@@ -430,9 +452,9 @@ class MidasParser(Parser):
|
|||||||
PredicateStmt: the parsed predicate declaration statement
|
PredicateStmt: the parsed predicate declaration statement
|
||||||
"""
|
"""
|
||||||
keyword: Token = self.previous()
|
keyword: Token = self.previous()
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
|
name: Token = self.consume_identifier("Expected predicate name")
|
||||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
||||||
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
|
subject: Token = self.consume_identifier("Expected subject name")
|
||||||
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
||||||
type: Type = self.type_expr()
|
type: Type = self.type_expr()
|
||||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
||||||
@@ -445,3 +467,48 @@ class MidasParser(Parser):
|
|||||||
type=type,
|
type=type,
|
||||||
condition=condition,
|
condition=condition,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def function(self) -> FunctionType:
|
||||||
|
l_paren: Token = self.consume(
|
||||||
|
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
||||||
|
)
|
||||||
|
pos_args: list[FunctionType.Argument] = []
|
||||||
|
kw_args: list[FunctionType.Argument] = []
|
||||||
|
|
||||||
|
positional: bool = True
|
||||||
|
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
|
||||||
|
if positional and (
|
||||||
|
self.match(TokenType.STAR) or self.match(TokenType.SLASH)
|
||||||
|
):
|
||||||
|
positional = False
|
||||||
|
else:
|
||||||
|
name: Optional[Token] = None
|
||||||
|
if self.check_identifier() and self.check_next(TokenType.COLON):
|
||||||
|
name = self.advance()
|
||||||
|
self.advance()
|
||||||
|
type: Type = self.type_expr()
|
||||||
|
required: bool = self.match(TokenType.QMARK)
|
||||||
|
arg = FunctionType.Argument(
|
||||||
|
location=None,
|
||||||
|
name=name,
|
||||||
|
type=type,
|
||||||
|
required=required,
|
||||||
|
)
|
||||||
|
if positional:
|
||||||
|
pos_args.append(arg)
|
||||||
|
else:
|
||||||
|
kw_args.append(arg)
|
||||||
|
|
||||||
|
if not self.match(TokenType.COMMA):
|
||||||
|
break
|
||||||
|
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
||||||
|
|
||||||
|
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||||
|
result: Type = self.type_expr()
|
||||||
|
|
||||||
|
return FunctionType(
|
||||||
|
location=l_paren.location_to(self.previous()),
|
||||||
|
pos_args=pos_args,
|
||||||
|
kw_args=kw_args,
|
||||||
|
returns=result,
|
||||||
|
)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from midas.ast.python import (
|
|||||||
Function,
|
Function,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
IfStmt,
|
IfStmt,
|
||||||
|
ListExpr,
|
||||||
LiteralExpr,
|
LiteralExpr,
|
||||||
LogicalExpr,
|
LogicalExpr,
|
||||||
MidasType,
|
MidasType,
|
||||||
@@ -416,6 +417,12 @@ class PythonParser:
|
|||||||
case ast.Name(id=name):
|
case ast.Name(id=name):
|
||||||
return VariableExpr(location=location, name=name)
|
return VariableExpr(location=location, name=name)
|
||||||
|
|
||||||
|
case ast.List(elts=items):
|
||||||
|
return ListExpr(
|
||||||
|
location=location,
|
||||||
|
items=[self.parse_expr(item) for item in items],
|
||||||
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise UnsupportedSyntaxError(node)
|
raise UnsupportedSyntaxError(node)
|
||||||
|
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from midas.checker.types import BaseType, Type, UnitType
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from midas.resolver.midas import MidasResolver
|
|
||||||
|
|
||||||
|
|
||||||
def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
|
|
||||||
ctx.define_operation(
|
|
||||||
left=t1,
|
|
||||||
operator=operator,
|
|
||||||
right=t2,
|
|
||||||
result=t3,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def basic_op(ctx: MidasResolver, type: Type, op: str):
|
|
||||||
ctx.define_operation(
|
|
||||||
left=type,
|
|
||||||
operator=op,
|
|
||||||
right=type,
|
|
||||||
result=type,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def define_builtins(ctx: MidasResolver):
|
|
||||||
"""Define builtin types and operations"""
|
|
||||||
unit = ctx.define_type("None", UnitType())
|
|
||||||
bool = ctx.define_type("bool", BaseType(name="bool"))
|
|
||||||
int = ctx.define_type("int", BaseType(name="int"))
|
|
||||||
float = ctx.define_type("float", BaseType(name="float"))
|
|
||||||
str = ctx.define_type("str", BaseType(name="str"))
|
|
||||||
|
|
||||||
basic_op(ctx, int, "__add__") # int + int = int
|
|
||||||
basic_op(ctx, int, "__sub__") # int - int = int
|
|
||||||
basic_op(ctx, int, "__mul__") # int * int = int
|
|
||||||
basic_op(ctx, int, "__pow__") # int ** int = int
|
|
||||||
basic_op(ctx, int, "__mod__") # int % int = int
|
|
||||||
basic_op(ctx, int, "__and__") # int & int = int
|
|
||||||
basic_op(ctx, int, "__or__") # int | int = int
|
|
||||||
basic_op(ctx, int, "__xor__") # int ^ int = int
|
|
||||||
op(ctx, int, "__lt__", int, bool) # int < int = bool
|
|
||||||
op(ctx, int, "__gt__", int, bool) # int > int = bool
|
|
||||||
op(ctx, int, "__le__", int, bool) # int <= int = bool
|
|
||||||
op(ctx, int, "__ge__", int, bool) # int >= int = bool
|
|
||||||
op(ctx, int, "__eq__", int, bool) # int == int = bool
|
|
||||||
basic_op(ctx, float, "__add__") # float + float = float
|
|
||||||
basic_op(ctx, float, "__sub__") # float - float = float
|
|
||||||
basic_op(ctx, float, "__mul__") # float * float = float
|
|
||||||
basic_op(ctx, float, "__truediv__") # float / float = float
|
|
||||||
op(ctx, float, "__lt__", float, bool) # float < float = bool
|
|
||||||
op(ctx, float, "__gt__", float, bool) # float > float = bool
|
|
||||||
op(ctx, float, "__le__", float, bool) # float <= float = bool
|
|
||||||
op(ctx, float, "__ge__", float, bool) # float >= float = bool
|
|
||||||
op(ctx, float, "__eq__", float, bool) # float == float = bool
|
|
||||||
basic_op(ctx, str, "__add__") # str + str = str
|
|
||||||
op(ctx, str, "__eq__", str, bool) # str == str = bool
|
|
||||||
|
|
||||||
op(ctx, int, "__lt__", float, bool) # int < float = bool
|
|
||||||
op(ctx, int, "__gt__", float, bool) # int > float = bool
|
|
||||||
op(ctx, int, "__le__", float, bool) # int <= float = bool
|
|
||||||
op(ctx, int, "__ge__", float, bool) # int >= float = bool
|
|
||||||
op(ctx, int, "__eq__", float, bool) # int == float = bool
|
|
||||||
|
|
||||||
op(ctx, float, "__lt__", int, bool) # float < int = bool
|
|
||||||
op(ctx, float, "__gt__", int, bool) # float > int = bool
|
|
||||||
op(ctx, float, "__le__", int, bool) # float <= int = bool
|
|
||||||
op(ctx, float, "__ge__", int, bool) # float >= int = bool
|
|
||||||
op(ctx, float, "__eq__", int, bool) # float == int = bool
|
|
||||||
@@ -1,163 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import midas.ast.midas as m
|
|
||||||
from midas.checker.types import (
|
|
||||||
AliasType,
|
|
||||||
Type,
|
|
||||||
UnknownType,
|
|
||||||
)
|
|
||||||
from midas.resolver.builtin import define_builtins
|
|
||||||
|
|
||||||
|
|
||||||
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
|
||||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._types: dict[str, Type] = {}
|
|
||||||
self._operations: dict[tuple[Type, str, Type], Type] = {}
|
|
||||||
|
|
||||||
define_builtins(self)
|
|
||||||
|
|
||||||
def get_type(self, name: str) -> Type:
|
|
||||||
"""Get a type from its name
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): the name of the type
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NameError: if the type is not defined
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Type: the type
|
|
||||||
"""
|
|
||||||
type: Optional[Type] = self._types.get(name)
|
|
||||||
if type is None:
|
|
||||||
raise NameError(f"Undefined type {name}")
|
|
||||||
return type
|
|
||||||
|
|
||||||
def get_operation_result(
|
|
||||||
self, left: Type, operator: str, right: Type
|
|
||||||
) -> Optional[Type]:
|
|
||||||
"""Get the resulting type of an operation
|
|
||||||
|
|
||||||
Args:
|
|
||||||
left (Type): the type of the left operand
|
|
||||||
operator (str): the operation name
|
|
||||||
right (Type): the type of the right operand
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Type]: the result type, or None if no matching operation was found
|
|
||||||
"""
|
|
||||||
operation: tuple[Type, str, Type] = (left, operator, right)
|
|
||||||
result: Optional[Type] = self._operations.get(operation)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def define_type(self, name: str, type: Type) -> Type:
|
|
||||||
"""Define a type in the registry
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): the name of the type
|
|
||||||
type (Type): the type to define
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if a type is already defined with that name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Type: the defined type
|
|
||||||
"""
|
|
||||||
if name in self._types:
|
|
||||||
raise ValueError(f"Type {name} already defined")
|
|
||||||
self._types[name] = type
|
|
||||||
return type
|
|
||||||
|
|
||||||
def define_operation(self, left: Type, operator: str, right: Type, result: Type):
|
|
||||||
"""Define an operation in the registry
|
|
||||||
|
|
||||||
Args:
|
|
||||||
left (Type): the type of the left operand
|
|
||||||
operator (str): the operation name
|
|
||||||
right (Type): the type of the right operand
|
|
||||||
result (Type): the result type
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if an operation is already defined with these operands and name
|
|
||||||
"""
|
|
||||||
operation: tuple[Type, str, Type] = (left, operator, right)
|
|
||||||
if operation in self._operations:
|
|
||||||
raise ValueError(
|
|
||||||
f"Operation {operator} already defined between {left} and {right}"
|
|
||||||
)
|
|
||||||
self._operations[operation] = result
|
|
||||||
|
|
||||||
def resolve(self, stmts: list[m.Stmt]):
|
|
||||||
"""Process a sequence of statements
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stmts (list[m.Stmt]): the statements
|
|
||||||
"""
|
|
||||||
for stmt in stmts:
|
|
||||||
stmt.accept(self)
|
|
||||||
|
|
||||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
|
||||||
type: Type = stmt.type.accept(self)
|
|
||||||
for param in stmt.params:
|
|
||||||
if param.bound is not None:
|
|
||||||
param.bound.accept(self)
|
|
||||||
name: str = stmt.name.lexeme
|
|
||||||
self.define_type(name, AliasType(name=name, type=type))
|
|
||||||
|
|
||||||
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
|
|
||||||
|
|
||||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
|
||||||
base: Type = stmt.type.accept(self)
|
|
||||||
for op in stmt.operations:
|
|
||||||
right: Type = op.operand.accept(self)
|
|
||||||
result: Type = op.result.accept(self)
|
|
||||||
self.define_operation(
|
|
||||||
left=base,
|
|
||||||
operator=op.name.lexeme,
|
|
||||||
right=right,
|
|
||||||
result=result,
|
|
||||||
)
|
|
||||||
|
|
||||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
|
|
||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
|
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_get_expr(self, expr: m.GetExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
|
||||||
return expr.expr.accept(self)
|
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
|
||||||
return self.get_type(type.name.lexeme)
|
|
||||||
|
|
||||||
def visit_generic_type(self, type: m.GenericType) -> Type:
|
|
||||||
type_: Type = type.type.accept(self)
|
|
||||||
params: list[Type] = [param.accept(self) for param in type.params]
|
|
||||||
# TODO
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
|
||||||
type_: Type = type.type.accept(self)
|
|
||||||
type.constraint.accept(self)
|
|
||||||
# TODO
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
def visit_complex_type(self, type: m.ComplexType) -> Type:
|
|
||||||
for prop in type.properties:
|
|
||||||
prop.accept(self)
|
|
||||||
# TODO
|
|
||||||
return UnknownType()
|
|
||||||
@@ -29,7 +29,7 @@ class Tester(ABC):
|
|||||||
def _list_tests(self) -> list[Path]: ...
|
def _list_tests(self) -> list[Path]: ...
|
||||||
|
|
||||||
def run_all_tests(self) -> bool:
|
def run_all_tests(self) -> bool:
|
||||||
paths: list[Path] = self._list_tests()
|
paths: list[Path] = sorted(self._list_tests())
|
||||||
return self.run_tests(paths)
|
return self.run_tests(paths)
|
||||||
|
|
||||||
def run_tests(self, tests: list[Path]) -> bool:
|
def run_tests(self, tests: list[Path]) -> bool:
|
||||||
@@ -40,7 +40,7 @@ class Tester(ABC):
|
|||||||
|
|
||||||
print(rule)
|
print(rule)
|
||||||
for i, test in enumerate(tests):
|
for i, test in enumerate(tests):
|
||||||
print(f"Case {i+1}/{n}: {test.relative_to(self.CASES_DIR)}")
|
print(f"Case {i+1}/{n}: {test.resolve().relative_to(self.CASES_DIR)}")
|
||||||
success: bool = self._run_test(test)
|
success: bool = self._run_test(test)
|
||||||
if success:
|
if success:
|
||||||
successes += 1
|
successes += 1
|
||||||
@@ -78,7 +78,7 @@ class Tester(ABC):
|
|||||||
def _exec_case(self, path: Path) -> CaseResult: ...
|
def _exec_case(self, path: Path) -> CaseResult: ...
|
||||||
|
|
||||||
def update_all_tests(self):
|
def update_all_tests(self):
|
||||||
paths: list[Path] = self._list_tests()
|
paths: list[Path] = sorted(self._list_tests())
|
||||||
return self.update_tests(paths)
|
return self.update_tests(paths)
|
||||||
|
|
||||||
def update_tests(self, tests: list[Path]):
|
def update_tests(self, tests: list[Path]):
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
{
|
{
|
||||||
"diagnostics": []
|
"diagnostics": [],
|
||||||
|
"judgments": []
|
||||||
}
|
}
|
||||||
@@ -12,35 +12,168 @@
|
|||||||
13
|
13
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')"
|
"message": "Cannot assign str to variable 'c' of type int"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"judgments": [
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L1:9",
|
||||||
|
"to": "L1:10"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 3
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "Error",
|
|
||||||
"location": {
|
"location": {
|
||||||
"start": [
|
"from": "L2:9",
|
||||||
9,
|
"to": "L2:10"
|
||||||
4
|
|
||||||
],
|
|
||||||
"end": [
|
|
||||||
9,
|
|
||||||
9
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
"message": "Undefined operation __add__ between BaseType(name='bool') and BaseType(name='bool')"
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 4
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "Error",
|
|
||||||
"location": {
|
"location": {
|
||||||
"start": [
|
"from": "L4:4",
|
||||||
11,
|
"to": "L4:5"
|
||||||
0
|
|
||||||
],
|
|
||||||
"end": [
|
|
||||||
11,
|
|
||||||
12
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
"message": "Cannot assign BaseType(name='int') to f of type BaseType(name='float')"
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L4:8",
|
||||||
|
"to": "L4:9"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L4:4",
|
||||||
|
"to": "L4:9"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"operator": "+",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:4",
|
||||||
|
"to": "L6:13"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": "invalid"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "str"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L8:4",
|
||||||
|
"to": "L8:8"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": true
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:4",
|
||||||
|
"to": "L9:5"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "d"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:8",
|
||||||
|
"to": "L9:9"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "d"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:4",
|
||||||
|
"to": "L9:9"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "d"
|
||||||
|
},
|
||||||
|
"operator": "+",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "d"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L11:11",
|
||||||
|
"to": "L11:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,109 @@
|
|||||||
{
|
{
|
||||||
"diagnostics": []
|
"diagnostics": [],
|
||||||
|
"judgments": [
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L4:18",
|
||||||
|
"to": "L4:37"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "CastExpr",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "Meter",
|
||||||
|
"param": null
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 123.45
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Meter",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L5:15",
|
||||||
|
"to": "L5:32"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "CastExpr",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "Second",
|
||||||
|
"param": null
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 6.7
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Second",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:8",
|
||||||
|
"to": "L6:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "distance"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Meter",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:19",
|
||||||
|
"to": "L6:23"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "time"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Second",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:8",
|
||||||
|
"to": "L6:23"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "distance"
|
||||||
|
},
|
||||||
|
"operator": "/",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "time"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "MeterPerSecond",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
@@ -42,5 +42,215 @@
|
|||||||
},
|
},
|
||||||
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
|
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
|
||||||
}
|
}
|
||||||
|
],
|
||||||
|
"judgments": [
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L2:11",
|
||||||
|
"to": "L2:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L2:15",
|
||||||
|
"to": "L2:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L5:7",
|
||||||
|
"to": "L5:8"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L5:11",
|
||||||
|
"to": "L5:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:15",
|
||||||
|
"to": "L6:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:19",
|
||||||
|
"to": "L6:20"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L8:15",
|
||||||
|
"to": "L8:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L8:19",
|
||||||
|
"to": "L8:20"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L15:7",
|
||||||
|
"to": "L15:8"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L15:11",
|
||||||
|
"to": "L15:13"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 10
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L16:15",
|
||||||
|
"to": "L16:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L16:19",
|
||||||
|
"to": "L16:21"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 10
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L22:7",
|
||||||
|
"to": "L22:8"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L22:11",
|
||||||
|
"to": "L22:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L23:15",
|
||||||
|
"to": "L23:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L23:19",
|
||||||
|
"to": "L23:20"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
12
tests/cases/checker/06_subtyping.py
Normal file
12
tests/cases/checker/06_subtyping.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
v1: int = 3
|
||||||
|
v2: float = 4
|
||||||
|
|
||||||
|
|
||||||
|
def maximum(a: float, b: float):
|
||||||
|
if b > a:
|
||||||
|
return b
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
v3 = maximum(v1, v2)
|
||||||
|
v3 = v1 + v2
|
||||||
193
tests/cases/checker/06_subtyping.py.ref.json
Normal file
193
tests/cases/checker/06_subtyping.py.ref.json
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
{
|
||||||
|
"diagnostics": [],
|
||||||
|
"judgments": [
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L1:10",
|
||||||
|
"to": "L1:11"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 3
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L2:12",
|
||||||
|
"to": "L2:13"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 4
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:7",
|
||||||
|
"to": "L6:8"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:11",
|
||||||
|
"to": "L6:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L11:5",
|
||||||
|
"to": "L11:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "maximum"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "maximum",
|
||||||
|
"pos_args": [],
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "a",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"pos": 1,
|
||||||
|
"name": "b",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L11:13",
|
||||||
|
"to": "L11:15"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v1"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L11:17",
|
||||||
|
"to": "L11:19"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v2"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L11:5",
|
||||||
|
"to": "L11:20"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "CallExpr",
|
||||||
|
"callee": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "maximum"
|
||||||
|
},
|
||||||
|
"arguments": [
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v2"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"keywords": {}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L12:5",
|
||||||
|
"to": "L12:7"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v1"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L12:10",
|
||||||
|
"to": "L12:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v2"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L12:5",
|
||||||
|
"to": "L12:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v1"
|
||||||
|
},
|
||||||
|
"operator": "+",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -2385,7 +2385,7 @@
|
|||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Difference"
|
"name": "Difference"
|
||||||
},
|
},
|
||||||
"params": [
|
"args": [
|
||||||
{
|
{
|
||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "GeoLocation"
|
"name": "GeoLocation"
|
||||||
@@ -2416,7 +2416,7 @@
|
|||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Difference"
|
"name": "Difference"
|
||||||
},
|
},
|
||||||
"params": [
|
"args": [
|
||||||
{
|
{
|
||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Latitude"
|
"name": "Latitude"
|
||||||
@@ -2433,7 +2433,7 @@
|
|||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Difference"
|
"name": "Difference"
|
||||||
},
|
},
|
||||||
"params": [
|
"args": [
|
||||||
{
|
{
|
||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Longitude"
|
"name": "Longitude"
|
||||||
@@ -2464,7 +2464,7 @@
|
|||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Difference"
|
"name": "Difference"
|
||||||
},
|
},
|
||||||
"params": [
|
"args": [
|
||||||
{
|
{
|
||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Latitude"
|
"name": "Latitude"
|
||||||
@@ -2494,7 +2494,7 @@
|
|||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Difference"
|
"name": "Difference"
|
||||||
},
|
},
|
||||||
"params": [
|
"args": [
|
||||||
{
|
{
|
||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Longitude"
|
"name": "Longitude"
|
||||||
@@ -2638,7 +2638,7 @@
|
|||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Optional"
|
"name": "Optional"
|
||||||
},
|
},
|
||||||
"params": [
|
"args": [
|
||||||
{
|
{
|
||||||
"_type": "ConstraintType",
|
"_type": "ConstraintType",
|
||||||
"type": {
|
"type": {
|
||||||
|
|||||||
@@ -1,19 +1,19 @@
|
|||||||
import ast
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.checker.checker import Checker
|
from midas.checker.checker import TypeChecker
|
||||||
from midas.checker.diagnostic import Diagnostic
|
from midas.checker.diagnostic import Diagnostic
|
||||||
from midas.parser.python import PythonParser
|
from midas.checker.types import Type
|
||||||
from midas.resolver.resolver import Resolver
|
|
||||||
from tests.base import Tester
|
from tests.base import Tester
|
||||||
|
from tests.serializer.python import PythonAstJsonSerializer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CaseResult:
|
class CaseResult:
|
||||||
diagnostics: list[dict] = field(default_factory=list)
|
diagnostics: list[dict] = field(default_factory=list)
|
||||||
|
judgments: list = field(default_factory=list)
|
||||||
|
|
||||||
def dumps(self) -> str:
|
def dumps(self) -> str:
|
||||||
return json.dumps(asdict(self), indent=2)
|
return json.dumps(asdict(self), indent=2)
|
||||||
@@ -33,23 +33,16 @@ class CheckerTester(Tester):
|
|||||||
if not path.is_file():
|
if not path.is_file():
|
||||||
raise TypeError(f"Test '{path}' is not a file")
|
raise TypeError(f"Test '{path}' is not a file")
|
||||||
|
|
||||||
types_paths: list[Path] = []
|
result: CaseResult = CaseResult()
|
||||||
|
|
||||||
|
checker = TypeChecker()
|
||||||
types_path: Path = path.with_suffix(".midas")
|
types_path: Path = path.with_suffix(".midas")
|
||||||
if types_path.exists():
|
if types_path.exists():
|
||||||
types_paths.append(types_path)
|
checker.import_midas(types_path)
|
||||||
source: str = path.read_text()
|
|
||||||
tree: ast.Module = ast.parse(source, filename=path)
|
checker.type_check(path)
|
||||||
parser = PythonParser()
|
|
||||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
diagnostics: list[Diagnostic] = checker.diagnostics
|
||||||
resolver = Resolver()
|
|
||||||
resolver.resolve(*stmts)
|
|
||||||
result: CaseResult = CaseResult()
|
|
||||||
checker = Checker(
|
|
||||||
resolver.locals,
|
|
||||||
source_path=path,
|
|
||||||
types_paths=types_paths,
|
|
||||||
)
|
|
||||||
diagnostics: list[Diagnostic] = checker.check(stmts)
|
|
||||||
for diagnostic in diagnostics:
|
for diagnostic in diagnostics:
|
||||||
result.diagnostics.append(
|
result.diagnostics.append(
|
||||||
{
|
{
|
||||||
@@ -68,6 +61,21 @@ class CheckerTester(Tester):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
judgements: list[tuple[p.Expr, Type]] = checker.python_typer.judgements
|
||||||
|
serializer = PythonAstJsonSerializer()
|
||||||
|
for expr, type in judgements:
|
||||||
|
loc = expr.location
|
||||||
|
result.judgments.append(
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": f"L{loc.lineno}:{loc.col_offset}",
|
||||||
|
"to": f"L{loc.end_lineno}:{loc.end_col_offset}",
|
||||||
|
},
|
||||||
|
"expr": expr.accept(serializer),
|
||||||
|
"type": asdict(type),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from midas.ast.midas import (
|
|||||||
ConstraintType,
|
ConstraintType,
|
||||||
Expr,
|
Expr,
|
||||||
ExtendStmt,
|
ExtendStmt,
|
||||||
|
FunctionType,
|
||||||
GenericType,
|
GenericType,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
GroupingExpr,
|
GroupingExpr,
|
||||||
@@ -17,6 +18,7 @@ from midas.ast.midas import (
|
|||||||
PropertyStmt,
|
PropertyStmt,
|
||||||
Stmt,
|
Stmt,
|
||||||
Type,
|
Type,
|
||||||
|
TypeParam,
|
||||||
TypeStmt,
|
TypeStmt,
|
||||||
UnaryExpr,
|
UnaryExpr,
|
||||||
VariableExpr,
|
VariableExpr,
|
||||||
@@ -46,13 +48,11 @@ class MidasAstJsonSerializer(
|
|||||||
return {
|
return {
|
||||||
"_type": "TypeStmt",
|
"_type": "TypeStmt",
|
||||||
"name": stmt.name.lexeme,
|
"name": stmt.name.lexeme,
|
||||||
"params": [
|
"params": [self._serialize_type_param(param) for param in stmt.params],
|
||||||
self._serialize_type_stmt_template_param(param) for param in stmt.params
|
|
||||||
],
|
|
||||||
"type": stmt.type.accept(self),
|
"type": stmt.type.accept(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict:
|
def _serialize_type_param(self, param: TypeParam) -> dict:
|
||||||
return {
|
return {
|
||||||
"name": param.name.lexeme,
|
"name": param.name.lexeme,
|
||||||
"bound": self._serialize_optional(param.bound),
|
"bound": self._serialize_optional(param.bound),
|
||||||
@@ -150,7 +150,7 @@ class MidasAstJsonSerializer(
|
|||||||
return {
|
return {
|
||||||
"_type": "GenericType",
|
"_type": "GenericType",
|
||||||
"type": type.type.accept(self),
|
"type": type.type.accept(self),
|
||||||
"params": self._serialize_list(type.params),
|
"args": self._serialize_list(type.args),
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_constraint_type(self, type: ConstraintType) -> dict:
|
def visit_constraint_type(self, type: ConstraintType) -> dict:
|
||||||
@@ -165,3 +165,18 @@ class MidasAstJsonSerializer(
|
|||||||
"_type": "ComplexType",
|
"_type": "ComplexType",
|
||||||
"properties": self._serialize_list(type.properties),
|
"properties": self._serialize_list(type.properties),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def visit_function_type(self, type: FunctionType) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "FunctionType",
|
||||||
|
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
|
||||||
|
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
|
||||||
|
"returns": type.returns.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
||||||
|
return {
|
||||||
|
"name": arg.name,
|
||||||
|
"type": arg.type.accept(self),
|
||||||
|
"required": arg.required,
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,11 +16,11 @@ from midas.ast.python import (
|
|||||||
Function,
|
Function,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
IfStmt,
|
IfStmt,
|
||||||
|
ListExpr,
|
||||||
LiteralExpr,
|
LiteralExpr,
|
||||||
LogicalExpr,
|
LogicalExpr,
|
||||||
MidasType,
|
MidasType,
|
||||||
ReturnStmt,
|
ReturnStmt,
|
||||||
SetExpr,
|
|
||||||
Stmt,
|
Stmt,
|
||||||
TernaryExpr,
|
TernaryExpr,
|
||||||
TypeAssign,
|
TypeAssign,
|
||||||
@@ -232,14 +232,6 @@ class PythonAstJsonSerializer(
|
|||||||
"right": expr.right.accept(self),
|
"right": expr.right.accept(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_set_expr(self, expr: SetExpr) -> dict:
|
|
||||||
return {
|
|
||||||
"_type": "SetExpr",
|
|
||||||
"object": expr.object.accept(self),
|
|
||||||
"name": expr.name,
|
|
||||||
"value": expr.value.accept(self),
|
|
||||||
}
|
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: CastExpr) -> dict:
|
def visit_cast_expr(self, expr: CastExpr) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "CastExpr",
|
"_type": "CastExpr",
|
||||||
@@ -254,3 +246,9 @@ class PythonAstJsonSerializer(
|
|||||||
"if_true": expr.if_true.accept(self),
|
"if_true": expr.if_true.accept(self),
|
||||||
"if_false": expr.if_false.accept(self),
|
"if_false": expr.if_false.accept(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def visit_list_expr(self, expr: ListExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "ListExpr",
|
||||||
|
"items": [item.accept(self) for item in expr.items],
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user