Merge pull request 'Python parser' (#4) from feat/python-parser into main

Reviewed-on: #4
This commit was merged in pull request #4.
This commit is contained in:
2026-05-26 08:28:42 +00:00
31 changed files with 1850 additions and 175 deletions

4
.gitignore vendored
View File

@@ -3,4 +3,6 @@ __pycache__
.env .env
venv venv
.venv .venv
*.pyc *.pyc
uv.lock
.python-version

View File

@@ -21,7 +21,7 @@ lat + lon # Invalid operation
# Registered operations are permitted # Registered operations are permitted
lat1: Latitude = lat[0] lat1: Latitude = lat[0]
lat2: Latitude = lat[1] lat2: Latitude = lat[1]
lat_diff: LatitudeDiff = lat2 - lat1 # Valid operation lat_diff: Difference[Latitude] = lat2 - lat1 # Valid operation
# In addition to the type, a column can have one or more constraints, either defined inline or in a separate file # In addition to the type, a column can have one or more constraints, either defined inline or in a separate file
df2: Frame[ df2: Frame[

View File

@@ -0,0 +1,15 @@
# type: ignore
# ruff: disable[F821]
from __future__ import annotations
def func(
col1: Column[float + (0 <= _ <= 1)],
col2: Column[float + (0 <= _ <= 1)],
) -> Column[float + (0 <= _ <= 2)]:
result: Column[float + (0 <= _ <= 2)] = col1 + col2
return result
def func2(a: int, /, b: float, *, c: str):
pass

View File

@@ -3,53 +3,34 @@ import re
HEADER = '''""" HEADER = '''"""
This file was generated by a script. Any manual changes might be overwritten. This file was generated by a script. Any manual changes might be overwritten.
Please modify gen/ast.py instead and run gen/gen.py Please modify {defs_path} instead and run {gen_path}
"""''' """'''
SECTION_TEMPLATE = """{banner}
@dataclass(frozen=True, kw_only=True)
class {base}(ABC):
location: Optional[Location] = None
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
{visitor_methods}
{classes}"""
TEMPLATE = """{header} TEMPLATE = """{header}
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod {imports}
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from lexer.token import Token
T = TypeVar("T") T = TypeVar("T")
############## {sections}
# Statements #
##############
@dataclass(frozen=True)
class Stmt(ABC):
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
{stmt_visitor_methods}
{statements}
###############
# Expressions #
###############
@dataclass(frozen=True)
class Expr(ABC):
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
{expr_visitor_methods}
{expressions}
""" """
VISITOR_METHOD_TEMPLATE = """ VISITOR_METHOD_TEMPLATE = """
@@ -66,17 +47,28 @@ class {cls}({base}):
return visitor.visit_{func_name}(self) return visitor.visit_{func_name}(self)
""" """
SECTION_REGEX = re.compile(
r"^###>\s*(?P<base>[^\n]*?)\s*\|\s*(?P<name>[^\n]*?)(\s*\|\s*(?P<param>[^\n]*?))?\s*?\n(?P<body>.*?)\n###<$",
re.MULTILINE | re.DOTALL,
)
IMPORTS_REGEX = re.compile(
r"^###>\s*Imports\s*?\n(?P<body>.*?)\n###<$",
re.MULTILINE | re.DOTALL,
)
def snake_case(text: str) -> str: def snake_case(text: str) -> str:
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_") return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
def make_visitor_method(cls: str, param: str): def make_visitor_method(cls: str, param: str):
method: str = VISITOR_METHOD_TEMPLATE.format( method: str = VISITOR_METHOD_TEMPLATE.format(
func_name=snake_case(cls), func_name=snake_case(cls), param=param, cls=cls
param=param,
cls=cls
) )
return method.strip("\n") return method.strip("\n")
def make_class(name: str, cls: str, base: str): def make_class(name: str, cls: str, base: str):
body: str = cls.split("\n", 1)[1] body: str = cls.split("\n", 1)[1]
func_name: str = snake_case(name) func_name: str = snake_case(name)
@@ -88,40 +80,66 @@ def make_class(name: str, cls: str, base: str):
) )
return cls_def.strip("\n") return cls_def.strip("\n")
def generate(src: str):
classes: list[str] = src.split("\n\n")
stmt_visitor_methods: list[str] = []
expr_visitor_methods: list[str] = []
statements: list[str] = []
expressions: list[str] = []
for cls in classes: def make_banner(text: str) -> str:
middle: str = f"# {text} #"
rule: str = "#" * len(middle)
return "\n".join((rule, middle, rule))
def make_section(full_name: str, base: str, param: str, body: str) -> str:
visitor_methods: list[str] = []
classes: list[str] = []
definitions: list[str] = body.strip("\n").split("\n\n\n")
for cls in definitions:
cls = cls.strip("\n") cls = cls.strip("\n")
name: str = re.match("class (.*?):", cls).group(1) # type: ignore name: str = re.match("class (.*?):", cls).group(1) # type: ignore
print(f"Processing {name}") print(f"Processing {name}")
if name.endswith("Stmt"): visitor_methods.append(make_visitor_method(name, param))
stmt_visitor_methods.append(make_visitor_method(name, "stmt")) classes.append(make_class(name, cls, base))
statements.append(make_class(name, cls, "Stmt"))
elif name.endswith("Expr"):
expr_visitor_methods.append(make_visitor_method(name, "expr"))
expressions.append(make_class(name, cls, "Expr"))
return TEMPLATE.format( return SECTION_TEMPLATE.format(
header=HEADER, banner=make_banner(full_name),
stmt_visitor_methods="\n\n".join(stmt_visitor_methods), base=base,
expr_visitor_methods="\n\n".join(expr_visitor_methods), visitor_methods="\n\n".join(visitor_methods),
statements="\n\n\n".join(statements), classes="\n\n\n".join(classes),
expressions="\n\n\n".join(expressions),
) )
def generate(definitions_path: Path, out_path: Path):
root_dir: Path = Path(__file__).parent.parent
rel_path: Path = definitions_path.relative_to(root_dir)
src: str = definitions_path.read_text()
sections: list[str] = []
imports: str = ""
if m := IMPORTS_REGEX.search(src):
imports = m.group("body").strip("\n")
for section_m in SECTION_REGEX.finditer(src):
full_name: str = section_m.group("name")
base: str = section_m.group("base")
param: str = section_m.group("param") or base.lower()
body: str = section_m.group("body")
sections.append(make_section(full_name, base, param, body))
result: str = TEMPLATE.format(
header=HEADER.format(
defs_path=rel_path,
gen_path=Path(__file__).relative_to(root_dir),
),
imports=imports,
sections="\n\n\n".join(sections),
)
out_path.write_text(result)
def main(): def main():
root: Path = Path(__file__).parent.parent root: Path = Path(__file__).parent.parent
in_path: Path = root / "gen" / "ast.py" defs_dir: Path = root / "gen"
out_path: Path = root / "core" / "ast" / "midas.py" ast_dir: Path = root / "midas" / "ast"
generate(defs_dir / "midas.py", ast_dir / "midas.py")
src: str = in_path.read_text() generate(defs_dir / "python.py", ast_dir / "python.py")
generated: str = generate(src)
out_path.write_text(generated)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,72 +1,110 @@
# type: ignore
# ruff: disable[F821, F401]
###> Imports
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location
from midas.lexer.token import Token
###<
###> Stmt | Statements
class SimpleTypeStmt: class SimpleTypeStmt:
name: Token name: Token
template: Optional[TemplateExpr] template: Optional[TemplateExpr]
base: TypeExpr base: TypeExpr
constraint: Optional[Expr] constraint: Optional[Expr]
class SimpleTypeExpr:
name: Token
optional: bool
class LogicalExpr:
left: Expr
operator: Token
right: Expr
class BinaryExpr:
left: Expr
operator: Token
right: Expr
class UnaryExpr:
operator: Token
right: Expr
class GetExpr:
expr: Expr
name: Token
class VariableExpr:
name: Token
class GroupingExpr:
expr: Expr
class LiteralExpr:
value: Any
class WildcardExpr:
token: Token
class TemplateExpr:
type: TypeExpr
class TypeExpr:
name: Token
template: Optional[TemplateExpr]
optional: bool
class ComplexTypeStmt: class ComplexTypeStmt:
name: Token name: Token
template: Optional[TemplateExpr] template: Optional[TemplateExpr]
properties: list[PropertyStmt] properties: list[PropertyStmt]
class PropertyStmt: class PropertyStmt:
name: Token name: Token
type: TypeExpr type: TypeExpr
constraint: Optional[Expr] constraint: Optional[Expr]
class ExtendStmt: class ExtendStmt:
type: TypeExpr type: TypeExpr
operations: list[OpStmt] operations: list[OpStmt]
class OpStmt: class OpStmt:
name: Token name: Token
operand: TypeExpr operand: TypeExpr
result: TypeExpr result: TypeExpr
class PredicateStmt: class PredicateStmt:
name: Token name: Token
subject: Token subject: Token
type: TypeExpr type: TypeExpr
condition: Expr condition: Expr
###<
###> Expr | Expressions
class SimpleTypeExpr:
name: Token
optional: bool
class LogicalExpr:
left: Expr
operator: Token
right: Expr
class BinaryExpr:
left: Expr
operator: Token
right: Expr
class UnaryExpr:
operator: Token
right: Expr
class GetExpr:
expr: Expr
name: Token
class VariableExpr:
name: Token
class GroupingExpr:
expr: Expr
class LiteralExpr:
value: Any
class WildcardExpr:
token: Token
class TemplateExpr:
type: TypeExpr
class TypeExpr:
name: Token
template: Optional[TemplateExpr]
optional: bool
###<

119
gen/python.py Normal file
View File

@@ -0,0 +1,119 @@
# type: ignore
# ruff: disable[F821, F401]
###> Imports
import ast
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location
###<
###> MidasType | Type annotations | node
class BaseType:
base: str
param: Optional[MidasType]
class ConstraintType:
type: MidasType
constraint: ast.expr
class FrameColumn:
name: Optional[str]
type: Optional[MidasType]
class FrameType:
columns: list[FrameColumn]
###<
###> Stmt | Statements
class ExpressionStmt:
expr: Expr
class Function:
name: str
posonlyargs: list[Argument]
args: list[Argument]
kwonlyargs: list[Argument]
returns: Optional[MidasType]
@dataclass(frozen=True, kw_only=True)
class Argument:
location: Optional[Location] = None
name: Optional[str]
type: Optional[MidasType]
class TypeAssign:
name: str
type: MidasType
class AssignStmt:
targets: list[Expr]
value: Expr
###<
###> Expr | Expressions
class BinaryExpr:
left: Expr
operator: ast.operator
right: Expr
class CompareExpr:
left: Expr
operator: ast.cmpop
right: Expr
class UnaryExpr:
operator: ast.unaryop
right: Expr
class CallExpr:
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
class GetExpr:
object: Expr
name: str
class LiteralExpr:
value: Any
class VariableExpr:
name: str
class LogicalExpr:
left: Expr
operator: ast.boolop
right: Expr
class SetExpr:
object: Expr
name: str
value: Expr
###<

View File

@@ -1,12 +0,0 @@
from lexer.token import TokenType
KEYWORDS: dict[str, TokenType] = {
"type": TokenType.TYPE,
"op": TokenType.OP,
"predicate": TokenType.PREDICATE,
"extend": TokenType.EXTEND,
"where": TokenType.WHERE,
"true": TokenType.TRUE,
"false": TokenType.FALSE,
"none": TokenType.NONE,
}

View File

@@ -1,6 +1,6 @@
from typing import Optional, Sequence from typing import Optional, Sequence
from core.ast.midas import ( from midas.ast.midas import (
BinaryExpr, BinaryExpr,
ComplexTypeStmt, ComplexTypeStmt,
Expr, Expr,

37
midas/ast/location.py Normal file
View File

@@ -0,0 +1,37 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Protocol
class HasLocation(Protocol):
lineno: int
col_offset: int
end_lineno: Optional[int]
end_col_offset: Optional[int]
@dataclass(frozen=True, kw_only=True)
class Location:
lineno: int
col_offset: int
end_lineno: Optional[int]
end_col_offset: Optional[int]
@staticmethod
def from_ast(obj: HasLocation) -> Location:
return Location(
lineno=obj.lineno,
col_offset=obj.col_offset,
end_lineno=obj.end_lineno,
end_col_offset=obj.end_col_offset,
)
@staticmethod
def span(start: Location, end: Location) -> Location:
return Location(
lineno=start.lineno,
col_offset=start.col_offset,
end_lineno=end.lineno,
end_col_offset=end.end_col_offset,
)

View File

@@ -1,6 +1,6 @@
""" """
This file was generated by a script. Any manual changes might be overwritten. This file was generated by a script. Any manual changes might be overwritten.
Please modify gen/ast.py instead and run gen/gen.py Please modify gen/midas.py instead and run gen/gen.py
""" """
from __future__ import annotations from __future__ import annotations
@@ -9,7 +9,8 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar from typing import Any, Generic, Optional, TypeVar
from lexer.token import Token from midas.ast.location import Location
from midas.lexer.token import Token
T = TypeVar("T") T = TypeVar("T")
@@ -18,8 +19,10 @@ T = TypeVar("T")
############## ##############
@dataclass(frozen=True) @dataclass(frozen=True, kw_only=True)
class Stmt(ABC): class Stmt(ABC):
location: Optional[Location] = None
@abstractmethod @abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ... def accept(self, visitor: Visitor[T]) -> T: ...
@@ -109,8 +112,10 @@ class PredicateStmt(Stmt):
############### ###############
@dataclass(frozen=True) @dataclass(frozen=True, kw_only=True)
class Expr(ABC): class Expr(ABC):
location: Optional[Location] = None
@abstractmethod @abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ... def accept(self, visitor: Visitor[T]) -> T: ...

View File

@@ -1,11 +1,13 @@
from __future__ import annotations from __future__ import annotations
import ast
import io import io
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum, auto from enum import Enum, auto
from typing import Generator, Generic, Optional, Protocol, TypeVar from typing import Generator, Generic, Optional, Protocol, TypeVar
import core.ast.midas as m import midas.ast.midas as m
import midas.ast.python as p
class _Level(Enum): class _Level(Enum):
@@ -84,7 +86,7 @@ class AstPrinter(Generic[T]):
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]): class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
#Statements # Statements
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt): def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
self._write_line("SimpleTypeStmt") self._write_line("SimpleTypeStmt")
@@ -346,3 +348,205 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
def visit_type_expr(self, expr: m.TypeExpr): def visit_type_expr(self, expr: m.TypeExpr):
template: str = expr.template.accept(self) if expr.template is not None else "" template: str = expr.template.accept(self) if expr.template is not None else ""
return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}" return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}"
class PythonAstPrinter(
AstPrinter,
p.MidasType.Visitor[None],
p.Stmt.Visitor[None],
p.Expr.Visitor[None],
):
def visit_base_type(self, node: p.BaseType) -> None:
self._write_line("BaseType")
with self._child_level():
self._write_line(f"base: {node.base}")
self._write_optional_child("param", node.param, last=True)
def visit_constraint_type(self, node: p.ConstraintType) -> None:
self._write_line("ConstraintType")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
node.type.accept(self)
self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True)
def visit_frame_column(self, node: p.FrameColumn) -> None:
self._write_line("FrameColumn")
with self._child_level():
self._write_line(f"name: {node.name}")
self._write_optional_child("type", node.type, last=True)
def visit_frame_type(self, node: p.FrameType) -> None:
self._write_line("FrameType")
with self._child_level():
self._write_line("columns", last=True)
with self._child_level():
for i, col in enumerate(node.columns):
self._idx = i
if i == len(node.columns) - 1:
self._mark_last()
col.accept(self)
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
stmt.expr.accept(self)
def visit_function(self, stmt: p.Function) -> None:
self._write_line("Function")
with self._child_level():
self._write_line(f"name: {stmt.name}")
self._write_line("posonlyargs")
with self._child_level():
for i, arg in enumerate(stmt.posonlyargs):
self._idx = i
if i == len(stmt.posonlyargs) - 1:
self._mark_last()
self._print_argument(arg)
self._write_line("args")
with self._child_level():
for i, arg in enumerate(stmt.args):
self._idx = i
if i == len(stmt.args) - 1:
self._mark_last()
self._print_argument(arg)
self._write_line("kwonlyargs")
with self._child_level():
for i, arg in enumerate(stmt.kwonlyargs):
self._idx = i
if i == len(stmt.kwonlyargs) - 1:
self._mark_last()
self._print_argument(arg)
self._write_optional_child("returns", stmt.returns, last=True)
def _print_argument(self, arg: p.Function.Argument) -> None:
self._write_line("FunctionArgument")
with self._child_level():
self._write_line(f"name: {arg.name}")
self._write_optional_child("type", arg.type, last=True)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
self._write_line("TypeAssign")
with self._child_level():
self._write_line(f"name: {stmt.name}")
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
self._write_line("AssignStmt")
with self._child_level():
self._write_line("targets")
with self._child_level():
for i, target in enumerate(stmt.targets):
self._idx = i
if i == len(stmt.targets) - 1:
self._mark_last()
target.accept(self)
self._write_line("value", last=True)
with self._child_level(single=True):
stmt.value.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self._write_line("BinaryExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
self._write_line("CompareExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
self._write_line("UnaryExpr")
with self._child_level():
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_call_expr(self, expr: p.CallExpr) -> None:
self._write_line("CallExpr")
with self._child_level():
self._write_line("callee")
with self._child_level(single=True):
expr.callee.accept(self)
self._write_line("arguments")
with self._child_level():
for i, arg in enumerate(expr.arguments):
self._idx = i
if i == len(expr.arguments) - 1:
self._mark_last()
arg.accept(self)
self._write_line("keywords", last=True)
with self._child_level():
for i, (name, arg) in enumerate(expr.keywords.items()):
self._idx = i
if i == len(expr.keywords) - 1:
self._mark_last()
self._write_line(name)
with self._child_level(single=True):
arg.accept(self)
def visit_get_expr(self, expr: p.GetExpr) -> None:
self._write_line("GetExpr")
with self._child_level():
self._write_line("object")
with self._child_level(single=True):
expr.object.accept(self)
self._write_line(f"name: {expr.name}", last=True)
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
self._write_line("LiteralExpr")
with self._child_level(single=True):
self._write_line(f"value: {expr.value}")
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
self._write_line("VariableExpr")
with self._child_level(single=True):
self._write_line(f"name: {expr.name}")
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
self._write_line("LogicalExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_set_expr(self, expr: p.SetExpr) -> None:
self._write_line("SetExpr")
with self._child_level():
self._write_line("object")
with self._child_level(single=True):
expr.object.accept(self)
self._write_line(f"name: {expr.name}")
self._write_line("value", last=True)
with self._child_level(single=True):
expr.value.accept(self)

270
midas/ast/python.py Normal file
View File

@@ -0,0 +1,270 @@
"""
This file was generated by a script. Any manual changes might be overwritten.
Please modify gen/python.py instead and run gen/gen.py
"""
from __future__ import annotations
import ast
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location
T = TypeVar("T")
####################
# Type annotations #
####################
@dataclass(frozen=True, kw_only=True)
class MidasType(ABC):
location: Optional[Location] = None
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_base_type(self, node: BaseType) -> T: ...
@abstractmethod
def visit_constraint_type(self, node: ConstraintType) -> T: ...
@abstractmethod
def visit_frame_column(self, node: FrameColumn) -> T: ...
@abstractmethod
def visit_frame_type(self, node: FrameType) -> T: ...
@dataclass(frozen=True)
class BaseType(MidasType):
base: str
param: Optional[MidasType]
def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_base_type(self)
@dataclass(frozen=True)
class ConstraintType(MidasType):
type: MidasType
constraint: ast.expr
def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_constraint_type(self)
@dataclass(frozen=True)
class FrameColumn(MidasType):
name: Optional[str]
type: Optional[MidasType]
def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_frame_column(self)
@dataclass(frozen=True)
class FrameType(MidasType):
columns: list[FrameColumn]
def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_frame_type(self)
##############
# Statements #
##############
@dataclass(frozen=True, kw_only=True)
class Stmt(ABC):
location: Optional[Location] = None
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_expression_stmt(self, stmt: ExpressionStmt) -> T: ...
@abstractmethod
def visit_function(self, stmt: Function) -> T: ...
@abstractmethod
def visit_type_assign(self, stmt: TypeAssign) -> T: ...
@abstractmethod
def visit_assign_stmt(self, stmt: AssignStmt) -> T: ...
@dataclass(frozen=True)
class ExpressionStmt(Stmt):
expr: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_expression_stmt(self)
@dataclass(frozen=True)
class Function(Stmt):
name: str
posonlyargs: list[Argument]
args: list[Argument]
kwonlyargs: list[Argument]
returns: Optional[MidasType]
@dataclass(frozen=True, kw_only=True)
class Argument:
location: Optional[Location] = None
name: Optional[str]
type: Optional[MidasType]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_function(self)
@dataclass(frozen=True)
class TypeAssign(Stmt):
name: str
type: MidasType
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_type_assign(self)
@dataclass(frozen=True)
class AssignStmt(Stmt):
targets: list[Expr]
value: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_assign_stmt(self)
###############
# Expressions #
###############
@dataclass(frozen=True, kw_only=True)
class Expr(ABC):
location: Optional[Location] = None
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_binary_expr(self, expr: BinaryExpr) -> T: ...
@abstractmethod
def visit_compare_expr(self, expr: CompareExpr) -> T: ...
@abstractmethod
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
@abstractmethod
def visit_call_expr(self, expr: CallExpr) -> T: ...
@abstractmethod
def visit_get_expr(self, expr: GetExpr) -> T: ...
@abstractmethod
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
@abstractmethod
def visit_variable_expr(self, expr: VariableExpr) -> T: ...
@abstractmethod
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
@abstractmethod
def visit_set_expr(self, expr: SetExpr) -> T: ...
@dataclass(frozen=True)
class BinaryExpr(Expr):
left: Expr
operator: ast.operator
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_binary_expr(self)
@dataclass(frozen=True)
class CompareExpr(Expr):
left: Expr
operator: ast.cmpop
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_compare_expr(self)
@dataclass(frozen=True)
class UnaryExpr(Expr):
operator: ast.unaryop
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_unary_expr(self)
@dataclass(frozen=True)
class CallExpr(Expr):
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_call_expr(self)
@dataclass(frozen=True)
class GetExpr(Expr):
object: Expr
name: str
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_get_expr(self)
@dataclass(frozen=True)
class LiteralExpr(Expr):
value: Any
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_literal_expr(self)
@dataclass(frozen=True)
class VariableExpr(Expr):
name: str
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_variable_expr(self)
@dataclass(frozen=True)
class LogicalExpr(Expr):
left: Expr
operator: ast.boolop
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_logical_expr(self)
@dataclass(frozen=True)
class SetExpr(Expr):
object: Expr
name: str
value: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_set_expr(self)

0
midas/cli/__init__.py Normal file
View File

57
midas/cli/highlight.css Normal file
View File

@@ -0,0 +1,57 @@
html,
body {
margin: 0;
font-size: 14pt;
}
* {
box-sizing: border-box;
}
#code {
display: flex;
flex-direction: column;
font-family: monospace;
white-space: pre-wrap;
}
.line {
display: flex;
&:nth-child(odd) {
background-color: rgb(247, 247, 247);
}
.no {
width: 4em;
text-align: right;
padding: 0.2em 0.4em;
border-right: solid black 1px;
flex-shrink: 0;
}
.txt {
flex-grow: 1;
padding: 0.2em 0.8em;
}
}
span {
--col: transparent;
--opacity: 0.1;
--border: 0px;
background-color: rgba(var(--col), var(--opacity));
outline: solid rgb(var(--col)) var(--border);
outline-offset: 2px;
border-radius: 2px;
&:hover:not(:has(*:hover)) {
--opacity: 0.8;
--border: 2px;
z-index: 10;
}
&.keyword {
color: rgb(211, 72, 9);
}
}

258
midas/cli/highlighter.py Normal file
View File

@@ -0,0 +1,258 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Generic, Optional, Protocol, TextIO, TypeVar
from midas.ast.location import Location
import midas.ast.midas as m
import midas.ast.python as p
H = TypeVar("H", bound="Highlighter", contravariant=True)
class Highlightable(Protocol, Generic[H]):
def accept(self, visitor: H): ...
class Locatable(Protocol):
@property
@abstractmethod
def location(self) -> Optional[Location]: ...
class Highlighter(ABC):
BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css"
EXTRA_CSS_PATH: Optional[Path] = None
def __init__(self, source: str) -> None:
self.source: str = source
self.lines: list[str] = self.source.splitlines()
self.openings: dict[tuple[int, int], list[str]] = {}
self.closings: dict[tuple[int, int], list[str]] = {}
def format_css(self, path: Path) -> list[str]:
css: str = path.read_text()
css = "\n".join((" " + line).rstrip() for line in css.splitlines())
return [
" <style>",
css,
" </style>",
]
def dump(self, buf: TextIO):
base_css: list[str] = self.format_css(self.BASE_CSS_PATH)
extra_css: list[str] = (
self.format_css(self.EXTRA_CSS_PATH)
if self.EXTRA_CSS_PATH is not None
else []
)
lines: list[str] = [
"<!DOCTYPE html>",
'<html lang="en">',
"<head>",
' <meta charset="UTF-8">',
' <meta name="viewport" content="width=device-width, initial-scale=1.0">',
" <title>Highlighted file</title>",
*base_css,
*extra_css,
"</head>",
"<body>",
' <div id="code">',
]
for l, line in enumerate(self.lines):
lineno: int = l + 1
line_buf: str = (
f'<div class="line" id="l{lineno}"><div class="no">{lineno}</div><div class="txt">'
)
for c, char in enumerate(line):
pos: tuple[int, int] = (lineno, c)
closings: list[str] = self.closings.get(pos, [])
openings: list[str] = self.openings.get(pos, [])
line_buf += "".join(closings + openings)
line_buf += char
line_buf += "</div></div>"
lines.append(" " + line_buf)
lines.extend(
[
" </div>",
"</body>",
"</html>",
]
)
buf.write("\n".join(lines))
def wrap(self, node: Locatable, cls: str):
if node.location is None:
return
if node.location.end_lineno is None or node.location.end_col_offset is None:
return
start_pos: tuple[int, int] = (node.location.lineno, node.location.col_offset)
end_pos: tuple[int, int] = (
node.location.end_lineno,
node.location.end_col_offset,
)
opening: str = f'<span class="{cls}" title="{cls}">'
closing: str = "</span>"
self.openings.setdefault(start_pos, []).append(opening)
self.closings.setdefault(end_pos, []).insert(0, closing)
if start_pos[0] != end_pos[0]:
for l in range(start_pos[0], end_pos[0]):
c: int = len(self.lines[l - 1])
self.closings.setdefault((l, c), []).insert(0, closing)
self.openings.setdefault((l + 1, 0), []).append(opening)
class PythonHighlighter(
Highlighter,
p.MidasType.Visitor[None],
p.Stmt.Visitor[None],
p.Expr.Visitor[None],
):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_python.css"
def highlight(self, node: Highlightable[PythonHighlighter]):
node.accept(self)
def visit_base_type(self, node: p.BaseType) -> None:
self.wrap(node, "base-type")
if node.param is not None:
self.wrap(node.param, "param")
node.param.accept(self)
def visit_constraint_type(self, node: p.ConstraintType) -> None:
self.wrap(node, "constraint-type")
node.type.accept(self)
def visit_frame_column(self, node: p.FrameColumn) -> None:
self.wrap(node, "frame-column")
if node.type is not None:
node.type.accept(self)
def visit_frame_type(self, node: p.FrameType) -> None:
self.wrap(node, "frame-type")
for column in node.columns:
column.accept(self)
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
stmt.expr.accept(self)
def visit_function(self, stmt: p.Function) -> None:
self.wrap(stmt, "function")
for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs:
self._highlight_function_argument(arg)
def _highlight_function_argument(self, arg: p.Function.Argument) -> None:
self.wrap(arg, "argument")
if arg.type is not None:
arg.type.accept(self)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
stmt.type.accept(self)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: ...
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
def visit_compare_expr(self, expr: p.CompareExpr) -> None: ...
def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ...
def visit_call_expr(self, expr: p.CallExpr) -> None: ...
def visit_get_expr(self, expr: p.GetExpr) -> None: ...
def visit_literal_expr(self, expr: p.LiteralExpr) -> None: ...
def visit_variable_expr(self, expr: p.VariableExpr) -> None: ...
def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ...
def visit_set_expr(self, expr: p.SetExpr) -> None: ...
class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"
def highlight(self, node: Highlightable[MidasHighlighter]):
node.accept(self)
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None:
self.wrap(stmt, "simple-type")
if stmt.template is not None:
stmt.template.accept(self)
stmt.base.accept(self)
if stmt.constraint is not None:
self.wrap(stmt.constraint, "constraint")
stmt.constraint.accept(self)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None:
self.wrap(stmt, "complex-type")
if stmt.template is not None:
stmt.template.accept(self)
for prop in stmt.properties:
prop.accept(self)
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None:
self.wrap(stmt, "property")
stmt.type.accept(self)
if stmt.constraint is not None:
self.wrap(stmt.constraint, "constraint")
stmt.constraint.accept(self)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self.wrap(stmt, "extend")
stmt.type.accept(self)
for op in stmt.operations:
op.accept(self)
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
self.wrap(stmt, "op")
stmt.operand.accept(self)
stmt.result.accept(self)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.wrap(stmt, "predicate")
stmt.type.accept(self)
stmt.condition.accept(self)
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> None:
self.wrap(expr, "simple-type-expr")
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.wrap(expr, "logical-expr")
expr.left.accept(self)
expr.right.accept(self)
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
self.wrap(expr, "binary-expr")
expr.left.accept(self)
expr.right.accept(self)
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
self.wrap(expr, "unary-expr")
expr.right.accept(self)
def visit_get_expr(self, expr: m.GetExpr) -> None:
self.wrap(expr, "get-expr")
expr.expr.accept(self)
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
self.wrap(expr, "variable")
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
self.wrap(expr, "template")
expr.type.accept(self)
def visit_type_expr(self, expr: m.TypeExpr) -> None:
self.wrap(expr, "type")
if expr.template is not None:
expr.template.accept(self)

55
midas/cli/hl_midas.css Normal file
View File

@@ -0,0 +1,55 @@
span {
&.comment {
--col: 200, 200, 200;
color: rgb(110, 110, 110);
font-style: italic;
}
&.simple-type {
--col: 108, 233, 108;
}
&.complex-type {
--col: 233, 206, 108;
}
&.constraint {
--col: 233, 108, 108;
}
&.property {
--col: 233, 108, 176;
}
&.extend {
--col: 108, 197, 233;
}
&.op {
--col: 108, 148, 233;
}
&.predicate {
--col: 193, 108, 233;
}
&.simple-type-expr {
--col: 150, 150, 150;
}
&.logical-expr,
&.binary-expr,
&.unary-expr,
&.get-expr {
--col: 123, 215, 193;
}
&.template {
--col: 163, 117, 71;
}
&.type {
--col: 200, 200, 200;
font-weight: bold;
}
}

29
midas/cli/hl_python.css Normal file
View File

@@ -0,0 +1,29 @@
span {
&.base-type {
--col: 108, 233, 108;
}
&.param {
--col: 103, 192, 224;
}
&.constraint-type {
--col: 174, 200, 195;
}
&.frame-column {
--col: 216, 231, 81;
}
&.frame-type {
--col: 231, 46, 40;
}
&.function {
--col: 215, 103, 224;
}
&.argument {
--col: 103, 192, 224;
}
}

111
midas/cli/main.py Normal file
View File

@@ -0,0 +1,111 @@
import ast
from dataclasses import dataclass
from typing import Optional, TextIO
import click
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import PythonAstPrinter
from midas.cli.highlighter import Highlighter, MidasHighlighter, PythonHighlighter
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token, TokenType
from midas.parser.midas import MidasParser
from midas.parser.python import PythonParser
@click.group()
def midas():
click.echo("Welcome to Midas!")
@midas.command()
@click.argument("file", type=click.File("r"))
def compile(file: TextIO):
raise NotImplementedError
@midas.group()
def utils():
pass
@utils.command()
@click.option("-o", "--output", type=click.File("w"))
@click.option("-p", "--parse", is_flag=True)
@click.argument("file", type=click.File("r"))
def dump_ast(output: Optional[TextIO], parse: bool, file: TextIO):
source: str = file.read()
tree: ast.Module = ast.parse(source, filename=file.name)
dump: str
if parse:
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
printer = PythonAstPrinter()
dump = ""
for stmt in stmts:
dump += printer.print(stmt)
dump += "\n"
else:
dump = ast.dump(tree, indent=4)
if output is None:
click.echo(dump)
else:
output.write(dump)
def highlight_python(source: str, path: str) -> Highlighter:
tree: ast.Module = ast.parse(source, filename=path)
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
highlighter = PythonHighlighter(source)
for stmt in stmts:
highlighter.highlight(stmt)
return highlighter
def highlight_midas(source: str, path: str) -> Highlighter:
lexer = MidasLexer(source, file=path)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
highlighter = MidasHighlighter(source)
for err in parser.errors:
print(err.get_report())
@dataclass(frozen=True)
class LocatableToken:
token: Token
@property
def location(self) -> Location:
return self.token.get_location()
for stmt in stmts:
highlighter.highlight(stmt)
for token in tokens:
if token.type == TokenType.COMMENT:
highlighter.wrap(LocatableToken(token), "comment")
elif token.is_keyword:
highlighter.wrap(LocatableToken(token), "keyword")
return highlighter
@utils.command()
@click.option("-o", "--output", type=click.File("w"), default="-")
@click.argument("file", type=click.File("r"))
def highlight(output: TextIO, file: TextIO):
source: str = file.read()
highlighter: Highlighter
if file.name.endswith(".py"):
highlighter = highlight_python(source, file.name)
elif file.name.endswith(".midas"):
highlighter = highlight_midas(source, file.name)
else:
raise ValueError("Unsupported file type")
highlighter.dump(output)

0
midas/lexer/__init__.py Normal file
View File

View File

@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from lexer.position import Position from midas.lexer.position import Position
from lexer.token import Token, TokenType from midas.lexer.token import Token, TokenType
class MidasSyntaxError(Exception): class MidasSyntaxError(Exception):

View File

@@ -1,6 +1,5 @@
from lexer.base import Lexer from midas.lexer.base import Lexer
from lexer.keyword import KEYWORDS from midas.lexer.token import KEYWORDS, TokenType
from lexer.token import TokenType
class MidasLexer(Lexer): class MidasLexer(Lexer):

View File

@@ -5,6 +5,7 @@ from typing import Optional
@dataclass(frozen=True) @dataclass(frozen=True)
class Position: class Position:
"""A simple structure to store the position of a token""" """A simple structure to store the position of a token"""
file: Optional[str] file: Optional[str]
line: int line: int
column: int column: int

View File

@@ -1,8 +1,11 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import Any from typing import Any
from lexer.position import Position from midas.ast.location import Location
from midas.lexer.position import Position
class TokenType(Enum): class TokenType(Enum):
@@ -55,6 +58,18 @@ class TokenType(Enum):
NEWLINE = auto() NEWLINE = auto()
KEYWORDS: dict[str, TokenType] = {
"type": TokenType.TYPE,
"op": TokenType.OP,
"predicate": TokenType.PREDICATE,
"extend": TokenType.EXTEND,
"where": TokenType.WHERE,
"true": TokenType.TRUE,
"false": TokenType.FALSE,
"none": TokenType.NONE,
}
@dataclass(frozen=True) @dataclass(frozen=True)
class Token: class Token:
"""A scanned token""" """A scanned token"""
@@ -63,3 +78,27 @@ class Token:
lexeme: str lexeme: str
value: Any value: Any
position: Position position: Position
def get_location(self) -> Location:
lineno: int = self.position.line
col_offset: int = self.position.column - 1
end_lineno = lineno
end_col_offset = col_offset
for c in self.lexeme:
end_col_offset += 1
if c == "\n":
end_lineno += 1
end_col_offset = 0
return Location(
lineno=lineno,
col_offset=col_offset,
end_lineno=end_lineno,
end_col_offset=end_col_offset,
)
def location_to(self, to: Token) -> Location:
return Location.span(self.get_location(), to.get_location())
@property
def is_keyword(self) -> bool:
return self.lexeme in KEYWORDS

View File

@@ -2,8 +2,8 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, TypeVar from typing import Generic, TypeVar
from lexer.token import Token, TokenType from midas.lexer.token import Token, TokenType
from parser.errors import ParsingError from midas.parser.errors import ParsingError
@dataclass(frozen=True) @dataclass(frozen=True)

View File

@@ -1,6 +1,7 @@
from typing import Optional from typing import Optional
from core.ast.midas import ( from midas.ast.location import Location
from midas.ast.midas import (
BinaryExpr, BinaryExpr,
ComplexTypeStmt, ComplexTypeStmt,
Expr, Expr,
@@ -21,9 +22,9 @@ from core.ast.midas import (
VariableExpr, VariableExpr,
WildcardExpr, WildcardExpr,
) )
from lexer.token import Token, TokenType from midas.lexer.token import Token, TokenType
from parser.base import Parser from midas.parser.base import Parser
from parser.errors import ParsingError from midas.parser.errors import ParsingError
class MidasParser(Parser): class MidasParser(Parser):
@@ -104,6 +105,7 @@ class MidasParser(Parser):
Returns: Returns:
TypeStmt: the parsed type declaration statement TypeStmt: the parsed type declaration statement
""" """
keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
template: Optional[TemplateExpr] = None template: Optional[TemplateExpr] = None
if self.check(TokenType.LEFT_BRACKET): if self.check(TokenType.LEFT_BRACKET):
@@ -116,11 +118,20 @@ class MidasParser(Parser):
if self.match(TokenType.WHERE): if self.match(TokenType.WHERE):
constraint = self.constraint() constraint = self.constraint()
return SimpleTypeStmt( return SimpleTypeStmt(
name=name, template=template, base=base, constraint=constraint location=keyword.location_to(self.previous()),
name=name,
template=template,
base=base,
constraint=constraint,
) )
else: else:
properties: list[PropertyStmt] = self.type_properties() properties: list[PropertyStmt] = self.type_properties()
return ComplexTypeStmt(name=name, template=template, properties=properties) return ComplexTypeStmt(
location=keyword.location_to(self.previous()),
name=name,
template=template,
properties=properties,
)
def template_expr(self) -> TemplateExpr: def template_expr(self) -> TemplateExpr:
"""Parse a generic template expression """Parse a generic template expression
@@ -130,10 +141,14 @@ class MidasParser(Parser):
Returns: Returns:
TemplateExpr: the parsed template expression TemplateExpr: the parsed template expression
""" """
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression") left: Token = self.consume(
TokenType.LEFT_BRACKET, "Missing '[' before template expression"
)
type: TypeExpr = self.type_expr() type: TypeExpr = self.type_expr()
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression") right: Token = self.consume(
return TemplateExpr(type=type) TokenType.RIGHT_BRACKET, "Missing ']' after template expression"
)
return TemplateExpr(location=left.location_to(right), type=type)
def type_expr(self) -> TypeExpr: def type_expr(self) -> TypeExpr:
"""Parse a type expression """Parse a type expression
@@ -149,7 +164,12 @@ class MidasParser(Parser):
if self.check(TokenType.LEFT_BRACKET): if self.check(TokenType.LEFT_BRACKET):
template = self.template_expr() template = self.template_expr()
optional: bool = self.match(TokenType.QMARK) optional: bool = self.match(TokenType.QMARK)
return TypeExpr(name=name, template=template, optional=optional) return TypeExpr(
location=name.location_to(self.previous()),
name=name,
template=template,
optional=optional,
)
def simple_type_expr(self) -> SimpleTypeExpr: def simple_type_expr(self) -> SimpleTypeExpr:
"""Parse a simple type expression """Parse a simple type expression
@@ -161,7 +181,9 @@ class MidasParser(Parser):
""" """
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
optional: bool = self.match(TokenType.QMARK) optional: bool = self.match(TokenType.QMARK)
return SimpleTypeExpr(name=name, optional=optional) return SimpleTypeExpr(
location=name.location_to(self.previous()), name=name, optional=optional
)
def constraint(self) -> Expr: def constraint(self) -> Expr:
"""Parse a constraint """Parse a constraint
@@ -183,7 +205,12 @@ class MidasParser(Parser):
while self.match(TokenType.AND): while self.match(TokenType.AND):
operator: Token = self.previous() operator: Token = self.previous()
right: Expr = self.equality() right: Expr = self.equality()
expr = LogicalExpr(left=expr, operator=operator, right=right) location: Optional[Location] = None
if expr.location and right.location:
location = Location.span(expr.location, right.location)
expr = LogicalExpr(
location=location, left=expr, operator=operator, right=right
)
return expr return expr
def equality(self) -> Expr: def equality(self) -> Expr:
@@ -196,7 +223,12 @@ class MidasParser(Parser):
while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL): while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL):
operator: Token = self.previous() operator: Token = self.previous()
right: Expr = self.comparison() right: Expr = self.comparison()
expr = BinaryExpr(left=expr, operator=operator, right=right) location: Optional[Location] = None
if expr.location and right.location:
location = Location.span(expr.location, right.location)
expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right
)
return expr return expr
def comparison(self) -> Expr: def comparison(self) -> Expr:
@@ -214,7 +246,12 @@ class MidasParser(Parser):
): ):
operator: Token = self.previous() operator: Token = self.previous()
right: Expr = self.unary() right: Expr = self.unary()
expr = BinaryExpr(left=expr, operator=operator, right=right) location: Optional[Location] = None
if expr.location and right.location:
location = Location.span(expr.location, right.location)
expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right
)
return expr return expr
def unary(self) -> Expr: def unary(self) -> Expr:
@@ -226,7 +263,10 @@ class MidasParser(Parser):
if self.match(TokenType.MINUS): if self.match(TokenType.MINUS):
operator: Token = self.previous() operator: Token = self.previous()
right: Expr = self.unary() right: Expr = self.unary()
return UnaryExpr(operator=operator, right=right) location: Optional[Location] = None
if right.location:
location = Location.span(operator.get_location(), right.location)
return UnaryExpr(location=location, operator=operator, right=right)
return self.reference() return self.reference()
def reference(self) -> Expr: def reference(self) -> Expr:
@@ -240,7 +280,10 @@ class MidasParser(Parser):
name: Token = self.consume( name: Token = self.consume(
TokenType.IDENTIFIER, "Expected property name after '.'" TokenType.IDENTIFIER, "Expected property name after '.'"
) )
expr = GetExpr(expr=expr, name=name) location: Optional[Location] = None
if expr.location:
location = Location.span(expr.location, name.get_location())
expr = GetExpr(location=location, expr=expr, name=name)
return expr return expr
def primary(self) -> Expr: def primary(self) -> Expr:
@@ -251,26 +294,27 @@ class MidasParser(Parser):
Returns: Returns:
Expr: the parsed expression Expr: the parsed expression
""" """
token: Token = self.peek()
if self.match(TokenType.FALSE): if self.match(TokenType.FALSE):
return LiteralExpr(False) return LiteralExpr(location=token.get_location(), value=False)
if self.match(TokenType.TRUE): if self.match(TokenType.TRUE):
return LiteralExpr(True) return LiteralExpr(location=token.get_location(), value=True)
if self.match(TokenType.NONE): if self.match(TokenType.NONE):
return LiteralExpr(None) return LiteralExpr(location=token.get_location(), value=None)
if self.match(TokenType.NUMBER): if self.match(TokenType.NUMBER):
return LiteralExpr(self.previous().value) return LiteralExpr(location=token.get_location(), value=token.value)
if self.match(TokenType.IDENTIFIER): if self.match(TokenType.IDENTIFIER):
return VariableExpr(self.previous()) return VariableExpr(location=token.get_location(), name=token)
if self.match(TokenType.UNDERSCORE): if self.match(TokenType.UNDERSCORE):
return WildcardExpr(self.previous()) return WildcardExpr(location=token.get_location(), token=token)
if self.match(TokenType.LEFT_PAREN): if self.match(TokenType.LEFT_PAREN):
expr: Expr = self.constraint() expr: Expr = self.constraint()
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis") right: Token = self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
return GroupingExpr(expr) return GroupingExpr(location=token.location_to(right), expr=expr)
raise self.error(self.peek(), "Expected expression") raise self.error(self.peek(), "Expected expression")
@@ -304,7 +348,12 @@ class MidasParser(Parser):
constraint: Optional[Expr] = None constraint: Optional[Expr] = None
if self.match(TokenType.WHERE): if self.match(TokenType.WHERE):
constraint = self.constraint() constraint = self.constraint()
return PropertyStmt(name=name, type=type, constraint=constraint) return PropertyStmt(
location=name.location_to(self.previous()),
name=name,
type=type,
constraint=constraint,
)
def extend_declaration(self) -> ExtendStmt: def extend_declaration(self) -> ExtendStmt:
"""Parse an extension definition """Parse an extension definition
@@ -314,13 +363,17 @@ class MidasParser(Parser):
Returns: Returns:
ExtendStmt: the parsed extension statement ExtendStmt: the parsed extension statement
""" """
keyword: Token = self.previous()
type: TypeExpr = self.type_expr() type: TypeExpr = self.type_expr()
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body") self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
operations: list[OpStmt] = [] operations: list[OpStmt] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE): while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
operations.append(self.op_declaration()) operations.append(self.op_declaration())
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body") self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
return ExtendStmt(type=type, operations=operations) location: Optional[Location] = None
if type.location:
location = keyword.location_to(self.previous())
return ExtendStmt(location=location, type=type, operations=operations)
def op_declaration(self) -> OpStmt: def op_declaration(self) -> OpStmt:
"""Parse an operation definition """Parse an operation definition
@@ -330,7 +383,7 @@ class MidasParser(Parser):
Returns: Returns:
OpStmt: the parsed operation statement OpStmt: the parsed operation statement
""" """
self.consume(TokenType.OP, "Expected 'op' keyword") keyword: Token = self.consume(TokenType.OP, "Expected 'op' keyword")
name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type") self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
@@ -340,7 +393,12 @@ class MidasParser(Parser):
self.consume(TokenType.ARROW, "Expected '->' before result type") self.consume(TokenType.ARROW, "Expected '->' before result type")
result: TypeExpr = self.type_expr() result: TypeExpr = self.type_expr()
return OpStmt(name=name, operand=operand, result=result) return OpStmt(
location=keyword.location_to(self.previous()),
name=name,
operand=operand,
result=result,
)
def predicate_declaration(self) -> PredicateStmt: def predicate_declaration(self) -> PredicateStmt:
"""Parse a predicate declaration """Parse a predicate declaration
@@ -350,6 +408,7 @@ class MidasParser(Parser):
Returns: Returns:
PredicateStmt: the parsed predicate declaration statement PredicateStmt: the parsed predicate declaration statement
""" """
keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject") self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name") subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
@@ -358,4 +417,10 @@ class MidasParser(Parser):
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject") self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject") self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
condition: Expr = self.constraint() condition: Expr = self.constraint()
return PredicateStmt(name=name, subject=subject, type=type, condition=condition) return PredicateStmt(
location=keyword.location_to(self.previous()),
name=name,
subject=subject,
type=type,
condition=condition,
)

