feat(checker): handle comparisons

This commit is contained in:
2026-06-01 14:12:22 +02:00
parent ab0fa1de1a
commit 9d45163d9c
2 changed files with 64 additions and 16 deletions

View File

@@ -8,7 +8,7 @@ import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.environment import Environment
from midas.checker.operators import OPERATOR_METHODS
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
from midas.checker.types import Function, Type, UnitType, UnknownType
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
@@ -294,7 +294,23 @@ class Checker(
return UnknownType()
return result
def visit_compare_expr(self, expr: p.CompareExpr) -> Type: ...
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.warning(expr.location, f"Unsupported operator {expr.operator}")
return UnknownType()
left: Type = self.evaluate(expr.left)
right: Type = self.evaluate(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
if result is None:
self.error(
expr.location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
return result
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...

View File

@@ -2,12 +2,20 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from midas.checker.types import BaseType, Type
from midas.checker.types import BaseType, Type, UnitType
if TYPE_CHECKING:
from midas.resolver.midas import MidasResolver
def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
ctx.define_operation(
left=t1,
operator=operator,
right=t2,
result=t3,
)
def basic_op(ctx: MidasResolver, type: Type, op: str):
ctx.define_operation(
left=type,
@@ -19,21 +27,45 @@ def basic_op(ctx: MidasResolver, type: Type, op: str):
def define_builtins(ctx: MidasResolver):
"""Define builtin types and operations"""
unit = ctx.define_type("None", UnitType())
bool = ctx.define_type("bool", BaseType(name="bool"))
int = ctx.define_type("int", BaseType(name="int"))
float = ctx.define_type("float", BaseType(name="float"))
str = ctx.define_type("str", BaseType(name="str"))
basic_op(ctx, int, "__add__")
basic_op(ctx, int, "__sub__")
basic_op(ctx, int, "__mul__")
basic_op(ctx, int, "__pow__")
basic_op(ctx, int, "__mod__")
basic_op(ctx, int, "__and__")
basic_op(ctx, int, "__or__")
basic_op(ctx, int, "__xor__")
basic_op(ctx, float, "__add__")
basic_op(ctx, float, "__sub__")
basic_op(ctx, float, "__mul__")
basic_op(ctx, float, "__truediv__")
basic_op(ctx, str, "__add__")
basic_op(ctx, int, "__add__") # int + int = int
basic_op(ctx, int, "__sub__") # int - int = int
basic_op(ctx, int, "__mul__") # int * int = int
basic_op(ctx, int, "__pow__") # int ** int = int
basic_op(ctx, int, "__mod__") # int % int = int
basic_op(ctx, int, "__and__") # int & int = int
basic_op(ctx, int, "__or__") # int | int = int
basic_op(ctx, int, "__xor__") # int ^ int = int
op(ctx, int, "__lt__", int, bool) # int < int = bool
op(ctx, int, "__gt__", int, bool) # int > int = bool
op(ctx, int, "__le__", int, bool) # int <= int = bool
op(ctx, int, "__ge__", int, bool) # int >= int = bool
op(ctx, int, "__eq__", int, bool) # int == int = bool
basic_op(ctx, float, "__add__") # float + float = float
basic_op(ctx, float, "__sub__") # float - float = float
basic_op(ctx, float, "__mul__") # float * float = float
basic_op(ctx, float, "__truediv__") # float / float = float
op(ctx, float, "__lt__", float, bool) # float < float = bool
op(ctx, float, "__gt__", float, bool) # float > float = bool
op(ctx, float, "__le__", float, bool) # float <= float = bool
op(ctx, float, "__ge__", float, bool) # float >= float = bool
op(ctx, float, "__eq__", float, bool) # float == float = bool
basic_op(ctx, str, "__add__") # str + str = str
op(ctx, str, "__eq__", str, bool) # str == str = bool
op(ctx, int, "__lt__", float, bool) # int < float = bool
op(ctx, int, "__gt__", float, bool) # int > float = bool
op(ctx, int, "__le__", float, bool) # int <= float = bool
op(ctx, int, "__ge__", float, bool) # int >= float = bool
op(ctx, int, "__eq__", float, bool) # int == float = bool
op(ctx, float, "__lt__", int, bool) # float < int = bool
op(ctx, float, "__gt__", int, bool) # float > int = bool
op(ctx, float, "__le__", int, bool) # float <= int = bool
op(ctx, float, "__ge__", int, bool) # float >= int = bool
op(ctx, float, "__eq__", int, bool) # float == int = bool