Skip to content
Snippets Groups Projects
Commit f1c556e6 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Integrate reduction pointers to parameters.py

parent 777ab888
1 merge request!438Reduction Support
from __future__ import annotations
from warnings import warn
from typing import Sequence, Iterable
from typing import Sequence, Iterable, Optional
from .properties import (
PsSymbolProperty,
_FieldProperty,
FieldShape,
FieldStride,
FieldBasePtr,
FieldBasePtr, ReductionPointerVariable,
)
from ..types import PsType
from ..field import Field
......@@ -39,6 +39,9 @@ class Parameter:
key=lambda f: f.name,
)
)
self._reduction_ptr: Optional[ReductionPointerVariable] = next(
(e for e in self._properties if isinstance(e, ReductionPointerVariable)), None
)
@property
def name(self):
......@@ -79,6 +82,11 @@ class Parameter:
"""Set of fields associated with this parameter."""
return self._fields
@property
def reduction_pointer(self) -> Optional[ReductionPointerVariable]:
"""Reduction pointer associated with this parameter."""
return self._reduction_ptr
def get_properties(
self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...]
) -> set[PsSymbolProperty]:
......@@ -105,6 +113,10 @@ class Parameter:
)
return bool(self.get_properties(FieldBasePtr))
@property
def is_reduction_pointer(self) -> bool:
return bool(self._reduction_ptr)
@property
def is_field_stride(self) -> bool: # pragma: no cover
warn(
......
......@@ -206,6 +206,8 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
self._array_assoc_var_extractions: dict[Parameter, str] = dict()
self._scalar_extractions: dict[Parameter, str] = dict()
self._reduction_ptrs: dict[Parameter, str] = dict()
self._constraint_checks: list[str] = []
self._call: str | None = None
......@@ -265,10 +267,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
return self._array_buffers[field]
def extract_scalar(self, param: Parameter) -> str:
if any(isinstance(e, ReductionPointerVariable) for e in param.properties):
# TODO: implement
pass
elif param not in self._scalar_extractions:
if param not in self._scalar_extractions:
extract_func = self._scalar_extractor(param.dtype)
code = self.TMPL_EXTRACT_SCALAR.format(
name=param.name,
......@@ -279,6 +278,12 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
return param.name
def extract_reduction_ptr(self, param: Parameter) -> str:
if param not in self._reduction_ptrs:
# TODO: implement
pass
return param.name
def extract_array_assoc_var(self, param: Parameter) -> str:
if param not in self._array_assoc_var_extractions:
field = param.fields[0]
......@@ -306,7 +311,9 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
return param.name
def extract_parameter(self, param: Parameter):
if param.is_field_parameter:
if param.is_reduction_pointer:
self.extract_reduction_ptr(param)
elif param.is_field_parameter:
self.extract_array_assoc_var(param)
else:
self.extract_scalar(param)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment