diff --git a/gen/midas.py b/gen/midas.py index cca6f39..2184f86 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -36,6 +36,7 @@ class PropertyStmt: class ExtendStmt: + params: list[TypeParam] type: Type operations: list[OpStmt] diff --git a/midas/ast/midas.py b/midas/ast/midas.py index 4459b52..d759079 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -71,6 +71,7 @@ class PropertyStmt(Stmt): @dataclass(frozen=True) class ExtendStmt(Stmt): + params: list[TypeParam] type: Type operations: list[OpStmt] diff --git a/midas/ast/printer.py b/midas/ast/printer.py index dc2e64c..82bd0b4 100644 --- a/midas/ast/printer.py +++ b/midas/ast/printer.py @@ -100,12 +100,12 @@ class MidasAstPrinter( self._idx = i if i == len(stmt.params) - 1: self._mark_last() - self._print_type_stmt_param(param) + self._print_type_param(param) self._write_line("type", last=True) with self._child_level(single=True): stmt.type.accept(self) - def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None: + def _print_type_param(self, param: m.TypeParam) -> None: self._write_line("Param") with self._child_level(): self._write_line(f'name: "{param.name.lexeme}"') @@ -122,6 +122,13 @@ class MidasAstPrinter( def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: self._write_line("ExtendStmt") with self._child_level(): + self._write_line("params") + with self._child_level(): + for i, param in enumerate(stmt.params): + self._idx = i + if i == len(stmt.params) - 1: + self._mark_last() + self._print_type_param(param) self._write_line("type") with self._child_level(single=True): stmt.type.accept(self) @@ -234,11 +241,11 @@ class MidasAstPrinter( self._write_line("type") with self._child_level(): type.type.accept(self) - self._write_line("params", last=True) + self._write_line("args", last=True) with self._child_level(): - for i, param in enumerate(type.params): + for i, param in enumerate(type.args): self._idx = i - if i == len(type.params) - 1: + if i == len(type.args) - 1: self._mark_last() param.accept(self) @@ -279,14 +286,12 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] def visit_type_stmt(self, stmt: m.TypeStmt) -> str: template: str = "" if len(stmt.params) != 0: - params: list[str] = [ - self._print_type_template_param(param) for param in stmt.params - ] + params: list[str] = [self._print_type_param(param) for param in stmt.params] template = f"[{', '.join(params)}]" res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}" return self.indented(res) - def _print_type_template_param(self, param: m.TypeStmt.Param) -> str: + def _print_type_param(self, param: m.TypeParam) -> str: res: str = param.name.lexeme if param.bound is not None: res += "<:" + param.bound.accept(self) @@ -358,9 +363,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str] def visit_generic_type(self, type: m.GenericType) -> str: res: str = type.type.accept(self) - if len(type.params) != 0: - params: list[str] = [param.accept(self) for param in type.params] - res += f"[{', '.join(params)}]" + if len(type.args) != 0: + args: list[str] = [param.accept(self) for param in type.args] + res += f"[{', '.join(args)}]" return res def visit_constraint_type(self, type: m.ConstraintType) -> str: diff --git a/midas/checker/midas.py b/midas/checker/midas.py index 2cb4ab1..a6d86a9 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -64,15 +64,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type stmt.accept(self) def visit_type_stmt(self, stmt: m.TypeStmt) -> None: - params: list[TypeVar] = [] - for param in stmt.params: - name: str = param.name.lexeme - bound: Optional[Type] = None - if param.bound is not None: - bound = param.bound.accept(self) - var = TypeVar(name=name, bound=bound) - self._local_variables[name] = var - params.append(var) + params: list[TypeVar] = self._resolve_type_params(stmt.params) + name: str = stmt.name.lexeme type: Type = stmt.type.accept(self) if len(params) != 0: @@ -85,6 +78,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: + self._resolve_type_params(stmt.params) base: Type = stmt.type.accept(self) for op in stmt.operations: right: Type = op.operand.accept(self) diff --git a/midas/checker/registry.py b/midas/checker/registry.py index 455c565..d5c432a 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -255,7 +255,7 @@ class TypesRegistry: case AliasType(name=name, type=base): return AliasType(name=name, type=self.apply_generic(base, args)) - case GenericType(name=name, args=type_vars, body=body): + case GenericType(name=name, params=type_vars, body=body): n_args: int = len(args) n_type_vars: int = len(type_vars) if n_args < n_type_vars: diff --git a/midas/parser/midas.py b/midas/parser/midas.py index cd83b84..35c7a97 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -383,12 +383,14 @@ class MidasParser(Parser): def extend_declaration(self) -> ExtendStmt: """Parse an extension definition - An extension is written `extend Type { operations }` + An extension is written `extend Type { operations }` or `extend[S <: T, U] Type { operations }` Returns: ExtendStmt: the parsed extension statement """ keyword: Token = self.previous() + params: list[TypeParam] = self.type_params() + type: Type = self.type_expr() self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body") operations: list[OpStmt] = [] @@ -396,7 +398,12 @@ class MidasParser(Parser): operations.append(self.op_declaration()) self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body") location: Location = keyword.location_to(self.previous()) - return ExtendStmt(location=location, type=type, operations=operations) + return ExtendStmt( + location=location, + params=params, + type=type, + operations=operations, + ) def op_declaration(self) -> OpStmt: """Parse an operation definition