From 0cec1fa81c131b4fbc43d582ac552c8b70d2bc2f Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 8 Mar 2024 18:19:24 +0100
Subject: [PATCH] various fixes to bugs encountered during waLBerla integration

---
 .../backend/kernelcreation/context.py         |  2 ++
 .../backend/kernelcreation/freeze.py          |  4 +++
 .../backend/kernelcreation/iteration_space.py | 25 +++++++++++--------
 .../backend/platforms/generic_cpu.py          | 10 +++++++-
 src/pystencils/types/__init__.py              |  3 ++-
 src/pystencils/types/basic_types.py           | 12 +++++++++
 6 files changed, 43 insertions(+), 13 deletions(-)

diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py
index 8c2b34fc6..22286156e 100644
--- a/src/pystencils/backend/kernelcreation/context.py
+++ b/src/pystencils/backend/kernelcreation/context.py
@@ -25,6 +25,8 @@ class FieldsInKernel:
         self.custom_fields: set[Field] = set()
         self.buffer_fields: set[Field] = set()
 
+        self.archetype_field: Field | None = None
+
     def __iter__(self) -> Iterator:
         return chain(
             self.domain_fields,
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index 63dc9170a..a13f21ae2 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -182,6 +182,10 @@ class FreezeExpressions:
     def map_Integer(self, expr: sp.Integer) -> PsConstantExpr:
         value = int(expr)
         return PsConstantExpr(PsConstant(value))
+    
+    def map_Float(self, expr: sp.Float) -> PsConstantExpr:
+        value = float(expr)  # TODO: check accuracy of evaluation
+        return PsConstantExpr(PsConstant(value))
 
     def map_Rational(self, expr: sp.Rational) -> PsExpression:
         num = PsConstantExpr(PsConstant(expr.numerator))
diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py
index e5b586688..3a7143bc2 100644
--- a/src/pystencils/backend/kernelcreation/iteration_space.py
+++ b/src/pystencils/backend/kernelcreation/iteration_space.py
@@ -114,11 +114,7 @@ class FullIterationSpace(IterationSpace):
             )
         ]
 
-        #   Determine loop order by permuting dimensions
-        loop_order = archetype_field.layout
-        dimensions = [dimensions[coordinate] for coordinate in loop_order]
-
-        return FullIterationSpace(ctx, dimensions)
+        return FullIterationSpace(ctx, dimensions, archetype_field=archetype_field)
 
     @staticmethod
     def create_from_slice(
@@ -176,18 +172,21 @@ class FullIterationSpace(IterationSpace):
             )
         ]
 
-        #   Determine loop order by permuting dimensions
-        loop_order = archetype_field.layout
-        dimensions = [dimensions[coordinate] for coordinate in loop_order]
-
-        return FullIterationSpace(ctx, dimensions)
+        return FullIterationSpace(ctx, dimensions, archetype_field=archetype_field)
 
-    def __init__(self, ctx: KernelCreationContext, dimensions: Sequence[Dimension]):
+    def __init__(
+        self,
+        ctx: KernelCreationContext,
+        dimensions: Sequence[Dimension],
+        archetype_field: Field | None = None,
+    ):
         super().__init__(tuple(dim.counter for dim in dimensions))
 
         self._ctx = ctx
         self._dimensions = dimensions
 
+        self._archetype_field = archetype_field
+
     @property
     def dimensions(self):
         return self._dimensions
@@ -204,6 +203,10 @@ class FullIterationSpace(IterationSpace):
     def steps(self):
         return (dim.step for dim in self._dimensions)
 
+    @property
+    def archetype_field(self) -> Field | None:
+        return self._archetype_field
+
     def actual_iterations(self, dimension: int | None = None) -> PsExpression:
         if dimension is None:
             return reduce(
diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py
index d512153a6..6b19a88ff 100644
--- a/src/pystencils/backend/platforms/generic_cpu.py
+++ b/src/pystencils/backend/platforms/generic_cpu.py
@@ -6,7 +6,7 @@ from .platform import Platform
 from ..kernelcreation.iteration_space import (
     IterationSpace,
     FullIterationSpace,
-    SparseIterationSpace,
+    SparseIterationSpace
 )
 
 from ..constants import PsConstant
@@ -43,7 +43,15 @@ class GenericCpu(Platform):
     def _create_domain_loops(
         self, body: PsBlock, ispace: FullIterationSpace
     ) -> PsBlock:
+        
         dimensions = ispace.dimensions
+
+        #   Determine loop order by permuting dimensions
+        archetype_field = ispace.archetype_field
+        if archetype_field is not None:
+            loop_order = archetype_field.layout
+            dimensions = [dimensions[coordinate] for coordinate in loop_order]
+
         outer_block = body
 
         for dimension in dimensions[::-1]:
diff --git a/src/pystencils/types/__init__.py b/src/pystencils/types/__init__.py
index 7096bf6c2..0c475c8ac 100644
--- a/src/pystencils/types/__init__.py
+++ b/src/pystencils/types/__init__.py
@@ -15,7 +15,7 @@ from .basic_types import (
     deconstify,
 )
 
-from .quick import create_type, create_numeric_type
+from .quick import UserTypeSpec, create_type, create_numeric_type
 
 from .exception import PsTypeError
 
@@ -34,6 +34,7 @@ __all__ = [
     "PsIeeeFloatType",
     "constify",
     "deconstify",
+    "UserTypeSpec",
     "create_type",
     "create_numeric_type",
     "PsTypeError",
diff --git a/src/pystencils/types/basic_types.py b/src/pystencils/types/basic_types.py
index eb373bf3d..55055205f 100644
--- a/src/pystencils/types/basic_types.py
+++ b/src/pystencils/types/basic_types.py
@@ -463,6 +463,18 @@ class PsBoolType(PsScalarType):
             return np.False_
         else:
             raise PsTypeError(f"Cannot create boolean constant from value {value}")
+        
+    def c_string(self) -> str:
+        return "bool"
+    
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, PsBoolType):
+            return False
+        
+        return self._base_equal(other)
+    
+    def __hash__(self) -> int:
+        return hash(("PsBoolType", self._const))
 
 
 class PsIntegerType(PsScalarType, ABC):
-- 
GitLab