feat(checker): add environment preamble

this adds some builtin functions such as the builtin type constructors
This commit is contained in:
2026-06-16 14:02:45 +02:00
parent c4062c9595
commit 732f7b0796
3 changed files with 127 additions and 5 deletions

121
midas/checker/preamble.py Normal file
View File

@@ -0,0 +1,121 @@
from dataclasses import dataclass
from midas.checker.environment import Environment
from midas.checker.registry import TypesRegistry
from midas.checker.types import Function, GenericType, TopType, Type, TypeVar, UnitType
@dataclass(frozen=True)
class Param:
name: str
type: Type
required: bool = True
class Preamble(Environment):
def __init__(self, types: TypesRegistry) -> None:
super().__init__()
self._types: TypesRegistry = types
self._def_type_constructor("object")
self._def_type_constructor("float")
self._def_type_constructor("int")
self._def_type_constructor("bool")
self._def_type_constructor("str")
self._def_function(
name="list",
pos=[Param("object", TopType())],
returns=self._list_of(TopType()),
)
# TODO: use sink
self._def_function(
name="print",
pos=[Param("object", TopType())],
returns=UnitType(),
)
map_in = TypeVar(name="T", bound=None)
map_out = TypeVar(name="U", bound=None)
mapper = self._make_function(
name="MapTransform",
pos=[Param("v", map_in)],
returns=map_out,
)
self._def_function(
name="map",
pos=[
Param("transform", mapper),
Param(
"iterable",
self._list_of(map_in), # TODO: replace with Iterable[T]
),
],
returns=self._list_of(map_out), # TODO: replace with Iterable[U]
)
def _list_of(self, item_type: Type) -> Type:
return self._types.apply_generic(self._types.get_type("list"), [item_type])
def _def_type_constructor(self, name: str):
# TODO: more specific arg types
self._def_function(
name=name,
pos=[Param("object", TopType())],
returns=self._types.get_type(name),
)
def _make_function(
self,
*,
name: str,
pos: list[Param] = [],
mixed: list[Param] = [],
kw: list[Param] = [],
returns: Type = UnitType(),
type_vars: list[TypeVar] = [],
) -> Type:
def map_args(params: list[Param], offset: int) -> list[Function.Argument]:
return [
Function.Argument(
pos=i + offset,
name=param.name,
type=param.type,
required=param.required,
)
for i, param in enumerate(params)
]
function = Function(
pos_args=map_args(pos, 0),
args=map_args(mixed, len(pos)),
kw_args=map_args(kw, len(pos) + len(mixed)),
returns=returns,
)
if len(type_vars) != 0:
function = GenericType(
name=name,
params=type_vars,
body=function,
)
return function
def _def_function(
self,
*,
name: str,
pos: list[Param] = [],
mixed: list[Param] = [],
kw: list[Param] = [],
returns: Type = UnitType(),
type_vars: list[TypeVar] = [],
):
function: Type = self._make_function(
name=name,
pos=pos,
mixed=mixed,
kw=kw,
returns=returns,
type_vars=type_vars,
)
self.define(name, function)

View File

@@ -7,6 +7,7 @@ import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver from midas.checker.resolver import Resolver
@@ -56,7 +57,7 @@ class PythonTyper(
self.logger: logging.Logger = logging.getLogger("PythonTyper") self.logger: logging.Logger = logging.getLogger("PythonTyper")
self.reporter: FileReporter = reporter.for_file(None) self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types self.types: TypesRegistry = types
self.global_env: Environment = Environment() self.global_env: Environment = Preamble(self.types)
self.env: Environment = self.global_env self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {} self.locals: dict[p.Expr, int] = {}
self.judgements: list[tuple[p.Expr, Type]] = [] self.judgements: list[tuple[p.Expr, Type]] = []

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional from typing import Optional
@@ -41,9 +41,9 @@ class UnitType:
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Function: class Function:
pos_args: list[Argument] pos_args: list[Argument] = field(default_factory=list)
args: list[Argument] args: list[Argument] = field(default_factory=list)
kw_args: list[Argument] kw_args: list[Argument] = field(default_factory=list)
returns: Type returns: Type
def __str__(self) -> str: def __str__(self) -> str: