diff --git a/midas/checker/checker.py b/midas/checker/checker.py index 4ebedd5..cec5c80 100644 --- a/midas/checker/checker.py +++ b/midas/checker/checker.py @@ -8,7 +8,7 @@ import midas.ast.python as p from midas.ast.location import Location from midas.checker.diagnostic import Diagnostic, DiagnosticType from midas.checker.environment import Environment -from midas.checker.operators import OPERATOR_METHODS +from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS from midas.checker.types import Function, Type, UnitType, UnknownType from midas.lexer.midas import MidasLexer from midas.lexer.token import Token @@ -294,7 +294,23 @@ class Checker( return UnknownType() return result - def visit_compare_expr(self, expr: p.CompareExpr) -> Type: ... + def visit_compare_expr(self, expr: p.CompareExpr) -> Type: + method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__) + if method is None: + self.logger.warning(f"Unsupported operator {expr.operator}") + self.warning(expr.location, f"Unsupported operator {expr.operator}") + return UnknownType() + left: Type = self.evaluate(expr.left) + right: Type = self.evaluate(expr.right) + + result: Optional[Type] = self.ctx.get_operation_result(left, method, right) + if result is None: + self.error( + expr.location, + f"Undefined operation {method} between {left} and {right}", + ) + return UnknownType() + return result def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ... diff --git a/midas/resolver/builtin.py b/midas/resolver/builtin.py index 0ceae90..6a00b14 100644 --- a/midas/resolver/builtin.py +++ b/midas/resolver/builtin.py @@ -2,12 +2,20 @@ from __future__ import annotations from typing import TYPE_CHECKING -from midas.checker.types import BaseType, Type +from midas.checker.types import BaseType, Type, UnitType if TYPE_CHECKING: from midas.resolver.midas import MidasResolver +def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type): + ctx.define_operation( + left=t1, + operator=operator, + right=t2, + result=t3, + ) + def basic_op(ctx: MidasResolver, type: Type, op: str): ctx.define_operation( left=type, @@ -19,21 +27,45 @@ def basic_op(ctx: MidasResolver, type: Type, op: str): def define_builtins(ctx: MidasResolver): """Define builtin types and operations""" + unit = ctx.define_type("None", UnitType()) bool = ctx.define_type("bool", BaseType(name="bool")) int = ctx.define_type("int", BaseType(name="int")) float = ctx.define_type("float", BaseType(name="float")) str = ctx.define_type("str", BaseType(name="str")) - basic_op(ctx, int, "__add__") - basic_op(ctx, int, "__sub__") - basic_op(ctx, int, "__mul__") - basic_op(ctx, int, "__pow__") - basic_op(ctx, int, "__mod__") - basic_op(ctx, int, "__and__") - basic_op(ctx, int, "__or__") - basic_op(ctx, int, "__xor__") - basic_op(ctx, float, "__add__") - basic_op(ctx, float, "__sub__") - basic_op(ctx, float, "__mul__") - basic_op(ctx, float, "__truediv__") - basic_op(ctx, str, "__add__") + basic_op(ctx, int, "__add__") # int + int = int + basic_op(ctx, int, "__sub__") # int - int = int + basic_op(ctx, int, "__mul__") # int * int = int + basic_op(ctx, int, "__pow__") # int ** int = int + basic_op(ctx, int, "__mod__") # int % int = int + basic_op(ctx, int, "__and__") # int & int = int + basic_op(ctx, int, "__or__") # int | int = int + basic_op(ctx, int, "__xor__") # int ^ int = int + op(ctx, int, "__lt__", int, bool) # int < int = bool + op(ctx, int, "__gt__", int, bool) # int > int = bool + op(ctx, int, "__le__", int, bool) # int <= int = bool + op(ctx, int, "__ge__", int, bool) # int >= int = bool + op(ctx, int, "__eq__", int, bool) # int == int = bool + basic_op(ctx, float, "__add__") # float + float = float + basic_op(ctx, float, "__sub__") # float - float = float + basic_op(ctx, float, "__mul__") # float * float = float + basic_op(ctx, float, "__truediv__") # float / float = float + op(ctx, float, "__lt__", float, bool) # float < float = bool + op(ctx, float, "__gt__", float, bool) # float > float = bool + op(ctx, float, "__le__", float, bool) # float <= float = bool + op(ctx, float, "__ge__", float, bool) # float >= float = bool + op(ctx, float, "__eq__", float, bool) # float == float = bool + basic_op(ctx, str, "__add__") # str + str = str + op(ctx, str, "__eq__", str, bool) # str == str = bool + + op(ctx, int, "__lt__", float, bool) # int < float = bool + op(ctx, int, "__gt__", float, bool) # int > float = bool + op(ctx, int, "__le__", float, bool) # int <= float = bool + op(ctx, int, "__ge__", float, bool) # int >= float = bool + op(ctx, int, "__eq__", float, bool) # int == float = bool + + op(ctx, float, "__lt__", int, bool) # float < int = bool + op(ctx, float, "__gt__", int, bool) # float > int = bool + op(ctx, float, "__le__", int, bool) # float <= int = bool + op(ctx, float, "__ge__", int, bool) # float >= int = bool + op(ctx, float, "__eq__", int, bool) # float == int = bool \ No newline at end of file