Compare commits
7 Commits
35ec0d0db8
...
d0f1178c17
| Author | SHA1 | Date | |
|---|---|---|---|
|
d0f1178c17
|
|||
|
0eca23b894
|
|||
|
f664fb4a4f
|
|||
|
32330243c6
|
|||
|
96e76065cf
|
|||
|
7b7d87e59a
|
|||
|
1eb90164e6
|
@@ -31,7 +31,7 @@ class TypedParamSpec:
|
||||
kw: list[Function.Argument]
|
||||
|
||||
|
||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
|
||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||
|
||||
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||
@@ -41,6 +41,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
self.types: TypesRegistry = types
|
||||
self._local_variables: dict[str, TypeVar] = {}
|
||||
|
||||
self._predicate_params: dict[str, Type] = {}
|
||||
|
||||
self._current_name: Optional[str] = None
|
||||
|
||||
define_builtins(self.types)
|
||||
@@ -59,6 +61,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
self.reporter.error(error.token.get_location(), error.message)
|
||||
self.resolve(stmts)
|
||||
|
||||
def type_of(self, expr: m.Expr) -> Type:
|
||||
type: Type = expr.accept(self)
|
||||
return type
|
||||
|
||||
def get_type(self, name: str) -> Type:
|
||||
"""Get a type from its name
|
||||
|
||||
@@ -75,6 +81,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
return self._local_variables[name]
|
||||
return self.types.get_type(name)
|
||||
|
||||
def get_variable(self, name: str) -> Type:
|
||||
if name in self._predicate_params:
|
||||
return self._predicate_params[name]
|
||||
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
|
||||
if predicate is None:
|
||||
raise NameError(f"Unknown variable '{name}'")
|
||||
return predicate.type
|
||||
|
||||
def resolve(self, stmts: list[m.Stmt]):
|
||||
"""Process a sequence of statements
|
||||
|
||||
@@ -84,6 +98,12 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
for stmt in stmts:
|
||||
stmt.accept(self)
|
||||
|
||||
def assert_bool(self, expr: m.Expr):
|
||||
type: Type = self.type_of(expr)
|
||||
if not self.types.is_subtype(type, self._bool):
|
||||
# TODO: change back to error when operations are type checked
|
||||
self.reporter.warning(expr.location, f"Must be a boolean but is {type}")
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
name: str = stmt.name.lexeme
|
||||
self._current_name = name
|
||||
@@ -118,10 +138,24 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||
for spec in stmt.params:
|
||||
for param in spec.mixed:
|
||||
assert param.name is not None
|
||||
self._predicate_params[param.name.lexeme] = param.type.accept(self)
|
||||
|
||||
type: Type = self.type_of(stmt.body)
|
||||
params: list[TypedParamSpec] = [
|
||||
self._visit_param_spec(spec) for spec in stmt.params
|
||||
]
|
||||
type: Type = self._bool
|
||||
|
||||
if not self._is_valid_predicate(type):
|
||||
# TODO: change back to error when operations are type checked
|
||||
self.reporter.warning(
|
||||
stmt.body.location,
|
||||
f"Predicate function body must evaluate to a boolean, got {type}",
|
||||
)
|
||||
if len(params) != 0:
|
||||
type = self._bool
|
||||
for spec in reversed(params):
|
||||
type = Function(
|
||||
pos_args=spec.pos,
|
||||
@@ -129,40 +163,100 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
kw_args=spec.kw,
|
||||
returns=type,
|
||||
)
|
||||
self._predicate_params = {}
|
||||
self.types.define_predicate(
|
||||
stmt.name.lexeme,
|
||||
Predicate(
|
||||
type=type,
|
||||
body=stmt.body,
|
||||
alias=len(params) == 0,
|
||||
),
|
||||
)
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
||||
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
|
||||
def _is_valid_predicate(self, body: Type) -> bool:
|
||||
match body:
|
||||
case Function(returns=returns):
|
||||
return self._is_valid_predicate(returns)
|
||||
case _ if self.types.is_subtype(body, self._bool):
|
||||
return True
|
||||
case _:
|
||||
return False
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type:
|
||||
self.assert_bool(expr.left)
|
||||
self.assert_bool(expr.right)
|
||||
return self._bool
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
|
||||
# TODO
|
||||
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
|
||||
return UnknownType()
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
||||
# TODO
|
||||
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
|
||||
return UnknownType()
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> None:
|
||||
self.reporter.warning(expr.location, "CallExpr not yet supported")
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
||||
callee: Type = expr.callee.accept(self)
|
||||
if not isinstance(callee, Function):
|
||||
self.reporter.error(expr.location, f"Cannot call {callee}")
|
||||
return UnknownType()
|
||||
args: list[Type] = [arg.accept(self) for arg in expr.arguments]
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> None:
|
||||
self.reporter.warning(expr.location, "GetExpr not yet supported")
|
||||
n_args: int = len(args)
|
||||
n_params: int = len(callee.args)
|
||||
if n_args != n_params:
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Wrong number of argument, expected {n_params}, got {n_args}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
|
||||
self.reporter.warning(expr.location, "VariableExpr not yet supported")
|
||||
valid: bool = True
|
||||
for arg, param in zip(args, callee.args):
|
||||
if not self.types.is_subtype(arg, param.type):
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Invalid argument type at pos {param.pos}, expected {param.type}, got {arg}",
|
||||
)
|
||||
valid = False
|
||||
if not valid:
|
||||
return UnknownType()
|
||||
return callee.returns
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
||||
object: Type = expr.expr.accept(self)
|
||||
member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme)
|
||||
if member is None:
|
||||
self.reporter.error(
|
||||
expr.location, f"Unknown member '{expr.name}' of {object}"
|
||||
)
|
||||
return UnknownType()
|
||||
return member
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> Type:
|
||||
return self.get_variable(expr.name.lexeme)
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
|
||||
return expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
||||
self.reporter.warning(expr.location, "LiteralExpr not yet supported")
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type:
|
||||
match expr.value:
|
||||
case bool(): # Must be before int
|
||||
return self.types.get_type("bool")
|
||||
case int():
|
||||
return self.types.get_type("int")
|
||||
case float():
|
||||
return self.types.get_type("float")
|
||||
case str():
|
||||
return self.types.get_type("str")
|
||||
case _:
|
||||
self.reporter.warning(expr.location, f"Unknown literal {expr}")
|
||||
return UnknownType()
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||
self.reporter.warning(expr.location, "WildcardExpr not yet supported")
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type:
|
||||
return self.get_variable("_")
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||
name: str = type.name.lexeme
|
||||
|
||||
@@ -652,7 +652,7 @@ class PythonTyper(
|
||||
If the function has overloads, the function will try to resolve the
|
||||
appropriate signature.
|
||||
Argument types are matched to the defined parameters.
|
||||
The function doesn't take the raw expression as a parameter to accomodate
|
||||
The function doesn't take the raw expression as a parameter to accommodate
|
||||
for desugared calls such as for operators.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -7,6 +7,7 @@ from midas.checker.types import (
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
@@ -130,6 +131,9 @@ class TypesRegistry:
|
||||
return False
|
||||
return self.is_subtype(bound, type2)
|
||||
|
||||
case (ConstraintType(type=base1), _):
|
||||
return self.is_subtype(base1, type2)
|
||||
|
||||
return False
|
||||
|
||||
# TODO: verify the logic in here
|
||||
|
||||
@@ -238,10 +238,63 @@ def unfold_type(type: Type) -> Type:
|
||||
return type
|
||||
|
||||
|
||||
def to_annotation(type: Type) -> str:
|
||||
def _args_annotation(func: Function) -> str:
|
||||
if len(func.kw_args) != 0:
|
||||
return "..."
|
||||
|
||||
args: str = ", ".join(
|
||||
to_annotation(arg.type) for arg in func.pos_args + func.args
|
||||
)
|
||||
return f"[{args}]"
|
||||
|
||||
match type:
|
||||
case TopType():
|
||||
return "Any"
|
||||
|
||||
case BaseType(name=name):
|
||||
return name
|
||||
|
||||
case AliasType(name=name):
|
||||
return name
|
||||
|
||||
case UnknownType():
|
||||
return "Any"
|
||||
|
||||
case UnitType():
|
||||
return "None"
|
||||
|
||||
case Function(returns=returns):
|
||||
params_annot: str = _args_annotation(type)
|
||||
return f"Callable[{params_annot}, {to_annotation(returns)}]"
|
||||
|
||||
case OverloadedFunction():
|
||||
return "Callable"
|
||||
|
||||
case ComplexType() | ExtensionType():
|
||||
raise NotImplementedError
|
||||
|
||||
case TypeVar(name=name):
|
||||
return name
|
||||
|
||||
case GenericType(name=name, params=params):
|
||||
return f"{name}[{', '.join(map(to_annotation, params))}]"
|
||||
|
||||
case AppliedType(name=name, args=args):
|
||||
return f"{name}[{', '.join(map(to_annotation, args))}]"
|
||||
|
||||
case ConstraintType():
|
||||
return str(type)
|
||||
|
||||
case _:
|
||||
assert_never(type)
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Predicate:
|
||||
type: Type
|
||||
body: m.Expr
|
||||
alias: bool
|
||||
|
||||
|
||||
Type = (
|
||||
|
||||
@@ -3,7 +3,12 @@ from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import Function, Predicate, Type
|
||||
from midas.checker.types import (
|
||||
Function,
|
||||
Predicate,
|
||||
Type,
|
||||
to_annotation,
|
||||
)
|
||||
from midas.lexer.token import TokenType
|
||||
|
||||
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
|
||||
@@ -63,26 +68,55 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||
)
|
||||
alias: str = self.make_alias(None)
|
||||
definition: ast.stmt = self.make_definition(
|
||||
alias, Predicate(type=func, body=expr)
|
||||
alias, Predicate(type=func, body=expr, alias=False)
|
||||
)
|
||||
self._definitions.append(definition)
|
||||
return ast.Name(id=alias)
|
||||
|
||||
def make_alias(self, name: Optional[str]) -> str:
|
||||
suffix: str = f"_{name}" if name is not None else ""
|
||||
alias: str = f"__midas_p{self._id}{suffix}__"
|
||||
suffix: str
|
||||
if name is None:
|
||||
suffix = f"p{self._id}"
|
||||
self._id += 1
|
||||
else:
|
||||
suffix = name
|
||||
alias: str = f"__midas_{suffix}__"
|
||||
return alias
|
||||
|
||||
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
|
||||
body: list[ast.stmt] = [ast.Return(value=predicate.body.accept(self))]
|
||||
return self.make_func(name, body, predicate.type)
|
||||
body: ast.expr = predicate.body.accept(self)
|
||||
if predicate.alias:
|
||||
return ast.Assign(
|
||||
targets=[
|
||||
ast.Name(id=name),
|
||||
],
|
||||
value=body,
|
||||
)
|
||||
return self.make_func(name, [ast.Return(value=body)], predicate.type)
|
||||
|
||||
def make_args(self, func: Function) -> ast.arguments:
|
||||
return ast.arguments(
|
||||
posonlyargs=[ast.arg(arg=arg.name) for arg in func.pos_args],
|
||||
args=[ast.arg(arg=arg.name) for arg in func.args],
|
||||
kwonlyargs=[ast.arg(arg=arg.name) for arg in func.kw_args],
|
||||
posonlyargs=[
|
||||
ast.arg(
|
||||
arg=arg.name,
|
||||
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||
)
|
||||
for arg in func.pos_args
|
||||
],
|
||||
args=[
|
||||
ast.arg(
|
||||
arg=arg.name,
|
||||
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||
)
|
||||
for arg in func.args
|
||||
],
|
||||
kwonlyargs=[
|
||||
ast.arg(
|
||||
arg=arg.name,
|
||||
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||
)
|
||||
for arg in func.kw_args
|
||||
],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
)
|
||||
@@ -100,6 +134,7 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||
self.make_func(inner_name, inner_body, type.returns, level + 1),
|
||||
ast.Return(value=ast.Name(id=inner_name)),
|
||||
],
|
||||
returns=ast.Constant(value=to_annotation(type.returns)),
|
||||
decorator_list=[],
|
||||
)
|
||||
|
||||
@@ -108,11 +143,12 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||
name=name,
|
||||
args=self.make_args(type),
|
||||
body=inner_body,
|
||||
returns=ast.Constant(value=to_annotation(type.returns)),
|
||||
decorator_list=[],
|
||||
)
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Expected function, got {type}")
|
||||
raise ValueError(f"Expected function, got {type!r}")
|
||||
|
||||
def get_predicate(self, name: str) -> Optional[ast.expr]:
|
||||
if name not in self._aliases:
|
||||
|
||||
@@ -380,7 +380,7 @@ class MidasParser(Parser):
|
||||
TokenType.RIGHT_PAREN, "Expected ')' after arguments."
|
||||
)
|
||||
return CallExpr(
|
||||
location=l_paren.location_to(r_paren),
|
||||
location=Location.span(callee.location, r_paren.get_location()),
|
||||
callee=callee,
|
||||
arguments=pos_args,
|
||||
keywords=kw_args,
|
||||
|
||||
Reference in New Issue
Block a user