From 0cec1fa81c131b4fbc43d582ac552c8b70d2bc2f Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 8 Mar 2024 18:19:24 +0100 Subject: [PATCH] various fixes to bugs encountered during waLBerla integration --- .../backend/kernelcreation/context.py | 2 ++ .../backend/kernelcreation/freeze.py | 4 +++ .../backend/kernelcreation/iteration_space.py | 25 +++++++++++-------- .../backend/platforms/generic_cpu.py | 10 +++++++- src/pystencils/types/__init__.py | 3 ++- src/pystencils/types/basic_types.py | 12 +++++++++ 6 files changed, 43 insertions(+), 13 deletions(-) diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 8c2b34fc6..22286156e 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -25,6 +25,8 @@ class FieldsInKernel: self.custom_fields: set[Field] = set() self.buffer_fields: set[Field] = set() + self.archetype_field: Field | None = None + def __iter__(self) -> Iterator: return chain( self.domain_fields, diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 63dc9170a..a13f21ae2 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -182,6 +182,10 @@ class FreezeExpressions: def map_Integer(self, expr: sp.Integer) -> PsConstantExpr: value = int(expr) return PsConstantExpr(PsConstant(value)) + + def map_Float(self, expr: sp.Float) -> PsConstantExpr: + value = float(expr) # TODO: check accuracy of evaluation + return PsConstantExpr(PsConstant(value)) def map_Rational(self, expr: sp.Rational) -> PsExpression: num = PsConstantExpr(PsConstant(expr.numerator)) diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index e5b586688..3a7143bc2 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -114,11 +114,7 @@ class FullIterationSpace(IterationSpace): ) ] - # Determine loop order by permuting dimensions - loop_order = archetype_field.layout - dimensions = [dimensions[coordinate] for coordinate in loop_order] - - return FullIterationSpace(ctx, dimensions) + return FullIterationSpace(ctx, dimensions, archetype_field=archetype_field) @staticmethod def create_from_slice( @@ -176,18 +172,21 @@ class FullIterationSpace(IterationSpace): ) ] - # Determine loop order by permuting dimensions - loop_order = archetype_field.layout - dimensions = [dimensions[coordinate] for coordinate in loop_order] - - return FullIterationSpace(ctx, dimensions) + return FullIterationSpace(ctx, dimensions, archetype_field=archetype_field) - def __init__(self, ctx: KernelCreationContext, dimensions: Sequence[Dimension]): + def __init__( + self, + ctx: KernelCreationContext, + dimensions: Sequence[Dimension], + archetype_field: Field | None = None, + ): super().__init__(tuple(dim.counter for dim in dimensions)) self._ctx = ctx self._dimensions = dimensions + self._archetype_field = archetype_field + @property def dimensions(self): return self._dimensions @@ -204,6 +203,10 @@ class FullIterationSpace(IterationSpace): def steps(self): return (dim.step for dim in self._dimensions) + @property + def archetype_field(self) -> Field | None: + return self._archetype_field + def actual_iterations(self, dimension: int | None = None) -> PsExpression: if dimension is None: return reduce( diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index d512153a6..6b19a88ff 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -6,7 +6,7 @@ from .platform import Platform from ..kernelcreation.iteration_space import ( IterationSpace, FullIterationSpace, - SparseIterationSpace, + SparseIterationSpace ) from ..constants import PsConstant @@ -43,7 +43,15 @@ class GenericCpu(Platform): def _create_domain_loops( self, body: PsBlock, ispace: FullIterationSpace ) -> PsBlock: + dimensions = ispace.dimensions + + # Determine loop order by permuting dimensions + archetype_field = ispace.archetype_field + if archetype_field is not None: + loop_order = archetype_field.layout + dimensions = [dimensions[coordinate] for coordinate in loop_order] + outer_block = body for dimension in dimensions[::-1]: diff --git a/src/pystencils/types/__init__.py b/src/pystencils/types/__init__.py index 7096bf6c2..0c475c8ac 100644 --- a/src/pystencils/types/__init__.py +++ b/src/pystencils/types/__init__.py @@ -15,7 +15,7 @@ from .basic_types import ( deconstify, ) -from .quick import create_type, create_numeric_type +from .quick import UserTypeSpec, create_type, create_numeric_type from .exception import PsTypeError @@ -34,6 +34,7 @@ __all__ = [ "PsIeeeFloatType", "constify", "deconstify", + "UserTypeSpec", "create_type", "create_numeric_type", "PsTypeError", diff --git a/src/pystencils/types/basic_types.py b/src/pystencils/types/basic_types.py index eb373bf3d..55055205f 100644 --- a/src/pystencils/types/basic_types.py +++ b/src/pystencils/types/basic_types.py @@ -463,6 +463,18 @@ class PsBoolType(PsScalarType): return np.False_ else: raise PsTypeError(f"Cannot create boolean constant from value {value}") + + def c_string(self) -> str: + return "bool" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsBoolType): + return False + + return self._base_equal(other) + + def __hash__(self) -> int: + return hash(("PsBoolType", self._const)) class PsIntegerType(PsScalarType, ABC): -- GitLab