feat(checker): handle ternary expression

This commit is contained in:
2026-06-01 15:02:12 +02:00
parent 55060bfecd
commit bea3f399ad
4 changed files with 27 additions and 2 deletions

View File

@@ -12,3 +12,5 @@ def factorial(n: int) -> int:
if n <= 1:
return 1
return n * factorial(n - 1)
category = "Category 1" if a < 10 else "Category 2"

View File

@@ -9,7 +9,7 @@ from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.environment import Environment
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
from midas.checker.types import Function, Type, UnitType, UnknownType
from midas.checker.types import BaseType, Function, SimpleType, Type, UnitType, UnknownType
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
@@ -405,6 +405,22 @@ class Checker(
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
return expr.type.accept(self)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
test_type: Type = expr.test.accept(self)
# TODO Allow subtypes or any type
if test_type != self.ctx.get_type("bool"):
self.error(
expr.test.location, f"If test must be a boolean, got {test_type}"
)
true_type: Type = expr.if_true.accept(self)
false_type: Type = expr.if_false.accept(self)
if true_type != false_type:
self.error(expr.location, f"Type mismatch in ternary if branches: true={true_type} != false={false_type}")
return UnknownType()
return true_type
def visit_base_type(self, node: p.BaseType) -> Type:
return self.ctx.get_type(node.base)

View File

@@ -203,6 +203,8 @@ class PythonHighlighter(
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"

View File

@@ -180,3 +180,8 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
def visit_cast_expr(self, expr: p.CastExpr) -> None:
self.resolve(expr.expr)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self.resolve(expr.test)
self.resolve(expr.if_true)
self.resolve(expr.if_false)