feat(types): add type params to extend statement
This commit is contained in:
@@ -36,6 +36,7 @@ class PropertyStmt:
|
|||||||
|
|
||||||
|
|
||||||
class ExtendStmt:
|
class ExtendStmt:
|
||||||
|
params: list[TypeParam]
|
||||||
type: Type
|
type: Type
|
||||||
operations: list[OpStmt]
|
operations: list[OpStmt]
|
||||||
|
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ class PropertyStmt(Stmt):
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ExtendStmt(Stmt):
|
class ExtendStmt(Stmt):
|
||||||
|
params: list[TypeParam]
|
||||||
type: Type
|
type: Type
|
||||||
operations: list[OpStmt]
|
operations: list[OpStmt]
|
||||||
|
|
||||||
|
|||||||
@@ -100,12 +100,12 @@ class MidasAstPrinter(
|
|||||||
self._idx = i
|
self._idx = i
|
||||||
if i == len(stmt.params) - 1:
|
if i == len(stmt.params) - 1:
|
||||||
self._mark_last()
|
self._mark_last()
|
||||||
self._print_type_stmt_param(param)
|
self._print_type_param(param)
|
||||||
self._write_line("type", last=True)
|
self._write_line("type", last=True)
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
stmt.type.accept(self)
|
stmt.type.accept(self)
|
||||||
|
|
||||||
def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None:
|
def _print_type_param(self, param: m.TypeParam) -> None:
|
||||||
self._write_line("Param")
|
self._write_line("Param")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
self._write_line(f'name: "{param.name.lexeme}"')
|
self._write_line(f'name: "{param.name.lexeme}"')
|
||||||
@@ -122,6 +122,13 @@ class MidasAstPrinter(
|
|||||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||||
self._write_line("ExtendStmt")
|
self._write_line("ExtendStmt")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
|
self._write_line("params")
|
||||||
|
with self._child_level():
|
||||||
|
for i, param in enumerate(stmt.params):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.params) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_type_param(param)
|
||||||
self._write_line("type")
|
self._write_line("type")
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
stmt.type.accept(self)
|
stmt.type.accept(self)
|
||||||
@@ -234,11 +241,11 @@ class MidasAstPrinter(
|
|||||||
self._write_line("type")
|
self._write_line("type")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
type.type.accept(self)
|
type.type.accept(self)
|
||||||
self._write_line("params", last=True)
|
self._write_line("args", last=True)
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
for i, param in enumerate(type.params):
|
for i, param in enumerate(type.args):
|
||||||
self._idx = i
|
self._idx = i
|
||||||
if i == len(type.params) - 1:
|
if i == len(type.args) - 1:
|
||||||
self._mark_last()
|
self._mark_last()
|
||||||
param.accept(self)
|
param.accept(self)
|
||||||
|
|
||||||
@@ -279,14 +286,12 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
|
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
|
||||||
template: str = ""
|
template: str = ""
|
||||||
if len(stmt.params) != 0:
|
if len(stmt.params) != 0:
|
||||||
params: list[str] = [
|
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||||
self._print_type_template_param(param) for param in stmt.params
|
|
||||||
]
|
|
||||||
template = f"[{', '.join(params)}]"
|
template = f"[{', '.join(params)}]"
|
||||||
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
||||||
return self.indented(res)
|
return self.indented(res)
|
||||||
|
|
||||||
def _print_type_template_param(self, param: m.TypeStmt.Param) -> str:
|
def _print_type_param(self, param: m.TypeParam) -> str:
|
||||||
res: str = param.name.lexeme
|
res: str = param.name.lexeme
|
||||||
if param.bound is not None:
|
if param.bound is not None:
|
||||||
res += "<:" + param.bound.accept(self)
|
res += "<:" + param.bound.accept(self)
|
||||||
@@ -358,9 +363,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
|
|
||||||
def visit_generic_type(self, type: m.GenericType) -> str:
|
def visit_generic_type(self, type: m.GenericType) -> str:
|
||||||
res: str = type.type.accept(self)
|
res: str = type.type.accept(self)
|
||||||
if len(type.params) != 0:
|
if len(type.args) != 0:
|
||||||
params: list[str] = [param.accept(self) for param in type.params]
|
args: list[str] = [param.accept(self) for param in type.args]
|
||||||
res += f"[{', '.join(params)}]"
|
res += f"[{', '.join(args)}]"
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def visit_constraint_type(self, type: m.ConstraintType) -> str:
|
def visit_constraint_type(self, type: m.ConstraintType) -> str:
|
||||||
|
|||||||
@@ -64,15 +64,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
stmt.accept(self)
|
stmt.accept(self)
|
||||||
|
|
||||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||||
params: list[TypeVar] = []
|
params: list[TypeVar] = self._resolve_type_params(stmt.params)
|
||||||
for param in stmt.params:
|
|
||||||
name: str = param.name.lexeme
|
|
||||||
bound: Optional[Type] = None
|
|
||||||
if param.bound is not None:
|
|
||||||
bound = param.bound.accept(self)
|
|
||||||
var = TypeVar(name=name, bound=bound)
|
|
||||||
self._local_variables[name] = var
|
|
||||||
params.append(var)
|
|
||||||
name: str = stmt.name.lexeme
|
name: str = stmt.name.lexeme
|
||||||
type: Type = stmt.type.accept(self)
|
type: Type = stmt.type.accept(self)
|
||||||
if len(params) != 0:
|
if len(params) != 0:
|
||||||
@@ -85,6 +78,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
|
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
|
||||||
|
|
||||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||||
|
self._resolve_type_params(stmt.params)
|
||||||
base: Type = stmt.type.accept(self)
|
base: Type = stmt.type.accept(self)
|
||||||
for op in stmt.operations:
|
for op in stmt.operations:
|
||||||
right: Type = op.operand.accept(self)
|
right: Type = op.operand.accept(self)
|
||||||
|
|||||||
@@ -255,7 +255,7 @@ class TypesRegistry:
|
|||||||
case AliasType(name=name, type=base):
|
case AliasType(name=name, type=base):
|
||||||
return AliasType(name=name, type=self.apply_generic(base, args))
|
return AliasType(name=name, type=self.apply_generic(base, args))
|
||||||
|
|
||||||
case GenericType(name=name, args=type_vars, body=body):
|
case GenericType(name=name, params=type_vars, body=body):
|
||||||
n_args: int = len(args)
|
n_args: int = len(args)
|
||||||
n_type_vars: int = len(type_vars)
|
n_type_vars: int = len(type_vars)
|
||||||
if n_args < n_type_vars:
|
if n_args < n_type_vars:
|
||||||
|
|||||||
@@ -383,12 +383,14 @@ class MidasParser(Parser):
|
|||||||
def extend_declaration(self) -> ExtendStmt:
|
def extend_declaration(self) -> ExtendStmt:
|
||||||
"""Parse an extension definition
|
"""Parse an extension definition
|
||||||
|
|
||||||
An extension is written `extend Type { operations }`
|
An extension is written `extend Type { operations }` or `extend[S <: T, U] Type { operations }`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ExtendStmt: the parsed extension statement
|
ExtendStmt: the parsed extension statement
|
||||||
"""
|
"""
|
||||||
keyword: Token = self.previous()
|
keyword: Token = self.previous()
|
||||||
|
params: list[TypeParam] = self.type_params()
|
||||||
|
|
||||||
type: Type = self.type_expr()
|
type: Type = self.type_expr()
|
||||||
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
|
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
|
||||||
operations: list[OpStmt] = []
|
operations: list[OpStmt] = []
|
||||||
@@ -396,7 +398,12 @@ class MidasParser(Parser):
|
|||||||
operations.append(self.op_declaration())
|
operations.append(self.op_declaration())
|
||||||
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
|
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
|
||||||
location: Location = keyword.location_to(self.previous())
|
location: Location = keyword.location_to(self.previous())
|
||||||
return ExtendStmt(location=location, type=type, operations=operations)
|
return ExtendStmt(
|
||||||
|
location=location,
|
||||||
|
params=params,
|
||||||
|
type=type,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
def op_declaration(self) -> OpStmt:
|
def op_declaration(self) -> OpStmt:
|
||||||
"""Parse an operation definition
|
"""Parse an operation definition
|
||||||
|
|||||||
Reference in New Issue
Block a user