17 Commits

Author SHA1 Message Date
d1c217a335 refactor: use metaclass to collect frame methods 2026-06-25 22:31:59 +02:00
5b3e87afcb refactor: add MethodResolver class 2026-06-25 22:14:25 +02:00
894d5a7196 feat: add dummy classes for typing frames and columns 2026-06-25 21:35:47 +02:00
eb809c6341 fix(checker): improve heterogeneous error message 2026-06-25 21:35:19 +02:00
bd68d1003f feat(checker): lookup dataframe methods 2026-06-25 21:34:59 +02:00
72c9236650 feat(checker): defined add method of dataframes 2026-06-25 21:34:00 +02:00
90051c7981 feat(checker): add structural subtyping rule for dataframes 2026-06-25 21:09:14 +02:00
dd1e2e693c feat(cli): print context for multiline diagnostics 2026-06-25 16:32:15 +02:00
78e10e0895 feat(checker): process frame type definitions 2026-06-24 14:36:53 +02:00
c81e4a9560 feat(cli): add frame type to highlighter 2026-06-24 14:36:53 +02:00
6d0cf1a055 feat(parser): add frame type to midas syntax 2026-06-24 14:36:52 +02:00
cc5e7af143 feat(gen): add support for tuples and dataframes 2026-06-24 14:36:51 +02:00
3bdbc80079 feat(checker): handle setting dataframe column 2026-06-24 14:36:51 +02:00
c1b5284f72 feat(checker): type check subscript on dataframes 2026-06-24 14:36:28 +02:00
5e9ccd4e13 feat(types): add TupleType 2026-06-24 14:36:04 +02:00
cf083fc0c3 fix(types): add str methods to dataframe types 2026-06-24 14:35:31 +02:00
a80da5db2c feat(types): add DataFrameType and ColumnType 2026-06-24 14:35:30 +02:00
16 changed files with 718 additions and 15 deletions

View File

@@ -152,4 +152,14 @@ class FunctionType:
required: bool
class FrameType:
columns: list[Column]
@dataclass(frozen=True, kw_only=True)
class Column:
location: Optional[Location] = None
name: Token
type: Type
###<

View File

@@ -253,6 +253,9 @@ class Type(ABC):
@abstractmethod
def visit_function_type(self, type: FunctionType) -> T: ...
@abstractmethod
def visit_frame_type(self, type: FrameType) -> T: ...
@dataclass(frozen=True)
class NamedType(Type):
@@ -311,3 +314,17 @@ class FunctionType(Type):
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_function_type(self)
@dataclass(frozen=True)
class FrameType(Type):
columns: list[Column]
@dataclass(frozen=True, kw_only=True)
class Column:
location: Optional[Location] = None
name: Token
type: Type
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_frame_type(self)

View File

@@ -350,6 +350,25 @@ class MidasAstPrinter(
arg.type.accept(self)
self._write_line(f"required: {arg.required}", last=True)
def visit_frame_type(self, type: m.FrameType) -> None:
self._write_line("FrameType")
with self._child_level(single=True):
self._write_line("columns")
with self._child_level():
for i, column in enumerate(type.columns):
self._idx = i
if i == len(type.columns) - 1:
self._mark_last()
self._print_frame_column(column)
def _print_frame_column(self, column: m.FrameType.Column) -> None:
self._write_line("Column")
with self._child_level():
self._write_line(f'name: "{column.name.lexeme}"')
self._write_line("type")
with self._child_level(single=True):
column.type.accept(self)
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
def __init__(self, indent: int = 4):
@@ -502,6 +521,23 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
res += "?"
return res
def visit_frame_type(self, type: m.FrameType) -> str:
res: str = self.indented("Frame[")
if len(type.columns) != 0:
res += "\n"
self.level += 1
columns: list[str] = []
for column in type.columns:
columns.append(self.indented(self._print_frame_column(column)))
res += ",\n".join(columns)
self.level -= 1
res += "\n"
res += "]"
return res
def _print_frame_column(self, column: m.FrameType.Column) -> str:
return f"{column.name.lexeme}: {column.type.accept(self)}"
class PythonAstPrinter(
AstPrinter,

View File

@@ -0,0 +1,153 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional
from midas.ast.location import Location
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import (
ColumnType,
DataFrameType,
Function,
Type,
UnknownType,
unfold_type,
)
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
@staticmethod
def frame_method(*names: str):
def wrapper(func):
names_: tuple[str, ...] = names
if len(names_) == 0:
names_ = (func.__name__,)
setattr(func, "__method_names__", names_)
return func
return wrapper
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
frame: DataFrameType
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
class _MethodRegistryMeta(type):
_methods: dict[str, Callable] = {}
def __new__(
cls,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
):
new_class = super().__new__(cls, name, bases, namespace)
new_class._methods = {}
for attr in namespace.values():
if callable(attr) and hasattr(attr, "__method_names__"):
for name in attr.__method_names__: # type: ignore
new_class._methods[name] = attr
return new_class
class MethodRegistry(metaclass=_MethodRegistryMeta):
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
@property
def reporter(self) -> FileReporter:
return self.typer.reporter
@property
def types(self) -> TypesRegistry:
return self.typer.types
def call(
self,
method: str,
call: Call,
) -> Type:
func: Optional[Callable] = self._methods.get(method)
if func is None:
self.reporter.error(call.location, f"Unknown method {method}")
return UnknownType()
return func(self, call)
@frame_method("add", "__add__")
def add(
self,
call: Call,
) -> Type:
new_columns: list[DataFrameType.Column] = []
by_name: dict[str, DataFrameType.Column] = {}
frame2: Optional[DataFrameType] = None
if len(call.positional) != 0:
other: Type = call.positional[0][1]
unfolded_other: Type = unfold_type(other)
if isinstance(unfolded_other, DataFrameType):
frame2 = unfolded_other
by_name = {
col.name: col for col in frame2.columns if col.name is not None
}
in_frame1: set[str] = set()
for column in call.frame.columns:
if column.name is not None:
in_frame1.add(column.name)
col_type1: Type = column.type
col_type: Type = ColumnType(type=UnknownType())
if column.name in by_name:
column2 = by_name[column.name]
col_type2: Type = column2.type
if self.types.are_equivalent(col_type2, col_type1):
col_type = col_type1
new_column = DataFrameType.Column(
index=column.index,
name=column.name,
type=col_type,
)
new_columns.append(new_column)
if frame2 is not None:
for column in frame2.columns:
if column.name in in_frame1:
continue
new_columns.append(
DataFrameType.Column(
index=len(new_columns),
name=column.name,
type=ColumnType(type=UnknownType()),
)
)
signature = Function(
args=[
Function.Argument(
pos=0,
name="other",
type=DataFrameType(columns=[]),
required=True,
),
],
returns=DataFrameType(columns=new_columns),
)
return (
self.typer._get_call_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
or UnknownType()
)

154
midas/checker/frames.py Normal file
View File

@@ -0,0 +1,154 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, TypeGuard, cast
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frame_methods import Call, MethodRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
return all(isinstance(expr, p.LiteralExpr) for expr in exprs)
class FrameManager:
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
self.method_resolver: MethodRegistry = MethodRegistry(self.typer)
def assign(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
index: p.Expr,
value_type: Type,
) -> Type:
match index:
case p.LiteralExpr(value=str() as name):
return self.assign_column(reporter, location, frame, name, value_type)
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
isinstance(idx, str) for idx in indices
):
raise NotImplementedError
case _:
reporter.error(location, f"Invalid index type {index} on {frame}")
return UnknownType()
def assign_column(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
name: str,
type: Type,
) -> Type:
if not isinstance(type, ColumnType):
reporter.error(
location,
f"Cannot assign {type} to dataframe column. Must be a ColumnType",
)
return frame
return self._set_column(frame, name, type)
def get(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
index: p.Expr,
) -> Type:
match index:
case p.LiteralExpr(value=str() as name):
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
if column is None:
reporter.error(location, f"Unknown column '{name}' on {frame}")
return UnknownType()
return column
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
isinstance(index.value, str) for index in indices
):
names: list[str] = [cast(str, index.value) for index in indices]
columns: list[ColumnType] = []
for name in names:
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
if column is None:
reporter.error(location, f"Unknown column '{name}' on {frame}")
return UnknownType()
columns.append(column)
return TupleType(items=tuple(columns))
case _:
reporter.error(location, f"Invalid index type {index} on {frame}")
return UnknownType()
@classmethod
def _set_column(
cls, frame: DataFrameType, name: str, column: ColumnType
) -> DataFrameType:
new_columns: list[DataFrameType.Column] = []
index: int = len(frame.columns)
replace: bool = False
for i, col in enumerate(frame.columns):
if col.name == name:
index = i
replace = True
# TODO: check column type here to prevent changing it
new_columns.append(col)
new_col: DataFrameType.Column = DataFrameType.Column(
index=index,
name=name,
type=column,
)
if replace:
new_columns[index] = new_col
else:
new_columns.append(new_col)
return DataFrameType(columns=new_columns)
@classmethod
def _set_columns(
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
) -> DataFrameType:
for name, col in zip(names, columns):
frame = cls._set_column(frame, name, col)
return frame
@classmethod
def _get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
for col in frame.columns:
if col.name == name:
return col.type
return None
@classmethod
def _get_columns(
cls, frame: DataFrameType, names: list[str]
) -> list[Optional[ColumnType]]:
return [cls._get_column(frame, name) for name in names]
def call(
self,
method: str,
location: Location,
frame: DataFrameType,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: Call = Call(
location=location,
frame=frame,
positional=positional,
keywords=keywords,
)
return self.method_resolver.call(method, call)

View File

@@ -14,8 +14,10 @@ from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import (
AliasType,
AppliedType,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
ExtensionType,
Function,
GenericType,
@@ -401,6 +403,18 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
)
def visit_frame_type(self, type: m.FrameType) -> Type:
def process_column(i: int, col: m.FrameType.Column) -> DataFrameType.Column:
return DataFrameType.Column(
index=i,
name=col.name.lexeme,
type=ColumnType(type=col.type.accept(self)),
)
return DataFrameType(
columns=[process_column(i, col) for i, col in enumerate(type.columns)]
)
def _resolve_type_params(self, params: list[m.TypeParam]):
vars: list[TypeVar] = []
for param in params:

View File

@@ -8,6 +8,7 @@ from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.environment import Environment
from midas.checker.evaluator import Evaluator
from midas.checker.frames import FrameManager
from midas.checker.operators import (
PY_COMPARATOR_METHODS,
PY_OPERATOR_METHODS,
@@ -21,10 +22,13 @@ from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ColumnType,
ConstraintType,
DataFrameType,
Function,
GenericType,
OverloadedFunction,
TupleType,
Type,
TypeVar,
UnitType,
@@ -71,6 +75,7 @@ class PythonTyper(
self.logger: logging.Logger = logging.getLogger("PythonTyper")
self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types
self.frame_mgr: FrameManager = FrameManager(self)
self.global_env: Environment = Preamble(self.types)
self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {}
@@ -319,9 +324,15 @@ class PythonTyper(
case p.VariableExpr():
self._assign_var(location, target, value_type)
# Allow any kind of object because we disallow creating new attributes
case p.GetExpr(object=object, name=name):
self._assign_attr(location, object, name, value_type)
# Only support variable expressions because modifying
# the underlying value would require reference types
case p.SubscriptExpr(object=p.VariableExpr() as var, index=index):
self._assign_sub(location, var, index, value_type)
case _:
if not isinstance(target, p.VariableExpr):
self.logger.warning(f"Unsupported assignment to {target}")
@@ -360,6 +371,27 @@ class PythonTyper(
f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}",
)
def _assign_sub(
self,
location: Location,
var: p.VariableExpr,
index: p.Expr,
value_type: Type,
):
var_type: Type = self.type_of(var)
# TODO: what happens if type is an alias of a dataframe type
match var_type:
case DataFrameType() as frame:
new_type: Type = self.frame_mgr.assign(
self.reporter, location, frame, index, value_type
)
self.env.assign(var.name, new_type)
case _:
self.reporter.error(
location,
f"Cannot assign {value_type} to index {index} of {var_type}",
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType()
self.env.return_types.append(type)
@@ -483,13 +515,27 @@ class PythonTyper(
case p.VariableExpr(name="TypeVar"):
return self.define_typevar(expr) or UnknownType()
callee: Type = self.type_of(expr.callee)
positional: list[TypedExpr] = [
(arg, self.type_of(arg)) for arg in expr.arguments
]
keywords: dict[str, TypedExpr] = {
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
}
match expr.callee:
case p.GetExpr(object=obj, name=method):
obj_type: Type = self.type_of(obj)
unfolded: Type = unfold_type(obj_type)
if isinstance(unfolded, DataFrameType):
return self.frame_mgr.call(
method,
expr.location,
unfolded,
positional,
keywords,
)
callee: Type = self.type_of(expr.callee)
return (
self._get_call_result(
location=expr.location,
@@ -594,7 +640,7 @@ class PythonTyper(
return self.types.apply_generic(list_type, [item_type])
self.reporter.error(
expr.location,
f"Heterogeneous list items: {item_types}",
f"Heterogeneous list items: [{', '.join(map(str, item_types))}]",
)
return self.types.apply_generic(list_type, [UnknownType()])
@@ -626,7 +672,7 @@ class PythonTyper(
else:
self.reporter.error(
expr.location,
f"Heterogeneous dict keys: {key_types}",
f"Heterogeneous dict keys: [{', '.join(map(str, key_types))}]",
)
if len(value_types) == 1:
@@ -634,12 +680,19 @@ class PythonTyper(
else:
self.reporter.error(
expr.location,
f"Heterogeneous dict values: {value_types}",
f"Heterogeneous dict values: [{', '.join(map(str, value_types))}]",
)
return self.types.apply_generic(dict_type, [key_type, value_type])
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type:
object: Type = self.type_of(expr.object)
unfolded: Type = unfold_type(object)
match unfolded:
case TupleType():
return self._visit_tuple_subscript(unfolded, expr)
case DataFrameType():
return self._visit_frame_subscript(unfolded, expr)
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
if operation is None:
self.reporter.error(
@@ -677,13 +730,26 @@ class PythonTyper(
self.reporter.warning(node.location, "ConstraintType not yet supported")
return UnknownType()
def visit_frame_column(self, node: p.FrameColumn) -> Type:
self.reporter.warning(node.location, "FrameColumn not yet supported")
return UnknownType()
def visit_frame_column(self, node: p.FrameColumn) -> ColumnType:
return ColumnType(
type=(
self.resolve_type_expr(node.type)
if node.type is not None
else UnknownType()
)
)
def visit_frame_type(self, node: p.FrameType) -> Type:
self.reporter.warning(node.location, "FrameType not yet supported")
return UnknownType()
return DataFrameType(
columns=[
DataFrameType.Column(
index=i,
name=column.name,
type=self.visit_frame_column(column),
)
for i, column in enumerate(node.columns)
]
)
def _get_call_result(
self,
@@ -1216,3 +1282,23 @@ class PythonTyper(
expr.location, f"Cannot evaluate cast to {target_type} statically"
)
return False
def _visit_tuple_subscript(self, tup: TupleType, expr: p.SubscriptExpr) -> Type:
match expr.index:
case p.LiteralExpr(value=int() as index):
if index < 0 or index >= len(tup.items):
self.reporter.error(
expr.location, f"Index {index} out of range for tuple {tup}"
)
return UnknownType()
return tup.items[index]
case _:
self.reporter.error(
expr.location, f"Invalid index type {expr.index} on {tup}"
)
return UnknownType()
def _visit_frame_subscript(
self, frame: DataFrameType, expr: p.SubscriptExpr
) -> Type:
return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)

View File

@@ -8,8 +8,10 @@ from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
ExtensionType,
Function,
GenericType,
@@ -157,6 +159,24 @@ class TypesRegistry:
return False
return True
case (DataFrameType(columns=columns1), DataFrameType(columns=columns2)):
# TODO: check order?
by_name1: dict[str, DataFrameType.Column] = {
col.name: col for col in columns1 if col.name is not None
}
for col2 in columns2:
if col2.name not in by_name1:
return False
if not self.is_subtype(by_name1[col2.name].type, col2.type):
return False
return True
case (ColumnType(type=inner1), ColumnType(type=inner2)):
# TODO: invariant, replace ColumnType with simple GenericType
if not self.are_equivalent(inner1, inner2):
return False
return True
case (Function(), Function()):
return self.is_func_subtype(type1, type2)
@@ -187,6 +207,9 @@ class TypesRegistry:
return False
def are_equivalent(self, type1: Type, type2: Type) -> bool:
return self.is_subtype(type1, type2) and self.is_subtype(type2, type1)
# TODO: verify the logic in here
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
"""Check whether a function is a subtype of another

View File

@@ -128,6 +128,10 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
case p.GetExpr():
target.accept(self)
case p.SubscriptExpr():
target.accept(self)
case _:
raise Exception(f"Unsupported assignment to {target}")

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Optional, assert_never
from typing import Optional, assert_never, cast
import midas.ast.midas as m
from midas.ast.printer import MidasPrinter
@@ -156,6 +156,37 @@ class ConstraintType:
return f"{self.type} where {printer.print(self.constraint)}"
@dataclass(frozen=True, kw_only=True)
class TupleType:
items: tuple[Type, ...]
def __str__(self) -> str:
return f"({', '.join(map(str, self.items))})"
@dataclass(frozen=True, kw_only=True)
class ColumnType:
type: Type
def __str__(self) -> str:
return f"Column[{self.type}]"
@dataclass(frozen=True, kw_only=True)
class DataFrameType:
columns: list[Column]
def __str__(self) -> str:
schema: list[str] = [f"{col.name}: {col.type}" for col in self.columns]
return f"Frame[{', '.join(schema)}]"
@dataclass(frozen=True, kw_only=True)
class Column:
index: int
name: Optional[str]
type: ColumnType
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def sub_argument(arg: Function.Argument):
return Function.Argument(
@@ -165,6 +196,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
required=arg.required,
)
def sub_column(col: DataFrameType.Column):
return DataFrameType.Column(
index=col.index,
name=col.name,
type=cast(ColumnType, substitute_typevars(col.type, substitutions)),
)
match type:
case TopType():
return type
@@ -250,10 +288,26 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
body=substitute_typevars(body, substitutions),
)
case TupleType(items=items):
return TupleType(
items=tuple(substitute_typevars(item, substitutions) for item in items),
)
case ColumnType(type=items_type):
return ColumnType(
type=substitute_typevars(items_type, substitutions),
)
case DataFrameType(columns=columns):
return DataFrameType(
columns=list(map(sub_column, columns)),
)
case UnknownType() | UnitType():
return type
case TopType() | GenericType():
raise NotImplementedError(f"Unsupported type {type}")
# Ensure exhaustiveness
@@ -317,6 +371,15 @@ def to_annotation(type: Type) -> str:
case ConstraintType():
return str(type)
case TupleType(items=items):
return f"Tuple[{', '.join(map(to_annotation, items))}]"
case ColumnType():
return "pd.Series"
case DataFrameType():
return "pd.DataFrame"
case _:
assert_never(type)
@@ -342,4 +405,7 @@ Type = (
| GenericType
| AppliedType
| ConstraintType
| TupleType
| ColumnType
| DataFrameType
)

View File

@@ -350,6 +350,14 @@ class MidasHighlighter(
for param in spec.pos + spec.mixed + spec.kw:
param.type.accept(self)
def visit_frame_type(self, type: m.FrameType) -> None:
self.wrap(type, "frame")
for column in type.columns:
self._visit_frame_column(column)
def _visit_frame_column(self, column: m.FrameType.Column) -> None:
self.wrap(column, "column")
class DiagnosticsHighlighter(Highlighter):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"

View File

@@ -68,7 +68,7 @@ class DiagnosticPrinter:
loc: Location = diagnostic.location
if loc.lineno != loc.end_lineno:
print(diagnostic)
self.print_multiline(lines, diagnostic, indent)
return
start_offset: int = loc.col_offset
@@ -95,3 +95,27 @@ class DiagnosticPrinter:
print(indent_str + before + subject + after)
print(indent_str + cursor)
print()
def print_multiline(
self, all_lines: list[str], diagnostic: Diagnostic, indent: int = 4
):
loc: Location = diagnostic.location
lines: list[str] = all_lines[loc.lineno - 1 : loc.end_lineno]
start_offset: int = loc.col_offset
end_offset: int = loc.end_col_offset or (start_offset + 1)
indent_str: str = " " * indent
color: int = self.COLORS.get(diagnostic.type, Ansi.WHITE)
res: str = indent_str + lines[0][:start_offset]
res += Ansi.FG(color) + lines[0][start_offset:]
for line in lines[1:-1]:
res += "\n" + indent_str + line
res += "\n" + indent_str + lines[-1][:end_offset]
res += Ansi.RESET + lines[-1][end_offset:]
print(diagnostic.location_str + ":")
print(res)
print()
print(Ansi.FG(color) + diagnostic.message + Ansi.RESET)
print()

View File

@@ -1,4 +1,5 @@
import ast
import logging
import shutil
from dataclasses import dataclass, field
from pathlib import Path
@@ -13,13 +14,16 @@ from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
TopType,
TupleType,
Type,
TypeVar,
UnitType,
@@ -40,6 +44,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self.workdir: Path = workdir.resolve()
self.build_dir: Path = self.workdir / "build" / "midas"
self.rel_src_path: Path = Path()
self.logger: logging.Logger = logging.getLogger("Generator")
self._typed_ast: TypedAST = TypedAST(
stmts=[],
@@ -332,6 +337,19 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
if bound is not None:
self._make_cast_asserts(src_location, expr, bound)
case TupleType(items=items):
self._add_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id="tuple")],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
)
assert isinstance(expr, ast.Tuple)
for item, item_type in zip(expr.elts, items):
self._make_cast_asserts(src_location, item, item_type)
case (
TopType()
| Function()
@@ -339,8 +357,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
| ComplexType()
| ExtensionType()
| GenericType()
| ColumnType()
| DataFrameType()
):
raise NotImplementedError(f"Can't make assertion for type {type}")
self.logger.warning(f"Can't make assertion for type {type}")
# Ensure exhaustiveness
case _:

View File

@@ -7,13 +7,16 @@ from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
TopType,
TupleType,
Type,
TypeVar,
UnitType,
@@ -30,6 +33,7 @@ class StubsGenerator:
self.types: TypesRegistry = types
self.stubs: list[ast.stmt] = []
self.typing_imports: set[str] = set()
self.import_pandas: bool = False
self.protocol_idx: int = 0
self.stub_idx: int = 0
self.type_var_idx: int = 0
@@ -38,6 +42,7 @@ class StubsGenerator:
def generate_stubs(self) -> ast.Module:
self.stubs = []
self.typing_imports = set()
self.import_pandas = False
for name, type in self.types._types.items():
# Skip builtin types, not just based on name so the user can override
# TODO: check if added members on builtin type
@@ -53,7 +58,7 @@ class StubsGenerator:
continue
self.generate_stub(name, type)
imports = [
imports: list[ast.stmt] = [
ast.ImportFrom(
module="__future__",
names=[ast.alias(name="annotations")],
@@ -70,6 +75,17 @@ class StubsGenerator:
level=0,
)
)
if self.import_pandas:
imports.append(
ast.Import(
names=[
ast.alias(
name="pandas",
asname="pd",
)
],
)
)
return ast.Module(body=imports + self.stubs, type_ignores=[])
def generate_stub(self, name: str, type: Type):
@@ -231,6 +247,31 @@ class StubsGenerator:
case ConstraintType():
return self.dump_type(type.type)
case TupleType(items=items):
return ast.Subscript(
value=ast.Name(id="tuple"),
slice=ast.Tuple(
elts=[self.dump_type(item) for item in items],
),
)
case ColumnType(type=inner):
self.import_pandas = True
return ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="pd"),
attr="Series",
),
slice=self.dump_type(inner),
)
case DataFrameType():
self.import_pandas = True
return ast.Attribute(
value=ast.Name(id="pd"),
attr="DataFrame",
)
case _:
assert_never(type)

View File

@@ -9,6 +9,7 @@ from midas.ast.midas import (
Expr,
ExtendStmt,
ExtensionType,
FrameType,
FunctionType,
GenericType,
GetExpr,
@@ -204,8 +205,10 @@ class MidasParser(Parser):
return self.generic_type()
def generic_type(self) -> Type:
type: Type = self.named_type()
type: NamedType = self.named_type()
if self.check(TokenType.LEFT_BRACKET):
if type.name.lexeme == "Frame":
return self.frame_type()
args: list[Type] = self.type_args()
return GenericType(
location=Location.span(type.location, self.previous().get_location()),
@@ -224,7 +227,7 @@ class MidasParser(Parser):
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
return args
def named_type(self) -> Type:
def named_type(self) -> NamedType:
name: Token = self.consume_identifier("Expected type name")
return NamedType(
location=name.get_location(),
@@ -259,6 +262,32 @@ class MidasParser(Parser):
members=members,
)
def frame_type(self) -> FrameType:
keyword: Token = self.previous()
self.consume(TokenType.LEFT_BRACKET, "Expected '[' to start frame schema")
columns: list[FrameType.Column] = []
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
name: Token = self.advance()
self.consume(TokenType.COLON, "Expected ':' between column name and type")
type: Type = self.type_expr()
columns.append(
FrameType.Column(
location=name.location_to(self.previous()),
name=name,
type=type,
)
)
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Unclosed frame schema")
return FrameType(
location=keyword.location_to(self.previous()),
columns=columns,
)
def constraint(self) -> Expr:
"""Parse a constraint

View File

@@ -1,3 +1,4 @@
from typing import Generic, TypeVar
from typing import cast as typing_cast
cast = typing_cast
@@ -32,3 +33,20 @@ This operation is unsound, use at your own risk!
_**Internal Python documentation**_
"""
T = TypeVar("T")
class Frame(Generic[T]):
"""A `Frame` is the abstract type implemented by `DataFrame`
A frame contains any number of named columns (see :class:`Column`)
"""
class Column(Generic[T]):
"""A `Column` is the abstract type implemented by `Series`
A column contains a any number of values of the same type
"""