refactor(types): extract TypeParams

also rename generic type params to type args (when calling a generic)
This commit is contained in:
2026-06-09 15:30:45 +02:00
parent efa5454776
commit a26b9293be
10 changed files with 85 additions and 64 deletions

View File

@@ -30,6 +30,7 @@ from __future__ import annotations
T = TypeVar("T") T = TypeVar("T")
{preamble}
{sections} {sections}
""" """
@@ -57,6 +58,11 @@ IMPORTS_REGEX = re.compile(
re.MULTILINE | re.DOTALL, re.MULTILINE | re.DOTALL,
) )
PREAMBLE_REGEX = re.compile(
r"^###>\s*Preamble\s*?\n(?P<body>.*?)\n###<$",
re.MULTILINE | re.DOTALL,
)
def snake_case(text: str) -> str: def snake_case(text: str) -> str:
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_") return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
@@ -88,6 +94,7 @@ def make_banner(text: str) -> str:
def make_section(full_name: str, base: str, param: str, body: str) -> str: def make_section(full_name: str, base: str, param: str, body: str) -> str:
print(f" Generating {full_name}")
visitor_methods: list[str] = [] visitor_methods: list[str] = []
classes: list[str] = [] classes: list[str] = []
definitions: list[str] = body.strip("\n").split("\n\n\n") definitions: list[str] = body.strip("\n").split("\n\n\n")
@@ -107,6 +114,7 @@ def make_section(full_name: str, base: str, param: str, body: str) -> str:
def generate(definitions_path: Path, out_path: Path): def generate(definitions_path: Path, out_path: Path):
print(f"Processing generating {out_path} from {definitions_path}")
root_dir: Path = Path(__file__).parent.parent root_dir: Path = Path(__file__).parent.parent
rel_path: Path = definitions_path.relative_to(root_dir) rel_path: Path = definitions_path.relative_to(root_dir)
src: str = definitions_path.read_text() src: str = definitions_path.read_text()
@@ -116,6 +124,10 @@ def generate(definitions_path: Path, out_path: Path):
if m := IMPORTS_REGEX.search(src): if m := IMPORTS_REGEX.search(src):
imports = m.group("body").strip("\n") imports = m.group("body").strip("\n")
preamble: str = ""
if m := PREAMBLE_REGEX.search(src):
preamble = m.group("body")
for section_m in SECTION_REGEX.finditer(src): for section_m in SECTION_REGEX.finditer(src):
full_name: str = section_m.group("name") full_name: str = section_m.group("name")
base: str = section_m.group("base") base: str = section_m.group("base")
@@ -129,6 +141,7 @@ def generate(definitions_path: Path, out_path: Path):
gen_path=Path(__file__).relative_to(root_dir), gen_path=Path(__file__).relative_to(root_dir),
), ),
imports=imports, imports=imports,
preamble=preamble,
sections="\n\n\n".join(sections), sections="\n\n\n".join(sections),
) )
out_path.write_text(result) out_path.write_text(result)

View File

@@ -12,19 +12,24 @@ from midas.lexer.token import Token
###< ###<
###> Stmt | Statements ###> Preamble
class TypeStmt:
name: Token
params: list[Param]
type: Type
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Param: class TypeParam:
location: Location location: Location
name: Token name: Token
bound: Optional[Type] bound: Optional[Type]
###<
###> Stmt | Statements
class TypeStmt:
name: Token
params: list[TypeParam]
type: Type
class PropertyStmt: class PropertyStmt:
name: Token name: Token
type: Type type: Type
@@ -103,7 +108,7 @@ class NamedType:
class GenericType: class GenericType:
type: Type type: Type
params: list[Type] args: list[Type]
class ConstraintType: class ConstraintType:

View File

@@ -14,6 +14,13 @@ from midas.lexer.token import Token
T = TypeVar("T") T = TypeVar("T")
@dataclass(frozen=True, kw_only=True)
class TypeParam:
location: Location
name: Token
bound: Optional[Type]
############## ##############
# Statements # # Statements #
############## ##############
@@ -46,15 +53,9 @@ class Stmt(ABC):
@dataclass(frozen=True) @dataclass(frozen=True)
class TypeStmt(Stmt): class TypeStmt(Stmt):
name: Token name: Token
params: list[Param] params: list[TypeParam]
type: Type type: Type
@dataclass(frozen=True, kw_only=True)
class Param:
location: Location
name: Token
bound: Optional[Type]
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_type_stmt(self) return visitor.visit_type_stmt(self)
@@ -243,7 +244,7 @@ class NamedType(Type):
@dataclass(frozen=True) @dataclass(frozen=True)
class GenericType(Type): class GenericType(Type):
type: Type type: Type
params: list[Type] args: list[Type]
def accept(self, visitor: Type.Visitor[T]) -> T: def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_generic_type(self) return visitor.visit_generic_type(self)

View File

@@ -14,6 +14,7 @@ from midas.ast.location import Location
T = TypeVar("T") T = TypeVar("T")
#################### ####################
# Type annotations # # Type annotations #
#################### ####################

View File

@@ -122,8 +122,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
def visit_generic_type(self, type: m.GenericType) -> Type: def visit_generic_type(self, type: m.GenericType) -> Type:
type_: Type = type.type.accept(self) type_: Type = type.type.accept(self)
params: list[Type] = [param.accept(self) for param in type.params] args: list[Type] = [arg.accept(self) for arg in type.args]
return self.types.apply_generic(type_, params) return self.types.apply_generic(type_, args)
def visit_constraint_type(self, type: m.ConstraintType) -> Type: def visit_constraint_type(self, type: m.ConstraintType) -> Type:
type_: Type = type.type.accept(self) type_: Type = type.type.accept(self)

View File

@@ -250,34 +250,34 @@ class TypesRegistry:
return True return True
def apply_generic(self, type: Type, params: list[Type]) -> Type: def apply_generic(self, type: Type, args: list[Type]) -> Type:
match type: match type:
case AliasType(name=name, type=base): case AliasType(name=name, type=base):
return AliasType(name=name, type=self.apply_generic(base, params)) return AliasType(name=name, type=self.apply_generic(base, args))
case GenericType(name=name, params=type_vars, body=body): case GenericType(name=name, args=type_vars, body=body):
n_params: int = len(params) n_args: int = len(args)
n_type_vars: int = len(type_vars) n_type_vars: int = len(type_vars)
if n_params < n_type_vars: if n_args < n_type_vars:
raise ValueError( raise ValueError(
f"Missing type parameters, expected {n_type_vars} but only {n_params} provided" f"Missing type arguments, expected {n_type_vars} but only {n_args} provided"
) )
if n_params > n_type_vars: if n_args > n_type_vars:
raise ValueError( raise ValueError(
f"Too many type parameters, expected {n_type_vars} but {n_params} provided" f"Too many type arguments, expected {n_type_vars} but {n_args} provided"
) )
substitutions: dict[str, Type] = {} substitutions: dict[str, Type] = {}
for param, type_var in zip(params, type_vars): for arg, type_var in zip(args, type_vars):
if type_var.bound is not None and not self.is_subtype( if type_var.bound is not None and not self.is_subtype(
param, type_var.bound arg, type_var.bound
): ):
raise ValueError( raise ValueError(
f"Type parameter {param} is not a subtype of {type_var.bound}" f"Type argument {arg} is not a subtype of {type_var.bound}"
) )
substitutions[type_var.name] = param substitutions[type_var.name] = arg
return AppliedType( return AppliedType(
name=name, name=name,
args=params, args=args,
body=substitute_typevars(body, substitutions), body=substitute_typevars(body, substitutions),
) )

View File

@@ -288,8 +288,8 @@ class MidasHighlighter(
def visit_generic_type(self, type: m.GenericType) -> None: def visit_generic_type(self, type: m.GenericType) -> None:
self.wrap(type, "generic-type") self.wrap(type, "generic-type")
type.type.accept(self) type.type.accept(self)
for param in type.params: for arg in type.args:
param.accept(self) arg.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None: def visit_constraint_type(self, type: m.ConstraintType) -> None:
self.wrap(type, "constraint-type") self.wrap(type, "constraint-type")

View File

@@ -18,6 +18,7 @@ from midas.ast.midas import (
PropertyStmt, PropertyStmt,
Stmt, Stmt,
Type, Type,
TypeParam,
TypeStmt, TypeStmt,
UnaryExpr, UnaryExpr,
VariableExpr, VariableExpr,
@@ -108,9 +109,7 @@ class MidasParser(Parser):
""" """
keyword: Token = self.previous() keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
params: list[TypeStmt.Param] = [] params: list[TypeParam] = self.type_params()
if self.check(TokenType.LEFT_BRACKET):
params = self.type_stmt_params()
self.consume(TokenType.EQUAL, "Expected '=' before type definition") self.consume(TokenType.EQUAL, "Expected '=' before type definition")
@@ -123,16 +122,19 @@ class MidasParser(Parser):
type=type, type=type,
) )
def type_stmt_params(self) -> list[TypeStmt.Param]: def type_params(self) -> list[TypeParam]:
"""Parse a generic template expression """Parse a list of type parameters
A template is written `[TypeExpr]` Type parameters are a comma-separated list of type variables wrapped in brackets.
Each type variable is either a simple variable, or a bounded variable written `S <: T`
Returns: Returns:
TemplateExpr: the parsed template expression list[TypeParam]: the list of type parameters, if any, or an empty list
""" """
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression") if not self.match(TokenType.LEFT_BRACKET):
params: list[TypeStmt.Param] = [] return []
params: list[TypeParam] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET): while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable") name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
bound: Optional[Type] = None bound: Optional[Type] = None
@@ -140,7 +142,7 @@ class MidasParser(Parser):
self.consume(TokenType.COLON, "Expected ':' after '<'") self.consume(TokenType.COLON, "Expected ':' after '<'")
bound = self.type_expr() bound = self.type_expr()
params.append( params.append(
TypeStmt.Param( TypeParam(
location=name.location_to(self.previous()), location=name.location_to(self.previous()),
name=name, name=name,
bound=bound, bound=bound,
@@ -148,7 +150,7 @@ class MidasParser(Parser):
) )
if not self.match(TokenType.COMMA): if not self.match(TokenType.COMMA):
break break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression") self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
return params return params
def type_expr(self) -> Type: def type_expr(self) -> Type:
@@ -187,23 +189,23 @@ class MidasParser(Parser):
def generic_type(self) -> Type: def generic_type(self) -> Type:
type: Type = self.named_type() type: Type = self.named_type()
if self.check(TokenType.LEFT_BRACKET): if self.check(TokenType.LEFT_BRACKET):
params: list[Type] = self.type_params() args: list[Type] = self.type_args()
return GenericType( return GenericType(
location=Location.span(type.location, self.previous().get_location()), location=Location.span(type.location, self.previous().get_location()),
type=type, type=type,
params=params, args=args,
) )
return type return type
def type_params(self) -> list[Type]: def type_args(self) -> list[Type]:
params: list[Type] = [] args: list[Type] = []
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters") self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments")
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET): while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
params.append(self.type_expr()) args.append(self.type_expr())
if not self.match(TokenType.COMMA): if not self.match(TokenType.COMMA):
break break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters") self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
return params return args
def named_type(self) -> Type: def named_type(self) -> Type:
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")

