diff --git a/midas/checker/midas.py b/midas/checker/midas.py index fabdd59..5e117c2 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -10,6 +10,7 @@ from midas.checker.reporter import FileReporter, Reporter from midas.checker.types import ( AliasType, ComplexType, + ConstraintType, ExtensionType, Function, GenericType, @@ -184,10 +185,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type return UnknownType() def visit_constraint_type(self, type: m.ConstraintType) -> Type: - type_: Type = type.type.accept(self) - type.constraint.accept(self) - # TODO - return UnknownType() + return ConstraintType( + type=type.type.accept(self), + constraint=type.constraint, + ) def visit_complex_type(self, type: m.ComplexType) -> ComplexType: return ComplexType( diff --git a/midas/checker/types.py b/midas/checker/types.py index c16db79..f66468b 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -3,6 +3,9 @@ from __future__ import annotations from dataclasses import dataclass, field from typing import Optional, assert_never +import midas.ast.midas as m +from midas.ast.printer import MidasPrinter + @dataclass(frozen=True, kw_only=True) class TopType: @@ -130,6 +133,16 @@ class AppliedType: return f"{self.name}[{', '.join(map(str, self.args))}]" +@dataclass(frozen=True, kw_only=True) +class ConstraintType: + type: Type + constraint: m.Expr + + def __str__(self) -> str: + printer = MidasPrinter() + return f"{self.type} where {printer.print(self.constraint)}" + + def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: def sub_argument(arg: Function.Argument): return Function.Argument( @@ -195,6 +208,12 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: body=substitute_typevars(body, substitutions), ) + case ConstraintType(): + return ConstraintType( + type=substitute_typevars(type.type, substitutions), + constraint=type.constraint, + ) + case TypeVar(name=name): if name in substitutions: return substitutions[name] @@ -238,4 +257,5 @@ Type = ( | TypeVar | GenericType | AppliedType + | ConstraintType ) diff --git a/midas/generator/generator.py b/midas/generator/generator.py index c7b1323..67e11e9 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Optional, assert_never +import midas.ast.midas as m import midas.ast.python as p from midas.ast.location import Location from midas.checker.types import ( @@ -11,6 +12,7 @@ from midas.checker.types import ( AppliedType, BaseType, ComplexType, + ConstraintType, ExtensionType, Function, GenericType, @@ -308,6 +310,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): case AppliedType(body=body): self._make_cast_asserts(src_location, expr, body) + case ConstraintType(type=base, constraint=constraint): + self._make_cast_asserts(src_location, expr, base) + self._make_constraint_assert(src_location, expr, constraint) + case TypeVar(): raise RuntimeError("Unexpected TypeVar") @@ -347,3 +353,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): ast.Constant(f" to {type}"), ] ) + + def _make_constraint_assert( + self, src_location: Location, expr: ast.expr, constraint: m.Expr + ): + # TODO + pass