From 5022bccad306e73567e4610bb731a9f31ef77d82 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Wed, 18 Apr 2018 15:34:53 +0200
Subject: [PATCH] Generalized flag handling & support for flag dependent
 quantities

- flags can be force when setting boundaries, this is helpful for:
- setting up quantities (e.g. relaxation rate..) dependent on current
  flag configuration
- bit operation fix when pickling: function has to have the same name
  as the python object
---
 .../simplificationstrategy.py                 |  2 +-
 astnodes.py                                   |  2 +-
 bitoperations.py                              | 10 +--
 boundaries/boundaryconditions.py              |  2 +-
 boundaries/boundaryhandling.py                | 69 +++++++++++++++----
 data_types.py                                 |  2 +-
 field.py                                      |  2 +-
 7 files changed, 64 insertions(+), 25 deletions(-)

diff --git a/assignment_collection/simplificationstrategy.py b/assignment_collection/simplificationstrategy.py
index bfd88f2ef..d87b1ee1b 100644
--- a/assignment_collection/simplificationstrategy.py
+++ b/assignment_collection/simplificationstrategy.py
@@ -4,7 +4,7 @@ from typing import Callable, Any, Optional, Sequence
 from pystencils.assignment_collection.assignment_collection import AssignmentCollection
 
 
-class SimplificationStrategy(object):
+class SimplificationStrategy:
     """A simplification strategy is an ordered collection of simplification rules.
 
     Each simplification is a function taking an equation collection, and returning a new simplified
diff --git a/astnodes.py b/astnodes.py
index 96d935391..7206624ef 100644
--- a/astnodes.py
+++ b/astnodes.py
@@ -8,7 +8,7 @@ from typing import List, Set, Optional, Union, Any
 NodeOrExpr = Union['Node', sp.Expr]
 
 
-class Node(object):
+class Node:
     """Base class for all AST nodes."""
 
     def __init__(self, parent: Optional['Node'] = None):
diff --git a/bitoperations.py b/bitoperations.py
index 841505e08..865772efb 100644
--- a/bitoperations.py
+++ b/bitoperations.py
@@ -1,6 +1,6 @@
 import sympy as sp
-bitwise_xor = sp.Function("⊻")
-bit_shift_right = sp.Function("rshift")
-bit_shift_left = sp.Function("lshift")
-bitwise_and = sp.Function("Bit&")
-bitwise_or = sp.Function("Bit|")
+bitwise_xor = sp.Function("bitwise_xor")
+bit_shift_right = sp.Function("bit_shift_right")
+bit_shift_left = sp.Function("bit_shift_left")
+bitwise_and = sp.Function("bitwise_and")
+bitwise_or = sp.Function("bitwise_or")
diff --git a/boundaries/boundaryconditions.py b/boundaries/boundaryconditions.py
index 5032faff9..db898d982 100644
--- a/boundaries/boundaryconditions.py
+++ b/boundaries/boundaryconditions.py
@@ -3,7 +3,7 @@ from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo
 from typing import List, Tuple, Any
 
 
-class Boundary(object):
+class Boundary:
     """Base class all boundaries should derive from"""
 
     def __init__(self, name=None):
diff --git a/boundaries/boundaryhandling.py b/boundaries/boundaryhandling.py
index 95dd77d82..314490e7a 100644
--- a/boundaries/boundaryhandling.py
+++ b/boundaries/boundaryhandling.py
@@ -7,30 +7,63 @@ from pystencils.boundaries.createindexlist import numpy_data_type_for_boundary_o
 from pystencils.cache import memorycache
 from pystencils.data_types import create_type
 
+DEFAULT_FLAG_TYPE = np.uint32
 
-class FlagInterface:
-    FLAG_DTYPE = np.uint32
 
-    def __init__(self, data_handling, flag_field_name):
+class FlagInterface:
+    """Manages the reservation of bits (i.e. flags) in an array of unsigned integers.
+
+    Examples:
+        >>> from pystencils.datahandling import SerialDataHandling
+        >>> dh = SerialDataHandling((4, 5))
+        >>> fi = FlagInterface(dh, 'flag_field', np.uint8)
+        >>> assert dh.has_data('flag_field')
+        >>> fi.reserve_next_flag()
+        2
+        >>> fi.reserve_flag(4)
+        4
+        >>> fi.reserve_next_flag()
+        8
+    """
+
+    def __init__(self, data_handling, flag_field_name, dtype=DEFAULT_FLAG_TYPE):
         self.flag_field_name = flag_field_name
-        self.domain_flag = self.FLAG_DTYPE(1 << 0)
-        self._nextFreeFlag = 1
+        self.domain_flag = dtype(1 << 0)
+        self._used_flags = {self.domain_flag}
         self.data_handling = data_handling
+        self.dtype = dtype
+        self.max_bits = self.dtype().itemsize * 8
 
         # Add flag field to data handling if it does not yet exist
         if data_handling.has_data(self.flag_field_name):
             raise ValueError("There is already a boundary handling registered at the data handling."
                              "If you want to add multiple handling objects, choose a different name.")
 
-        data_handling.add_array(self.flag_field_name, dtype=self.FLAG_DTYPE, cpu=True, gpu=False)
+        self.flag_field = data_handling.add_array(self.flag_field_name, dtype=self.dtype, cpu=True, gpu=False)
         ff_ghost_layers = data_handling.ghost_layers_of_field(self.flag_field_name)
         for b in data_handling.iterate(ghost_layers=ff_ghost_layers):
             b[self.flag_field_name].fill(self.domain_flag)
 
-    def allocate_next_flag(self):
-        result = self.FLAG_DTYPE(1 << self._nextFreeFlag)
-        self._nextFreeFlag += 1
-        return result
+    def reserve_next_flag(self):
+        for i in range(1, self.max_bits):
+            flag = self.dtype(1 << i)
+            if flag not in self._used_flags:
+                self._used_flags.add(flag)
+                assert self._is_power_of_2(flag)
+                return flag
+        raise ValueError(f"All available {self.max_bits} flags are reserved")
+
+    def reserve_flag(self, flag):
+        assert self._is_power_of_2(flag)
+        flag = self.dtype(flag)
+        if flag in self._used_flags:
+            raise ValueError(f"The flag {flag} is already reserved")
+        self._used_flags.add(flag)
+        return flag
+
+    @staticmethod
+    def _is_power_of_2(num):
+        return num != 0 and ((num & (num - 1)) == 0)
 
 
 class BoundaryHandling:
@@ -92,7 +125,7 @@ class BoundaryHandling:
             return result
 
     def set_boundary(self, boundary_obj, slice_obj=None, mask_callback=None,
-                     ghost_layers=True, inner_ghost_layers=True, replace=True):
+                     ghost_layers=True, inner_ghost_layers=True, replace=True, force_flag_value=None):
         """Sets boundary using either a rectangular slice, a boolean mask or a combination of both.
 
         Args:
@@ -111,11 +144,15 @@ class BoundaryHandling:
             inner_ghost_layers: see DataHandling.iterate()
             replace: by default all other flags are erased in the cells where the boundary is set. To add a
                      boundary condition, set this replace flag to False
+            force_flag_value: flag that should be reserved for this boundary. Has to be an integer that is a power of 2
+                              and was not reserved before for another boundary.
         """
         if isinstance(boundary_obj, str) and boundary_obj.lower() == 'domain':
             flag = self.flag_interface.domain_flag
         else:
-            flag = self._add_boundary(boundary_obj)
+            if force_flag_value:
+                self.flag_interface.reserve_flag(force_flag_value)
+            flag = self._add_boundary(boundary_obj, force_flag_value)
 
         for b in self._data_handling.iterate(slice_obj, ghost_layers=ghost_layers,
                                              inner_ghost_layers=inner_ghost_layers):
@@ -139,6 +176,7 @@ class BoundaryHandling:
         return flag
 
     def set_boundary_where_flag_is_set(self, boundary_obj, flag):
+        """Adds an (additional) boundary to all cells that have been previously marked with the passed flag."""
         self._add_boundary(boundary_obj, flag)
         self._dirty = True
         return flag
@@ -193,7 +231,8 @@ class BoundaryHandling:
             if b == 'domain':
                 masks_to_name[self.flag_interface.domain_flag] = 'domain'
             else:
-                masks_to_name[self._boundary_object_to_boundary_info[b].flag] = b.name
+                flag = self._boundary_object_to_boundary_info[b].flag
+                masks_to_name[flag] = b.name
 
         writer = self.data_handling.create_vtk_writer_for_flag_array(file_name, self.flag_interface.flag_field_name,
                                                                      masks_to_name, ghost_layers=ghost_layers)
@@ -208,7 +247,7 @@ class BoundaryHandling:
             ast = self._create_boundary_kernel(self._data_handling.fields[self._field_name],
                                                sym_index_field, boundary_obj)
             if flag is None:
-                flag = self.flag_interface.allocate_next_flag()
+                flag = self.flag_interface.reserve_next_flag()
             boundary_info = self.BoundaryInfo(boundary_obj, flag=flag, kernel=ast.compile())
             self._boundary_object_to_boundary_info[boundary_obj] = boundary_info
         return self._boundary_object_to_boundary_info[boundary_obj].flag
@@ -289,7 +328,7 @@ class BoundaryDataSetter:
         arr_field_names = index_array.dtype.names
         self.dim = 3 if 'z' in arr_field_names else 2
         assert 'x' in arr_field_names and 'y' in arr_field_names and 'dir' in arr_field_names, str(arr_field_names)
-        self.boundary_data_names = set(self.index_array.dtype.names) - set(['x', 'y', 'z', 'dir'])
+        self.boundary_data_names = set(self.index_array.dtype.names) - {'x', 'y', 'z', 'dir'}
         self.coord_map = {0: 'x', 1: 'y', 2: 'z'}
         self.ghost_layers = ghost_layers
 
diff --git a/data_types.py b/data_types.py
index d12065f2f..87c5b9678 100644
--- a/data_types.py
+++ b/data_types.py
@@ -470,7 +470,7 @@ class PointerType(Type):
         return hash((self._base_type, self.const, self.restrict))
 
 
-class StructType(object):
+class StructType:
     def __init__(self, numpy_type, const=False):
         self.const = const
         self._dtype = np.dtype(numpy_type)
diff --git a/field.py b/field.py
index 085f82c63..6fec4010d 100644
--- a/field.py
+++ b/field.py
@@ -35,7 +35,7 @@ class FieldType(Enum):
         return field.field_type == FieldType.BUFFER
 
 
-class Field(object):
+class Field:
     """
     With fields one can formulate stencil-like update rules on structured grids.
     This Field class knows about the dimension, memory layout (strides) and optionally about the size of an array.
-- 
GitLab