From a26b9293becb1ac72980a53236f8ab0eb0e4a15d Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 9 Jun 2026 15:30:45 +0200 Subject: [PATCH] refactor(types): extract TypeParams also rename generic type params to type args (when calling a generic) --- gen/gen.py | 15 ++++++- gen/midas.py | 21 ++++++---- midas/ast/midas.py | 17 ++++---- midas/ast/python.py | 1 + midas/checker/midas.py | 4 +- midas/checker/registry.py | 26 ++++++------ midas/cli/highlighter.py | 4 +- midas/parser/midas.py | 40 ++++++++++--------- .../01_simple_types.midas.ref.json | 12 +++--- tests/serializer/midas.py | 9 ++--- 10 files changed, 85 insertions(+), 64 deletions(-) diff --git a/gen/gen.py b/gen/gen.py index e78c872..50c9c9d 100644 --- a/gen/gen.py +++ b/gen/gen.py @@ -30,6 +30,7 @@ from __future__ import annotations T = TypeVar("T") +{preamble} {sections} """ @@ -57,6 +58,11 @@ IMPORTS_REGEX = re.compile( re.MULTILINE | re.DOTALL, ) +PREAMBLE_REGEX = re.compile( + r"^###>\s*Preamble\s*?\n(?P.*?)\n###<$", + re.MULTILINE | re.DOTALL, +) + def snake_case(text: str) -> str: return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_") @@ -88,13 +94,14 @@ def make_banner(text: str) -> str: def make_section(full_name: str, base: str, param: str, body: str) -> str: + print(f" Generating {full_name}") visitor_methods: list[str] = [] classes: list[str] = [] definitions: list[str] = body.strip("\n").split("\n\n\n") for cls in definitions: cls = cls.strip("\n") name: str = re.match("class (.*?):", cls).group(1) # type: ignore - print(f"Processing {name}") + print(f" Processing {name}") visitor_methods.append(make_visitor_method(name, param)) classes.append(make_class(name, cls, base)) @@ -107,6 +114,7 @@ def make_section(full_name: str, base: str, param: str, body: str) -> str: def generate(definitions_path: Path, out_path: Path): + print(f"Processing generating {out_path} from {definitions_path}") root_dir: Path = Path(__file__).parent.parent rel_path: Path = definitions_path.relative_to(root_dir) src: str = definitions_path.read_text() @@ -116,6 +124,10 @@ def generate(definitions_path: Path, out_path: Path): if m := IMPORTS_REGEX.search(src): imports = m.group("body").strip("\n") + preamble: str = "" + if m := PREAMBLE_REGEX.search(src): + preamble = m.group("body") + for section_m in SECTION_REGEX.finditer(src): full_name: str = section_m.group("name") base: str = section_m.group("base") @@ -129,6 +141,7 @@ def generate(definitions_path: Path, out_path: Path): gen_path=Path(__file__).relative_to(root_dir), ), imports=imports, + preamble=preamble, sections="\n\n\n".join(sections), ) out_path.write_text(result) diff --git a/gen/midas.py b/gen/midas.py index e1c304d..cca6f39 100644 --- a/gen/midas.py +++ b/gen/midas.py @@ -12,18 +12,23 @@ from midas.lexer.token import Token ###< +###> Preamble +@dataclass(frozen=True, kw_only=True) +class TypeParam: + location: Location + name: Token + bound: Optional[Type] + + +###< + + ###> Stmt | Statements class TypeStmt: name: Token - params: list[Param] + params: list[TypeParam] type: Type - @dataclass(frozen=True, kw_only=True) - class Param: - location: Location - name: Token - bound: Optional[Type] - class PropertyStmt: name: Token @@ -103,7 +108,7 @@ class NamedType: class GenericType: type: Type - params: list[Type] + args: list[Type] class ConstraintType: diff --git a/midas/ast/midas.py b/midas/ast/midas.py index 335e5cf..4459b52 100644 --- a/midas/ast/midas.py +++ b/midas/ast/midas.py @@ -14,6 +14,13 @@ from midas.lexer.token import Token T = TypeVar("T") +@dataclass(frozen=True, kw_only=True) +class TypeParam: + location: Location + name: Token + bound: Optional[Type] + + ############## # Statements # ############## @@ -46,15 +53,9 @@ class Stmt(ABC): @dataclass(frozen=True) class TypeStmt(Stmt): name: Token - params: list[Param] + params: list[TypeParam] type: Type - @dataclass(frozen=True, kw_only=True) - class Param: - location: Location - name: Token - bound: Optional[Type] - def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_type_stmt(self) @@ -243,7 +244,7 @@ class NamedType(Type): @dataclass(frozen=True) class GenericType(Type): type: Type - params: list[Type] + args: list[Type] def accept(self, visitor: Type.Visitor[T]) -> T: return visitor.visit_generic_type(self) diff --git a/midas/ast/python.py b/midas/ast/python.py index 1aea8ed..350dbb1 100644 --- a/midas/ast/python.py +++ b/midas/ast/python.py @@ -14,6 +14,7 @@ from midas.ast.location import Location T = TypeVar("T") + #################### # Type annotations # #################### diff --git a/midas/checker/midas.py b/midas/checker/midas.py index c55123f..2cb4ab1 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -122,8 +122,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type def visit_generic_type(self, type: m.GenericType) -> Type: type_: Type = type.type.accept(self) - params: list[Type] = [param.accept(self) for param in type.params] - return self.types.apply_generic(type_, params) + args: list[Type] = [arg.accept(self) for arg in type.args] + return self.types.apply_generic(type_, args) def visit_constraint_type(self, type: m.ConstraintType) -> Type: type_: Type = type.type.accept(self) diff --git a/midas/checker/registry.py b/midas/checker/registry.py index 585e3af..455c565 100644 --- a/midas/checker/registry.py +++ b/midas/checker/registry.py @@ -250,34 +250,34 @@ class TypesRegistry: return True - def apply_generic(self, type: Type, params: list[Type]) -> Type: + def apply_generic(self, type: Type, args: list[Type]) -> Type: match type: case AliasType(name=name, type=base): - return AliasType(name=name, type=self.apply_generic(base, params)) + return AliasType(name=name, type=self.apply_generic(base, args)) - case GenericType(name=name, params=type_vars, body=body): - n_params: int = len(params) + case GenericType(name=name, args=type_vars, body=body): + n_args: int = len(args) n_type_vars: int = len(type_vars) - if n_params < n_type_vars: + if n_args < n_type_vars: raise ValueError( - f"Missing type parameters, expected {n_type_vars} but only {n_params} provided" + f"Missing type arguments, expected {n_type_vars} but only {n_args} provided" ) - if n_params > n_type_vars: + if n_args > n_type_vars: raise ValueError( - f"Too many type parameters, expected {n_type_vars} but {n_params} provided" + f"Too many type arguments, expected {n_type_vars} but {n_args} provided" ) substitutions: dict[str, Type] = {} - for param, type_var in zip(params, type_vars): + for arg, type_var in zip(args, type_vars): if type_var.bound is not None and not self.is_subtype( - param, type_var.bound + arg, type_var.bound ): raise ValueError( - f"Type parameter {param} is not a subtype of {type_var.bound}" + f"Type argument {arg} is not a subtype of {type_var.bound}" ) - substitutions[type_var.name] = param + substitutions[type_var.name] = arg return AppliedType( name=name, - args=params, + args=args, body=substitute_typevars(body, substitutions), ) diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py index 0d6a018..af0fb4d 100644 --- a/midas/cli/highlighter.py +++ b/midas/cli/highlighter.py @@ -288,8 +288,8 @@ class MidasHighlighter( def visit_generic_type(self, type: m.GenericType) -> None: self.wrap(type, "generic-type") type.type.accept(self) - for param in type.params: - param.accept(self) + for arg in type.args: + arg.accept(self) def visit_constraint_type(self, type: m.ConstraintType) -> None: self.wrap(type, "constraint-type") diff --git a/midas/parser/midas.py b/midas/parser/midas.py index 5d09b83..cd83b84 100644 --- a/midas/parser/midas.py +++ b/midas/parser/midas.py @@ -18,6 +18,7 @@ from midas.ast.midas import ( PropertyStmt, Stmt, Type, + TypeParam, TypeStmt, UnaryExpr, VariableExpr, @@ -108,9 +109,7 @@ class MidasParser(Parser): """ keyword: Token = self.previous() name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") - params: list[TypeStmt.Param] = [] - if self.check(TokenType.LEFT_BRACKET): - params = self.type_stmt_params() + params: list[TypeParam] = self.type_params() self.consume(TokenType.EQUAL, "Expected '=' before type definition") @@ -123,16 +122,19 @@ class MidasParser(Parser): type=type, ) - def type_stmt_params(self) -> list[TypeStmt.Param]: - """Parse a generic template expression + def type_params(self) -> list[TypeParam]: + """Parse a list of type parameters - A template is written `[TypeExpr]` + Type parameters are a comma-separated list of type variables wrapped in brackets. + Each type variable is either a simple variable, or a bounded variable written `S <: T` Returns: - TemplateExpr: the parsed template expression + list[TypeParam]: the list of type parameters, if any, or an empty list """ - self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression") - params: list[TypeStmt.Param] = [] + if not self.match(TokenType.LEFT_BRACKET): + return [] + + params: list[TypeParam] = [] while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET): name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable") bound: Optional[Type] = None @@ -140,7 +142,7 @@ class MidasParser(Parser): self.consume(TokenType.COLON, "Expected ':' after '<'") bound = self.type_expr() params.append( - TypeStmt.Param( + TypeParam( location=name.location_to(self.previous()), name=name, bound=bound, @@ -148,7 +150,7 @@ class MidasParser(Parser): ) if not self.match(TokenType.COMMA): break - self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression") + self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters") return params def type_expr(self) -> Type: @@ -187,23 +189,23 @@ class MidasParser(Parser): def generic_type(self) -> Type: type: Type = self.named_type() if self.check(TokenType.LEFT_BRACKET): - params: list[Type] = self.type_params() + args: list[Type] = self.type_args() return GenericType( location=Location.span(type.location, self.previous().get_location()), type=type, - params=params, + args=args, ) return type - def type_params(self) -> list[Type]: - params: list[Type] = [] - self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters") + def type_args(self) -> list[Type]: + args: list[Type] = [] + self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments") while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET): - params.append(self.type_expr()) + args.append(self.type_expr()) if not self.match(TokenType.COMMA): break - self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters") - return params + self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments") + return args def named_type(self) -> Type: name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") diff --git a/tests/cases/midas-parser/01_simple_types.midas.ref.json b/tests/cases/midas-parser/01_simple_types.midas.ref.json index 55b4813..1d94718 100644 --- a/tests/cases/midas-parser/01_simple_types.midas.ref.json +++ b/tests/cases/midas-parser/01_simple_types.midas.ref.json @@ -2385,7 +2385,7 @@ "_type": "NamedType", "name": "Difference" }, - "params": [ + "args": [ { "_type": "NamedType", "name": "GeoLocation" @@ -2416,7 +2416,7 @@ "_type": "NamedType", "name": "Difference" }, - "params": [ + "args": [ { "_type": "NamedType", "name": "Latitude" @@ -2433,7 +2433,7 @@ "_type": "NamedType", "name": "Difference" }, - "params": [ + "args": [ { "_type": "NamedType", "name": "Longitude" @@ -2464,7 +2464,7 @@ "_type": "NamedType", "name": "Difference" }, - "params": [ + "args": [ { "_type": "NamedType", "name": "Latitude" @@ -2494,7 +2494,7 @@ "_type": "NamedType", "name": "Difference" }, - "params": [ + "args": [ { "_type": "NamedType", "name": "Longitude" @@ -2638,7 +2638,7 @@ "_type": "NamedType", "name": "Optional" }, - "params": [ + "args": [ { "_type": "ConstraintType", "type": { diff --git a/tests/serializer/midas.py b/tests/serializer/midas.py index 919dc66..947641e 100644 --- a/tests/serializer/midas.py +++ b/tests/serializer/midas.py @@ -17,6 +17,7 @@ from midas.ast.midas import ( PropertyStmt, Stmt, Type, + TypeParam, TypeStmt, UnaryExpr, VariableExpr, @@ -46,13 +47,11 @@ class MidasAstJsonSerializer( return { "_type": "TypeStmt", "name": stmt.name.lexeme, - "params": [ - self._serialize_type_stmt_template_param(param) for param in stmt.params - ], + "params": [self._serialize_type_param(param) for param in stmt.params], "type": stmt.type.accept(self), } - def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict: + def _serialize_type_param(self, param: TypeParam) -> dict: return { "name": param.name.lexeme, "bound": self._serialize_optional(param.bound), @@ -150,7 +149,7 @@ class MidasAstJsonSerializer( return { "_type": "GenericType", "type": type.type.accept(self), - "params": self._serialize_list(type.params), + "args": self._serialize_list(type.args), } def visit_constraint_type(self, type: ConstraintType) -> dict: