diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 8c2b34fc65716c298d6dd15a65d0d660b305b8c4..22286156e8e4bd0da55c6ca8fc74e8ea990bab33 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -25,6 +25,8 @@ class FieldsInKernel: self.custom_fields: set[Field] = set() self.buffer_fields: set[Field] = set() + self.archetype_field: Field | None = None + def __iter__(self) -> Iterator: return chain( self.domain_fields, diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 63dc9170a0efc1e5bb0b626353d290774c584ac8..a13f21ae29715f491cefb1f08245855e499fc4d5 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -182,6 +182,10 @@ class FreezeExpressions: def map_Integer(self, expr: sp.Integer) -> PsConstantExpr: value = int(expr) return PsConstantExpr(PsConstant(value)) + + def map_Float(self, expr: sp.Float) -> PsConstantExpr: + value = float(expr) # TODO: check accuracy of evaluation + return PsConstantExpr(PsConstant(value)) def map_Rational(self, expr: sp.Rational) -> PsExpression: num = PsConstantExpr(PsConstant(expr.numerator)) diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index e5b586688297bf7433dafc20e7801efda739dbc2..3a7143bc2cb3dfa25500b50a12f87257cae825ca 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -114,11 +114,7 @@ class FullIterationSpace(IterationSpace): ) ] - # Determine loop order by permuting dimensions - loop_order = archetype_field.layout - dimensions = [dimensions[coordinate] for coordinate in loop_order] - - return FullIterationSpace(ctx, dimensions) + return FullIterationSpace(ctx, dimensions, archetype_field=archetype_field) @staticmethod def create_from_slice( @@ -176,18 +172,21 @@ class FullIterationSpace(IterationSpace): ) ] - # Determine loop order by permuting dimensions - loop_order = archetype_field.layout - dimensions = [dimensions[coordinate] for coordinate in loop_order] - - return FullIterationSpace(ctx, dimensions) + return FullIterationSpace(ctx, dimensions, archetype_field=archetype_field) - def __init__(self, ctx: KernelCreationContext, dimensions: Sequence[Dimension]): + def __init__( + self, + ctx: KernelCreationContext, + dimensions: Sequence[Dimension], + archetype_field: Field | None = None, + ): super().__init__(tuple(dim.counter for dim in dimensions)) self._ctx = ctx self._dimensions = dimensions + self._archetype_field = archetype_field + @property def dimensions(self): return self._dimensions @@ -204,6 +203,10 @@ class FullIterationSpace(IterationSpace): def steps(self): return (dim.step for dim in self._dimensions) + @property + def archetype_field(self) -> Field | None: + return self._archetype_field + def actual_iterations(self, dimension: int | None = None) -> PsExpression: if dimension is None: return reduce( diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index d512153a6e7c2c8a22821f5a1f6b596af1addc6c..6b19a88ff975ab1a63e519913a526c6be80b6784 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -6,7 +6,7 @@ from .platform import Platform from ..kernelcreation.iteration_space import ( IterationSpace, FullIterationSpace, - SparseIterationSpace, + SparseIterationSpace ) from ..constants import PsConstant @@ -43,7 +43,15 @@ class GenericCpu(Platform): def _create_domain_loops( self, body: PsBlock, ispace: FullIterationSpace ) -> PsBlock: + dimensions = ispace.dimensions + + # Determine loop order by permuting dimensions + archetype_field = ispace.archetype_field + if archetype_field is not None: + loop_order = archetype_field.layout + dimensions = [dimensions[coordinate] for coordinate in loop_order] + outer_block = body for dimension in dimensions[::-1]: diff --git a/src/pystencils/types/__init__.py b/src/pystencils/types/__init__.py index 7096bf6c24a9989e70fa12c5770a4168ba6cd2c4..0c475c8ac9178c1ed89a17dcfd9d5ceb2ebf112a 100644 --- a/src/pystencils/types/__init__.py +++ b/src/pystencils/types/__init__.py @@ -15,7 +15,7 @@ from .basic_types import ( deconstify, ) -from .quick import create_type, create_numeric_type +from .quick import UserTypeSpec, create_type, create_numeric_type from .exception import PsTypeError @@ -34,6 +34,7 @@ __all__ = [ "PsIeeeFloatType", "constify", "deconstify", + "UserTypeSpec", "create_type", "create_numeric_type", "PsTypeError", diff --git a/src/pystencils/types/basic_types.py b/src/pystencils/types/basic_types.py index eb373bf3de878379054f4dd9cdfba4f5630516a4..55055205fe1c8725ac567d57b1d225cc6e0e8eb9 100644 --- a/src/pystencils/types/basic_types.py +++ b/src/pystencils/types/basic_types.py @@ -463,6 +463,18 @@ class PsBoolType(PsScalarType): return np.False_ else: raise PsTypeError(f"Cannot create boolean constant from value {value}") + + def c_string(self) -> str: + return "bool" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsBoolType): + return False + + return self._base_equal(other) + + def __hash__(self) -> int: + return hash(("PsBoolType", self._const)) class PsIntegerType(PsScalarType, ABC):