refactor(types): extract TypeParams
also rename generic type params to type args (when calling a generic)
This commit is contained in:
15
gen/gen.py
15
gen/gen.py
@@ -30,6 +30,7 @@ from __future__ import annotations
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
{preamble}
|
||||
{sections}
|
||||
"""
|
||||
|
||||
@@ -57,6 +58,11 @@ IMPORTS_REGEX = re.compile(
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
PREAMBLE_REGEX = re.compile(
|
||||
r"^###>\s*Preamble\s*?\n(?P<body>.*?)\n###<$",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def snake_case(text: str) -> str:
|
||||
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
||||
@@ -88,13 +94,14 @@ def make_banner(text: str) -> str:
|
||||
|
||||
|
||||
def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
||||
print(f" Generating {full_name}")
|
||||
visitor_methods: list[str] = []
|
||||
classes: list[str] = []
|
||||
definitions: list[str] = body.strip("\n").split("\n\n\n")
|
||||
for cls in definitions:
|
||||
cls = cls.strip("\n")
|
||||
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
|
||||
print(f"Processing {name}")
|
||||
print(f" Processing {name}")
|
||||
visitor_methods.append(make_visitor_method(name, param))
|
||||
classes.append(make_class(name, cls, base))
|
||||
|
||||
@@ -107,6 +114,7 @@ def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
||||
|
||||
|
||||
def generate(definitions_path: Path, out_path: Path):
|
||||
print(f"Processing generating {out_path} from {definitions_path}")
|
||||
root_dir: Path = Path(__file__).parent.parent
|
||||
rel_path: Path = definitions_path.relative_to(root_dir)
|
||||
src: str = definitions_path.read_text()
|
||||
@@ -116,6 +124,10 @@ def generate(definitions_path: Path, out_path: Path):
|
||||
if m := IMPORTS_REGEX.search(src):
|
||||
imports = m.group("body").strip("\n")
|
||||
|
||||
preamble: str = ""
|
||||
if m := PREAMBLE_REGEX.search(src):
|
||||
preamble = m.group("body")
|
||||
|
||||
for section_m in SECTION_REGEX.finditer(src):
|
||||
full_name: str = section_m.group("name")
|
||||
base: str = section_m.group("base")
|
||||
@@ -129,6 +141,7 @@ def generate(definitions_path: Path, out_path: Path):
|
||||
gen_path=Path(__file__).relative_to(root_dir),
|
||||
),
|
||||
imports=imports,
|
||||
preamble=preamble,
|
||||
sections="\n\n\n".join(sections),
|
||||
)
|
||||
out_path.write_text(result)
|
||||
|
||||
23
gen/midas.py
23
gen/midas.py
@@ -12,19 +12,24 @@ from midas.lexer.token import Token
|
||||
###<
|
||||
|
||||
|
||||
###> Stmt | Statements
|
||||
class TypeStmt:
|
||||
name: Token
|
||||
params: list[Param]
|
||||
type: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Param:
|
||||
###> Preamble
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypeParam:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Stmt | Statements
|
||||
class TypeStmt:
|
||||
name: Token
|
||||
params: list[TypeParam]
|
||||
type: Type
|
||||
|
||||
|
||||
class PropertyStmt:
|
||||
name: Token
|
||||
type: Type
|
||||
@@ -103,7 +108,7 @@ class NamedType:
|
||||
|
||||
class GenericType:
|
||||
type: Type
|
||||
params: list[Type]
|
||||
args: list[Type]
|
||||
|
||||
|
||||
class ConstraintType:
|
||||
|
||||
@@ -14,6 +14,13 @@ from midas.lexer.token import Token
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypeParam:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
|
||||
##############
|
||||
# Statements #
|
||||
##############
|
||||
@@ -46,15 +53,9 @@ class Stmt(ABC):
|
||||
@dataclass(frozen=True)
|
||||
class TypeStmt(Stmt):
|
||||
name: Token
|
||||
params: list[Param]
|
||||
params: list[TypeParam]
|
||||
type: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Param:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_type_stmt(self)
|
||||
|
||||
@@ -243,7 +244,7 @@ class NamedType(Type):
|
||||
@dataclass(frozen=True)
|
||||
class GenericType(Type):
|
||||
type: Type
|
||||
params: list[Type]
|
||||
args: list[Type]
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_generic_type(self)
|
||||
|
||||
@@ -14,6 +14,7 @@ from midas.ast.location import Location
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
####################
|
||||
# Type annotations #
|
||||
####################
|
||||
|
||||
@@ -122,8 +122,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
params: list[Type] = [param.accept(self) for param in type.params]
|
||||
return self.types.apply_generic(type_, params)
|
||||
args: list[Type] = [arg.accept(self) for arg in type.args]
|
||||
return self.types.apply_generic(type_, args)
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
|
||||
@@ -250,34 +250,34 @@ class TypesRegistry:
|
||||
|
||||
return True
|
||||
|
||||
def apply_generic(self, type: Type, params: list[Type]) -> Type:
|
||||
def apply_generic(self, type: Type, args: list[Type]) -> Type:
|
||||
match type:
|
||||
case AliasType(name=name, type=base):
|
||||
return AliasType(name=name, type=self.apply_generic(base, params))
|
||||
return AliasType(name=name, type=self.apply_generic(base, args))
|
||||
|
||||
case GenericType(name=name, params=type_vars, body=body):
|
||||
n_params: int = len(params)
|
||||
case GenericType(name=name, args=type_vars, body=body):
|
||||
n_args: int = len(args)
|
||||
n_type_vars: int = len(type_vars)
|
||||
if n_params < n_type_vars:
|
||||
if n_args < n_type_vars:
|
||||
raise ValueError(
|
||||
f"Missing type parameters, expected {n_type_vars} but only {n_params} provided"
|
||||
f"Missing type arguments, expected {n_type_vars} but only {n_args} provided"
|
||||
)
|
||||
if n_params > n_type_vars:
|
||||
if n_args > n_type_vars:
|
||||
raise ValueError(
|
||||
f"Too many type parameters, expected {n_type_vars} but {n_params} provided"
|
||||
f"Too many type arguments, expected {n_type_vars} but {n_args} provided"
|
||||
)
|
||||
substitutions: dict[str, Type] = {}
|
||||
for param, type_var in zip(params, type_vars):
|
||||
for arg, type_var in zip(args, type_vars):
|
||||
if type_var.bound is not None and not self.is_subtype(
|
||||
param, type_var.bound
|
||||
arg, type_var.bound
|
||||
):
|
||||
raise ValueError(
|
||||
f"Type parameter {param} is not a subtype of {type_var.bound}"
|
||||
f"Type argument {arg} is not a subtype of {type_var.bound}"
|
||||
)
|
||||
substitutions[type_var.name] = param
|
||||
substitutions[type_var.name] = arg
|
||||
return AppliedType(
|
||||
name=name,
|
||||
args=params,
|
||||
args=args,
|
||||
body=substitute_typevars(body, substitutions),
|
||||
)
|
||||
|
||||
|
||||
@@ -288,8 +288,8 @@ class MidasHighlighter(
|
||||
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||
self.wrap(type, "generic-type")
|
||||
type.type.accept(self)
|
||||
for param in type.params:
|
||||
param.accept(self)
|
||||
for arg in type.args:
|
||||
arg.accept(self)
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||
self.wrap(type, "constraint-type")
|
||||
|
||||
@@ -18,6 +18,7 @@ from midas.ast.midas import (
|
||||
PropertyStmt,
|
||||
Stmt,
|
||||
Type,
|
||||
TypeParam,
|
||||
TypeStmt,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
@@ -108,9 +109,7 @@ class MidasParser(Parser):
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
params: list[TypeStmt.Param] = []
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
params = self.type_stmt_params()
|
||||
params: list[TypeParam] = self.type_params()
|
||||
|
||||
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
|
||||
|
||||
@@ -123,16 +122,19 @@ class MidasParser(Parser):
|
||||
type=type,
|
||||
)
|
||||
|
||||
def type_stmt_params(self) -> list[TypeStmt.Param]:
|
||||
"""Parse a generic template expression
|
||||
def type_params(self) -> list[TypeParam]:
|
||||
"""Parse a list of type parameters
|
||||
|
||||
A template is written `[TypeExpr]`
|
||||
Type parameters are a comma-separated list of type variables wrapped in brackets.
|
||||
Each type variable is either a simple variable, or a bounded variable written `S <: T`
|
||||
|
||||
Returns:
|
||||
TemplateExpr: the parsed template expression
|
||||
list[TypeParam]: the list of type parameters, if any, or an empty list
|
||||
"""
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
|
||||
params: list[TypeStmt.Param] = []
|
||||
if not self.match(TokenType.LEFT_BRACKET):
|
||||
return []
|
||||
|
||||
params: list[TypeParam] = []
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
|
||||
bound: Optional[Type] = None
|
||||
@@ -140,7 +142,7 @@ class MidasParser(Parser):
|
||||
self.consume(TokenType.COLON, "Expected ':' after '<'")
|
||||
bound = self.type_expr()
|
||||
params.append(
|
||||
TypeStmt.Param(
|
||||
TypeParam(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
bound=bound,
|
||||
@@ -148,7 +150,7 @@ class MidasParser(Parser):
|
||||
)
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
|
||||
return params
|
||||
|
||||
def type_expr(self) -> Type:
|
||||
@@ -187,23 +189,23 @@ class MidasParser(Parser):
|
||||
def generic_type(self) -> Type:
|
||||
type: Type = self.named_type()
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
params: list[Type] = self.type_params()
|
||||
args: list[Type] = self.type_args()
|
||||
return GenericType(
|
||||
location=Location.span(type.location, self.previous().get_location()),
|
||||
type=type,
|
||||
params=params,
|
||||
args=args,
|
||||
)
|
||||
return type
|
||||
|
||||
def type_params(self) -> list[Type]:
|
||||
params: list[Type] = []
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters")
|
||||
def type_args(self) -> list[Type]:
|
||||
args: list[Type] = []
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments")
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
params.append(self.type_expr())
|
||||
args.append(self.type_expr())
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters")
|
||||
return params
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
|
||||
return args
|
||||
|
||||
def named_type(self) -> Type:
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
|
||||
@@ -2385,7 +2385,7 @@
|
||||
"_type": "NamedType",
|
||||
"name": "Difference"
|
||||
},
|
||||
"params": [
|
||||
"args": [
|
||||
{
|
||||
"_type": "NamedType",
|
||||
"name": "GeoLocation"
|
||||
@@ -2416,7 +2416,7 @@
|
||||
"_type": "NamedType",
|
||||
"name": "Difference"
|
||||
},
|
||||
"params": [
|
||||
"args": [
|
||||
{
|
||||
"_type": "NamedType",
|
||||
"name": "Latitude"
|
||||
@@ -2433,7 +2433,7 @@
|
||||
"_type": "NamedType",
|
||||
"name": "Difference"
|
||||
},
|
||||
"params": [
|
||||
"args": [
|
||||
{
|
||||
"_type": "NamedType",
|
||||
"name": "Longitude"
|
||||
@@ -2464,7 +2464,7 @@
|
||||
"_type": "NamedType",
|
||||
"name": "Difference"
|
||||
},
|
||||
"params": [
|
||||
"args": [
|
||||
{
|
||||
"_type": "NamedType",
|
||||
"name": "Latitude"
|
||||
@@ -2494,7 +2494,7 @@
|
||||
"_type": "NamedType",
|
||||
"name": "Difference"
|
||||
},
|
||||
"params": [
|
||||
"args": [
|
||||
{
|
||||
"_type": "NamedType",
|
||||
"name": "Longitude"
|
||||
@@ -2638,7 +2638,7 @@
|
||||
"_type": "NamedType",
|
||||
"name": "Optional"
|
||||
},
|
||||
"params": [
|
||||
"args": [
|
||||
{
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
|
||||
@@ -17,6 +17,7 @@ from midas.ast.midas import (
|
||||
PropertyStmt,
|
||||
Stmt,
|
||||
Type,
|
||||
TypeParam,
|
||||
TypeStmt,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
@@ -46,13 +47,11 @@ class MidasAstJsonSerializer(
|
||||
return {
|
||||
"_type": "TypeStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"params": [
|
||||
self._serialize_type_stmt_template_param(param) for param in stmt.params
|
||||
],
|
||||
"params": [self._serialize_type_param(param) for param in stmt.params],
|
||||
"type": stmt.type.accept(self),
|
||||
}
|
||||
|
||||
def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict:
|
||||
def _serialize_type_param(self, param: TypeParam) -> dict:
|
||||
return {
|
||||
"name": param.name.lexeme,
|
||||
"bound": self._serialize_optional(param.bound),
|
||||
@@ -150,7 +149,7 @@ class MidasAstJsonSerializer(
|
||||
return {
|
||||
"_type": "GenericType",
|
||||
"type": type.type.accept(self),
|
||||
"params": self._serialize_list(type.params),
|
||||
"args": self._serialize_list(type.args),
|
||||
}
|
||||
|
||||
def visit_constraint_type(self, type: ConstraintType) -> dict:
|
||||
|
||||
Reference in New Issue
Block a user