feat(types): add ConstraintType
This commit is contained in:
@@ -10,6 +10,7 @@ from midas.checker.reporter import FileReporter, Reporter
|
|||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
AliasType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
|
ConstraintType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
@@ -184,10 +185,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||||
type_: Type = type.type.accept(self)
|
return ConstraintType(
|
||||||
type.constraint.accept(self)
|
type=type.type.accept(self),
|
||||||
# TODO
|
constraint=type.constraint,
|
||||||
return UnknownType()
|
)
|
||||||
|
|
||||||
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
|
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
|
||||||
return ComplexType(
|
return ComplexType(
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, assert_never
|
from typing import Optional, assert_never
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
|
from midas.ast.printer import MidasPrinter
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class TopType:
|
class TopType:
|
||||||
@@ -130,6 +133,16 @@ class AppliedType:
|
|||||||
return f"{self.name}[{', '.join(map(str, self.args))}]"
|
return f"{self.name}[{', '.join(map(str, self.args))}]"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class ConstraintType:
|
||||||
|
type: Type
|
||||||
|
constraint: m.Expr
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
printer = MidasPrinter()
|
||||||
|
return f"{self.type} where {printer.print(self.constraint)}"
|
||||||
|
|
||||||
|
|
||||||
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||||
def sub_argument(arg: Function.Argument):
|
def sub_argument(arg: Function.Argument):
|
||||||
return Function.Argument(
|
return Function.Argument(
|
||||||
@@ -195,6 +208,12 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
body=substitute_typevars(body, substitutions),
|
body=substitute_typevars(body, substitutions),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case ConstraintType():
|
||||||
|
return ConstraintType(
|
||||||
|
type=substitute_typevars(type.type, substitutions),
|
||||||
|
constraint=type.constraint,
|
||||||
|
)
|
||||||
|
|
||||||
case TypeVar(name=name):
|
case TypeVar(name=name):
|
||||||
if name in substitutions:
|
if name in substitutions:
|
||||||
return substitutions[name]
|
return substitutions[name]
|
||||||
@@ -238,4 +257,5 @@ Type = (
|
|||||||
| TypeVar
|
| TypeVar
|
||||||
| GenericType
|
| GenericType
|
||||||
| AppliedType
|
| AppliedType
|
||||||
|
| ConstraintType
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from dataclasses import dataclass, field
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, assert_never
|
from typing import Optional, assert_never
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
@@ -11,6 +12,7 @@ from midas.checker.types import (
|
|||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
|
ConstraintType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
@@ -308,6 +310,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
case AppliedType(body=body):
|
case AppliedType(body=body):
|
||||||
self._make_cast_asserts(src_location, expr, body)
|
self._make_cast_asserts(src_location, expr, body)
|
||||||
|
|
||||||
|
case ConstraintType(type=base, constraint=constraint):
|
||||||
|
self._make_cast_asserts(src_location, expr, base)
|
||||||
|
self._make_constraint_assert(src_location, expr, constraint)
|
||||||
|
|
||||||
case TypeVar():
|
case TypeVar():
|
||||||
raise RuntimeError("Unexpected TypeVar")
|
raise RuntimeError("Unexpected TypeVar")
|
||||||
|
|
||||||
@@ -347,3 +353,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
ast.Constant(f" to {type}"),
|
ast.Constant(f" to {type}"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _make_constraint_assert(
|
||||||
|
self, src_location: Location, expr: ast.expr, constraint: m.Expr
|
||||||
|
):
|
||||||
|
# TODO
|
||||||
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user