Compare commits
35 Commits
main
...
feat/gener
| Author | SHA1 | Date | |
|---|---|---|---|
|
c6ead886ec
|
|||
|
9de03bf2b5
|
|||
|
a26b9293be
|
|||
|
efa5454776
|
|||
|
b8bb8190c4
|
|||
|
a4f5db7ece
|
|||
|
fc67f01f34
|
|||
|
0a748a36a3
|
|||
|
89fdd1b47e
|
|||
|
0cde53ac6e
|
|||
|
f3ec3606c2
|
|||
|
67ec029529
|
|||
|
e2aef7a811
|
|||
|
86ba4e658a
|
|||
|
7eccf59558
|
|||
|
9dd7801d2d
|
|||
|
154cb8b314
|
|||
|
c64ab434b5
|
|||
|
25e6410546
|
|||
|
8a22acc17c
|
|||
|
e0179bc442
|
|||
|
e665d03533
|
|||
|
b8cb2b4273
|
|||
|
d278dc5f5b
|
|||
|
59e73f0fd9
|
|||
|
3e0dc60283
|
|||
|
c24eb5125e
|
|||
|
25bd895dde
|
|||
|
bccd75317e
|
|||
|
f0e3f7574f
|
|||
|
5d44081847
|
|||
|
2a2bb0aec7
|
|||
|
67c40a3909
|
|||
|
1c30188122
|
|||
|
82a0f13242
|
11
examples/01_simple_type_checking/04_complex_types.midas
Normal file
11
examples/01_simple_type_checking/04_complex_types.midas
Normal file
@@ -0,0 +1,11 @@
|
||||
type Meter = float
|
||||
|
||||
extend Meter {
|
||||
op __add__(Meter) -> Meter
|
||||
op __sub__(Meter) -> Meter
|
||||
}
|
||||
|
||||
type Coordinate = {
|
||||
x: Meter
|
||||
y: Meter
|
||||
}
|
||||
14
examples/01_simple_type_checking/04_complex_types.py
Normal file
14
examples/01_simple_type_checking/04_complex_types.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
p1: Coordinate
|
||||
p2: Coordinate
|
||||
|
||||
diff_x = p2.x - p1.x
|
||||
diff_y = p2.y - p1.y
|
||||
|
||||
dist = diff_x + diff_y
|
||||
|
||||
p2.x += cast(Meter, 1)
|
||||
p2.y = True
|
||||
p2.z = 3
|
||||
p2.x.a = 3
|
||||
15
gen/gen.py
15
gen/gen.py
@@ -30,6 +30,7 @@ from __future__ import annotations
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
{preamble}
|
||||
{sections}
|
||||
"""
|
||||
|
||||
@@ -57,6 +58,11 @@ IMPORTS_REGEX = re.compile(
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
PREAMBLE_REGEX = re.compile(
|
||||
r"^###>\s*Preamble\s*?\n(?P<body>.*?)\n###<$",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def snake_case(text: str) -> str:
|
||||
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
||||
@@ -88,13 +94,14 @@ def make_banner(text: str) -> str:
|
||||
|
||||
|
||||
def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
||||
print(f" Generating {full_name}")
|
||||
visitor_methods: list[str] = []
|
||||
classes: list[str] = []
|
||||
definitions: list[str] = body.strip("\n").split("\n\n\n")
|
||||
for cls in definitions:
|
||||
cls = cls.strip("\n")
|
||||
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
|
||||
print(f"Processing {name}")
|
||||
print(f" Processing {name}")
|
||||
visitor_methods.append(make_visitor_method(name, param))
|
||||
classes.append(make_class(name, cls, base))
|
||||
|
||||
@@ -107,6 +114,7 @@ def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
||||
|
||||
|
||||
def generate(definitions_path: Path, out_path: Path):
|
||||
print(f"Processing generating {out_path} from {definitions_path}")
|
||||
root_dir: Path = Path(__file__).parent.parent
|
||||
rel_path: Path = definitions_path.relative_to(root_dir)
|
||||
src: str = definitions_path.read_text()
|
||||
@@ -116,6 +124,10 @@ def generate(definitions_path: Path, out_path: Path):
|
||||
if m := IMPORTS_REGEX.search(src):
|
||||
imports = m.group("body").strip("\n")
|
||||
|
||||
preamble: str = ""
|
||||
if m := PREAMBLE_REGEX.search(src):
|
||||
preamble = m.group("body")
|
||||
|
||||
for section_m in SECTION_REGEX.finditer(src):
|
||||
full_name: str = section_m.group("name")
|
||||
base: str = section_m.group("base")
|
||||
@@ -129,6 +141,7 @@ def generate(definitions_path: Path, out_path: Path):
|
||||
gen_path=Path(__file__).relative_to(root_dir),
|
||||
),
|
||||
imports=imports,
|
||||
preamble=preamble,
|
||||
sections="\n\n\n".join(sections),
|
||||
)
|
||||
out_path.write_text(result)
|
||||
|
||||
35
gen/midas.py
35
gen/midas.py
@@ -12,18 +12,23 @@ from midas.lexer.token import Token
|
||||
###<
|
||||
|
||||
|
||||
###> Preamble
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypeParam:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Stmt | Statements
|
||||
class TypeStmt:
|
||||
name: Token
|
||||
params: list[Param]
|
||||
params: list[TypeParam]
|
||||
type: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Param:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
|
||||
class PropertyStmt:
|
||||
name: Token
|
||||
@@ -31,6 +36,7 @@ class PropertyStmt:
|
||||
|
||||
|
||||
class ExtendStmt:
|
||||
params: list[TypeParam]
|
||||
type: Type
|
||||
operations: list[OpStmt]
|
||||
|
||||
@@ -103,7 +109,7 @@ class NamedType:
|
||||
|
||||
class GenericType:
|
||||
type: Type
|
||||
params: list[Type]
|
||||
args: list[Type]
|
||||
|
||||
|
||||
class ConstraintType:
|
||||
@@ -115,4 +121,17 @@ class ComplexType:
|
||||
properties: list[PropertyStmt]
|
||||
|
||||
|
||||
class FunctionType:
|
||||
pos_args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
location: Optional[Location] = None
|
||||
name: Optional[Token]
|
||||
type: Type
|
||||
required: bool
|
||||
|
||||
|
||||
###<
|
||||
|
||||
@@ -128,12 +128,6 @@ class LogicalExpr:
|
||||
right: Expr
|
||||
|
||||
|
||||
class SetExpr:
|
||||
object: Expr
|
||||
name: str
|
||||
value: Expr
|
||||
|
||||
|
||||
class CastExpr:
|
||||
type: MidasType
|
||||
expr: Expr
|
||||
@@ -145,4 +139,8 @@ class TernaryExpr:
|
||||
if_false: Expr
|
||||
|
||||
|
||||
class ListExpr:
|
||||
items: list[Expr]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
@@ -14,6 +14,13 @@ from midas.lexer.token import Token
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypeParam:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
|
||||
##############
|
||||
# Statements #
|
||||
##############
|
||||
@@ -46,15 +53,9 @@ class Stmt(ABC):
|
||||
@dataclass(frozen=True)
|
||||
class TypeStmt(Stmt):
|
||||
name: Token
|
||||
params: list[Param]
|
||||
params: list[TypeParam]
|
||||
type: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Param:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_type_stmt(self)
|
||||
|
||||
@@ -70,6 +71,7 @@ class PropertyStmt(Stmt):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtendStmt(Stmt):
|
||||
params: list[TypeParam]
|
||||
type: Type
|
||||
operations: list[OpStmt]
|
||||
|
||||
@@ -231,6 +233,9 @@ class Type(ABC):
|
||||
@abstractmethod
|
||||
def visit_complex_type(self, type: ComplexType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_function_type(self, type: FunctionType) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NamedType(Type):
|
||||
@@ -243,7 +248,7 @@ class NamedType(Type):
|
||||
@dataclass(frozen=True)
|
||||
class GenericType(Type):
|
||||
type: Type
|
||||
params: list[Type]
|
||||
args: list[Type]
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_generic_type(self)
|
||||
@@ -264,3 +269,20 @@ class ComplexType(Type):
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_complex_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionType(Type):
|
||||
pos_args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
location: Optional[Location] = None
|
||||
name: Optional[Token]
|
||||
type: Type
|
||||
required: bool
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_function_type(self)
|
||||
|
||||
@@ -100,12 +100,12 @@ class MidasAstPrinter(
|
||||
self._idx = i
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
self._print_type_stmt_param(param)
|
||||
self._print_type_param(param)
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None:
|
||||
def _print_type_param(self, param: m.TypeParam) -> None:
|
||||
self._write_line("Param")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{param.name.lexeme}"')
|
||||
@@ -122,6 +122,13 @@ class MidasAstPrinter(
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self._write_line("ExtendStmt")
|
||||
with self._child_level():
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
for i, param in enumerate(stmt.params):
|
||||
self._idx = i
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
self._print_type_param(param)
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
@@ -234,11 +241,11 @@ class MidasAstPrinter(
|
||||
self._write_line("type")
|
||||
with self._child_level():
|
||||
type.type.accept(self)
|
||||
self._write_line("params", last=True)
|
||||
self._write_line("args", last=True)
|
||||
with self._child_level():
|
||||
for i, param in enumerate(type.params):
|
||||
for i, param in enumerate(type.args):
|
||||
self._idx = i
|
||||
if i == len(type.params) - 1:
|
||||
if i == len(type.args) - 1:
|
||||
self._mark_last()
|
||||
param.accept(self)
|
||||
|
||||
@@ -263,6 +270,41 @@ class MidasAstPrinter(
|
||||
self._mark_last()
|
||||
prop.accept(self)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||
self._write_line("FunctionType")
|
||||
with self._child_level():
|
||||
self._write_line("pos_args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.pos_args):
|
||||
self._idx = i
|
||||
if i == len(type.pos_args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("kw_args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.kw_args):
|
||||
self._idx = i
|
||||
if i == len(type.kw_args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("returns", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.returns.accept(self)
|
||||
|
||||
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
||||
self._write_line("Argument")
|
||||
with self._child_level():
|
||||
name: str = "None"
|
||||
if arg.name is not None:
|
||||
name = f'"{arg.name.lexeme}"'
|
||||
self._write_line(f"name: {name}")
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
arg.type.accept(self)
|
||||
self._write_line(f"required: {arg.required}", last=True)
|
||||
|
||||
|
||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
||||
def __init__(self, indent: int = 4):
|
||||
@@ -279,14 +321,12 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
|
||||
template: str = ""
|
||||
if len(stmt.params) != 0:
|
||||
params: list[str] = [
|
||||
self._print_type_template_param(param) for param in stmt.params
|
||||
]
|
||||
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||
template = f"[{', '.join(params)}]"
|
||||
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def _print_type_template_param(self, param: m.TypeStmt.Param) -> str:
|
||||
def _print_type_param(self, param: m.TypeParam) -> str:
|
||||
res: str = param.name.lexeme
|
||||
if param.bound is not None:
|
||||
res += "<:" + param.bound.accept(self)
|
||||
@@ -358,9 +398,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
if len(type.params) != 0:
|
||||
params: list[str] = [param.accept(self) for param in type.params]
|
||||
res += f"[{', '.join(params)}]"
|
||||
if len(type.args) != 0:
|
||||
args: list[str] = [param.accept(self) for param in type.args]
|
||||
res += f"[{', '.join(args)}]"
|
||||
return res
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> str:
|
||||
@@ -378,6 +418,29 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
||||
kw_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
||||
args: list[str] = pos_args
|
||||
|
||||
if len(pos_args) != 0:
|
||||
args.append("/")
|
||||
if len(kw_args) != 0:
|
||||
args.append("*")
|
||||
args += kw_args
|
||||
|
||||
return f"({', '.join(args)}) -> {type.returns.accept(self)}"
|
||||
|
||||
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
||||
res: str = ""
|
||||
if arg.name is not None:
|
||||
res += arg.name.lexeme
|
||||
res += ": "
|
||||
res += arg.type.accept(self)
|
||||
if not arg.required:
|
||||
res += "?"
|
||||
return res
|
||||
|
||||
|
||||
class PythonAstPrinter(
|
||||
AstPrinter,
|
||||
@@ -602,17 +665,6 @@ class PythonAstPrinter(
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_set_expr(self, expr: p.SetExpr) -> None:
|
||||
self._write_line("SetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("object")
|
||||
with self._child_level(single=True):
|
||||
expr.object.accept(self)
|
||||
self._write_line(f"name: {expr.name}")
|
||||
self._write_line("value", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.value.accept(self)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||
self._write_line("CastExpr")
|
||||
with self._child_level():
|
||||
@@ -637,3 +689,14 @@ class PythonAstPrinter(
|
||||
self._write_line("if_false", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.if_false.accept(self)
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||
self._write_line("ListExpr")
|
||||
with self._child_level():
|
||||
self._write_line("items", last=True)
|
||||
with self._child_level():
|
||||
for i, item in enumerate(expr.items):
|
||||
self._idx = i
|
||||
if i == len(expr.items) - 1:
|
||||
self._mark_last()
|
||||
item.accept(self)
|
||||
|
||||
@@ -14,6 +14,7 @@ from midas.ast.location import Location
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
####################
|
||||
# Type annotations #
|
||||
####################
|
||||
@@ -214,15 +215,15 @@ class Expr(ABC):
|
||||
@abstractmethod
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_set_expr(self, expr: SetExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_cast_expr(self, expr: CastExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_list_expr(self, expr: ListExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BinaryExpr(Expr):
|
||||
@@ -298,16 +299,6 @@ class LogicalExpr(Expr):
|
||||
return visitor.visit_logical_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SetExpr(Expr):
|
||||
object: Expr
|
||||
name: str
|
||||
value: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_set_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CastExpr(Expr):
|
||||
type: MidasType
|
||||
@@ -325,3 +316,11 @@ class TernaryExpr(Expr):
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_ternary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListExpr(Expr):
|
||||
items: list[Expr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_list_expr(self)
|
||||
|
||||
112
midas/checker/builtins.py
Normal file
112
midas/checker/builtins.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from midas.checker.types import (
|
||||
BaseType,
|
||||
ComplexType,
|
||||
Function,
|
||||
GenericType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.registry import TypesRegistry
|
||||
|
||||
|
||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||
"float": {"int"},
|
||||
"int": {"bool"},
|
||||
}
|
||||
|
||||
|
||||
def op(reg: TypesRegistry, t1: Type, operator: str, t2: Type, t3: Type):
|
||||
reg.define_operation(
|
||||
left=t1,
|
||||
operator=operator,
|
||||
right=t2,
|
||||
result=t3,
|
||||
)
|
||||
|
||||
|
||||
def basic_op(reg: TypesRegistry, type: Type, op: str):
|
||||
reg.define_operation(
|
||||
left=type,
|
||||
operator=op,
|
||||
right=type,
|
||||
result=type,
|
||||
)
|
||||
|
||||
|
||||
def define_builtins(reg: TypesRegistry):
|
||||
"""Define builtin types and operations"""
|
||||
unit = reg.define_type("None", UnitType())
|
||||
bool = reg.define_type("bool", BaseType(name="bool"))
|
||||
int = reg.define_type("int", BaseType(name="int"))
|
||||
float = reg.define_type("float", BaseType(name="float"))
|
||||
str = reg.define_type("str", BaseType(name="str"))
|
||||
|
||||
basic_op(reg, int, "__add__") # int + int = int
|
||||
basic_op(reg, int, "__sub__") # int - int = int
|
||||
basic_op(reg, int, "__mul__") # int * int = int
|
||||
basic_op(reg, int, "__pow__") # int ** int = int
|
||||
basic_op(reg, int, "__mod__") # int % int = int
|
||||
basic_op(reg, int, "__and__") # int & int = int
|
||||
basic_op(reg, int, "__or__") # int | int = int
|
||||
basic_op(reg, int, "__xor__") # int ^ int = int
|
||||
op(reg, int, "__lt__", int, bool) # int < int = bool
|
||||
op(reg, int, "__gt__", int, bool) # int > int = bool
|
||||
op(reg, int, "__le__", int, bool) # int <= int = bool
|
||||
op(reg, int, "__ge__", int, bool) # int >= int = bool
|
||||
op(reg, int, "__eq__", int, bool) # int == int = bool
|
||||
basic_op(reg, float, "__add__") # float + float = float
|
||||
basic_op(reg, float, "__sub__") # float - float = float
|
||||
basic_op(reg, float, "__mul__") # float * float = float
|
||||
basic_op(reg, float, "__truediv__") # float / float = float
|
||||
op(reg, float, "__lt__", float, bool) # float < float = bool
|
||||
op(reg, float, "__gt__", float, bool) # float > float = bool
|
||||
op(reg, float, "__le__", float, bool) # float <= float = bool
|
||||
op(reg, float, "__ge__", float, bool) # float >= float = bool
|
||||
op(reg, float, "__eq__", float, bool) # float == float = bool
|
||||
basic_op(reg, str, "__add__") # str + str = str
|
||||
op(reg, str, "__eq__", str, bool) # str == str = bool
|
||||
|
||||
op(reg, int, "__lt__", float, bool) # int < float = bool
|
||||
op(reg, int, "__gt__", float, bool) # int > float = bool
|
||||
op(reg, int, "__le__", float, bool) # int <= float = bool
|
||||
op(reg, int, "__ge__", float, bool) # int >= float = bool
|
||||
op(reg, int, "__eq__", float, bool) # int == float = bool
|
||||
|
||||
op(reg, float, "__lt__", int, bool) # float < int = bool
|
||||
op(reg, float, "__gt__", int, bool) # float > int = bool
|
||||
op(reg, float, "__le__", int, bool) # float <= int = bool
|
||||
op(reg, float, "__ge__", int, bool) # float >= int = bool
|
||||
op(reg, float, "__eq__", int, bool) # float == int = bool
|
||||
|
||||
list = reg.define_type(
|
||||
"list",
|
||||
GenericType(
|
||||
name="list",
|
||||
params=[TypeVar(name="T", bound=None)],
|
||||
body=ComplexType(
|
||||
properties={
|
||||
"append": Function(
|
||||
name="append",
|
||||
pos_args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="object",
|
||||
type=TypeVar(name="T", bound=None),
|
||||
required=True,
|
||||
)
|
||||
],
|
||||
args=[],
|
||||
kw_args=[],
|
||||
returns=UnitType(),
|
||||
)
|
||||
}
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -1,540 +1,35 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
|
||||
from midas.checker.types import Function, Type, UnitType, UnknownType
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
from midas.resolver.midas import MidasResolver
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.checker.midas import MidasTyper
|
||||
from midas.checker.python import PythonTyper
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import Reporter
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
pass
|
||||
class TypeChecker:
|
||||
def __init__(self):
|
||||
self.types: TypesRegistry = TypesRegistry()
|
||||
self.reporter: Reporter = Reporter()
|
||||
|
||||
self.midas_typer = MidasTyper(self.types, self.reporter)
|
||||
self.python_typer = PythonTyper(self.types, self.reporter)
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MappedArgument:
|
||||
expr: p.Expr
|
||||
type: Type
|
||||
argument: Function.Argument
|
||||
def import_midas(self, path: Path):
|
||||
source: str = path.read_text()
|
||||
return self.import_midas_source(source, path=str(path))
|
||||
|
||||
def import_midas_source(self, source: str, path: Optional[str] = None):
|
||||
self.midas_typer.process(source, path)
|
||||
|
||||
class Checker(
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[Type],
|
||||
p.MidasType.Visitor[Type],
|
||||
):
|
||||
"""A type checker which can use custom type definitions"""
|
||||
def type_check(self, path: Path):
|
||||
source: str = path.read_text()
|
||||
return self.type_check_source(source, path=str(path))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
locals: dict[p.Expr, int],
|
||||
source_path: Path,
|
||||
types_paths: list[Path],
|
||||
):
|
||||
self.logger: logging.Logger = logging.getLogger("Checker")
|
||||
self.source_path: Path = source_path
|
||||
self.types_paths: list[Path] = types_paths
|
||||
self.ctx: MidasResolver = MidasResolver()
|
||||
self.global_env: Environment = Environment()
|
||||
self.env: Environment = self.global_env
|
||||
self.locals: dict[p.Expr, int] = locals
|
||||
self.diagnostics: list[Diagnostic] = []
|
||||
def type_check_source(self, source: str, path: Optional[str] = None):
|
||||
self.python_typer.process(source, path)
|
||||
|
||||
def diagnostic(self, type: DiagnosticType, location: Location, message: str):
|
||||
self.diagnostics.append(
|
||||
Diagnostic(
|
||||
file_path=self.source_path,
|
||||
location=location,
|
||||
type=type,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
|
||||
def error(self, location: Location, message: str):
|
||||
self.diagnostic(
|
||||
type=DiagnosticType.ERROR,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def warning(self, location: Location, message: str):
|
||||
self.diagnostic(
|
||||
type=DiagnosticType.WARNING,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def info(self, location: Location, message: str):
|
||||
self.diagnostic(
|
||||
type=DiagnosticType.INFO,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def type_of(self, expr: p.Expr) -> Type:
|
||||
"""Evaluate the type of an expression
|
||||
|
||||
Args:
|
||||
expr (p.Expr): the expression to evaluate
|
||||
|
||||
Returns:
|
||||
Type: the type of the given expression
|
||||
"""
|
||||
return expr.accept(self)
|
||||
|
||||
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
||||
"""Evaluate a sequence of statements
|
||||
|
||||
Args:
|
||||
block (list[p.Stmt]): the statements to evaluate
|
||||
env (Environment): the environment in which to evaluate
|
||||
|
||||
Returns:
|
||||
bool: whether a return statement is present in the block
|
||||
"""
|
||||
previous_env: Environment = self.env
|
||||
self.env = env
|
||||
returned: bool = False
|
||||
for i, stmt in enumerate(block):
|
||||
try:
|
||||
stmt.accept(self)
|
||||
except ReturnException:
|
||||
returned = True
|
||||
if i < len(block) - 1:
|
||||
self.warning(block[i + 1].location, "Unreachable statement")
|
||||
break
|
||||
self.env = previous_env
|
||||
return returned
|
||||
|
||||
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
|
||||
"""Type check a sequence of statements and returns diagnostics
|
||||
|
||||
Args:
|
||||
statements (list[p.Stmt]): the statements to evaluate and check
|
||||
|
||||
Returns:
|
||||
list[Diagnostic]: the list of diagnostics (errors, warning, etc.)
|
||||
"""
|
||||
self.diagnostics = []
|
||||
|
||||
for path in self.types_paths:
|
||||
self.import_midas(path)
|
||||
self.logger.debug(f"Midas types: {self.ctx._types}")
|
||||
self.logger.debug(f"Midas operations: {self.ctx._operations}")
|
||||
|
||||
for stmt in statements:
|
||||
stmt.accept(self)
|
||||
|
||||
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
|
||||
return self.diagnostics
|
||||
|
||||
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
|
||||
"""Look up a variable in the environment it was declared
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
expr (p.Expr): the variable expression, used to lookup the scope distance
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the type of the variable, or None if it was not found
|
||||
"""
|
||||
distance: Optional[int] = self.locals.get(expr)
|
||||
if distance is not None:
|
||||
return self.env.get_at(distance, name)
|
||||
return self.global_env.get(name)
|
||||
|
||||
def import_midas(self, path: Path) -> None:
|
||||
"""Import Midas definitions from a path
|
||||
|
||||
Args:
|
||||
path (Path): the import path
|
||||
"""
|
||||
self.logger.debug(f"Importing type definitions from {path}")
|
||||
lexer: MidasLexer = MidasLexer(path.read_text())
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
self.ctx.resolve(stmts)
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
self.type_of(stmt.expr)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
env: Environment = Environment(self.env)
|
||||
pos_args: list[Function.Argument] = []
|
||||
args: list[Function.Argument] = []
|
||||
kw_args: list[Function.Argument] = []
|
||||
|
||||
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||
if arg.type is not None:
|
||||
return arg.type.accept(self)
|
||||
if arg.default is not None:
|
||||
return arg.default.accept(self)
|
||||
return UnknownType()
|
||||
|
||||
for arg in stmt.posonlyargs:
|
||||
pos_args.append(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
for arg in stmt.args:
|
||||
args.append(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
for arg in stmt.kwonlyargs:
|
||||
kw_args.append(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
|
||||
for arg in pos_args + args + kw_args:
|
||||
env.define(arg.name, arg.type)
|
||||
|
||||
returns_hint: Optional[Type] = None
|
||||
if stmt.returns is not None:
|
||||
returns_hint = stmt.returns.accept(self)
|
||||
# Early define to handle simple fully-typed recursion
|
||||
inside_function: Function = Function(
|
||||
name=stmt.name,
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns_hint,
|
||||
)
|
||||
self.env.define(stmt.name, inside_function)
|
||||
|
||||
returned: bool = self.process_block(stmt.body, env)
|
||||
inferred_return: Type = UnknownType()
|
||||
if not returned:
|
||||
env.return_types.append(UnitType())
|
||||
return_types: set[Type] = set(env.return_types)
|
||||
if len(return_types) == 1:
|
||||
inferred_return = list(return_types)[0]
|
||||
elif len(return_types) > 1:
|
||||
self.error(
|
||||
stmt.location,
|
||||
f"Mixed return types: {env.return_types}",
|
||||
)
|
||||
|
||||
returns: Type = UnknownType()
|
||||
if returns_hint is not None:
|
||||
assert stmt.returns is not None
|
||||
returns = returns_hint
|
||||
if returns != inferred_return:
|
||||
self.error(
|
||||
stmt.returns.location,
|
||||
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
||||
)
|
||||
else:
|
||||
returns = inferred_return
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
function: Function = Function(
|
||||
name=stmt.name,
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns,
|
||||
)
|
||||
self.env.define(stmt.name, function)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
# TODO check not yet defined locally
|
||||
type: Type = stmt.type.accept(self)
|
||||
self.env.define(stmt.name, type)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
value: Type = self.type_of(stmt.value)
|
||||
for target in stmt.targets:
|
||||
if not isinstance(target, p.VariableExpr):
|
||||
self.logger.warning(f"Unsupported assignment to {target}")
|
||||
self.warning(target.location, f"Unsupported assignment to {target}")
|
||||
continue
|
||||
name: str = target.name
|
||||
var_type: Optional[Type] = self.look_up_variable(name, target)
|
||||
|
||||
if var_type is None:
|
||||
self.env.define(name, value)
|
||||
else:
|
||||
# TODO: implement real comparison method
|
||||
if var_type != value:
|
||||
self.error(
|
||||
stmt.location,
|
||||
f"Cannot assign {value} to {name} of type {var_type}",
|
||||
)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
|
||||
self.env.return_types.append(type)
|
||||
raise ReturnException()
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
# Not evaluated in sub-environment because assignments in the test leak out of the if
|
||||
# For example:
|
||||
# if (m := 1 + 1) < 2:
|
||||
# ...
|
||||
# print(m) # <- m is still defined
|
||||
test_type: Type = stmt.test.accept(self)
|
||||
|
||||
# TODO Allow subtypes or any type
|
||||
if test_type != self.ctx.get_type("bool"):
|
||||
self.error(
|
||||
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
||||
)
|
||||
|
||||
env: Environment = Environment(self.env)
|
||||
body_returned: bool = self.process_block(stmt.body, env)
|
||||
else_returned: bool = self.process_block(stmt.orelse, env)
|
||||
self.env.return_types.extend(env.return_types)
|
||||
if body_returned and else_returned:
|
||||
raise ReturnException()
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
||||
return UnknownType()
|
||||
left: Type = self.type_of(expr.left)
|
||||
right: Type = self.type_of(expr.right)
|
||||
|
||||
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
||||
if result is None:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
return result
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
||||
return UnknownType()
|
||||
left: Type = self.type_of(expr.left)
|
||||
right: Type = self.type_of(expr.right)
|
||||
|
||||
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
||||
if result is None:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
return result
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||
callee: Type = self.type_of(expr.callee)
|
||||
if not isinstance(callee, Function):
|
||||
self.error(expr.callee.location, "Callee is not a function")
|
||||
return UnknownType()
|
||||
function: Function = callee
|
||||
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
||||
for arg in mapped:
|
||||
if arg.type != arg.argument.type:
|
||||
self.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
return function.returns
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> Type: ...
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
|
||||
match expr.value:
|
||||
case bool(): # Must be before int
|
||||
return self.ctx.get_type("bool")
|
||||
case int():
|
||||
return self.ctx.get_type("int")
|
||||
case float():
|
||||
return self.ctx.get_type("float")
|
||||
case str():
|
||||
return self.ctx.get_type("str")
|
||||
case _:
|
||||
self.warning(expr.location, f"Unknown literal {expr}")
|
||||
return UnknownType()
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
|
||||
return self.look_up_variable(expr.name, expr) or UnknownType()
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
|
||||
left: Type = expr.left.accept(self)
|
||||
right: Type = expr.right.accept(self)
|
||||
# TODO: union type
|
||||
if left != right:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Operands must be of the same type, left={left} != right={right}",
|
||||
)
|
||||
return left
|
||||
|
||||
def visit_set_expr(self, expr: p.SetExpr) -> Type: ...
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
||||
return expr.type.accept(self)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
||||
test_type: Type = expr.test.accept(self)
|
||||
|
||||
# TODO Allow subtypes or any type
|
||||
if test_type != self.ctx.get_type("bool"):
|
||||
self.error(
|
||||
expr.test.location, f"If test must be a boolean, got {test_type}"
|
||||
)
|
||||
|
||||
true_type: Type = expr.if_true.accept(self)
|
||||
false_type: Type = expr.if_false.accept(self)
|
||||
if true_type != false_type:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Type mismatch in ternary if branches: true={true_type} != false={false_type}",
|
||||
)
|
||||
return UnknownType()
|
||||
return true_type
|
||||
|
||||
def visit_base_type(self, node: p.BaseType) -> Type:
|
||||
return self.ctx.get_type(node.base)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> Type: ...
|
||||
|
||||
def map_call_arguments(
|
||||
self, function: Function, call: p.CallExpr
|
||||
) -> list[MappedArgument]:
|
||||
"""Map call arguments to function parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
with the arguments passed at the call site
|
||||
|
||||
Any mismatched, missing or unexpected argument is reported as a diagnostic
|
||||
|
||||
Args:
|
||||
function (Function): the function definition
|
||||
call (p.CallExpr): the call expression
|
||||
|
||||
Returns:
|
||||
list[MappedArgument]: the list of mapped arguments
|
||||
"""
|
||||
positional: list[tuple[p.Expr, Type]] = [
|
||||
(arg, self.type_of(arg)) for arg in call.arguments
|
||||
]
|
||||
keywords: dict[str, tuple[p.Expr, Type]] = {
|
||||
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
|
||||
}
|
||||
set_args: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
arg.name for arg in function.pos_args + function.args if arg.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
arg.name for arg in function.kw_args if arg.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument] = []
|
||||
|
||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||
mixed_params: list[Function.Argument] = list(function.args)
|
||||
kw_params: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.kw_args
|
||||
}
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
self.error(arg[0].location, "Too many positional arguments")
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
if name not in kw_params:
|
||||
if name in set_args:
|
||||
self.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.error(arg[0].location, f"Unknown keyword argument '{name}'")
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_args(args: list[str]) -> str:
|
||||
args = list(map(lambda a: f"'{a}'", args))
|
||||
if len(args) == 0:
|
||||
return ""
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
args: str = join_args(required_positional)
|
||||
self.error(
|
||||
call.location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
)
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
args: str = join_args(required_keyword)
|
||||
self.error(
|
||||
call.location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
)
|
||||
|
||||
return mapped
|
||||
@property
|
||||
def diagnostics(self) -> list[Diagnostic]:
|
||||
return self.reporter.diagnostics
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
@@ -14,12 +13,13 @@ class DiagnosticType(StrEnum):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Diagnostic:
|
||||
file_path: Path
|
||||
file_path: Optional[str]
|
||||
location: Location
|
||||
type: DiagnosticType
|
||||
message: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
@property
|
||||
def location_str(self) -> str:
|
||||
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
|
||||
end_loc: Optional[str] = ""
|
||||
if (
|
||||
@@ -27,7 +27,16 @@ class Diagnostic:
|
||||
and self.location.end_col_offset is not None
|
||||
):
|
||||
end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}"
|
||||
loc: str = (
|
||||
f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}"
|
||||
)
|
||||
return f"{self.type} in {self.file_path} {loc}: {self.message}"
|
||||
|
||||
loc: str = ""
|
||||
if self.file_path is not None:
|
||||
loc += f" in {self.file_path}"
|
||||
if end_loc is None:
|
||||
loc += f" at {start_loc}"
|
||||
else:
|
||||
loc += f" from {start_loc} to {end_loc}"
|
||||
|
||||
return f"{self.type}{loc}"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.location_str}: {self.message}"
|
||||
|
||||
171
midas/checker/midas.py
Normal file
171
midas/checker/midas.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.builtins import define_builtins
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter, Reporter
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
ComplexType,
|
||||
Function,
|
||||
GenericType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
|
||||
|
||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||
|
||||
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||
self.logger: logging.Logger = logging.getLogger("MidasTyper")
|
||||
self.reporter: FileReporter = reporter.for_file(None)
|
||||
|
||||
self.types: TypesRegistry = types
|
||||
self._local_variables: dict[str, TypeVar] = {}
|
||||
|
||||
define_builtins(self.types)
|
||||
|
||||
def process(self, source: str, path: Optional[str]):
|
||||
self.reporter = self.reporter.for_file(path)
|
||||
lexer: MidasLexer = MidasLexer(source)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
self.resolve(stmts)
|
||||
|
||||
def get_type(self, name: str) -> Type:
|
||||
"""Get a type from its name
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
|
||||
Raises:
|
||||
NameError: if the type is not defined
|
||||
|
||||
Returns:
|
||||
Type: the type
|
||||
"""
|
||||
if name in self._local_variables:
|
||||
return self._local_variables[name]
|
||||
return self.types.get_type(name)
|
||||
|
||||
def resolve(self, stmts: list[m.Stmt]):
|
||||
"""Process a sequence of statements
|
||||
|
||||
Args:
|
||||
stmts (list[m.Stmt]): the statements
|
||||
"""
|
||||
for stmt in stmts:
|
||||
stmt.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
params: list[TypeVar] = self._resolve_type_params(stmt.params)
|
||||
|
||||
name: str = stmt.name.lexeme
|
||||
type: Type = stmt.type.accept(self)
|
||||
if len(params) != 0:
|
||||
type = GenericType(name=name, params=params, body=type)
|
||||
else:
|
||||
type = AliasType(name=name, type=type)
|
||||
self.types.define_type(name, type)
|
||||
self._local_variables.clear()
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self._resolve_type_params(stmt.params)
|
||||
base: Type = stmt.type.accept(self)
|
||||
for op in stmt.operations:
|
||||
right: Type = op.operand.accept(self)
|
||||
result: Type = op.result.accept(self)
|
||||
self.types.define_operation(
|
||||
left=base,
|
||||
operator=op.name.lexeme,
|
||||
right=right,
|
||||
result=result,
|
||||
)
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> None: ...
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||
return expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||
return self.get_type(type.name.lexeme)
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
args: list[Type] = [arg.accept(self) for arg in type.args]
|
||||
return self.types.apply_generic(type_, args)
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
type.constraint.accept(self)
|
||||
# TODO
|
||||
return UnknownType()
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> Type:
|
||||
return ComplexType(
|
||||
properties={
|
||||
prop.name.lexeme: prop.type.accept(self) for prop in type.properties
|
||||
}
|
||||
)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> Type:
|
||||
return Function(
|
||||
name="<anonymous>",
|
||||
pos_args=[
|
||||
Function.Argument(
|
||||
pos=i,
|
||||
name=arg.name.lexeme if arg.name is not None else str(i),
|
||||
type=arg.type.accept(self),
|
||||
required=arg.required,
|
||||
)
|
||||
for i, arg in enumerate(type.pos_args)
|
||||
],
|
||||
args=[],
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=i,
|
||||
name=arg.name.lexeme if arg.name is not None else str(i),
|
||||
type=arg.type.accept(self),
|
||||
required=arg.required,
|
||||
)
|
||||
for i, arg in enumerate(type.kw_args, start=len(type.pos_args))
|
||||
],
|
||||
returns=type.returns.accept(self),
|
||||
)
|
||||
|
||||
def _resolve_type_params(self, params: list[m.TypeParam]):
|
||||
vars: list[TypeVar] = []
|
||||
for param in params:
|
||||
name: str = param.name.lexeme
|
||||
bound: Optional[Type] = None
|
||||
if param.bound is not None:
|
||||
bound = param.bound.accept(self)
|
||||
var = TypeVar(name=name, bound=bound)
|
||||
self._local_variables[name] = var
|
||||
vars.append(var)
|
||||
return vars
|
||||
646
midas/checker/python.py
Normal file
646
midas/checker/python.py
Normal file
@@ -0,0 +1,646 @@
|
||||
import ast
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter, Reporter
|
||||
from midas.checker.resolver import Resolver
|
||||
from midas.checker.types import (
|
||||
ComplexType,
|
||||
Function,
|
||||
Operation,
|
||||
Type,
|
||||
UnitType,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
from midas.parser.python import PythonParser
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MappedArgument:
|
||||
expr: p.Expr
|
||||
type: Type
|
||||
argument: Function.Argument
|
||||
|
||||
|
||||
class PythonTyper(
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[Type],
|
||||
p.MidasType.Visitor[Type],
|
||||
):
|
||||
"""A type checker which can use custom type definitions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
types: TypesRegistry,
|
||||
reporter: Reporter,
|
||||
):
|
||||
self.logger: logging.Logger = logging.getLogger("PythonTyper")
|
||||
self.reporter: FileReporter = reporter.for_file(None)
|
||||
self.types: TypesRegistry = types
|
||||
self.global_env: Environment = Environment()
|
||||
self.env: Environment = self.global_env
|
||||
self.locals: dict[p.Expr, int] = {}
|
||||
self.judgements: list[tuple[p.Expr, Type]] = []
|
||||
|
||||
def process(self, source: str, path: Optional[str]):
|
||||
self.reporter = self.reporter.for_file(path)
|
||||
|
||||
tree: ast.Module = ast.parse(source, filename=path or "<unknown>")
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
resolver = Resolver()
|
||||
resolver.resolve(*stmts)
|
||||
|
||||
self.env = self.global_env
|
||||
self.locals = resolver.locals
|
||||
self.judgements = []
|
||||
|
||||
self.check(stmts)
|
||||
|
||||
def type_of(self, expr: p.Expr) -> Type:
|
||||
"""Evaluate the type of an expression
|
||||
|
||||
Args:
|
||||
expr (p.Expr): the expression to evaluate
|
||||
|
||||
Returns:
|
||||
Type: the type of the given expression
|
||||
"""
|
||||
type: Type = expr.accept(self)
|
||||
self.judgements.append((expr, type))
|
||||
return type
|
||||
|
||||
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
||||
"""Evaluate a sequence of statements
|
||||
|
||||
Args:
|
||||
block (list[p.Stmt]): the statements to evaluate
|
||||
env (Environment): the environment in which to evaluate
|
||||
|
||||
Returns:
|
||||
bool: whether a return statement is present in the block
|
||||
"""
|
||||
previous_env: Environment = self.env
|
||||
self.env = env
|
||||
returned: bool = False
|
||||
for i, stmt in enumerate(block):
|
||||
try:
|
||||
stmt.accept(self)
|
||||
except ReturnException:
|
||||
returned = True
|
||||
if i < len(block) - 1:
|
||||
self.reporter.warning(
|
||||
block[i + 1].location, "Unreachable statement"
|
||||
)
|
||||
break
|
||||
self.env = previous_env
|
||||
return returned
|
||||
|
||||
def check(self, statements: list[p.Stmt]) -> None:
|
||||
"""Type check a sequence of statements and returns diagnostics
|
||||
|
||||
Args:
|
||||
statements (list[p.Stmt]): the statements to evaluate and check
|
||||
"""
|
||||
for stmt in statements:
|
||||
stmt.accept(self)
|
||||
|
||||
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
|
||||
|
||||
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
|
||||
"""Look up a variable in the environment it was declared
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
expr (p.Expr): the variable expression, used to lookup the scope distance
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the type of the variable, or None if it was not found
|
||||
"""
|
||||
distance: Optional[int] = self.locals.get(expr)
|
||||
if distance is not None:
|
||||
return self.env.get_at(distance, name)
|
||||
return self.global_env.get(name)
|
||||
|
||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||
return self.types.is_subtype(type1, type2)
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
self.type_of(stmt.expr)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
env: Environment = Environment(self.env)
|
||||
pos_args: list[Function.Argument] = []
|
||||
args: list[Function.Argument] = []
|
||||
kw_args: list[Function.Argument] = []
|
||||
|
||||
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||
if arg.type is not None:
|
||||
return arg.type.accept(self)
|
||||
if arg.default is not None:
|
||||
return arg.default.accept(self)
|
||||
return UnknownType()
|
||||
|
||||
pos: int = 0
|
||||
for arg in stmt.posonlyargs:
|
||||
pos_args.append(
|
||||
Function.Argument(
|
||||
pos=pos,
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
pos += 1
|
||||
for arg in stmt.args:
|
||||
args.append(
|
||||
Function.Argument(
|
||||
pos=pos,
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
pos += 1
|
||||
for arg in stmt.kwonlyargs:
|
||||
kw_args.append(
|
||||
Function.Argument(
|
||||
pos=pos, # not relevant
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
pos += 1
|
||||
|
||||
for arg in pos_args + args + kw_args:
|
||||
env.define(arg.name, arg.type)
|
||||
|
||||
returns_hint: Optional[Type] = None
|
||||
if stmt.returns is not None:
|
||||
returns_hint = stmt.returns.accept(self)
|
||||
# Early define to handle simple fully-typed recursion
|
||||
inside_function: Function = Function(
|
||||
name=stmt.name,
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns_hint,
|
||||
)
|
||||
self.env.define(stmt.name, inside_function)
|
||||
|
||||
returned: bool = self.process_block(stmt.body, env)
|
||||
inferred_return: Type = UnknownType()
|
||||
if not returned:
|
||||
env.return_types.append(UnitType())
|
||||
return_types: list[Type] = self.types.reduce_types(env.return_types)
|
||||
if len(return_types) == 1:
|
||||
inferred_return = return_types[0]
|
||||
elif len(return_types) > 1:
|
||||
self.reporter.error(
|
||||
stmt.location,
|
||||
f"Mixed return types: {return_types}",
|
||||
)
|
||||
|
||||
returns: Type = UnknownType()
|
||||
if returns_hint is not None:
|
||||
assert stmt.returns is not None
|
||||
returns = returns_hint
|
||||
if returns != inferred_return:
|
||||
self.reporter.error(
|
||||
stmt.returns.location,
|
||||
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
||||
)
|
||||
else:
|
||||
returns = inferred_return
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
function: Function = Function(
|
||||
name=stmt.name,
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns,
|
||||
)
|
||||
self.env.define(stmt.name, function)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
# TODO check not yet defined locally
|
||||
type: Type = stmt.type.accept(self)
|
||||
self.env.define(stmt.name, type)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
value_type: Type = self.type_of(stmt.value)
|
||||
for target in stmt.targets:
|
||||
self._assign(stmt.location, target, value_type)
|
||||
|
||||
def _assign(self, location: Location, target: p.Expr, value_type: Type):
|
||||
match target:
|
||||
case p.VariableExpr():
|
||||
self._assign_var(location, target, value_type)
|
||||
|
||||
case p.GetExpr():
|
||||
self._assign_attr(location, target, value_type)
|
||||
|
||||
case _:
|
||||
if not isinstance(target, p.VariableExpr):
|
||||
self.logger.warning(f"Unsupported assignment to {target}")
|
||||
self.reporter.warning(
|
||||
target.location, f"Unsupported assignment to {target}"
|
||||
)
|
||||
|
||||
def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type):
|
||||
name: str = target.name
|
||||
var_type: Optional[Type] = self.look_up_variable(name, target)
|
||||
|
||||
if var_type is None:
|
||||
self.env.define(name, value_type)
|
||||
else:
|
||||
# S <: T
|
||||
# Γ, x: T v: S
|
||||
# x = v
|
||||
if not self.is_subtype(value_type, var_type):
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Cannot assign {value_type} to variable '{name}' of type {var_type}",
|
||||
)
|
||||
|
||||
def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type):
|
||||
object: Type = self.type_of(target.object)
|
||||
base_object: Type = unfold_type(object)
|
||||
match base_object:
|
||||
case ComplexType(properties=properties):
|
||||
if target.name not in properties:
|
||||
self.reporter.error(
|
||||
target.location, f"Unknown property '{object}.{target.name}'"
|
||||
)
|
||||
return
|
||||
|
||||
prop_type: Type = properties[target.name]
|
||||
if not self.is_subtype(value_type, prop_type):
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Cannot assign {value_type} to property '{object}.{target.name}' of type {prop_type}",
|
||||
)
|
||||
return
|
||||
|
||||
case UnknownType():
|
||||
pass
|
||||
|
||||
case _:
|
||||
self.reporter.error(
|
||||
target.location,
|
||||
f"Cannot assign {value_type} to unknown property '{object}.{target.name}'",
|
||||
)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
|
||||
self.env.return_types.append(type)
|
||||
raise ReturnException()
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
# Not evaluated in sub-environment because assignments in the test leak out of the if
|
||||
# For example:
|
||||
# if (m := 1 + 1) < 2:
|
||||
# ...
|
||||
# print(m) # <- m is still defined
|
||||
test_type: Type = stmt.test.accept(self)
|
||||
|
||||
# TODO Allow subtypes or any type
|
||||
if test_type != self.types.get_type("bool"):
|
||||
self.reporter.error(
|
||||
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
||||
)
|
||||
|
||||
env: Environment = Environment(self.env)
|
||||
body_returned: bool = self.process_block(stmt.body, env)
|
||||
else_returned: bool = self.process_block(stmt.orelse, env)
|
||||
self.env.return_types.extend(env.return_types)
|
||||
if body_returned and else_returned:
|
||||
raise ReturnException()
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.reporter.warning(
|
||||
expr.location, f"Unsupported operator {expr.operator}"
|
||||
)
|
||||
return UnknownType()
|
||||
left: Type = self.type_of(expr.left)
|
||||
right: Type = self.type_of(expr.right)
|
||||
|
||||
operations: list[Operation] = self.types.get_operations_by_name(method)
|
||||
valid_operations: list[Operation] = []
|
||||
for op in operations:
|
||||
sig: Operation.CallSignature = op.signature
|
||||
if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right):
|
||||
valid_operations.append(op)
|
||||
|
||||
if len(valid_operations) == 0:
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
elif len(valid_operations) == 1:
|
||||
self.logger.debug(f"Unique operation {method} between {left} and {right}")
|
||||
return valid_operations[0].result
|
||||
|
||||
for i, op1 in enumerate(valid_operations):
|
||||
sig1: Operation.CallSignature = op1.signature
|
||||
best_match: bool = True
|
||||
for j, op2 in enumerate(valid_operations):
|
||||
if i == j:
|
||||
continue
|
||||
sig2: Operation.CallSignature = op2.signature
|
||||
|
||||
# If op1 is not a full overload of op2 (i.e. operands of op1 are subtypes of op2's)
|
||||
# ambiguity -> not best match
|
||||
if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype(
|
||||
sig1.right, sig2.right
|
||||
):
|
||||
best_match = False
|
||||
break
|
||||
self.logger.debug(f"{op1} is a full overload of {op2}")
|
||||
if best_match:
|
||||
return op1.result
|
||||
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(map(str, valid_operations))}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.reporter.warning(
|
||||
expr.location, f"Unsupported operator {expr.operator}"
|
||||
)
|
||||
return UnknownType()
|
||||
left: Type = self.type_of(expr.left)
|
||||
right: Type = self.type_of(expr.right)
|
||||
|
||||
result: Optional[Type] = self.types.get_operation_result(left, method, right)
|
||||
if result is None:
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
return result
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||
callee: Type = self.type_of(expr.callee)
|
||||
if not isinstance(callee, Function):
|
||||
self.reporter.error(expr.callee.location, "Callee is not a function")
|
||||
return UnknownType()
|
||||
function: Function = callee
|
||||
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
||||
for arg in mapped:
|
||||
if not self.is_subtype(arg.type, arg.argument.type):
|
||||
self.reporter.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
return function.returns
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
||||
object: Type = self.type_of(expr.object)
|
||||
base_object: Type = unfold_type(object)
|
||||
match base_object:
|
||||
case ComplexType(properties=properties):
|
||||
if expr.name not in properties:
|
||||
self.reporter.error(
|
||||
expr.location, f"Unknown property '{expr.name} on {object}"
|
||||
)
|
||||
return UnknownType()
|
||||
return properties[expr.name]
|
||||
|
||||
case UnknownType():
|
||||
return UnknownType()
|
||||
|
||||
case _:
|
||||
self.reporter.error(
|
||||
expr.location, f"Cannot get property '{expr.name}' on {object}"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
|
||||
match expr.value:
|
||||
case bool(): # Must be before int
|
||||
return self.types.get_type("bool")
|
||||
case int():
|
||||
return self.types.get_type("int")
|
||||
case float():
|
||||
return self.types.get_type("float")
|
||||
case str():
|
||||
return self.types.get_type("str")
|
||||
case _:
|
||||
self.reporter.warning(expr.location, f"Unknown literal {expr}")
|
||||
return UnknownType()
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
|
||||
return self.look_up_variable(expr.name, expr) or UnknownType()
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
|
||||
left: Type = expr.left.accept(self)
|
||||
right: Type = expr.right.accept(self)
|
||||
|
||||
if self.is_subtype(left, right):
|
||||
return right
|
||||
if self.is_subtype(right, left):
|
||||
return left
|
||||
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Incompatible operand types, {left=} and {right=}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
||||
return expr.type.accept(self)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
||||
test_type: Type = expr.test.accept(self)
|
||||
|
||||
# TODO Allow subtypes or any type
|
||||
if test_type != self.types.get_type("bool"):
|
||||
self.reporter.error(
|
||||
expr.test.location, f"If test must be a boolean, got {test_type}"
|
||||
)
|
||||
|
||||
true_type: Type = expr.if_true.accept(self)
|
||||
false_type: Type = expr.if_false.accept(self)
|
||||
if self.is_subtype(true_type, false_type):
|
||||
return false_type
|
||||
if self.is_subtype(false_type, true_type):
|
||||
return true_type
|
||||
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Incompatible types in ternary if branches: true={true_type} and false={false_type}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> Type:
|
||||
list_type: Type = self.types.get_type("list")
|
||||
item_types: list[Type] = [self.type_of(item) for item in expr.items]
|
||||
item_types = self.types.reduce_types(item_types)
|
||||
|
||||
if len(item_types) == 0:
|
||||
return list_type
|
||||
|
||||
if len(item_types) == 1:
|
||||
item_type: Type = item_types[0]
|
||||
return self.types.apply_generic(list_type, [item_type])
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Heterogeneous list items: {item_types}",
|
||||
)
|
||||
return self.types.apply_generic(list_type, [UnknownType()])
|
||||
|
||||
def visit_base_type(self, node: p.BaseType) -> Type:
|
||||
base: Type = self.types.get_type(node.base)
|
||||
if node.param is not None:
|
||||
param: Type = node.param.accept(self)
|
||||
return self.types.apply_generic(base, [param])
|
||||
return base
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> Type: ...
|
||||
|
||||
def map_call_arguments(
|
||||
self, function: Function, call: p.CallExpr
|
||||
) -> list[MappedArgument]:
|
||||
"""Map call arguments to function parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
with the arguments passed at the call site
|
||||
|
||||
Any mismatched, missing or unexpected argument is reported as a diagnostic
|
||||
|
||||
Args:
|
||||
function (Function): the function definition
|
||||
call (p.CallExpr): the call expression
|
||||
|
||||
Returns:
|
||||
list[MappedArgument]: the list of mapped arguments
|
||||
"""
|
||||
positional: list[tuple[p.Expr, Type]] = [
|
||||
(arg, self.type_of(arg)) for arg in call.arguments
|
||||
]
|
||||
keywords: dict[str, tuple[p.Expr, Type]] = {
|
||||
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
|
||||
}
|
||||
set_args: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
arg.name for arg in function.pos_args + function.args if arg.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
arg.name for arg in function.kw_args if arg.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument] = []
|
||||
|
||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||
mixed_params: list[Function.Argument] = list(function.args)
|
||||
kw_params: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.kw_args
|
||||
}
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
self.reporter.error(arg[0].location, "Too many positional arguments")
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
if name not in kw_params:
|
||||
if name in set_args:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||
)
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_args(args: list[str]) -> str:
|
||||
args = list(map(lambda a: f"'{a}'", args))
|
||||
if len(args) == 0:
|
||||
return ""
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
args: str = join_args(required_positional)
|
||||
self.reporter.error(
|
||||
call.location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
)
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
args: str = join_args(required_keyword)
|
||||
self.reporter.error(
|
||||
call.location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
)
|
||||
|
||||
return mapped
|
||||
313
midas/checker/registry.py
Normal file
313
midas/checker/registry.py
Normal file
@@ -0,0 +1,313 @@
|
||||
from typing import Optional
|
||||
|
||||
from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ComplexType,
|
||||
Function,
|
||||
GenericType,
|
||||
Operation,
|
||||
Type,
|
||||
substitute_typevars,
|
||||
)
|
||||
|
||||
|
||||
class TypesRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._types: dict[str, Type] = {}
|
||||
self._operations: dict[Operation.CallSignature, Type] = {}
|
||||
|
||||
def get_type(self, name: str) -> Type:
|
||||
"""Get a type from its name
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
|
||||
Raises:
|
||||
NameError: if the type is not defined
|
||||
|
||||
Returns:
|
||||
Type: the type
|
||||
"""
|
||||
if name in self._types:
|
||||
return self._types[name]
|
||||
raise NameError(f"Undefined type {name}")
|
||||
|
||||
def get_operation_result(
|
||||
self, left: Type, operator: str, right: Type
|
||||
) -> Optional[Type]:
|
||||
"""Get the resulting type of an operation
|
||||
|
||||
Args:
|
||||
left (Type): the type of the left operand
|
||||
operator (str): the operation name
|
||||
right (Type): the type of the right operand
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the result type, or None if no matching operation was found
|
||||
"""
|
||||
signature: Operation.CallSignature = Operation.CallSignature(
|
||||
left=left,
|
||||
method=operator,
|
||||
right=right,
|
||||
)
|
||||
result: Optional[Type] = self._operations.get(signature)
|
||||
return result
|
||||
|
||||
def get_operations_by_name(self, name: str) -> list[Operation]:
|
||||
operations: list[Operation] = []
|
||||
for signature, result in self._operations.items():
|
||||
if signature.method == name:
|
||||
operations.append(
|
||||
Operation(
|
||||
signature=signature,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
return operations
|
||||
|
||||
def define_type(self, name: str, type: Type) -> Type:
|
||||
"""Define a type in the registry
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
type (Type): the type to define
|
||||
|
||||
Raises:
|
||||
ValueError: if a type is already defined with that name
|
||||
|
||||
Returns:
|
||||
Type: the defined type
|
||||
"""
|
||||
if name in self._types:
|
||||
raise ValueError(f"Type {name} already defined")
|
||||
self._types[name] = type
|
||||
return type
|
||||
|
||||
def define_operation(self, left: Type, operator: str, right: Type, result: Type):
|
||||
"""Define an operation in the registry
|
||||
|
||||
Args:
|
||||
left (Type): the type of the left operand
|
||||
operator (str): the operation name
|
||||
right (Type): the type of the right operand
|
||||
result (Type): the result type
|
||||
|
||||
Raises:
|
||||
ValueError: if an operation is already defined with these operands and name
|
||||
"""
|
||||
signature: Operation.CallSignature = Operation.CallSignature(
|
||||
left=left,
|
||||
method=operator,
|
||||
right=right,
|
||||
)
|
||||
if signature in self._operations:
|
||||
raise ValueError(
|
||||
f"Operation {operator} already defined between {left} and {right}"
|
||||
)
|
||||
self._operations[signature] = result
|
||||
|
||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||
"""Check whether `type1` is a subtype of `type2`
|
||||
|
||||
For more details on the rules checked here, see TAPL Chap. 15-16-17
|
||||
|
||||
Args:
|
||||
type1 (Type): the potential subtype
|
||||
type2 (Type): the potential supertype
|
||||
|
||||
Returns:
|
||||
bool: whether `type1` is a subtype of `type2`
|
||||
"""
|
||||
|
||||
if type1 == type2:
|
||||
return True
|
||||
|
||||
match (type1, type2):
|
||||
case (AliasType(type=base1), _):
|
||||
return self.is_subtype(base1, type2)
|
||||
|
||||
case (BaseType(name=name1), BaseType(name=name2)):
|
||||
return name1 in BUILTIN_SUBTYPES.get(name2, set())
|
||||
|
||||
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
||||
for k, t in props2.items():
|
||||
if k not in props1:
|
||||
return False
|
||||
if not self.is_subtype(props1[k], t):
|
||||
return False
|
||||
return True
|
||||
|
||||
case (Function(), Function()):
|
||||
return self.is_func_subtype(type1, type2)
|
||||
|
||||
return False
|
||||
|
||||
# TODO: verify the logic in here
|
||||
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
|
||||
"""Check whether a function is a subtype of another
|
||||
|
||||
Args:
|
||||
func1 (Function): the potential function subtype
|
||||
func2 (Function): the potential function supertype
|
||||
|
||||
Returns:
|
||||
bool: whether `func1` is a subtype of `func2`
|
||||
"""
|
||||
if not self.is_subtype(func1.returns, func2.returns):
|
||||
return False
|
||||
|
||||
pos1: list[Function.Argument] = func1.pos_args
|
||||
mixed1: list[Function.Argument] = func1.args
|
||||
kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args}
|
||||
pos2: list[Function.Argument] = func2.pos_args
|
||||
mixed2: list[Function.Argument] = func2.args
|
||||
kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args}
|
||||
|
||||
mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2}
|
||||
mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2}
|
||||
|
||||
def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool:
|
||||
if not self.is_subtype(sub.type, sup.type):
|
||||
return False
|
||||
if not sup.required and sub.required:
|
||||
return False
|
||||
return True
|
||||
|
||||
for arg1 in pos1:
|
||||
arg2: Function.Argument
|
||||
if arg1.pos < len(pos2):
|
||||
arg2 = pos2[arg1.pos]
|
||||
elif arg1.pos in mixed_by_pos:
|
||||
arg2 = mixed_by_pos[arg1.pos]
|
||||
elif not arg1.required:
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
if not is_arg_subtype(arg2, arg1):
|
||||
return False
|
||||
|
||||
for name, arg1 in kw1.items():
|
||||
arg2: Function.Argument
|
||||
if name in kw2:
|
||||
arg2 = kw2[name]
|
||||
elif name in mixed_by_name:
|
||||
arg2 = mixed_by_name[name]
|
||||
elif not arg1.required:
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
if not is_arg_subtype(arg2, arg1):
|
||||
return False
|
||||
|
||||
for arg1 in mixed1:
|
||||
pos_arg2: Optional[Function.Argument] = None
|
||||
kw_arg2: Optional[Function.Argument] = None
|
||||
if arg1.name in kw2:
|
||||
kw_arg2 = kw2[arg1.name]
|
||||
elif arg1.name in mixed_by_name:
|
||||
kw_arg2 = mixed_by_name[arg1.name]
|
||||
if arg1.pos < len(pos2):
|
||||
pos_arg2 = pos2[arg1.pos]
|
||||
elif arg1.pos in mixed_by_pos:
|
||||
pos_arg2 = mixed_by_pos[arg1.pos]
|
||||
|
||||
# No match in func2 and arg is required
|
||||
if pos_arg2 is None and kw_arg2 is None and arg1.required:
|
||||
return False
|
||||
|
||||
# Matching keyword argument
|
||||
if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1):
|
||||
return False
|
||||
|
||||
# Matching positional argument
|
||||
if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1):
|
||||
return False
|
||||
|
||||
mixed_positions: set[int] = {a.pos for a in mixed1}
|
||||
mixed_names: set[str] = {a.name for a in mixed1}
|
||||
for arg2 in pos2:
|
||||
if not arg2.required:
|
||||
continue
|
||||
if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions:
|
||||
return False
|
||||
|
||||
for name, arg2 in kw2.items():
|
||||
if not arg2.required:
|
||||
continue
|
||||
if name not in kw1 and name not in mixed_names:
|
||||
return False
|
||||
|
||||
for arg2 in mixed2:
|
||||
if arg2.required:
|
||||
continue
|
||||
pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions
|
||||
kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names
|
||||
if not pos_match or not kw_match:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def apply_generic(self, type: Type, args: list[Type]) -> Type:
|
||||
match type:
|
||||
case AliasType(name=name, type=base):
|
||||
return AliasType(name=name, type=self.apply_generic(base, args))
|
||||
|
||||
case GenericType(name=name, params=type_vars, body=body):
|
||||
n_args: int = len(args)
|
||||
n_type_vars: int = len(type_vars)
|
||||
if n_args < n_type_vars:
|
||||
raise ValueError(
|
||||
f"Missing type arguments, expected {n_type_vars} but only {n_args} provided"
|
||||
)
|
||||
if n_args > n_type_vars:
|
||||
raise ValueError(
|
||||
f"Too many type arguments, expected {n_type_vars} but {n_args} provided"
|
||||
)
|
||||
substitutions: dict[str, Type] = {}
|
||||
for arg, type_var in zip(args, type_vars):
|
||||
if type_var.bound is not None and not self.is_subtype(
|
||||
arg, type_var.bound
|
||||
):
|
||||
raise ValueError(
|
||||
f"Type argument {arg} is not a subtype of {type_var.bound}"
|
||||
)
|
||||
substitutions[type_var.name] = arg
|
||||
return AppliedType(
|
||||
name=name,
|
||||
args=args,
|
||||
body=substitute_typevars(body, substitutions),
|
||||
)
|
||||
|
||||
case _:
|
||||
raise ValueError(f"{type} is not a generic type")
|
||||
|
||||
def reduce_types(self, types: list[Type]) -> list[Type]:
|
||||
"""Reduce a list of types to remove subtypes and only keep the highest types
|
||||
|
||||
Args:
|
||||
types (list[Type]): the types to reduce
|
||||
|
||||
Returns:
|
||||
list[Type]: the reduced list of types
|
||||
"""
|
||||
|
||||
reduced: bool = True
|
||||
keep: list[int] = list(range(len(types)))
|
||||
while reduced:
|
||||
reduced = False
|
||||
for i, i1 in enumerate(keep):
|
||||
type1: Type = types[i1]
|
||||
for i2 in keep[i + 1 :]:
|
||||
type2 = types[i2]
|
||||
if self.is_subtype(type1, type2):
|
||||
keep.remove(i1)
|
||||
elif self.is_subtype(type2, type1):
|
||||
keep.remove(i2)
|
||||
else:
|
||||
continue
|
||||
reduced = True
|
||||
break
|
||||
return [types[i] for i in keep]
|
||||
63
midas/checker/reporter.py
Normal file
63
midas/checker/reporter.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||
|
||||
|
||||
class Reporter:
|
||||
def __init__(self):
|
||||
self.diagnostics: list[Diagnostic] = []
|
||||
|
||||
def report(
|
||||
self,
|
||||
path: Optional[str],
|
||||
type: DiagnosticType,
|
||||
location: Location,
|
||||
message: str,
|
||||
):
|
||||
self.diagnostics.append(
|
||||
Diagnostic(
|
||||
file_path=path,
|
||||
location=location,
|
||||
type=type,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
|
||||
def for_file(self, path: Optional[str]) -> FileReporter:
|
||||
return FileReporter(self, path)
|
||||
|
||||
|
||||
class FileReporter:
|
||||
def __init__(self, base_reporter: Reporter, path: Optional[str]) -> None:
|
||||
self.base_reporter: Reporter = base_reporter
|
||||
self.path: Optional[str] = path
|
||||
|
||||
def for_file(self, path: Optional[str]) -> FileReporter:
|
||||
return FileReporter(self.base_reporter, path)
|
||||
|
||||
def report(self, type: DiagnosticType, location: Location, message: str):
|
||||
self.base_reporter.report(self.path, type, location, message)
|
||||
|
||||
def error(self, location: Location, message: str):
|
||||
self.report(
|
||||
type=DiagnosticType.ERROR,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def warning(self, location: Location, message: str):
|
||||
self.report(
|
||||
type=DiagnosticType.WARNING,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def info(self, location: Location, message: str):
|
||||
self.report(
|
||||
type=DiagnosticType.INFO,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
@@ -13,7 +13,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
|
||||
def __init__(self):
|
||||
self.locals: dict[p.Expr, int] = {}
|
||||
self.scopes: list[dict[str, bool]] = []
|
||||
self.scopes: list[dict[str, bool]] = [{}]
|
||||
|
||||
def resolve(self, *objects: p.Stmt | p.Expr) -> None:
|
||||
"""Resolve the given statements or expressions"""
|
||||
@@ -77,6 +77,12 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
self.locals[expr] = i
|
||||
return
|
||||
|
||||
def is_defined(self, name: str) -> bool:
|
||||
for scope in self.scopes:
|
||||
if name in scope:
|
||||
return True
|
||||
return False
|
||||
|
||||
def resolve_function(self, function: p.Function) -> None:
|
||||
"""Resolve a function definition
|
||||
|
||||
@@ -112,8 +118,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
for target in stmt.targets:
|
||||
match target:
|
||||
case p.VariableExpr(name=name):
|
||||
self.resolve_local(target, name)
|
||||
# TODO: declare if not found
|
||||
if not self.is_defined(name):
|
||||
self.declare(name)
|
||||
self.define(name)
|
||||
target.accept(self)
|
||||
|
||||
case p.GetExpr():
|
||||
target.accept(self)
|
||||
case _:
|
||||
raise Exception(f"Unsupported assignment to {target}")
|
||||
|
||||
@@ -174,10 +185,6 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
self.resolve(expr.left)
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_set_expr(self, expr: p.SetExpr) -> None:
|
||||
self.resolve(expr.value)
|
||||
self.resolve(expr.object)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||
self.resolve(expr.expr)
|
||||
|
||||
@@ -185,3 +192,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
self.resolve(expr.test)
|
||||
self.resolve(expr.if_true)
|
||||
self.resolve(expr.if_false)
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||
for item in expr.items:
|
||||
self.resolve(item)
|
||||
@@ -1,27 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class BaseType:
|
||||
name: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class AliasType:
|
||||
name: str
|
||||
type: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class UnknownType:
|
||||
pass
|
||||
def __str__(self) -> str:
|
||||
return "<Unknown>"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class UnitType:
|
||||
pass
|
||||
def __str__(self) -> str:
|
||||
return "None"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@@ -32,16 +41,159 @@ class Function:
|
||||
kw_args: list[Argument]
|
||||
returns: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
args: list[str] = []
|
||||
if len(self.pos_args) != 0:
|
||||
args += list(map(str, self.pos_args))
|
||||
if len(self.args) + len(self.kw_args) != 0:
|
||||
args.append("/")
|
||||
|
||||
if len(self.args) != 0:
|
||||
args += list(map(str, self.args))
|
||||
|
||||
if len(self.kw_args) != 0:
|
||||
if len(args) != 0:
|
||||
args.append("*")
|
||||
args += list(map(str, self.kw_args))
|
||||
|
||||
return f"{self.name}({', '.join(args)}) -> {self.returns}"
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
pos: int
|
||||
name: str
|
||||
type: Type
|
||||
required: bool
|
||||
|
||||
def __str__(self) -> str:
|
||||
opt: str = "" if self.required else "?"
|
||||
return f"{self.name}: {self.type}{opt}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ComplexType:
|
||||
properties: dict[str, Type]
|
||||
|
||||
def __str__(self) -> str:
|
||||
props: list[str] = [f"{name}: {type}" for name, type in self.properties.items()]
|
||||
return f"{{{', '.join(props)}}}"
|
||||
|
||||
Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Operation:
|
||||
signature: CallSignature
|
||||
result: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.signature} -> {self.result}"
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class CallSignature:
|
||||
left: Type
|
||||
method: str
|
||||
right: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.method}({self.left}, {self.right})"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypeVar:
|
||||
name: str
|
||||
bound: Optional[Type]
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.bound is not None:
|
||||
return f"{self.name} <: {self.bound}"
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class GenericType:
|
||||
name: str
|
||||
params: list[TypeVar]
|
||||
body: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}[{', '.join(map(str, self.params))}]"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class AppliedType:
|
||||
name: str
|
||||
args: list[Type]
|
||||
body: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}[{', '.join(map(str, self.args))}]"
|
||||
|
||||
|
||||
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
def sub_argument(arg: Function.Argument):
|
||||
return Function.Argument(
|
||||
pos=arg.pos,
|
||||
name=arg.name,
|
||||
type=substitute_typevars(arg.type, substitutions),
|
||||
required=arg.required,
|
||||
)
|
||||
|
||||
match type:
|
||||
case BaseType(name=name) if name in substitutions:
|
||||
return substitutions[name]
|
||||
|
||||
case AliasType(name=name, type=type2):
|
||||
return AliasType(name=name, type=substitute_typevars(type2, substitutions))
|
||||
|
||||
case Function(
|
||||
name=name,
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns,
|
||||
):
|
||||
return Function(
|
||||
name=name,
|
||||
pos_args=list(map(sub_argument, pos_args)),
|
||||
args=list(map(sub_argument, args)),
|
||||
kw_args=list(map(sub_argument, kw_args)),
|
||||
returns=substitute_typevars(returns, substitutions),
|
||||
)
|
||||
|
||||
case ComplexType(properties=properties):
|
||||
properties2: dict[str, Type] = {
|
||||
name: substitute_typevars(prop, substitutions)
|
||||
for name, prop in properties.items()
|
||||
}
|
||||
return ComplexType(properties=properties2)
|
||||
|
||||
case TypeVar(name=name):
|
||||
if name in substitutions:
|
||||
return substitutions[name]
|
||||
raise ValueError(f"Missing TypeVar substitution for {name}")
|
||||
|
||||
case UnknownType() | UnitType():
|
||||
return type
|
||||
|
||||
case _:
|
||||
raise NotImplementedError(f"Unsupported type {type}")
|
||||
|
||||
|
||||
def unfold_type(type: Type) -> Type:
|
||||
match type:
|
||||
case AliasType(type=ref_type):
|
||||
return unfold_type(ref_type)
|
||||
case _:
|
||||
return type
|
||||
|
||||
|
||||
Type = (
|
||||
BaseType
|
||||
| AliasType
|
||||
| UnknownType
|
||||
| UnitType
|
||||
| Function
|
||||
| ComplexType
|
||||
| TypeVar
|
||||
| GenericType
|
||||
| AppliedType
|
||||
)
|
||||
|
||||
41
midas/cli/ansi.py
Normal file
41
midas/cli/ansi.py
Normal file
@@ -0,0 +1,41 @@
|
||||
class Ansi:
|
||||
CTRL = "\x1b["
|
||||
RESET = CTRL + "0m"
|
||||
BOLD = CTRL + "1m"
|
||||
DIM = CTRL + "2m"
|
||||
ITALIC = CTRL + "3m"
|
||||
UNDERLINE = CTRL + "4m"
|
||||
|
||||
BLACK = 0
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
YELLOW = 3
|
||||
BLUE = 4
|
||||
MAGENTA = 5
|
||||
CYAN = 6
|
||||
WHITE = 7
|
||||
|
||||
BRIGHT_BLACK = 60
|
||||
BRIGHT_RED = 61
|
||||
BRIGHT_GREEN = 62
|
||||
BRIGHT_YELLOW = 63
|
||||
BRIGHT_BLUE = 64
|
||||
BRIGHT_MAGENTA = 65
|
||||
BRIGHT_CYAN = 66
|
||||
BRIGHT_WHITE = 67
|
||||
|
||||
@classmethod
|
||||
def FG(cls, col: int) -> str:
|
||||
return f"{cls.CTRL}{30 + col}m"
|
||||
|
||||
@classmethod
|
||||
def BG(cls, col: int) -> str:
|
||||
return f"{cls.CTRL}{40 + col}m"
|
||||
|
||||
@classmethod
|
||||
def FG_RGB(cls, r: int, g: int, b: int) -> str:
|
||||
return f"{cls.CTRL}38;2;{r};{g};{b}m"
|
||||
|
||||
@classmethod
|
||||
def BG_RGB(cls, r: int, g: int, b: int) -> str:
|
||||
return f"{cls.CTRL}48;2;{r};{g};{b}m"
|
||||
@@ -210,12 +210,14 @@ class PythonHighlighter(
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ...
|
||||
|
||||
def visit_set_expr(self, expr: p.SetExpr) -> None: ...
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||
for item in expr.items:
|
||||
item.accept(self)
|
||||
|
||||
|
||||
class MidasHighlighter(
|
||||
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
|
||||
@@ -286,8 +288,8 @@ class MidasHighlighter(
|
||||
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||
self.wrap(type, "generic-type")
|
||||
type.type.accept(self)
|
||||
for param in type.params:
|
||||
param.accept(self)
|
||||
for arg in type.args:
|
||||
arg.accept(self)
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||
self.wrap(type, "constraint-type")
|
||||
@@ -299,6 +301,12 @@ class MidasHighlighter(
|
||||
for prop in type.properties:
|
||||
prop.accept(self)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||
self.wrap(type, "function")
|
||||
for arg in type.pos_args + type.kw_args:
|
||||
arg.type.accept(self)
|
||||
type.returns.accept(self)
|
||||
|
||||
|
||||
class DiagnosticsHighlighter(Highlighter):
|
||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
||||
|
||||
@@ -8,10 +8,12 @@ import click
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
|
||||
from midas.checker.checker import Checker
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.checker.checker import TypeChecker
|
||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||
from midas.checker.types import Type
|
||||
from midas.cli.ansi import Ansi
|
||||
from midas.cli.highlighter import (
|
||||
DiagnosticsHighlighter,
|
||||
Highlighter,
|
||||
@@ -23,7 +25,6 @@ from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token, TokenType
|
||||
from midas.parser.midas import MidasParser
|
||||
from midas.parser.python import PythonParser
|
||||
from midas.resolver.resolver import Resolver
|
||||
from midas.utils import UniversalJSONDumper
|
||||
|
||||
|
||||
@@ -32,38 +33,92 @@ def midas():
|
||||
pass
|
||||
|
||||
|
||||
def print_diagnostic(lines: list[str], diagnostic: Diagnostic, indent: int = 4):
|
||||
"""Pretty-print a diagnostic, showing some context if possible
|
||||
|
||||
If the diagnostic concerns a specific part of one line, the line is shown
|
||||
with the affected part highlighted. The message is clearly printed under the
|
||||
line with an underline further indicating the target expression.
|
||||
|
||||
If multiple lines are concerned, no context is shown, only the
|
||||
diagnostic type, location and message
|
||||
|
||||
Args:
|
||||
lines (list[str]): source code lines
|
||||
diagnostic (Diagnostic): the diagnostic to print
|
||||
indent (int, optional): the number of spaces added before the target line to indent if from the location header. Defaults to 4.
|
||||
"""
|
||||
|
||||
loc: Location = diagnostic.location
|
||||
if loc.lineno != loc.end_lineno:
|
||||
print(diagnostic)
|
||||
return
|
||||
|
||||
start_offset: int = loc.col_offset
|
||||
end_offset: int = loc.end_col_offset or (start_offset + 1)
|
||||
|
||||
line: str = lines[loc.lineno - 1]
|
||||
before: str = line[:start_offset]
|
||||
after: str = line[end_offset:]
|
||||
|
||||
color: int = {
|
||||
DiagnosticType.ERROR: Ansi.RED,
|
||||
DiagnosticType.WARNING: Ansi.YELLOW,
|
||||
DiagnosticType.INFO: Ansi.CYAN,
|
||||
}.get(diagnostic.type, Ansi.WHITE)
|
||||
|
||||
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
|
||||
cursor: str = (
|
||||
" " * start_offset
|
||||
+ Ansi.FG(color)
|
||||
+ "~" * (end_offset - start_offset)
|
||||
+ "> "
|
||||
+ diagnostic.message
|
||||
+ Ansi.RESET
|
||||
)
|
||||
|
||||
indent_str: str = " " * indent
|
||||
print(diagnostic.location_str + ":")
|
||||
print(indent_str + before + subject + after)
|
||||
print(indent_str + cursor)
|
||||
print()
|
||||
|
||||
|
||||
@midas.command()
|
||||
@click.option("-l", "--highlight", type=click.File("w"))
|
||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||
@click.option("-v", "--verbose", is_flag=True)
|
||||
@click.argument("file", type=click.File("r"))
|
||||
def compile(highlight: Optional[TextIO], file: TextIO, types: tuple[TextIO]):
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
def compile(
|
||||
highlight: Optional[TextIO],
|
||||
types: tuple[TextIO],
|
||||
verbose: bool,
|
||||
file: TextIO,
|
||||
):
|
||||
logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN)
|
||||
source: str = file.read()
|
||||
tree: ast.Module = ast.parse(source, filename=file.name)
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
resolver = Resolver()
|
||||
resolver.resolve(*stmts)
|
||||
types_paths: list[Path] = [Path(t.name).resolve() for t in types]
|
||||
checker = Checker(
|
||||
resolver.locals,
|
||||
source_path=Path(file.name).resolve(),
|
||||
types_paths=types_paths,
|
||||
)
|
||||
diagnostics: list[Diagnostic] = checker.check(stmts)
|
||||
for diagnostic in diagnostics:
|
||||
print(diagnostic)
|
||||
|
||||
print(
|
||||
json.dumps(
|
||||
UniversalJSONDumper.dump(
|
||||
checker.global_env,
|
||||
[("Environment", "_children")],
|
||||
lambda obj: isinstance(obj, get_args(Type)),
|
||||
),
|
||||
indent=4,
|
||||
checker = TypeChecker()
|
||||
for path in types:
|
||||
checker.import_midas(Path(path.name).resolve())
|
||||
|
||||
checker.type_check_source(source, str(Path(file.name).resolve()))
|
||||
diagnostics: list[Diagnostic] = checker.diagnostics
|
||||
lines: list[str] = source.split("\n")
|
||||
for diagnostic in diagnostics:
|
||||
print_diagnostic(lines, diagnostic)
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
json.dumps(
|
||||
UniversalJSONDumper.dump(
|
||||
checker.python_typer.global_env,
|
||||
[("Environment", "_children")],
|
||||
lambda obj: isinstance(obj, get_args(Type)),
|
||||
),
|
||||
indent=4,
|
||||
)
|
||||
)
|
||||
)
|
||||
if highlight is not None:
|
||||
highlighter = DiagnosticsHighlighter(source)
|
||||
highlighter.highlight(diagnostics)
|
||||
|
||||
@@ -50,12 +50,14 @@ class MidasLexer(Lexer):
|
||||
# self.add_token(TokenType.PLUS)
|
||||
case "-":
|
||||
self.add_token(TokenType.MINUS)
|
||||
# case "*":
|
||||
# self.add_token(TokenType.STAR)
|
||||
case "*":
|
||||
self.add_token(TokenType.STAR)
|
||||
case "/" if self.match("/"):
|
||||
self.scan_comment()
|
||||
case "/" if self.match("*"):
|
||||
self.scan_comment_multiline()
|
||||
case "/":
|
||||
self.add_token(TokenType.SLASH)
|
||||
case "\n":
|
||||
self.add_token(TokenType.NEWLINE)
|
||||
case " " | "\r" | "\t":
|
||||
|
||||
@@ -27,8 +27,8 @@ class TokenType(Enum):
|
||||
# Operators
|
||||
# PLUS = auto()
|
||||
MINUS = auto()
|
||||
# STAR = auto()
|
||||
# SLASH = auto()
|
||||
STAR = auto()
|
||||
SLASH = auto()
|
||||
GREATER = auto()
|
||||
GREATER_EQUAL = auto()
|
||||
LESS = auto()
|
||||
|
||||
@@ -7,6 +7,7 @@ from midas.ast.midas import (
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
FunctionType,
|
||||
GenericType,
|
||||
GetExpr,
|
||||
GroupingExpr,
|
||||
@@ -18,12 +19,13 @@ from midas.ast.midas import (
|
||||
PropertyStmt,
|
||||
Stmt,
|
||||
Type,
|
||||
TypeParam,
|
||||
TypeStmt,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
WildcardExpr,
|
||||
)
|
||||
from midas.lexer.token import Token, TokenType
|
||||
from midas.lexer.token import KEYWORDS, Token, TokenType
|
||||
from midas.parser.base import Parser
|
||||
from midas.parser.errors import ParsingError
|
||||
|
||||
@@ -107,10 +109,8 @@ class MidasParser(Parser):
|
||||
TypeStmt: the parsed type declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
params: list[TypeStmt.Param] = []
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
params = self.type_stmt_params()
|
||||
name: Token = self.consume_identifier("Expected type name")
|
||||
params: list[TypeParam] = self.type_params()
|
||||
|
||||
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
|
||||
|
||||
@@ -123,24 +123,27 @@ class MidasParser(Parser):
|
||||
type=type,
|
||||
)
|
||||
|
||||
def type_stmt_params(self) -> list[TypeStmt.Param]:
|
||||
"""Parse a generic template expression
|
||||
def type_params(self) -> list[TypeParam]:
|
||||
"""Parse a list of type parameters
|
||||
|
||||
A template is written `[TypeExpr]`
|
||||
Type parameters are a comma-separated list of type variables wrapped in brackets.
|
||||
Each type variable is either a simple variable, or a bounded variable written `S <: T`
|
||||
|
||||
Returns:
|
||||
TemplateExpr: the parsed template expression
|
||||
list[TypeParam]: the list of type parameters, if any, or an empty list
|
||||
"""
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
|
||||
params: list[TypeStmt.Param] = []
|
||||
if not self.match(TokenType.LEFT_BRACKET):
|
||||
return []
|
||||
|
||||
params: list[TypeParam] = []
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
|
||||
name: Token = self.consume_identifier("Expected type variable")
|
||||
bound: Optional[Type] = None
|
||||
if self.match(TokenType.LESS):
|
||||
self.consume(TokenType.COLON, "Expected ':' after '<'")
|
||||
bound = self.type_expr()
|
||||
params.append(
|
||||
TypeStmt.Param(
|
||||
TypeParam(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
bound=bound,
|
||||
@@ -148,7 +151,7 @@ class MidasParser(Parser):
|
||||
)
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
|
||||
return params
|
||||
|
||||
def type_expr(self) -> Type:
|
||||
@@ -187,26 +190,26 @@ class MidasParser(Parser):
|
||||
def generic_type(self) -> Type:
|
||||
type: Type = self.named_type()
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
params: list[Type] = self.type_params()
|
||||
args: list[Type] = self.type_args()
|
||||
return GenericType(
|
||||
location=Location.span(type.location, self.previous().get_location()),
|
||||
type=type,
|
||||
params=params,
|
||||
args=args,
|
||||
)
|
||||
return type
|
||||
|
||||
def type_params(self) -> list[Type]:
|
||||
params: list[Type] = []
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters")
|
||||
def type_args(self) -> list[Type]:
|
||||
args: list[Type] = []
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments")
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
params.append(self.type_expr())
|
||||
args.append(self.type_expr())
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters")
|
||||
return params
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
|
||||
return args
|
||||
|
||||
def named_type(self) -> Type:
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
name: Token = self.consume_identifier("Expected type name")
|
||||
return NamedType(
|
||||
location=name.get_location(),
|
||||
name=name,
|
||||
@@ -322,9 +325,7 @@ class MidasParser(Parser):
|
||||
"""
|
||||
expr: Expr = self.primary()
|
||||
while self.match(TokenType.DOT):
|
||||
name: Token = self.consume(
|
||||
TokenType.IDENTIFIER, "Expected property name after '.'"
|
||||
)
|
||||
name: Token = self.consume_identifier("Expected property name after '.'")
|
||||
location: Location = Location.span(expr.location, name.get_location())
|
||||
expr = GetExpr(location=location, expr=expr, name=name)
|
||||
return expr
|
||||
@@ -348,7 +349,7 @@ class MidasParser(Parser):
|
||||
if self.match(TokenType.NUMBER):
|
||||
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||
|
||||
if self.match(TokenType.IDENTIFIER):
|
||||
if self.match_identifier():
|
||||
return VariableExpr(location=token.get_location(), name=token)
|
||||
|
||||
if self.match(TokenType.UNDERSCORE):
|
||||
@@ -361,6 +362,20 @@ class MidasParser(Parser):
|
||||
|
||||
raise self.error(self.peek(), "Expected expression")
|
||||
|
||||
def consume_identifier(self, message: str = "Expected identifier") -> Token:
|
||||
if not self.match_identifier():
|
||||
raise self.error(self.peek(), message)
|
||||
return self.previous()
|
||||
|
||||
def match_identifier(self) -> bool:
|
||||
return self.match(TokenType.IDENTIFIER, *KEYWORDS.values())
|
||||
|
||||
def check_identifier(self) -> bool:
|
||||
for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]:
|
||||
if self.check(tt):
|
||||
return True
|
||||
return False
|
||||
|
||||
def property_stmt(self) -> PropertyStmt:
|
||||
"""Parse a property statement
|
||||
|
||||
@@ -369,7 +384,7 @@ class MidasParser(Parser):
|
||||
Returns:
|
||||
PropertyStmt: the parsed property statement
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
|
||||
name: Token = self.consume_identifier("Expected property name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after property name")
|
||||
type: Type = self.type_expr()
|
||||
return PropertyStmt(
|
||||
@@ -381,12 +396,14 @@ class MidasParser(Parser):
|
||||
def extend_declaration(self) -> ExtendStmt:
|
||||
"""Parse an extension definition
|
||||
|
||||
An extension is written `extend Type { operations }`
|
||||
An extension is written `extend Type { operations }` or `extend[S <: T, U] Type { operations }`
|
||||
|
||||
Returns:
|
||||
ExtendStmt: the parsed extension statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
params: list[TypeParam] = self.type_params()
|
||||
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
|
||||
operations: list[OpStmt] = []
|
||||
@@ -394,7 +411,12 @@ class MidasParser(Parser):
|
||||
operations.append(self.op_declaration())
|
||||
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
|
||||
location: Location = keyword.location_to(self.previous())
|
||||
return ExtendStmt(location=location, type=type, operations=operations)
|
||||
return ExtendStmt(
|
||||
location=location,
|
||||
params=params,
|
||||
type=type,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def op_declaration(self) -> OpStmt:
|
||||
"""Parse an operation definition
|
||||
@@ -430,9 +452,9 @@ class MidasParser(Parser):
|
||||
PredicateStmt: the parsed predicate declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
|
||||
name: Token = self.consume_identifier("Expected predicate name")
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
||||
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
|
||||
subject: Token = self.consume_identifier("Expected subject name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
||||
@@ -445,3 +467,48 @@ class MidasParser(Parser):
|
||||
type=type,
|
||||
condition=condition,
|
||||
)
|
||||
|
||||
def function(self) -> FunctionType:
|
||||
l_paren: Token = self.consume(
|
||||
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
||||
)
|
||||
pos_args: list[FunctionType.Argument] = []
|
||||
kw_args: list[FunctionType.Argument] = []
|
||||
|
||||
positional: bool = True
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
|
||||
if positional and (
|
||||
self.match(TokenType.STAR) or self.match(TokenType.SLASH)
|
||||
):
|
||||
positional = False
|
||||
else:
|
||||
name: Optional[Token] = None
|
||||
if self.check_identifier() and self.check_next(TokenType.COLON):
|
||||
name = self.advance()
|
||||
self.advance()
|
||||
type: Type = self.type_expr()
|
||||
required: bool = self.match(TokenType.QMARK)
|
||||
arg = FunctionType.Argument(
|
||||
location=None,
|
||||
name=name,
|
||||
type=type,
|
||||
required=required,
|
||||
)
|
||||
if positional:
|
||||
pos_args.append(arg)
|
||||
else:
|
||||
kw_args.append(arg)
|
||||
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: Type = self.type_expr()
|
||||
|
||||
return FunctionType(
|
||||
location=l_paren.location_to(self.previous()),
|
||||
pos_args=pos_args,
|
||||
kw_args=kw_args,
|
||||
returns=result,
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ from midas.ast.python import (
|
||||
Function,
|
||||
GetExpr,
|
||||
IfStmt,
|
||||
ListExpr,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MidasType,
|
||||
@@ -416,6 +417,12 @@ class PythonParser:
|
||||
case ast.Name(id=name):
|
||||
return VariableExpr(location=location, name=name)
|
||||
|
||||
case ast.List(elts=items):
|
||||
return ListExpr(
|
||||
location=location,
|
||||
items=[self.parse_expr(item) for item in items],
|
||||
)
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(node)
|
||||
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from midas.checker.types import BaseType, Type, UnitType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.resolver.midas import MidasResolver
|
||||
|
||||
|
||||
def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
|
||||
ctx.define_operation(
|
||||
left=t1,
|
||||
operator=operator,
|
||||
right=t2,
|
||||
result=t3,
|
||||
)
|
||||
|
||||
|
||||
def basic_op(ctx: MidasResolver, type: Type, op: str):
|
||||
ctx.define_operation(
|
||||
left=type,
|
||||
operator=op,
|
||||
right=type,
|
||||
result=type,
|
||||
)
|
||||
|
||||
|
||||
def define_builtins(ctx: MidasResolver):
|
||||
"""Define builtin types and operations"""
|
||||
unit = ctx.define_type("None", UnitType())
|
||||
bool = ctx.define_type("bool", BaseType(name="bool"))
|
||||
int = ctx.define_type("int", BaseType(name="int"))
|
||||
float = ctx.define_type("float", BaseType(name="float"))
|
||||
str = ctx.define_type("str", BaseType(name="str"))
|
||||
|
||||
basic_op(ctx, int, "__add__") # int + int = int
|
||||
basic_op(ctx, int, "__sub__") # int - int = int
|
||||
basic_op(ctx, int, "__mul__") # int * int = int
|
||||
basic_op(ctx, int, "__pow__") # int ** int = int
|
||||
basic_op(ctx, int, "__mod__") # int % int = int
|
||||
basic_op(ctx, int, "__and__") # int & int = int
|
||||
basic_op(ctx, int, "__or__") # int | int = int
|
||||
basic_op(ctx, int, "__xor__") # int ^ int = int
|
||||
op(ctx, int, "__lt__", int, bool) # int < int = bool
|
||||
op(ctx, int, "__gt__", int, bool) # int > int = bool
|
||||
op(ctx, int, "__le__", int, bool) # int <= int = bool
|
||||
op(ctx, int, "__ge__", int, bool) # int >= int = bool
|
||||
op(ctx, int, "__eq__", int, bool) # int == int = bool
|
||||
basic_op(ctx, float, "__add__") # float + float = float
|
||||
basic_op(ctx, float, "__sub__") # float - float = float
|
||||
basic_op(ctx, float, "__mul__") # float * float = float
|
||||
basic_op(ctx, float, "__truediv__") # float / float = float
|
||||
op(ctx, float, "__lt__", float, bool) # float < float = bool
|
||||
op(ctx, float, "__gt__", float, bool) # float > float = bool
|
||||
op(ctx, float, "__le__", float, bool) # float <= float = bool
|
||||
op(ctx, float, "__ge__", float, bool) # float >= float = bool
|
||||
op(ctx, float, "__eq__", float, bool) # float == float = bool
|
||||
basic_op(ctx, str, "__add__") # str + str = str
|
||||
op(ctx, str, "__eq__", str, bool) # str == str = bool
|
||||
|
||||
op(ctx, int, "__lt__", float, bool) # int < float = bool
|
||||
op(ctx, int, "__gt__", float, bool) # int > float = bool
|
||||
op(ctx, int, "__le__", float, bool) # int <= float = bool
|
||||
op(ctx, int, "__ge__", float, bool) # int >= float = bool
|
||||
op(ctx, int, "__eq__", float, bool) # int == float = bool
|
||||
|
||||
op(ctx, float, "__lt__", int, bool) # float < int = bool
|
||||
op(ctx, float, "__gt__", int, bool) # float > int = bool
|
||||
op(ctx, float, "__le__", int, bool) # float <= int = bool
|
||||
op(ctx, float, "__ge__", int, bool) # float >= int = bool
|
||||
op(ctx, float, "__eq__", int, bool) # float == int = bool
|
||||
@@ -1,163 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.resolver.builtin import define_builtins
|
||||
|
||||
|
||||
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._types: dict[str, Type] = {}
|
||||
self._operations: dict[tuple[Type, str, Type], Type] = {}
|
||||
|
||||
define_builtins(self)
|
||||
|
||||
def get_type(self, name: str) -> Type:
|
||||
"""Get a type from its name
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
|
||||
Raises:
|
||||
NameError: if the type is not defined
|
||||
|
||||
Returns:
|
||||
Type: the type
|
||||
"""
|
||||
type: Optional[Type] = self._types.get(name)
|
||||
if type is None:
|
||||
raise NameError(f"Undefined type {name}")
|
||||
return type
|
||||
|
||||
def get_operation_result(
|
||||
self, left: Type, operator: str, right: Type
|
||||
) -> Optional[Type]:
|
||||
"""Get the resulting type of an operation
|
||||
|
||||
Args:
|
||||
left (Type): the type of the left operand
|
||||
operator (str): the operation name
|
||||
right (Type): the type of the right operand
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the result type, or None if no matching operation was found
|
||||
"""
|
||||
operation: tuple[Type, str, Type] = (left, operator, right)
|
||||
result: Optional[Type] = self._operations.get(operation)
|
||||
return result
|
||||
|
||||
def define_type(self, name: str, type: Type) -> Type:
|
||||
"""Define a type in the registry
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
type (Type): the type to define
|
||||
|
||||
Raises:
|
||||
ValueError: if a type is already defined with that name
|
||||
|
||||
Returns:
|
||||
Type: the defined type
|
||||
"""
|
||||
if name in self._types:
|
||||
raise ValueError(f"Type {name} already defined")
|
||||
self._types[name] = type
|
||||
return type
|
||||
|
||||
def define_operation(self, left: Type, operator: str, right: Type, result: Type):
|
||||
"""Define an operation in the registry
|
||||
|
||||
Args:
|
||||
left (Type): the type of the left operand
|
||||
operator (str): the operation name
|
||||
right (Type): the type of the right operand
|
||||
result (Type): the result type
|
||||
|
||||
Raises:
|
||||
ValueError: if an operation is already defined with these operands and name
|
||||
"""
|
||||
operation: tuple[Type, str, Type] = (left, operator, right)
|
||||
if operation in self._operations:
|
||||
raise ValueError(
|
||||
f"Operation {operator} already defined between {left} and {right}"
|
||||
)
|
||||
self._operations[operation] = result
|
||||
|
||||
def resolve(self, stmts: list[m.Stmt]):
|
||||
"""Process a sequence of statements
|
||||
|
||||
Args:
|
||||
stmts (list[m.Stmt]): the statements
|
||||
"""
|
||||
for stmt in stmts:
|
||||
stmt.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
type: Type = stmt.type.accept(self)
|
||||
for param in stmt.params:
|
||||
if param.bound is not None:
|
||||
param.bound.accept(self)
|
||||
name: str = stmt.name.lexeme
|
||||
self.define_type(name, AliasType(name=name, type=type))
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
base: Type = stmt.type.accept(self)
|
||||
for op in stmt.operations:
|
||||
right: Type = op.operand.accept(self)
|
||||
result: Type = op.result.accept(self)
|
||||
self.define_operation(
|
||||
left=base,
|
||||
operator=op.name.lexeme,
|
||||
right=right,
|
||||
result=result,
|
||||
)
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> None: ...
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||
return expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||
return self.get_type(type.name.lexeme)
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
params: list[Type] = [param.accept(self) for param in type.params]
|
||||
# TODO
|
||||
return UnknownType()
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
type.constraint.accept(self)
|
||||
# TODO
|
||||
return UnknownType()
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> Type:
|
||||
for prop in type.properties:
|
||||
prop.accept(self)
|
||||
# TODO
|
||||
return UnknownType()
|
||||
@@ -29,7 +29,7 @@ class Tester(ABC):
|
||||
def _list_tests(self) -> list[Path]: ...
|
||||
|
||||
def run_all_tests(self) -> bool:
|
||||
paths: list[Path] = self._list_tests()
|
||||
paths: list[Path] = sorted(self._list_tests())
|
||||
return self.run_tests(paths)
|
||||
|
||||
def run_tests(self, tests: list[Path]) -> bool:
|
||||
@@ -40,7 +40,7 @@ class Tester(ABC):
|
||||
|
||||
print(rule)
|
||||
for i, test in enumerate(tests):
|
||||
print(f"Case {i+1}/{n}: {test.relative_to(self.CASES_DIR)}")
|
||||
print(f"Case {i+1}/{n}: {test.resolve().relative_to(self.CASES_DIR)}")
|
||||
success: bool = self._run_test(test)
|
||||
if success:
|
||||
successes += 1
|
||||
@@ -78,7 +78,7 @@ class Tester(ABC):
|
||||
def _exec_case(self, path: Path) -> CaseResult: ...
|
||||
|
||||
def update_all_tests(self):
|
||||
paths: list[Path] = self._list_tests()
|
||||
paths: list[Path] = sorted(self._list_tests())
|
||||
return self.update_tests(paths)
|
||||
|
||||
def update_tests(self, tests: list[Path]):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{
|
||||
"diagnostics": []
|
||||
"diagnostics": [],
|
||||
"judgments": []
|
||||
}
|
||||
@@ -12,35 +12,168 @@
|
||||
13
|
||||
]
|
||||
},
|
||||
"message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')"
|
||||
"message": "Cannot assign str to variable 'c' of type int"
|
||||
}
|
||||
],
|
||||
"judgments": [
|
||||
{
|
||||
"location": {
|
||||
"from": "L1:9",
|
||||
"to": "L1:10"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 3
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
9,
|
||||
4
|
||||
],
|
||||
"end": [
|
||||
9,
|
||||
9
|
||||
]
|
||||
"from": "L2:9",
|
||||
"to": "L2:10"
|
||||
},
|
||||
"message": "Undefined operation __add__ between BaseType(name='bool') and BaseType(name='bool')"
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 4
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
11,
|
||||
0
|
||||
],
|
||||
"end": [
|
||||
11,
|
||||
12
|
||||
]
|
||||
"from": "L4:4",
|
||||
"to": "L4:5"
|
||||
},
|
||||
"message": "Cannot assign BaseType(name='int') to f of type BaseType(name='float')"
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L4:8",
|
||||
"to": "L4:9"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L4:4",
|
||||
"to": "L4:9"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"operator": "+",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:4",
|
||||
"to": "L6:13"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "invalid"
|
||||
},
|
||||
"type": {
|
||||
"name": "str"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:4",
|
||||
"to": "L8:8"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": true
|
||||
},
|
||||
"type": {
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:4",
|
||||
"to": "L9:5"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "d"
|
||||
},
|
||||
"type": {
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:8",
|
||||
"to": "L9:9"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "d"
|
||||
},
|
||||
"type": {
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:4",
|
||||
"to": "L9:9"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "d"
|
||||
},
|
||||
"operator": "+",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "d"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:11",
|
||||
"to": "L11:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,109 @@
|
||||
{
|
||||
"diagnostics": []
|
||||
"diagnostics": [],
|
||||
"judgments": [
|
||||
{
|
||||
"location": {
|
||||
"from": "L4:18",
|
||||
"to": "L4:37"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "CastExpr",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Meter",
|
||||
"param": null
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 123.45
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "Meter",
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L5:15",
|
||||
"to": "L5:32"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "CastExpr",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Second",
|
||||
"param": null
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 6.7
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "Second",
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:8",
|
||||
"to": "L6:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "distance"
|
||||
},
|
||||
"type": {
|
||||
"name": "Meter",
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:19",
|
||||
"to": "L6:23"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "time"
|
||||
},
|
||||
"type": {
|
||||
"name": "Second",
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:8",
|
||||
"to": "L6:23"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "distance"
|
||||
},
|
||||
"operator": "/",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "time"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "MeterPerSecond",
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -42,5 +42,215 @@
|
||||
},
|
||||
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
|
||||
}
|
||||
],
|
||||
"judgments": [
|
||||
{
|
||||
"location": {
|
||||
"from": "L2:11",
|
||||
"to": "L2:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L2:15",
|
||||
"to": "L2:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L5:7",
|
||||
"to": "L5:8"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L5:11",
|
||||
"to": "L5:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:15",
|
||||
"to": "L6:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:19",
|
||||
"to": "L6:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:15",
|
||||
"to": "L8:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:19",
|
||||
"to": "L8:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L15:7",
|
||||
"to": "L15:8"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L15:11",
|
||||
"to": "L15:13"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 10
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L16:15",
|
||||
"to": "L16:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L16:19",
|
||||
"to": "L16:21"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 10
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L22:7",
|
||||
"to": "L22:8"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L22:11",
|
||||
"to": "L22:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L23:15",
|
||||
"to": "L23:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L23:19",
|
||||
"to": "L23:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
12
tests/cases/checker/06_subtyping.py
Normal file
12
tests/cases/checker/06_subtyping.py
Normal file
@@ -0,0 +1,12 @@
|
||||
v1: int = 3
|
||||
v2: float = 4
|
||||
|
||||
|
||||
def maximum(a: float, b: float):
|
||||
if b > a:
|
||||
return b
|
||||
return a
|
||||
|
||||
|
||||
v3 = maximum(v1, v2)
|
||||
v3 = v1 + v2
|
||||
193
tests/cases/checker/06_subtyping.py.ref.json
Normal file
193
tests/cases/checker/06_subtyping.py.ref.json
Normal file
@@ -0,0 +1,193 @@
|
||||
{
|
||||
"diagnostics": [],
|
||||
"judgments": [
|
||||
{
|
||||
"location": {
|
||||
"from": "L1:10",
|
||||
"to": "L1:11"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 3
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L2:12",
|
||||
"to": "L2:13"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 4
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:7",
|
||||
"to": "L6:8"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:11",
|
||||
"to": "L6:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:5",
|
||||
"to": "L11:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "maximum"
|
||||
},
|
||||
"type": {
|
||||
"name": "maximum",
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:13",
|
||||
"to": "L11:15"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v1"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:17",
|
||||
"to": "L11:19"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v2"
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:5",
|
||||
"to": "L11:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "CallExpr",
|
||||
"callee": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "maximum"
|
||||
},
|
||||
"arguments": [
|
||||
{
|
||||
"_type": "VariableExpr",
|
||||
"name": "v1"
|
||||
},
|
||||
{
|
||||
"_type": "VariableExpr",
|
||||
"name": "v2"
|
||||
}
|
||||
],
|
||||
"keywords": {}
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:5",
|
||||
"to": "L12:7"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v1"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:10",
|
||||
"to": "L12:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v2"
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:5",
|
||||
"to": "L12:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v1"
|
||||
},
|
||||
"operator": "+",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v2"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -2385,7 +2385,7 @@
|
||||
"_type": "NamedType",
|
||||
"name": "Difference"
|
||||
},
|
||||
"params": [
|
||||
"args": [
|
||||
{
|
||||
"_type": "NamedType",
|
||||
"name": "GeoLocation"
|
||||
@@ -2416,7 +2416,7 @@
|
||||
"_type": "NamedType",
|
||||
"name": "Difference"
|
||||
},
|
||||
"params": [
|
||||
"args": [
|
||||
{
|
||||
"_type": "NamedType",
|
||||
"name": "Latitude"
|
||||
@@ -2433,7 +2433,7 @@
|
||||
"_type": "NamedType",
|
||||
"name": "Difference"
|
||||
},
|
||||
"params": [
|
||||
"args": [
|
||||
{
|
||||
"_type": "NamedType",
|
||||
"name": "Longitude"
|
||||
@@ -2464,7 +2464,7 @@
|
||||
"_type": "NamedType",
|
||||
"name": "Difference"
|
||||
},
|
||||
"params": [
|
||||
"args": [
|
||||
{
|
||||
"_type": "NamedType",
|
||||
"name": "Latitude"
|
||||
@@ -2494,7 +2494,7 @@
|
||||
"_type": "NamedType",
|
||||
"name": "Difference"
|
||||
},
|
||||
"params": [
|
||||
"args": [
|
||||
{
|
||||
"_type": "NamedType",
|
||||
"name": "Longitude"
|
||||
@@ -2638,7 +2638,7 @@
|
||||
"_type": "NamedType",
|
||||
"name": "Optional"
|
||||
},
|
||||
"params": [
|
||||
"args": [
|
||||
{
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import ast
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.checker.checker import Checker
|
||||
from midas.checker.checker import TypeChecker
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.parser.python import PythonParser
|
||||
from midas.resolver.resolver import Resolver
|
||||
from midas.checker.types import Type
|
||||
from tests.base import Tester
|
||||
from tests.serializer.python import PythonAstJsonSerializer
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaseResult:
|
||||
diagnostics: list[dict] = field(default_factory=list)
|
||||
judgments: list = field(default_factory=list)
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(asdict(self), indent=2)
|
||||
@@ -33,23 +33,16 @@ class CheckerTester(Tester):
|
||||
if not path.is_file():
|
||||
raise TypeError(f"Test '{path}' is not a file")
|
||||
|
||||
types_paths: list[Path] = []
|
||||
result: CaseResult = CaseResult()
|
||||
|
||||
checker = TypeChecker()
|
||||
types_path: Path = path.with_suffix(".midas")
|
||||
if types_path.exists():
|
||||
types_paths.append(types_path)
|
||||
source: str = path.read_text()
|
||||
tree: ast.Module = ast.parse(source, filename=path)
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
resolver = Resolver()
|
||||
resolver.resolve(*stmts)
|
||||
result: CaseResult = CaseResult()
|
||||
checker = Checker(
|
||||
resolver.locals,
|
||||
source_path=path,
|
||||
types_paths=types_paths,
|
||||
)
|
||||
diagnostics: list[Diagnostic] = checker.check(stmts)
|
||||
checker.import_midas(types_path)
|
||||
|
||||
checker.type_check(path)
|
||||
|
||||
diagnostics: list[Diagnostic] = checker.diagnostics
|
||||
for diagnostic in diagnostics:
|
||||
result.diagnostics.append(
|
||||
{
|
||||
@@ -68,6 +61,21 @@ class CheckerTester(Tester):
|
||||
}
|
||||
)
|
||||
|
||||
judgements: list[tuple[p.Expr, Type]] = checker.python_typer.judgements
|
||||
serializer = PythonAstJsonSerializer()
|
||||
for expr, type in judgements:
|
||||
loc = expr.location
|
||||
result.judgments.append(
|
||||
{
|
||||
"location": {
|
||||
"from": f"L{loc.lineno}:{loc.col_offset}",
|
||||
"to": f"L{loc.end_lineno}:{loc.end_col_offset}",
|
||||
},
|
||||
"expr": expr.accept(serializer),
|
||||
"type": asdict(type),
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from midas.ast.midas import (
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
FunctionType,
|
||||
GenericType,
|
||||
GetExpr,
|
||||
GroupingExpr,
|
||||
@@ -17,6 +18,7 @@ from midas.ast.midas import (
|
||||
PropertyStmt,
|
||||
Stmt,
|
||||
Type,
|
||||
TypeParam,
|
||||
TypeStmt,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
@@ -46,13 +48,11 @@ class MidasAstJsonSerializer(
|
||||
return {
|
||||
"_type": "TypeStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"params": [
|
||||
self._serialize_type_stmt_template_param(param) for param in stmt.params
|
||||
],
|
||||
"params": [self._serialize_type_param(param) for param in stmt.params],
|
||||
"type": stmt.type.accept(self),
|
||||
}
|
||||
|
||||
def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict:
|
||||
def _serialize_type_param(self, param: TypeParam) -> dict:
|
||||
return {
|
||||
"name": param.name.lexeme,
|
||||
"bound": self._serialize_optional(param.bound),
|
||||
@@ -150,7 +150,7 @@ class MidasAstJsonSerializer(
|
||||
return {
|
||||
"_type": "GenericType",
|
||||
"type": type.type.accept(self),
|
||||
"params": self._serialize_list(type.params),
|
||||
"args": self._serialize_list(type.args),
|
||||
}
|
||||
|
||||
def visit_constraint_type(self, type: ConstraintType) -> dict:
|
||||
@@ -165,3 +165,18 @@ class MidasAstJsonSerializer(
|
||||
"_type": "ComplexType",
|
||||
"properties": self._serialize_list(type.properties),
|
||||
}
|
||||
|
||||
def visit_function_type(self, type: FunctionType) -> dict:
|
||||
return {
|
||||
"_type": "FunctionType",
|
||||
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
|
||||
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
|
||||
"returns": type.returns.accept(self),
|
||||
}
|
||||
|
||||
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
||||
return {
|
||||
"name": arg.name,
|
||||
"type": arg.type.accept(self),
|
||||
"required": arg.required,
|
||||
}
|
||||
|
||||
@@ -16,11 +16,11 @@ from midas.ast.python import (
|
||||
Function,
|
||||
GetExpr,
|
||||
IfStmt,
|
||||
ListExpr,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MidasType,
|
||||
ReturnStmt,
|
||||
SetExpr,
|
||||
Stmt,
|
||||
TernaryExpr,
|
||||
TypeAssign,
|
||||
@@ -232,14 +232,6 @@ class PythonAstJsonSerializer(
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_set_expr(self, expr: SetExpr) -> dict:
|
||||
return {
|
||||
"_type": "SetExpr",
|
||||
"object": expr.object.accept(self),
|
||||
"name": expr.name,
|
||||
"value": expr.value.accept(self),
|
||||
}
|
||||
|
||||
def visit_cast_expr(self, expr: CastExpr) -> dict:
|
||||
return {
|
||||
"_type": "CastExpr",
|
||||
@@ -254,3 +246,9 @@ class PythonAstJsonSerializer(
|
||||
"if_true": expr.if_true.accept(self),
|
||||
"if_false": expr.if_false.accept(self),
|
||||
}
|
||||
|
||||
def visit_list_expr(self, expr: ListExpr) -> dict:
|
||||
return {
|
||||
"_type": "ListExpr",
|
||||
"items": [item.accept(self) for item in expr.items],
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user