Commit 5022bcca authored by Martin Bauer's avatar Martin Bauer
Browse files

Generalized flag handling & support for flag dependent quantities

- flags can be force when setting boundaries, this is helpful for:
- setting up quantities (e.g. relaxation rate..) dependent on current
  flag configuration
- bit operation fix when pickling: function has to have the same name
  as the python object
parent ecb9390a
......@@ -4,7 +4,7 @@ from typing import Callable, Any, Optional, Sequence
from pystencils.assignment_collection.assignment_collection import AssignmentCollection
class SimplificationStrategy(object):
class SimplificationStrategy:
"""A simplification strategy is an ordered collection of simplification rules.
Each simplification is a function taking an equation collection, and returning a new simplified
......@@ -8,7 +8,7 @@ from typing import List, Set, Optional, Union, Any
NodeOrExpr = Union['Node', sp.Expr]
class Node(object):
class Node:
"""Base class for all AST nodes."""
def __init__(self, parent: Optional['Node'] = None):
import sympy as sp
bitwise_xor = sp.Function("")
bit_shift_right = sp.Function("rshift")
bit_shift_left = sp.Function("lshift")
bitwise_and = sp.Function("Bit&")
bitwise_or = sp.Function("Bit|")
bitwise_xor = sp.Function("bitwise_xor")
bit_shift_right = sp.Function("bit_shift_right")
bit_shift_left = sp.Function("bit_shift_left")
bitwise_and = sp.Function("bitwise_and")
bitwise_or = sp.Function("bitwise_or")
......@@ -3,7 +3,7 @@ from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo
from typing import List, Tuple, Any
class Boundary(object):
class Boundary:
"""Base class all boundaries should derive from"""
def __init__(self, name=None):
......@@ -7,30 +7,63 @@ from pystencils.boundaries.createindexlist import numpy_data_type_for_boundary_o
from pystencils.cache import memorycache
from pystencils.data_types import create_type
class FlagInterface:
FLAG_DTYPE = np.uint32
def __init__(self, data_handling, flag_field_name):
class FlagInterface:
"""Manages the reservation of bits (i.e. flags) in an array of unsigned integers.
>>> from pystencils.datahandling import SerialDataHandling
>>> dh = SerialDataHandling((4, 5))
>>> fi = FlagInterface(dh, 'flag_field', np.uint8)
>>> assert dh.has_data('flag_field')
>>> fi.reserve_next_flag()
>>> fi.reserve_flag(4)
>>> fi.reserve_next_flag()
def __init__(self, data_handling, flag_field_name, dtype=DEFAULT_FLAG_TYPE):
self.flag_field_name = flag_field_name
self.domain_flag = self.FLAG_DTYPE(1 << 0)
self._nextFreeFlag = 1
self.domain_flag = dtype(1 << 0)
self._used_flags = {self.domain_flag}
self.data_handling = data_handling
self.dtype = dtype
self.max_bits = self.dtype().itemsize * 8
# Add flag field to data handling if it does not yet exist
if data_handling.has_data(self.flag_field_name):
raise ValueError("There is already a boundary handling registered at the data handling."
"If you want to add multiple handling objects, choose a different name.")
data_handling.add_array(self.flag_field_name, dtype=self.FLAG_DTYPE, cpu=True, gpu=False)
self.flag_field = data_handling.add_array(self.flag_field_name, dtype=self.dtype, cpu=True, gpu=False)
ff_ghost_layers = data_handling.ghost_layers_of_field(self.flag_field_name)
for b in data_handling.iterate(ghost_layers=ff_ghost_layers):
def allocate_next_flag(self):
result = self.FLAG_DTYPE(1 << self._nextFreeFlag)
self._nextFreeFlag += 1
return result
def reserve_next_flag(self):
for i in range(1, self.max_bits):
flag = self.dtype(1 << i)
if flag not in self._used_flags:
assert self._is_power_of_2(flag)
return flag
raise ValueError(f"All available {self.max_bits} flags are reserved")
def reserve_flag(self, flag):
assert self._is_power_of_2(flag)
flag = self.dtype(flag)
if flag in self._used_flags:
raise ValueError(f"The flag {flag} is already reserved")
return flag
def _is_power_of_2(num):
return num != 0 and ((num & (num - 1)) == 0)
class BoundaryHandling:
......@@ -92,7 +125,7 @@ class BoundaryHandling:
return result
def set_boundary(self, boundary_obj, slice_obj=None, mask_callback=None,
ghost_layers=True, inner_ghost_layers=True, replace=True):
ghost_layers=True, inner_ghost_layers=True, replace=True, force_flag_value=None):
"""Sets boundary using either a rectangular slice, a boolean mask or a combination of both.
......@@ -111,11 +144,15 @@ class BoundaryHandling:
inner_ghost_layers: see DataHandling.iterate()
replace: by default all other flags are erased in the cells where the boundary is set. To add a
boundary condition, set this replace flag to False
force_flag_value: flag that should be reserved for this boundary. Has to be an integer that is a power of 2
and was not reserved before for another boundary.
if isinstance(boundary_obj, str) and boundary_obj.lower() == 'domain':
flag = self.flag_interface.domain_flag
flag = self._add_boundary(boundary_obj)
if force_flag_value:
flag = self._add_boundary(boundary_obj, force_flag_value)
for b in self._data_handling.iterate(slice_obj, ghost_layers=ghost_layers,
......@@ -139,6 +176,7 @@ class BoundaryHandling:
return flag
def set_boundary_where_flag_is_set(self, boundary_obj, flag):
"""Adds an (additional) boundary to all cells that have been previously marked with the passed flag."""
self._add_boundary(boundary_obj, flag)
self._dirty = True
return flag
......@@ -193,7 +231,8 @@ class BoundaryHandling:
if b == 'domain':
masks_to_name[self.flag_interface.domain_flag] = 'domain'
masks_to_name[self._boundary_object_to_boundary_info[b].flag] =
flag = self._boundary_object_to_boundary_info[b].flag
masks_to_name[flag] =
writer = self.data_handling.create_vtk_writer_for_flag_array(file_name, self.flag_interface.flag_field_name,
masks_to_name, ghost_layers=ghost_layers)
......@@ -208,7 +247,7 @@ class BoundaryHandling:
ast = self._create_boundary_kernel(self._data_handling.fields[self._field_name],
sym_index_field, boundary_obj)
if flag is None:
flag = self.flag_interface.allocate_next_flag()
flag = self.flag_interface.reserve_next_flag()
boundary_info = self.BoundaryInfo(boundary_obj, flag=flag, kernel=ast.compile())
self._boundary_object_to_boundary_info[boundary_obj] = boundary_info
return self._boundary_object_to_boundary_info[boundary_obj].flag
......@@ -289,7 +328,7 @@ class BoundaryDataSetter:
arr_field_names = index_array.dtype.names
self.dim = 3 if 'z' in arr_field_names else 2
assert 'x' in arr_field_names and 'y' in arr_field_names and 'dir' in arr_field_names, str(arr_field_names)
self.boundary_data_names = set(self.index_array.dtype.names) - set(['x', 'y', 'z', 'dir'])
self.boundary_data_names = set(self.index_array.dtype.names) - {'x', 'y', 'z', 'dir'}
self.coord_map = {0: 'x', 1: 'y', 2: 'z'}
self.ghost_layers = ghost_layers
......@@ -470,7 +470,7 @@ class PointerType(Type):
return hash((self._base_type, self.const, self.restrict))
class StructType(object):
class StructType:
def __init__(self, numpy_type, const=False):
self.const = const
self._dtype = np.dtype(numpy_type)
......@@ -35,7 +35,7 @@ class FieldType(Enum):
return field.field_type == FieldType.BUFFER
class Field(object):
class Field:
With fields one can formulate stencil-like update rules on structured grids.
This Field class knows about the dimension, memory layout (strides) and optionally about the size of an array.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment