Compare commits

...

40 Commits

Author SHA1 Message Date
c81287df7f Merge pull request 'Initial dataframe implementation' (#25) from feat/dataframes into main
Reviewed-on: #25
2026-07-01 08:24:36 +00:00
ffccc1bedd feat(cli): generate stubs in build dir when compiling 2026-07-01 10:16:13 +02:00
d14f208897 feat(gen): add tuple expr to generator 2026-07-01 10:16:13 +02:00
293953a078 tests: update with multi-parameter generics 2026-07-01 10:16:12 +02:00
bccc96e4d0 fix: minor fixes 2026-07-01 10:16:11 +02:00
9db56adf56 feat: add Python tuple expression 2026-07-01 10:16:10 +02:00
3f99563ac8 feat: handle multi-parameter generic in Python 2026-07-01 10:16:10 +02:00
b36896cc7b feat(checker): add len() 2026-07-01 10:16:09 +02:00
cb75878ae9 fix(checker): allow some assignments to unknown 2026-07-01 10:16:08 +02:00
a5fe985eb2 feat(checker): add methods on str 2026-07-01 10:16:08 +02:00
e324f414e6 feat(checker): type check tuple instantiation in Midas 2026-07-01 10:16:07 +02:00
256536562f fix(parser): parse empty calls 2026-07-01 10:16:06 +02:00
64f4314f0d fix(gen): prevent empty loop for column asserts 2026-07-01 10:16:06 +02:00
6f6245d283 fix(checker): allow iterating on unknown 2026-07-01 10:16:05 +02:00
3392bc347d fix(checker): allow subtypes and unknown as if test 2026-07-01 10:16:04 +02:00
7e0319906a feat(gen): assertions for column values 2026-07-01 10:16:03 +02:00
75bd203d4a fix(checker): allow calling unknown method on dataframes 2026-07-01 10:15:16 +02:00
db40198357 feat(gen): generate asserts for dataframes and columns 2026-07-01 10:15:16 +02:00
d79e1dee18 fix(checker): change heterogeneous errors to warnings 2026-07-01 10:15:15 +02:00
4ea400265c feat(checker): add mean method on frames 2026-07-01 10:15:14 +02:00
24bffdabd4 fix(checker): type check None literal 2026-07-01 10:15:13 +02:00
d7bb6326de feat(checker): lookup dunders on dataframes 2026-07-01 10:15:12 +02:00
dbf6f9e2db tests: update with reordered argument typing 2026-07-01 10:15:12 +02:00
3cdc9031d3 refactor: use metaclass to collect frame methods 2026-07-01 10:15:11 +02:00
00e2ca8fe3 refactor: add MethodResolver class 2026-07-01 10:15:10 +02:00
4efb01285c feat: add dummy classes for typing frames and columns 2026-07-01 10:15:10 +02:00
f84a19159f fix(checker): improve heterogeneous error message 2026-07-01 10:15:09 +02:00
946b2e0d2e feat(checker): lookup dataframe methods 2026-07-01 10:15:08 +02:00
08dd7408ec feat(checker): defined add method of dataframes 2026-07-01 10:15:07 +02:00
b33fadf768 feat(checker): add structural subtyping rule for dataframes 2026-07-01 10:15:06 +02:00
7219109e5d feat(cli): print context for multiline diagnostics 2026-07-01 10:14:48 +02:00
cdf1725c26 feat(checker): process frame type definitions 2026-07-01 10:14:48 +02:00
7074b074bc feat(cli): add frame type to highlighter 2026-07-01 10:14:17 +02:00
ede7272c09 feat(parser): add frame type to midas syntax 2026-07-01 10:14:16 +02:00
87d5e286d2 feat(gen): add support for tuples and dataframes 2026-07-01 10:14:16 +02:00
c91b206791 feat(checker): handle setting dataframe column 2026-07-01 10:13:30 +02:00
a31d295eb1 feat(checker): type check subscript on dataframes 2026-07-01 10:13:30 +02:00
0d20993f02 feat(types): add TupleType 2026-07-01 10:13:28 +02:00
5357ca8e58 fix(types): add str methods to dataframe types 2026-07-01 10:13:28 +02:00
556765fd35 feat(types): add DataFrameType and ColumnType 2026-07-01 10:13:27 +02:00
37 changed files with 1990 additions and 744 deletions

View File

@@ -157,4 +157,14 @@ class FunctionType:
required: bool required: bool
class FrameType:
columns: list[Column]
@dataclass(frozen=True, kw_only=True)
class Column:
location: Optional[Location] = None
name: Token
type: Type
###< ###<

View File

@@ -15,7 +15,7 @@ from midas.ast.location import Location
###> MidasType | Type annotations | node ###> MidasType | Type annotations | node
class BaseType: class BaseType:
base: str base: str
param: Optional[MidasType] args: tuple[MidasType, ...]
class ConstraintType: class ConstraintType:
@@ -174,6 +174,10 @@ class SliceExpr:
step: Optional[Expr] step: Optional[Expr]
class TupleExpr:
items: tuple[Expr, ...]
class RawExpr: class RawExpr:
expr: ast.expr expr: ast.expr

View File

@@ -265,6 +265,9 @@ class Type(ABC):
@abstractmethod @abstractmethod
def visit_function_type(self, type: FunctionType) -> T: ... def visit_function_type(self, type: FunctionType) -> T: ...
@abstractmethod
def visit_frame_type(self, type: FrameType) -> T: ...
@dataclass(frozen=True) @dataclass(frozen=True)
class NamedType(Type): class NamedType(Type):
@@ -323,3 +326,17 @@ class FunctionType(Type):
def accept(self, visitor: Type.Visitor[T]) -> T: def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_function_type(self) return visitor.visit_function_type(self)
@dataclass(frozen=True)
class FrameType(Type):
columns: list[Column]
@dataclass(frozen=True, kw_only=True)
class Column:
location: Optional[Location] = None
name: Token
type: Type
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_frame_type(self)

View File

@@ -358,6 +358,25 @@ class MidasAstPrinter(
arg.type.accept(self) arg.type.accept(self)
self._write_line(f"required: {arg.required}", last=True) self._write_line(f"required: {arg.required}", last=True)
def visit_frame_type(self, type: m.FrameType) -> None:
self._write_line("FrameType")
with self._child_level(single=True):
self._write_line("columns")
with self._child_level():
for i, column in enumerate(type.columns):
self._idx = i
if i == len(type.columns) - 1:
self._mark_last()
self._print_frame_column(column)
def _print_frame_column(self, column: m.FrameType.Column) -> None:
self._write_line("Column")
with self._child_level():
self._write_line(f'name: "{column.name.lexeme}"')
self._write_line("type")
with self._child_level(single=True):
column.type.accept(self)
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]): class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
def __init__(self, indent: int = 4): def __init__(self, indent: int = 4):
@@ -513,6 +532,23 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
res += "?" res += "?"
return res return res
def visit_frame_type(self, type: m.FrameType) -> str:
res: str = self.indented("Frame[")
if len(type.columns) != 0:
res += "\n"
self.level += 1
columns: list[str] = []
for column in type.columns:
columns.append(self.indented(self._print_frame_column(column)))
res += ",\n".join(columns)
self.level -= 1
res += "\n"
res += "]"
return res
def _print_frame_column(self, column: m.FrameType.Column) -> str:
return f"{column.name.lexeme}: {column.type.accept(self)}"
class PythonAstPrinter( class PythonAstPrinter(
AstPrinter, AstPrinter,
@@ -524,7 +560,13 @@ class PythonAstPrinter(
self._write_line("BaseType") self._write_line("BaseType")
with self._child_level(): with self._child_level():
self._write_line(f"base: {node.base}") self._write_line(f"base: {node.base}")
self._write_optional_child("param", node.param, last=True) self._write_line("args:", last=True)
with self._child_level():
for i, arg in enumerate(node.args):
self._idx = i
if i == len(node.args) - 1:
self._mark_last()
arg.accept(self)
def visit_constraint_type(self, node: p.ConstraintType) -> None: def visit_constraint_type(self, node: p.ConstraintType) -> None:
self._write_line("ConstraintType") self._write_line("ConstraintType")
@@ -837,6 +879,17 @@ class PythonAstPrinter(
self._write_optional_child("upper", expr.upper) self._write_optional_child("upper", expr.upper)
self._write_optional_child("step", expr.step, last=True) self._write_optional_child("step", expr.step, last=True)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
self._write_line("TupleExpr")
with self._child_level():
self._write_line("items", last=True)
with self._child_level():
for i, item in enumerate(expr.items):
self._idx = i
if i == len(expr.items) - 1:
self._mark_last()
item.accept(self)
def visit_raw_expr(self, expr: p.RawExpr) -> None: def visit_raw_expr(self, expr: p.RawExpr) -> None:
self._write_line("RawExpr") self._write_line("RawExpr")
with self._child_level(single=True): with self._child_level(single=True):

View File

@@ -44,7 +44,7 @@ class MidasType(ABC):
@dataclass(frozen=True) @dataclass(frozen=True)
class BaseType(MidasType): class BaseType(MidasType):
base: str base: str
param: Optional[MidasType] args: tuple[MidasType, ...]
def accept(self, visitor: MidasType.Visitor[T]) -> T: def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_base_type(self) return visitor.visit_base_type(self)
@@ -268,6 +268,9 @@ class Expr(ABC):
@abstractmethod @abstractmethod
def visit_slice_expr(self, expr: SliceExpr) -> T: ... def visit_slice_expr(self, expr: SliceExpr) -> T: ...
@abstractmethod
def visit_tuple_expr(self, expr: TupleExpr) -> T: ...
@abstractmethod @abstractmethod
def visit_raw_expr(self, expr: RawExpr) -> T: ... def visit_raw_expr(self, expr: RawExpr) -> T: ...
@@ -402,6 +405,14 @@ class SliceExpr(Expr):
return visitor.visit_slice_expr(self) return visitor.visit_slice_expr(self)
@dataclass(frozen=True)
class TupleExpr(Expr):
items: tuple[Expr, ...]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_tuple_expr(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class RawExpr(Expr): class RawExpr(Expr):
expr: ast.expr expr: ast.expr

View File

@@ -178,4 +178,100 @@ extend dict[K, V] {
// def __ior__: fn(value: SupportsKeysAndGetItem[K, V], /) -> dict[K, V] // def __ior__: fn(value: SupportsKeysAndGetItem[K, V], /) -> dict[K, V]
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V] // def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
} }
extend str {
def capitalize: fn() -> str
def casefold: fn() -> str
def center: fn(width: int, fillchar: str?, /) -> str
def count: fn(sub: str, start: None?, end: None?, /) -> int
def count: fn(sub: str, start: int, end: None?, /) -> int
def count: fn(sub: str, start: None, end: int, /) -> int
def count: fn(sub: str, start: int, end: int, /) -> int
def encode: fn(encoding: str?, errors: str?) -> bytes
def endswith: fn(suffix: str, start: None?, end: None?, /) -> bool
def endswith: fn(suffix: str, start: int, end: None?, /) -> bool
def endswith: fn(suffix: str, start: None, end: int, /) -> bool
def endswith: fn(suffix: str, start: int, end: int, /) -> bool
def expandtabs: fn(tabsize: int?) -> str
def find: fn(sub: str, start: None?, end: None?, /) -> int
def find: fn(sub: str, start: int, end: None?, /) -> int
def find: fn(sub: str, start: None, end: int, /) -> int
def find: fn(sub: str, start: int, end: int, /) -> int
// def format: fn(*args: object, **kwargs: object) -> str
// def format_map: fn(mapping: _FormatMapMapping, /) -> str
def index: fn(sub: str, start: None?, end: None?, /) -> int
def index: fn(sub: str, start: int, end: None?, /) -> int
def index: fn(sub: str, start: None, end: int, /) -> int
def index: fn(sub: str, start: int, end: int, /) -> int
def isalnum: fn() -> bool
def isalpha: fn() -> bool
def isascii: fn() -> bool
def isdecimal: fn() -> bool
def isdigit: fn() -> bool
def isidentifier: fn() -> bool
def islower: fn() -> bool
def isnumeric: fn() -> bool
def isprintable: fn() -> bool
def isspace: fn() -> bool
def istitle: fn() -> bool
def isupper: fn() -> bool
def join: fn(iterable: list[str], /) -> str // TODO: use Iterable
def ljust: fn(width: int, fillchar: str?, /) -> str
def lower: fn() -> str
def lstrip: fn(chars: None?, /) -> str
def lstrip: fn(chars: str, /) -> str
def partition: fn(sep: str, /) -> tuple[str, str, str]
def replace: fn(old: str, new: str, count: int?, /) -> str
def removeprefix: fn(prefix: str, /) -> str
def removesuffix: fn(suffix: str, /) -> str
def rfind: fn(sub: str, start: None?, end: None?, /) -> int
def rfind: fn(sub: str, start: int, end: None?, /) -> int
def rfind: fn(sub: str, start: None, end: int, /) -> int
def rfind: fn(sub: str, start: int, end: int, /) -> int
def rindex: fn(sub: str, start: None?, end: None?, /) -> int
def rindex: fn(sub: str, start: int, end: None?, /) -> int
def rindex: fn(sub: str, start: None, end: int, /) -> int
def rindex: fn(sub: str, start: int, end: int, /) -> int
def rjust: fn(width: int, fillchar: str?, /) -> str
def rpartition: fn(sep: str, /) -> tuple[str, str, str]
def rsplit: fn(sep: None?, maxsplit: int?) -> list[str]
def rsplit: fn(sep: str, maxsplit: int?) -> list[str]
def rstrip: fn(chars: None?, /) -> str
def rstrip: fn(chars: str, /) -> str
def split: fn(sep: None?, maxsplit: int?) -> list[str]
def split: fn(sep: str, maxsplit: int?) -> list[str]
def splitlines: fn(keepends: bool?) -> list[str]
def startswith: fn(prefix: str, start: None?, end: None?, /) -> bool
def startswith: fn(prefix: str, start: int, end: None?, /) -> bool
def startswith: fn(prefix: str, start: None, end: int, /) -> bool
def startswith: fn(prefix: str, start: int, end: int, /) -> bool
def strip: fn(chars: None?, /) -> str
def strip: fn(chars: str, /) -> str
def swapcase: fn() -> str
def title: fn() -> str
// def translate: fn(table: _TranslateTable, /) -> str
def upper: fn() -> str
def zfill: fn(width: int, /) -> str
def __add__: fn(value: str, /) -> str
// Incompatible with Sequence.__contains__
def __contains__: fn(key: str, /) -> bool
def __eq__: fn(value: object, /) -> bool
def __ge__: fn(value: str, /) -> bool
def __getitem__: fn(key: slice, /) -> str
def __getitem__: fn(key: int, /) -> str
def __gt__: fn(value: str, /) -> bool
def __hash__: fn() -> int
// def __iter__: fn() -> Iterator[str]
def __le__: fn(value: str, /) -> bool
def __len__: fn() -> int
def __lt__: fn(value: str, /) -> bool
def __mod__: fn(value: Any, /) -> str
def __mul__: fn(value: int, /) -> str
def __ne__: fn(value: object, /) -> bool
def __rmul__: fn(value: int, /) -> str
def __getnewargs__: fn() -> tuple[str]
def __format__: fn(format_spec: str, /) -> str
}

View File

@@ -15,7 +15,7 @@ if TYPE_CHECKING:
BUILTIN_SUBTYPES: dict[str, set[str]] = { BUILTIN_SUBTYPES: dict[str, set[str]] = {
"object": {"float", "list", "dict", "str"}, "object": {"float", "list", "dict", "str", "bytes", "tuple"},
"float": {"int"}, "float": {"int"},
"int": {"bool"}, "int": {"bool"},
} }
@@ -26,12 +26,15 @@ def define_builtins(reg: TypesRegistry):
any = reg.define_type("Any", TopType()) any = reg.define_type("Any", TopType())
unit = reg.define_type("None", UnitType()) unit = reg.define_type("None", UnitType())
object = reg.define_type("object", BaseType(name="object")) object = reg.define_type("object", BaseType(name="object"))
bytes = reg.define_type("bytes", BaseType(name="bytes"))
bool = reg.define_type("bool", BaseType(name="bool")) bool = reg.define_type("bool", BaseType(name="bool"))
int = reg.define_type("int", BaseType(name="int")) int = reg.define_type("int", BaseType(name="int"))
float = reg.define_type("float", BaseType(name="float")) float = reg.define_type("float", BaseType(name="float"))
str = reg.define_type("str", BaseType(name="str")) str = reg.define_type("str", BaseType(name="str"))
slice = reg.define_type("slice", BaseType(name="slice")) slice = reg.define_type("slice", BaseType(name="slice"))
tuple = reg.define_type("tuple", BaseType(name="tuple"))
list = reg.define_type( list = reg.define_type(
"list", "list",
GenericType( GenericType(

View File

@@ -0,0 +1,198 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional
from midas.ast.location import Location
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import (
ColumnType,
DataFrameType,
Function,
OverloadedFunction,
TopType,
Type,
UnknownType,
unfold_type,
)
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
@staticmethod
def frame_method(*names: str):
def wrapper(func):
names_: tuple[str, ...] = names
if len(names_) == 0:
names_ = (func.__name__,)
setattr(func, "__method_names__", names_)
return func
return wrapper
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
frame: DataFrameType
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
class _MethodRegistryMeta(type):
_methods: dict[str, Callable[..., Type]] = {}
def __new__(
cls,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
):
new_class = super().__new__(cls, name, bases, namespace)
new_class._methods = {}
for attr in namespace.values():
if callable(attr) and hasattr(attr, "__method_names__"):
for name in attr.__method_names__: # type: ignore
new_class._methods[name] = attr # type: ignore
return new_class
class MethodRegistry(metaclass=_MethodRegistryMeta):
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
@property
def reporter(self) -> FileReporter:
return self.typer.reporter
@property
def types(self) -> TypesRegistry:
return self.typer.types
def call(
self,
method: str,
call: Call,
) -> Type:
func: Optional[Callable[..., Type]] = self._methods.get(method)
if func is None:
self.reporter.warning(call.location, f"Unknown method {method}")
return UnknownType()
return func(self, call)
@frame_method("add", "__add__")
def add(
self,
call: Call,
) -> Type:
# TODO: support add with scalar, sequence, Series, dict
# TODO: check operation exists on inner column types
new_columns: list[DataFrameType.Column] = []
by_name: dict[str, DataFrameType.Column] = {}
frame2: Optional[DataFrameType] = None
if len(call.positional) != 0:
other: Type = call.positional[0][1]
unfolded_other: Type = unfold_type(other)
if isinstance(unfolded_other, DataFrameType):
frame2 = unfolded_other
by_name = {
col.name: col for col in frame2.columns if col.name is not None
}
in_frame1: set[str] = set()
for column in call.frame.columns:
if column.name is not None:
in_frame1.add(column.name)
col_type1: Type = column.type
col_type: Type = ColumnType(type=UnknownType())
if column.name in by_name:
column2 = by_name[column.name]
col_type2: Type = column2.type
if self.types.are_equivalent(col_type2, col_type1):
col_type = col_type1
new_column = DataFrameType.Column(
index=column.index,
name=column.name,
type=col_type,
)
new_columns.append(new_column)
if frame2 is not None:
for column in frame2.columns:
if column.name in in_frame1:
continue
new_columns.append(
DataFrameType.Column(
index=len(new_columns),
name=column.name,
type=ColumnType(type=UnknownType()),
)
)
signature = Function(
args=[
Function.Argument(
pos=0,
name="other",
type=DataFrameType(columns=[]),
required=True,
),
],
returns=DataFrameType(columns=new_columns),
)
return (
self.typer._get_call_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
or UnknownType()
)
@frame_method()
def mean(self, call: Call) -> Type:
with_axis = Function(
kw_args=[
Function.Argument(
pos=0,
name="axis",
type=self.types.get_type("int"),
required=False,
)
],
returns=ColumnType(type=TopType()),
)
without_axis = Function(
kw_args=[
Function.Argument(
pos=0,
name="axis",
type=self.types.get_type("None"),
required=True,
)
],
returns=TopType(),
)
overload = OverloadedFunction(
overloads=[
with_axis,
without_axis,
]
)
return (
self.typer._get_call_result(
location=call.location,
callee=overload,
positional=call.positional,
keywords=call.keywords,
)
or UnknownType()
)

154
midas/checker/frames.py Normal file
View File

@@ -0,0 +1,154 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, TypeGuard, cast
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frame_methods import Call, MethodRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
return all(isinstance(expr, p.LiteralExpr) for expr in exprs)
class FrameManager:
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
self.method_resolver: MethodRegistry = MethodRegistry(self.typer)
def assign(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
index: p.Expr,
value_type: Type,
) -> Type:
match index:
case p.LiteralExpr(value=str() as name):
return self.assign_column(reporter, location, frame, name, value_type)
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
isinstance(idx, str) for idx in indices
):
raise NotImplementedError
case _:
reporter.error(location, f"Invalid index type {index} on {frame}")
return UnknownType()
def assign_column(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
name: str,
type: Type,
) -> Type:
if not isinstance(type, ColumnType):
reporter.error(
location,
f"Cannot assign {type} to dataframe column. Must be a ColumnType",
)
return frame
return self._set_column(frame, name, type)
def get(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
index: p.Expr,
) -> Type:
match index:
case p.LiteralExpr(value=str() as name):
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
if column is None:
reporter.error(location, f"Unknown column '{name}' on {frame}")
return UnknownType()
return column
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
isinstance(index.value, str) for index in indices
):
names: list[str] = [cast(str, index.value) for index in indices]
columns: list[ColumnType] = []
for name in names:
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
if column is None:
reporter.error(location, f"Unknown column '{name}' on {frame}")
return UnknownType()
columns.append(column)
return TupleType(items=tuple(columns))
case _:
reporter.error(location, f"Invalid index type {index} on {frame}")
return UnknownType()
@classmethod
def _set_column(
cls, frame: DataFrameType, name: str, column: ColumnType
) -> DataFrameType:
new_columns: list[DataFrameType.Column] = []
index: int = len(frame.columns)
replace: bool = False
for i, col in enumerate(frame.columns):
if col.name == name:
index = i
replace = True
# TODO: check column type here to prevent changing it
new_columns.append(col)
new_col: DataFrameType.Column = DataFrameType.Column(
index=index,
name=name,
type=column,
)
if replace:
new_columns[index] = new_col
else:
new_columns.append(new_col)
return DataFrameType(columns=new_columns)
@classmethod
def _set_columns(
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
) -> DataFrameType:
for name, col in zip(names, columns):
frame = cls._set_column(frame, name, col)
return frame
@classmethod
def _get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
for col in frame.columns:
if col.name == name:
return col.type
return None
@classmethod
def _get_columns(
cls, frame: DataFrameType, names: list[str]
) -> list[Optional[ColumnType]]:
return [cls._get_column(frame, name) for name in names]
def call(
self,
method: str,
location: Location,
frame: DataFrameType,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: Call = Call(
location=location,
frame=frame,
positional=positional,
keywords=keywords,
)
return self.method_resolver.call(method, call)

View File

@@ -13,8 +13,10 @@ from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import ( from midas.checker.types import (
AppliedType, AppliedType,
ColumnType,
ComplexType, ComplexType,
ConstraintType, ConstraintType,
DataFrameType,
DerivedType, DerivedType,
ExtensionType, ExtensionType,
Function, Function,
@@ -408,6 +410,18 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)], kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
) )
def visit_frame_type(self, type: m.FrameType) -> Type:
def process_column(i: int, col: m.FrameType.Column) -> DataFrameType.Column:
return DataFrameType.Column(
index=i,
name=col.name.lexeme,
type=ColumnType(type=col.type.accept(self)),
)
return DataFrameType(
columns=[process_column(i, col) for i, col in enumerate(type.columns)]
)
def _resolve_type_params(self, params: list[m.TypeParam]): def _resolve_type_params(self, params: list[m.TypeParam]):
vars: list[TypeVar] = [] vars: list[TypeVar] = []
for param in params: for param in params:

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Optional from typing import Any, Callable, Optional
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.registry import TypesRegistry from midas.checker.registry import TypesRegistry
@@ -17,7 +17,7 @@ class Preamble(Environment):
def __init__(self, types: TypesRegistry) -> None: def __init__(self, types: TypesRegistry) -> None:
super().__init__() super().__init__()
self._types: TypesRegistry = types self._types: TypesRegistry = types
self._python_funcs: dict[str, Callable] = {} self._python_funcs: dict[str, Callable[..., Any]] = {}
self._def_type_constructor("object", object) self._def_type_constructor("object", object)
self._def_type_constructor("float", float) self._def_type_constructor("float", float)
@@ -34,7 +34,7 @@ class Preamble(Environment):
# TODO: use sink # TODO: use sink
self._def_function( self._def_function(
name="print", name="print",
pos=[Param("object", TopType())], pos=[Param("object", TopType(), required=False)],
returns=UnitType(), returns=UnitType(),
py_function=print, py_function=print,
) )
@@ -64,11 +64,18 @@ class Preamble(Environment):
pos=[Param("prompt", TopType(), required=False)], pos=[Param("prompt", TopType(), required=False)],
returns=self._types.get_type("str"), returns=self._types.get_type("str"),
) )
self._def_function(
name="len",
pos=[Param("object", TopType())],
returns=self._types.get_type("int"),
)
def _list_of(self, item_type: Type) -> Type: def _list_of(self, item_type: Type) -> Type:
return self._types.apply_generic(self._types.get_type("list"), [item_type]) return self._types.apply_generic(self._types.get_type("list"), [item_type])
def _def_type_constructor(self, name: str, py_function: Optional[Callable] = None): def _def_type_constructor(
self, name: str, py_function: Optional[Callable[..., Any]] = None
):
# TODO: more specific arg types # TODO: more specific arg types
self._def_function( self._def_function(
name=name, name=name,
@@ -121,7 +128,7 @@ class Preamble(Environment):
kw: list[Param] = [], kw: list[Param] = [],
returns: Type = UnitType(), returns: Type = UnitType(),
type_vars: list[TypeVar] = [], type_vars: list[TypeVar] = [],
py_function: Optional[Callable] = None, py_function: Optional[Callable[..., Any]] = None,
): ):
function: Type = self._make_function( function: Type = self._make_function(
name=name, name=name,
@@ -135,5 +142,5 @@ class Preamble(Environment):
if py_function is not None: if py_function is not None:
self._python_funcs[name] = py_function self._python_funcs[name] = py_function
def get_py_func(self, name: str) -> Optional[Callable]: def get_py_func(self, name: str) -> Optional[Callable[..., Any]]:
return self._python_funcs.get(name) return self._python_funcs.get(name)

View File

@@ -8,6 +8,7 @@ from midas.ast.location import Location
from midas.ast.printer import MidasPrinter from midas.ast.printer import MidasPrinter
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.evaluator import Evaluator from midas.checker.evaluator import Evaluator
from midas.checker.frames import FrameManager
from midas.checker.operators import ( from midas.checker.operators import (
PY_COMPARATOR_METHODS, PY_COMPARATOR_METHODS,
PY_OPERATOR_METHODS, PY_OPERATOR_METHODS,
@@ -20,11 +21,14 @@ from midas.checker.resolver import Resolver
from midas.checker.types import ( from midas.checker.types import (
AppliedType, AppliedType,
BaseType, BaseType,
ColumnType,
ConstraintType, ConstraintType,
DataFrameType,
DerivedType, DerivedType,
Function, Function,
GenericType, GenericType,
OverloadedFunction, OverloadedFunction,
TupleType,
Type, Type,
TypeVar, TypeVar,
UnitType, UnitType,
@@ -43,6 +47,10 @@ class ReturnException(Exception):
pass pass
class UndefinedMethodException(Exception):
pass
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class MappedArgument: class MappedArgument:
expr: p.Expr expr: p.Expr
@@ -71,6 +79,7 @@ class PythonTyper(
self.logger: logging.Logger = logging.getLogger("PythonTyper") self.logger: logging.Logger = logging.getLogger("PythonTyper")
self.reporter: FileReporter = reporter.for_file(None) self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types self.types: TypesRegistry = types
self.frame_mgr: FrameManager = FrameManager(self)
self.global_env: Environment = Preamble(self.types) self.global_env: Environment = Preamble(self.types)
self.env: Environment = self.global_env self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {} self.locals: dict[p.Expr, int] = {}
@@ -190,6 +199,36 @@ class PythonTyper(
return self.env.get_at(distance, name) return self.env.get_at(distance, name)
return self.global_env.get(name) return self.global_env.get(name)
def call_method(
self,
location: Location,
obj: Type,
method_name: str,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Optional[Type]:
unfolded: Type = unfold_type(obj)
match unfolded:
case DataFrameType():
return self.frame_mgr.call(
method=method_name,
location=location,
frame=unfolded,
positional=positional,
keywords=keywords,
)
method: Optional[Type] = self.types.lookup_member(obj, method_name)
if method is None:
raise UndefinedMethodException
return self._get_call_result(
location,
method,
positional,
keywords,
)
def is_subtype(self, type1: Type, type2: Type) -> bool: def is_subtype(self, type1: Type, type2: Type) -> bool:
return self.types.is_subtype(type1, type2) return self.types.is_subtype(type1, type2)
@@ -319,9 +358,15 @@ class PythonTyper(
case p.VariableExpr(): case p.VariableExpr():
self._assign_var(location, target, value_type) self._assign_var(location, target, value_type)
# Allow any kind of object because we disallow creating new attributes
case p.GetExpr(object=object, name=name): case p.GetExpr(object=object, name=name):
self._assign_attr(location, object, name, value_type) self._assign_attr(location, object, name, value_type)
# Only support variable expressions because modifying
# the underlying value would require reference types
case p.SubscriptExpr(object=p.VariableExpr() as var, index=index):
self._assign_sub(location, var, index, value_type)
case _: case _:
if not isinstance(target, p.VariableExpr): if not isinstance(target, p.VariableExpr):
self.logger.warning(f"Unsupported assignment to {target}") self.logger.warning(f"Unsupported assignment to {target}")
@@ -360,6 +405,30 @@ class PythonTyper(
f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}", f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}",
) )
def _assign_sub(
self,
location: Location,
var: p.VariableExpr,
index: p.Expr,
value_type: Type,
):
var_type: Type = self.type_of(var)
unfolded_type: Type = unfold_type(var_type)
# TODO: what happens if type is an alias of a dataframe type
match unfolded_type:
case DataFrameType() as frame:
new_type: Type = self.frame_mgr.assign(
self.reporter, location, frame, index, value_type
)
self.env.assign(var.name, new_type)
case UnknownType():
return
case _:
self.reporter.error(
location,
f"Cannot assign {value_type} to index {index} of {var_type}",
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType() type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType()
self.env.return_types.append(type) self.env.return_types.append(type)
@@ -373,8 +442,10 @@ class PythonTyper(
# print(m) # <- m is still defined # print(m) # <- m is still defined
test_type: Type = self.type_of(stmt.test) test_type: Type = self.type_of(stmt.test)
# TODO Allow subtypes or any type if (
if test_type != self.types.get_type("bool"): not self.types.is_subtype(test_type, self.types.get_type("bool"))
and test_type != UnknownType()
):
self.reporter.error( self.reporter.error(
stmt.test.location, f"If test must be a boolean, got {test_type}" stmt.test.location, f"If test must be a boolean, got {test_type}"
) )
@@ -390,13 +461,16 @@ class PythonTyper(
pass pass
def visit_for_stmt(self, stmt: p.ForStmt) -> None: def visit_for_stmt(self, stmt: p.ForStmt) -> None:
item_type: Optional[Type] = self._get_iterator_type(stmt.iterator) item_type: Type = UnknownType()
if item_type is None: iterator_type: Type = self.type_of(stmt.iterator)
iterator_type: Type = self.compute_type(stmt.iterator) if iterator_type != UnknownType():
self.reporter.error( maybe_item_type = self._get_iterator_type(stmt.iterator, iterator_type)
stmt.iterator.location, f"{iterator_type} is not iterable" if maybe_item_type is None:
) self.reporter.error(
item_type = UnknownType() stmt.iterator.location, f"{iterator_type} is not iterable"
)
else:
item_type = maybe_item_type
self._assign(stmt.location, stmt.target, item_type) self._assign(stmt.location, stmt.target, item_type)
self.judge(stmt.target, item_type) self.judge(stmt.target, item_type)
@@ -436,20 +510,16 @@ class PythonTyper(
left: Type = self.type_of(left_expr) left: Type = self.type_of(left_expr)
right: Type = self.type_of(right_expr) right: Type = self.type_of(right_expr)
operation: Optional[Type] = self.types.lookup_member(left, method) result: Optional[Type]
if operation is None: try:
result = self.call_method(location, left, method, [(right_expr, right)], {})
except UndefinedMethodException:
self.reporter.error( self.reporter.error(
location, location,
f"Undefined operation {method} between {left} and {right}", f"Undefined operation {method} between {left} and {right}",
) )
return UnknownType() return UnknownType()
result: Optional[Type] = self._get_call_result(
location,
operation,
[(right_expr, right)],
{},
)
return result or UnknownType() return result or UnknownType()
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
@@ -462,20 +532,17 @@ class PythonTyper(
return UnknownType() return UnknownType()
operand: Type = self.type_of(expr.right) operand: Type = self.type_of(expr.right)
operation: Optional[Type] = self.types.lookup_member(operand, method)
if operation is None: result: Optional[Type]
try:
result = self.call_method(expr.location, operand, method, [], {})
except UndefinedMethodException:
self.reporter.error( self.reporter.error(
expr.location, expr.location,
f"Undefined operation {method} for {operand}", f"Undefined operation {method} for {operand}",
) )
return UnknownType() return UnknownType()
result: Optional[Type] = self._get_call_result(
expr.location,
operation,
[],
{},
)
return result or UnknownType() return result or UnknownType()
def visit_call_expr(self, expr: p.CallExpr) -> Type: def visit_call_expr(self, expr: p.CallExpr) -> Type:
@@ -483,13 +550,27 @@ class PythonTyper(
case p.VariableExpr(name="TypeVar"): case p.VariableExpr(name="TypeVar"):
return self.define_typevar(expr) or UnknownType() return self.define_typevar(expr) or UnknownType()
callee: Type = self.type_of(expr.callee)
positional: list[TypedExpr] = [ positional: list[TypedExpr] = [
(arg, self.type_of(arg)) for arg in expr.arguments (arg, self.type_of(arg)) for arg in expr.arguments
] ]
keywords: dict[str, TypedExpr] = { keywords: dict[str, TypedExpr] = {
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items() name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
} }
match expr.callee:
case p.GetExpr(object=obj, name=method):
obj_type: Type = self.type_of(obj)
unfolded: Type = unfold_type(obj_type)
if isinstance(unfolded, DataFrameType):
return self.frame_mgr.call(
method,
expr.location,
unfolded,
positional,
keywords,
)
callee: Type = self.type_of(expr.callee)
return ( return (
self._get_call_result( self._get_call_result(
location=expr.location, location=expr.location,
@@ -504,7 +585,7 @@ class PythonTyper(
object: Type = self.type_of(expr.object) object: Type = self.type_of(expr.object)
member: Optional[Type] = self.types.lookup_member(object, expr.name) member: Optional[Type] = self.types.lookup_member(object, expr.name)
if member is None: if member is None:
self.reporter.error( self.reporter.warning(
expr.location, f"Unknown member '{expr.name}' of {object}" expr.location, f"Unknown member '{expr.name}' of {object}"
) )
return UnknownType() return UnknownType()
@@ -521,6 +602,8 @@ class PythonTyper(
return self.types.get_type("float") return self.types.get_type("float")
case str(): case str():
return self.types.get_type("str") return self.types.get_type("str")
case None:
return self.types.get_type("None")
case _: case _:
self.reporter.warning(expr.location, f"Unknown literal {expr}") self.reporter.warning(expr.location, f"Unknown literal {expr}")
return UnknownType() return UnknownType()
@@ -563,7 +646,10 @@ class PythonTyper(
test_type: Type = self.type_of(expr.test) test_type: Type = self.type_of(expr.test)
# TODO Allow subtypes or any type # TODO Allow subtypes or any type
if test_type != self.types.get_type("bool"): if (
not self.is_subtype(test_type, self.types.get_type("bool"))
and test_type != UnknownType()
):
self.reporter.error( self.reporter.error(
expr.test.location, f"If test must be a boolean, got {test_type}" expr.test.location, f"If test must be a boolean, got {test_type}"
) )
@@ -592,9 +678,9 @@ class PythonTyper(
if len(item_types) == 1: if len(item_types) == 1:
item_type: Type = item_types[0] item_type: Type = item_types[0]
return self.types.apply_generic(list_type, [item_type]) return self.types.apply_generic(list_type, [item_type])
self.reporter.error( self.reporter.warning(
expr.location, expr.location,
f"Heterogeneous list items: {item_types}", f"Heterogeneous list items: [{', '.join(map(str, item_types))}]",
) )
return self.types.apply_generic(list_type, [UnknownType()]) return self.types.apply_generic(list_type, [UnknownType()])
@@ -624,22 +710,29 @@ class PythonTyper(
if len(key_types) == 1: if len(key_types) == 1:
key_type = key_types[0] key_type = key_types[0]
else: else:
self.reporter.error( self.reporter.warning(
expr.location, expr.location,
f"Heterogeneous dict keys: {key_types}", f"Heterogeneous dict keys: [{', '.join(map(str, key_types))}]",
) )
if len(value_types) == 1: if len(value_types) == 1:
value_type = value_types[0] value_type = value_types[0]
else: else:
self.reporter.error( self.reporter.warning(
expr.location, expr.location,
f"Heterogeneous dict values: {value_types}", f"Heterogeneous dict values: [{', '.join(map(str, value_types))}]",
) )
return self.types.apply_generic(dict_type, [key_type, value_type]) return self.types.apply_generic(dict_type, [key_type, value_type])
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type: def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type:
object: Type = self.type_of(expr.object) object: Type = self.type_of(expr.object)
unfolded: Type = unfold_type(object)
match unfolded:
case TupleType():
return self._visit_tuple_subscript(unfolded, expr)
case DataFrameType():
return self._visit_frame_subscript(unfolded, expr)
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__") operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
if operation is None: if operation is None:
self.reporter.error( self.reporter.error(
@@ -657,6 +750,11 @@ class PythonTyper(
def visit_slice_expr(self, expr: p.SliceExpr) -> Type: def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
return self.types.get_type("slice") return self.types.get_type("slice")
def visit_tuple_expr(self, expr: p.TupleExpr) -> Type:
return TupleType(
items=tuple(self.type_of(item) for item in expr.items),
)
def visit_raw_expr(self, expr: p.RawExpr) -> Type: def visit_raw_expr(self, expr: p.RawExpr) -> Type:
return UnknownType() return UnknownType()
@@ -668,22 +766,35 @@ class PythonTyper(
self.reporter.warning(node.location, f"Unknown type '{node.base}'") self.reporter.warning(node.location, f"Unknown type '{node.base}'")
return UnknownType() return UnknownType()
if node.param is not None: if len(node.args) != 0:
param: Type = self.resolve_type_expr(node.param) args: list[Type] = [self.resolve_type_expr(arg) for arg in node.args]
return self.types.apply_generic(base, [param]) return self.types.apply_generic(base, args)
return base return base
def visit_constraint_type(self, node: p.ConstraintType) -> Type: def visit_constraint_type(self, node: p.ConstraintType) -> Type:
self.reporter.warning(node.location, "ConstraintType not yet supported") self.reporter.warning(node.location, "ConstraintType not yet supported")
return UnknownType() return UnknownType()
def visit_frame_column(self, node: p.FrameColumn) -> Type: def visit_frame_column(self, node: p.FrameColumn) -> ColumnType:
self.reporter.warning(node.location, "FrameColumn not yet supported") return ColumnType(
return UnknownType() type=(
self.resolve_type_expr(node.type)
if node.type is not None
else UnknownType()
)
)
def visit_frame_type(self, node: p.FrameType) -> Type: def visit_frame_type(self, node: p.FrameType) -> Type:
self.reporter.warning(node.location, "FrameType not yet supported") return DataFrameType(
return UnknownType() columns=[
DataFrameType.Column(
index=i,
name=column.name,
type=self.visit_frame_column(column),
)
for i, column in enumerate(node.columns)
]
)
def _get_call_result( def _get_call_result(
self, self,
@@ -1055,9 +1166,8 @@ class PythonTyper(
return False return False
return True return True
def _get_iterator_type(self, expr: p.Expr) -> Optional[Type]: def _get_iterator_type(self, expr: p.Expr, type: Type) -> Optional[Type]:
# TODO: lookup __iter__ # TODO: lookup __iter__
type: Type = self.type_of(expr)
getitem: Optional[Type] = self.types.lookup_member(type, "__getitem__") getitem: Optional[Type] = self.types.lookup_member(type, "__getitem__")
if getitem is None: if getitem is None:
return None return None
@@ -1123,7 +1233,7 @@ class PythonTyper(
node: ast.Expression = ast.parse(value, mode="eval") node: ast.Expression = ast.parse(value, mode="eval")
return parser._parse_type(node.body) return parser._parse_type(node.body)
case p.VariableExpr(name=name): case p.VariableExpr(name=name):
return p.BaseType(location=location, base=name, param=None) return p.BaseType(location=location, base=name, args=())
case _: case _:
raise NotImplementedError raise NotImplementedError
@@ -1211,8 +1321,34 @@ class PythonTyper(
return False return False
return True return True
case DataFrameType() | ColumnType():
self.reporter.error(
expr.location, f"Cannot cast {lit_value!r} to {target_type}"
)
return False
case _: case _:
self.reporter.info( self.reporter.info(
expr.location, f"Cannot evaluate cast to {target_type} statically" expr.location, f"Cannot evaluate cast to {target_type} statically"
) )
return False return False
def _visit_tuple_subscript(self, tup: TupleType, expr: p.SubscriptExpr) -> Type:
match expr.index:
case p.LiteralExpr(value=int() as index):
if index < 0 or index >= len(tup.items):
self.reporter.error(
expr.location, f"Index {index} out of range for tuple {tup}"
)
return UnknownType()
return tup.items[index]
case _:
self.reporter.error(
expr.location, f"Invalid index type {expr.index} on {tup}"
)
return UnknownType()
def _visit_frame_subscript(
self, frame: DataFrameType, expr: p.SubscriptExpr
) -> Type:
return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)

View File

@@ -7,8 +7,10 @@ from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import ( from midas.checker.types import (
AppliedType, AppliedType,
BaseType, BaseType,
ColumnType,
ComplexType, ComplexType,
ConstraintType, ConstraintType,
DataFrameType,
DerivedType, DerivedType,
ExtensionType, ExtensionType,
Function, Function,
@@ -16,6 +18,7 @@ from midas.checker.types import (
OverloadedFunction, OverloadedFunction,
Predicate, Predicate,
TopType, TopType,
TupleType,
Type, Type,
TypeVar, TypeVar,
UnknownType, UnknownType,
@@ -157,6 +160,24 @@ class TypesRegistry:
return False return False
return True return True
case (DataFrameType(columns=columns1), DataFrameType(columns=columns2)):
# TODO: check order?
by_name1: dict[str, DataFrameType.Column] = {
col.name: col for col in columns1 if col.name is not None
}
for col2 in columns2:
if col2.name not in by_name1:
return False
if not self.is_subtype(by_name1[col2.name].type, col2.type):
return False
return True
case (ColumnType(type=inner1), ColumnType(type=inner2)):
# TODO: invariant, replace ColumnType with simple GenericType
if not self.are_equivalent(inner1, inner2):
return False
return True
case (Function(), Function()): case (Function(), Function()):
return self.is_func_subtype(type1, type2) return self.is_func_subtype(type1, type2)
@@ -187,6 +208,9 @@ class TypesRegistry:
return False return False
def are_equivalent(self, type1: Type, type2: Type) -> bool:
return self.is_subtype(type1, type2) and self.is_subtype(type2, type1)
# TODO: verify the logic in here # TODO: verify the logic in here
def is_func_subtype(self, func1: Function, func2: Function) -> bool: def is_func_subtype(self, func1: Function, func2: Function) -> bool:
"""Check whether a function is a subtype of another """Check whether a function is a subtype of another
@@ -323,6 +347,9 @@ class TypesRegistry:
body=substitute_typevars(body, substitutions), body=substitute_typevars(body, substitutions),
) )
case BaseType(name="tuple"):
return TupleType(items=tuple(args))
case _: case _:
raise ValueError(f"{type} is not a generic type") raise ValueError(f"{type} is not a generic type")

View File

@@ -128,6 +128,10 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
case p.GetExpr(): case p.GetExpr():
target.accept(self) target.accept(self)
case p.SubscriptExpr():
target.accept(self)
case _: case _:
raise Exception(f"Unsupported assignment to {target}") raise Exception(f"Unsupported assignment to {target}")
@@ -232,5 +236,9 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
if expr.step is not None: if expr.step is not None:
self.resolve(expr.step) self.resolve(expr.step)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
for item in expr.items:
self.resolve(item)
def visit_raw_expr(self, expr: p.RawExpr) -> None: def visit_raw_expr(self, expr: p.RawExpr) -> None:
pass pass

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import StrEnum from enum import StrEnum
from typing import Optional, assert_never from typing import Optional, assert_never, cast
import midas.ast.midas as m import midas.ast.midas as m
from midas.ast.printer import MidasPrinter from midas.ast.printer import MidasPrinter
@@ -156,6 +156,37 @@ class ConstraintType:
return f"{self.type} where {printer.print(self.constraint)}" return f"{self.type} where {printer.print(self.constraint)}"
@dataclass(frozen=True, kw_only=True)
class TupleType:
items: tuple[Type, ...]
def __str__(self) -> str:
return f"({', '.join(map(str, self.items))})"
@dataclass(frozen=True, kw_only=True)
class ColumnType:
type: Type
def __str__(self) -> str:
return f"Column[{self.type}]"
@dataclass(frozen=True, kw_only=True)
class DataFrameType:
columns: list[Column]
def __str__(self) -> str:
schema: list[str] = [f"{col.name}: {col.type}" for col in self.columns]
return f"Frame[{', '.join(schema)}]"
@dataclass(frozen=True, kw_only=True)
class Column:
index: int
name: Optional[str]
type: ColumnType
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def sub_argument(arg: Function.Argument): def sub_argument(arg: Function.Argument):
return Function.Argument( return Function.Argument(
@@ -165,6 +196,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
required=arg.required, required=arg.required,
) )
def sub_column(col: DataFrameType.Column):
return DataFrameType.Column(
index=col.index,
name=col.name,
type=cast(ColumnType, substitute_typevars(col.type, substitutions)),
)
match type: match type:
case TopType(): case TopType():
return type return type
@@ -252,10 +290,26 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
body=substitute_typevars(body, substitutions), body=substitute_typevars(body, substitutions),
) )
case TupleType(items=items):
return TupleType(
items=tuple(substitute_typevars(item, substitutions) for item in items),
)
case ColumnType(type=items_type):
return ColumnType(
type=substitute_typevars(items_type, substitutions),
)
case DataFrameType(columns=columns):
return DataFrameType(
columns=list(map(sub_column, columns)),
)
case UnknownType() | UnitType(): case UnknownType() | UnitType():
return type return type
case TopType() | GenericType(): case TopType() | GenericType():
raise NotImplementedError(f"Unsupported type {type}") raise NotImplementedError(f"Unsupported type {type}")
# Ensure exhaustiveness # Ensure exhaustiveness
@@ -319,6 +373,15 @@ def to_annotation(type: Type) -> str:
case ConstraintType(): case ConstraintType():
return str(type) return str(type)
case TupleType(items=items):
return f"Tuple[{', '.join(map(to_annotation, items))}]"
case ColumnType():
return "pd.Series"
case DataFrameType():
return "pd.DataFrame"
case _: case _:
assert_never(type) assert_never(type)
@@ -344,4 +407,7 @@ Type = (
| GenericType | GenericType
| AppliedType | AppliedType
| ConstraintType | ConstraintType
| TupleType
| ColumnType
| DataFrameType
) )

View File

@@ -5,7 +5,7 @@
import sys import sys
from pathlib import Path from pathlib import Path
from typing import TextIO from typing import Optional, TextIO
import click import click
@@ -19,18 +19,23 @@ from midas.utils import TypedAST
@click.command(help="Compile source") @click.command(help="Compile source")
@click.argument("file", type=click.File("r")) @click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True) @click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-s", "--stubs", type=str, multiple=True)
@click.option("--ignore-errors", is_flag=True) @click.option("--ignore-errors", is_flag=True)
def compile( def compile(
file: TextIO, file: TextIO,
types: tuple[TextIO], types: tuple[TextIO],
stubs: tuple[str],
ignore_errors: bool, ignore_errors: bool,
): ):
source: str = file.read() source: str = file.read()
source_path: Path = Path(file.name).resolve() source_path: Path = Path(file.name).resolve()
checker = TypeChecker() checker = TypeChecker()
for types_file in types: type_files: list[tuple[Path, Optional[str]]] = []
checker.import_midas(Path(types_file.name).resolve()) for i, types_file in enumerate(types):
in_path: Path = Path(types_file.name).resolve()
checker.import_midas(in_path)
type_files.append((in_path, stubs[i] if i < len(stubs) else None))
typed_ast: TypedAST = checker.type_check_source(source, str(source_path)) typed_ast: TypedAST = checker.type_check_source(source, str(source_path))
diagnostics: list[Diagnostic] = checker.diagnostics.copy() diagnostics: list[Diagnostic] = checker.diagnostics.copy()
@@ -43,4 +48,4 @@ def compile(
sys.exit(1) sys.exit(1)
generator = Generator(workdir=source_path.parent, types=checker.types) generator = Generator(workdir=source_path.parent, types=checker.types)
generator.generate(typed_ast, source_path) generator.generate(typed_ast, source_path, type_files=type_files)

View File

@@ -1,7 +1,7 @@
import ast import ast
import time import time
from pathlib import Path from pathlib import Path
from typing import TextIO from typing import Optional, TextIO
import black import black
import click import click
@@ -38,15 +38,17 @@ class Handler(FileSystemEventHandler):
@click.command(help="Generate stubs from Midas definitions") @click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r")) @click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-") @click.option("-o", "--output", type=click.File("w"))
@click.option("-w", "--watch", is_flag=True) @click.option("-w", "--watch", is_flag=True)
def stubs( def stubs(
file: TextIO, file: TextIO,
output: TextIO, output: Optional[TextIO],
watch: bool, watch: bool,
): ):
source_path: Path = Path(file.name).resolve() source_path: Path = Path(file.name).resolve()
out_path: Path = Path(output.name).resolve() out_path: Path = source_path.with_suffix(".pyi")
if output is not None:
out_path = Path(output.name).resolve()
generate_stubs(source_path, out_path) generate_stubs(source_path, out_path)
if watch: if watch:

View File

@@ -134,9 +134,9 @@ class PythonHighlighter(
def visit_base_type(self, node: p.BaseType) -> None: def visit_base_type(self, node: p.BaseType) -> None:
self.wrap(node, "base-type") self.wrap(node, "base-type")
if node.param is not None: for arg in node.args:
self.wrap(node.param, "param") self.wrap(arg, "arg")
node.param.accept(self) arg.accept(self)
def visit_constraint_type(self, node: p.ConstraintType) -> None: def visit_constraint_type(self, node: p.ConstraintType) -> None:
self.wrap(node, "constraint-type") self.wrap(node, "constraint-type")
@@ -247,6 +247,10 @@ class PythonHighlighter(
if expr.step is not None: if expr.step is not None:
expr.step.accept(self) expr.step.accept(self)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
for item in expr.items:
item.accept(self)
def visit_raw_expr(self, expr: p.RawExpr) -> None: ... def visit_raw_expr(self, expr: p.RawExpr) -> None: ...
def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ... def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ...
@@ -350,6 +354,14 @@ class MidasHighlighter(
for param in spec.pos + spec.mixed + spec.kw: for param in spec.pos + spec.mixed + spec.kw:
param.type.accept(self) param.type.accept(self)
def visit_frame_type(self, type: m.FrameType) -> None:
self.wrap(type, "frame")
for column in type.columns:
self._visit_frame_column(column)
def _visit_frame_column(self, column: m.FrameType.Column) -> None:
self.wrap(column, "column")
class DiagnosticsHighlighter(Highlighter): class DiagnosticsHighlighter(Highlighter):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css" EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"

View File

@@ -3,7 +3,7 @@ span {
--col: 108, 233, 108; --col: 108, 233, 108;
} }
&.param { &.arg {
--col: 103, 192, 224; --col: 103, 192, 224;
} }

View File

@@ -68,7 +68,7 @@ class DiagnosticPrinter:
loc: Location = diagnostic.location loc: Location = diagnostic.location
if loc.lineno != loc.end_lineno: if loc.lineno != loc.end_lineno:
print(diagnostic) self.print_multiline(lines, diagnostic, indent)
return return
start_offset: int = loc.col_offset start_offset: int = loc.col_offset
@@ -95,3 +95,27 @@ class DiagnosticPrinter:
print(indent_str + before + subject + after) print(indent_str + before + subject + after)
print(indent_str + cursor) print(indent_str + cursor)
print() print()
def print_multiline(
self, all_lines: list[str], diagnostic: Diagnostic, indent: int = 4
):
loc: Location = diagnostic.location
lines: list[str] = all_lines[loc.lineno - 1 : loc.end_lineno]
start_offset: int = loc.col_offset
end_offset: int = loc.end_col_offset or (start_offset + 1)
indent_str: str = " " * indent
color: int = self.COLORS.get(diagnostic.type, Ansi.WHITE)
res: str = indent_str + lines[0][:start_offset]
res += Ansi.FG(color) + lines[0][start_offset:]
for line in lines[1:-1]:
res += "\n" + indent_str + line
res += "\n" + indent_str + lines[-1][:end_offset]
res += Ansi.RESET + lines[-1][end_offset:]
print(diagnostic.location_str + ":")
print(res)
print()
print(Ansi.FG(color) + diagnostic.message + Ansi.RESET)
print()

View File

@@ -1,4 +1,5 @@
import ast import ast
import logging
import shutil import shutil
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
@@ -8,38 +9,47 @@ import midas.ast.midas as m
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.ast.printer import MidasPrinter from midas.ast.printer import MidasPrinter
from midas.checker.checker import TypeChecker
from midas.checker.registry import TypesRegistry from midas.checker.registry import TypesRegistry
from midas.checker.types import ( from midas.checker.types import (
AppliedType, AppliedType,
BaseType, BaseType,
ColumnType,
ComplexType, ComplexType,
ConstraintType, ConstraintType,
DataFrameType,
DerivedType, DerivedType,
ExtensionType, ExtensionType,
Function, Function,
GenericType, GenericType,
OverloadedFunction, OverloadedFunction,
TopType, TopType,
TupleType,
Type, Type,
TypeVar, TypeVar,
UnitType, UnitType,
UnknownType, UnknownType,
) )
from midas.generator.constraints import ConstraintGenerator from midas.generator.constraints import ConstraintGenerator
from midas.generator.stubs import StubsGenerator
from midas.utils import TypedAST from midas.utils import TypedAST
@dataclass @dataclass
class Scope: class Scope:
pre_assertions: list[ast.stmt] = field(default_factory=list) pre_assertions: list[ast.stmt] = field(default_factory=list[ast.stmt])
aliases: list[str] = field(default_factory=list) aliases: list[str] = field(default_factory=list[str])
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
IS_DATAFRAME_FUNC = "__midas_is_dataframe__"
IS_COLUMN_FUNC = "__midas_is_column__"
def __init__(self, workdir: Path, types: TypesRegistry) -> None: def __init__(self, workdir: Path, types: TypesRegistry) -> None:
self.workdir: Path = workdir.resolve() self.workdir: Path = workdir.resolve()
self.build_dir: Path = self.workdir / "build" / "midas" self.build_dir: Path = self.workdir / "build" / "midas"
self.rel_src_path: Path = Path() self.rel_src_path: Path = Path()
self.logger: logging.Logger = logging.getLogger("Generator")
self._typed_ast: TypedAST = TypedAST( self._typed_ast: TypedAST = TypedAST(
stmts=[], stmts=[],
@@ -53,20 +63,37 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types) self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
self._constraints: list[tuple[m.Expr, ast.expr]] = [] self._constraints: list[tuple[m.Expr, ast.expr]] = []
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST: self.define_is_dataframe: bool = False
self.rel_src_path = src_path.resolve().relative_to(self.workdir) self.define_is_column: bool = False
def set_src_path(self, path: Path):
self.rel_src_path = path.resolve().relative_to(self.workdir)
def generate_ast(self, typed_ast: TypedAST) -> ast.AST:
self._typed_ast = typed_ast self._typed_ast = typed_ast
body: list[ast.stmt] = self._visit_body(typed_ast.stmts) body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
predicates: list[ast.stmt] = self._constraint_generator.get_definitions() predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
module = ast.Module(body=predicates + body, type_ignores=[])
body = predicates + body
if self.define_is_dataframe:
body = [self._is_dataframe_definition()] + body
if self.define_is_column:
body = [self._is_column_definition()] + body
module = ast.Module(body=body, type_ignores=[])
module = ast.fix_missing_locations(module) module = ast.fix_missing_locations(module)
return module return module
def generate( def generate(
self, typed_ast: TypedAST, src_path: Path, out_path: Optional[Path] = None self,
typed_ast: TypedAST,
src_path: Path,
out_path: Optional[Path] = None,
type_files: Optional[list[tuple[Path, Optional[str]]]] = None,
) -> Path: ) -> Path:
module: ast.AST = self.generate_ast(typed_ast, src_path) self.set_src_path(src_path)
compiled: str = ast.unparse(module)
if out_path is None: if out_path is None:
if self.build_dir.exists(): if self.build_dir.exists():
shutil.rmtree(self.build_dir) shutil.rmtree(self.build_dir)
@@ -78,10 +105,30 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
raise ValueError( raise ValueError(
f"Directory traversal, {self.rel_src_path} points outside of parent directory" f"Directory traversal, {self.rel_src_path} points outside of parent directory"
) )
out_path.parent.mkdir(parents=True, exist_ok=True) out_dir: Path = out_path.parent
out_dir.parent.mkdir(parents=True, exist_ok=True)
if type_files is not None:
for in_path, out_name in type_files:
if out_name is None:
out_name = in_path.stem
self.generate_stubs(in_path, out_dir / f"{out_name}.py")
module: ast.AST = self.generate_ast(typed_ast)
compiled: str = ast.unparse(module)
out_path.write_text(compiled) out_path.write_text(compiled)
return out_path return out_path
def generate_stubs(self, in_path: Path, out_path: Path):
checker = TypeChecker()
checker.import_midas(in_path)
generator = StubsGenerator(checker.types)
module: ast.Module = generator.generate_stubs()
module = ast.fix_missing_locations(module)
output: str = ast.unparse(module)
out_path.write_text(output)
def visit_binary_expr(self, expr: p.BinaryExpr) -> ast.expr: def visit_binary_expr(self, expr: p.BinaryExpr) -> ast.expr:
return ast.BinOp( return ast.BinOp(
left=expr.left.accept(self), left=expr.left.accept(self),
@@ -139,7 +186,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
alias: ast.expr = self._make_alias(expr2) alias: ast.expr = self._make_alias(expr2)
type: Type = self._get_expr_type(expr) type: Type = self._get_expr_type(expr)
self._make_cast_asserts(expr.location, alias, type) asserts: list[ast.stmt] = self._make_cast_asserts(expr.location, alias, type)
for assert_ in asserts:
self._add_assert(assert_)
return alias return alias
@@ -174,6 +223,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
step=expr.step.accept(self) if expr.step is not None else None, step=expr.step.accept(self) if expr.step is not None else None,
) )
def visit_tuple_expr(self, expr: p.TupleExpr) -> ast.expr:
return ast.Tuple(
elts=[item.accept(self) for item in expr.items],
)
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr: def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
return expr.expr return expr.expr
@@ -274,63 +328,156 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
) )
return alias return alias
def _add_assert(self, expr: ast.expr, message: str | ast.expr): def _build_assert(self, expr: ast.expr, message: str | ast.expr) -> ast.stmt:
if isinstance(message, str): if isinstance(message, str):
message = ast.Constant(value=message) message = ast.Constant(value=message)
self._scopes[-1].pre_assertions.append( return ast.Assert(
ast.Assert( test=expr,
test=expr, msg=message,
msg=message,
)
) )
def _add_assert(self, assertion: ast.stmt):
self._scopes[-1].pre_assertions.append(assertion)
def _get_expr_type(self, query: p.Expr) -> Type: def _get_expr_type(self, query: p.Expr) -> Type:
for expr, type in self._typed_ast.judgements: for expr, type in self._typed_ast.judgements:
if expr == query: if expr == query:
return type return type
raise RuntimeError(f"Cannot get type judgement for {query}") raise RuntimeError(f"Cannot get type judgement for {query}")
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type): def _make_cast_asserts(
self, src_location: Location, expr: ast.expr, type: Type
) -> list[ast.stmt]:
match type: match type:
case UnknownType(): case UnknownType():
pass return []
case BaseType(name=name): case BaseType(name=name):
self._add_assert( return [
ast.Call( self._build_assert(
func=ast.Name(id="isinstance"), ast.Call(
args=[expr, ast.Name(id=name)], func=ast.Name(id="isinstance"),
keywords=[], args=[expr, ast.Name(id=name)],
), keywords=[],
self._make_cast_assert_message(src_location, expr, type), ),
) self._make_cast_assert_message(src_location, expr, type),
)
]
case DerivedType(type=base): case DerivedType(type=base):
self._make_cast_asserts(src_location, expr, base) return self._make_cast_asserts(src_location, expr, base)
case UnitType(): case UnitType():
self._add_assert( return [
ast.Compare( self._build_assert(
left=expr, ast.Compare(
ops=[ast.Is()], left=expr,
comparators=[ ops=[ast.Is()],
ast.Constant(value=None), comparators=[
], ast.Constant(value=None),
],
),
self._make_cast_assert_message(src_location, expr, type),
), ),
self._make_cast_assert_message(src_location, expr, type), ]
)
case AppliedType(body=body): case AppliedType(body=body):
self._make_cast_asserts(src_location, expr, body) return self._make_cast_asserts(src_location, expr, body)
case ConstraintType(type=base, constraint=constraint): case ConstraintType(type=base, constraint=constraint):
self._make_cast_asserts(src_location, expr, base) asserts: list[ast.stmt] = self._make_cast_asserts(
self._make_constraint_assert(src_location, expr, constraint) src_location, expr, base
)
asserts.append(
self._make_constraint_assert(src_location, expr, constraint)
)
return asserts
case TypeVar(bound=bound): case TypeVar(bound=bound):
# TODO: check with type from arguments / use call-site context # TODO: check with type from arguments / use call-site context
if bound is not None: if bound is None:
self._make_cast_asserts(src_location, expr, bound) return []
return self._make_cast_asserts(src_location, expr, bound)
case TupleType(items=items):
asserts: list[ast.stmt] = [
self._build_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id="tuple")],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
),
]
assert isinstance(expr, ast.Tuple)
for item, item_type in zip(expr.elts, items):
asserts.extend(
self._make_cast_asserts(src_location, item, item_type)
)
return asserts
case DataFrameType(columns=columns):
self.define_is_dataframe = True
asserts: list[ast.stmt] = [
self._build_assert(
ast.Call(
func=ast.Name(id=self.IS_DATAFRAME_FUNC),
args=[expr],
keywords=[],
),
self._make_cast_assert_message(
src_location, expr, type, ": Not a dataframe"
),
),
]
for column in columns:
asserts.append(
self._build_assert(
ast.Compare(
left=ast.Constant(value=column.name),
ops=[ast.In()],
comparators=[expr],
),
self._make_cast_assert_message(
src_location,
expr,
type,
f": Missing column {column.name}",
),
)
)
asserts.extend(
self._make_cast_asserts(
src_location,
ast.Subscript(
value=expr, slice=ast.Constant(value=column.name)
),
column.type,
)
)
return asserts
case ColumnType():
self.define_is_column = True
asserts: list[ast.stmt] = [
self._build_assert(
ast.Call(
func=ast.Name(id=self.IS_COLUMN_FUNC),
args=[expr],
keywords=[],
),
self._make_cast_assert_message(
src_location, expr, type, ": Not a column"
),
),
]
inner_assert: Optional[ast.stmt] = self._make_column_inner_assert(
src_location, expr, type
)
if inner_assert is not None:
asserts.append(inner_assert)
return asserts
case ( case (
TopType() TopType()
@@ -340,14 +487,19 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
| ExtensionType() | ExtensionType()
| GenericType() | GenericType()
): ):
raise NotImplementedError(f"Can't make assertion for type {type}") self.logger.warning(f"Can't make assertion for type {type}")
return []
# Ensure exhaustiveness # Ensure exhaustiveness
case _: case _:
assert_never(type) assert_never(type)
def _make_cast_assert_message( def _make_cast_assert_message(
self, location: Location, expr: ast.expr, type: Type self,
location: Location,
expr: ast.expr,
type: Type,
extra: Optional[str] = None,
) -> ast.expr: ) -> ast.expr:
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}" loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type" # f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
@@ -365,15 +517,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
), ),
conversion=-1, conversion=-1,
), ),
ast.Constant(f" to {type}"), ast.Constant(f" to {type}{extra or ''}"),
] ]
) )
def _make_constraint_assert( def _make_constraint_assert(
self, src_location: Location, expr: ast.expr, constraint: m.Expr self, src_location: Location, expr: ast.expr, constraint: m.Expr
): ) -> ast.stmt:
test_func: ast.expr = self._get_constraint(constraint) test_func: ast.expr = self._get_constraint(constraint)
self._add_assert( return self._build_assert(
ast.Call( ast.Call(
func=test_func, func=test_func,
args=[expr], args=[expr],
@@ -401,3 +553,90 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
constraint: ast.expr = self._constraint_generator.generate(expr) constraint: ast.expr = self._constraint_generator.generate(expr)
self._constraints.append((expr, constraint)) self._constraints.append((expr, constraint))
return constraint return constraint
def _is_dataframe_definition(self) -> ast.stmt:
"""
def IS_DATAFRAME_FUNC(obj) -> bool:
import pandas as pd
return isinstance(obj, pd.DataFrame)
"""
return ast.FunctionDef(
name=self.IS_DATAFRAME_FUNC,
args=ast.arguments(
posonlyargs=[ast.arg(arg="obj")],
args=[],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
ast.Return(
value=ast.Call(
func=ast.Name(id="isinstance"),
args=[
ast.Name(id="obj"),
ast.Attribute(
value=ast.Name(id="pd"),
attr="DataFrame",
),
],
keywords=[],
)
),
],
decorator_list=[],
returns=ast.Name(id="bool"),
)
def _is_column_definition(self) -> ast.stmt:
"""
def IS_COLUMN_FUNC(obj) -> bool:
import pandas as pd
return isinstance(obj, pd.Series)
"""
return ast.FunctionDef(
name=self.IS_COLUMN_FUNC,
args=ast.arguments(
posonlyargs=[ast.arg(arg="obj")],
args=[],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
ast.Return(
value=ast.Call(
func=ast.Name(id="isinstance"),
args=[
ast.Name(id="obj"),
ast.Attribute(
value=ast.Name(id="pd"),
attr="Series",
),
],
keywords=[],
)
),
],
decorator_list=[],
returns=ast.Name(id="bool"),
)
def _make_column_inner_assert(
self, src_location: Location, column: ast.expr, type: ColumnType
) -> Optional[ast.stmt]:
# TODO: improve message, maybe chain contexts
col: ast.expr = ast.Name(id="col")
body: list[ast.stmt] = self._make_cast_asserts(src_location, col, type.type)
if len(body) == 0:
return None
return ast.For(
target=col,
iter=column,
body=body,
orelse=[],
)

View File

@@ -6,14 +6,17 @@ from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import ( from midas.checker.types import (
AppliedType, AppliedType,
BaseType, BaseType,
ColumnType,
ComplexType, ComplexType,
ConstraintType, ConstraintType,
DataFrameType,
DerivedType, DerivedType,
ExtensionType, ExtensionType,
Function, Function,
GenericType, GenericType,
OverloadedFunction, OverloadedFunction,
TopType, TopType,
TupleType,
Type, Type,
TypeVar, TypeVar,
UnitType, UnitType,
@@ -30,6 +33,7 @@ class StubsGenerator:
self.types: TypesRegistry = types self.types: TypesRegistry = types
self.stubs: list[ast.stmt] = [] self.stubs: list[ast.stmt] = []
self.typing_imports: set[str] = set() self.typing_imports: set[str] = set()
self.import_pandas: bool = False
self.protocol_idx: int = 0 self.protocol_idx: int = 0
self.stub_idx: int = 0 self.stub_idx: int = 0
self.type_var_idx: int = 0 self.type_var_idx: int = 0
@@ -38,6 +42,7 @@ class StubsGenerator:
def generate_stubs(self) -> ast.Module: def generate_stubs(self) -> ast.Module:
self.stubs = [] self.stubs = []
self.typing_imports = set() self.typing_imports = set()
self.import_pandas = False
for name, type in self.types._types.items(): for name, type in self.types._types.items():
# Skip builtin types, not just based on name so the user can override # Skip builtin types, not just based on name so the user can override
# TODO: check if added members on builtin type # TODO: check if added members on builtin type
@@ -53,7 +58,7 @@ class StubsGenerator:
continue continue
self.generate_stub(name, type) self.generate_stub(name, type)
imports = [ imports: list[ast.stmt] = [
ast.ImportFrom( ast.ImportFrom(
module="__future__", module="__future__",
names=[ast.alias(name="annotations")], names=[ast.alias(name="annotations")],
@@ -70,6 +75,17 @@ class StubsGenerator:
level=0, level=0,
) )
) )
if self.import_pandas:
imports.append(
ast.Import(
names=[
ast.alias(
name="pandas",
asname="pd",
)
],
)
)
return ast.Module(body=imports + self.stubs, type_ignores=[]) return ast.Module(body=imports + self.stubs, type_ignores=[])
def generate_stub(self, name: str, type: Type): def generate_stub(self, name: str, type: Type):
@@ -231,6 +247,31 @@ class StubsGenerator:
case ConstraintType(): case ConstraintType():
return self.dump_type(type.type) return self.dump_type(type.type)
case TupleType(items=items):
return ast.Subscript(
value=ast.Name(id="tuple"),
slice=ast.Tuple(
elts=[self.dump_type(item) for item in items],
),
)
case ColumnType(type=inner):
self.import_pandas = True
return ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="pd"),
attr="Series",
),
slice=self.dump_type(inner),
)
case DataFrameType():
self.import_pandas = True
return ast.Attribute(
value=ast.Name(id="pd"),
attr="DataFrame",
)
case _: case _:
assert_never(type) assert_never(type)

View File

@@ -10,6 +10,7 @@ from midas.ast.midas import (
Expr, Expr,
ExtendStmt, ExtendStmt,
ExtensionType, ExtensionType,
FrameType,
FunctionType, FunctionType,
GenericType, GenericType,
GetExpr, GetExpr,
@@ -226,8 +227,10 @@ class MidasParser(Parser):
return self.generic_type() return self.generic_type()
def generic_type(self) -> Type: def generic_type(self) -> Type:
type: Type = self.named_type() type: NamedType = self.named_type()
if self.check(TokenType.LEFT_BRACKET): if self.check(TokenType.LEFT_BRACKET):
if type.name.lexeme == "Frame":
return self.frame_type()
args: list[Type] = self.type_args() args: list[Type] = self.type_args()
return GenericType( return GenericType(
location=Location.span(type.location, self.previous().get_location()), location=Location.span(type.location, self.previous().get_location()),
@@ -246,7 +249,7 @@ class MidasParser(Parser):
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments") self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
return args return args
def named_type(self) -> Type: def named_type(self) -> NamedType:
name: Token = self.consume_identifier("Expected type name") name: Token = self.consume_identifier("Expected type name")
return NamedType( return NamedType(
location=name.get_location(), location=name.get_location(),
@@ -281,6 +284,32 @@ class MidasParser(Parser):
members=members, members=members,
) )
def frame_type(self) -> FrameType:
keyword: Token = self.previous()
self.consume(TokenType.LEFT_BRACKET, "Expected '[' to start frame schema")
columns: list[FrameType.Column] = []
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
name: Token = self.advance()
self.consume(TokenType.COLON, "Expected ':' between column name and type")
type: Type = self.type_expr()
columns.append(
FrameType.Column(
location=name.location_to(self.previous()),
name=name,
type=type,
)
)
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Unclosed frame schema")
return FrameType(
location=keyword.location_to(self.previous()),
columns=columns,
)
def constraint(self) -> Expr: def constraint(self) -> Expr:
"""Parse a constraint """Parse a constraint
@@ -370,7 +399,7 @@ class MidasParser(Parser):
pos_args: list[Expr] = [] pos_args: list[Expr] = []
kw_args: dict[str, Expr] = {} kw_args: dict[str, Expr] = {}
keywords: bool = False keywords: bool = False
while not self.match(TokenType.RIGHT_PAREN): while not self.check(TokenType.RIGHT_PAREN):
if self.check_identifier() and self.check_next(TokenType.EQUAL): if self.check_identifier() and self.check_next(TokenType.EQUAL):
keywords = True keywords = True
keyword: Token = self.advance() keyword: Token = self.advance()

View File

@@ -30,6 +30,7 @@ from midas.ast.python import (
Stmt, Stmt,
SubscriptExpr, SubscriptExpr,
TernaryExpr, TernaryExpr,
TupleExpr,
TypeAssign, TypeAssign,
UnaryExpr, UnaryExpr,
VariableExpr, VariableExpr,
@@ -300,26 +301,28 @@ class PythonParser:
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema): case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
return self._parse_frame_type(schema) return self._parse_frame_type(schema)
case ast.Subscript(value=ast.Name(id=name), slice=param): case ast.Subscript(value=ast.Name(id=name), slice=arg):
args: tuple[MidasType, ...] = (
tuple(self._parse_type(a) for a in arg.elts)
if isinstance(arg, ast.Tuple)
else (self._parse_type(arg),)
)
return BaseType( return BaseType(
location=loc, location=loc,
base=name, base=name,
param=self._parse_type(param), args=args,
) )
case ast.Name(id=name): case ast.Name(id=name):
return BaseType( return BaseType(
location=loc, location=loc,
base=name, base=name,
param=None, args=(),
) )
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr): case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
left = self._parse_type(left_expr) left = self._parse_type(left_expr)
match left: match left:
case None:
raise InvalidSyntaxError()
# If chained constraints, separate base type and rebuild constraint # If chained constraints, separate base type and rebuild constraint
case ConstraintType(type=left_type, constraint=left_constraint): case ConstraintType(type=left_type, constraint=left_constraint):
constraint = ast.BinOp( constraint = ast.BinOp(
@@ -345,7 +348,7 @@ class PythonParser:
return BaseType( return BaseType(
location=loc, location=loc,
base="None", base="None",
param=None, args=(),
) )
case _: case _:
@@ -477,6 +480,12 @@ class PythonParser:
step=self.parse_expr(step) if step is not None else None, step=self.parse_expr(step) if step is not None else None,
) )
case ast.Tuple(elts=items):
return TupleExpr(
location=location,
items=tuple(self.parse_expr(item) for item in items),
)
case _: case _:
print(f"Unsupported expression: {ast.unparse(node)}") print(f"Unsupported expression: {ast.unparse(node)}")
return RawExpr(location=location, expr=node) return RawExpr(location=location, expr=node)

View File

@@ -1,3 +1,4 @@
from typing import Generic, TypeVar
from typing import cast as typing_cast from typing import cast as typing_cast
cast = typing_cast cast = typing_cast
@@ -32,3 +33,20 @@ This operation is unsound, use at your own risk!
_**Internal Python documentation**_ _**Internal Python documentation**_
""" """
T = TypeVar("T")
class Frame(Generic[T]):
"""A `Frame` is the abstract type implemented by `DataFrame`
A frame contains any number of named columns (see :class:`Column`)
"""
class Column(Generic[T]):
"""A `Column` is the abstract type implemented by `Series`
A column contains a any number of values of the same type
"""

View File

@@ -4,7 +4,35 @@
"type": "Warning", "type": "Warning",
"location": { "location": {
"start": [ "start": [
6, 8,
12
],
"end": [
8,
43
]
},
"message": "ConstraintType not yet supported"
},
{
"type": "Warning",
"location": {
"start": [
10,
10
],
"end": [
10,
18
]
},
"message": "Unknown type 'datetime'"
},
{
"type": "Warning",
"location": {
"start": [
13,
4 4
], ],
"end": [ "end": [
@@ -12,7 +40,7 @@
5 5
] ]
}, },
"message": "FrameType not yet supported" "message": "Unknown type '_'"
} }
], ],
"judgments": [] "judgments": []

View File

@@ -328,6 +328,19 @@
}, },
"type": {} "type": {}
}, },
{
"location": {
"from": "L6:9",
"to": "L6:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{ {
"location": { "location": {
"from": "L6:5", "from": "L6:5",
@@ -373,19 +386,6 @@
} }
} }
}, },
{
"location": {
"from": "L6:9",
"to": "L6:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{ {
"location": { "location": {
"from": "L6:5", "from": "L6:5",
@@ -407,6 +407,32 @@
}, },
"type": {} "type": {}
}, },
{
"location": {
"from": "L7:9",
"to": "L7:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L7:12",
"to": "L7:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{ {
"location": { "location": {
"from": "L7:5", "from": "L7:5",
@@ -452,32 +478,6 @@
} }
} }
}, },
{
"location": {
"from": "L7:9",
"to": "L7:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L7:12",
"to": "L7:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{ {
"location": { "location": {
"from": "L7:5", "from": "L7:5",
@@ -503,6 +503,32 @@
}, },
"type": {} "type": {}
}, },
{
"location": {
"from": "L8:9",
"to": "L8:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L8:14",
"to": "L8:17"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{ {
"location": { "location": {
"from": "L8:5", "from": "L8:5",
@@ -548,32 +574,6 @@
} }
} }
}, },
{
"location": {
"from": "L8:9",
"to": "L8:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L8:14",
"to": "L8:17"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{ {
"location": { "location": {
"from": "L8:5", "from": "L8:5",
@@ -600,6 +600,45 @@
}, },
"type": {} "type": {}
}, },
{
"location": {
"from": "L9:9",
"to": "L9:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L9:12",
"to": "L9:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:17",
"to": "L9:23"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{ {
"location": { "location": {
"from": "L9:5", "from": "L9:5",
@@ -645,45 +684,6 @@
} }
} }
}, },
{
"location": {
"from": "L9:9",
"to": "L9:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L9:12",
"to": "L9:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:17",
"to": "L9:23"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{ {
"location": { "location": {
"from": "L9:5", "from": "L9:5",
@@ -713,6 +713,45 @@
}, },
"type": {} "type": {}
}, },
{
"location": {
"from": "L10:9",
"to": "L10:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:12",
"to": "L10:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L10:19",
"to": "L10:22"
},
"expr": {
"_type": "LiteralExpr",
"value": 3.0
},
"type": {
"name": "float"
}
},
{ {
"location": { "location": {
"from": "L10:5", "from": "L10:5",
@@ -758,45 +797,6 @@
} }
} }
}, },
{
"location": {
"from": "L10:9",
"to": "L10:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:12",
"to": "L10:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L10:19",
"to": "L10:22"
},
"expr": {
"_type": "LiteralExpr",
"value": 3.0
},
"type": {
"name": "float"
}
},
{ {
"location": { "location": {
"from": "L10:5", "from": "L10:5",
@@ -827,6 +827,19 @@
}, },
"type": {} "type": {}
}, },
{
"location": {
"from": "L11:11",
"to": "L11:12"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{ {
"location": { "location": {
"from": "L11:5", "from": "L11:5",
@@ -872,19 +885,6 @@
} }
} }
}, },
{
"location": {
"from": "L11:11",
"to": "L11:12"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{ {
"location": { "location": {
"from": "L11:5", "from": "L11:5",
@@ -906,6 +906,19 @@
}, },
"type": {} "type": {}
}, },
{
"location": {
"from": "L12:11",
"to": "L12:17"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{ {
"location": { "location": {
"from": "L12:5", "from": "L12:5",
@@ -951,19 +964,6 @@
} }
} }
}, },
{
"location": {
"from": "L12:11",
"to": "L12:17"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{ {
"location": { "location": {
"from": "L12:5", "from": "L12:5",
@@ -985,6 +985,45 @@
}, },
"type": {} "type": {}
}, },
{
"location": {
"from": "L14:10",
"to": "L14:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L14:13",
"to": "L14:16"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L14:20",
"to": "L14:26"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{ {
"location": { "location": {
"from": "L14:6", "from": "L14:6",
@@ -1030,45 +1069,6 @@
} }
} }
}, },
{
"location": {
"from": "L14:10",
"to": "L14:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L14:13",
"to": "L14:16"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L14:20",
"to": "L14:26"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{ {
"location": { "location": {
"from": "L14:6", "from": "L14:6",
@@ -1101,6 +1101,45 @@
"name": "bool" "name": "bool"
} }
}, },
{
"location": {
"from": "L15:10",
"to": "L15:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L15:15",
"to": "L15:18"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L15:22",
"to": "L15:28"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{ {
"location": { "location": {
"from": "L15:6", "from": "L15:6",
@@ -1146,45 +1185,6 @@
} }
} }
}, },
{
"location": {
"from": "L15:10",
"to": "L15:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L15:15",
"to": "L15:18"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L15:22",
"to": "L15:28"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{ {
"location": { "location": {
"from": "L15:6", "from": "L15:6",
@@ -1217,6 +1217,45 @@
"name": "bool" "name": "bool"
} }
}, },
{
"location": {
"from": "L16:10",
"to": "L16:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L16:15",
"to": "L16:21"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L16:25",
"to": "L16:28"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{ {
"location": { "location": {
"from": "L16:6", "from": "L16:6",
@@ -1262,45 +1301,6 @@
} }
} }
}, },
{
"location": {
"from": "L16:10",
"to": "L16:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L16:15",
"to": "L16:21"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L16:25",
"to": "L16:28"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{ {
"location": { "location": {
"from": "L16:6", "from": "L16:6",
@@ -1333,6 +1333,45 @@
"name": "bool" "name": "bool"
} }
}, },
{
"location": {
"from": "L18:10",
"to": "L18:13"
},
"expr": {
"_type": "LiteralExpr",
"value": "a"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L18:15",
"to": "L18:16"
},
"expr": {
"_type": "LiteralExpr",
"value": 3
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L18:20",
"to": "L18:25"
},
"expr": {
"_type": "LiteralExpr",
"value": false
},
"type": {
"name": "bool"
}
},
{ {
"location": { "location": {
"from": "L18:6", "from": "L18:6",
@@ -1378,45 +1417,6 @@
} }
} }
}, },
{
"location": {
"from": "L18:10",
"to": "L18:13"
},
"expr": {
"_type": "LiteralExpr",
"value": "a"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L18:15",
"to": "L18:16"
},
"expr": {
"_type": "LiteralExpr",
"value": 3
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L18:20",
"to": "L18:25"
},
"expr": {
"_type": "LiteralExpr",
"value": false
},
"type": {
"name": "bool"
}
},
{ {
"location": { "location": {
"from": "L18:6", "from": "L18:6",

View File

@@ -24,7 +24,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "Meter", "base": "Meter",
"param": null "args": []
}, },
"expr": { "expr": {
"_type": "LiteralExpr", "_type": "LiteralExpr",
@@ -62,7 +62,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "Second", "base": "Second",
"param": null "args": []
}, },
"expr": { "expr": {
"_type": "LiteralExpr", "_type": "LiteralExpr",

View File

@@ -100,6 +100,32 @@
"name": "float" "name": "float"
} }
}, },
{
"location": {
"from": "L11:13",
"to": "L11:15"
},
"expr": {
"_type": "VariableExpr",
"name": "v1"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L11:17",
"to": "L11:19"
},
"expr": {
"_type": "VariableExpr",
"name": "v2"
},
"type": {
"name": "float"
}
},
{ {
"location": { "location": {
"from": "L11:5", "from": "L11:5",
@@ -135,32 +161,6 @@
} }
} }
}, },
{
"location": {
"from": "L11:13",
"to": "L11:15"
},
"expr": {
"_type": "VariableExpr",
"name": "v1"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L11:17",
"to": "L11:19"
},
"expr": {
"_type": "VariableExpr",
"name": "v2"
},
"type": {
"name": "float"
}
},
{ {
"location": { "location": {
"from": "L11:5", "from": "L11:5",

View File

@@ -72,29 +72,6 @@
} }
], ],
"judgments": [ "judgments": [
{
"location": {
"from": "L26:0",
"to": "L26:5"
},
"expr": {
"_type": "VariableExpr",
"name": "print"
},
"type": {
"pos_args": [
{
"pos": 0,
"name": "object",
"type": {},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {}
}
},
{ {
"location": { "location": {
"from": "L27:4", "from": "L27:4",
@@ -325,6 +302,29 @@
} }
} }
}, },
{
"location": {
"from": "L26:0",
"to": "L26:5"
},
"expr": {
"_type": "VariableExpr",
"name": "print"
},
"type": {
"pos_args": [
{
"pos": 0,
"name": "object",
"type": {},
"required": false
}
],
"args": [],
"kw_args": [],
"returns": {}
}
},
{ {
"location": { "location": {
"from": "L26:0", "from": "L26:0",

View File

@@ -63,31 +63,6 @@
"name": "float" "name": "float"
} }
}, },
{
"location": {
"from": "L6:11",
"to": "L6:15"
},
"expr": {
"_type": "VariableExpr",
"name": "bool"
},
"type": {
"pos_args": [
{
"pos": 0,
"name": "object",
"type": {},
"required": false
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "bool"
}
}
},
{ {
"location": { "location": {
"from": "L6:16", "from": "L6:16",
@@ -135,6 +110,31 @@
"name": "int" "name": "int"
} }
}, },
{
"location": {
"from": "L6:11",
"to": "L6:15"
},
"expr": {
"_type": "VariableExpr",
"name": "bool"
},
"type": {
"pos_args": [
{
"pos": 0,
"name": "object",
"type": {},
"required": false
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "bool"
}
}
},
{ {
"location": { "location": {
"from": "L6:11", "from": "L6:11",
@@ -367,6 +367,54 @@
} }
} }
}, },
{
"location": {
"from": "L12:21",
"to": "L12:27"
},
"expr": {
"_type": "VariableExpr",
"name": "double"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L12:29",
"to": "L12:35"
},
"expr": {
"_type": "VariableExpr",
"name": "floats"
},
"type": {
"name": "list",
"args": [
{
"name": "float"
}
],
"body": {
"name": "list"
}
}
},
{ {
"location": { "location": {
"from": "L12:17", "from": "L12:17",
@@ -455,54 +503,6 @@
} }
} }
}, },
{
"location": {
"from": "L12:21",
"to": "L12:27"
},
"expr": {
"_type": "VariableExpr",
"name": "double"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L12:29",
"to": "L12:35"
},
"expr": {
"_type": "VariableExpr",
"name": "floats"
},
"type": {
"name": "list",
"args": [
{
"name": "float"
}
],
"body": {
"name": "list"
}
}
},
{ {
"location": { "location": {
"from": "L12:17", "from": "L12:17",
@@ -538,6 +538,54 @@
} }
} }
}, },
{
"location": {
"from": "L13:19",
"to": "L13:25"
},
"expr": {
"_type": "VariableExpr",
"name": "double"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L13:27",
"to": "L13:31"
},
"expr": {
"_type": "VariableExpr",
"name": "ints"
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{ {
"location": { "location": {
"from": "L13:15", "from": "L13:15",
@@ -626,54 +674,6 @@
} }
} }
}, },
{
"location": {
"from": "L13:19",
"to": "L13:25"
},
"expr": {
"_type": "VariableExpr",
"name": "double"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L13:27",
"to": "L13:31"
},
"expr": {
"_type": "VariableExpr",
"name": "ints"
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{ {
"location": { "location": {
"from": "L13:15", "from": "L13:15",
@@ -699,6 +699,54 @@
}, },
"type": {} "type": {}
}, },
{
"location": {
"from": "L14:15",
"to": "L14:21"
},
"expr": {
"_type": "VariableExpr",
"name": "is_odd"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "int"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "bool"
}
}
},
{
"location": {
"from": "L14:23",
"to": "L14:27"
},
"expr": {
"_type": "VariableExpr",
"name": "ints"
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{ {
"location": { "location": {
"from": "L14:11", "from": "L14:11",
@@ -787,54 +835,6 @@
} }
} }
}, },
{
"location": {
"from": "L14:15",
"to": "L14:21"
},
"expr": {
"_type": "VariableExpr",
"name": "is_odd"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "int"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "bool"
}
}
},
{
"location": {
"from": "L14:23",
"to": "L14:27"
},
"expr": {
"_type": "VariableExpr",
"name": "ints"
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{ {
"location": { "location": {
"from": "L14:11", "from": "L14:11",

View File

@@ -16,7 +16,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "bool", "base": "bool",
"param": null "args": []
} }
}, },
{ {
@@ -25,7 +25,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "int", "base": "int",
"param": null "args": []
} }
}, },
{ {
@@ -36,7 +36,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "float", "base": "float",
"param": null "args": []
}, },
"constraint": "(_ > 0) + (_ < 250)" "constraint": "(_ > 0) + (_ < 250)"
} }
@@ -47,7 +47,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "str", "base": "str",
"param": null "args": []
} }
}, },
{ {
@@ -56,7 +56,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "datetime", "base": "datetime",
"param": null "args": []
} }
}, },
{ {
@@ -65,7 +65,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "float", "base": "float",
"param": null "args": []
} }
}, },
{ {
@@ -79,7 +79,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "_", "base": "_",
"param": null "args": []
} }
} }
] ]

View File

@@ -16,7 +16,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "GeoLocation", "base": "GeoLocation",
"param": null "args": []
} }
} }
] ]
@@ -28,11 +28,13 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "Column", "base": "Column",
"param": { "args": [
"_type": "BaseType", {
"base": "GeoLocation", "_type": "BaseType",
"param": null "base": "GeoLocation",
} "args": []
}
]
} }
}, },
{ {
@@ -65,11 +67,13 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "Column", "base": "Column",
"param": { "args": [
"_type": "BaseType", {
"base": "GeoLocation", "_type": "BaseType",
"param": null "base": "GeoLocation",
} "args": []
}
]
} }
}, },
{ {
@@ -117,7 +121,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "Latitude", "base": "Latitude",
"param": null "args": []
} }
}, },
{ {
@@ -146,7 +150,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "Latitude", "base": "Latitude",
"param": null "args": []
} }
}, },
{ {
@@ -175,11 +179,13 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "Difference", "base": "Difference",
"param": { "args": [
"_type": "BaseType", {
"base": "Latitude", "_type": "BaseType",
"param": null "base": "Latitude",
} "args": []
}
]
} }
}, },
{ {
@@ -217,7 +223,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "int", "base": "int",
"param": null "args": []
}, },
"constraint": "_ >= 0" "constraint": "_ >= 0"
} }
@@ -230,7 +236,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "float", "base": "float",
"param": null "args": []
}, },
"constraint": "_ >= 0" "constraint": "_ >= 0"
} }
@@ -252,7 +258,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "int", "base": "int",
"param": null "args": []
}, },
"constraint": "Positive" "constraint": "Positive"
} }
@@ -265,7 +271,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "float", "base": "float",
"param": null "args": []
}, },
"constraint": "Positive" "constraint": "Positive"
} }

View File

@@ -14,15 +14,17 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "Column", "base": "Column",
"param": { "args": [
"_type": "ConstraintType", {
"type": { "_type": "ConstraintType",
"_type": "BaseType", "type": {
"base": "float", "_type": "BaseType",
"param": null "base": "float",
}, "args": []
"constraint": "0 <= _ <= 1" },
} "constraint": "0 <= _ <= 1"
}
]
}, },
"default": null "default": null
}, },
@@ -31,15 +33,17 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "Column", "base": "Column",
"param": { "args": [
"_type": "ConstraintType", {
"type": { "_type": "ConstraintType",
"_type": "BaseType", "type": {
"base": "float", "_type": "BaseType",
"param": null "base": "float",
}, "args": []
"constraint": "0 <= _ <= 1" },
} "constraint": "0 <= _ <= 1"
}
]
}, },
"default": null "default": null
} }
@@ -50,15 +54,17 @@
"returns": { "returns": {
"_type": "BaseType", "_type": "BaseType",
"base": "Column", "base": "Column",
"param": { "args": [
"_type": "ConstraintType", {
"type": { "_type": "ConstraintType",
"_type": "BaseType", "type": {
"base": "float", "_type": "BaseType",
"param": null "base": "float",
}, "args": []
"constraint": "0 <= _ <= 2" },
} "constraint": "0 <= _ <= 2"
}
]
}, },
"body": [ "body": [
{ {
@@ -67,15 +73,17 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "Column", "base": "Column",
"param": { "args": [
"_type": "ConstraintType", {
"type": { "_type": "ConstraintType",
"_type": "BaseType", "type": {
"base": "float", "_type": "BaseType",
"param": null "base": "float",
}, "args": []
"constraint": "0 <= _ <= 2" },
} "constraint": "0 <= _ <= 2"
}
]
} }
}, },
{ {
@@ -117,7 +125,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "int", "base": "int",
"param": null "args": []
}, },
"default": null "default": null
} }
@@ -128,7 +136,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "float", "base": "float",
"param": null "args": []
}, },
"default": null "default": null
} }
@@ -140,7 +148,7 @@
"type": { "type": {
"_type": "BaseType", "_type": "BaseType",
"base": "str", "base": "str",
"param": null "args": []
}, },
"default": null "default": null
} }

View File

@@ -46,7 +46,8 @@ class GeneratorTester(Tester):
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics): if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
generator = Generator(workdir=path.parent, types=checker.types) generator = Generator(workdir=path.parent, types=checker.types)
result.compiled_ast = generator.generate_ast(typed_ast, path) generator.set_src_path(path)
result.compiled_ast = generator.generate_ast(typed_ast)
return result return result

View File

@@ -8,6 +8,7 @@ from midas.ast.midas import (
Expr, Expr,
ExtendStmt, ExtendStmt,
ExtensionType, ExtensionType,
FrameType,
FunctionType, FunctionType,
GenericType, GenericType,
GetExpr, GetExpr,
@@ -197,3 +198,15 @@ class MidasAstJsonSerializer(
"base": type.base.accept(self), "base": type.base.accept(self),
"extension": type.extension.accept(self), "extension": type.extension.accept(self),
} }
def visit_frame_type(self, type: FrameType) -> dict:
return {
"_type": "FrameType",
"columns": [self._serialize_column(col) for col in type.columns],
}
def _serialize_column(self, column: FrameType.Column):
return {
"name": column.name.lexeme,
"type": column.type.accept(self),
}

View File

@@ -30,6 +30,7 @@ from midas.ast.python import (
Stmt, Stmt,
SubscriptExpr, SubscriptExpr,
TernaryExpr, TernaryExpr,
TupleExpr,
TypeAssign, TypeAssign,
UnaryExpr, UnaryExpr,
VariableExpr, VariableExpr,
@@ -98,7 +99,7 @@ class PythonAstJsonSerializer(
return { return {
"_type": "BaseType", "_type": "BaseType",
"base": node.base, "base": node.base,
"param": self._serialize_optional(node.param), "args": self._serialize_list(node.args),
} }
def visit_constraint_type(self, node: ConstraintType) -> dict: def visit_constraint_type(self, node: ConstraintType) -> dict:
@@ -302,6 +303,12 @@ class PythonAstJsonSerializer(
"step": self._serialize_optional(expr.step), "step": self._serialize_optional(expr.step),
} }
def visit_tuple_expr(self, expr: TupleExpr) -> dict:
return {
"_type": "TupleExpr",
"items": [item.accept(self) for item in expr.items],
}
def visit_raw_expr(self, expr: RawExpr) -> dict: def visit_raw_expr(self, expr: RawExpr) -> dict:
return { return {
"_type": "RawExpr", "_type": "RawExpr",