343
midas/parser/python.py Normal file
View File

@@ -0,0 +1,343 @@
import ast
from typing import Optional
from midas.ast.location import Location
from midas.ast.python import (
AssignStmt,
BaseType,
BinaryExpr,
CallExpr,
CompareExpr,
ConstraintType,
Expr,
ExpressionStmt,
FrameColumn,
FrameType,
Function,
GetExpr,
LiteralExpr,
LogicalExpr,
MidasType,
Stmt,
TypeAssign,
UnaryExpr,
VariableExpr,
)
class InvalidSyntaxError(Exception):
pass
class UnsupportedSyntaxError(Exception):
def __init__(self, expr: ast.expr) -> None:
super().__init__(
f"Unsupported syntax at L{expr.lineno}:{expr.col_offset}: {ast.unparse(expr)}"
)
class PythonParser:
def parse_module(self, node: ast.Module) -> list[Stmt]:
statements: list[Stmt] = []
for stmt in node.body:
try:
parsed: None | Stmt | list[Stmt] = self.parse_stmt(stmt)
if isinstance(parsed, Stmt):
statements.append(parsed)
elif parsed is not None:
statements.extend(parsed)
except UnsupportedSyntaxError as e:
print(f"{e}, skipping")
continue
return statements
def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]:
match node:
case ast.AnnAssign():
return self.parse_annotation_assign(node)
case ast.Assign():
return self.parse_assign(node)
case ast.FunctionDef():
return self.parse_function(node)
case ast.Expr(value=expr):
return ExpressionStmt(expr=self.parse_expr(expr))
case _:
print(f"Unsupported statement: {ast.unparse(node)}")
return None
def parse_annotation_assign(self, node: ast.AnnAssign) -> list[Stmt]:
statements: list[Stmt] = []
loc: Location = Location.from_ast(node)
match node:
case ast.AnnAssign(
target=ast.Name(id=target),
annotation=annotation,
value=value,
simple=1,
):
type = self._parse_type(annotation, root=True)
if type is not None:
statements.append(
TypeAssign(
location=loc,
name=target,
type=type,
)
)
if value is not None:
statements.append(
AssignStmt(
location=loc,
targets=[
VariableExpr(
location=Location.from_ast(node.target), name=target
),
],
value=self.parse_expr(value),
),
)
case _:
print(f"Unsupported annotation: {ast.unparse(node)}")
return statements
def parse_assign(self, node: ast.Assign) -> AssignStmt:
targets: list[Expr] = []
for target in node.targets:
targets.append(self.parse_expr(target))
value: Expr = self.parse_expr(node.value)
return AssignStmt(
location=Location.from_ast(node),
targets=targets,
value=value,
)
def parse_function(self, node: ast.FunctionDef) -> Function:
loc: Location = Location.from_ast(node)
match node:
case ast.FunctionDef(
name=name,
args=ast.arguments(
posonlyargs=posonlyargs,
args=args,
kwonlyargs=kwonlyargs,
),
returns=returns,
):
def parse_args(args_list: list[ast.arg]) -> list[Function.Argument]:
return [self._parse_function_argument(arg) for arg in args_list]
return Function(
location=loc,
name=name,
posonlyargs=parse_args(posonlyargs),
args=parse_args(args),
kwonlyargs=parse_args(kwonlyargs),
returns=self._parse_type(returns) if returns is not None else None,
)
case _:
print(f"Unsupported function definition: {ast.unparse(node)}")
def _parse_function_argument(self, arg: ast.arg) -> Function.Argument:
loc: Location = Location.from_ast(arg)
name: str = arg.arg
type: Optional[MidasType] = None
if arg.annotation is not None:
type = self._parse_type(arg.annotation)
return Function.Argument(
location=loc,
name=name,
type=type,
)
def _parse_type(
self, type_expr: ast.expr, root: bool = False
) -> Optional[MidasType]:
loc: Location = Location.from_ast(type_expr)
match type_expr:
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
return self._parse_frame_type(schema)
case ast.Subscript(value=ast.Name(id=name), slice=param):
return BaseType(
location=loc,
base=name,
param=self._parse_type(param),
)
case ast.Name(id=name):
return BaseType(
location=loc,
base=name,
param=None,
)
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
left = self._parse_type(left_expr)
match left:
case None:
raise InvalidSyntaxError()
# If chained constraints, separate base type and rebuild constraint
case ConstraintType(type=left_type, constraint=left_constraint):
constraint = ast.BinOp(
left=left_constraint,
op=ast.Add(),
right=right_expr,
)
ast.copy_location(constraint, type_expr)
return ConstraintType(
location=loc,
type=left_type,
constraint=constraint,
)
case _:
return ConstraintType(
location=loc,
type=left,
constraint=right_expr,
)
case _:
if root:
return None
raise UnsupportedSyntaxError(type_expr)
def _parse_frame_type(self, schema: ast.expr) -> FrameType:
loc: Location = Location.from_ast(schema)
columns: list[FrameColumn] = []
match schema:
case ast.Tuple(elts=cols):
for col in cols:
columns.append(self._parse_frame_column(col))
case ast.Slice() | ast.Name():
columns.append(self._parse_frame_column(schema))
case _:
raise UnsupportedSyntaxError(schema)
return FrameType(location=loc, columns=columns)
def _parse_frame_column(self, column: ast.expr) -> FrameColumn:
loc: Location = Location.from_ast(column)
match column:
case ast.Name():
return FrameColumn(
location=loc,
name=None,
type=self._parse_type(column),
)
case ast.Slice(lower=ast.Name(id=name), upper=type_expr):
if name == "_":
name = None
type: Optional[MidasType] = None
match type_expr:
case None:
raise InvalidSyntaxError("Missing column type")
case ast.Name(id="_"):
type = None
case ast.expr():
type = self._parse_type(type_expr)
case _:
raise UnsupportedSyntaxError(type_expr)
return FrameColumn(location=loc, name=name, type=type)
case _:
raise UnsupportedSyntaxError(column)
def parse_expr(self, node: ast.expr) -> Expr:
match node:
case ast.BoolOp():
return self.parse_bool_op(node)
case ast.BinOp(left=left, op=op, right=right):
return BinaryExpr(
left=self.parse_expr(left),
operator=op,
right=self.parse_expr(right),
)
case ast.UnaryOp(op=op, operand=right):
return UnaryExpr(
operator=op,
right=self.parse_expr(right),
)
case ast.Compare():
return self.parse_compare(node)
case ast.Call():
return self.parse_call(node)
case ast.Constant(value=value):
return LiteralExpr(value=value)
case ast.Attribute(value=object, attr=name):
return GetExpr(
object=self.parse_expr(object),
name=name,
)
case ast.Name(id=name):
return VariableExpr(name=name)
case _:
raise UnsupportedSyntaxError(node)
def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr:
op: ast.boolop = node.op
values: list[ast.expr] = node.values
expr: LogicalExpr = LogicalExpr(
left=self.parse_expr(values[0]),
operator=op,
right=self.parse_expr(values[1]),
)
for value in values[2:]:
expr = LogicalExpr(
left=expr,
operator=op,
right=self.parse_expr(value),
)
return expr
def parse_compare(self, node: ast.Compare) -> Expr:
ops: list[ast.cmpop] = node.ops
rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators]
expr: Expr = CompareExpr(
left=self.parse_expr(node.left),
operator=ops[0],
right=rights[0],
)
for i, right in enumerate(rights[1:]):
expr = LogicalExpr(
left=expr,
operator=ast.And(),
right=CompareExpr(
left=rights[i],
operator=ops[i],
right=right,
),
)
return expr
def parse_call(self, node: ast.Call) -> CallExpr:
return CallExpr(
callee=self.parse_expr(node.func),
arguments=[self.parse_expr(arg) for arg in node.args],
keywords={
arg.arg: self.parse_expr(arg.value)
for arg in node.keywords
if arg.arg is not None # Should always be True, type checker happy
},
)

22
pyproject.toml Normal file
View File

@@ -0,0 +1,22 @@
[project]
name = "midas"
version = "0.1.0"
description = "A static-first type checking framework for Python data-frames"
readme = "README.md"
requires-python = ">=3.11"
authors = [
{ name = "Louis Heredero", email = "louis.heredero@students.hevs.ch" },
]
classifiers = ["Programming Language :: Python :: 3"]
dependencies = ["click>=8.4.1"]
[project.urls]
Homepage = "https://git.kbk28.ch/HEL/midas"
Repository = "https://git.kbk28.ch/HEL/midas"
[project.scripts]
midas = "midas.cli.main:midas"
[build-system]
requires = ['hatchling']
build-backend = 'hatchling.build'

View File

@@ -1,10 +1,10 @@
import json import json
from pathlib import Path from pathlib import Path
from core.ast.printer import MidasAstPrinter from midas.ast.printer import MidasAstPrinter
from lexer.midas import MidasLexer from midas.lexer.midas import MidasLexer
from lexer.token import Token from midas.lexer.token import Token
from parser.midas import MidasParser from midas.parser.midas import MidasParser
def test_midas(): def test_midas():

View File

@@ -8,12 +8,12 @@ from dataclasses import asdict, dataclass, field
from pathlib import Path from pathlib import Path
from typing import Iterator, Optional from typing import Iterator, Optional
from core.ast.json_serializer import AstJsonSerializer from midas.ast.json_serializer import AstJsonSerializer
from core.ast.midas import Stmt from midas.ast.midas import Stmt
from lexer.base import MidasSyntaxError from midas.lexer.base import MidasSyntaxError
from lexer.midas import MidasLexer from midas.lexer.midas import MidasLexer
from lexer.token import Token from midas.lexer.token import Token
from parser.midas import MidasParser from midas.parser.midas import MidasParser
DEFAULT_BASE_DIR: Path = Path() / "tests" DEFAULT_BASE_DIR: Path = Path() / "tests"