diff --git a/src/pystencils/codegen/parameters.py b/src/pystencils/codegen/parameters.py index d8411266ee514d4270a7a9d1c2fb24383f005329..094553517a56a189f4fa714749b5f7b5761f8e33 100644 --- a/src/pystencils/codegen/parameters.py +++ b/src/pystencils/codegen/parameters.py @@ -1,14 +1,14 @@ from __future__ import annotations from warnings import warn -from typing import Sequence, Iterable +from typing import Sequence, Iterable, Optional from .properties import ( PsSymbolProperty, _FieldProperty, FieldShape, FieldStride, - FieldBasePtr, + FieldBasePtr, ReductionPointerVariable, ) from ..types import PsType from ..field import Field @@ -39,6 +39,9 @@ class Parameter: key=lambda f: f.name, ) ) + self._reduction_ptr: Optional[ReductionPointerVariable] = next( + (e for e in self._properties if isinstance(e, ReductionPointerVariable)), None + ) @property def name(self): @@ -79,6 +82,11 @@ class Parameter: """Set of fields associated with this parameter.""" return self._fields + @property + def reduction_pointer(self) -> Optional[ReductionPointerVariable]: + """Reduction pointer associated with this parameter.""" + return self._reduction_ptr + def get_properties( self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...] ) -> set[PsSymbolProperty]: @@ -105,6 +113,10 @@ class Parameter: ) return bool(self.get_properties(FieldBasePtr)) + @property + def is_reduction_pointer(self) -> bool: + return bool(self._reduction_ptr) + @property def is_field_stride(self) -> bool: # pragma: no cover warn( diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py index c2c969eaae1057f7892150d10124f11d49e6d060..f9c04200c249594babede2bab3ff79d59e909045 100644 --- a/src/pystencils/jit/cpu_extension_module.py +++ b/src/pystencils/jit/cpu_extension_module.py @@ -206,6 +206,8 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ self._array_assoc_var_extractions: dict[Parameter, str] = dict() self._scalar_extractions: dict[Parameter, str] = dict() + self._reduction_ptrs: dict[Parameter, str] = dict() + self._constraint_checks: list[str] = [] self._call: str | None = None @@ -265,10 +267,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return self._array_buffers[field] def extract_scalar(self, param: Parameter) -> str: - if any(isinstance(e, ReductionPointerVariable) for e in param.properties): - # TODO: implement - pass - elif param not in self._scalar_extractions: + if param not in self._scalar_extractions: extract_func = self._scalar_extractor(param.dtype) code = self.TMPL_EXTRACT_SCALAR.format( name=param.name, @@ -279,6 +278,12 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return param.name + def extract_reduction_ptr(self, param: Parameter) -> str: + if param not in self._reduction_ptrs: + # TODO: implement + pass + return param.name + def extract_array_assoc_var(self, param: Parameter) -> str: if param not in self._array_assoc_var_extractions: field = param.fields[0] @@ -306,7 +311,9 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return param.name def extract_parameter(self, param: Parameter): - if param.is_field_parameter: + if param.is_reduction_pointer: + self.extract_reduction_ptr(param) + elif param.is_field_parameter: self.extract_array_assoc_var(param) else: self.extract_scalar(param)