35 Commits

Author SHA1 Message Date
c6ead886ec feat: add function type to midas syntax 2026-06-09 23:48:06 +02:00
9de03bf2b5 feat(types): add type params to extend statement 2026-06-09 23:40:57 +02:00
a26b9293be refactor(types): extract TypeParams
also rename generic type params to type args (when calling a generic)
2026-06-09 15:30:45 +02:00
efa5454776 feat(types): add human-friendly string rep
add `__str__` methods on type structures to improve readability of diagnostics
2026-06-09 12:59:36 +02:00
b8bb8190c4 fix(resolver): define variable on assignment
if a variable is not already defined when an assignment is visited, it is then defined in the current scope
2026-06-09 08:06:46 +02:00
a4f5db7ece fix(checker): use reduce_types to infer return type 2026-06-09 08:05:31 +02:00
fc67f01f34 refactor(checker): extract reduce_types function 2026-06-09 08:04:45 +02:00
0a748a36a3 feat(types): WIP add AppliedType 2026-06-08 18:26:11 +02:00
89fdd1b47e feat(checker): WIP add lists 2026-06-08 18:25:37 +02:00
0cde53ac6e feat(types): add name to generic type 2026-06-08 18:21:40 +02:00
f3ec3606c2 fix: avoid circular import in builtins.py 2026-06-08 13:48:46 +02:00
67ec029529 refactor(resolver): move resolver to checker module 2026-06-08 13:45:48 +02:00
e2aef7a811 refactor(checker): unify builtins definitions 2026-06-08 13:44:26 +02:00
86ba4e658a refactor(checker): restructure around shared registry
restructure the type checker with a shared TypesRegistry used by MidasTyper and PythonTyper

this commit also relocates some methods in more appropriate places, such as is_subtype and apply_generic (now in TypesRegistry)
2026-06-08 13:41:42 +02:00
7eccf59558 feat(checker): add reporter class 2026-06-08 13:38:35 +02:00
9dd7801d2d feat(resolver): handle generic application 2026-06-08 10:59:01 +02:00
154cb8b314 refactor(checker): move is_subtype to resolver 2026-06-08 10:57:50 +02:00
c64ab434b5 refactor(checker): move unfold_type to types.py 2026-06-08 10:56:27 +02:00
25e6410546 feat(resolver): handle generics definition 2026-06-08 10:55:15 +02:00
8a22acc17c feat(checker): add generic type structure 2026-06-08 10:52:34 +02:00
e0179bc442 feat(checker): handle assignments to attributes 2026-06-07 17:50:56 +02:00
e665d03533 fix: remove unused SetExpr 2026-06-07 17:48:31 +02:00
b8cb2b4273 feat(checker): handle attribute getter 2026-06-07 15:07:24 +02:00
d278dc5f5b tests: update tests with operation overloads 2026-06-07 14:28:36 +02:00
59e73f0fd9 fix(checker): invert property subtype check 2026-06-07 14:00:02 +02:00
3e0dc60283 fix(checker): only unfold alias on subtype 2026-06-07 13:59:27 +02:00
c24eb5125e feat(checker): resolve operation overloads with subtypes 2026-06-07 13:43:43 +02:00
25bd895dde feat(cli): improve diagnostic printing 2026-06-07 13:42:15 +02:00
bccd75317e tests: add subtyping test 2026-06-06 16:59:49 +02:00
f0e3f7574f feat(tests): add judgements to test results
add type judgements to checker test results and update all tests (including the new subtyping rules)
2026-06-06 16:58:13 +02:00
5d44081847 feat(checker): implement function subtyping
the logic for checking function subtypes is a WIP and has not been fully tested, there may be some errors and unhandled edge cases
Claude helped lay out and verify the overall steps

Co-authored-by: Claude <noreply@anthropic.com>
2026-06-06 16:53:52 +02:00
2a2bb0aec7 feat(checker): store function param position 2026-06-06 16:50:42 +02:00
67c40a3909 feat(checker): add is_subtype method 2026-06-06 16:30:04 +02:00
1c30188122 feat(checker): record type judgements 2026-06-06 16:25:33 +02:00
82a0f13242 feat(cli): add verbose flag to compile 2026-06-05 14:17:24 +02:00
38 changed files with 3919 additions and 989 deletions

View File

@@ -0,0 +1,11 @@
type Meter = float
extend Meter {
op __add__(Meter) -> Meter
op __sub__(Meter) -> Meter
}
type Coordinate = {
x: Meter
y: Meter
}

View File

@@ -0,0 +1,14 @@
# type: ignore
# ruff: disable [F821]
p1: Coordinate
p2: Coordinate
diff_x = p2.x - p1.x
diff_y = p2.y - p1.y
dist = diff_x + diff_y
p2.x += cast(Meter, 1)
p2.y = True
p2.z = 3
p2.x.a = 3

View File

