refactor(types): extract TypeParams

also rename generic type params to type args (when calling a generic)
This commit is contained in:
2026-06-09 15:30:45 +02:00
parent efa5454776
commit a26b9293be
10 changed files with 85 additions and 64 deletions

View File

@@ -30,6 +30,7 @@ from __future__ import annotations
T = TypeVar("T")
{preamble}
{sections}
"""
@@ -57,6 +58,11 @@ IMPORTS_REGEX = re.compile(
re.MULTILINE | re.DOTALL,
)
PREAMBLE_REGEX = re.compile(
r"^###>\s*Preamble\s*?\n(?P<body>.*?)\n###<$",
re.MULTILINE | re.DOTALL,
)
def snake_case(text: str) -> str:
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
@@ -88,13 +94,14 @@ def make_banner(text: str) -> str:
def make_section(full_name: str, base: str, param: str, body: str) -> str:
print(f" Generating {full_name}")
visitor_methods: list[str] = []
classes: list[str] = []
definitions: list[str] = body.strip("\n").split("\n\n\n")
for cls in definitions:
cls = cls.strip("\n")
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
print(f"Processing {name}")
print(f" Processing {name}")
visitor_methods.append(make_visitor_method(name, param))
classes.append(make_class(name, cls, base))
@@ -107,6 +114,7 @@ def make_section(full_name: str, base: str, param: str, body: str) -> str:
def generate(definitions_path: Path, out_path: Path):
print(f"Processing generating {out_path} from {definitions_path}")
root_dir: Path = Path(__file__).parent.parent
rel_path: Path = definitions_path.relative_to(root_dir)
src: str = definitions_path.read_text()
@@ -116,6 +124,10 @@ def generate(definitions_path: Path, out_path: Path):
if m := IMPORTS_REGEX.search(src):
imports = m.group("body").strip("\n")
preamble: str = ""
if m := PREAMBLE_REGEX.search(src):
preamble = m.group("body")
for section_m in SECTION_REGEX.finditer(src):
full_name: str = section_m.group("name")
base: str = section_m.group("base")
@@ -129,6 +141,7 @@ def generate(definitions_path: Path, out_path: Path):
gen_path=Path(__file__).relative_to(root_dir),
),
imports=imports,
preamble=preamble,
sections="\n\n\n".join(sections),
)
out_path.write_text(result)

View File

@@ -12,19 +12,24 @@ from midas.lexer.token import Token
###<
###> Stmt | Statements
class TypeStmt:
name: Token
params: list[Param]
type: Type
@dataclass(frozen=True, kw_only=True)
class Param:
###> Preamble
@dataclass(frozen=True, kw_only=True)
class TypeParam:
location: Location
name: Token
bound: Optional[Type]
###<
###> Stmt | Statements
class TypeStmt:
name: Token
params: list[TypeParam]
type: Type
class PropertyStmt:
name: Token
type: Type
@@ -103,7 +108,7 @@ class NamedType:
class GenericType:
type: Type
params: list[Type]
args: list[Type]
class ConstraintType:

View File

@@ -14,6 +14,13 @@ from midas.lexer.token import Token
T = TypeVar("T")
@dataclass(frozen=True, kw_only=True)
class TypeParam:
location: Location
name: Token
bound: Optional[Type]
##############
# Statements #
##############
@@ -46,15 +53,9 @@ class Stmt(ABC):
@dataclass(frozen=True)
class TypeStmt(Stmt):
name: Token
params: list[Param]
params: list[TypeParam]
type: Type
@dataclass(frozen=True, kw_only=True)
class Param:
location: Location
name: Token
bound: Optional[Type]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_type_stmt(self)
@@ -243,7 +244,7 @@ class NamedType(Type):
@dataclass(frozen=True)
class GenericType(Type):
type: Type
params: list[Type]
args: list[Type]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_generic_type(self)

View File

@@ -14,6 +14,7 @@ from midas.ast.location import Location
T = TypeVar("T")
####################
# Type annotations #
####################

View File

@@ -122,8 +122,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
def visit_generic_type(self, type: m.GenericType) -> Type:
type_: Type = type.type.accept(self)
params: list[Type] = [param.accept(self) for param in type.params]
return self.types.apply_generic(type_, params)
args: list[Type] = [arg.accept(self) for arg in type.args]
return self.types.apply_generic(type_, args)
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
type_: Type = type.type.accept(self)

View File

@@ -250,34 +250,34 @@ class TypesRegistry:
return True
def apply_generic(self, type: Type, params: list[Type]) -> Type:
def apply_generic(self, type: Type, args: list[Type]) -> Type:
match type:
case AliasType(name=name, type=base):
return AliasType(name=name, type=self.apply_generic(base, params))
return AliasType(name=name, type=self.apply_generic(base, args))
case GenericType(name=name, params=type_vars, body=body):
n_params: int = len(params)
case GenericType(name=name, args=type_vars, body=body):
n_args: int = len(args)
n_type_vars: int = len(type_vars)
if n_params < n_type_vars:
if n_args < n_type_vars:
raise ValueError(
f"Missing type parameters, expected {n_type_vars} but only {n_params} provided"
f"Missing type arguments, expected {n_type_vars} but only {n_args} provided"
)
if n_params > n_type_vars:
if n_args > n_type_vars:
raise ValueError(
f"Too many type parameters, expected {n_type_vars} but {n_params} provided"
f"Too many type arguments, expected {n_type_vars} but {n_args} provided"
)
substitutions: dict[str, Type] = {}
for param, type_var in zip(params, type_vars):
for arg, type_var in zip(args, type_vars):
if type_var.bound is not None and not self.is_subtype(
param, type_var.bound
arg, type_var.bound
):
raise ValueError(
f"Type parameter {param} is not a subtype of {type_var.bound}"
f"Type argument {arg} is not a subtype of {type_var.bound}"
)
substitutions[type_var.name] = param
substitutions[type_var.name] = arg
return AppliedType(
name=name,
args=params,
args=args,
body=substitute_typevars(body, substitutions),
)

View File

@@ -288,8 +288,8 @@ class MidasHighlighter(
def visit_generic_type(self, type: m.GenericType) -> None:
self.wrap(type, "generic-type")
type.type.accept(self)
for param in type.params:
param.accept(self)
for arg in type.args:
arg.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self.wrap(type, "constraint-type")

View File

@@ -18,6 +18,7 @@ from midas.ast.midas import (
PropertyStmt,
Stmt,
Type,
TypeParam,
TypeStmt,
UnaryExpr,
VariableExpr,
@@ -108,9 +109,7 @@ class MidasParser(Parser):
"""
keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
params: list[TypeStmt.Param] = []
if self.check(TokenType.LEFT_BRACKET):
params = self.type_stmt_params()
params: list[TypeParam] = self.type_params()
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
@@ -123,16 +122,19 @@ class MidasParser(Parser):
type=type,
)
def type_stmt_params(self) -> list[TypeStmt.Param]:
"""Parse a generic template expression
def type_params(self) -> list[TypeParam]:
"""Parse a list of type parameters
A template is written `[TypeExpr]`
Type parameters are a comma-separated list of type variables wrapped in brackets.
Each type variable is either a simple variable, or a bounded variable written `S <: T`
Returns:
TemplateExpr: the parsed template expression
list[TypeParam]: the list of type parameters, if any, or an empty list
"""
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
params: list[TypeStmt.Param] = []
if not self.match(TokenType.LEFT_BRACKET):
return []
params: list[TypeParam] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
bound: Optional[Type] = None
@@ -140,7 +142,7 @@ class MidasParser(Parser):
self.consume(TokenType.COLON, "Expected ':' after '<'")
bound = self.type_expr()
params.append(
TypeStmt.Param(
TypeParam(
location=name.location_to(self.previous()),
name=name,
bound=bound,
@@ -148,7 +150,7 @@ class MidasParser(Parser):
)
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
return params
def type_expr(self) -> Type:
@@ -187,23 +189,23 @@ class MidasParser(Parser):
def generic_type(self) -> Type:
type: Type = self.named_type()
if self.check(TokenType.LEFT_BRACKET):
params: list[Type] = self.type_params()
args: list[Type] = self.type_args()
return GenericType(
location=Location.span(type.location, self.previous().get_location()),
type=type,
params=params,
args=args,
)
return type
def type_params(self) -> list[Type]:
params: list[Type] = []
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters")
def type_args(self) -> list[Type]:
args: list[Type] = []
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments")
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
params.append(self.type_expr())
args.append(self.type_expr())
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters")
return params
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
return args
def named_type(self) -> Type:
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")

View File

@@ -2385,7 +2385,7 @@
"_type": "NamedType",
"name": "Difference"
},
"params": [
"args": [
{
"_type": "NamedType",
"name": "GeoLocation"
@@ -2416,7 +2416,7 @@
"_type": "NamedType",
"name": "Difference"
},
"params": [
"args": [
{
"_type": "NamedType",
"name": "Latitude"
@@ -2433,7 +2433,7 @@
"_type": "NamedType",
"name": "Difference"
},
"params": [
"args": [
{
"_type": "NamedType",
"name": "Longitude"
@@ -2464,7 +2464,7 @@
"_type": "NamedType",
"name": "Difference"
},
"params": [
"args": [
{
"_type": "NamedType",
"name": "Latitude"
@@ -2494,7 +2494,7 @@
"_type": "NamedType",
"name": "Difference"
},
"params": [
"args": [
{
"_type": "NamedType",
"name": "Longitude"
@@ -2638,7 +2638,7 @@
"_type": "NamedType",
"name": "Optional"
},
"params": [
"args": [
{
"_type": "ConstraintType",
"type": {

View File

@@ -17,6 +17,7 @@ from midas.ast.midas import (
PropertyStmt,
Stmt,
Type,
TypeParam,
TypeStmt,
UnaryExpr,
VariableExpr,
@@ -46,13 +47,11 @@ class MidasAstJsonSerializer(
return {
"_type": "TypeStmt",
"name": stmt.name.lexeme,
"params": [
self._serialize_type_stmt_template_param(param) for param in stmt.params
],
"params": [self._serialize_type_param(param) for param in stmt.params],
"type": stmt.type.accept(self),
}
def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict:
def _serialize_type_param(self, param: TypeParam) -> dict:
return {
"name": param.name.lexeme,
"bound": self._serialize_optional(param.bound),
@@ -150,7 +149,7 @@ class MidasAstJsonSerializer(
return {
"_type": "GenericType",
"type": type.type.accept(self),
"params": self._serialize_list(type.params),
"args": self._serialize_list(type.args),
}
def visit_constraint_type(self, type: ConstraintType) -> dict: