feat(checker): handle comparisons

This commit is contained in:
2026-06-01 14:12:22 +02:00
parent ab0fa1de1a
commit 9d45163d9c
2 changed files with 64 additions and 16 deletions

View File

@@ -8,7 +8,7 @@ import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.operators import OPERATOR_METHODS from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
from midas.checker.types import Function, Type, UnitType, UnknownType from midas.checker.types import Function, Type, UnitType, UnknownType
from midas.lexer.midas import MidasLexer from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token from midas.lexer.token import Token
@@ -294,7 +294,23 @@ class Checker(
return UnknownType() return UnknownType()
return result return result
def visit_compare_expr(self, expr: p.CompareExpr) -> Type: ... def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.warning(expr.location, f"Unsupported operator {expr.operator}")
return UnknownType()
left: Type = self.evaluate(expr.left)
right: Type = self.evaluate(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
if result is None:
self.error(
expr.location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
return result
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ... def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...

View File

@@ -2,12 +2,20 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from midas.checker.types import BaseType, Type from midas.checker.types import BaseType, Type, UnitType
if TYPE_CHECKING: if TYPE_CHECKING:
from midas.resolver.midas import MidasResolver from midas.resolver.midas import MidasResolver
def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
ctx.define_operation(
left=t1,
operator=operator,
right=t2,
result=t3,
)
def basic_op(ctx: MidasResolver, type: Type, op: str): def basic_op(ctx: MidasResolver, type: Type, op: str):
ctx.define_operation( ctx.define_operation(
left=type, left=type,
@@ -19,21 +27,45 @@ def basic_op(ctx: MidasResolver, type: Type, op: str):
def define_builtins(ctx: MidasResolver): def define_builtins(ctx: MidasResolver):
"""Define builtin types and operations""" """Define builtin types and operations"""
unit = ctx.define_type("None", UnitType())
bool = ctx.define_type("bool", BaseType(name="bool")) bool = ctx.define_type("bool", BaseType(name="bool"))
int = ctx.define_type("int", BaseType(name="int")) int = ctx.define_type("int", BaseType(name="int"))
float = ctx.define_type("float", BaseType(name="float")) float = ctx.define_type("float", BaseType(name="float"))
str = ctx.define_type("str", BaseType(name="str")) str = ctx.define_type("str", BaseType(name="str"))
basic_op(ctx, int, "__add__") basic_op(ctx, int, "__add__") # int + int = int
basic_op(ctx, int, "__sub__") basic_op(ctx, int, "__sub__") # int - int = int
basic_op(ctx, int, "__mul__") basic_op(ctx, int, "__mul__") # int * int = int
basic_op(ctx, int, "__pow__") basic_op(ctx, int, "__pow__") # int ** int = int
basic_op(ctx, int, "__mod__") basic_op(ctx, int, "__mod__") # int % int = int
basic_op(ctx, int, "__and__") basic_op(ctx, int, "__and__") # int & int = int
basic_op(ctx, int, "__or__") basic_op(ctx, int, "__or__") # int | int = int
basic_op(ctx, int, "__xor__") basic_op(ctx, int, "__xor__") # int ^ int = int
basic_op(ctx, float, "__add__") op(ctx, int, "__lt__", int, bool) # int < int = bool
basic_op(ctx, float, "__sub__") op(ctx, int, "__gt__", int, bool) # int > int = bool
basic_op(ctx, float, "__mul__") op(ctx, int, "__le__", int, bool) # int <= int = bool
basic_op(ctx, float, "__truediv__") op(ctx, int, "__ge__", int, bool) # int >= int = bool
basic_op(ctx, str, "__add__") op(ctx, int, "__eq__", int, bool) # int == int = bool
basic_op(ctx, float, "__add__") # float + float = float
basic_op(ctx, float, "__sub__") # float - float = float
basic_op(ctx, float, "__mul__") # float * float = float
basic_op(ctx, float, "__truediv__") # float / float = float
op(ctx, float, "__lt__", float, bool) # float < float = bool
op(ctx, float, "__gt__", float, bool) # float > float = bool
op(ctx, float, "__le__", float, bool) # float <= float = bool
op(ctx, float, "__ge__", float, bool) # float >= float = bool
op(ctx, float, "__eq__", float, bool) # float == float = bool
basic_op(ctx, str, "__add__") # str + str = str
op(ctx, str, "__eq__", str, bool) # str == str = bool
op(ctx, int, "__lt__", float, bool) # int < float = bool
op(ctx, int, "__gt__", float, bool) # int > float = bool
op(ctx, int, "__le__", float, bool) # int <= float = bool
op(ctx, int, "__ge__", float, bool) # int >= float = bool
op(ctx, int, "__eq__", float, bool) # int == float = bool
op(ctx, float, "__lt__", int, bool) # float < int = bool
op(ctx, float, "__gt__", int, bool) # float > int = bool
op(ctx, float, "__le__", int, bool) # float <= int = bool
op(ctx, float, "__ge__", int, bool) # float >= int = bool
op(ctx, float, "__eq__", int, bool) # float == int = bool