15 Commits

Author SHA1 Message Date
e0179bc442 feat(checker): handle assignments to attributes 2026-06-07 17:50:56 +02:00
e665d03533 fix: remove unused SetExpr 2026-06-07 17:48:31 +02:00
b8cb2b4273 feat(checker): handle attribute getter 2026-06-07 15:07:24 +02:00
d278dc5f5b tests: update tests with operation overloads 2026-06-07 14:28:36 +02:00
59e73f0fd9 fix(checker): invert property subtype check 2026-06-07 14:00:02 +02:00
3e0dc60283 fix(checker): only unfold alias on subtype 2026-06-07 13:59:27 +02:00
c24eb5125e feat(checker): resolve operation overloads with subtypes 2026-06-07 13:43:43 +02:00
25bd895dde feat(cli): improve diagnostic printing 2026-06-07 13:42:15 +02:00
bccd75317e tests: add subtyping test 2026-06-06 16:59:49 +02:00
f0e3f7574f feat(tests): add judgements to test results
add type judgements to checker test results and update all tests (including the new subtyping rules)
2026-06-06 16:58:13 +02:00
5d44081847 feat(checker): implement function subtyping
the logic for checking function subtypes is a WIP and has not been fully tested, there may be some errors and unhandled edge cases
Claude helped lay out and verify the overall steps

Co-authored-by: Claude <noreply@anthropic.com>
2026-06-06 16:53:52 +02:00
2a2bb0aec7 feat(checker): store function param position 2026-06-06 16:50:42 +02:00
67c40a3909 feat(checker): add is_subtype method 2026-06-06 16:30:04 +02:00
1c30188122 feat(checker): record type judgements 2026-06-06 16:25:33 +02:00
82a0f13242 feat(cli): add verbose flag to compile 2026-06-05 14:17:24 +02:00
24 changed files with 2414 additions and 148 deletions

View File

@@ -0,0 +1,11 @@
type Meter = float
extend Meter {
op __add__(Meter) -> Meter
op __sub__(Meter) -> Meter
}
type Coordinate = {
x: Meter
y: Meter
}

View File

@@ -0,0 +1,11 @@
# type: ignore
# ruff: disable [F821]
p1: Coordinate
p2: Coordinate
diff_x = p2.x - p1.x
diff_y = p2.y - p1.y
dist = diff_x + diff_y
p2.x += cast(Meter, 1)

View File

@@ -128,12 +128,6 @@ class LogicalExpr:
right: Expr right: Expr
class SetExpr:
object: Expr
name: str
value: Expr
class CastExpr: class CastExpr:
type: MidasType type: MidasType
expr: Expr expr: Expr

View File

