refactor(types): extract TypeParams
also rename generic type params to type args (when calling a generic)
This commit is contained in:
13
gen/gen.py
13
gen/gen.py
@@ -30,6 +30,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
{preamble}
|
||||||
{sections}
|
{sections}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -57,6 +58,11 @@ IMPORTS_REGEX = re.compile(
|
|||||||
re.MULTILINE | re.DOTALL,
|
re.MULTILINE | re.DOTALL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
PREAMBLE_REGEX = re.compile(
|
||||||
|
r"^###>\s*Preamble\s*?\n(?P<body>.*?)\n###<$",
|
||||||
|
re.MULTILINE | re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def snake_case(text: str) -> str:
|
def snake_case(text: str) -> str:
|
||||||
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
||||||
@@ -88,6 +94,7 @@ def make_banner(text: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
||||||
|
print(f" Generating {full_name}")
|
||||||
visitor_methods: list[str] = []
|
visitor_methods: list[str] = []
|
||||||
classes: list[str] = []
|
classes: list[str] = []
|
||||||
definitions: list[str] = body.strip("\n").split("\n\n\n")
|
definitions: list[str] = body.strip("\n").split("\n\n\n")
|
||||||
@@ -107,6 +114,7 @@ def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def generate(definitions_path: Path, out_path: Path):
|
def generate(definitions_path: Path, out_path: Path):
|
||||||
|
print(f"Processing generating {out_path} from {definitions_path}")
|
||||||
root_dir: Path = Path(__file__).parent.parent
|
root_dir: Path = Path(__file__).parent.parent
|
||||||
rel_path: Path = definitions_path.relative_to(root_dir)
|
rel_path: Path = definitions_path.relative_to(root_dir)
|
||||||
src: str = definitions_path.read_text()
|
src: str = definitions_path.read_text()
|
||||||
@@ -116,6 +124,10 @@ def generate(definitions_path: Path, out_path: Path):
|
|||||||
if m := IMPORTS_REGEX.search(src):
|
if m := IMPORTS_REGEX.search(src):
|
||||||
imports = m.group("body").strip("\n")
|
imports = m.group("body").strip("\n")
|
||||||
|
|
||||||
|
preamble: str = ""
|
||||||
|
if m := PREAMBLE_REGEX.search(src):
|
||||||
|
preamble = m.group("body")
|
||||||
|
|
||||||
for section_m in SECTION_REGEX.finditer(src):
|
for section_m in SECTION_REGEX.finditer(src):
|
||||||
full_name: str = section_m.group("name")
|
full_name: str = section_m.group("name")
|
||||||
base: str = section_m.group("base")
|
base: str = section_m.group("base")
|
||||||
@@ -129,6 +141,7 @@ def generate(definitions_path: Path, out_path: Path):
|
|||||||
gen_path=Path(__file__).relative_to(root_dir),
|
gen_path=Path(__file__).relative_to(root_dir),
|
||||||
),
|
),
|
||||||
imports=imports,
|
imports=imports,
|
||||||
|
preamble=preamble,
|
||||||
sections="\n\n\n".join(sections),
|
sections="\n\n\n".join(sections),
|
||||||
)
|
)
|
||||||
out_path.write_text(result)
|
out_path.write_text(result)
|
||||||
|
|||||||
21
gen/midas.py
21
gen/midas.py
@@ -12,19 +12,24 @@ from midas.lexer.token import Token
|
|||||||
###<
|
###<
|
||||||
|
|
||||||
|
|
||||||
###> Stmt | Statements
|
###> Preamble
|
||||||
class TypeStmt:
|
|
||||||
name: Token
|
|
||||||
params: list[Param]
|
|
||||||
type: Type
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class Param:
|
class TypeParam:
|
||||||
location: Location
|
location: Location
|
||||||
name: Token
|
name: Token
|
||||||
bound: Optional[Type]
|
bound: Optional[Type]
|
||||||
|
|
||||||
|
|
||||||
|
###<
|
||||||
|
|
||||||
|
|
||||||
|
###> Stmt | Statements
|
||||||
|
class TypeStmt:
|
||||||
|
name: Token
|
||||||
|
params: list[TypeParam]
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
|
||||||
class PropertyStmt:
|
class PropertyStmt:
|
||||||
name: Token
|
name: Token
|
||||||
type: Type
|
type: Type
|
||||||
@@ -103,7 +108,7 @@ class NamedType:
|
|||||||
|
|
||||||
class GenericType:
|
class GenericType:
|
||||||
type: Type
|
type: Type
|
||||||
params: list[Type]
|
args: list[Type]
|
||||||
|
|
||||||
|
|
||||||
class ConstraintType:
|
class ConstraintType:
|
||||||
|
|||||||
@@ -14,6 +14,13 @@ from midas.lexer.token import Token
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class TypeParam:
|
||||||
|
location: Location
|
||||||
|
name: Token
|
||||||
|
bound: Optional[Type]
|
||||||
|
|
||||||
|
|
||||||
##############
|
##############
|
||||||
# Statements #
|
# Statements #
|
||||||
##############
|
##############
|
||||||
@@ -46,15 +53,9 @@ class Stmt(ABC):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TypeStmt(Stmt):
|
class TypeStmt(Stmt):
|
||||||
name: Token
|
name: Token
|
||||||
params: list[Param]
|
params: list[TypeParam]
|
||||||
type: Type
|
type: Type
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
|
||||||
class Param:
|
|
||||||
location: Location
|
|
||||||
name: Token
|
|
||||||
bound: Optional[Type]
|
|
||||||
|
|
||||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
return visitor.visit_type_stmt(self)
|
return visitor.visit_type_stmt(self)
|
||||||
|
|
||||||
@@ -243,7 +244,7 @@ class NamedType(Type):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class GenericType(Type):
|
class GenericType(Type):
|
||||||
type: Type
|
type: Type
|
||||||
params: list[Type]
|
args: list[Type]
|
||||||
|
|
||||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||||
return visitor.visit_generic_type(self)
|
return visitor.visit_generic_type(self)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from midas.ast.location import Location
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Type annotations #
|
# Type annotations #
|
||||||
####################
|
####################
|
||||||
|
|||||||
@@ -122,8 +122,8 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
|
|
||||||
def visit_generic_type(self, type: m.GenericType) -> Type:
|
def visit_generic_type(self, type: m.GenericType) -> Type:
|
||||||
type_: Type = type.type.accept(self)
|
type_: Type = type.type.accept(self)
|
||||||
params: list[Type] = [param.accept(self) for param in type.params]
|
args: list[Type] = [arg.accept(self) for arg in type.args]
|
||||||
return self.types.apply_generic(type_, params)
|
return self.types.apply_generic(type_, args)
|
||||||
|
|
||||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||||
type_: Type = type.type.accept(self)
|
type_: Type = type.type.accept(self)
|
||||||
|
|||||||
@@ -250,34 +250,34 @@ class TypesRegistry:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def apply_generic(self, type: Type, params: list[Type]) -> Type:
|
def apply_generic(self, type: Type, args: list[Type]) -> Type:
|
||||||
match type:
|
match type:
|
||||||
case AliasType(name=name, type=base):
|
case AliasType(name=name, type=base):
|
||||||
return AliasType(name=name, type=self.apply_generic(base, params))
|
return AliasType(name=name, type=self.apply_generic(base, args))
|
||||||
|
|
||||||
case GenericType(name=name, params=type_vars, body=body):
|
case GenericType(name=name, args=type_vars, body=body):
|
||||||
n_params: int = len(params)
|
n_args: int = len(args)
|
||||||
n_type_vars: int = len(type_vars)
|
n_type_vars: int = len(type_vars)
|
||||||
if n_params < n_type_vars:
|
if n_args < n_type_vars:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Missing type parameters, expected {n_type_vars} but only {n_params} provided"
|
f"Missing type arguments, expected {n_type_vars} but only {n_args} provided"
|
||||||
)
|
)
|
||||||
if n_params > n_type_vars:
|
if n_args > n_type_vars:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Too many type parameters, expected {n_type_vars} but {n_params} provided"
|
f"Too many type arguments, expected {n_type_vars} but {n_args} provided"
|
||||||
)
|
)
|
||||||
substitutions: dict[str, Type] = {}
|
substitutions: dict[str, Type] = {}
|
||||||
for param, type_var in zip(params, type_vars):
|
for arg, type_var in zip(args, type_vars):
|
||||||
if type_var.bound is not None and not self.is_subtype(
|
if type_var.bound is not None and not self.is_subtype(
|
||||||
param, type_var.bound
|
arg, type_var.bound
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Type parameter {param} is not a subtype of {type_var.bound}"
|
f"Type argument {arg} is not a subtype of {type_var.bound}"
|
||||||
)
|
)
|
||||||
substitutions[type_var.name] = param
|
substitutions[type_var.name] = arg
|
||||||
return AppliedType(
|
return AppliedType(
|
||||||
name=name,
|
name=name,
|
||||||
args=params,
|
args=args,
|
||||||
body=substitute_typevars(body, substitutions),
|
body=substitute_typevars(body, substitutions),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -288,8 +288,8 @@ class MidasHighlighter(
|
|||||||
def visit_generic_type(self, type: m.GenericType) -> None:
|
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||||
self.wrap(type, "generic-type")
|
self.wrap(type, "generic-type")
|
||||||
type.type.accept(self)
|
type.type.accept(self)
|
||||||
for param in type.params:
|
for arg in type.args:
|
||||||
param.accept(self)
|
arg.accept(self)
|
||||||
|
|
||||||
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||||
self.wrap(type, "constraint-type")
|
self.wrap(type, "constraint-type")
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from midas.ast.midas import (
|
|||||||
PropertyStmt,
|
PropertyStmt,
|
||||||
Stmt,
|
Stmt,
|
||||||
Type,
|
Type,
|
||||||
|
TypeParam,
|
||||||
TypeStmt,
|
TypeStmt,
|
||||||
UnaryExpr,
|
UnaryExpr,
|
||||||
VariableExpr,
|
VariableExpr,
|
||||||
@@ -108,9 +109,7 @@ class MidasParser(Parser):
|
|||||||
"""
|
"""
|
||||||
keyword: Token = self.previous()
|
keyword: Token = self.previous()
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||||
params: list[TypeStmt.Param] = []
|
params: list[TypeParam] = self.type_params()
|
||||||
if self.check(TokenType.LEFT_BRACKET):
|
|
||||||
params = self.type_stmt_params()
|
|
||||||
|
|
||||||
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
|
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
|
||||||
|
|
||||||
@@ -123,16 +122,19 @@ class MidasParser(Parser):
|
|||||||
type=type,
|
type=type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def type_stmt_params(self) -> list[TypeStmt.Param]:
|
def type_params(self) -> list[TypeParam]:
|
||||||
"""Parse a generic template expression
|
"""Parse a list of type parameters
|
||||||
|
|
||||||
A template is written `[TypeExpr]`
|
Type parameters are a comma-separated list of type variables wrapped in brackets.
|
||||||
|
Each type variable is either a simple variable, or a bounded variable written `S <: T`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TemplateExpr: the parsed template expression
|
list[TypeParam]: the list of type parameters, if any, or an empty list
|
||||||
"""
|
"""
|
||||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
|
if not self.match(TokenType.LEFT_BRACKET):
|
||||||
params: list[TypeStmt.Param] = []
|
return []
|
||||||
|
|
||||||
|
params: list[TypeParam] = []
|
||||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
|
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
|
||||||
bound: Optional[Type] = None
|
bound: Optional[Type] = None
|
||||||
@@ -140,7 +142,7 @@ class MidasParser(Parser):
|
|||||||
self.consume(TokenType.COLON, "Expected ':' after '<'")
|
self.consume(TokenType.COLON, "Expected ':' after '<'")
|
||||||
bound = self.type_expr()
|
bound = self.type_expr()
|
||||||
params.append(
|
params.append(
|
||||||
TypeStmt.Param(
|
TypeParam(
|
||||||
location=name.location_to(self.previous()),
|
location=name.location_to(self.previous()),
|
||||||
name=name,
|
name=name,
|
||||||
bound=bound,
|
bound=bound,
|
||||||
@@ -148,7 +150,7 @@ class MidasParser(Parser):
|
|||||||
)
|
)
|
||||||
if not self.match(TokenType.COMMA):
|
if not self.match(TokenType.COMMA):
|
||||||
break
|
break
|
||||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
|
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def type_expr(self) -> Type:
|
def type_expr(self) -> Type:
|
||||||
@@ -187,23 +189,23 @@ class MidasParser(Parser):
|
|||||||
def generic_type(self) -> Type:
|
def generic_type(self) -> Type:
|
||||||
type: Type = self.named_type()
|
type: Type = self.named_type()
|
||||||
if self.check(TokenType.LEFT_BRACKET):
|
if self.check(TokenType.LEFT_BRACKET):
|
||||||
params: list[Type] = self.type_params()
|
args: list[Type] = self.type_args()
|
||||||
return GenericType(
|
return GenericType(
|
||||||
location=Location.span(type.location, self.previous().get_location()),
|
location=Location.span(type.location, self.previous().get_location()),
|
||||||
type=type,
|
type=type,
|
||||||
params=params,
|
args=args,
|
||||||
)
|
)
|
||||||
return type
|
return type
|
||||||
|
|
||||||
def type_params(self) -> list[Type]:
|
def type_args(self) -> list[Type]:
|
||||||
params: list[Type] = []
|
args: list[Type] = []
|
||||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters")
|
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments")
|
||||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||||
params.append(self.type_expr())
|
args.append(self.type_expr())
|
||||||
if not self.match(TokenType.COMMA):
|
if not self.match(TokenType.COMMA):
|
||||||
break
|
break
|
||||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters")
|
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
|
||||||
return params
|
return args
|
||||||
|
|
||||||
def named_type(self) -> Type:
|
def named_type(self) -> Type:
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||||
|
|||||||
@@ -2385,7 +2385,7 @@
|
|||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Difference"
|
"name": "Difference"
|
||||||
},
|
},
|
||||||
"params": [
|
"args": [
|
||||||
{
|
{
|
||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "GeoLocation"
|
"name": "GeoLocation"
|
||||||
@@ -2416,7 +2416,7 @@
|
|||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Difference"
|
"name": "Difference"
|
||||||
},
|
},
|
||||||
"params": [
|
"args": [
|
||||||
{
|
{
|
||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Latitude"
|
"name": "Latitude"
|
||||||
@@ -2433,7 +2433,7 @@
|
|||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Difference"
|
"name": "Difference"
|
||||||
},
|
},
|
||||||
"params": [
|
"args": [
|
||||||
{
|
{
|
||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Longitude"
|
"name": "Longitude"
|
||||||
@@ -2464,7 +2464,7 @@
|
|||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Difference"
|
"name": "Difference"
|
||||||
},
|
},
|
||||||
"params": [
|
"args": [
|
||||||
{
|
{
|
||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Latitude"
|
"name": "Latitude"
|
||||||
@@ -2494,7 +2494,7 @@
|
|||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Difference"
|
"name": "Difference"
|
||||||
},
|
},
|
||||||
"params": [
|
"args": [
|
||||||
{
|
{
|
||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Longitude"
|
"name": "Longitude"
|
||||||
@@ -2638,7 +2638,7 @@
|
|||||||
"_type": "NamedType",
|
"_type": "NamedType",
|
||||||
"name": "Optional"
|
"name": "Optional"
|
||||||
},
|
},
|
||||||
"params": [
|
"args": [
|
||||||
{
|
{
|
||||||
"_type": "ConstraintType",
|
"_type": "ConstraintType",
|
||||||
"type": {
|
"type": {
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from midas.ast.midas import (
|
|||||||
PropertyStmt,
|
PropertyStmt,
|
||||||
Stmt,
|
Stmt,
|
||||||
Type,
|
Type,
|
||||||
|
TypeParam,
|
||||||
TypeStmt,
|
TypeStmt,
|
||||||
UnaryExpr,
|
UnaryExpr,
|
||||||
VariableExpr,
|
VariableExpr,
|
||||||
@@ -46,13 +47,11 @@ class MidasAstJsonSerializer(
|
|||||||
return {
|
return {
|
||||||
"_type": "TypeStmt",
|
"_type": "TypeStmt",
|
||||||
"name": stmt.name.lexeme,
|
"name": stmt.name.lexeme,
|
||||||
"params": [
|
"params": [self._serialize_type_param(param) for param in stmt.params],
|
||||||
self._serialize_type_stmt_template_param(param) for param in stmt.params
|
|
||||||
],
|
|
||||||
"type": stmt.type.accept(self),
|
"type": stmt.type.accept(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict:
|
def _serialize_type_param(self, param: TypeParam) -> dict:
|
||||||
return {
|
return {
|
||||||
"name": param.name.lexeme,
|
"name": param.name.lexeme,
|
||||||
"bound": self._serialize_optional(param.bound),
|
"bound": self._serialize_optional(param.bound),
|
||||||
@@ -150,7 +149,7 @@ class MidasAstJsonSerializer(
|
|||||||
return {
|
return {
|
||||||
"_type": "GenericType",
|
"_type": "GenericType",
|
||||||
"type": type.type.accept(self),
|
"type": type.type.accept(self),
|
||||||
"params": self._serialize_list(type.params),
|
"args": self._serialize_list(type.args),
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_constraint_type(self, type: ConstraintType) -> dict:
|
def visit_constraint_type(self, type: ConstraintType) -> dict:
|
||||||
|
|||||||
Reference in New Issue
Block a user