diff --git a/.gitignore b/.gitignore index f63541d..b540ed8 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ __pycache__ .env venv .venv -*.pyc \ No newline at end of file +*.pyc +uv.lock +.python-version \ No newline at end of file diff --git a/examples/00_syntax_prototype/02_custom_types.py b/examples/00_syntax_prototype/02_custom_types.py index 0297058..16bf442 100644 --- a/examples/00_syntax_prototype/02_custom_types.py +++ b/examples/00_syntax_prototype/02_custom_types.py @@ -21,7 +21,7 @@ lat + lon # Invalid operation # Registered operations are permitted lat1: Latitude = lat[0] lat2: Latitude = lat[1] -lat_diff: LatitudeDiff = lat2 - lat1 # Valid operation +lat_diff: Difference[Latitude] = lat2 - lat1 # Valid operation # In addition to the type, a column can have one or more constraints, either defined inline or in a separate file df2: Frame[ diff --git a/examples/00_syntax_prototype/04_functions.py b/examples/00_syntax_prototype/04_functions.py new file mode 100644 index 0000000..3b07899 --- /dev/null +++ b/examples/00_syntax_prototype/04_functions.py @@ -0,0 +1,15 @@ +# type: ignore +# ruff: disable[F821] +from __future__ import annotations + + +def func( + col1: Column[float + (0 <= _ <= 1)], + col2: Column[float + (0 <= _ <= 1)], +) -> Column[float + (0 <= _ <= 2)]: + result: Column[float + (0 <= _ <= 2)] = col1 + col2 + return result + + +def func2(a: int, /, b: float, *, c: str): + pass diff --git a/gen/gen.py b/gen/gen.py index 47cb827..75e6100 100644 --- a/gen/gen.py +++ b/gen/gen.py @@ -3,53 +3,34 @@ import re HEADER = '''""" This file was generated by a script. Any manual changes might be overwritten. -Please modify gen/ast.py instead and run gen/gen.py +Please modify {defs_path} instead and run {gen_path} """''' +SECTION_TEMPLATE = """{banner} + + +@dataclass(frozen=True, kw_only=True) +class {base}(ABC): + location: Optional[Location] = None + + @abstractmethod + def accept(self, visitor: Visitor[T]) -> T: ... + + class Visitor(ABC, Generic[T]): +{visitor_methods} + + +{classes}""" + TEMPLATE = """{header} from __future__ import annotations -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any, Generic, Optional, TypeVar - -from lexer.token import Token +{imports} T = TypeVar("T") -############## -# Statements # -############## - - -@dataclass(frozen=True) -class Stmt(ABC): - @abstractmethod - def accept(self, visitor: Visitor[T]) -> T: ... - - class Visitor(ABC, Generic[T]): -{stmt_visitor_methods} - - -{statements} - - -############### -# Expressions # -############### - - -@dataclass(frozen=True) -class Expr(ABC): - @abstractmethod - def accept(self, visitor: Visitor[T]) -> T: ... - - class Visitor(ABC, Generic[T]): -{expr_visitor_methods} - - -{expressions} +{sections} """ VISITOR_METHOD_TEMPLATE = """ @@ -66,17 +47,28 @@ class {cls}({base}): return visitor.visit_{func_name}(self) """ +SECTION_REGEX = re.compile( + r"^###>\s*(?P[^\n]*?)\s*\|\s*(?P[^\n]*?)(\s*\|\s*(?P[^\n]*?))?\s*?\n(?P.*?)\n###<$", + re.MULTILINE | re.DOTALL, +) + +IMPORTS_REGEX = re.compile( + r"^###>\s*Imports\s*?\n(?P.*?)\n###<$", + re.MULTILINE | re.DOTALL, +) + + def snake_case(text: str) -> str: return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_") + def make_visitor_method(cls: str, param: str): method: str = VISITOR_METHOD_TEMPLATE.format( - func_name=snake_case(cls), - param=param, - cls=cls + func_name=snake_case(cls), param=param, cls=cls ) return method.strip("\n") + def make_class(name: str, cls: str, base: str): body: str = cls.split("\n", 1)[1] func_name: str = snake_case(name) @@ -88,40 +80,66 @@ def make_class(name: str, cls: str, base: str): ) return cls_def.strip("\n") -def generate(src: str): - classes: list[str] = src.split("\n\n") - stmt_visitor_methods: list[str] = [] - expr_visitor_methods: list[str] = [] - statements: list[str] = [] - expressions: list[str] = [] - for cls in classes: +def make_banner(text: str) -> str: + middle: str = f"# {text} #" + rule: str = "#" * len(middle) + return "\n".join((rule, middle, rule)) + + +def make_section(full_name: str, base: str, param: str, body: str) -> str: + visitor_methods: list[str] = [] + classes: list[str] = [] + definitions: list[str] = body.strip("\n").split("\n\n\n") + for cls in definitions: cls = cls.strip("\n") name: str = re.match("class (.*?):", cls).group(1) # type: ignore print(f"Processing {name}") - if name.endswith("Stmt"): - stmt_visitor_methods.append(make_visitor_method(name, "stmt")) - statements.append(make_class(name, cls, "Stmt")) - elif name.endswith("Expr"): - expr_visitor_methods.append(make_visitor_method(name, "expr")) - expressions.append(make_class(name, cls, "Expr")) + visitor_methods.append(make_visitor_method(name, param)) + classes.append(make_class(name, cls, base)) - return TEMPLATE.format( - header=HEADER, - stmt_visitor_methods="\n\n".join(stmt_visitor_methods), - expr_visitor_methods="\n\n".join(expr_visitor_methods), - statements="\n\n\n".join(statements), - expressions="\n\n\n".join(expressions), + return SECTION_TEMPLATE.format( + banner=make_banner(full_name), + base=base, + visitor_methods="\n\n".join(visitor_methods), + classes="\n\n\n".join(classes), ) + +def generate(definitions_path: Path, out_path: Path): + root_dir: Path = Path(__file__).parent.parent + rel_path: Path = definitions_path.relative_to(root_dir) + src: str = definitions_path.read_text() + sections: list[str] = [] + + imports: str = "" + if m := IMPORTS_REGEX.search(src): + imports = m.group("body").strip("\n") + + for section_m in SECTION_REGEX.finditer(src): + full_name: str = section_m.group("name") + base: str = section_m.group("base") + param: str = section_m.group("param") or base.lower() + body: str = section_m.group("body") + sections.append(make_section(full_name, base, param, body)) + + result: str = TEMPLATE.format( + header=HEADER.format( + defs_path=rel_path, + gen_path=Path(__file__).relative_to(root_dir), + ), + imports=imports, + sections="\n\n\n".join(sections), + ) + out_path.write_text(result) + + def main(): root: Path = Path(__file__).parent.parent - in_path: Path = root / "gen" / "ast.py" - out_path: Path = root / "core" / "ast" / "midas.py" - - src: str = in_path.read_text() - generated: str = generate(src) - out_path.write_text(generated) + defs_dir: Path = root / "gen" + ast_dir: Path = root / "midas" / "ast" + generate(defs_dir / "midas.py", ast_dir / "midas.py") + generate(defs_dir / "python.py", ast_dir / "python.py") if __name__ == "__main__": diff --git a/gen/ast.py b/gen/midas.py similarity index 76% rename from gen/ast.py rename to gen/midas.py index 6fca631..7187554 100644 --- a/gen/ast.py +++ b/gen/midas.py @@ -1,72 +1,110 @@ +# type: ignore +# ruff: disable[F821, F401] + +###> Imports +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Generic, Optional, TypeVar + +from midas.ast.location import Location +from midas.lexer.token import Token + +###< + + +###> Stmt | Statements class SimpleTypeStmt: name: Token template: Optional[TemplateExpr] base: TypeExpr constraint: Optional[Expr] -class SimpleTypeExpr: - name: Token - optional: bool - -class LogicalExpr: - left: Expr - operator: Token - right: Expr - -class BinaryExpr: - left: Expr - operator: Token - right: Expr - -class UnaryExpr: - operator: Token - right: Expr - -class GetExpr: - expr: Expr - name: Token - -class VariableExpr: - name: Token - -class GroupingExpr: - expr: Expr - -class LiteralExpr: - value: Any - -class WildcardExpr: - token: Token - -class TemplateExpr: - type: TypeExpr - -class TypeExpr: - name: Token - template: Optional[TemplateExpr] - optional: bool class ComplexTypeStmt: name: Token template: Optional[TemplateExpr] properties: list[PropertyStmt] + class PropertyStmt: name: Token type: TypeExpr constraint: Optional[Expr] + class ExtendStmt: type: TypeExpr operations: list[OpStmt] + class OpStmt: name: Token operand: TypeExpr result: TypeExpr + class PredicateStmt: name: Token subject: Token type: TypeExpr condition: Expr + + +###< + + +###> Expr | Expressions +class SimpleTypeExpr: + name: Token + optional: bool + + +class LogicalExpr: + left: Expr + operator: Token + right: Expr + + +class BinaryExpr: + left: Expr + operator: Token + right: Expr + + +class UnaryExpr: + operator: Token + right: Expr + + +class GetExpr: + expr: Expr + name: Token + + +class VariableExpr: + name: Token + + +class GroupingExpr: + expr: Expr + + +class LiteralExpr: + value: Any + + +class WildcardExpr: + token: Token + + +class TemplateExpr: + type: TypeExpr + + +class TypeExpr: + name: Token + template: Optional[TemplateExpr] + optional: bool + + +###< diff --git a/gen/python.py b/gen/python.py new file mode 100644 index 0000000..0aadd57 --- /dev/null +++ b/gen/python.py @@ -0,0 +1,119 @@ +# type: ignore +# ruff: disable[F821, F401] + +###> Imports +import ast +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Generic, Optional, TypeVar + +from midas.ast.location import Location + +###< + + +###> MidasType | Type annotations | node +class BaseType: + base: str + param: Optional[MidasType] + + +class ConstraintType: + type: MidasType + constraint: ast.expr + + +class FrameColumn: + name: Optional[str] + type: Optional[MidasType] + + +class FrameType: + columns: list[FrameColumn] + + +###< + + +###> Stmt | Statements +class ExpressionStmt: + expr: Expr + + +class Function: + name: str + posonlyargs: list[Argument] + args: list[Argument] + kwonlyargs: list[Argument] + returns: Optional[MidasType] + + @dataclass(frozen=True, kw_only=True) + class Argument: + location: Optional[Location] = None + name: Optional[str] + type: Optional[MidasType] + + +class TypeAssign: + name: str + type: MidasType + + +class AssignStmt: + targets: list[Expr] + value: Expr + + +###< + + +###> Expr | Expressions +class BinaryExpr: + left: Expr + operator: ast.operator + right: Expr + + +class CompareExpr: + left: Expr + operator: ast.cmpop + right: Expr + + +class UnaryExpr: + operator: ast.unaryop + right: Expr + + +class CallExpr: + callee: Expr + arguments: list[Expr] + keywords: dict[str, Expr] + + +class GetExpr: + object: Expr + name: str + + +class LiteralExpr: + value: Any + + +class VariableExpr: + name: str + + +class LogicalExpr: + left: Expr + operator: ast.boolop + right: Expr + + +class SetExpr: + object: Expr + name: str + value: Expr + + +###< diff --git a/lexer/keyword.py b/lexer/keyword.py deleted file mode 100644 index e5c4b64..0000000 --- a/lexer/keyword.py +++ /dev/null @@ -1,12 +0,0 @@ -from lexer.token import TokenType - -KEYWORDS: dict[str, TokenType] = { - "type": TokenType.TYPE, - "op": TokenType.OP, - "predicate": TokenType.PREDICATE, - "extend": TokenType.EXTEND, - "where": TokenType.WHERE, - "true": TokenType.TRUE, - "false": TokenType.FALSE, - "none": TokenType.NONE, -} diff --git a/lexer/__init__.py b/midas/__init__.py similarity index 100% rename from lexer/__init__.py rename to midas/__init__.py diff --git a/core/ast/json_serializer.py b/midas/ast/json_serializer.py similarity index 99% rename from core/ast/json_serializer.py rename to midas/ast/json_serializer.py index 0064726..d602117 100644 --- a/core/ast/json_serializer.py +++ b/midas/ast/json_serializer.py @@ -1,6 +1,6 @@ from typing import Optional, Sequence -from core.ast.midas import ( +from midas.ast.midas import ( BinaryExpr, ComplexTypeStmt, Expr, diff --git a/midas/ast/location.py b/midas/ast/location.py new file mode 100644 index 0000000..47fe360 --- /dev/null +++ b/midas/ast/location.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Protocol + + +class HasLocation(Protocol): + lineno: int + col_offset: int + end_lineno: Optional[int] + end_col_offset: Optional[int] + + +@dataclass(frozen=True, kw_only=True) +class Location: + lineno: int + col_offset: int + end_lineno: Optional[int] + end_col_offset: Optional[int] + + @staticmethod + def from_ast(obj: HasLocation) -> Location: + return Location( + lineno=obj.lineno, + col_offset=obj.col_offset, + end_lineno=obj.end_lineno, + end_col_offset=obj.end_col_offset, + ) + + @staticmethod + def span(start: Location, end: Location) -> Location: + return Location( + lineno=start.lineno, + col_offset=start.col_offset, + end_lineno=end.lineno, + end_col_offset=end.end_col_offset, + ) diff --git a/core/ast/midas.py b/midas/ast/midas.py similarity index 95% rename from core/ast/midas.py rename to midas/ast/midas.py index f4280fb..9cea8c2 100644 --- a/core/ast/midas.py +++ b/midas/ast/midas.py @@ -1,6 +1,6 @@ """ This file was generated by a script. Any manual changes might be overwritten. -Please modify gen/ast.py instead and run gen/gen.py +Please modify gen/midas.py instead and run gen/gen.py """ from __future__ import annotations @@ -9,7 +9,8 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Generic, Optional, TypeVar -from lexer.token import Token +from midas.ast.location import Location +from midas.lexer.token import Token T = TypeVar("T") @@ -18,8 +19,10 @@ T = TypeVar("T") ############## -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class Stmt(ABC): + location: Optional[Location] = None + @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... @@ -109,8 +112,10 @@ class PredicateStmt(Stmt): ############### -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class Expr(ABC): + location: Optional[Location] = None + @abstractmethod def accept(self, visitor: Visitor[T]) -> T: ... diff --git a/core/ast/printer.py b/midas/ast/printer.py similarity index 61% rename from core/ast/printer.py rename to midas/ast/printer.py index 61fede8..e3ecde9 100644 --- a/core/ast/printer.py +++ b/midas/ast/printer.py @@ -1,11 +1,13 @@ from __future__ import annotations +import ast import io from contextlib import contextmanager from enum import Enum, auto from typing import Generator, Generic, Optional, Protocol, TypeVar -import core.ast.midas as m +import midas.ast.midas as m +import midas.ast.python as p class _Level(Enum): @@ -84,7 +86,7 @@ class AstPrinter(Generic[T]): class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]): - #Statements + # Statements def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt): self._write_line("SimpleTypeStmt") @@ -346,3 +348,205 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): def visit_type_expr(self, expr: m.TypeExpr): template: str = expr.template.accept(self) if expr.template is not None else "" return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}" + + +class PythonAstPrinter( + AstPrinter, + p.MidasType.Visitor[None], + p.Stmt.Visitor[None], + p.Expr.Visitor[None], +): + def visit_base_type(self, node: p.BaseType) -> None: + self._write_line("BaseType") + with self._child_level(): + self._write_line(f"base: {node.base}") + self._write_optional_child("param", node.param, last=True) + + def visit_constraint_type(self, node: p.ConstraintType) -> None: + self._write_line("ConstraintType") + with self._child_level(): + self._write_line("type") + with self._child_level(single=True): + node.type.accept(self) + self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True) + + def visit_frame_column(self, node: p.FrameColumn) -> None: + self._write_line("FrameColumn") + with self._child_level(): + self._write_line(f"name: {node.name}") + self._write_optional_child("type", node.type, last=True) + + def visit_frame_type(self, node: p.FrameType) -> None: + self._write_line("FrameType") + with self._child_level(): + self._write_line("columns", last=True) + with self._child_level(): + for i, col in enumerate(node.columns): + self._idx = i + if i == len(node.columns) - 1: + self._mark_last() + col.accept(self) + + def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: + stmt.expr.accept(self) + + def visit_function(self, stmt: p.Function) -> None: + self._write_line("Function") + with self._child_level(): + self._write_line(f"name: {stmt.name}") + + self._write_line("posonlyargs") + with self._child_level(): + for i, arg in enumerate(stmt.posonlyargs): + self._idx = i + if i == len(stmt.posonlyargs) - 1: + self._mark_last() + self._print_argument(arg) + + self._write_line("args") + with self._child_level(): + for i, arg in enumerate(stmt.args): + self._idx = i + if i == len(stmt.args) - 1: + self._mark_last() + self._print_argument(arg) + + self._write_line("kwonlyargs") + with self._child_level(): + for i, arg in enumerate(stmt.kwonlyargs): + self._idx = i + if i == len(stmt.kwonlyargs) - 1: + self._mark_last() + self._print_argument(arg) + + self._write_optional_child("returns", stmt.returns, last=True) + + def _print_argument(self, arg: p.Function.Argument) -> None: + self._write_line("FunctionArgument") + with self._child_level(): + self._write_line(f"name: {arg.name}") + self._write_optional_child("type", arg.type, last=True) + + def visit_type_assign(self, stmt: p.TypeAssign) -> None: + self._write_line("TypeAssign") + with self._child_level(): + self._write_line(f"name: {stmt.name}") + self._write_line("type", last=True) + with self._child_level(single=True): + stmt.type.accept(self) + + def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: + self._write_line("AssignStmt") + with self._child_level(): + self._write_line("targets") + with self._child_level(): + for i, target in enumerate(stmt.targets): + self._idx = i + if i == len(stmt.targets) - 1: + self._mark_last() + target.accept(self) + self._write_line("value", last=True) + with self._child_level(single=True): + stmt.value.accept(self) + + def visit_binary_expr(self, expr: p.BinaryExpr) -> None: + self._write_line("BinaryExpr") + with self._child_level(): + self._write_line("left") + with self._child_level(single=True): + expr.left.accept(self) + + self._write_line(f"operator: {expr.operator.__class__.__name__}") + + self._write_line("right", last=True) + with self._child_level(single=True): + expr.right.accept(self) + + def visit_compare_expr(self, expr: p.CompareExpr) -> None: + self._write_line("CompareExpr") + with self._child_level(): + self._write_line("left") + with self._child_level(single=True): + expr.left.accept(self) + + self._write_line(f"operator: {expr.operator.__class__.__name__}") + + self._write_line("right", last=True) + with self._child_level(single=True): + expr.right.accept(self) + + def visit_unary_expr(self, expr: p.UnaryExpr) -> None: + self._write_line("UnaryExpr") + with self._child_level(): + self._write_line(f"operator: {expr.operator.__class__.__name__}") + + self._write_line("right", last=True) + with self._child_level(single=True): + expr.right.accept(self) + + def visit_call_expr(self, expr: p.CallExpr) -> None: + self._write_line("CallExpr") + with self._child_level(): + self._write_line("callee") + with self._child_level(single=True): + expr.callee.accept(self) + + self._write_line("arguments") + with self._child_level(): + for i, arg in enumerate(expr.arguments): + self._idx = i + if i == len(expr.arguments) - 1: + self._mark_last() + arg.accept(self) + + self._write_line("keywords", last=True) + with self._child_level(): + for i, (name, arg) in enumerate(expr.keywords.items()): + self._idx = i + if i == len(expr.keywords) - 1: + self._mark_last() + self._write_line(name) + with self._child_level(single=True): + arg.accept(self) + + def visit_get_expr(self, expr: p.GetExpr) -> None: + self._write_line("GetExpr") + with self._child_level(): + self._write_line("object") + with self._child_level(single=True): + expr.object.accept(self) + self._write_line(f"name: {expr.name}", last=True) + + def visit_literal_expr(self, expr: p.LiteralExpr) -> None: + self._write_line("LiteralExpr") + with self._child_level(single=True): + self._write_line(f"value: {expr.value}") + + def visit_variable_expr(self, expr: p.VariableExpr) -> None: + self._write_line("VariableExpr") + with self._child_level(single=True): + self._write_line(f"name: {expr.name}") + + def visit_logical_expr(self, expr: p.LogicalExpr) -> None: + self._write_line("LogicalExpr") + with self._child_level(): + self._write_line("left") + with self._child_level(single=True): + expr.left.accept(self) + + self._write_line(f"operator: {expr.operator.__class__.__name__}") + + self._write_line("right", last=True) + with self._child_level(single=True): + expr.right.accept(self) + + def visit_set_expr(self, expr: p.SetExpr) -> None: + self._write_line("SetExpr") + with self._child_level(): + self._write_line("object") + with self._child_level(single=True): + expr.object.accept(self) + self._write_line(f"name: {expr.name}") + self._write_line("value", last=True) + with self._child_level(single=True): + expr.value.accept(self) diff --git a/midas/ast/python.py b/midas/ast/python.py new file mode 100644 index 0000000..d4fc032 --- /dev/null +++ b/midas/ast/python.py @@ -0,0 +1,270 @@ +""" +This file was generated by a script. Any manual changes might be overwritten. +Please modify gen/python.py instead and run gen/gen.py +""" + +from __future__ import annotations + +import ast +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Generic, Optional, TypeVar + +from midas.ast.location import Location + +T = TypeVar("T") + +#################### +# Type annotations # +#################### + + +@dataclass(frozen=True, kw_only=True) +class MidasType(ABC): + location: Optional[Location] = None + + @abstractmethod + def accept(self, visitor: Visitor[T]) -> T: ... + + class Visitor(ABC, Generic[T]): + @abstractmethod + def visit_base_type(self, node: BaseType) -> T: ... + + @abstractmethod + def visit_constraint_type(self, node: ConstraintType) -> T: ... + + @abstractmethod + def visit_frame_column(self, node: FrameColumn) -> T: ... + + @abstractmethod + def visit_frame_type(self, node: FrameType) -> T: ... + + +@dataclass(frozen=True) +class BaseType(MidasType): + base: str + param: Optional[MidasType] + + def accept(self, visitor: MidasType.Visitor[T]) -> T: + return visitor.visit_base_type(self) + + +@dataclass(frozen=True) +class ConstraintType(MidasType): + type: MidasType + constraint: ast.expr + + def accept(self, visitor: MidasType.Visitor[T]) -> T: + return visitor.visit_constraint_type(self) + + +@dataclass(frozen=True) +class FrameColumn(MidasType): + name: Optional[str] + type: Optional[MidasType] + + def accept(self, visitor: MidasType.Visitor[T]) -> T: + return visitor.visit_frame_column(self) + + +@dataclass(frozen=True) +class FrameType(MidasType): + columns: list[FrameColumn] + + def accept(self, visitor: MidasType.Visitor[T]) -> T: + return visitor.visit_frame_type(self) + + +############## +# Statements # +############## + + +@dataclass(frozen=True, kw_only=True) +class Stmt(ABC): + location: Optional[Location] = None + + @abstractmethod + def accept(self, visitor: Visitor[T]) -> T: ... + + class Visitor(ABC, Generic[T]): + @abstractmethod + def visit_expression_stmt(self, stmt: ExpressionStmt) -> T: ... + + @abstractmethod + def visit_function(self, stmt: Function) -> T: ... + + @abstractmethod + def visit_type_assign(self, stmt: TypeAssign) -> T: ... + + @abstractmethod + def visit_assign_stmt(self, stmt: AssignStmt) -> T: ... + + +@dataclass(frozen=True) +class ExpressionStmt(Stmt): + expr: Expr + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_expression_stmt(self) + + +@dataclass(frozen=True) +class Function(Stmt): + name: str + posonlyargs: list[Argument] + args: list[Argument] + kwonlyargs: list[Argument] + returns: Optional[MidasType] + + @dataclass(frozen=True, kw_only=True) + class Argument: + location: Optional[Location] = None + name: Optional[str] + type: Optional[MidasType] + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_function(self) + + +@dataclass(frozen=True) +class TypeAssign(Stmt): + name: str + type: MidasType + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_type_assign(self) + + +@dataclass(frozen=True) +class AssignStmt(Stmt): + targets: list[Expr] + value: Expr + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_assign_stmt(self) + + +############### +# Expressions # +############### + + +@dataclass(frozen=True, kw_only=True) +class Expr(ABC): + location: Optional[Location] = None + + @abstractmethod + def accept(self, visitor: Visitor[T]) -> T: ... + + class Visitor(ABC, Generic[T]): + @abstractmethod + def visit_binary_expr(self, expr: BinaryExpr) -> T: ... + + @abstractmethod + def visit_compare_expr(self, expr: CompareExpr) -> T: ... + + @abstractmethod + def visit_unary_expr(self, expr: UnaryExpr) -> T: ... + + @abstractmethod + def visit_call_expr(self, expr: CallExpr) -> T: ... + + @abstractmethod + def visit_get_expr(self, expr: GetExpr) -> T: ... + + @abstractmethod + def visit_literal_expr(self, expr: LiteralExpr) -> T: ... + + @abstractmethod + def visit_variable_expr(self, expr: VariableExpr) -> T: ... + + @abstractmethod + def visit_logical_expr(self, expr: LogicalExpr) -> T: ... + + @abstractmethod + def visit_set_expr(self, expr: SetExpr) -> T: ... + + +@dataclass(frozen=True) +class BinaryExpr(Expr): + left: Expr + operator: ast.operator + right: Expr + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_binary_expr(self) + + +@dataclass(frozen=True) +class CompareExpr(Expr): + left: Expr + operator: ast.cmpop + right: Expr + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_compare_expr(self) + + +@dataclass(frozen=True) +class UnaryExpr(Expr): + operator: ast.unaryop + right: Expr + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_unary_expr(self) + + +@dataclass(frozen=True) +class CallExpr(Expr): + callee: Expr + arguments: list[Expr] + keywords: dict[str, Expr] + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_call_expr(self) + + +@dataclass(frozen=True) +class GetExpr(Expr): + object: Expr + name: str + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_get_expr(self) + + +@dataclass(frozen=True) +class LiteralExpr(Expr): + value: Any + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_literal_expr(self) + + +@dataclass(frozen=True) +class VariableExpr(Expr): + name: str + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_variable_expr(self) + + +@dataclass(frozen=True) +class LogicalExpr(Expr): + left: Expr + operator: ast.boolop + right: Expr + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_logical_expr(self) + + +@dataclass(frozen=True) +class SetExpr(Expr): + object: Expr + name: str + value: Expr + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_set_expr(self) diff --git a/midas/cli/__init__.py b/midas/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/midas/cli/highlight.css b/midas/cli/highlight.css new file mode 100644 index 0000000..31f005d --- /dev/null +++ b/midas/cli/highlight.css @@ -0,0 +1,57 @@ +html, +body { + margin: 0; + font-size: 14pt; +} + +* { + box-sizing: border-box; +} + +#code { + display: flex; + flex-direction: column; + font-family: monospace; + white-space: pre-wrap; +} + +.line { + display: flex; + + &:nth-child(odd) { + background-color: rgb(247, 247, 247); + } + + .no { + width: 4em; + text-align: right; + padding: 0.2em 0.4em; + border-right: solid black 1px; + flex-shrink: 0; + } + + .txt { + flex-grow: 1; + padding: 0.2em 0.8em; + } +} + +span { + --col: transparent; + --opacity: 0.1; + --border: 0px; + background-color: rgba(var(--col), var(--opacity)); + outline: solid rgb(var(--col)) var(--border); + outline-offset: 2px; + border-radius: 2px; + + &:hover:not(:has(*:hover)) { + --opacity: 0.8; + --border: 2px; + z-index: 10; + } + + &.keyword { + color: rgb(211, 72, 9); + } +} \ No newline at end of file diff --git a/midas/cli/highlighter.py b/midas/cli/highlighter.py new file mode 100644 index 0000000..f4801bb --- /dev/null +++ b/midas/cli/highlighter.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Generic, Optional, Protocol, TextIO, TypeVar + +from midas.ast.location import Location +import midas.ast.midas as m +import midas.ast.python as p + +H = TypeVar("H", bound="Highlighter", contravariant=True) + + +class Highlightable(Protocol, Generic[H]): + def accept(self, visitor: H): ... + + +class Locatable(Protocol): + @property + @abstractmethod + def location(self) -> Optional[Location]: ... + + +class Highlighter(ABC): + BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css" + EXTRA_CSS_PATH: Optional[Path] = None + + def __init__(self, source: str) -> None: + self.source: str = source + self.lines: list[str] = self.source.splitlines() + self.openings: dict[tuple[int, int], list[str]] = {} + self.closings: dict[tuple[int, int], list[str]] = {} + + def format_css(self, path: Path) -> list[str]: + css: str = path.read_text() + css = "\n".join((" " + line).rstrip() for line in css.splitlines()) + return [ + " ", + ] + + def dump(self, buf: TextIO): + base_css: list[str] = self.format_css(self.BASE_CSS_PATH) + extra_css: list[str] = ( + self.format_css(self.EXTRA_CSS_PATH) + if self.EXTRA_CSS_PATH is not None + else [] + ) + lines: list[str] = [ + "", + '', + "", + ' ', + ' ', + " Highlighted file", + *base_css, + *extra_css, + "", + "", + '
', + ] + for l, line in enumerate(self.lines): + lineno: int = l + 1 + line_buf: str = ( + f'
{lineno}
' + ) + for c, char in enumerate(line): + pos: tuple[int, int] = (lineno, c) + closings: list[str] = self.closings.get(pos, []) + openings: list[str] = self.openings.get(pos, []) + line_buf += "".join(closings + openings) + line_buf += char + line_buf += "
" + lines.append(" " + line_buf) + lines.extend( + [ + "
", + "", + "", + ] + ) + + buf.write("\n".join(lines)) + + def wrap(self, node: Locatable, cls: str): + if node.location is None: + return + if node.location.end_lineno is None or node.location.end_col_offset is None: + return + start_pos: tuple[int, int] = (node.location.lineno, node.location.col_offset) + end_pos: tuple[int, int] = ( + node.location.end_lineno, + node.location.end_col_offset, + ) + opening: str = f'' + closing: str = "" + self.openings.setdefault(start_pos, []).append(opening) + self.closings.setdefault(end_pos, []).insert(0, closing) + if start_pos[0] != end_pos[0]: + for l in range(start_pos[0], end_pos[0]): + c: int = len(self.lines[l - 1]) + self.closings.setdefault((l, c), []).insert(0, closing) + self.openings.setdefault((l + 1, 0), []).append(opening) + + +class PythonHighlighter( + Highlighter, + p.MidasType.Visitor[None], + p.Stmt.Visitor[None], + p.Expr.Visitor[None], +): + EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_python.css" + + def highlight(self, node: Highlightable[PythonHighlighter]): + node.accept(self) + + def visit_base_type(self, node: p.BaseType) -> None: + self.wrap(node, "base-type") + if node.param is not None: + self.wrap(node.param, "param") + node.param.accept(self) + + def visit_constraint_type(self, node: p.ConstraintType) -> None: + self.wrap(node, "constraint-type") + node.type.accept(self) + + def visit_frame_column(self, node: p.FrameColumn) -> None: + self.wrap(node, "frame-column") + if node.type is not None: + node.type.accept(self) + + def visit_frame_type(self, node: p.FrameType) -> None: + self.wrap(node, "frame-type") + for column in node.columns: + column.accept(self) + + def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: + stmt.expr.accept(self) + + def visit_function(self, stmt: p.Function) -> None: + self.wrap(stmt, "function") + for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs: + self._highlight_function_argument(arg) + + def _highlight_function_argument(self, arg: p.Function.Argument) -> None: + self.wrap(arg, "argument") + if arg.type is not None: + arg.type.accept(self) + + def visit_type_assign(self, stmt: p.TypeAssign) -> None: + stmt.type.accept(self) + + def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: ... + + def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ... + + def visit_compare_expr(self, expr: p.CompareExpr) -> None: ... + + def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ... + + def visit_call_expr(self, expr: p.CallExpr) -> None: ... + + def visit_get_expr(self, expr: p.GetExpr) -> None: ... + + def visit_literal_expr(self, expr: p.LiteralExpr) -> None: ... + + def visit_variable_expr(self, expr: p.VariableExpr) -> None: ... + + def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ... + + def visit_set_expr(self, expr: p.SetExpr) -> None: ... + + +class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]): + EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css" + + def highlight(self, node: Highlightable[MidasHighlighter]): + node.accept(self) + + def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None: + self.wrap(stmt, "simple-type") + if stmt.template is not None: + stmt.template.accept(self) + stmt.base.accept(self) + if stmt.constraint is not None: + self.wrap(stmt.constraint, "constraint") + stmt.constraint.accept(self) + + def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None: + self.wrap(stmt, "complex-type") + if stmt.template is not None: + stmt.template.accept(self) + for prop in stmt.properties: + prop.accept(self) + + def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: + self.wrap(stmt, "property") + stmt.type.accept(self) + if stmt.constraint is not None: + self.wrap(stmt.constraint, "constraint") + stmt.constraint.accept(self) + + def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: + self.wrap(stmt, "extend") + stmt.type.accept(self) + for op in stmt.operations: + op.accept(self) + + def visit_op_stmt(self, stmt: m.OpStmt) -> None: + self.wrap(stmt, "op") + stmt.operand.accept(self) + stmt.result.accept(self) + + def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: + self.wrap(stmt, "predicate") + stmt.type.accept(self) + stmt.condition.accept(self) + + def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> None: + self.wrap(expr, "simple-type-expr") + + def visit_logical_expr(self, expr: m.LogicalExpr) -> None: + self.wrap(expr, "logical-expr") + expr.left.accept(self) + expr.right.accept(self) + + def visit_binary_expr(self, expr: m.BinaryExpr) -> None: + self.wrap(expr, "binary-expr") + expr.left.accept(self) + expr.right.accept(self) + + def visit_unary_expr(self, expr: m.UnaryExpr) -> None: + self.wrap(expr, "unary-expr") + expr.right.accept(self) + + def visit_get_expr(self, expr: m.GetExpr) -> None: + self.wrap(expr, "get-expr") + expr.expr.accept(self) + + def visit_variable_expr(self, expr: m.VariableExpr) -> None: + self.wrap(expr, "variable") + + def visit_grouping_expr(self, expr: m.GroupingExpr) -> None: + expr.expr.accept(self) + + def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ... + + def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ... + + def visit_template_expr(self, expr: m.TemplateExpr) -> None: + self.wrap(expr, "template") + expr.type.accept(self) + + def visit_type_expr(self, expr: m.TypeExpr) -> None: + self.wrap(expr, "type") + if expr.template is not None: + expr.template.accept(self) diff --git a/midas/cli/hl_midas.css b/midas/cli/hl_midas.css new file mode 100644 index 0000000..e8adef6 --- /dev/null +++ b/midas/cli/hl_midas.css @@ -0,0 +1,55 @@ +span { + &.comment { + --col: 200, 200, 200; + color: rgb(110, 110, 110); + font-style: italic; + } + + &.simple-type { + --col: 108, 233, 108; + } + + &.complex-type { + --col: 233, 206, 108; + } + + &.constraint { + --col: 233, 108, 108; + } + + &.property { + --col: 233, 108, 176; + } + + &.extend { + --col: 108, 197, 233; + } + + &.op { + --col: 108, 148, 233; + } + + &.predicate { + --col: 193, 108, 233; + } + + &.simple-type-expr { + --col: 150, 150, 150; + } + + &.logical-expr, + &.binary-expr, + &.unary-expr, + &.get-expr { + --col: 123, 215, 193; + } + + &.template { + --col: 163, 117, 71; + } + + &.type { + --col: 200, 200, 200; + font-weight: bold; + } +} \ No newline at end of file diff --git a/midas/cli/hl_python.css b/midas/cli/hl_python.css new file mode 100644 index 0000000..e6dc43b --- /dev/null +++ b/midas/cli/hl_python.css @@ -0,0 +1,29 @@ +span { + &.base-type { + --col: 108, 233, 108; + } + + &.param { + --col: 103, 192, 224; + } + + &.constraint-type { + --col: 174, 200, 195; + } + + &.frame-column { + --col: 216, 231, 81; + } + + &.frame-type { + --col: 231, 46, 40; + } + + &.function { + --col: 215, 103, 224; + } + + &.argument { + --col: 103, 192, 224; + } +} \ No newline at end of file diff --git a/midas/cli/main.py b/midas/cli/main.py new file mode 100644 index 0000000..11d69e0 --- /dev/null +++ b/midas/cli/main.py @@ -0,0 +1,111 @@ +import ast +from dataclasses import dataclass +from typing import Optional, TextIO + +import click + +import midas.ast.midas as m +import midas.ast.python as p +from midas.ast.location import Location +from midas.ast.printer import PythonAstPrinter +from midas.cli.highlighter import Highlighter, MidasHighlighter, PythonHighlighter +from midas.lexer.midas import MidasLexer +from midas.lexer.token import Token, TokenType +from midas.parser.midas import MidasParser +from midas.parser.python import PythonParser + + +@click.group() +def midas(): + click.echo("Welcome to Midas!") + + +@midas.command() +@click.argument("file", type=click.File("r")) +def compile(file: TextIO): + raise NotImplementedError + + +@midas.group() +def utils(): + pass + + +@utils.command() +@click.option("-o", "--output", type=click.File("w")) +@click.option("-p", "--parse", is_flag=True) +@click.argument("file", type=click.File("r")) +def dump_ast(output: Optional[TextIO], parse: bool, file: TextIO): + source: str = file.read() + tree: ast.Module = ast.parse(source, filename=file.name) + dump: str + + if parse: + parser = PythonParser() + stmts: list[p.Stmt] = parser.parse_module(tree) + printer = PythonAstPrinter() + dump = "" + for stmt in stmts: + dump += printer.print(stmt) + dump += "\n" + + else: + dump = ast.dump(tree, indent=4) + + if output is None: + click.echo(dump) + else: + output.write(dump) + + +def highlight_python(source: str, path: str) -> Highlighter: + tree: ast.Module = ast.parse(source, filename=path) + parser = PythonParser() + stmts: list[p.Stmt] = parser.parse_module(tree) + highlighter = PythonHighlighter(source) + for stmt in stmts: + highlighter.highlight(stmt) + return highlighter + + +def highlight_midas(source: str, path: str) -> Highlighter: + lexer = MidasLexer(source, file=path) + tokens: list[Token] = lexer.process() + parser = MidasParser(tokens) + stmts: list[m.Stmt] = parser.parse() + highlighter = MidasHighlighter(source) + for err in parser.errors: + print(err.get_report()) + + @dataclass(frozen=True) + class LocatableToken: + token: Token + + @property + def location(self) -> Location: + return self.token.get_location() + + for stmt in stmts: + highlighter.highlight(stmt) + for token in tokens: + if token.type == TokenType.COMMENT: + highlighter.wrap(LocatableToken(token), "comment") + elif token.is_keyword: + highlighter.wrap(LocatableToken(token), "keyword") + return highlighter + + +@utils.command() +@click.option("-o", "--output", type=click.File("w"), default="-") +@click.argument("file", type=click.File("r")) +def highlight(output: TextIO, file: TextIO): + source: str = file.read() + highlighter: Highlighter + + if file.name.endswith(".py"): + highlighter = highlight_python(source, file.name) + elif file.name.endswith(".midas"): + highlighter = highlight_midas(source, file.name) + else: + raise ValueError("Unsupported file type") + highlighter.dump(output) diff --git a/midas/lexer/__init__.py b/midas/lexer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lexer/base.py b/midas/lexer/base.py similarity index 98% rename from lexer/base.py rename to midas/lexer/base.py index f6f357d..c4f4d82 100644 --- a/lexer/base.py +++ b/midas/lexer/base.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Optional -from lexer.position import Position -from lexer.token import Token, TokenType +from midas.lexer.position import Position +from midas.lexer.token import Token, TokenType class MidasSyntaxError(Exception): diff --git a/lexer/midas.py b/midas/lexer/midas.py similarity index 98% rename from lexer/midas.py rename to midas/lexer/midas.py index 054f91d..acc97d6 100644 --- a/lexer/midas.py +++ b/midas/lexer/midas.py @@ -1,6 +1,5 @@ -from lexer.base import Lexer -from lexer.keyword import KEYWORDS -from lexer.token import TokenType +from midas.lexer.base import Lexer +from midas.lexer.token import KEYWORDS, TokenType class MidasLexer(Lexer): diff --git a/lexer/position.py b/midas/lexer/position.py similarity index 99% rename from lexer/position.py rename to midas/lexer/position.py index 306e24d..8ff0972 100644 --- a/lexer/position.py +++ b/midas/lexer/position.py @@ -5,6 +5,7 @@ from typing import Optional @dataclass(frozen=True) class Position: """A simple structure to store the position of a token""" + file: Optional[str] line: int column: int diff --git a/lexer/token.py b/midas/lexer/token.py similarity index 50% rename from lexer/token.py rename to midas/lexer/token.py index 1097493..a518a8b 100644 --- a/lexer/token.py +++ b/midas/lexer/token.py @@ -1,8 +1,11 @@ +from __future__ import annotations + from dataclasses import dataclass from enum import Enum, auto from typing import Any -from lexer.position import Position +from midas.ast.location import Location +from midas.lexer.position import Position class TokenType(Enum): @@ -55,6 +58,18 @@ class TokenType(Enum): NEWLINE = auto() +KEYWORDS: dict[str, TokenType] = { + "type": TokenType.TYPE, + "op": TokenType.OP, + "predicate": TokenType.PREDICATE, + "extend": TokenType.EXTEND, + "where": TokenType.WHERE, + "true": TokenType.TRUE, + "false": TokenType.FALSE, + "none": TokenType.NONE, +} + + @dataclass(frozen=True) class Token: """A scanned token""" @@ -63,3 +78,27 @@ class Token: lexeme: str value: Any position: Position + + def get_location(self) -> Location: + lineno: int = self.position.line + col_offset: int = self.position.column - 1 + end_lineno = lineno + end_col_offset = col_offset + for c in self.lexeme: + end_col_offset += 1 + if c == "\n": + end_lineno += 1 + end_col_offset = 0 + return Location( + lineno=lineno, + col_offset=col_offset, + end_lineno=end_lineno, + end_col_offset=end_col_offset, + ) + + def location_to(self, to: Token) -> Location: + return Location.span(self.get_location(), to.get_location()) + + @property + def is_keyword(self) -> bool: + return self.lexeme in KEYWORDS diff --git a/parser/base.py b/midas/parser/base.py similarity index 98% rename from parser/base.py rename to midas/parser/base.py index 74962db..255cd26 100644 --- a/parser/base.py +++ b/midas/parser/base.py @@ -2,8 +2,8 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Generic, TypeVar -from lexer.token import Token, TokenType -from parser.errors import ParsingError +from midas.lexer.token import Token, TokenType +from midas.parser.errors import ParsingError @dataclass(frozen=True) diff --git a/parser/errors.py b/midas/parser/errors.py similarity index 100% rename from parser/errors.py rename to midas/parser/errors.py diff --git a/parser/midas.py b/midas/parser/midas.py similarity index 72% rename from parser/midas.py rename to midas/parser/midas.py index 65e2786..4998c51 100644 --- a/parser/midas.py +++ b/midas/parser/midas.py @@ -1,6 +1,7 @@ from typing import Optional -from core.ast.midas import ( +from midas.ast.location import Location +from midas.ast.midas import ( BinaryExpr, ComplexTypeStmt, Expr, @@ -21,9 +22,9 @@ from core.ast.midas import ( VariableExpr, WildcardExpr, ) -from lexer.token import Token, TokenType -from parser.base import Parser -from parser.errors import ParsingError +from midas.lexer.token import Token, TokenType +from midas.parser.base import Parser +from midas.parser.errors import ParsingError class MidasParser(Parser): @@ -104,6 +105,7 @@ class MidasParser(Parser): Returns: TypeStmt: the parsed type declaration statement """ + keyword: Token = self.previous() name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") template: Optional[TemplateExpr] = None if self.check(TokenType.LEFT_BRACKET): @@ -116,11 +118,20 @@ class MidasParser(Parser): if self.match(TokenType.WHERE): constraint = self.constraint() return SimpleTypeStmt( - name=name, template=template, base=base, constraint=constraint + location=keyword.location_to(self.previous()), + name=name, + template=template, + base=base, + constraint=constraint, ) else: properties: list[PropertyStmt] = self.type_properties() - return ComplexTypeStmt(name=name, template=template, properties=properties) + return ComplexTypeStmt( + location=keyword.location_to(self.previous()), + name=name, + template=template, + properties=properties, + ) def template_expr(self) -> TemplateExpr: """Parse a generic template expression @@ -130,10 +141,14 @@ class MidasParser(Parser): Returns: TemplateExpr: the parsed template expression """ - self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression") + left: Token = self.consume( + TokenType.LEFT_BRACKET, "Missing '[' before template expression" + ) type: TypeExpr = self.type_expr() - self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression") - return TemplateExpr(type=type) + right: Token = self.consume( + TokenType.RIGHT_BRACKET, "Missing ']' after template expression" + ) + return TemplateExpr(location=left.location_to(right), type=type) def type_expr(self) -> TypeExpr: """Parse a type expression @@ -149,7 +164,12 @@ class MidasParser(Parser): if self.check(TokenType.LEFT_BRACKET): template = self.template_expr() optional: bool = self.match(TokenType.QMARK) - return TypeExpr(name=name, template=template, optional=optional) + return TypeExpr( + location=name.location_to(self.previous()), + name=name, + template=template, + optional=optional, + ) def simple_type_expr(self) -> SimpleTypeExpr: """Parse a simple type expression @@ -161,7 +181,9 @@ class MidasParser(Parser): """ name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") optional: bool = self.match(TokenType.QMARK) - return SimpleTypeExpr(name=name, optional=optional) + return SimpleTypeExpr( + location=name.location_to(self.previous()), name=name, optional=optional + ) def constraint(self) -> Expr: """Parse a constraint @@ -183,7 +205,12 @@ class MidasParser(Parser): while self.match(TokenType.AND): operator: Token = self.previous() right: Expr = self.equality() - expr = LogicalExpr(left=expr, operator=operator, right=right) + location: Optional[Location] = None + if expr.location and right.location: + location = Location.span(expr.location, right.location) + expr = LogicalExpr( + location=location, left=expr, operator=operator, right=right + ) return expr def equality(self) -> Expr: @@ -196,7 +223,12 @@ class MidasParser(Parser): while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL): operator: Token = self.previous() right: Expr = self.comparison() - expr = BinaryExpr(left=expr, operator=operator, right=right) + location: Optional[Location] = None + if expr.location and right.location: + location = Location.span(expr.location, right.location) + expr = BinaryExpr( + location=location, left=expr, operator=operator, right=right + ) return expr def comparison(self) -> Expr: @@ -214,7 +246,12 @@ class MidasParser(Parser): ): operator: Token = self.previous() right: Expr = self.unary() - expr = BinaryExpr(left=expr, operator=operator, right=right) + location: Optional[Location] = None + if expr.location and right.location: + location = Location.span(expr.location, right.location) + expr = BinaryExpr( + location=location, left=expr, operator=operator, right=right + ) return expr def unary(self) -> Expr: @@ -226,7 +263,10 @@ class MidasParser(Parser): if self.match(TokenType.MINUS): operator: Token = self.previous() right: Expr = self.unary() - return UnaryExpr(operator=operator, right=right) + location: Optional[Location] = None + if right.location: + location = Location.span(operator.get_location(), right.location) + return UnaryExpr(location=location, operator=operator, right=right) return self.reference() def reference(self) -> Expr: @@ -240,7 +280,10 @@ class MidasParser(Parser): name: Token = self.consume( TokenType.IDENTIFIER, "Expected property name after '.'" ) - expr = GetExpr(expr=expr, name=name) + location: Optional[Location] = None + if expr.location: + location = Location.span(expr.location, name.get_location()) + expr = GetExpr(location=location, expr=expr, name=name) return expr def primary(self) -> Expr: @@ -251,26 +294,27 @@ class MidasParser(Parser): Returns: Expr: the parsed expression """ + token: Token = self.peek() if self.match(TokenType.FALSE): - return LiteralExpr(False) + return LiteralExpr(location=token.get_location(), value=False) if self.match(TokenType.TRUE): - return LiteralExpr(True) + return LiteralExpr(location=token.get_location(), value=True) if self.match(TokenType.NONE): - return LiteralExpr(None) + return LiteralExpr(location=token.get_location(), value=None) if self.match(TokenType.NUMBER): - return LiteralExpr(self.previous().value) + return LiteralExpr(location=token.get_location(), value=token.value) if self.match(TokenType.IDENTIFIER): - return VariableExpr(self.previous()) + return VariableExpr(location=token.get_location(), name=token) if self.match(TokenType.UNDERSCORE): - return WildcardExpr(self.previous()) + return WildcardExpr(location=token.get_location(), token=token) if self.match(TokenType.LEFT_PAREN): expr: Expr = self.constraint() - self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis") - return GroupingExpr(expr) + right: Token = self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis") + return GroupingExpr(location=token.location_to(right), expr=expr) raise self.error(self.peek(), "Expected expression") @@ -304,7 +348,12 @@ class MidasParser(Parser): constraint: Optional[Expr] = None if self.match(TokenType.WHERE): constraint = self.constraint() - return PropertyStmt(name=name, type=type, constraint=constraint) + return PropertyStmt( + location=name.location_to(self.previous()), + name=name, + type=type, + constraint=constraint, + ) def extend_declaration(self) -> ExtendStmt: """Parse an extension definition @@ -314,13 +363,17 @@ class MidasParser(Parser): Returns: ExtendStmt: the parsed extension statement """ + keyword: Token = self.previous() type: TypeExpr = self.type_expr() self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body") operations: list[OpStmt] = [] while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE): operations.append(self.op_declaration()) self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body") - return ExtendStmt(type=type, operations=operations) + location: Optional[Location] = None + if type.location: + location = keyword.location_to(self.previous()) + return ExtendStmt(location=location, type=type, operations=operations) def op_declaration(self) -> OpStmt: """Parse an operation definition @@ -330,7 +383,7 @@ class MidasParser(Parser): Returns: OpStmt: the parsed operation statement """ - self.consume(TokenType.OP, "Expected 'op' keyword") + keyword: Token = self.consume(TokenType.OP, "Expected 'op' keyword") name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name") self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type") @@ -340,7 +393,12 @@ class MidasParser(Parser): self.consume(TokenType.ARROW, "Expected '->' before result type") result: TypeExpr = self.type_expr() - return OpStmt(name=name, operand=operand, result=result) + return OpStmt( + location=keyword.location_to(self.previous()), + name=name, + operand=operand, + result=result, + ) def predicate_declaration(self) -> PredicateStmt: """Parse a predicate declaration @@ -350,6 +408,7 @@ class MidasParser(Parser): Returns: PredicateStmt: the parsed predicate declaration statement """ + keyword: Token = self.previous() name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name") self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject") subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name") @@ -358,4 +417,10 @@ class MidasParser(Parser): self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject") self.consume(TokenType.EQUAL, "Expected '=' after predicate subject") condition: Expr = self.constraint() - return PredicateStmt(name=name, subject=subject, type=type, condition=condition) + return PredicateStmt( + location=keyword.location_to(self.previous()), + name=name, + subject=subject, + type=type, + condition=condition, + ) diff --git a/midas/parser/python.py b/midas/parser/python.py new file mode 100644 index 0000000..4b6a3f1 --- /dev/null +++ b/midas/parser/python.py @@ -0,0 +1,343 @@ +import ast +from typing import Optional + +from midas.ast.location import Location + +from midas.ast.python import ( + AssignStmt, + BaseType, + BinaryExpr, + CallExpr, + CompareExpr, + ConstraintType, + Expr, + ExpressionStmt, + FrameColumn, + FrameType, + Function, + GetExpr, + LiteralExpr, + LogicalExpr, + MidasType, + Stmt, + TypeAssign, + UnaryExpr, + VariableExpr, +) + + +class InvalidSyntaxError(Exception): + pass + + +class UnsupportedSyntaxError(Exception): + def __init__(self, expr: ast.expr) -> None: + super().__init__( + f"Unsupported syntax at L{expr.lineno}:{expr.col_offset}: {ast.unparse(expr)}" + ) + + +class PythonParser: + def parse_module(self, node: ast.Module) -> list[Stmt]: + statements: list[Stmt] = [] + for stmt in node.body: + try: + parsed: None | Stmt | list[Stmt] = self.parse_stmt(stmt) + if isinstance(parsed, Stmt): + statements.append(parsed) + elif parsed is not None: + statements.extend(parsed) + except UnsupportedSyntaxError as e: + print(f"{e}, skipping") + continue + return statements + + def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]: + match node: + case ast.AnnAssign(): + return self.parse_annotation_assign(node) + + case ast.Assign(): + return self.parse_assign(node) + + case ast.FunctionDef(): + return self.parse_function(node) + + case ast.Expr(value=expr): + return ExpressionStmt(expr=self.parse_expr(expr)) + + case _: + print(f"Unsupported statement: {ast.unparse(node)}") + return None + + def parse_annotation_assign(self, node: ast.AnnAssign) -> list[Stmt]: + statements: list[Stmt] = [] + loc: Location = Location.from_ast(node) + match node: + case ast.AnnAssign( + target=ast.Name(id=target), + annotation=annotation, + value=value, + simple=1, + ): + type = self._parse_type(annotation, root=True) + if type is not None: + statements.append( + TypeAssign( + location=loc, + name=target, + type=type, + ) + ) + + if value is not None: + statements.append( + AssignStmt( + location=loc, + targets=[ + VariableExpr( + location=Location.from_ast(node.target), name=target + ), + ], + value=self.parse_expr(value), + ), + ) + case _: + print(f"Unsupported annotation: {ast.unparse(node)}") + return statements + + def parse_assign(self, node: ast.Assign) -> AssignStmt: + targets: list[Expr] = [] + for target in node.targets: + targets.append(self.parse_expr(target)) + value: Expr = self.parse_expr(node.value) + return AssignStmt( + location=Location.from_ast(node), + targets=targets, + value=value, + ) + + def parse_function(self, node: ast.FunctionDef) -> Function: + loc: Location = Location.from_ast(node) + match node: + case ast.FunctionDef( + name=name, + args=ast.arguments( + posonlyargs=posonlyargs, + args=args, + kwonlyargs=kwonlyargs, + ), + returns=returns, + ): + + def parse_args(args_list: list[ast.arg]) -> list[Function.Argument]: + return [self._parse_function_argument(arg) for arg in args_list] + + return Function( + location=loc, + name=name, + posonlyargs=parse_args(posonlyargs), + args=parse_args(args), + kwonlyargs=parse_args(kwonlyargs), + returns=self._parse_type(returns) if returns is not None else None, + ) + case _: + print(f"Unsupported function definition: {ast.unparse(node)}") + + def _parse_function_argument(self, arg: ast.arg) -> Function.Argument: + loc: Location = Location.from_ast(arg) + name: str = arg.arg + type: Optional[MidasType] = None + if arg.annotation is not None: + type = self._parse_type(arg.annotation) + return Function.Argument( + location=loc, + name=name, + type=type, + ) + + def _parse_type( + self, type_expr: ast.expr, root: bool = False + ) -> Optional[MidasType]: + loc: Location = Location.from_ast(type_expr) + match type_expr: + case ast.Subscript(value=ast.Name(id="Frame"), slice=schema): + return self._parse_frame_type(schema) + + case ast.Subscript(value=ast.Name(id=name), slice=param): + return BaseType( + location=loc, + base=name, + param=self._parse_type(param), + ) + + case ast.Name(id=name): + return BaseType( + location=loc, + base=name, + param=None, + ) + + case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr): + left = self._parse_type(left_expr) + match left: + case None: + raise InvalidSyntaxError() + + # If chained constraints, separate base type and rebuild constraint + case ConstraintType(type=left_type, constraint=left_constraint): + constraint = ast.BinOp( + left=left_constraint, + op=ast.Add(), + right=right_expr, + ) + ast.copy_location(constraint, type_expr) + return ConstraintType( + location=loc, + type=left_type, + constraint=constraint, + ) + + case _: + return ConstraintType( + location=loc, + type=left, + constraint=right_expr, + ) + + case _: + if root: + return None + raise UnsupportedSyntaxError(type_expr) + + def _parse_frame_type(self, schema: ast.expr) -> FrameType: + loc: Location = Location.from_ast(schema) + columns: list[FrameColumn] = [] + + match schema: + case ast.Tuple(elts=cols): + for col in cols: + columns.append(self._parse_frame_column(col)) + + case ast.Slice() | ast.Name(): + columns.append(self._parse_frame_column(schema)) + + case _: + raise UnsupportedSyntaxError(schema) + + return FrameType(location=loc, columns=columns) + + def _parse_frame_column(self, column: ast.expr) -> FrameColumn: + loc: Location = Location.from_ast(column) + match column: + case ast.Name(): + return FrameColumn( + location=loc, + name=None, + type=self._parse_type(column), + ) + + case ast.Slice(lower=ast.Name(id=name), upper=type_expr): + if name == "_": + name = None + + type: Optional[MidasType] = None + match type_expr: + case None: + raise InvalidSyntaxError("Missing column type") + case ast.Name(id="_"): + type = None + case ast.expr(): + type = self._parse_type(type_expr) + case _: + raise UnsupportedSyntaxError(type_expr) + return FrameColumn(location=loc, name=name, type=type) + + case _: + raise UnsupportedSyntaxError(column) + + def parse_expr(self, node: ast.expr) -> Expr: + match node: + case ast.BoolOp(): + return self.parse_bool_op(node) + + case ast.BinOp(left=left, op=op, right=right): + return BinaryExpr( + left=self.parse_expr(left), + operator=op, + right=self.parse_expr(right), + ) + + case ast.UnaryOp(op=op, operand=right): + return UnaryExpr( + operator=op, + right=self.parse_expr(right), + ) + + case ast.Compare(): + return self.parse_compare(node) + + case ast.Call(): + return self.parse_call(node) + + case ast.Constant(value=value): + return LiteralExpr(value=value) + + case ast.Attribute(value=object, attr=name): + return GetExpr( + object=self.parse_expr(object), + name=name, + ) + + case ast.Name(id=name): + return VariableExpr(name=name) + + case _: + raise UnsupportedSyntaxError(node) + + def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr: + op: ast.boolop = node.op + values: list[ast.expr] = node.values + expr: LogicalExpr = LogicalExpr( + left=self.parse_expr(values[0]), + operator=op, + right=self.parse_expr(values[1]), + ) + for value in values[2:]: + expr = LogicalExpr( + left=expr, + operator=op, + right=self.parse_expr(value), + ) + return expr + + def parse_compare(self, node: ast.Compare) -> Expr: + ops: list[ast.cmpop] = node.ops + rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators] + expr: Expr = CompareExpr( + left=self.parse_expr(node.left), + operator=ops[0], + right=rights[0], + ) + for i, right in enumerate(rights[1:]): + expr = LogicalExpr( + left=expr, + operator=ast.And(), + right=CompareExpr( + left=rights[i], + operator=ops[i], + right=right, + ), + ) + return expr + + def parse_call(self, node: ast.Call) -> CallExpr: + return CallExpr( + callee=self.parse_expr(node.func), + arguments=[self.parse_expr(arg) for arg in node.args], + keywords={ + arg.arg: self.parse_expr(arg.value) + for arg in node.keywords + if arg.arg is not None # Should always be True, type checker happy + }, + ) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..69a9f7e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,22 @@ +[project] +name = "midas" +version = "0.1.0" +description = "A static-first type checking framework for Python data-frames" +readme = "README.md" +requires-python = ">=3.11" +authors = [ + { name = "Louis Heredero", email = "louis.heredero@students.hevs.ch" }, +] +classifiers = ["Programming Language :: Python :: 3"] +dependencies = ["click>=8.4.1"] + +[project.urls] +Homepage = "https://git.kbk28.ch/HEL/midas" +Repository = "https://git.kbk28.ch/HEL/midas" + +[project.scripts] +midas = "midas.cli.main:midas" + +[build-system] +requires = ['hatchling'] +build-backend = 'hatchling.build' diff --git a/test.py b/test.py index 048329a..522bbac 100644 --- a/test.py +++ b/test.py @@ -1,10 +1,10 @@ import json from pathlib import Path -from core.ast.printer import MidasAstPrinter -from lexer.midas import MidasLexer -from lexer.token import Token -from parser.midas import MidasParser +from midas.ast.printer import MidasAstPrinter +from midas.lexer.midas import MidasLexer +from midas.lexer.token import Token +from midas.parser.midas import MidasParser def test_midas(): diff --git a/tester.py b/tester.py index 597ddee..3238a67 100644 --- a/tester.py +++ b/tester.py @@ -8,12 +8,12 @@ from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Iterator, Optional -from core.ast.json_serializer import AstJsonSerializer -from core.ast.midas import Stmt -from lexer.base import MidasSyntaxError -from lexer.midas import MidasLexer -from lexer.token import Token -from parser.midas import MidasParser +from midas.ast.json_serializer import AstJsonSerializer +from midas.ast.midas import Stmt +from midas.lexer.base import MidasSyntaxError +from midas.lexer.midas import MidasLexer +from midas.lexer.token import Token +from midas.parser.midas import MidasParser DEFAULT_BASE_DIR: Path = Path() / "tests"