Compare commits

...

7 Commits

Author SHA1 Message Date
d0f1178c17 fix(checker): change some diagnostics to warnings
temporarily change type errors in predicates to warnings until operations are fully type checked
2026-06-19 14:41:43 +02:00
0eca23b894 feat(gen): generate type hints for functions 2026-06-19 14:11:38 +02:00
f664fb4a4f feat(gen): handle predicate aliases
handle cases where a predicate is defined as an alias, i.e. without any parameters
2026-06-19 14:05:34 +02:00
32330243c6 fix(parser): fix call expr location span 2026-06-19 13:57:49 +02:00
96e76065cf feat(types): detect constraint base subtyping 2026-06-19 13:57:21 +02:00
7b7d87e59a feat(checker): type check predicate body 2026-06-19 13:55:32 +02:00
1eb90164e6 fix(gen): remove id from named predicate function 2026-06-19 10:15:09 +02:00
6 changed files with 223 additions and 36 deletions

View File

@@ -31,7 +31,7 @@ class TypedParamSpec:
kw: list[Function.Argument]
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry"""
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
@@ -41,6 +41,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
self.types: TypesRegistry = types
self._local_variables: dict[str, TypeVar] = {}
self._predicate_params: dict[str, Type] = {}
self._current_name: Optional[str] = None
define_builtins(self.types)
@@ -59,6 +61,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
self.reporter.error(error.token.get_location(), error.message)
self.resolve(stmts)
def type_of(self, expr: m.Expr) -> Type:
type: Type = expr.accept(self)
return type
def get_type(self, name: str) -> Type:
"""Get a type from its name
@@ -75,6 +81,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
return self._local_variables[name]
return self.types.get_type(name)
def get_variable(self, name: str) -> Type:
if name in self._predicate_params:
return self._predicate_params[name]
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is None:
raise NameError(f"Unknown variable '{name}'")
return predicate.type
def resolve(self, stmts: list[m.Stmt]):
"""Process a sequence of statements
@@ -84,6 +98,12 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
for stmt in stmts:
stmt.accept(self)
def assert_bool(self, expr: m.Expr):
type: Type = self.type_of(expr)
if not self.types.is_subtype(type, self._bool):
# TODO: change back to error when operations are type checked
self.reporter.warning(expr.location, f"Must be a boolean but is {type}")
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
name: str = stmt.name.lexeme
self._current_name = name
@@ -118,51 +138,125 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
for spec in stmt.params:
for param in spec.mixed:
assert param.name is not None
self._predicate_params[param.name.lexeme] = param.type.accept(self)
type: Type = self.type_of(stmt.body)
params: list[TypedParamSpec] = [
self._visit_param_spec(spec) for spec in stmt.params
]
type: Type = self._bool
for spec in reversed(params):
type = Function(
pos_args=spec.pos,
args=spec.mixed,
kw_args=spec.kw,
returns=type,
if not self._is_valid_predicate(type):
# TODO: change back to error when operations are type checked
self.reporter.warning(
stmt.body.location,
f"Predicate function body must evaluate to a boolean, got {type}",
)
if len(params) != 0:
type = self._bool
for spec in reversed(params):
type = Function(
pos_args=spec.pos,
args=spec.mixed,
kw_args=spec.kw,
returns=type,
)
self._predicate_params = {}
self.types.define_predicate(
stmt.name.lexeme,
Predicate(
type=type,
body=stmt.body,
alias=len(params) == 0,
),
)
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
def _is_valid_predicate(self, body: Type) -> bool:
match body:
case Function(returns=returns):
return self._is_valid_predicate(returns)
case _ if self.types.is_subtype(body, self._bool):
return True
case _:
return False
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type:
self.assert_bool(expr.left)
self.assert_bool(expr.right)
return self._bool
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
# TODO
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
return UnknownType()
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
# TODO
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
return UnknownType()
def visit_call_expr(self, expr: m.CallExpr) -> None:
self.reporter.warning(expr.location, "CallExpr not yet supported")
def visit_call_expr(self, expr: m.CallExpr) -> Type:
callee: Type = expr.callee.accept(self)
if not isinstance(callee, Function):
self.reporter.error(expr.location, f"Cannot call {callee}")
return UnknownType()
args: list[Type] = [arg.accept(self) for arg in expr.arguments]
def visit_get_expr(self, expr: m.GetExpr) -> None:
self.reporter.warning(expr.location, "GetExpr not yet supported")
n_args: int = len(args)
n_params: int = len(callee.args)
if n_args != n_params:
self.reporter.error(
expr.location,
f"Wrong number of argument, expected {n_params}, got {n_args}",
)
return UnknownType()
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
self.reporter.warning(expr.location, "VariableExpr not yet supported")
valid: bool = True
for arg, param in zip(args, callee.args):
if not self.types.is_subtype(arg, param.type):
self.reporter.error(
expr.location,
f"Invalid argument type at pos {param.pos}, expected {param.type}, got {arg}",
)
valid = False
if not valid:
return UnknownType()
return callee.returns
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
def visit_get_expr(self, expr: m.GetExpr) -> Type:
object: Type = expr.expr.accept(self)
member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme)
if member is None:
self.reporter.error(
expr.location, f"Unknown member '{expr.name}' of {object}"
)
return UnknownType()
return member
def visit_variable_expr(self, expr: m.VariableExpr) -> Type:
return self.get_variable(expr.name.lexeme)
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
return expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
self.reporter.warning(expr.location, "LiteralExpr not yet supported")
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type:
match expr.value:
case bool(): # Must be before int
return self.types.get_type("bool")
case int():
return self.types.get_type("int")
case float():
return self.types.get_type("float")
case str():
return self.types.get_type("str")
case _:
self.reporter.warning(expr.location, f"Unknown literal {expr}")
return UnknownType()
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
self.reporter.warning(expr.location, "WildcardExpr not yet supported")
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type:
return self.get_variable("_")
def visit_named_type(self, type: m.NamedType) -> Type:
name: str = type.name.lexeme

View File

@@ -652,7 +652,7 @@ class PythonTyper(
If the function has overloads, the function will try to resolve the
appropriate signature.
Argument types are matched to the defined parameters.
The function doesn't take the raw expression as a parameter to accomodate
The function doesn't take the raw expression as a parameter to accommodate
for desugared calls such as for operators.
Args:

View File

@@ -7,6 +7,7 @@ from midas.checker.types import (
AppliedType,
BaseType,
ComplexType,
ConstraintType,
ExtensionType,
Function,
GenericType,
@@ -130,6 +131,9 @@ class TypesRegistry:
return False
return self.is_subtype(bound, type2)
case (ConstraintType(type=base1), _):
return self.is_subtype(base1, type2)
return False
# TODO: verify the logic in here

View File

@@ -238,10 +238,63 @@ def unfold_type(type: Type) -> Type:
return type
def to_annotation(type: Type) -> str:
def _args_annotation(func: Function) -> str:
if len(func.kw_args) != 0:
return "..."
args: str = ", ".join(
to_annotation(arg.type) for arg in func.pos_args + func.args
)
return f"[{args}]"
match type:
case TopType():
return "Any"
case BaseType(name=name):
return name
case AliasType(name=name):
return name
case UnknownType():
return "Any"
case UnitType():
return "None"
case Function(returns=returns):
params_annot: str = _args_annotation(type)
return f"Callable[{params_annot}, {to_annotation(returns)}]"
case OverloadedFunction():
return "Callable"
case ComplexType() | ExtensionType():
raise NotImplementedError
case TypeVar(name=name):
return name
case GenericType(name=name, params=params):
return f"{name}[{', '.join(map(to_annotation, params))}]"
case AppliedType(name=name, args=args):
return f"{name}[{', '.join(map(to_annotation, args))}]"
case ConstraintType():
return str(type)
case _:
assert_never(type)
@dataclass(frozen=True, kw_only=True)
class Predicate:
type: Type
body: m.Expr
alias: bool
Type = (

View File

@@ -3,7 +3,12 @@ from typing import Optional
import midas.ast.midas as m
from midas.checker.registry import TypesRegistry
from midas.checker.types import Function, Predicate, Type
from midas.checker.types import (
Function,
Predicate,
Type,
to_annotation,
)
from midas.lexer.token import TokenType
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
@@ -63,26 +68,55 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
)
alias: str = self.make_alias(None)
definition: ast.stmt = self.make_definition(
alias, Predicate(type=func, body=expr)
alias, Predicate(type=func, body=expr, alias=False)
)
self._definitions.append(definition)
return ast.Name(id=alias)
def make_alias(self, name: Optional[str]) -> str:
suffix: str = f"_{name}" if name is not None else ""
alias: str = f"__midas_p{self._id}{suffix}__"
self._id += 1
suffix: str
if name is None:
suffix = f"p{self._id}"
self._id += 1
else:
suffix = name
alias: str = f"__midas_{suffix}__"
return alias
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
body: list[ast.stmt] = [ast.Return(value=predicate.body.accept(self))]
return self.make_func(name, body, predicate.type)
body: ast.expr = predicate.body.accept(self)
if predicate.alias:
return ast.Assign(
targets=[
ast.Name(id=name),
],
value=body,
)
return self.make_func(name, [ast.Return(value=body)], predicate.type)
def make_args(self, func: Function) -> ast.arguments:
return ast.arguments(
posonlyargs=[ast.arg(arg=arg.name) for arg in func.pos_args],
args=[ast.arg(arg=arg.name) for arg in func.args],
kwonlyargs=[ast.arg(arg=arg.name) for arg in func.kw_args],
posonlyargs=[
ast.arg(
arg=arg.name,
annotation=ast.Constant(value=to_annotation(arg.type)),
)
for arg in func.pos_args
],
args=[
ast.arg(
arg=arg.name,
annotation=ast.Constant(value=to_annotation(arg.type)),
)
for arg in func.args
],
kwonlyargs=[
ast.arg(
arg=arg.name,
annotation=ast.Constant(value=to_annotation(arg.type)),
)
for arg in func.kw_args
],
defaults=[],
kw_defaults=[],
)
@@ -100,6 +134,7 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
self.make_func(inner_name, inner_body, type.returns, level + 1),
ast.Return(value=ast.Name(id=inner_name)),
],
returns=ast.Constant(value=to_annotation(type.returns)),
decorator_list=[],
)
@@ -108,11 +143,12 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
name=name,
args=self.make_args(type),
body=inner_body,
returns=ast.Constant(value=to_annotation(type.returns)),
decorator_list=[],
)
case _:
raise ValueError(f"Expected function, got {type}")
raise ValueError(f"Expected function, got {type!r}")
def get_predicate(self, name: str) -> Optional[ast.expr]:
if name not in self._aliases:

View File

@@ -380,7 +380,7 @@ class MidasParser(Parser):
TokenType.RIGHT_PAREN, "Expected ')' after arguments."
)
return CallExpr(
location=l_paren.location_to(r_paren),
location=Location.span(callee.location, r_paren.get_location()),
callee=callee,
arguments=pos_args,
keywords=kw_args,