@@ -602,17 +602,6 @@ class PythonAstPrinter(
with self._child_level(single=True): with self._child_level(single=True):
expr.right.accept(self) expr.right.accept(self)
def visit_set_expr(self, expr: p.SetExpr) -> None:
self._write_line("SetExpr")
with self._child_level():
self._write_line("object")
with self._child_level(single=True):
expr.object.accept(self)
self._write_line(f"name: {expr.name}")
self._write_line("value", last=True)
with self._child_level(single=True):
expr.value.accept(self)
def visit_cast_expr(self, expr: p.CastExpr) -> None: def visit_cast_expr(self, expr: p.CastExpr) -> None:
self._write_line("CastExpr") self._write_line("CastExpr")
with self._child_level(): with self._child_level():

View File

@@ -214,9 +214,6 @@ class Expr(ABC):
@abstractmethod @abstractmethod
def visit_logical_expr(self, expr: LogicalExpr) -> T: ... def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
@abstractmethod
def visit_set_expr(self, expr: SetExpr) -> T: ...
@abstractmethod @abstractmethod
def visit_cast_expr(self, expr: CastExpr) -> T: ... def visit_cast_expr(self, expr: CastExpr) -> T: ...
@@ -298,16 +295,6 @@ class LogicalExpr(Expr):
return visitor.visit_logical_expr(self) return visitor.visit_logical_expr(self)
@dataclass(frozen=True)
class SetExpr(Expr):
object: Expr
name: str
value: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_set_expr(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class CastExpr(Expr): class CastExpr(Expr):
type: MidasType type: MidasType

View File

@@ -0,0 +1,4 @@
BUILTIN_SUBTYPES: dict[str, set[str]] = {
"float": {"int"},
"int": {"bool"},
}

View File

@@ -6,10 +6,20 @@ from typing import Optional
import midas.ast.midas as m import midas.ast.midas as m
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.diagnostic import Diagnostic, DiagnosticType from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
from midas.checker.types import Function, Type, UnitType, UnknownType from midas.checker.types import (
AliasType,
BaseType,
ComplexType,
Function,
Operation,
Type,
UnitType,
UnknownType,
)
from midas.lexer.midas import MidasLexer from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token from midas.lexer.token import Token
from midas.parser.midas import MidasParser from midas.parser.midas import MidasParser
@@ -48,6 +58,7 @@ class Checker(
self.env: Environment = self.global_env self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = locals self.locals: dict[p.Expr, int] = locals
self.diagnostics: list[Diagnostic] = [] self.diagnostics: list[Diagnostic] = []
self.judgements: list[tuple[p.Expr, Type]] = []
def diagnostic(self, type: DiagnosticType, location: Location, message: str): def diagnostic(self, type: DiagnosticType, location: Location, message: str):
self.diagnostics.append( self.diagnostics.append(
@@ -89,7 +100,9 @@ class Checker(
Returns: Returns:
Type: the type of the given expression Type: the type of the given expression
""" """
return expr.accept(self) type: Type = expr.accept(self)
self.judgements.append((expr, type))
return type
def process_block(self, block: list[p.Stmt], env: Environment) -> bool: def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
"""Evaluate a sequence of statements """Evaluate a sequence of statements
@@ -165,6 +178,158 @@ class Checker(
stmts: list[m.Stmt] = parser.parse() stmts: list[m.Stmt] = parser.parse()
self.ctx.resolve(stmts) self.ctx.resolve(stmts)
def unfold_type(self, type: Type) -> Type:
match type:
case AliasType(type=ref_type):
return self.unfold_type(ref_type)
case _:
return type
def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2`
For more details on the rules checked here, see TAPL Chap. 15-16-17
Args:
type1 (Type): the potential subtype
type2 (Type): the potential supertype
Returns:
bool: whether `type1` is a subtype of `type2`
"""
if type1 == type2:
return True
match (type1, type2):
case (AliasType(type=base1), _):
return self.is_subtype(base1, type2)
case (BaseType(name=name1), BaseType(name=name2)):
return name1 in BUILTIN_SUBTYPES.get(name2, set())
case (ComplexType(properties=props1), ComplexType(properties=props2)):
for k, t in props2.items():
if k not in props1:
return False
if not self.is_subtype(props1[k], t):
return False
return True
case (Function(returns=return1), Function(returns=return2)):
if not self.is_func_subtype(type1, type2):
return False
if not self.is_subtype(return1, return2):
return False
return True
return False
# TODO: verify the logic in here
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
"""Check whether a function is a subtype of another
Args:
func1 (Function): the potential function subtype
func2 (Function): the potential function supertype
Returns:
bool: whether `func1` is a subtype of `func2`
"""
if not self.is_subtype(func1.returns, func2.returns):
return False
pos1: list[Function.Argument] = func1.pos_args
mixed1: list[Function.Argument] = func1.args
kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args}
pos2: list[Function.Argument] = func2.pos_args
mixed2: list[Function.Argument] = func2.args
kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args}
mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2}
mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2}
def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool:
if not self.is_subtype(sub.type, sup.type):
return False
if not sup.required and sub.required:
return False
return True
for arg1 in pos1:
arg2: Function.Argument
if arg1.pos < len(pos2):
arg2 = pos2[arg1.pos]
elif arg1.pos in mixed_by_pos:
arg2 = mixed_by_pos[arg1.pos]
elif not arg1.required:
continue
else:
return False
if not is_arg_subtype(arg2, arg1):
return False
for name, arg1 in kw1.items():
arg2: Function.Argument
if name in kw2:
arg2 = kw2[name]
elif name in mixed_by_name:
arg2 = mixed_by_name[name]
elif not arg1.required:
continue
else:
return False
if not is_arg_subtype(arg2, arg1):
return False
for arg1 in mixed1:
pos_arg2: Optional[Function.Argument] = None
kw_arg2: Optional[Function.Argument] = None
if arg1.name in kw2:
kw_arg2 = kw2[arg1.name]
elif arg1.name in mixed_by_name:
kw_arg2 = mixed_by_name[arg1.name]
if arg1.pos < len(pos2):
pos_arg2 = pos2[arg1.pos]
elif arg1.pos in mixed_by_pos:
pos_arg2 = mixed_by_pos[arg1.pos]
# No match in func2 and arg is required
if pos_arg2 is None and kw_arg2 is None and arg1.required:
return False
# Matching keyword argument
if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1):
return False
# Matching positional argument
if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1):
return False
mixed_positions: set[int] = {a.pos for a in mixed1}
mixed_names: set[str] = {a.name for a in mixed1}
for arg2 in pos2:
if not arg2.required:
continue
if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions:
return False
for name, arg2 in kw2.items():
if not arg2.required:
continue
if name not in kw1 and name not in mixed_names:
return False
for arg2 in mixed2:
if arg2.required:
continue
pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions
kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names
if not pos_match or not kw_match:
return False
return True
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
self.type_of(stmt.expr) self.type_of(stmt.expr)
@@ -181,30 +346,37 @@ class Checker(
return arg.default.accept(self) return arg.default.accept(self)
return UnknownType() return UnknownType()
pos: int = 0
for arg in stmt.posonlyargs: for arg in stmt.posonlyargs:
pos_args.append( pos_args.append(
Function.Argument( Function.Argument(
pos=pos,
name=arg.name, name=arg.name,
type=eval_arg_type(arg), type=eval_arg_type(arg),
required=arg.default is None, required=arg.default is None,
) )
) )
pos += 1
for arg in stmt.args: for arg in stmt.args:
args.append( args.append(
Function.Argument( Function.Argument(
pos=pos,
name=arg.name, name=arg.name,
type=eval_arg_type(arg), type=eval_arg_type(arg),
required=arg.default is None, required=arg.default is None,
) )
) )
pos += 1
for arg in stmt.kwonlyargs: for arg in stmt.kwonlyargs:
kw_args.append( kw_args.append(
Function.Argument( Function.Argument(
pos=pos, # not relevant
name=arg.name, name=arg.name,
type=eval_arg_type(arg), type=eval_arg_type(arg),
required=arg.default is None, required=arg.default is None,
) )
) )
pos += 1
for arg in pos_args + args + kw_args: for arg in pos_args + args + kw_args:
env.define(arg.name, arg.type) env.define(arg.name, arg.type)
@@ -263,23 +435,65 @@ class Checker(
self.env.define(stmt.name, type) self.env.define(stmt.name, type)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
value: Type = self.type_of(stmt.value) value_type: Type = self.type_of(stmt.value)
for target in stmt.targets: for target in stmt.targets:
self._assign(stmt.location, target, value_type)
def _assign(self, location: Location, target: p.Expr, value_type: Type):
match target:
case p.VariableExpr():
self._assign_var(location, target, value_type)
case p.GetExpr():
self._assign_attr(location, target, value_type)
case _:
if not isinstance(target, p.VariableExpr): if not isinstance(target, p.VariableExpr):
self.logger.warning(f"Unsupported assignment to {target}") self.logger.warning(f"Unsupported assignment to {target}")
self.warning(target.location, f"Unsupported assignment to {target}") self.warning(target.location, f"Unsupported assignment to {target}")
continue
def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type):
name: str = target.name name: str = target.name
var_type: Optional[Type] = self.look_up_variable(name, target) var_type: Optional[Type] = self.look_up_variable(name, target)
if var_type is None: if var_type is None:
self.env.define(name, value) self.env.define(name, value_type)
else: else:
# TODO: implement real comparison method # S <: T
if var_type != value: # Γ, x: T v: S
# x = v
if not self.is_subtype(value_type, var_type):
self.error( self.error(
stmt.location, location,
f"Cannot assign {value} to {name} of type {var_type}", f"Cannot assign {value_type} to {name} of type {var_type}",
)
def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type):
object: Type = self.type_of(target.object)
base_object: Type = self.unfold_type(object)
match base_object:
case ComplexType(properties=properties):
if target.name not in properties:
self.error(
target.location, f"Unknown property '{target.name} on {object}"
)
return
prop_type: Type = properties[target.name]
if not self.is_subtype(value_type, prop_type):
self.error(
location,
f"Cannot assign {value_type} to property '{target.name}' of type {prop_type} on {object}",
)
return
case UnknownType():
pass
case _:
self.error(
target.location,
f"Cannot assign {value_type} to unknown property '{target.name}' on {object}",
) )
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
@@ -317,14 +531,48 @@ class Checker(
left: Type = self.type_of(expr.left) left: Type = self.type_of(expr.left)
right: Type = self.type_of(expr.right) right: Type = self.type_of(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right) operations: list[Operation] = self.ctx.get_operations_by_name(method)
if result is None: valid_operations: list[Operation] = []
for op in operations:
sig: Operation.CallSignature = op.signature
if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right):
valid_operations.append(op)
if len(valid_operations) == 0:
self.error( self.error(
expr.location, expr.location,
f"Undefined operation {method} between {left} and {right}", f"Undefined operation {method} between {left} and {right}",
) )
return UnknownType() return UnknownType()
return result elif len(valid_operations) == 1:
self.logger.debug(f"Unique operation {method} between {left} and {right}")
return valid_operations[0].result
for i, op1 in enumerate(valid_operations):
sig1: Operation.CallSignature = op1.signature
best_match: bool = True
for j, op2 in enumerate(valid_operations):
if i == j:
continue
sig2: Operation.CallSignature = op2.signature
if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype(
sig1.right, sig2.right
):
best_match = False
break
self.logger.debug(f"{op1} is a full overload of {op2}")
if best_match:
return op1.result
overloads: list[str] = [
f"({op.signature.left} {op.signature.method} {op.signature.right}) -> {op.result}"
for op in valid_operations
]
self.error(
expr.location,
f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(overloads)}",
)
return UnknownType()
def visit_compare_expr(self, expr: p.CompareExpr) -> Type: def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__) method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
@@ -354,14 +602,33 @@ class Checker(
function: Function = callee function: Function = callee
mapped: list[MappedArgument] = self.map_call_arguments(function, expr) mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
for arg in mapped: for arg in mapped:
if arg.type != arg.argument.type: if not self.is_subtype(arg.type, arg.argument.type):
self.error( self.error(
arg.expr.location, arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
) )
return function.returns return function.returns
def visit_get_expr(self, expr: p.GetExpr) -> Type: ... def visit_get_expr(self, expr: p.GetExpr) -> Type:
object: Type = self.type_of(expr.object)
base_object: Type = self.unfold_type(object)
match base_object:
case ComplexType(properties=properties):
if expr.name not in properties:
self.error(
expr.location, f"Unknown property '{expr.name} on {object}"
)
return UnknownType()
return properties[expr.name]
case UnknownType():
return UnknownType()
case _:
self.error(
expr.location, f"Cannot get property '{expr.name}' on {object}"
)
return UnknownType()
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type: def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
match expr.value: match expr.value:
@@ -383,15 +650,17 @@ class Checker(
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
left: Type = expr.left.accept(self) left: Type = expr.left.accept(self)
right: Type = expr.right.accept(self) right: Type = expr.right.accept(self)
# TODO: union type
if left != right: if self.is_subtype(left, right):
self.error( return right
expr.location, if self.is_subtype(right, left):
f"Operands must be of the same type, left={left} != right={right}",
)
return left return left
def visit_set_expr(self, expr: p.SetExpr) -> Type: ... self.error(
expr.location,
f"Incompatible operand types, {left=} and {right=}",
)
return UnknownType()
def visit_cast_expr(self, expr: p.CastExpr) -> Type: def visit_cast_expr(self, expr: p.CastExpr) -> Type:
return expr.type.accept(self) return expr.type.accept(self)
@@ -407,13 +676,16 @@ class Checker(
true_type: Type = expr.if_true.accept(self) true_type: Type = expr.if_true.accept(self)
false_type: Type = expr.if_false.accept(self) false_type: Type = expr.if_false.accept(self)
if true_type != false_type: if self.is_subtype(true_type, false_type):
return false_type
if self.is_subtype(false_type, true_type):
return true_type
self.error( self.error(
expr.location, expr.location,
f"Type mismatch in ternary if branches: true={true_type} != false={false_type}", f"Incompatible types in ternary if branches: true={true_type} and false={false_type}",
) )
return UnknownType() return UnknownType()
return true_type
def visit_base_type(self, node: p.BaseType) -> Type: def visit_base_type(self, node: p.BaseType) -> Type:
return self.ctx.get_type(node.base) return self.ctx.get_type(node.base)

View File

@@ -19,7 +19,8 @@ class Diagnostic:
type: DiagnosticType type: DiagnosticType
message: str message: str
def __str__(self) -> str: @property
def location_str(self) -> str:
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}" start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
end_loc: Optional[str] = "" end_loc: Optional[str] = ""
if ( if (
@@ -30,4 +31,7 @@ class Diagnostic:
loc: str = ( loc: str = (
f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}" f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}"
) )
return f"{self.type} in {self.file_path} {loc}: {self.message}" return f"{self.type} in {self.file_path} {loc}"
def __str__(self) -> str:
return f"{self.location_str}: {self.message}"

View File

@@ -34,6 +34,7 @@ class Function:
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Argument: class Argument:
pos: int
name: str name: str
type: Type type: Type
required: bool required: bool
@@ -44,4 +45,16 @@ class ComplexType:
properties: dict[str, Type] properties: dict[str, Type]
@dataclass(frozen=True, kw_only=True)
class Operation:
signature: CallSignature
result: Type
@dataclass(frozen=True, kw_only=True)
class CallSignature:
left: Type
method: str
right: Type
Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType

41
midas/cli/ansi.py Normal file
View File

@@ -0,0 +1,41 @@
class Ansi:
CTRL = "\x1b["
RESET = CTRL + "0m"
BOLD = CTRL + "1m"
DIM = CTRL + "2m"
ITALIC = CTRL + "3m"
UNDERLINE = CTRL + "4m"
BLACK = 0
RED = 1
GREEN = 2
YELLOW = 3
BLUE = 4
MAGENTA = 5
CYAN = 6
WHITE = 7
BRIGHT_BLACK = 60
BRIGHT_RED = 61
BRIGHT_GREEN = 62
BRIGHT_YELLOW = 63
BRIGHT_BLUE = 64
BRIGHT_MAGENTA = 65
BRIGHT_CYAN = 66
BRIGHT_WHITE = 67
@classmethod
def FG(cls, col: int) -> str:
return f"{cls.CTRL}{30 + col}m"
@classmethod
def BG(cls, col: int) -> str:
return f"{cls.CTRL}{40 + col}m"
@classmethod
def FG_RGB(cls, r: int, g: int, b: int) -> str:
return f"{cls.CTRL}38;2;{r};{g};{b}m"
@classmethod
def BG_RGB(cls, r: int, g: int, b: int) -> str:
return f"{cls.CTRL}48;2;{r};{g};{b}m"

View File

@@ -210,8 +210,6 @@ class PythonHighlighter(
def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ... def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ...
def visit_set_expr(self, expr: p.SetExpr) -> None: ...
def visit_cast_expr(self, expr: p.CastExpr) -> None: ... def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ... def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...

View File

@@ -8,10 +8,12 @@ import click
import midas.ast.midas as m import midas.ast.midas as m
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
from midas.checker.checker import Checker from midas.checker.checker import Checker
from midas.checker.diagnostic import Diagnostic from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.types import Type from midas.checker.types import Type
from midas.cli.ansi import Ansi
from midas.cli.highlighter import ( from midas.cli.highlighter import (
DiagnosticsHighlighter, DiagnosticsHighlighter,
Highlighter, Highlighter,
@@ -32,12 +34,69 @@ def midas():
pass pass
def print_diagnostic(lines: list[str], diagnostic: Diagnostic, indent: int = 4):
"""Pretty-print a diagnostic, showing some context if possible
If the diagnostic concerns a specific part of one line, the line is shown
with the affected part highlighted. The message is clearly printed under the
line with an underline further indicating the target expression.
If multiple lines are concerned, no context is shown, only the
diagnostic type, location and message
Args:
lines (list[str]): source code lines
diagnostic (Diagnostic): the diagnostic to print
indent (int, optional): the number of spaces added before the target line to indent if from the location header. Defaults to 4.
"""
loc: Location = diagnostic.location
if loc.lineno != loc.end_lineno:
print(diagnostic)
return
start_offset: int = loc.col_offset
end_offset: int = loc.end_col_offset or (start_offset + 1)
line: str = lines[loc.lineno - 1]
before: str = line[:start_offset]
after: str = line[end_offset:]
color: int = {
DiagnosticType.ERROR: Ansi.RED,
DiagnosticType.WARNING: Ansi.YELLOW,
DiagnosticType.INFO: Ansi.CYAN,
}.get(diagnostic.type, Ansi.WHITE)
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
cursor: str = (
" " * start_offset
+ Ansi.FG(color)
+ "~" * (end_offset - start_offset)
+ "> "
+ diagnostic.message
+ Ansi.RESET
)
indent_str: str = " " * indent
print(diagnostic.location_str + ":")
print(indent_str + before + subject + after)
print(indent_str + cursor)
print()
@midas.command() @midas.command()
@click.option("-l", "--highlight", type=click.File("w")) @click.option("-l", "--highlight", type=click.File("w"))
@click.option("-t", "--types", type=click.File("r"), multiple=True) @click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-v", "--verbose", is_flag=True)
@click.argument("file", type=click.File("r")) @click.argument("file", type=click.File("r"))
def compile(highlight: Optional[TextIO], file: TextIO, types: tuple[TextIO]): def compile(
logging.basicConfig(level=logging.DEBUG) highlight: Optional[TextIO],
types: tuple[TextIO],
verbose: bool,
file: TextIO,
):
logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN)
source: str = file.read() source: str = file.read()
tree: ast.Module = ast.parse(source, filename=file.name) tree: ast.Module = ast.parse(source, filename=file.name)
parser = PythonParser() parser = PythonParser()
@@ -51,9 +110,11 @@ def compile(highlight: Optional[TextIO], file: TextIO, types: tuple[TextIO]):
types_paths=types_paths, types_paths=types_paths,
) )
diagnostics: list[Diagnostic] = checker.check(stmts) diagnostics: list[Diagnostic] = checker.check(stmts)
lines: list[str] = source.split("\n")
for diagnostic in diagnostics: for diagnostic in diagnostics:
print(diagnostic) print_diagnostic(lines, diagnostic)
if verbose:
print( print(
json.dumps( json.dumps(
UniversalJSONDumper.dump( UniversalJSONDumper.dump(

View File

@@ -3,6 +3,8 @@ from typing import Optional
import midas.ast.midas as m import midas.ast.midas as m
from midas.checker.types import ( from midas.checker.types import (
AliasType, AliasType,
ComplexType,
Operation,
Type, Type,
UnknownType, UnknownType,
) )
@@ -14,7 +16,7 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
def __init__(self) -> None: def __init__(self) -> None:
self._types: dict[str, Type] = {} self._types: dict[str, Type] = {}
self._operations: dict[tuple[Type, str, Type], Type] = {} self._operations: dict[Operation.CallSignature, Type] = {}
define_builtins(self) define_builtins(self)
@@ -48,10 +50,26 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
Returns: Returns:
Optional[Type]: the result type, or None if no matching operation was found Optional[Type]: the result type, or None if no matching operation was found
""" """
operation: tuple[Type, str, Type] = (left, operator, right) signature: Operation.CallSignature = Operation.CallSignature(
result: Optional[Type] = self._operations.get(operation) left=left,
method=operator,
right=right,
)
result: Optional[Type] = self._operations.get(signature)
return result return result
def get_operations_by_name(self, name: str) -> list[Operation]:
operations: list[Operation] = []
for signature, result in self._operations.items():
if signature.method == name:
operations.append(
Operation(
signature=signature,
result=result,
)
)
return operations
def define_type(self, name: str, type: Type) -> Type: def define_type(self, name: str, type: Type) -> Type:
"""Define a type in the registry """Define a type in the registry
@@ -82,12 +100,16 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
Raises: Raises:
ValueError: if an operation is already defined with these operands and name ValueError: if an operation is already defined with these operands and name
""" """
operation: tuple[Type, str, Type] = (left, operator, right) signature: Operation.CallSignature = Operation.CallSignature(
if operation in self._operations: left=left,
method=operator,
right=right,
)
if signature in self._operations:
raise ValueError( raise ValueError(
f"Operation {operator} already defined between {left} and {right}" f"Operation {operator} already defined between {left} and {right}"
) )
self._operations[operation] = result self._operations[signature] = result
def resolve(self, stmts: list[m.Stmt]): def resolve(self, stmts: list[m.Stmt]):
"""Process a sequence of statements """Process a sequence of statements
@@ -157,7 +179,8 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
return UnknownType() return UnknownType()
def visit_complex_type(self, type: m.ComplexType) -> Type: def visit_complex_type(self, type: m.ComplexType) -> Type:
for prop in type.properties: return ComplexType(
prop.accept(self) properties={
# TODO prop.name.lexeme: prop.type.accept(self) for prop in type.properties
return UnknownType() }
)

View File

@@ -111,9 +111,8 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
self.resolve(stmt.value) self.resolve(stmt.value)
for target in stmt.targets: for target in stmt.targets:
match target: match target:
case p.VariableExpr(name=name): case p.VariableExpr() | p.GetExpr():
self.resolve_local(target, name) target.accept(self)
# TODO: declare if not found
case _: case _:
raise Exception(f"Unsupported assignment to {target}") raise Exception(f"Unsupported assignment to {target}")
@@ -174,10 +173,6 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
self.resolve(expr.left) self.resolve(expr.left)
self.resolve(expr.right) self.resolve(expr.right)
def visit_set_expr(self, expr: p.SetExpr) -> None:
self.resolve(expr.value)
self.resolve(expr.object)
def visit_cast_expr(self, expr: p.CastExpr) -> None: def visit_cast_expr(self, expr: p.CastExpr) -> None:
self.resolve(expr.expr) self.resolve(expr.expr)

View File

@@ -29,7 +29,7 @@ class Tester(ABC):
def _list_tests(self) -> list[Path]: ... def _list_tests(self) -> list[Path]: ...
def run_all_tests(self) -> bool: def run_all_tests(self) -> bool:
paths: list[Path] = self._list_tests() paths: list[Path] = sorted(self._list_tests())
return self.run_tests(paths) return self.run_tests(paths)
def run_tests(self, tests: list[Path]) -> bool: def run_tests(self, tests: list[Path]) -> bool:
@@ -40,7 +40,7 @@ class Tester(ABC):
print(rule) print(rule)
for i, test in enumerate(tests): for i, test in enumerate(tests):
print(f"Case {i+1}/{n}: {test.relative_to(self.CASES_DIR)}") print(f"Case {i+1}/{n}: {test.resolve().relative_to(self.CASES_DIR)}")
success: bool = self._run_test(test) success: bool = self._run_test(test)
if success: if success:
successes += 1 successes += 1
@@ -78,7 +78,7 @@ class Tester(ABC):
def _exec_case(self, path: Path) -> CaseResult: ... def _exec_case(self, path: Path) -> CaseResult: ...
def update_all_tests(self): def update_all_tests(self):
paths: list[Path] = self._list_tests() paths: list[Path] = sorted(self._list_tests())
return self.update_tests(paths) return self.update_tests(paths)
def update_tests(self, tests: list[Path]): def update_tests(self, tests: list[Path]):

View File

@@ -1,3 +1,4 @@
{ {
"diagnostics": [] "diagnostics": [],
"judgments": []
} }

View File

@@ -13,34 +13,167 @@
] ]
}, },
"message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')" "message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')"
}
],
"judgments": [
{
"location": {
"from": "L1:9",
"to": "L1:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 3
},
"type": {
"name": "int"
}
}, },
{ {
"type": "Error",
"location": { "location": {
"start": [ "from": "L2:9",
9, "to": "L2:10"
4
],
"end": [
9,
9
]
}, },
"message": "Undefined operation __add__ between BaseType(name='bool') and BaseType(name='bool')" "expr": {
"_type": "LiteralExpr",
"value": 4
},
"type": {
"name": "int"
}
}, },
{ {
"type": "Error",
"location": { "location": {
"start": [ "from": "L4:4",
11, "to": "L4:5"
0
],
"end": [
11,
12
]
}, },
"message": "Cannot assign BaseType(name='int') to f of type BaseType(name='float')" "expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L4:8",
"to": "L4:9"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L4:4",
"to": "L4:9"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "a"
},
"operator": "+",
"right": {
"_type": "VariableExpr",
"name": "b"
}
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:4",
"to": "L6:13"
},
"expr": {
"_type": "LiteralExpr",
"value": "invalid"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L8:4",
"to": "L8:8"
},
"expr": {
"_type": "LiteralExpr",
"value": true
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L9:4",
"to": "L9:5"
},
"expr": {
"_type": "VariableExpr",
"name": "d"
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L9:8",
"to": "L9:9"
},
"expr": {
"_type": "VariableExpr",
"name": "d"
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L9:4",
"to": "L9:9"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "d"
},
"operator": "+",
"right": {
"_type": "VariableExpr",
"name": "d"
}
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L11:11",
"to": "L11:12"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
} }
] ]
} }

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,109 @@
{ {
"diagnostics": [] "diagnostics": [],
"judgments": [
{
"location": {
"from": "L4:18",
"to": "L4:37"
},
"expr": {
"_type": "CastExpr",
"type": {
"_type": "BaseType",
"base": "Meter",
"param": null
},
"expr": {
"_type": "LiteralExpr",
"value": 123.45
}
},
"type": {
"name": "Meter",
"type": {
"name": "float"
}
}
},
{
"location": {
"from": "L5:15",
"to": "L5:32"
},
"expr": {
"_type": "CastExpr",
"type": {
"_type": "BaseType",
"base": "Second",
"param": null
},
"expr": {
"_type": "LiteralExpr",
"value": 6.7
}
},
"type": {
"name": "Second",
"type": {
"name": "float"
}
}
},
{
"location": {
"from": "L6:8",
"to": "L6:16"
},
"expr": {
"_type": "VariableExpr",
"name": "distance"
},
"type": {
"name": "Meter",
"type": {
"name": "float"
}
}
},
{
"location": {
"from": "L6:19",
"to": "L6:23"
},
"expr": {
"_type": "VariableExpr",
"name": "time"
},
"type": {
"name": "Second",
"type": {
"name": "float"
}
}
},
{
"location": {
"from": "L6:8",
"to": "L6:23"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "distance"
},
"operator": "/",
"right": {
"_type": "VariableExpr",
"name": "time"
}
},
"type": {
"name": "MeterPerSecond",
"type": {
"name": "float"
}
}
}
]
} }

View File

@@ -42,5 +42,215 @@
}, },
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]" "message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
} }
],
"judgments": [
{
"location": {
"from": "L2:11",
"to": "L2:12"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L2:15",
"to": "L2:16"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L5:7",
"to": "L5:8"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L5:11",
"to": "L5:12"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:15",
"to": "L6:16"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:19",
"to": "L6:20"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L8:15",
"to": "L8:16"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L8:19",
"to": "L8:20"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L15:7",
"to": "L15:8"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L15:11",
"to": "L15:13"
},
"expr": {
"_type": "LiteralExpr",
"value": 10
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L16:15",
"to": "L16:16"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L16:19",
"to": "L16:21"
},
"expr": {
"_type": "LiteralExpr",
"value": 10
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L22:7",
"to": "L22:8"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L22:11",
"to": "L22:12"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L23:15",
"to": "L23:16"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L23:19",
"to": "L23:20"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
}
] ]
} }

View File

@@ -0,0 +1,12 @@
v1: int = 3
v2: float = 4
def maximum(a: float, b: float):
if b > a:
return b
return a
v3 = maximum(v1, v2)
v3 = v1 + v2

View File

@@ -0,0 +1,193 @@
{
"diagnostics": [],
"judgments": [
{
"location": {
"from": "L1:10",
"to": "L1:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 3
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L2:12",
"to": "L2:13"
},
"expr": {
"_type": "LiteralExpr",
"value": 4
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:7",
"to": "L6:8"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L6:11",
"to": "L6:12"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L11:5",
"to": "L11:12"
},
"expr": {
"_type": "VariableExpr",
"name": "maximum"
},
"type": {
"name": "maximum",
"pos_args": [],
"args": [
{
"pos": 0,
"name": "a",
"type": {
"name": "float"
},
"required": true
},
{
"pos": 1,
"name": "b",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L11:13",
"to": "L11:15"
},
"expr": {
"_type": "VariableExpr",
"name": "v1"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L11:17",
"to": "L11:19"
},
"expr": {
"_type": "VariableExpr",
"name": "v2"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L11:5",
"to": "L11:20"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "maximum"
},
"arguments": [
{
"_type": "VariableExpr",
"name": "v1"
},
{
"_type": "VariableExpr",
"name": "v2"
}
],
"keywords": {}
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L12:5",
"to": "L12:7"
},
"expr": {
"_type": "VariableExpr",
"name": "v1"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L12:10",
"to": "L12:12"
},
"expr": {
"_type": "VariableExpr",
"name": "v2"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L12:5",
"to": "L12:12"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "v1"
},
"operator": "+",
"right": {
"_type": "VariableExpr",
"name": "v2"
}
},
"type": {
"name": "float"
}
}
]
}

View File

@@ -6,14 +6,17 @@ from pathlib import Path
import midas.ast.python as p import midas.ast.python as p
from midas.checker.checker import Checker from midas.checker.checker import Checker
from midas.checker.diagnostic import Diagnostic from midas.checker.diagnostic import Diagnostic
from midas.checker.types import Type
from midas.parser.python import PythonParser from midas.parser.python import PythonParser
from midas.resolver.resolver import Resolver from midas.resolver.resolver import Resolver
from tests.base import Tester from tests.base import Tester
from tests.serializer.python import PythonAstJsonSerializer
@dataclass @dataclass
class CaseResult: class CaseResult:
diagnostics: list[dict] = field(default_factory=list) diagnostics: list[dict] = field(default_factory=list)
judgments: list = field(default_factory=list)
def dumps(self) -> str: def dumps(self) -> str:
return json.dumps(asdict(self), indent=2) return json.dumps(asdict(self), indent=2)
@@ -49,6 +52,7 @@ class CheckerTester(Tester):
source_path=path, source_path=path,
types_paths=types_paths, types_paths=types_paths,
) )
diagnostics: list[Diagnostic] = checker.check(stmts) diagnostics: list[Diagnostic] = checker.check(stmts)
for diagnostic in diagnostics: for diagnostic in diagnostics:
result.diagnostics.append( result.diagnostics.append(
@@ -68,6 +72,21 @@ class CheckerTester(Tester):
} }
) )
judgements: list[tuple[p.Expr, Type]] = checker.judgements
serializer = PythonAstJsonSerializer()
for expr, type in judgements:
loc = expr.location
result.judgments.append(
{
"location": {
"from": f"L{loc.lineno}:{loc.col_offset}",
"to": f"L{loc.end_lineno}:{loc.end_col_offset}",
},
"expr": expr.accept(serializer),
"type": asdict(type),
}
)
return result return result

View File

@@ -20,7 +20,6 @@ from midas.ast.python import (
LogicalExpr, LogicalExpr,
MidasType, MidasType,
ReturnStmt, ReturnStmt,
SetExpr,
Stmt, Stmt,
TernaryExpr, TernaryExpr,
TypeAssign, TypeAssign,
@@ -232,14 +231,6 @@ class PythonAstJsonSerializer(
"right": expr.right.accept(self), "right": expr.right.accept(self),
} }
def visit_set_expr(self, expr: SetExpr) -> dict:
return {
"_type": "SetExpr",
"object": expr.object.accept(self),
"name": expr.name,
"value": expr.value.accept(self),
}
def visit_cast_expr(self, expr: CastExpr) -> dict: def visit_cast_expr(self, expr: CastExpr) -> dict:
return { return {
"_type": "CastExpr", "_type": "CastExpr",