Compare commits
7 Commits
35ec0d0db8
...
d0f1178c17
| Author | SHA1 | Date | |
|---|---|---|---|
|
d0f1178c17
|
|||
|
0eca23b894
|
|||
|
f664fb4a4f
|
|||
|
32330243c6
|
|||
|
96e76065cf
|
|||
|
7b7d87e59a
|
|||
|
1eb90164e6
|
@@ -31,7 +31,7 @@ class TypedParamSpec:
|
|||||||
kw: list[Function.Argument]
|
kw: list[Function.Argument]
|
||||||
|
|
||||||
|
|
||||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
|
||||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||||
|
|
||||||
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||||
@@ -41,6 +41,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
self.types: TypesRegistry = types
|
self.types: TypesRegistry = types
|
||||||
self._local_variables: dict[str, TypeVar] = {}
|
self._local_variables: dict[str, TypeVar] = {}
|
||||||
|
|
||||||
|
self._predicate_params: dict[str, Type] = {}
|
||||||
|
|
||||||
self._current_name: Optional[str] = None
|
self._current_name: Optional[str] = None
|
||||||
|
|
||||||
define_builtins(self.types)
|
define_builtins(self.types)
|
||||||
@@ -59,6 +61,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
self.reporter.error(error.token.get_location(), error.message)
|
self.reporter.error(error.token.get_location(), error.message)
|
||||||
self.resolve(stmts)
|
self.resolve(stmts)
|
||||||
|
|
||||||
|
def type_of(self, expr: m.Expr) -> Type:
|
||||||
|
type: Type = expr.accept(self)
|
||||||
|
return type
|
||||||
|
|
||||||
def get_type(self, name: str) -> Type:
|
def get_type(self, name: str) -> Type:
|
||||||
"""Get a type from its name
|
"""Get a type from its name
|
||||||
|
|
||||||
@@ -75,6 +81,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
return self._local_variables[name]
|
return self._local_variables[name]
|
||||||
return self.types.get_type(name)
|
return self.types.get_type(name)
|
||||||
|
|
||||||
|
def get_variable(self, name: str) -> Type:
|
||||||
|
if name in self._predicate_params:
|
||||||
|
return self._predicate_params[name]
|
||||||
|
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
|
||||||
|
if predicate is None:
|
||||||
|
raise NameError(f"Unknown variable '{name}'")
|
||||||
|
return predicate.type
|
||||||
|
|
||||||
def resolve(self, stmts: list[m.Stmt]):
|
def resolve(self, stmts: list[m.Stmt]):
|
||||||
"""Process a sequence of statements
|
"""Process a sequence of statements
|
||||||
|
|
||||||
@@ -84,6 +98,12 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
for stmt in stmts:
|
for stmt in stmts:
|
||||||
stmt.accept(self)
|
stmt.accept(self)
|
||||||
|
|
||||||
|
def assert_bool(self, expr: m.Expr):
|
||||||
|
type: Type = self.type_of(expr)
|
||||||
|
if not self.types.is_subtype(type, self._bool):
|
||||||
|
# TODO: change back to error when operations are type checked
|
||||||
|
self.reporter.warning(expr.location, f"Must be a boolean but is {type}")
|
||||||
|
|
||||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||||
name: str = stmt.name.lexeme
|
name: str = stmt.name.lexeme
|
||||||
self._current_name = name
|
self._current_name = name
|
||||||
@@ -118,10 +138,24 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
)
|
)
|
||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||||
|
for spec in stmt.params:
|
||||||
|
for param in spec.mixed:
|
||||||
|
assert param.name is not None
|
||||||
|
self._predicate_params[param.name.lexeme] = param.type.accept(self)
|
||||||
|
|
||||||
|
type: Type = self.type_of(stmt.body)
|
||||||
params: list[TypedParamSpec] = [
|
params: list[TypedParamSpec] = [
|
||||||
self._visit_param_spec(spec) for spec in stmt.params
|
self._visit_param_spec(spec) for spec in stmt.params
|
||||||
]
|
]
|
||||||
type: Type = self._bool
|
|
||||||
|
if not self._is_valid_predicate(type):
|
||||||
|
# TODO: change back to error when operations are type checked
|
||||||
|
self.reporter.warning(
|
||||||
|
stmt.body.location,
|
||||||
|
f"Predicate function body must evaluate to a boolean, got {type}",
|
||||||
|
)
|
||||||
|
if len(params) != 0:
|
||||||
|
type = self._bool
|
||||||
for spec in reversed(params):
|
for spec in reversed(params):
|
||||||
type = Function(
|
type = Function(
|
||||||
pos_args=spec.pos,
|
pos_args=spec.pos,
|
||||||
@@ -129,40 +163,100 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
kw_args=spec.kw,
|
kw_args=spec.kw,
|
||||||
returns=type,
|
returns=type,
|
||||||
)
|
)
|
||||||
|
self._predicate_params = {}
|
||||||
self.types.define_predicate(
|
self.types.define_predicate(
|
||||||
stmt.name.lexeme,
|
stmt.name.lexeme,
|
||||||
Predicate(
|
Predicate(
|
||||||
type=type,
|
type=type,
|
||||||
body=stmt.body,
|
body=stmt.body,
|
||||||
|
alias=len(params) == 0,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
def _is_valid_predicate(self, body: Type) -> bool:
|
||||||
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
|
match body:
|
||||||
|
case Function(returns=returns):
|
||||||
|
return self._is_valid_predicate(returns)
|
||||||
|
case _ if self.types.is_subtype(body, self._bool):
|
||||||
|
return True
|
||||||
|
case _:
|
||||||
|
return False
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
|
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type:
|
||||||
|
self.assert_bool(expr.left)
|
||||||
|
self.assert_bool(expr.right)
|
||||||
|
return self._bool
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
|
||||||
|
# TODO
|
||||||
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
|
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
|
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
||||||
|
# TODO
|
||||||
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
|
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
def visit_call_expr(self, expr: m.CallExpr) -> None:
|
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
||||||
self.reporter.warning(expr.location, "CallExpr not yet supported")
|
callee: Type = expr.callee.accept(self)
|
||||||
|
if not isinstance(callee, Function):
|
||||||
|
self.reporter.error(expr.location, f"Cannot call {callee}")
|
||||||
|
return UnknownType()
|
||||||
|
args: list[Type] = [arg.accept(self) for arg in expr.arguments]
|
||||||
|
|
||||||
def visit_get_expr(self, expr: m.GetExpr) -> None:
|
n_args: int = len(args)
|
||||||
self.reporter.warning(expr.location, "GetExpr not yet supported")
|
n_params: int = len(callee.args)
|
||||||
|
if n_args != n_params:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Wrong number of argument, expected {n_params}, got {n_args}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
|
valid: bool = True
|
||||||
self.reporter.warning(expr.location, "VariableExpr not yet supported")
|
for arg, param in zip(args, callee.args):
|
||||||
|
if not self.types.is_subtype(arg, param.type):
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Invalid argument type at pos {param.pos}, expected {param.type}, got {arg}",
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
if not valid:
|
||||||
|
return UnknownType()
|
||||||
|
return callee.returns
|
||||||
|
|
||||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
||||||
|
object: Type = expr.expr.accept(self)
|
||||||
|
member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme)
|
||||||
|
if member is None:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Unknown member '{expr.name}' of {object}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return member
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: m.VariableExpr) -> Type:
|
||||||
|
return self.get_variable(expr.name.lexeme)
|
||||||
|
|
||||||
|
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
|
||||||
return expr.expr.accept(self)
|
return expr.expr.accept(self)
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type:
|
||||||
self.reporter.warning(expr.location, "LiteralExpr not yet supported")
|
match expr.value:
|
||||||
|
case bool(): # Must be before int
|
||||||
|
return self.types.get_type("bool")
|
||||||
|
case int():
|
||||||
|
return self.types.get_type("int")
|
||||||
|
case float():
|
||||||
|
return self.types.get_type("float")
|
||||||
|
case str():
|
||||||
|
return self.types.get_type("str")
|
||||||
|
case _:
|
||||||
|
self.reporter.warning(expr.location, f"Unknown literal {expr}")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type:
|
||||||
self.reporter.warning(expr.location, "WildcardExpr not yet supported")
|
return self.get_variable("_")
|
||||||
|
|
||||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||||
name: str = type.name.lexeme
|
name: str = type.name.lexeme
|
||||||
|
|||||||
@@ -652,7 +652,7 @@ class PythonTyper(
|
|||||||
If the function has overloads, the function will try to resolve the
|
If the function has overloads, the function will try to resolve the
|
||||||
appropriate signature.
|
appropriate signature.
|
||||||
Argument types are matched to the defined parameters.
|
Argument types are matched to the defined parameters.
|
||||||
The function doesn't take the raw expression as a parameter to accomodate
|
The function doesn't take the raw expression as a parameter to accommodate
|
||||||
for desugared calls such as for operators.
|
for desugared calls such as for operators.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from midas.checker.types import (
|
|||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
|
ConstraintType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
@@ -130,6 +131,9 @@ class TypesRegistry:
|
|||||||
return False
|
return False
|
||||||
return self.is_subtype(bound, type2)
|
return self.is_subtype(bound, type2)
|
||||||
|
|
||||||
|
case (ConstraintType(type=base1), _):
|
||||||
|
return self.is_subtype(base1, type2)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# TODO: verify the logic in here
|
# TODO: verify the logic in here
|
||||||
|
|||||||
@@ -238,10 +238,63 @@ def unfold_type(type: Type) -> Type:
|
|||||||
return type
|
return type
|
||||||
|
|
||||||
|
|
||||||
|
def to_annotation(type: Type) -> str:
|
||||||
|
def _args_annotation(func: Function) -> str:
|
||||||
|
if len(func.kw_args) != 0:
|
||||||
|
return "..."
|
||||||
|
|
||||||
|
args: str = ", ".join(
|
||||||
|
to_annotation(arg.type) for arg in func.pos_args + func.args
|
||||||
|
)
|
||||||
|
return f"[{args}]"
|
||||||
|
|
||||||
|
match type:
|
||||||
|
case TopType():
|
||||||
|
return "Any"
|
||||||
|
|
||||||
|
case BaseType(name=name):
|
||||||
|
return name
|
||||||
|
|
||||||
|
case AliasType(name=name):
|
||||||
|
return name
|
||||||
|
|
||||||
|
case UnknownType():
|
||||||
|
return "Any"
|
||||||
|
|
||||||
|
case UnitType():
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
case Function(returns=returns):
|
||||||
|
params_annot: str = _args_annotation(type)
|
||||||
|
return f"Callable[{params_annot}, {to_annotation(returns)}]"
|
||||||
|
|
||||||
|
case OverloadedFunction():
|
||||||
|
return "Callable"
|
||||||
|
|
||||||
|
case ComplexType() | ExtensionType():
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
case TypeVar(name=name):
|
||||||
|
return name
|
||||||
|
|
||||||
|
case GenericType(name=name, params=params):
|
||||||
|
return f"{name}[{', '.join(map(to_annotation, params))}]"
|
||||||
|
|
||||||
|
case AppliedType(name=name, args=args):
|
||||||
|
return f"{name}[{', '.join(map(to_annotation, args))}]"
|
||||||
|
|
||||||
|
case ConstraintType():
|
||||||
|
return str(type)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
assert_never(type)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class Predicate:
|
class Predicate:
|
||||||
type: Type
|
type: Type
|
||||||
body: m.Expr
|
body: m.Expr
|
||||||
|
alias: bool
|
||||||
|
|
||||||
|
|
||||||
Type = (
|
Type = (
|
||||||
|
|||||||
@@ -3,7 +3,12 @@ from typing import Optional
|
|||||||
|
|
||||||
import midas.ast.midas as m
|
import midas.ast.midas as m
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.types import Function, Predicate, Type
|
from midas.checker.types import (
|
||||||
|
Function,
|
||||||
|
Predicate,
|
||||||
|
Type,
|
||||||
|
to_annotation,
|
||||||
|
)
|
||||||
from midas.lexer.token import TokenType
|
from midas.lexer.token import TokenType
|
||||||
|
|
||||||
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
|
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
|
||||||
@@ -63,26 +68,55 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
|||||||
)
|
)
|
||||||
alias: str = self.make_alias(None)
|
alias: str = self.make_alias(None)
|
||||||
definition: ast.stmt = self.make_definition(
|
definition: ast.stmt = self.make_definition(
|
||||||
alias, Predicate(type=func, body=expr)
|
alias, Predicate(type=func, body=expr, alias=False)
|
||||||
)
|
)
|
||||||
self._definitions.append(definition)
|
self._definitions.append(definition)
|
||||||
return ast.Name(id=alias)
|
return ast.Name(id=alias)
|
||||||
|
|
||||||
def make_alias(self, name: Optional[str]) -> str:
|
def make_alias(self, name: Optional[str]) -> str:
|
||||||
suffix: str = f"_{name}" if name is not None else ""
|
suffix: str
|
||||||
alias: str = f"__midas_p{self._id}{suffix}__"
|
if name is None:
|
||||||
|
suffix = f"p{self._id}"
|
||||||
self._id += 1
|
self._id += 1
|
||||||
|
else:
|
||||||
|
suffix = name
|
||||||
|
alias: str = f"__midas_{suffix}__"
|
||||||
return alias
|
return alias
|
||||||
|
|
||||||
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
|
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
|
||||||
body: list[ast.stmt] = [ast.Return(value=predicate.body.accept(self))]
|
body: ast.expr = predicate.body.accept(self)
|
||||||
return self.make_func(name, body, predicate.type)
|
if predicate.alias:
|
||||||
|
return ast.Assign(
|
||||||
|
targets=[
|
||||||
|
ast.Name(id=name),
|
||||||
|
],
|
||||||
|
value=body,
|
||||||
|
)
|
||||||
|
return self.make_func(name, [ast.Return(value=body)], predicate.type)
|
||||||
|
|
||||||
def make_args(self, func: Function) -> ast.arguments:
|
def make_args(self, func: Function) -> ast.arguments:
|
||||||
return ast.arguments(
|
return ast.arguments(
|
||||||
posonlyargs=[ast.arg(arg=arg.name) for arg in func.pos_args],
|
posonlyargs=[
|
||||||
args=[ast.arg(arg=arg.name) for arg in func.args],
|
ast.arg(
|
||||||
kwonlyargs=[ast.arg(arg=arg.name) for arg in func.kw_args],
|
arg=arg.name,
|
||||||
|
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||||
|
)
|
||||||
|
for arg in func.pos_args
|
||||||
|
],
|
||||||
|
args=[
|
||||||
|
ast.arg(
|
||||||
|
arg=arg.name,
|
||||||
|
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||||
|
)
|
||||||
|
for arg in func.args
|
||||||
|
],
|
||||||
|
kwonlyargs=[
|
||||||
|
ast.arg(
|
||||||
|
arg=arg.name,
|
||||||
|
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||||
|
)
|
||||||
|
for arg in func.kw_args
|
||||||
|
],
|
||||||
defaults=[],
|
defaults=[],
|
||||||
kw_defaults=[],
|
kw_defaults=[],
|
||||||
)
|
)
|
||||||
@@ -100,6 +134,7 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
|||||||
self.make_func(inner_name, inner_body, type.returns, level + 1),
|
self.make_func(inner_name, inner_body, type.returns, level + 1),
|
||||||
ast.Return(value=ast.Name(id=inner_name)),
|
ast.Return(value=ast.Name(id=inner_name)),
|
||||||
],
|
],
|
||||||
|
returns=ast.Constant(value=to_annotation(type.returns)),
|
||||||
decorator_list=[],
|
decorator_list=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -108,11 +143,12 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
|||||||
name=name,
|
name=name,
|
||||||
args=self.make_args(type),
|
args=self.make_args(type),
|
||||||
body=inner_body,
|
body=inner_body,
|
||||||
|
returns=ast.Constant(value=to_annotation(type.returns)),
|
||||||
decorator_list=[],
|
decorator_list=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Expected function, got {type}")
|
raise ValueError(f"Expected function, got {type!r}")
|
||||||
|
|
||||||
def get_predicate(self, name: str) -> Optional[ast.expr]:
|
def get_predicate(self, name: str) -> Optional[ast.expr]:
|
||||||
if name not in self._aliases:
|
if name not in self._aliases:
|
||||||
|
|||||||
@@ -380,7 +380,7 @@ class MidasParser(Parser):
|
|||||||
TokenType.RIGHT_PAREN, "Expected ')' after arguments."
|
TokenType.RIGHT_PAREN, "Expected ')' after arguments."
|
||||||
)
|
)
|
||||||
return CallExpr(
|
return CallExpr(
|
||||||
location=l_paren.location_to(r_paren),
|
location=Location.span(callee.location, r_paren.get_location()),
|
||||||
callee=callee,
|
callee=callee,
|
||||||
arguments=pos_args,
|
arguments=pos_args,
|
||||||
keywords=kw_args,
|
keywords=kw_args,
|
||||||
|
|||||||
Reference in New Issue
Block a user