feat(checker): type check subscripts
This commit is contained in:
@@ -28,3 +28,8 @@ bar: list[list[Meter]]
|
|||||||
bar.append([p2.x])
|
bar.append([p2.x])
|
||||||
|
|
||||||
foo2 = foo + foo
|
foo2 = foo + foo
|
||||||
|
|
||||||
|
a = foo[0]
|
||||||
|
b = bar[0][1]
|
||||||
|
c = bar[0][1][2] # invalid, not method __getitem__ on Meter
|
||||||
|
c = bar[""] # invalid, wrong index type
|
||||||
|
|||||||
@@ -356,7 +356,7 @@ class PythonTyper(
|
|||||||
|
|
||||||
match operation:
|
match operation:
|
||||||
case Function() as function:
|
case Function() as function:
|
||||||
if not self._is_binary_function(function):
|
if not self._check_arity(function, 1, 0, 0):
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
location,
|
location,
|
||||||
f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}",
|
f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}",
|
||||||
@@ -395,7 +395,7 @@ class PythonTyper(
|
|||||||
|
|
||||||
match operation:
|
match operation:
|
||||||
case Function() as function:
|
case Function() as function:
|
||||||
if not self._is_unary_function(function):
|
if not self._check_arity(function, 0, 0, 0):
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
expr.location,
|
expr.location,
|
||||||
f"Wrong definition of unary operation. Expected function with 0 parameters, got {function}",
|
f"Wrong definition of unary operation. Expected function with 0 parameters, got {function}",
|
||||||
@@ -512,6 +512,41 @@ class PythonTyper(
|
|||||||
)
|
)
|
||||||
return self.types.apply_generic(list_type, [UnknownType()])
|
return self.types.apply_generic(list_type, [UnknownType()])
|
||||||
|
|
||||||
|
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type:
|
||||||
|
object: Type = self.type_of(expr.object)
|
||||||
|
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
|
||||||
|
if operation is None:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Undefined method __getitem__ on {object}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
index: Type = self.type_of(expr.index)
|
||||||
|
|
||||||
|
match operation:
|
||||||
|
case Function() as function:
|
||||||
|
if not self._check_arity(function, 1, 0, 0):
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Wrong definition of __getitem__. Expected function with 1 positional-only parameters, got {function}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
index_arg: Function.Argument = function.pos_args[0]
|
||||||
|
if not self.is_subtype(index, index_arg.type):
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Wrong index type, expected {index_arg.type}, got {index}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return function.returns
|
||||||
|
case _:
|
||||||
|
self.reporter.warning(
|
||||||
|
expr.location, f"Unsupported operation {operation}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
def visit_base_type(self, node: p.BaseType) -> Type:
|
def visit_base_type(self, node: p.BaseType) -> Type:
|
||||||
base: Type
|
base: Type
|
||||||
try:
|
try:
|
||||||
@@ -654,20 +689,17 @@ class PythonTyper(
|
|||||||
|
|
||||||
return mapped
|
return mapped
|
||||||
|
|
||||||
def _is_binary_function(self, function: Function) -> bool:
|
def _check_arity(
|
||||||
if len(function.pos_args) != 1:
|
self,
|
||||||
|
function: Function,
|
||||||
|
n_pos: Optional[int] = None,
|
||||||
|
n_mixed: Optional[int] = None,
|
||||||
|
n_keyword: Optional[int] = None,
|
||||||
|
) -> bool:
|
||||||
|
if n_pos is not None and len(function.pos_args) != n_pos:
|
||||||
return False
|
return False
|
||||||
if len(function.args) != 0:
|
if n_mixed is not None and len(function.args) != n_mixed:
|
||||||
return False
|
return False
|
||||||
if len(function.kw_args) != 0:
|
if n_keyword is not None and len(function.kw_args) != n_keyword:
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _is_unary_function(self, function: Function) -> bool:
|
|
||||||
if len(function.pos_args) != 0:
|
|
||||||
return False
|
|
||||||
if len(function.args) != 0:
|
|
||||||
return False
|
|
||||||
if len(function.kw_args) != 0:
|
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -196,3 +196,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||||
for item in expr.items:
|
for item in expr.items:
|
||||||
self.resolve(item)
|
self.resolve(item)
|
||||||
|
|
||||||
|
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||||
|
self.resolve(expr.object)
|
||||||
|
self.resolve(expr.index)
|
||||||
|
|||||||
Reference in New Issue
Block a user