diff --git a/midas/checker/python.py b/midas/checker/python.py index 3d44975..f1cecb8 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -11,14 +11,11 @@ from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter, Reporter from midas.checker.resolver import Resolver from midas.checker.types import ( - ComplexType, - ExtensionType, Function, Operation, Type, UnitType, UnknownType, - unfold_type, ) from midas.parser.python import PythonParser @@ -250,8 +247,7 @@ class PythonTyper( self._assign_var(location, target, value_type) case p.GetExpr(object=object, name=name): - object_type: Type = self.type_of(object) - self._assign_attr(location, object_type, name, value_type) + self._assign_attr(location, object, name, value_type) case _: if not isinstance(target, p.VariableExpr): @@ -277,43 +273,19 @@ class PythonTyper( ) def _assign_attr( - self, location: Location, object: Type, name: str, value_type: Type + self, location: Location, object: p.Expr, name: str, value_type: Type ): - # TODO: improve recursion to have better error messages - base_object: Type = unfold_type(object) - match base_object: - case ComplexType(members=members): - if name not in members: - self.reporter.error(location, f"Unknown member '{object}.{name}'") - return - - member_type: Type = members[name] - if not self.is_subtype(value_type, member_type): - self.reporter.error( - location, - f"Cannot assign {value_type} to member '{object}.{name}' of type {member_type}", - ) - return - - case ExtensionType(base=base, extension=ComplexType(members=members)): - if name in members: - member_type: Type = members[name] - if not self.is_subtype(value_type, member_type): - self.reporter.error( - location, - f"Cannot assign {value_type} to member '{object}.{name}' of type {member_type}", - ) - return - return self._assign_attr(location, base, name, value_type) - - case UnknownType(): - pass - - case _: - self.reporter.error( - location, - f"Cannot assign {value_type} to unknown property '{object}.{name}'", - ) + object_type: Type = self.type_of(object) + member: Optional[Type] = self.types.lookup_member(object_type, name) + if member is None: + self.reporter.error(location, f"Unknown member '{name}' of {object_type}") + return + self.logger.debug(f"Member '{name}' of {object_type} has type {member}") + if not self.is_subtype(value_type, member): + self.reporter.error( + location, + f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}", + ) def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType() @@ -433,39 +405,15 @@ class PythonTyper( def visit_get_expr(self, expr: p.GetExpr) -> Type: object: Type = self.type_of(expr.object) - member: Optional[Type] = self._get_member(object, expr.name) + member: Optional[Type] = self.types.lookup_member(object, expr.name) if member is None: self.reporter.error( - expr.location, f"Unknown property '{expr.name}' on {object}" + expr.location, f"Unknown member '{expr.name}' of {object}" ) return UnknownType() - self.logger.debug(f"Property '{expr.name}' on {object} has type {member}") + self.logger.debug(f"Member '{expr.name}' of {object} has type {member}") return member - def _get_member(self, object: Type, name: str) -> Optional[Type]: - base_object: Type = unfold_type(object) - match base_object: - case ComplexType(members=members): - if name in members: - return members[name] - self.logger.debug(f"No property '{name}' in {base_object}") - return None - - case ExtensionType(base=base, extension=ComplexType(members=members)): - if name in members: - return members[name] - self.logger.debug( - f"No property '{name}' on {base_object}, looking up in base" - ) - return self._get_member(base, name) - - case UnknownType(): - return UnknownType() - - case _: - self.logger.debug(f"Can't get property on {base_object}") - return UnknownType() - def visit_literal_expr(self, expr: p.LiteralExpr) -> Type: match expr.value: case bool(): # Must be before int diff --git a/midas/checker/registry.py b/midas/checker/registry.py index 16c36a5..db2f972 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -7,11 +7,13 @@ from midas.checker.types import ( AppliedType, BaseType, ComplexType, + ExtensionType, Function, GenericType, Operation, OverloadedFunction, Type, + UnknownType, substitute_typevars, ) @@ -337,3 +339,51 @@ class TypesRegistry: reduced = True break return [types[i] for i in keep] + + def lookup_member(self, type: Type, member_name: str) -> Optional[Type]: + match type: + case AliasType(name=name, type=base): + if name in self._members: + if member_name in self._members[name]: + return self._members[name][member_name] + return self.lookup_member(base, member_name) + + case AppliedType(name=name, body=body, args=args): + generic: Type = self.get_type(name) + + if not isinstance(generic, GenericType): + raise ValueError("AppliedType not derived from a GenericType") + + substitutions = { + type_var.name: arg for arg, type_var in zip(args, generic.params) + } + if name in self._members: + if member_name in self._members[name]: + member_type: Type = self._members[name][member_name] + return substitute_typevars(member_type, substitutions) + + member_type2: Optional[Type] = self.lookup_member(body, member_name) + if member_type2 is not None: + member_type2 = substitute_typevars(member_type2, substitutions) + return member_type2 + + case ComplexType(members=members): + if member_name in members: + return members[member_name] + self.logger.debug(f"No member '{member_name}' in {type}") + return None + + case ExtensionType(base=base, extension=ComplexType(members=members)): + if member_name in members: + return members[member_name] + self.logger.debug( + f"No member '{member_name}' on {type}, looking up in base" + ) + return self.lookup_member(base, member_name) + + case UnknownType(): + return UnknownType() + + case _: + self.logger.debug(f"Can't get member on {type}") + return None