feat(checker): map and check function call arguments
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -19,6 +20,13 @@ class ReturnException(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class MappedArgument:
|
||||||
|
expr: p.Expr
|
||||||
|
type: Type
|
||||||
|
argument: Function.Argument
|
||||||
|
|
||||||
|
|
||||||
class Checker(
|
class Checker(
|
||||||
p.Stmt.Visitor[None],
|
p.Stmt.Visitor[None],
|
||||||
p.Expr.Visitor[Type],
|
p.Expr.Visitor[Type],
|
||||||
@@ -126,15 +134,18 @@ class Checker(
|
|||||||
kw_args: list[Function.Argument] = []
|
kw_args: list[Function.Argument] = []
|
||||||
|
|
||||||
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||||
if arg.type is None:
|
if arg.type is not None:
|
||||||
return UnknownType()
|
return arg.type.accept(self)
|
||||||
return arg.type.accept(self)
|
if arg.default is not None:
|
||||||
|
return arg.default.accept(self)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
for arg in stmt.posonlyargs:
|
for arg in stmt.posonlyargs:
|
||||||
pos_args.append(
|
pos_args.append(
|
||||||
Function.Argument(
|
Function.Argument(
|
||||||
name=arg.name,
|
name=arg.name,
|
||||||
type=eval_arg_type(arg),
|
type=eval_arg_type(arg),
|
||||||
|
required=arg.default is None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for arg in stmt.args:
|
for arg in stmt.args:
|
||||||
@@ -142,6 +153,7 @@ class Checker(
|
|||||||
Function.Argument(
|
Function.Argument(
|
||||||
name=arg.name,
|
name=arg.name,
|
||||||
type=eval_arg_type(arg),
|
type=eval_arg_type(arg),
|
||||||
|
required=arg.default is None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for arg in stmt.kwonlyargs:
|
for arg in stmt.kwonlyargs:
|
||||||
@@ -149,6 +161,7 @@ class Checker(
|
|||||||
Function.Argument(
|
Function.Argument(
|
||||||
name=arg.name,
|
name=arg.name,
|
||||||
type=eval_arg_type(arg),
|
type=eval_arg_type(arg),
|
||||||
|
required=arg.default is None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -175,7 +188,9 @@ class Checker(
|
|||||||
else:
|
else:
|
||||||
returns = inferred_return
|
returns = inferred_return
|
||||||
|
|
||||||
|
# TODO: handle *args and **kwargs sinks
|
||||||
function: Function = Function(
|
function: Function = Function(
|
||||||
|
name=stmt.name,
|
||||||
pos_args=pos_args,
|
pos_args=pos_args,
|
||||||
args=args,
|
args=args,
|
||||||
kw_args=kw_args,
|
kw_args=kw_args,
|
||||||
@@ -240,11 +255,18 @@ class Checker(
|
|||||||
self.import_midas(path)
|
self.import_midas(path)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
callee: Type = self.evaluate(expr.callee)
|
callee: Type = self.evaluate(expr.callee)
|
||||||
arguments: list[Type] = [self.evaluate(arg) for arg in expr.arguments]
|
if not isinstance(callee, Function):
|
||||||
keywords: dict[str, Type] = {
|
self.error(expr.callee.location, "Callee is not a function")
|
||||||
name: self.evaluate(arg) for name, arg in expr.keywords.items()
|
return UnknownType()
|
||||||
}
|
function: Function = callee
|
||||||
return UnknownType()
|
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
||||||
|
for arg in mapped:
|
||||||
|
if arg.type != arg.argument.type:
|
||||||
|
self.error(
|
||||||
|
arg.expr.location,
|
||||||
|
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||||
|
)
|
||||||
|
return function.returns
|
||||||
|
|
||||||
def visit_get_expr(self, expr: p.GetExpr) -> Type: ...
|
def visit_get_expr(self, expr: p.GetExpr) -> Type: ...
|
||||||
|
|
||||||
@@ -277,3 +299,87 @@ class Checker(
|
|||||||
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
|
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
|
||||||
|
|
||||||
def visit_frame_type(self, node: p.FrameType) -> Type: ...
|
def visit_frame_type(self, node: p.FrameType) -> Type: ...
|
||||||
|
|
||||||
|
def map_call_arguments(
|
||||||
|
self, function: Function, call: p.CallExpr
|
||||||
|
) -> list[MappedArgument]:
|
||||||
|
positional: list[tuple[p.Expr, Type]] = [
|
||||||
|
(arg, self.evaluate(arg)) for arg in call.arguments
|
||||||
|
]
|
||||||
|
keywords: dict[str, tuple[p.Expr, Type]] = {
|
||||||
|
name: (arg, self.evaluate(arg)) for name, arg in call.keywords.items()
|
||||||
|
}
|
||||||
|
set_args: set[str] = set()
|
||||||
|
|
||||||
|
required_positional: set[str] = {
|
||||||
|
arg.name for arg in function.pos_args + function.args if arg.required
|
||||||
|
}
|
||||||
|
required_keyword: set[str] = {
|
||||||
|
arg.name for arg in function.kw_args if arg.required
|
||||||
|
}
|
||||||
|
|
||||||
|
mapped: list[MappedArgument] = []
|
||||||
|
|
||||||
|
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||||
|
mixed_params: list[Function.Argument] = list(function.args)
|
||||||
|
kw_params: dict[str, Function.Argument] = {
|
||||||
|
arg.name: arg for arg in function.kw_args
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: handle *args and **kwargs sinks
|
||||||
|
for arg in positional:
|
||||||
|
param: Function.Argument
|
||||||
|
if len(pos_params) != 0:
|
||||||
|
param = pos_params.pop(0)
|
||||||
|
elif len(mixed_params) != 0:
|
||||||
|
param = mixed_params.pop(0)
|
||||||
|
else:
|
||||||
|
self.error(arg[0].location, "Too many positional arguments")
|
||||||
|
break
|
||||||
|
required_positional.discard(param.name)
|
||||||
|
required_keyword.discard(param.name)
|
||||||
|
set_args.add(param.name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||||
|
for name, arg in keywords.items():
|
||||||
|
param: Function.Argument
|
||||||
|
if name not in kw_params:
|
||||||
|
if name in set_args:
|
||||||
|
self.error(
|
||||||
|
arg[0].location, f"Multiple values for argument '{name}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.error(arg[0].location, f"Unknown keyword argument '{name}'")
|
||||||
|
continue
|
||||||
|
param = kw_params.pop(name)
|
||||||
|
required_positional.discard(name)
|
||||||
|
required_keyword.discard(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(required_positional) != 0:
|
||||||
|
self.error(
|
||||||
|
call.location,
|
||||||
|
f"Missing required positional arguments: {required_positional}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(required_keyword) != 0:
|
||||||
|
self.error(
|
||||||
|
call.location,
|
||||||
|
f"Missing required keyword arguments: {required_keyword}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return mapped
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class UnitType:
|
|||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class Function:
|
class Function:
|
||||||
|
name: str
|
||||||
pos_args: list[Argument]
|
pos_args: list[Argument]
|
||||||
args: list[Argument]
|
args: list[Argument]
|
||||||
kw_args: list[Argument]
|
kw_args: list[Argument]
|
||||||
@@ -35,6 +36,7 @@ class Function:
|
|||||||
class Argument:
|
class Argument:
|
||||||
name: str
|
name: str
|
||||||
type: Type
|
type: Type
|
||||||
|
required: bool
|
||||||
|
|
||||||
|
|
||||||
Type = BaseType | SimpleType | UnknownType | UnitType | Function
|
Type = BaseType | SimpleType | UnknownType | UnitType | Function
|
||||||
|
|||||||
Reference in New Issue
Block a user