View File

@@ -2385,7 +2385,7 @@
"_type": "NamedType", "_type": "NamedType",
"name": "Difference" "name": "Difference"
}, },
"params": [ "args": [
{ {
"_type": "NamedType", "_type": "NamedType",
"name": "GeoLocation" "name": "GeoLocation"
@@ -2416,7 +2416,7 @@
"_type": "NamedType", "_type": "NamedType",
"name": "Difference" "name": "Difference"
}, },
"params": [ "args": [
{ {
"_type": "NamedType", "_type": "NamedType",
"name": "Latitude" "name": "Latitude"
@@ -2433,7 +2433,7 @@
"_type": "NamedType", "_type": "NamedType",
"name": "Difference" "name": "Difference"
}, },
"params": [ "args": [
{ {
"_type": "NamedType", "_type": "NamedType",
"name": "Longitude" "name": "Longitude"
@@ -2464,7 +2464,7 @@
"_type": "NamedType", "_type": "NamedType",
"name": "Difference" "name": "Difference"
}, },
"params": [ "args": [
{ {
"_type": "NamedType", "_type": "NamedType",
"name": "Latitude" "name": "Latitude"
@@ -2494,7 +2494,7 @@
"_type": "NamedType", "_type": "NamedType",
"name": "Difference" "name": "Difference"
}, },
"params": [ "args": [
{ {
"_type": "NamedType", "_type": "NamedType",
"name": "Longitude" "name": "Longitude"
@@ -2638,7 +2638,7 @@
"_type": "NamedType", "_type": "NamedType",
"name": "Optional" "name": "Optional"
}, },
"params": [ "args": [
{ {
"_type": "ConstraintType", "_type": "ConstraintType",
"type": { "type": {

View File

@@ -17,6 +17,7 @@ from midas.ast.midas import (
PropertyStmt, PropertyStmt,
Stmt, Stmt,
Type, Type,
TypeParam,
TypeStmt, TypeStmt,
UnaryExpr, UnaryExpr,
VariableExpr, VariableExpr,
@@ -46,13 +47,11 @@ class MidasAstJsonSerializer(
return { return {
"_type": "TypeStmt", "_type": "TypeStmt",
"name": stmt.name.lexeme, "name": stmt.name.lexeme,
"params": [ "params": [self._serialize_type_param(param) for param in stmt.params],
self._serialize_type_stmt_template_param(param) for param in stmt.params
],
"type": stmt.type.accept(self), "type": stmt.type.accept(self),
} }
def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict: def _serialize_type_param(self, param: TypeParam) -> dict:
return { return {
"name": param.name.lexeme, "name": param.name.lexeme,
"bound": self._serialize_optional(param.bound), "bound": self._serialize_optional(param.bound),
@@ -150,7 +149,7 @@ class MidasAstJsonSerializer(
return { return {
"_type": "GenericType", "_type": "GenericType",
"type": type.type.accept(self), "type": type.type.accept(self),
"params": self._serialize_list(type.params), "args": self._serialize_list(type.args),
} }
def visit_constraint_type(self, type: ConstraintType) -> dict: def visit_constraint_type(self, type: ConstraintType) -> dict: