From e6deceee80fb330850336b90b51753bceb476b73 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 15 Mar 2024 10:59:26 +0100
Subject: [PATCH] typification of PsConditional

---
 .../backend/kernelcreation/typification.py    | 19 ++++++++++++++++++-
 1 file changed, 18 insertions(+), 1 deletion(-)

diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index d3f0b0331..259821afa 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -10,9 +10,17 @@ from ...types import (
     PsIntegerType,
     PsArrayType,
     PsSubscriptableType,
+    PsBoolType,
     deconstify,
 )
-from ..ast.structural import PsAstNode, PsBlock, PsLoop, PsExpression, PsAssignment
+from ..ast.structural import (
+    PsAstNode,
+    PsBlock,
+    PsLoop,
+    PsConditional,
+    PsExpression,
+    PsAssignment,
+)
 from ..ast.expressions import (
     PsSymbolExpr,
     PsConstantExpr,
@@ -162,6 +170,15 @@ class Typifier:
                 assert tc.target_type is not None
                 self.visit_expr(rhs, tc)
 
+            case PsConditional(cond, branch_true, branch_false):
+                cond_tc = TypeContext(PsBoolType(const=True))
+                self.visit_expr(cond, cond_tc)
+
+                self.visit(branch_true)
+
+                if branch_false is not None:
+                    self.visit(branch_false)
+
             case PsLoop(ctr, start, stop, step, body):
                 if ctr.symbol.dtype is None:
                     ctr.symbol.apply_dtype(self._ctx.index_dtype)
-- 
GitLab