@@ -30,6 +30,7 @@ from __future__ import annotations
T = TypeVar("T")
{preamble}
{sections}
"""
@@ -57,6 +58,11 @@ IMPORTS_REGEX = re.compile(
re.MULTILINE | re.DOTALL,
)
PREAMBLE_REGEX = re.compile(
r"^###>\s*Preamble\s*?\n(?P<body>.*?)\n###<$",
re.MULTILINE | re.DOTALL,
)
def snake_case(text: str) -> str:
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
@@ -88,13 +94,14 @@ def make_banner(text: str) -> str:
def make_section(full_name: str, base: str, param: str, body: str) -> str:
print(f" Generating {full_name}")
visitor_methods: list[str] = []
classes: list[str] = []
definitions: list[str] = body.strip("\n").split("\n\n\n")
for cls in definitions:
cls = cls.strip("\n")
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
print(f"Processing {name}")
print(f" Processing {name}")
visitor_methods.append(make_visitor_method(name, param))
classes.append(make_class(name, cls, base))
@@ -107,6 +114,7 @@ def make_section(full_name: str, base: str, param: str, body: str) -> str:
def generate(definitions_path: Path, out_path: Path):
print(f"Processing generating {out_path} from {definitions_path}")
root_dir: Path = Path(__file__).parent.parent
rel_path: Path = definitions_path.relative_to(root_dir)
src: str = definitions_path.read_text()
@@ -116,6 +124,10 @@ def generate(definitions_path: Path, out_path: Path):
if m := IMPORTS_REGEX.search(src):
imports = m.group("body").strip("\n")
preamble: str = ""
if m := PREAMBLE_REGEX.search(src):
preamble = m.group("body")
for section_m in SECTION_REGEX.finditer(src):
full_name: str = section_m.group("name")
base: str = section_m.group("base")
@@ -129,6 +141,7 @@ def generate(definitions_path: Path, out_path: Path):
gen_path=Path(__file__).relative_to(root_dir),
),
imports=imports,
preamble=preamble,
sections="\n\n\n".join(sections),
)
out_path.write_text(result)

View File

@@ -12,25 +12,31 @@ from midas.lexer.token import Token
###<
###> Stmt | Statements
class TypeStmt:
name: Token
params: list[Param]
type: Type
@dataclass(frozen=True, kw_only=True)
class Param:
###> Preamble
@dataclass(frozen=True, kw_only=True)
class TypeParam:
location: Location
name: Token
bound: Optional[Type]
###<
###> Stmt | Statements
class TypeStmt:
name: Token
params: list[TypeParam]
type: Type
class PropertyStmt:
name: Token
type: Type
class ExtendStmt:
params: list[TypeParam]
type: Type
operations: list[OpStmt]
@@ -103,7 +109,7 @@ class NamedType:
class GenericType:
type: Type
params: list[Type]
args: list[Type]
class ConstraintType:
@@ -115,4 +121,17 @@ class ComplexType:
properties: list[PropertyStmt]
class FunctionType:
pos_args: list[Argument]
kw_args: list[Argument]
returns: Type
@dataclass(frozen=True, kw_only=True)
class Argument:
location: Optional[Location] = None
name: Optional[Token]
type: Type
required: bool
###<

View File

@@ -128,12 +128,6 @@ class LogicalExpr:
right: Expr
class SetExpr:
object: Expr
name: str
value: Expr
class CastExpr:
type: MidasType
expr: Expr
@@ -145,4 +139,8 @@ class TernaryExpr:
if_false: Expr
class ListExpr:
items: list[Expr]
###<

View File

@@ -14,6 +14,13 @@ from midas.lexer.token import Token
T = TypeVar("T")
@dataclass(frozen=True, kw_only=True)
class TypeParam:
location: Location
name: Token
bound: Optional[Type]
##############
# Statements #
##############
@@ -46,15 +53,9 @@ class Stmt(ABC):
@dataclass(frozen=True)
class TypeStmt(Stmt):
name: Token
params: list[Param]
params: list[TypeParam]
type: Type
@dataclass(frozen=True, kw_only=True)
class Param:
location: Location
name: Token
bound: Optional[Type]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_type_stmt(self)
@@ -70,6 +71,7 @@ class PropertyStmt(Stmt):
@dataclass(frozen=True)
class ExtendStmt(Stmt):
params: list[TypeParam]
type: Type
operations: list[OpStmt]
@@ -231,6 +233,9 @@ class Type(ABC):
@abstractmethod
def visit_complex_type(self, type: ComplexType) -> T: ...
@abstractmethod
def visit_function_type(self, type: FunctionType) -> T: ...
@dataclass(frozen=True)
class NamedType(Type):
@@ -243,7 +248,7 @@ class NamedType(Type):
@dataclass(frozen=True)
class GenericType(Type):
type: Type
params: list[Type]
args: list[Type]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_generic_type(self)
@@ -264,3 +269,20 @@ class ComplexType(Type):
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_complex_type(self)
@dataclass(frozen=True)
class FunctionType(Type):
pos_args: list[Argument]
kw_args: list[Argument]
returns: Type
@dataclass(frozen=True, kw_only=True)
class Argument:
location: Optional[Location] = None
name: Optional[Token]
type: Type
required: bool
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_function_type(self)

View File

@@ -100,12 +100,12 @@ class MidasAstPrinter(
self._idx = i
if i == len(stmt.params) - 1:
self._mark_last()
self._print_type_stmt_param(param)
self._print_type_param(param)
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None:
def _print_type_param(self, param: m.TypeParam) -> None:
self._write_line("Param")
with self._child_level():
self._write_line(f'name: "{param.name.lexeme}"')
@@ -122,6 +122,13 @@ class MidasAstPrinter(
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self._write_line("ExtendStmt")
with self._child_level():
self._write_line("params")
with self._child_level():
for i, param in enumerate(stmt.params):
self._idx = i
if i == len(stmt.params) - 1:
self._mark_last()
self._print_type_param(param)
self._write_line("type")
with self._child_level(single=True):
stmt.type.accept(self)
@@ -234,11 +241,11 @@ class MidasAstPrinter(
self._write_line("type")
with self._child_level():
type.type.accept(self)
self._write_line("params", last=True)
self._write_line("args", last=True)
with self._child_level():
for i, param in enumerate(type.params):
for i, param in enumerate(type.args):
self._idx = i
if i == len(type.params) - 1:
if i == len(type.args) - 1:
self._mark_last()
param.accept(self)
@@ -263,6 +270,41 @@ class MidasAstPrinter(
self._mark_last()
prop.accept(self)
def visit_function_type(self, type: m.FunctionType) -> None:
self._write_line("FunctionType")
with self._child_level():
self._write_line("pos_args")
with self._child_level():
for i, arg in enumerate(type.pos_args):
self._idx = i
if i == len(type.pos_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw_args")
with self._child_level():
for i, arg in enumerate(type.kw_args):
self._idx = i
if i == len(type.kw_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("returns", last=True)
with self._child_level(single=True):
type.returns.accept(self)
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
self._write_line("Argument")
with self._child_level():
name: str = "None"
if arg.name is not None:
name = f'"{arg.name.lexeme}"'
self._write_line(f"name: {name}")
self._write_line("type")
with self._child_level(single=True):
arg.type.accept(self)
self._write_line(f"required: {arg.required}", last=True)
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
def __init__(self, indent: int = 4):
@@ -279,14 +321,12 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
template: str = ""
if len(stmt.params) != 0:
params: list[str] = [
self._print_type_template_param(param) for param in stmt.params
]
params: list[str] = [self._print_type_param(param) for param in stmt.params]
template = f"[{', '.join(params)}]"
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
return self.indented(res)
def _print_type_template_param(self, param: m.TypeStmt.Param) -> str:
def _print_type_param(self, param: m.TypeParam) -> str:
res: str = param.name.lexeme
if param.bound is not None:
res += "<:" + param.bound.accept(self)
@@ -358,9 +398,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
def visit_generic_type(self, type: m.GenericType) -> str:
res: str = type.type.accept(self)
if len(type.params) != 0:
params: list[str] = [param.accept(self) for param in type.params]
res += f"[{', '.join(params)}]"
if len(type.args) != 0:
args: list[str] = [param.accept(self) for param in type.args]
res += f"[{', '.join(args)}]"
return res
def visit_constraint_type(self, type: m.ConstraintType) -> str:
@@ -378,6 +418,29 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
res += self.indented("}")
return res
def visit_function_type(self, type: m.FunctionType) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
kw_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
args: list[str] = pos_args
if len(pos_args) != 0:
args.append("/")
if len(kw_args) != 0:
args.append("*")
args += kw_args
return f"({', '.join(args)}) -> {type.returns.accept(self)}"
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
res: str = ""
if arg.name is not None:
res += arg.name.lexeme
res += ": "
res += arg.type.accept(self)
if not arg.required:
res += "?"
return res
class PythonAstPrinter(
AstPrinter,
@@ -602,17 +665,6 @@ class PythonAstPrinter(
with self._child_level(single=True):
expr.right.accept(self)
def visit_set_expr(self, expr: p.SetExpr) -> None:
self._write_line("SetExpr")
with self._child_level():
self._write_line("object")
with self._child_level(single=True):
expr.object.accept(self)
self._write_line(f"name: {expr.name}")
self._write_line("value", last=True)
with self._child_level(single=True):
expr.value.accept(self)
def visit_cast_expr(self, expr: p.CastExpr) -> None:
self._write_line("CastExpr")
with self._child_level():
@@ -637,3 +689,14 @@ class PythonAstPrinter(
self._write_line("if_false", last=True)
with self._child_level(single=True):
expr.if_false.accept(self)
def visit_list_expr(self, expr: p.ListExpr) -> None:
self._write_line("ListExpr")
with self._child_level():
self._write_line("items", last=True)
with self._child_level():
for i, item in enumerate(expr.items):
self._idx = i
if i == len(expr.items) - 1:
self._mark_last()
item.accept(self)

View File

@@ -14,6 +14,7 @@ from midas.ast.location import Location
T = TypeVar("T")
####################
# Type annotations #
####################
@@ -214,15 +215,15 @@ class Expr(ABC):
@abstractmethod
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
@abstractmethod
def visit_set_expr(self, expr: SetExpr) -> T: ...
@abstractmethod
def visit_cast_expr(self, expr: CastExpr) -> T: ...
@abstractmethod
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
@abstractmethod
def visit_list_expr(self, expr: ListExpr) -> T: ...
@dataclass(frozen=True)
class BinaryExpr(Expr):
@@ -298,16 +299,6 @@ class LogicalExpr(Expr):
return visitor.visit_logical_expr(self)
@dataclass(frozen=True)
class SetExpr(Expr):
object: Expr
name: str
value: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_set_expr(self)
@dataclass(frozen=True)
class CastExpr(Expr):
type: MidasType
@@ -325,3 +316,11 @@ class TernaryExpr(Expr):
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_ternary_expr(self)
@dataclass(frozen=True)
class ListExpr(Expr):
items: list[Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_list_expr(self)

112
midas/checker/builtins.py Normal file
View File

@@ -0,0 +1,112 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from midas.checker.types import (
BaseType,
ComplexType,
Function,
GenericType,
Type,
TypeVar,
UnitType,
)
if TYPE_CHECKING:
from midas.checker.registry import TypesRegistry
BUILTIN_SUBTYPES: dict[str, set[str]] = {
"float": {"int"},
"int": {"bool"},
}
def op(reg: TypesRegistry, t1: Type, operator: str, t2: Type, t3: Type):
reg.define_operation(
left=t1,
operator=operator,
right=t2,
result=t3,
)
def basic_op(reg: TypesRegistry, type: Type, op: str):
reg.define_operation(
left=type,
operator=op,
right=type,
result=type,
)
def define_builtins(reg: TypesRegistry):
"""Define builtin types and operations"""
unit = reg.define_type("None", UnitType())
bool = reg.define_type("bool", BaseType(name="bool"))
int = reg.define_type("int", BaseType(name="int"))
float = reg.define_type("float", BaseType(name="float"))
str = reg.define_type("str", BaseType(name="str"))
basic_op(reg, int, "__add__") # int + int = int
basic_op(reg, int, "__sub__") # int - int = int
basic_op(reg, int, "__mul__") # int * int = int
basic_op(reg, int, "__pow__") # int ** int = int
basic_op(reg, int, "__mod__") # int % int = int
basic_op(reg, int, "__and__") # int & int = int
basic_op(reg, int, "__or__") # int | int = int
basic_op(reg, int, "__xor__") # int ^ int = int
op(reg, int, "__lt__", int, bool) # int < int = bool
op(reg, int, "__gt__", int, bool) # int > int = bool
op(reg, int, "__le__", int, bool) # int <= int = bool
op(reg, int, "__ge__", int, bool) # int >= int = bool
op(reg, int, "__eq__", int, bool) # int == int = bool
basic_op(reg, float, "__add__") # float + float = float
basic_op(reg, float, "__sub__") # float - float = float
basic_op(reg, float, "__mul__") # float * float = float
basic_op(reg, float, "__truediv__") # float / float = float
op(reg, float, "__lt__", float, bool) # float < float = bool
op(reg, float, "__gt__", float, bool) # float > float = bool
op(reg, float, "__le__", float, bool) # float <= float = bool
op(reg, float, "__ge__", float, bool) # float >= float = bool
op(reg, float, "__eq__", float, bool) # float == float = bool
basic_op(reg, str, "__add__") # str + str = str
op(reg, str, "__eq__", str, bool) # str == str = bool
op(reg, int, "__lt__", float, bool) # int < float = bool
op(reg, int, "__gt__", float, bool) # int > float = bool
op(reg, int, "__le__", float, bool) # int <= float = bool
op(reg, int, "__ge__", float, bool) # int >= float = bool
op(reg, int, "__eq__", float, bool) # int == float = bool
op(reg, float, "__lt__", int, bool) # float < int = bool
op(reg, float, "__gt__", int, bool) # float > int = bool
op(reg, float, "__le__", int, bool) # float <= int = bool
op(reg, float, "__ge__", int, bool) # float >= int = bool
op(reg, float, "__eq__", int, bool) # float == int = bool
list = reg.define_type(
"list",
GenericType(
name="list",
params=[TypeVar(name="T", bound=None)],
body=ComplexType(
properties={
"append": Function(
name="append",
pos_args=[
Function.Argument(
pos=0,
name="object",
type=TypeVar(name="T", bound=None),
required=True,
)
],
args=[],
kw_args=[],
returns=UnitType(),
)
}
),
),
)

View File

@@ -1,540 +1,35 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.environment import Environment
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
from midas.checker.types import Function, Type, UnitType, UnknownType
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
from midas.resolver.midas import MidasResolver
from midas.checker.diagnostic import Diagnostic
from midas.checker.midas import MidasTyper
from midas.checker.python import PythonTyper
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import Reporter
class ReturnException(Exception):
pass
class TypeChecker:
def __init__(self):
self.types: TypesRegistry = TypesRegistry()
self.reporter: Reporter = Reporter()
self.midas_typer = MidasTyper(self.types, self.reporter)
self.python_typer = PythonTyper(self.types, self.reporter)
@dataclass(frozen=True, kw_only=True)
class MappedArgument:
expr: p.Expr
type: Type
argument: Function.Argument
def import_midas(self, path: Path):
source: str = path.read_text()
return self.import_midas_source(source, path=str(path))
def import_midas_source(self, source: str, path: Optional[str] = None):
self.midas_typer.process(source, path)
class Checker(
p.Stmt.Visitor[None],
p.Expr.Visitor[Type],
p.MidasType.Visitor[Type],
):
"""A type checker which can use custom type definitions"""
def type_check(self, path: Path):
source: str = path.read_text()
return self.type_check_source(source, path=str(path))
def __init__(
self,
locals: dict[p.Expr, int],
source_path: Path,
types_paths: list[Path],
):
self.logger: logging.Logger = logging.getLogger("Checker")
self.source_path: Path = source_path
self.types_paths: list[Path] = types_paths
self.ctx: MidasResolver = MidasResolver()
self.global_env: Environment = Environment()
self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = locals
self.diagnostics: list[Diagnostic] = []
def type_check_source(self, source: str, path: Optional[str] = None):
self.python_typer.process(source, path)
def diagnostic(self, type: DiagnosticType, location: Location, message: str):
self.diagnostics.append(
Diagnostic(
file_path=self.source_path,
location=location,
type=type,
message=message,
)
)
def error(self, location: Location, message: str):
self.diagnostic(
type=DiagnosticType.ERROR,
location=location,
message=message,
)
def warning(self, location: Location, message: str):
self.diagnostic(
type=DiagnosticType.WARNING,
location=location,
message=message,
)
def info(self, location: Location, message: str):
self.diagnostic(
type=DiagnosticType.INFO,
location=location,
message=message,
)
def type_of(self, expr: p.Expr) -> Type:
"""Evaluate the type of an expression
Args:
expr (p.Expr): the expression to evaluate
Returns:
Type: the type of the given expression
"""
return expr.accept(self)
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
"""Evaluate a sequence of statements
Args:
block (list[p.Stmt]): the statements to evaluate
env (Environment): the environment in which to evaluate
Returns:
bool: whether a return statement is present in the block
"""
previous_env: Environment = self.env
self.env = env
returned: bool = False
for i, stmt in enumerate(block):
try:
stmt.accept(self)
except ReturnException:
returned = True
if i < len(block) - 1:
self.warning(block[i + 1].location, "Unreachable statement")
break
self.env = previous_env
return returned
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
"""Type check a sequence of statements and returns diagnostics
Args:
statements (list[p.Stmt]): the statements to evaluate and check
Returns:
list[Diagnostic]: the list of diagnostics (errors, warning, etc.)
"""
self.diagnostics = []
for path in self.types_paths:
self.import_midas(path)
self.logger.debug(f"Midas types: {self.ctx._types}")
self.logger.debug(f"Midas operations: {self.ctx._operations}")
for stmt in statements:
stmt.accept(self)
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
return self.diagnostics
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
"""Look up a variable in the environment it was declared
Args:
name (str): the name of the variable
expr (p.Expr): the variable expression, used to lookup the scope distance
Returns:
Optional[Type]: the type of the variable, or None if it was not found
"""
distance: Optional[int] = self.locals.get(expr)
if distance is not None:
return self.env.get_at(distance, name)
return self.global_env.get(name)
def import_midas(self, path: Path) -> None:
"""Import Midas definitions from a path
Args:
path (Path): the import path
"""
self.logger.debug(f"Importing type definitions from {path}")
lexer: MidasLexer = MidasLexer(path.read_text())
tokens: list[Token] = lexer.process()
parser: MidasParser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
self.ctx.resolve(stmts)
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
self.type_of(stmt.expr)
def visit_function(self, stmt: p.Function) -> None:
env: Environment = Environment(self.env)
pos_args: list[Function.Argument] = []
args: list[Function.Argument] = []
kw_args: list[Function.Argument] = []
def eval_arg_type(arg: p.Function.Argument) -> Type:
if arg.type is not None:
return arg.type.accept(self)
if arg.default is not None:
return arg.default.accept(self)
return UnknownType()
for arg in stmt.posonlyargs:
pos_args.append(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
for arg in stmt.args:
args.append(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
for arg in stmt.kwonlyargs:
kw_args.append(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
for arg in pos_args + args + kw_args:
env.define(arg.name, arg.type)
returns_hint: Optional[Type] = None
if stmt.returns is not None:
returns_hint = stmt.returns.accept(self)
# Early define to handle simple fully-typed recursion
inside_function: Function = Function(
name=stmt.name,
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns_hint,
)
self.env.define(stmt.name, inside_function)
returned: bool = self.process_block(stmt.body, env)
inferred_return: Type = UnknownType()
if not returned:
env.return_types.append(UnitType())
return_types: set[Type] = set(env.return_types)
if len(return_types) == 1:
inferred_return = list(return_types)[0]
elif len(return_types) > 1:
self.error(
stmt.location,
f"Mixed return types: {env.return_types}",
)
returns: Type = UnknownType()
if returns_hint is not None:
assert stmt.returns is not None
returns = returns_hint
if returns != inferred_return:
self.error(
stmt.returns.location,
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
)
else:
returns = inferred_return
# TODO: handle *args and **kwargs sinks
function: Function = Function(
name=stmt.name,
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns,
)
self.env.define(stmt.name, function)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
# TODO check not yet defined locally
type: Type = stmt.type.accept(self)
self.env.define(stmt.name, type)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
value: Type = self.type_of(stmt.value)
for target in stmt.targets:
if not isinstance(target, p.VariableExpr):
self.logger.warning(f"Unsupported assignment to {target}")
self.warning(target.location, f"Unsupported assignment to {target}")
continue
name: str = target.name
var_type: Optional[Type] = self.look_up_variable(name, target)
if var_type is None:
self.env.define(name, value)
else:
# TODO: implement real comparison method
if var_type != value:
self.error(
stmt.location,
f"Cannot assign {value} to {name} of type {var_type}",
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
self.env.return_types.append(type)
raise ReturnException()
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
# Not evaluated in sub-environment because assignments in the test leak out of the if
# For example:
# if (m := 1 + 1) < 2:
# ...
# print(m) # <- m is still defined
test_type: Type = stmt.test.accept(self)
# TODO Allow subtypes or any type
if test_type != self.ctx.get_type("bool"):
self.error(
stmt.test.location, f"If test must be a boolean, got {test_type}"
)
env: Environment = Environment(self.env)
body_returned: bool = self.process_block(stmt.body, env)
else_returned: bool = self.process_block(stmt.orelse, env)
self.env.return_types.extend(env.return_types)
if body_returned and else_returned:
raise ReturnException()
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.warning(expr.location, f"Unsupported operator {expr.operator}")
return UnknownType()
left: Type = self.type_of(expr.left)
right: Type = self.type_of(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
if result is None:
self.error(
expr.location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
return result
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.warning(expr.location, f"Unsupported operator {expr.operator}")
return UnknownType()
left: Type = self.type_of(expr.left)
right: Type = self.type_of(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
if result is None:
self.error(
expr.location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
return result
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
def visit_call_expr(self, expr: p.CallExpr) -> Type:
callee: Type = self.type_of(expr.callee)
if not isinstance(callee, Function):
self.error(expr.callee.location, "Callee is not a function")
return UnknownType()
function: Function = callee
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
for arg in mapped:
if arg.type != arg.argument.type:
self.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
return function.returns
def visit_get_expr(self, expr: p.GetExpr) -> Type: ...
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
match expr.value:
case bool(): # Must be before int
return self.ctx.get_type("bool")
case int():
return self.ctx.get_type("int")
case float():
return self.ctx.get_type("float")
case str():
return self.ctx.get_type("str")
case _:
self.warning(expr.location, f"Unknown literal {expr}")
return UnknownType()
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
return self.look_up_variable(expr.name, expr) or UnknownType()
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
left: Type = expr.left.accept(self)
right: Type = expr.right.accept(self)
# TODO: union type
if left != right:
self.error(
expr.location,
f"Operands must be of the same type, left={left} != right={right}",
)
return left
def visit_set_expr(self, expr: p.SetExpr) -> Type: ...
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
return expr.type.accept(self)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
test_type: Type = expr.test.accept(self)
# TODO Allow subtypes or any type
if test_type != self.ctx.get_type("bool"):
self.error(
expr.test.location, f"If test must be a boolean, got {test_type}"
)
true_type: Type = expr.if_true.accept(self)
false_type: Type = expr.if_false.accept(self)
if true_type != false_type:
self.error(
expr.location,
f"Type mismatch in ternary if branches: true={true_type} != false={false_type}",
)
return UnknownType()
return true_type
def visit_base_type(self, node: p.BaseType) -> Type:
return self.ctx.get_type(node.base)
def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
def visit_frame_type(self, node: p.FrameType) -> Type: ...
def map_call_arguments(
self, function: Function, call: p.CallExpr
) -> list[MappedArgument]:
"""Map call arguments to function parameters as defined in its signature
This method maps positional-only, keyword-only and mixed parameter definitions
with the arguments passed at the call site
Any mismatched, missing or unexpected argument is reported as a diagnostic
Args:
function (Function): the function definition
call (p.CallExpr): the call expression
Returns:
list[MappedArgument]: the list of mapped arguments
"""
positional: list[tuple[p.Expr, Type]] = [
(arg, self.type_of(arg)) for arg in call.arguments
]
keywords: dict[str, tuple[p.Expr, Type]] = {
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
}
set_args: set[str] = set()
required_positional: list[str] = [
arg.name for arg in function.pos_args + function.args if arg.required
]
required_keyword: list[str] = [
arg.name for arg in function.kw_args if arg.required
]
mapped: list[MappedArgument] = []
pos_params: list[Function.Argument] = list(function.pos_args)
mixed_params: list[Function.Argument] = list(function.args)
kw_params: dict[str, Function.Argument] = {
arg.name: arg for arg in function.kw_args
}
# TODO: handle *args and **kwargs sinks
for arg in positional:
param: Function.Argument
if len(pos_params) != 0:
param = pos_params.pop(0)
elif len(mixed_params) != 0:
param = mixed_params.pop(0)
else:
self.error(arg[0].location, "Too many positional arguments")
break
name: str = param.name
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
kw_params.update({arg.name: arg for arg in mixed_params})
for name, arg in keywords.items():
param: Function.Argument
if name not in kw_params:
if name in set_args:
self.error(
arg[0].location, f"Multiple values for argument '{name}'"
)
else:
self.error(arg[0].location, f"Unknown keyword argument '{name}'")
continue
param = kw_params.pop(name)
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
def join_args(args: list[str]) -> str:
args = list(map(lambda a: f"'{a}'", args))
if len(args) == 0:
return ""
if len(args) == 1:
return args[0]
return ", ".join(args[:-1]) + " and " + args[-1]
if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s"
args: str = join_args(required_positional)
self.error(
call.location,
f"Missing required positional argument{plural}: {args}",
)
if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s"
args: str = join_args(required_keyword)
self.error(
call.location,
f"Missing required keyword argument{plural}: {args}",
)
return mapped
@property
def diagnostics(self) -> list[Diagnostic]:
return self.reporter.diagnostics

View File

@@ -1,6 +1,5 @@
from dataclasses import dataclass
from enum import StrEnum
from pathlib import Path
from typing import Optional
from midas.ast.location import Location
@@ -14,12 +13,13 @@ class DiagnosticType(StrEnum):
@dataclass(frozen=True)
class Diagnostic:
file_path: Path
file_path: Optional[str]
location: Location
type: DiagnosticType
message: str
def __str__(self) -> str:
@property
def location_str(self) -> str:
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
end_loc: Optional[str] = ""
if (
@@ -27,7 +27,16 @@ class Diagnostic:
and self.location.end_col_offset is not None
):
end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}"
loc: str = (
f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}"
)
return f"{self.type} in {self.file_path} {loc}: {self.message}"
loc: str = ""
if self.file_path is not None:
loc += f" in {self.file_path}"
if end_loc is None:
loc += f" at {start_loc}"
else:
loc += f" from {start_loc} to {end_loc}"
return f"{self.type}{loc}"
def __str__(self) -> str:
return f"{self.location_str}: {self.message}"

171
midas/checker/midas.py Normal file
View File

@@ -0,0 +1,171 @@
import logging
from typing import Optional
import midas.ast.midas as m
from midas.checker.builtins import define_builtins
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import (
AliasType,
ComplexType,
Function,
GenericType,
Type,
TypeVar,
UnknownType,
)
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry"""
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
self.logger: logging.Logger = logging.getLogger("MidasTyper")
self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types
self._local_variables: dict[str, TypeVar] = {}
define_builtins(self.types)
def process(self, source: str, path: Optional[str]):
self.reporter = self.reporter.for_file(path)
lexer: MidasLexer = MidasLexer(source)
tokens: list[Token] = lexer.process()
parser: MidasParser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
self.resolve(stmts)
def get_type(self, name: str) -> Type:
"""Get a type from its name
Args:
name (str): the name of the type
Raises:
NameError: if the type is not defined
Returns:
Type: the type
"""
if name in self._local_variables:
return self._local_variables[name]
return self.types.get_type(name)
def resolve(self, stmts: list[m.Stmt]):
"""Process a sequence of statements
Args:
stmts (list[m.Stmt]): the statements
"""
for stmt in stmts:
stmt.accept(self)
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
params: list[TypeVar] = self._resolve_type_params(stmt.params)
name: str = stmt.name.lexeme
type: Type = stmt.type.accept(self)
if len(params) != 0:
type = GenericType(name=name, params=params, body=type)
else:
type = AliasType(name=name, type=type)
self.types.define_type(name, type)
self._local_variables.clear()
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self._resolve_type_params(stmt.params)
base: Type = stmt.type.accept(self)
for op in stmt.operations:
right: Type = op.operand.accept(self)
result: Type = op.result.accept(self)
self.types.define_operation(
left=base,
operator=op.name.lexeme,
right=right,
result=result,
)
def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
def visit_get_expr(self, expr: m.GetExpr) -> None: ...
def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
return expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
def visit_named_type(self, type: m.NamedType) -> Type:
return self.get_type(type.name.lexeme)
def visit_generic_type(self, type: m.GenericType) -> Type:
type_: Type = type.type.accept(self)
args: list[Type] = [arg.accept(self) for arg in type.args]
return self.types.apply_generic(type_, args)
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
type_: Type = type.type.accept(self)
type.constraint.accept(self)
# TODO
return UnknownType()
def visit_complex_type(self, type: m.ComplexType) -> Type:
return ComplexType(
properties={
prop.name.lexeme: prop.type.accept(self) for prop in type.properties
}
)
def visit_function_type(self, type: m.FunctionType) -> Type:
return Function(
name="<anonymous>",
pos_args=[
Function.Argument(
pos=i,
name=arg.name.lexeme if arg.name is not None else str(i),
type=arg.type.accept(self),
required=arg.required,
)
for i, arg in enumerate(type.pos_args)
],
args=[],
kw_args=[
Function.Argument(
pos=i,
name=arg.name.lexeme if arg.name is not None else str(i),
type=arg.type.accept(self),
required=arg.required,
)
for i, arg in enumerate(type.kw_args, start=len(type.pos_args))
],
returns=type.returns.accept(self),
)
def _resolve_type_params(self, params: list[m.TypeParam]):
vars: list[TypeVar] = []
for param in params:
name: str = param.name.lexeme
bound: Optional[Type] = None
if param.bound is not None:
bound = param.bound.accept(self)
var = TypeVar(name=name, bound=bound)
self._local_variables[name] = var
vars.append(var)
return vars

646
midas/checker/python.py Normal file
View File

@@ -0,0 +1,646 @@
import ast
import logging
from dataclasses import dataclass
from typing import Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.environment import Environment
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver
from midas.checker.types import (
ComplexType,
Function,
Operation,
Type,
UnitType,
UnknownType,
unfold_type,
)
from midas.parser.python import PythonParser
class ReturnException(Exception):
pass
@dataclass(frozen=True, kw_only=True)
class MappedArgument:
expr: p.Expr
type: Type
argument: Function.Argument
class PythonTyper(
p.Stmt.Visitor[None],
p.Expr.Visitor[Type],
p.MidasType.Visitor[Type],
):
"""A type checker which can use custom type definitions"""
def __init__(
self,
types: TypesRegistry,
reporter: Reporter,
):
self.logger: logging.Logger = logging.getLogger("PythonTyper")
self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types
self.global_env: Environment = Environment()
self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {}
self.judgements: list[tuple[p.Expr, Type]] = []
def process(self, source: str, path: Optional[str]):
self.reporter = self.reporter.for_file(path)
tree: ast.Module = ast.parse(source, filename=path or "<unknown>")
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
resolver = Resolver()
resolver.resolve(*stmts)
self.env = self.global_env
self.locals = resolver.locals
self.judgements = []
self.check(stmts)
def type_of(self, expr: p.Expr) -> Type:
"""Evaluate the type of an expression
Args:
expr (p.Expr): the expression to evaluate
Returns:
Type: the type of the given expression
"""
type: Type = expr.accept(self)
self.judgements.append((expr, type))
return type
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
"""Evaluate a sequence of statements
Args:
block (list[p.Stmt]): the statements to evaluate
env (Environment): the environment in which to evaluate
Returns:
bool: whether a return statement is present in the block
"""
previous_env: Environment = self.env
self.env = env
returned: bool = False
for i, stmt in enumerate(block):
try:
stmt.accept(self)
except ReturnException:
returned = True
if i < len(block) - 1:
self.reporter.warning(
block[i + 1].location, "Unreachable statement"
)
break
self.env = previous_env
return returned
def check(self, statements: list[p.Stmt]) -> None:
"""Type check a sequence of statements and returns diagnostics
Args:
statements (list[p.Stmt]): the statements to evaluate and check
"""
for stmt in statements:
stmt.accept(self)
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
"""Look up a variable in the environment it was declared
Args:
name (str): the name of the variable
expr (p.Expr): the variable expression, used to lookup the scope distance
Returns:
Optional[Type]: the type of the variable, or None if it was not found
"""
distance: Optional[int] = self.locals.get(expr)
if distance is not None:
return self.env.get_at(distance, name)
return self.global_env.get(name)
def is_subtype(self, type1: Type, type2: Type) -> bool:
return self.types.is_subtype(type1, type2)
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
self.type_of(stmt.expr)
def visit_function(self, stmt: p.Function) -> None:
env: Environment = Environment(self.env)
pos_args: list[Function.Argument] = []
args: list[Function.Argument] = []
kw_args: list[Function.Argument] = []
def eval_arg_type(arg: p.Function.Argument) -> Type:
if arg.type is not None:
return arg.type.accept(self)
if arg.default is not None:
return arg.default.accept(self)
return UnknownType()
pos: int = 0
for arg in stmt.posonlyargs:
pos_args.append(
Function.Argument(
pos=pos,
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
pos += 1
for arg in stmt.args:
args.append(
Function.Argument(
pos=pos,
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
pos += 1
for arg in stmt.kwonlyargs:
kw_args.append(
Function.Argument(
pos=pos, # not relevant
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
pos += 1
for arg in pos_args + args + kw_args:
env.define(arg.name, arg.type)
returns_hint: Optional[Type] = None
if stmt.returns is not None:
returns_hint = stmt.returns.accept(self)
# Early define to handle simple fully-typed recursion
inside_function: Function = Function(
name=stmt.name,
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns_hint,
)
self.env.define(stmt.name, inside_function)
returned: bool = self.process_block(stmt.body, env)
inferred_return: Type = UnknownType()
if not returned:
env.return_types.append(UnitType())
return_types: list[Type] = self.types.reduce_types(env.return_types)
if len(return_types) == 1:
inferred_return = return_types[0]
elif len(return_types) > 1:
self.reporter.error(
stmt.location,
f"Mixed return types: {return_types}",
)
returns: Type = UnknownType()
if returns_hint is not None:
assert stmt.returns is not None
returns = returns_hint
if returns != inferred_return:
self.reporter.error(
stmt.returns.location,
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
)
else:
returns = inferred_return
# TODO: handle *args and **kwargs sinks
function: Function = Function(
name=stmt.name,
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns,
)
self.env.define(stmt.name, function)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
# TODO check not yet defined locally
type: Type = stmt.type.accept(self)
self.env.define(stmt.name, type)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
value_type: Type = self.type_of(stmt.value)
for target in stmt.targets:
self._assign(stmt.location, target, value_type)
def _assign(self, location: Location, target: p.Expr, value_type: Type):
match target:
case p.VariableExpr():
self._assign_var(location, target, value_type)
case p.GetExpr():
self._assign_attr(location, target, value_type)
case _:
if not isinstance(target, p.VariableExpr):
self.logger.warning(f"Unsupported assignment to {target}")
self.reporter.warning(
target.location, f"Unsupported assignment to {target}"
)
def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type):
name: str = target.name
var_type: Optional[Type] = self.look_up_variable(name, target)
if var_type is None:
self.env.define(name, value_type)
else:
# S <: T
# Γ, x: T v: S
# x = v
if not self.is_subtype(value_type, var_type):
self.reporter.error(
location,
f"Cannot assign {value_type} to variable '{name}' of type {var_type}",
)
def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type):
object: Type = self.type_of(target.object)
base_object: Type = unfold_type(object)
match base_object:
case ComplexType(properties=properties):
if target.name not in properties:
self.reporter.error(
target.location, f"Unknown property '{object}.{target.name}'"
)
return
prop_type: Type = properties[target.name]
if not self.is_subtype(value_type, prop_type):
self.reporter.error(
location,
f"Cannot assign {value_type} to property '{object}.{target.name}' of type {prop_type}",
)
return
case UnknownType():
pass
case _:
self.reporter.error(
target.location,
f"Cannot assign {value_type} to unknown property '{object}.{target.name}'",
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
self.env.return_types.append(type)
raise ReturnException()
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
# Not evaluated in sub-environment because assignments in the test leak out of the if
# For example:
# if (m := 1 + 1) < 2:
# ...
# print(m) # <- m is still defined
test_type: Type = stmt.test.accept(self)
# TODO Allow subtypes or any type
if test_type != self.types.get_type("bool"):
self.reporter.error(
stmt.test.location, f"If test must be a boolean, got {test_type}"
)
env: Environment = Environment(self.env)
body_returned: bool = self.process_block(stmt.body, env)
else_returned: bool = self.process_block(stmt.orelse, env)
self.env.return_types.extend(env.return_types)
if body_returned and else_returned:
raise ReturnException()
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.reporter.warning(
expr.location, f"Unsupported operator {expr.operator}"
)
return UnknownType()
left: Type = self.type_of(expr.left)
right: Type = self.type_of(expr.right)
operations: list[Operation] = self.types.get_operations_by_name(method)
valid_operations: list[Operation] = []
for op in operations:
sig: Operation.CallSignature = op.signature
if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right):
valid_operations.append(op)
if len(valid_operations) == 0:
self.reporter.error(
expr.location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
elif len(valid_operations) == 1:
self.logger.debug(f"Unique operation {method} between {left} and {right}")
return valid_operations[0].result
for i, op1 in enumerate(valid_operations):
sig1: Operation.CallSignature = op1.signature
best_match: bool = True
for j, op2 in enumerate(valid_operations):
if i == j:
continue
sig2: Operation.CallSignature = op2.signature
# If op1 is not a full overload of op2 (i.e. operands of op1 are subtypes of op2's)
# ambiguity -> not best match
if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype(
sig1.right, sig2.right
):
best_match = False
break
self.logger.debug(f"{op1} is a full overload of {op2}")
if best_match:
return op1.result
self.reporter.error(
expr.location,
f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(map(str, valid_operations))}",
)
return UnknownType()
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.reporter.warning(
expr.location, f"Unsupported operator {expr.operator}"
)
return UnknownType()
left: Type = self.type_of(expr.left)
right: Type = self.type_of(expr.right)
result: Optional[Type] = self.types.get_operation_result(left, method, right)
if result is None:
self.reporter.error(
expr.location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
return result
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
def visit_call_expr(self, expr: p.CallExpr) -> Type:
callee: Type = self.type_of(expr.callee)
if not isinstance(callee, Function):
self.reporter.error(expr.callee.location, "Callee is not a function")
return UnknownType()
function: Function = callee
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
for arg in mapped:
if not self.is_subtype(arg.type, arg.argument.type):
self.reporter.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
return function.returns
def visit_get_expr(self, expr: p.GetExpr) -> Type:
object: Type = self.type_of(expr.object)
base_object: Type = unfold_type(object)
match base_object:
case ComplexType(properties=properties):
if expr.name not in properties:
self.reporter.error(
expr.location, f"Unknown property '{expr.name} on {object}"
)
return UnknownType()
return properties[expr.name]
case UnknownType():
return UnknownType()
case _:
self.reporter.error(
expr.location, f"Cannot get property '{expr.name}' on {object}"
)
return UnknownType()
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
match expr.value:
case bool(): # Must be before int
return self.types.get_type("bool")
case int():
return self.types.get_type("int")
case float():
return self.types.get_type("float")
case str():
return self.types.get_type("str")
case _:
self.reporter.warning(expr.location, f"Unknown literal {expr}")
return UnknownType()
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
return self.look_up_variable(expr.name, expr) or UnknownType()
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
left: Type = expr.left.accept(self)
right: Type = expr.right.accept(self)
if self.is_subtype(left, right):
return right
if self.is_subtype(right, left):
return left
self.reporter.error(
expr.location,
f"Incompatible operand types, {left=} and {right=}",
)
return UnknownType()
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
return expr.type.accept(self)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
test_type: Type = expr.test.accept(self)
# TODO Allow subtypes or any type
if test_type != self.types.get_type("bool"):
self.reporter.error(
expr.test.location, f"If test must be a boolean, got {test_type}"
)
true_type: Type = expr.if_true.accept(self)
false_type: Type = expr.if_false.accept(self)
if self.is_subtype(true_type, false_type):
return false_type
if self.is_subtype(false_type, true_type):
return true_type
self.reporter.error(
expr.location,
f"Incompatible types in ternary if branches: true={true_type} and false={false_type}",
)
return UnknownType()
def visit_list_expr(self, expr: p.ListExpr) -> Type:
list_type: Type = self.types.get_type("list")
item_types: list[Type] = [self.type_of(item) for item in expr.items]
item_types = self.types.reduce_types(item_types)
if len(item_types) == 0:
return list_type
if len(item_types) == 1:
item_type: Type = item_types[0]
return self.types.apply_generic(list_type, [item_type])
self.reporter.error(
expr.location,
f"Heterogeneous list items: {item_types}",
)
return self.types.apply_generic(list_type, [UnknownType()])
def visit_base_type(self, node: p.BaseType) -> Type:
base: Type = self.types.get_type(node.base)
if node.param is not None:
param: Type = node.param.accept(self)
return self.types.apply_generic(base, [param])
return base
def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
def visit_frame_type(self, node: p.FrameType) -> Type: ...
def map_call_arguments(
self, function: Function, call: p.CallExpr
) -> list[MappedArgument]:
"""Map call arguments to function parameters as defined in its signature
This method maps positional-only, keyword-only and mixed parameter definitions
with the arguments passed at the call site
Any mismatched, missing or unexpected argument is reported as a diagnostic
Args:
function (Function): the function definition
call (p.CallExpr): the call expression
Returns:
list[MappedArgument]: the list of mapped arguments
"""
positional: list[tuple[p.Expr, Type]] = [
(arg, self.type_of(arg)) for arg in call.arguments
]
keywords: dict[str, tuple[p.Expr, Type]] = {
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
}
set_args: set[str] = set()
required_positional: list[str] = [
arg.name for arg in function.pos_args + function.args if arg.required
]
required_keyword: list[str] = [
arg.name for arg in function.kw_args if arg.required
]
mapped: list[MappedArgument] = []
pos_params: list[Function.Argument] = list(function.pos_args)
mixed_params: list[Function.Argument] = list(function.args)
kw_params: dict[str, Function.Argument] = {
arg.name: arg for arg in function.kw_args
}
# TODO: handle *args and **kwargs sinks
for arg in positional:
param: Function.Argument
if len(pos_params) != 0:
param = pos_params.pop(0)
elif len(mixed_params) != 0:
param = mixed_params.pop(0)
else:
self.reporter.error(arg[0].location, "Too many positional arguments")
break
name: str = param.name
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
kw_params.update({arg.name: arg for arg in mixed_params})
for name, arg in keywords.items():
param: Function.Argument
if name not in kw_params:
if name in set_args:
self.reporter.error(
arg[0].location, f"Multiple values for argument '{name}'"
)
else:
self.reporter.error(
arg[0].location, f"Unknown keyword argument '{name}'"
)
continue
param = kw_params.pop(name)
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
def join_args(args: list[str]) -> str:
args = list(map(lambda a: f"'{a}'", args))
if len(args) == 0:
return ""
if len(args) == 1:
return args[0]
return ", ".join(args[:-1]) + " and " + args[-1]
if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s"
args: str = join_args(required_positional)
self.reporter.error(
call.location,
f"Missing required positional argument{plural}: {args}",
)
if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s"
args: str = join_args(required_keyword)
self.reporter.error(
call.location,
f"Missing required keyword argument{plural}: {args}",
)
return mapped

313
midas/checker/registry.py Normal file
View File

@@ -0,0 +1,313 @@
from typing import Optional
from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
Function,
GenericType,
Operation,
Type,
substitute_typevars,
)
class TypesRegistry:
def __init__(self) -> None:
self._types: dict[str, Type] = {}
self._operations: dict[Operation.CallSignature, Type] = {}
def get_type(self, name: str) -> Type:
"""Get a type from its name
Args:
name (str): the name of the type
Raises:
NameError: if the type is not defined
Returns:
Type: the type
"""
if name in self._types:
return self._types[name]
raise NameError(f"Undefined type {name}")
def get_operation_result(
self, left: Type, operator: str, right: Type
) -> Optional[Type]:
"""Get the resulting type of an operation
Args:
left (Type): the type of the left operand
operator (str): the operation name
right (Type): the type of the right operand
Returns:
Optional[Type]: the result type, or None if no matching operation was found
"""
signature: Operation.CallSignature = Operation.CallSignature(
left=left,
method=operator,
right=right,
)
result: Optional[Type] = self._operations.get(signature)
return result
def get_operations_by_name(self, name: str) -> list[Operation]:
operations: list[Operation] = []
for signature, result in self._operations.items():
if signature.method == name:
operations.append(
Operation(
signature=signature,
result=result,
)
)
return operations
def define_type(self, name: str, type: Type) -> Type:
"""Define a type in the registry
Args:
name (str): the name of the type
type (Type): the type to define
Raises:
ValueError: if a type is already defined with that name
Returns:
Type: the defined type
"""
if name in self._types:
raise ValueError(f"Type {name} already defined")
self._types[name] = type
return type
def define_operation(self, left: Type, operator: str, right: Type, result: Type):
"""Define an operation in the registry
Args:
left (Type): the type of the left operand
operator (str): the operation name
right (Type): the type of the right operand
result (Type): the result type
Raises:
ValueError: if an operation is already defined with these operands and name
"""
signature: Operation.CallSignature = Operation.CallSignature(
left=left,
method=operator,
right=right,
)
if signature in self._operations:
raise ValueError(
f"Operation {operator} already defined between {left} and {right}"
)
self._operations[signature] = result
def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2`
For more details on the rules checked here, see TAPL Chap. 15-16-17
Args:
type1 (Type): the potential subtype
type2 (Type): the potential supertype
Returns:
bool: whether `type1` is a subtype of `type2`
"""
if type1 == type2:
return True
match (type1, type2):
case (AliasType(type=base1), _):
return self.is_subtype(base1, type2)
case (BaseType(name=name1), BaseType(name=name2)):
return name1 in BUILTIN_SUBTYPES.get(name2, set())
case (ComplexType(properties=props1), ComplexType(properties=props2)):
for k, t in props2.items():
if k not in props1:
return False
if not self.is_subtype(props1[k], t):
return False
return True
case (Function(), Function()):
return self.is_func_subtype(type1, type2)
return False
# TODO: verify the logic in here
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
"""Check whether a function is a subtype of another
Args:
func1 (Function): the potential function subtype
func2 (Function): the potential function supertype
Returns:
bool: whether `func1` is a subtype of `func2`
"""
if not self.is_subtype(func1.returns, func2.returns):
return False
pos1: list[Function.Argument] = func1.pos_args
mixed1: list[Function.Argument] = func1.args
kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args}
pos2: list[Function.Argument] = func2.pos_args
mixed2: list[Function.Argument] = func2.args
kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args}
mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2}
mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2}
def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool:
if not self.is_subtype(sub.type, sup.type):
return False
if not sup.required and sub.required:
return False
return True
for arg1 in pos1:
arg2: Function.Argument
if arg1.pos < len(pos2):
arg2 = pos2[arg1.pos]
elif arg1.pos in mixed_by_pos:
arg2 = mixed_by_pos[arg1.pos]
elif not arg1.required:
continue
else:
return False
if not is_arg_subtype(arg2, arg1):
return False
for name, arg1 in kw1.items():
arg2: Function.Argument
if name in kw2:
arg2 = kw2[name]
elif name in mixed_by_name:
arg2 = mixed_by_name[name]
elif not arg1.required:
continue
else:
return False
if not is_arg_subtype(arg2, arg1):
return False
for arg1 in mixed1:
pos_arg2: Optional[Function.Argument] = None
kw_arg2: Optional[Function.Argument] = None
if arg1.name in kw2:
kw_arg2 = kw2[arg1.name]
elif arg1.name in mixed_by_name:
kw_arg2 = mixed_by_name[arg1.name]
if arg1.pos < len(pos2):
pos_arg2 = pos2[arg1.pos]
elif arg1.pos in mixed_by_pos:
pos_arg2 = mixed_by_pos[arg1.pos]
# No match in func2 and arg is required
if pos_arg2 is None and kw_arg2 is None and arg1.required:
return False
# Matching keyword argument
if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1):
return False
# Matching positional argument
if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1):
return False
mixed_positions: set[int] = {a.pos for a in mixed1}
mixed_names: set[str] = {a.name for a in mixed1}
for arg2 in pos2:
if not arg2.required:
continue
if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions:
return False
for name, arg2 in kw2.items():
if not arg2.required:
continue
if name not in kw1 and name not in mixed_names:
return False
for arg2 in mixed2:
if arg2.required:
continue
pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions
kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names
if not pos_match or not kw_match:
return False
return True
def apply_generic(self, type: Type, args: list[Type]) -> Type:
match type:
case AliasType(name=name, type=base):
return AliasType(name=name, type=self.apply_generic(base, args))
case GenericType(name=name, params=type_vars, body=body):
n_args: int = len(args)
n_type_vars: int = len(type_vars)
if n_args < n_type_vars:
raise ValueError(
f"Missing type arguments, expected {n_type_vars} but only {n_args} provided"
)
if n_args > n_type_vars:
raise ValueError(
f"Too many type arguments, expected {n_type_vars} but {n_args} provided"
)
substitutions: dict[str, Type] = {}
for arg, type_var in zip(args, type_vars):
if type_var.bound is not None and not self.is_subtype(
arg, type_var.bound
):
raise ValueError(
f"Type argument {arg} is not a subtype of {type_var.bound}"
)
substitutions[type_var.name] = arg
return AppliedType(
name=name,
args=args,
body=substitute_typevars(body, substitutions),
)
case _:
raise ValueError(f"{type} is not a generic type")
def reduce_types(self, types: list[Type]) -> list[Type]:
"""Reduce a list of types to remove subtypes and only keep the highest types
Args:
types (list[Type]): the types to reduce
Returns:
list[Type]: the reduced list of types
"""
reduced: bool = True
keep: list[int] = list(range(len(types)))
while reduced:
reduced = False
for i, i1 in enumerate(keep):
type1: Type = types[i1]
for i2 in keep[i + 1 :]:
type2 = types[i2]
if self.is_subtype(type1, type2):
keep.remove(i1)
elif self.is_subtype(type2, type1):
keep.remove(i2)
else:
continue
reduced = True
break
return [types[i] for i in keep]

63
midas/checker/reporter.py Normal file
View File

@@ -0,0 +1,63 @@
from __future__ import annotations
from typing import Optional
from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType
class Reporter:
def __init__(self):
self.diagnostics: list[Diagnostic] = []
def report(
self,
path: Optional[str],
type: DiagnosticType,
location: Location,
message: str,
):
self.diagnostics.append(
Diagnostic(
file_path=path,
location=location,
type=type,
message=message,
)
)
def for_file(self, path: Optional[str]) -> FileReporter:
return FileReporter(self, path)
class FileReporter:
def __init__(self, base_reporter: Reporter, path: Optional[str]) -> None:
self.base_reporter: Reporter = base_reporter
self.path: Optional[str] = path
def for_file(self, path: Optional[str]) -> FileReporter:
return FileReporter(self.base_reporter, path)
def report(self, type: DiagnosticType, location: Location, message: str):
self.base_reporter.report(self.path, type, location, message)
def error(self, location: Location, message: str):
self.report(
type=DiagnosticType.ERROR,
location=location,
message=message,
)
def warning(self, location: Location, message: str):
self.report(
type=DiagnosticType.WARNING,
location=location,
message=message,
)
def info(self, location: Location, message: str):
self.report(
type=DiagnosticType.INFO,
location=location,
message=message,
)

View File

@@ -13,7 +13,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
def __init__(self):
self.locals: dict[p.Expr, int] = {}
self.scopes: list[dict[str, bool]] = []
self.scopes: list[dict[str, bool]] = [{}]
def resolve(self, *objects: p.Stmt | p.Expr) -> None:
"""Resolve the given statements or expressions"""
@@ -77,6 +77,12 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
self.locals[expr] = i
return
def is_defined(self, name: str) -> bool:
for scope in self.scopes:
if name in scope:
return True
return False
def resolve_function(self, function: p.Function) -> None:
"""Resolve a function definition
@@ -112,8 +118,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
for target in stmt.targets:
match target:
case p.VariableExpr(name=name):
self.resolve_local(target, name)
# TODO: declare if not found
if not self.is_defined(name):
self.declare(name)
self.define(name)
target.accept(self)
case p.GetExpr():
target.accept(self)
case _:
raise Exception(f"Unsupported assignment to {target}")
@@ -174,10 +185,6 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
self.resolve(expr.left)
self.resolve(expr.right)
def visit_set_expr(self, expr: p.SetExpr) -> None:
self.resolve(expr.value)
self.resolve(expr.object)
def visit_cast_expr(self, expr: p.CastExpr) -> None:
self.resolve(expr.expr)
@@ -185,3 +192,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
self.resolve(expr.test)
self.resolve(expr.if_true)
self.resolve(expr.if_false)
def visit_list_expr(self, expr: p.ListExpr) -> None:
for item in expr.items:
self.resolve(item)

View File

@@ -1,27 +1,36 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
@dataclass(frozen=True, kw_only=True)
class BaseType:
name: str
def __str__(self) -> str:
return self.name
@dataclass(frozen=True, kw_only=True)
class AliasType:
name: str
type: Type
def __str__(self) -> str:
return self.name
@dataclass(frozen=True, kw_only=True)
class UnknownType:
pass
def __str__(self) -> str:
return "<Unknown>"
@dataclass(frozen=True, kw_only=True)
class UnitType:
pass
def __str__(self) -> str:
return "None"
@dataclass(frozen=True, kw_only=True)
@@ -32,16 +41,159 @@ class Function:
kw_args: list[Argument]
returns: Type
def __str__(self) -> str:
args: list[str] = []
if len(self.pos_args) != 0:
args += list(map(str, self.pos_args))
if len(self.args) + len(self.kw_args) != 0:
args.append("/")
if len(self.args) != 0:
args += list(map(str, self.args))
if len(self.kw_args) != 0:
if len(args) != 0:
args.append("*")
args += list(map(str, self.kw_args))
return f"{self.name}({', '.join(args)}) -> {self.returns}"
@dataclass(frozen=True, kw_only=True)
class Argument:
pos: int
name: str
type: Type
required: bool
def __str__(self) -> str:
opt: str = "" if self.required else "?"
return f"{self.name}: {self.type}{opt}"
@dataclass(frozen=True, kw_only=True)
class ComplexType:
properties: dict[str, Type]
def __str__(self) -> str:
props: list[str] = [f"{name}: {type}" for name, type in self.properties.items()]
return f"{{{', '.join(props)}}}"
Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType
@dataclass(frozen=True, kw_only=True)
class Operation:
signature: CallSignature
result: Type
def __str__(self) -> str:
return f"{self.signature} -> {self.result}"
@dataclass(frozen=True, kw_only=True)
class CallSignature:
left: Type
method: str
right: Type
def __str__(self) -> str:
return f"{self.method}({self.left}, {self.right})"
@dataclass(frozen=True, kw_only=True)
class TypeVar:
name: str
bound: Optional[Type]
def __str__(self) -> str:
if self.bound is not None:
return f"{self.name} <: {self.bound}"
return self.name
@dataclass(frozen=True, kw_only=True)
class GenericType:
name: str
params: list[TypeVar]
body: Type
def __str__(self) -> str:
return f"{self.name}[{', '.join(map(str, self.params))}]"
@dataclass(frozen=True, kw_only=True)
class AppliedType:
name: str
args: list[Type]
body: Type
def __str__(self) -> str:
return f"{self.name}[{', '.join(map(str, self.args))}]"
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def sub_argument(arg: Function.Argument):
return Function.Argument(
pos=arg.pos,
name=arg.name,
type=substitute_typevars(arg.type, substitutions),
required=arg.required,
)
match type:
case BaseType(name=name) if name in substitutions:
return substitutions[name]
case AliasType(name=name, type=type2):
return AliasType(name=name, type=substitute_typevars(type2, substitutions))
case Function(
name=name,
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns,
):
return Function(
name=name,
pos_args=list(map(sub_argument, pos_args)),
args=list(map(sub_argument, args)),
kw_args=list(map(sub_argument, kw_args)),
returns=substitute_typevars(returns, substitutions),
)
case ComplexType(properties=properties):
properties2: dict[str, Type] = {
name: substitute_typevars(prop, substitutions)
for name, prop in properties.items()
}
return ComplexType(properties=properties2)
case TypeVar(name=name):
if name in substitutions:
return substitutions[name]
raise ValueError(f"Missing TypeVar substitution for {name}")
case UnknownType() | UnitType():
return type
case _:
raise NotImplementedError(f"Unsupported type {type}")
def unfold_type(type: Type) -> Type:
match type:
case AliasType(type=ref_type):
return unfold_type(ref_type)
case _:
return type
Type = (
BaseType
| AliasType
| UnknownType
| UnitType
| Function
| ComplexType
| TypeVar
| GenericType
| AppliedType
)

41
midas/cli/ansi.py Normal file
View File

@@ -0,0 +1,41 @@
class Ansi:
CTRL = "\x1b["
RESET = CTRL + "0m"
BOLD = CTRL + "1m"
DIM = CTRL + "2m"
ITALIC = CTRL + "3m"
UNDERLINE = CTRL + "4m"
BLACK = 0
RED = 1
GREEN = 2
YELLOW = 3
BLUE = 4
MAGENTA = 5
CYAN = 6
WHITE = 7
BRIGHT_BLACK = 60
BRIGHT_RED = 61
BRIGHT_GREEN = 62
BRIGHT_YELLOW = 63
BRIGHT_BLUE = 64
BRIGHT_MAGENTA = 65
BRIGHT_CYAN = 66
BRIGHT_WHITE = 67
@classmethod
def FG(cls, col: int) -> str:
return f"{cls.CTRL}{30 + col}m"
@classmethod
def BG(cls, col: int) -> str:
return f"{cls.CTRL}{40 + col}m"
@classmethod
def FG_RGB(cls, r: int, g: int, b: int) -> str:
return f"{cls.CTRL}38;2;{r};{g};{b}m"
@classmethod
def BG_RGB(cls, r: int, g: int, b: int) -> str:
return f"{cls.CTRL}48;2;{r};{g};{b}m"

View File

@@ -210,12 +210,14 @@ class PythonHighlighter(
def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ...
def visit_set_expr(self, expr: p.SetExpr) -> None: ...
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
def visit_list_expr(self, expr: p.ListExpr) -> None:
for item in expr.items:
item.accept(self)
class MidasHighlighter(
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
@@ -286,8 +288,8 @@ class MidasHighlighter(
def visit_generic_type(self, type: m.GenericType) -> None:
self.wrap(type, "generic-type")
type.type.accept(self)
for param in type.params:
param.accept(self)
for arg in type.args:
arg.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self.wrap(type, "constraint-type")
@@ -299,6 +301,12 @@ class MidasHighlighter(
for prop in type.properties:
prop.accept(self)
def visit_function_type(self, type: m.FunctionType) -> None:
self.wrap(type, "function")
for arg in type.pos_args + type.kw_args:
arg.type.accept(self)
type.returns.accept(self)
class DiagnosticsHighlighter(Highlighter):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"

View File

@@ -8,10 +8,12 @@ import click
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
from midas.checker.checker import Checker
from midas.checker.diagnostic import Diagnostic
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.types import Type
from midas.cli.ansi import Ansi
from midas.cli.highlighter import (
DiagnosticsHighlighter,
Highlighter,
@@ -23,7 +25,6 @@ from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token, TokenType
from midas.parser.midas import MidasParser
from midas.parser.python import PythonParser
from midas.resolver.resolver import Resolver
from midas.utils import UniversalJSONDumper
@@ -32,32 +33,86 @@ def midas():
pass
def print_diagnostic(lines: list[str], diagnostic: Diagnostic, indent: int = 4):
"""Pretty-print a diagnostic, showing some context if possible
If the diagnostic concerns a specific part of one line, the line is shown
with the affected part highlighted. The message is clearly printed under the
line with an underline further indicating the target expression.
If multiple lines are concerned, no context is shown, only the
diagnostic type, location and message
Args:
lines (list[str]): source code lines
diagnostic (Diagnostic): the diagnostic to print
indent (int, optional): the number of spaces added before the target line to indent if from the location header. Defaults to 4.
"""
loc: Location = diagnostic.location
if loc.lineno != loc.end_lineno:
print(diagnostic)
return
start_offset: int = loc.col_offset
end_offset: int = loc.end_col_offset or (start_offset + 1)
line: str = lines[loc.lineno - 1]
before: str = line[:start_offset]
after: str = line[end_offset:]
color: int = {
DiagnosticType.ERROR: Ansi.RED,
DiagnosticType.WARNING: Ansi.YELLOW,
DiagnosticType.INFO: Ansi.CYAN,
}.get(diagnostic.type, Ansi.WHITE)
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
cursor: str = (
" " * start_offset
+ Ansi.FG(color)
+ "~" * (end_offset - start_offset)
+ "> "
+ diagnostic.message
+ Ansi.RESET
)
indent_str: str = " " * indent
print(diagnostic.location_str + ":")
print(indent_str + before + subject + after)
print(indent_str + cursor)
print()
@midas.command()
@click.option("-l", "--highlight", type=click.File("w"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-v", "--verbose", is_flag=True)
@click.argument("file", type=click.File("r"))
def compile(highlight: Optional[TextIO], file: TextIO, types: tuple[TextIO]):
logging.basicConfig(level=logging.DEBUG)
def compile(
highlight: Optional[TextIO],
types: tuple[TextIO],
verbose: bool,
file: TextIO,
):
logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN)
source: str = file.read()
tree: ast.Module = ast.parse(source, filename=file.name)
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
resolver = Resolver()
resolver.resolve(*stmts)
types_paths: list[Path] = [Path(t.name).resolve() for t in types]
checker = Checker(
resolver.locals,
source_path=Path(file.name).resolve(),
types_paths=types_paths,
)
diagnostics: list[Diagnostic] = checker.check(stmts)
for diagnostic in diagnostics:
print(diagnostic)
checker = TypeChecker()
for path in types:
checker.import_midas(Path(path.name).resolve())
checker.type_check_source(source, str(Path(file.name).resolve()))
diagnostics: list[Diagnostic] = checker.diagnostics
lines: list[str] = source.split("\n")
for diagnostic in diagnostics:
print_diagnostic(lines, diagnostic)
if verbose:
print(
json.dumps(
UniversalJSONDumper.dump(
checker.global_env,
checker.python_typer.global_env,
[("Environment", "_children")],
lambda obj: isinstance(obj, get_args(Type)),
),

View File

@@ -50,12 +50,14 @@ class MidasLexer(Lexer):
# self.add_token(TokenType.PLUS)
case "-":
self.add_token(TokenType.MINUS)
# case "*":
# self.add_token(TokenType.STAR)
case "*":
self.add_token(TokenType.STAR)
case "/" if self.match("/"):
self.scan_comment()
case "/" if self.match("*"):
self.scan_comment_multiline()
case "/":
self.add_token(TokenType.SLASH)
case "\n":
self.add_token(TokenType.NEWLINE)
case " " | "\r" | "\t":

View File

@@ -27,8 +27,8 @@ class TokenType(Enum):
# Operators
# PLUS = auto()
MINUS = auto()
# STAR = auto()
# SLASH = auto()
STAR = auto()
SLASH = auto()
GREATER = auto()
GREATER_EQUAL = auto()
LESS = auto()

View File

@@ -7,6 +7,7 @@ from midas.ast.midas import (
ConstraintType,
Expr,
ExtendStmt,
FunctionType,
GenericType,
GetExpr,
GroupingExpr,
@@ -18,12 +19,13 @@ from midas.ast.midas import (
PropertyStmt,
Stmt,
Type,
TypeParam,
TypeStmt,
UnaryExpr,
VariableExpr,
WildcardExpr,
)
from midas.lexer.token import Token, TokenType
from midas.lexer.token import KEYWORDS, Token, TokenType
from midas.parser.base import Parser
from midas.parser.errors import ParsingError
@@ -107,10 +109,8 @@ class MidasParser(Parser):
TypeStmt: the parsed type declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
params: list[TypeStmt.Param] = []
if self.check(TokenType.LEFT_BRACKET):
params = self.type_stmt_params()
name: Token = self.consume_identifier("Expected type name")
params: list[TypeParam] = self.type_params()
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
@@ -123,24 +123,27 @@ class MidasParser(Parser):
type=type,
)
def type_stmt_params(self) -> list[TypeStmt.Param]:
"""Parse a generic template expression
def type_params(self) -> list[TypeParam]:
"""Parse a list of type parameters
A template is written `[TypeExpr]`
Type parameters are a comma-separated list of type variables wrapped in brackets.
Each type variable is either a simple variable, or a bounded variable written `S <: T`
Returns:
TemplateExpr: the parsed template expression
list[TypeParam]: the list of type parameters, if any, or an empty list
"""
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
params: list[TypeStmt.Param] = []
if not self.match(TokenType.LEFT_BRACKET):
return []
params: list[TypeParam] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
name: Token = self.consume_identifier("Expected type variable")
bound: Optional[Type] = None
if self.match(TokenType.LESS):
self.consume(TokenType.COLON, "Expected ':' after '<'")
bound = self.type_expr()
params.append(
TypeStmt.Param(
TypeParam(
location=name.location_to(self.previous()),
name=name,
bound=bound,
@@ -148,7 +151,7 @@ class MidasParser(Parser):
)
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
return params
def type_expr(self) -> Type:
@@ -187,26 +190,26 @@ class MidasParser(Parser):
def generic_type(self) -> Type:
type: Type = self.named_type()
if self.check(TokenType.LEFT_BRACKET):
params: list[Type] = self.type_params()
args: list[Type] = self.type_args()
return GenericType(
location=Location.span(type.location, self.previous().get_location()),
type=type,
params=params,
args=args,
)
return type
def type_params(self) -> list[Type]:
params: list[Type] = []
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters")
def type_args(self) -> list[Type]:
args: list[Type] = []
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments")
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
params.append(self.type_expr())
args.append(self.type_expr())
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters")
return params
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
return args
def named_type(self) -> Type:
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
name: Token = self.consume_identifier("Expected type name")
return NamedType(
location=name.get_location(),
name=name,
@@ -322,9 +325,7 @@ class MidasParser(Parser):
"""
expr: Expr = self.primary()
while self.match(TokenType.DOT):
name: Token = self.consume(
TokenType.IDENTIFIER, "Expected property name after '.'"
)
name: Token = self.consume_identifier("Expected property name after '.'")
location: Location = Location.span(expr.location, name.get_location())
expr = GetExpr(location=location, expr=expr, name=name)
return expr
@@ -348,7 +349,7 @@ class MidasParser(Parser):
if self.match(TokenType.NUMBER):
return LiteralExpr(location=token.get_location(), value=token.value)
if self.match(TokenType.IDENTIFIER):
if self.match_identifier():
return VariableExpr(location=token.get_location(), name=token)
if self.match(TokenType.UNDERSCORE):
@@ -361,6 +362,20 @@ class MidasParser(Parser):
raise self.error(self.peek(), "Expected expression")
def consume_identifier(self, message: str = "Expected identifier") -> Token:
if not self.match_identifier():
raise self.error(self.peek(), message)
return self.previous()
def match_identifier(self) -> bool:
return self.match(TokenType.IDENTIFIER, *KEYWORDS.values())
def check_identifier(self) -> bool:
for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]:
if self.check(tt):
return True
return False
def property_stmt(self) -> PropertyStmt:
"""Parse a property statement
@@ -369,7 +384,7 @@ class MidasParser(Parser):
Returns:
PropertyStmt: the parsed property statement
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
name: Token = self.consume_identifier("Expected property name")
self.consume(TokenType.COLON, "Expected ':' after property name")
type: Type = self.type_expr()
return PropertyStmt(
@@ -381,12 +396,14 @@ class MidasParser(Parser):
def extend_declaration(self) -> ExtendStmt:
"""Parse an extension definition
An extension is written `extend Type { operations }`
An extension is written `extend Type { operations }` or `extend[S <: T, U] Type { operations }`
Returns:
ExtendStmt: the parsed extension statement
"""
keyword: Token = self.previous()
params: list[TypeParam] = self.type_params()
type: Type = self.type_expr()
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
operations: list[OpStmt] = []
@@ -394,7 +411,12 @@ class MidasParser(Parser):
operations.append(self.op_declaration())
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
location: Location = keyword.location_to(self.previous())
return ExtendStmt(location=location, type=type, operations=operations)
return ExtendStmt(
location=location,
params=params,
type=type,
operations=operations,
)
def op_declaration(self) -> OpStmt:
"""Parse an operation definition
@@ -430,9 +452,9 @@ class MidasParser(Parser):
PredicateStmt: the parsed predicate declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
name: Token = self.consume_identifier("Expected predicate name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
subject: Token = self.consume_identifier("Expected subject name")
self.consume(TokenType.COLON, "Expected ':' after subject name")
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
@@ -445,3 +467,48 @@ class MidasParser(Parser):
type=type,
condition=condition,
)
def function(self) -> FunctionType:
l_paren: Token = self.consume(
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
)
pos_args: list[FunctionType.Argument] = []
kw_args: list[FunctionType.Argument] = []
positional: bool = True
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
if positional and (
self.match(TokenType.STAR) or self.match(TokenType.SLASH)
):
positional = False
else:
name: Optional[Token] = None
if self.check_identifier() and self.check_next(TokenType.COLON):
name = self.advance()
self.advance()
type: Type = self.type_expr()
required: bool = self.match(TokenType.QMARK)
arg = FunctionType.Argument(
location=None,
name=name,
type=type,
required=required,
)
if positional:
pos_args.append(arg)
else:
kw_args.append(arg)
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr()
return FunctionType(
location=l_paren.location_to(self.previous()),
pos_args=pos_args,
kw_args=kw_args,
returns=result,
)

View File

@@ -17,6 +17,7 @@ from midas.ast.python import (
Function,
GetExpr,
IfStmt,
ListExpr,
LiteralExpr,
LogicalExpr,
MidasType,
@@ -416,6 +417,12 @@ class PythonParser:
case ast.Name(id=name):
return VariableExpr(location=location, name=name)
case ast.List(elts=items):
return ListExpr(
location=location,
items=[self.parse_expr(item) for item in items],
)
case _:
raise UnsupportedSyntaxError(node)

View File

@@ -1,72 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from midas.checker.types import BaseType, Type, UnitType
if TYPE_CHECKING:
from midas.resolver.midas import MidasResolver
def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
ctx.define_operation(
left=t1,
operator=operator,
right=t2,
result=t3,
)
def basic_op(ctx: MidasResolver, type: Type, op: str):
ctx.define_operation(
left=type,
operator=op,
right=type,
result=type,
)
def define_builtins(ctx: MidasResolver):
"""Define builtin types and operations"""
unit = ctx.define_type("None", UnitType())
bool = ctx.define_type("bool", BaseType(name="bool"))
int = ctx.define_type("int", BaseType(name="int"))
float = ctx.define_type("float", BaseType(name="float"))
str = ctx.define_type("str", BaseType(name="str"))
basic_op(ctx, int, "__add__") # int + int = int
basic_op(ctx, int, "__sub__") # int - int = int
basic_op(ctx, int, "__mul__") # int * int = int
basic_op(ctx, int, "__pow__") # int ** int = int
basic_op(ctx, int, "__mod__") # int % int = int
basic_op(ctx, int, "__and__") # int & int = int
basic_op(ctx, int, "__or__") # int | int = int
basic_op(ctx, int, "__xor__") # int ^ int = int
op(ctx, int, "__lt__", int, bool) # int < int = bool
op(ctx, int, "__gt__", int, bool) # int > int = bool
op(ctx, int, "__le__", int, bool) # int <= int = bool
op(ctx, int, "__ge__", int, bool) # int >= int = bool
op(ctx, int, "__eq__", int, bool) # int == int = bool
basic_op(ctx, float, "__add__") # float + float = float
basic_op(ctx, float, "__sub__") # float - float = float
basic_op(ctx, float, "__mul__") # float * float = float
basic_op(ctx, float, "__truediv__") # float / float = float
op(ctx, float, "__lt__", float, bool) # float < float = bool
op(ctx, float, "__gt__", float, bool) # float > float = bool
op(ctx, float, "__le__", float, bool) # float <= float = bool
op(ctx, float, "__ge__", float, bool) # float >= float = bool
op(ctx, float, "__eq__", float, bool) # float == float = bool
basic_op(ctx, str, "__add__") # str + str = str
op(ctx, str, "__eq__", str, bool) # str == str = bool
op(ctx, int, "__lt__", float, bool) # int < float = bool
op(ctx, int, "__gt__", float, bool) # int > float = bool
op(ctx, int, "__le__", float, bool) # int <= float = bool
op(ctx, int, "__ge__", float, bool) # int >= float = bool
op(ctx, int, "__eq__", float, bool) # int == float = bool
op(ctx, float, "__lt__", int, bool) # float < int = bool
op(ctx, float, "__gt__", int, bool) # float > int = bool
op(ctx, float, "__le__", int, bool) # float <= int = bool
op(ctx, float, "__ge__", int, bool) # float >= int = bool
op(ctx, float, "__eq__", int, bool) # float == int = bool

View File

@@ -1,163 +0,0 @@
from typing import Optional
import midas.ast.midas as m
from midas.checker.types import (
AliasType,
Type,
UnknownType,
)
from midas.resolver.builtin import define_builtins
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry"""
def __init__(self) -> None:
self._types: dict[str, Type] = {}
self._operations: dict[tuple[Type, str, Type], Type] = {}
define_builtins(self)
def get_type(self, name: str) -> Type:
"""Get a type from its name
Args:
name (str): the name of the type
Raises:
NameError: if the type is not defined
Returns:
Type: the type
"""
type: Optional[Type] = self._types.get(name)
if type is None:
raise NameError(f"Undefined type {name}")
return type
def get_operation_result(
self, left: Type, operator: str, right: Type
) -> Optional[Type]:
"""Get the resulting type of an operation
Args:
left (Type): the type of the left operand
operator (str): the operation name
right (Type): the type of the right operand
Returns:
Optional[Type]: the result type, or None if no matching operation was found
"""
operation: tuple[Type, str, Type] = (left, operator, right)
result: Optional[Type] = self._operations.get(operation)
return result
def define_type(self, name: str, type: Type) -> Type:
"""Define a type in the registry
Args:
name (str): the name of the type
type (Type): the type to define
Raises:
ValueError: if a type is already defined with that name
Returns:
Type: the defined type
"""
if name in self._types:
raise ValueError(f"Type {name} already defined")
self._types[name] = type
return type
def define_operation(self, left: Type, operator: str, right: Type, result: Type):
"""Define an operation in the registry
Args:
left (Type): the type of the left operand
operator (str): the operation name
right (Type): the type of the right operand
result (Type): the result type
Raises:
ValueError: if an operation is already defined with these operands and name
"""
operation: tuple[Type, str, Type] = (left, operator, right)
if operation in self._operations:
raise ValueError(
f"Operation {operator} already defined between {left} and {right}"
)
self._operations[operation] = result
def resolve(self, stmts: list[m.Stmt]):
"""Process a sequence of statements
Args:
stmts (list[m.Stmt]): the statements
"""
for stmt in stmts:
stmt.accept(self)
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
type: Type = stmt.type.accept(self)
for param in stmt.params:
if param.bound is not None:
param.bound.accept(self)
name: str = stmt.name.lexeme
self.define_type(name, AliasType(name=name, type=type))
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
base: Type = stmt.type.accept(self)
for op in stmt.operations:
right: Type = op.operand.accept(self)
result: Type = op.result.accept(self)
self.define_operation(
left=base,
operator=op.name.lexeme,
right=right,
result=result,
)
def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
def visit_get_expr(self, expr: m.GetExpr) -> None: ...
def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
return expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
def visit_named_type(self, type: m.NamedType) -> Type:
return self.get_type(type.name.lexeme)
def visit_generic_type(self, type: m.GenericType) -> Type:
type_: Type = type.type.accept(self)
params: list[Type] = [param.accept(self) for param in type.params]
# TODO
return UnknownType()
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
type_: Type = type.type.accept(self)
type.constraint.accept(self)
# TODO
return UnknownType()
def visit_complex_type(self, type: m.ComplexType) -> Type:
for prop in type.properties:
prop.accept(self)
# TODO
return UnknownType()

View File

@@ -29,7 +29,7 @@ class Tester(ABC):
def _list_tests(self) -> list[Path]: ...
def run_all_tests(self) -> bool:
paths: list[Path] = self._list_tests()
paths: list[Path] = sorted(self._list_tests())
return self.run_tests(paths)
def run_tests(self, tests: list[Path]) -> bool:
@@ -40,7 +40,7 @@ class Tester(ABC):
print(rule)
for i, test in enumerate(tests):
print(f"Case {i+1}/{n}: {test.relative_to(self.CASES_DIR)}")
print(f"Case {i+1}/{n}: {test.resolve().relative_to(self.CASES_DIR)}")
success: bool = self._run_test(test)
if success:
successes += 1
@@ -78,7 +78,7 @@ class Tester(ABC):
def _exec_case(self, path: Path) -> CaseResult: ...
def update_all_tests(self):
paths: list[Path] = self._list_tests()
paths: list[Path] = sorted(self._list_tests())
return self.update_tests(paths)
def update_tests(self, tests: list[Path]):

View File

@@ -1,3 +1,4 @@
{
"diagnostics": []
"diagnostics": [],
"judgments": []
}

View File

@@ -12,35 +12,168 @@
13
]
},
"message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')"
"message": "Cannot assign str to variable 'c' of type int"
}
],
"judgments": [
{
"location": {
"from": "L1:9",
"to": "L1:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 3
},
"type": {
"name": "int"
}
},
{
"type": "Error",
"location": {
"start": [
9,
4
],
"end": [
9,
9
]
"from": "L2:9",
"to": "L2:10"
},
"message": "Undefined operation __add__ between BaseType(name='bool') and BaseType(name='bool')"
"expr": {
"_type": "LiteralExpr",
"value": 4
},
"type": {
"name": "int"
}
},
{
"type": "Error",
"location": {
"start": [
11,
0
],
"end": [
11,
12
]
"from": "L4:4",
"to": "L4:5"
},
"message": "Cannot assign BaseType(name='int') to f of type BaseType(name='float')"
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L4:8",
"to": "L4:9"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L4:4",
"to": "L4:9"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "a"
},
"operator": "+",
"right": {
"_type": "VariableExpr",
"name": "b"
}
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:4",
"to": "L6:13"
},
"expr": {
"_type": "LiteralExpr",
"value": "invalid"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L8:4",
"to": "L8:8"
},
"expr": {
"_type": "LiteralExpr",
"value": true
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L9:4",
"to": "L9:5"
},
"expr": {
"_type": "VariableExpr",
"name": "d"
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L9:8",
"to": "L9:9"
},
"expr": {
"_type": "VariableExpr",
"name": "d"
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L9:4",
"to": "L9:9"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "d"
},
"operator": "+",
"right": {
"_type": "VariableExpr",
"name": "d"
}
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L11:11",
"to": "L11:12"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
}
]
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,109 @@
{
"diagnostics": []
"diagnostics": [],
"judgments": [
{
"location": {
"from": "L4:18",
"to": "L4:37"
},
"expr": {
"_type": "CastExpr",
"type": {
"_type": "BaseType",
"base": "Meter",
"param": null
},
"expr": {
"_type": "LiteralExpr",
"value": 123.45
}
},
"type": {
"name": "Meter",
"type": {
"name": "float"
}
}
},
{
"location": {
"from": "L5:15",
"to": "L5:32"
},
"expr": {
"_type": "CastExpr",
"type": {
"_type": "BaseType",
"base": "Second",
"param": null
},
"expr": {
"_type": "LiteralExpr",
"value": 6.7
}
},
"type": {
"name": "Second",
"type": {
"name": "float"
}
}
},
{
"location": {
"from": "L6:8",
"to": "L6:16"
},
"expr": {
"_type": "VariableExpr",
"name": "distance"
},
"type": {
"name": "Meter",
"type": {
"name": "float"
}
}
},
{
"location": {
"from": "L6:19",
"to": "L6:23"
},
"expr": {
"_type": "VariableExpr",
"name": "time"
},
"type": {
"name": "Second",
"type": {
"name": "float"
}
}
},
{
"location": {
"from": "L6:8",
"to": "L6:23"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "distance"
},
"operator": "/",
"right": {
"_type": "VariableExpr",
"name": "time"
}
},
"type": {
"name": "MeterPerSecond",
"type": {
"name": "float"
}
}
}
]
}

View File

@@ -42,5 +42,215 @@
},
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
}
],
"judgments": [
{
"location": {
"from": "L2:11",
"to": "L2:12"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L2:15",
"to": "L2:16"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L5:7",
"to": "L5:8"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L5:11",
"to": "L5:12"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:15",
"to": "L6:16"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:19",
"to": "L6:20"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L8:15",
"to": "L8:16"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L8:19",
"to": "L8:20"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L15:7",
"to": "L15:8"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L15:11",
"to": "L15:13"
},
"expr": {
"_type": "LiteralExpr",
"value": 10
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L16:15",
"to": "L16:16"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L16:19",
"to": "L16:21"
},
"expr": {
"_type": "LiteralExpr",
"value": 10
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L22:7",
"to": "L22:8"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L22:11",
"to": "L22:12"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L23:15",
"to": "L23:16"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L23:19",
"to": "L23:20"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
}
]
}

View File

@@ -0,0 +1,12 @@
v1: int = 3
v2: float = 4
def maximum(a: float, b: float):
if b > a:
return b
return a
v3 = maximum(v1, v2)
v3 = v1 + v2

View File

@@ -0,0 +1,193 @@
{
"diagnostics": [],
"judgments": [
{
"location": {
"from": "L1:10",
"to": "L1:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 3
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L2:12",
"to": "L2:13"
},
"expr": {
"_type": "LiteralExpr",
"value": 4
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:7",
"to": "L6:8"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L6:11",
"to": "L6:12"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L11:5",
"to": "L11:12"
},
"expr": {
"_type": "VariableExpr",
"name": "maximum"
},
"type": {
"name": "maximum",
"pos_args": [],
"args": [
{
"pos": 0,
"name": "a",
"type": {
"name": "float"
},
"required": true
},
{
"pos": 1,
"name": "b",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L11:13",
"to": "L11:15"
},
"expr": {
"_type": "VariableExpr",
"name": "v1"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L11:17",
"to": "L11:19"
},
"expr": {
"_type": "VariableExpr",
"name": "v2"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L11:5",
"to": "L11:20"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "maximum"
},
"arguments": [
{
"_type": "VariableExpr",
"name": "v1"
},
{
"_type": "VariableExpr",
"name": "v2"
}
],
"keywords": {}
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L12:5",
"to": "L12:7"
},
"expr": {
"_type": "VariableExpr",
"name": "v1"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L12:10",
"to": "L12:12"
},
"expr": {
"_type": "VariableExpr",
"name": "v2"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L12:5",
"to": "L12:12"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "v1"
},
"operator": "+",
"right": {
"_type": "VariableExpr",
"name": "v2"
}
},
"type": {
"name": "float"
}
}
]
}

View File

@@ -2385,7 +2385,7 @@
"_type": "NamedType",
"name": "Difference"
},
"params": [
"args": [
{
"_type": "NamedType",
"name": "GeoLocation"
@@ -2416,7 +2416,7 @@
"_type": "NamedType",
"name": "Difference"
},
"params": [
"args": [
{
"_type": "NamedType",
"name": "Latitude"
@@ -2433,7 +2433,7 @@
"_type": "NamedType",
"name": "Difference"
},
"params": [
"args": [
{
"_type": "NamedType",
"name": "Longitude"
@@ -2464,7 +2464,7 @@
"_type": "NamedType",
"name": "Difference"
},
"params": [
"args": [
{
"_type": "NamedType",
"name": "Latitude"
@@ -2494,7 +2494,7 @@
"_type": "NamedType",
"name": "Difference"
},
"params": [
"args": [
{
"_type": "NamedType",
"name": "Longitude"
@@ -2638,7 +2638,7 @@
"_type": "NamedType",
"name": "Optional"
},
"params": [
"args": [
{
"_type": "ConstraintType",
"type": {

View File

@@ -1,19 +1,19 @@
import ast
import json
from dataclasses import asdict, dataclass, field
from pathlib import Path
import midas.ast.python as p
from midas.checker.checker import Checker
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic
from midas.parser.python import PythonParser
from midas.resolver.resolver import Resolver
from midas.checker.types import Type
from tests.base import Tester
from tests.serializer.python import PythonAstJsonSerializer
@dataclass
class CaseResult:
diagnostics: list[dict] = field(default_factory=list)
judgments: list = field(default_factory=list)
def dumps(self) -> str:
return json.dumps(asdict(self), indent=2)
@@ -33,23 +33,16 @@ class CheckerTester(Tester):
if not path.is_file():
raise TypeError(f"Test '{path}' is not a file")
types_paths: list[Path] = []
result: CaseResult = CaseResult()
checker = TypeChecker()
types_path: Path = path.with_suffix(".midas")
if types_path.exists():
types_paths.append(types_path)
source: str = path.read_text()
tree: ast.Module = ast.parse(source, filename=path)
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
resolver = Resolver()
resolver.resolve(*stmts)
result: CaseResult = CaseResult()
checker = Checker(
resolver.locals,
source_path=path,
types_paths=types_paths,
)
diagnostics: list[Diagnostic] = checker.check(stmts)
checker.import_midas(types_path)
checker.type_check(path)
diagnostics: list[Diagnostic] = checker.diagnostics
for diagnostic in diagnostics:
result.diagnostics.append(
{
@@ -68,6 +61,21 @@ class CheckerTester(Tester):
}
)
judgements: list[tuple[p.Expr, Type]] = checker.python_typer.judgements
serializer = PythonAstJsonSerializer()
for expr, type in judgements:
loc = expr.location
result.judgments.append(
{
"location": {
"from": f"L{loc.lineno}:{loc.col_offset}",
"to": f"L{loc.end_lineno}:{loc.end_col_offset}",
},
"expr": expr.accept(serializer),
"type": asdict(type),
}
)
return result

View File

@@ -6,6 +6,7 @@ from midas.ast.midas import (
ConstraintType,
Expr,
ExtendStmt,
FunctionType,
GenericType,
GetExpr,
GroupingExpr,
@@ -17,6 +18,7 @@ from midas.ast.midas import (
PropertyStmt,
Stmt,
Type,
TypeParam,
TypeStmt,
UnaryExpr,
VariableExpr,
@@ -46,13 +48,11 @@ class MidasAstJsonSerializer(
return {
"_type": "TypeStmt",
"name": stmt.name.lexeme,
"params": [
self._serialize_type_stmt_template_param(param) for param in stmt.params
],
"params": [self._serialize_type_param(param) for param in stmt.params],
"type": stmt.type.accept(self),
}
def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict:
def _serialize_type_param(self, param: TypeParam) -> dict:
return {
"name": param.name.lexeme,
"bound": self._serialize_optional(param.bound),
@@ -150,7 +150,7 @@ class MidasAstJsonSerializer(
return {
"_type": "GenericType",
"type": type.type.accept(self),
"params": self._serialize_list(type.params),
"args": self._serialize_list(type.args),
}
def visit_constraint_type(self, type: ConstraintType) -> dict:
@@ -165,3 +165,18 @@ class MidasAstJsonSerializer(
"_type": "ComplexType",
"properties": self._serialize_list(type.properties),
}
def visit_function_type(self, type: FunctionType) -> dict:
return {
"_type": "FunctionType",
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
"returns": type.returns.accept(self),
}
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
return {
"name": arg.name,
"type": arg.type.accept(self),
"required": arg.required,
}

View File

@@ -16,11 +16,11 @@ from midas.ast.python import (
Function,
GetExpr,
IfStmt,
ListExpr,
LiteralExpr,
LogicalExpr,
MidasType,
ReturnStmt,
SetExpr,
Stmt,
TernaryExpr,
TypeAssign,
@@ -232,14 +232,6 @@ class PythonAstJsonSerializer(
"right": expr.right.accept(self),
}
def visit_set_expr(self, expr: SetExpr) -> dict:
return {
"_type": "SetExpr",
"object": expr.object.accept(self),
"name": expr.name,
"value": expr.value.accept(self),
}
def visit_cast_expr(self, expr: CastExpr) -> dict:
return {
"_type": "CastExpr",
@@ -254,3 +246,9 @@ class PythonAstJsonSerializer(
"if_true": expr.if_true.accept(self),
"if_false": expr.if_false.accept(self),
}
def visit_list_expr(self, expr: ListExpr) -> dict:
return {
"_type": "ListExpr",
"items": [item.accept(self) for item in expr.items],
}