From eb8adcdf7cc76e638e096ad34db5a60b5a1ee447 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 9 Jun 2021 06:00:22 +0000
Subject: [PATCH] Stream-Only Kernel Fixes

---
 lbmpy/updatekernels.py            | 44 +++++++++++++++++--------------
 lbmpy_tests/test_update_kernel.py | 39 +++++++++++++++++----------
 2 files changed, 49 insertions(+), 34 deletions(-)

diff --git a/lbmpy/updatekernels.py b/lbmpy/updatekernels.py
index 8edb11b7..f979880b 100644
--- a/lbmpy/updatekernels.py
+++ b/lbmpy/updatekernels.py
@@ -55,11 +55,28 @@ def create_lbm_kernel(collision_rule, input_field, output_field, accessor):
     return result
 
 
-def create_stream_only_kernel(stencil, numpy_arr=None, src_field_name="src", dst_field_name="dst",
-                              generic_layout='numpy', generic_field_type=np.float64,
-                              accessor=StreamPullTwoFieldsAccessor()):
+def create_stream_only_kernel(stencil, src_field, dst_field, accessor=StreamPullTwoFieldsAccessor()):
     """Creates a stream kernel, without collision.
 
+    Args:
+        stencil: lattice Boltzmann stencil which is used
+        src_field: Field the pre-streaming values are read from
+        dst_field: Field the post-streaming values are written to
+        accessor: Field accessor which is used to create the update rule. See 'fieldaccess.PdfFieldAccessor'
+
+    Returns:
+        AssignmentCollection of the stream only update rule
+    """
+    temporary_symbols = sp.symbols(f'tmp_:{len(stencil)}')
+    subexpressions = [Assignment(tmp, acc) for tmp, acc in zip(temporary_symbols, accessor.read(src_field, stencil))]
+    main_assignments = [Assignment(acc, tmp) for acc, tmp in zip(accessor.write(dst_field, stencil), temporary_symbols)]
+    return AssignmentCollection(main_assignments, subexpressions=subexpressions)
+
+
+def create_stream_pull_only_kernel(stencil, numpy_arr=None, src_field_name="src", dst_field_name="dst",
+                                   generic_layout='numpy', generic_field_type=np.float64):
+    """Creates a stream kernel with the pull scheme, without collision.
+
     Args:
         stencil: lattice Boltzmann stencil which is used
         numpy_arr: numpy array which containes the pdf field data. If no numpy array is provided the symbolic field
@@ -68,11 +85,12 @@ def create_stream_only_kernel(stencil, numpy_arr=None, src_field_name="src", dst
         dst_field_name: name of the destination field.
         generic_layout: data layout. for example 'fzyx' of 'zyxf'.
         generic_field_type: field data type.
-        accessor: Field accessor which is used to create the update rule. See 'fieldaccess.PdfFieldAccessor'
 
     Returns:
         AssignmentCollection of the stream only update rule
     """
+    warnings.warn("This function is depricated. Please use create_stream_only_kernel. If no PdfFieldAccessor is "
+                  "provided to this function a standard StreamPullTwoFieldsAccessor is used ", DeprecationWarning)
     dim = len(stencil[0])
     if numpy_arr is None:
         src = Field.create_generic(src_field_name, dim, index_shape=(len(stencil),),
@@ -82,22 +100,7 @@ def create_stream_only_kernel(stencil, numpy_arr=None, src_field_name="src", dst
     else:
         src = Field.create_from_numpy_array(src_field_name, numpy_arr, index_dimensions=1)
         dst = Field.create_from_numpy_array(dst_field_name, numpy_arr, index_dimensions=1)
-
-    eqs = [Assignment(a, b) for a, b in zip(accessor.write(dst, stencil), accessor.read(src, stencil))]
-    return AssignmentCollection(eqs, [])
-
-
-def create_stream_pull_only_kernel(stencil, numpy_arr=None, src_field_name="src", dst_field_name="dst",
-                                   generic_layout='numpy', generic_field_type=np.float64):
-    """Creates a stream kernel with the pull scheme, without collision.
-
-    For parameters see function ``create_stream_pull_collide_kernel``
-    """
-    warnings.warn("This function is depricated. Please use create_stream_only_kernel. If no PdfFieldAccessor is "
-                  "provided to this function a standard StreamPullTwoFieldsAccessor is used ", DeprecationWarning)
-    return create_stream_only_kernel(stencil, numpy_arr=numpy_arr, src_field_name=src_field_name,
-                                     dst_field_name=dst_field_name, generic_layout=generic_layout,
-                                     generic_field_type=generic_field_type, accessor=StreamPullTwoFieldsAccessor())
+    return create_stream_only_kernel(stencil, src, dst, accessor=StreamPullTwoFieldsAccessor())
 
 
 def create_stream_pull_with_output_kernel(lb_method, src_field, dst_field, output):
@@ -114,6 +117,7 @@ def create_stream_pull_with_output_kernel(lb_method, src_field, dst_field, outpu
     return LbmCollisionRule(lb_method, main_eqs, subexpressions,
                             simplification_hints=output_eq_collection.simplification_hints)
 
+
 # ---------------------------------- Pdf array creation for various layouts --------------------------------------------
 
 
diff --git a/lbmpy_tests/test_update_kernel.py b/lbmpy_tests/test_update_kernel.py
index 43b780ad..993d06a2 100644
--- a/lbmpy_tests/test_update_kernel.py
+++ b/lbmpy_tests/test_update_kernel.py
@@ -3,27 +3,38 @@ import pytest
 import pystencils as ps
 
 from lbmpy.stencils import get_stencil
-from lbmpy.fieldaccess import StreamPullTwoFieldsAccessor, StreamPushTwoFieldsAccessor,\
-    AAOddTimeStepAccessor, AAEvenTimeStepAccessor, EsoTwistOddTimeStepAccessor, EsoTwistEvenTimeStepAccessor
+from lbmpy.advanced_streaming.utility import get_timesteps, streaming_patterns, get_accessor, is_inplace, AccessPdfValues
 from lbmpy.updatekernels import create_stream_only_kernel
+from pystencils import create_kernel
 
 
-@pytest.mark.parametrize('accessor', [StreamPullTwoFieldsAccessor(), StreamPushTwoFieldsAccessor(),
-                                      AAOddTimeStepAccessor(), AAEvenTimeStepAccessor(),
-                                      EsoTwistOddTimeStepAccessor(), EsoTwistEvenTimeStepAccessor()])
-def test_stream_only_kernel(accessor):
+@pytest.mark.parametrize('streaming_pattern', streaming_patterns)
+def test_stream_only_kernel(streaming_pattern):
     domain_size = (4, 4)
     stencil = get_stencil("D2Q9")
     dh = ps.create_data_handling(domain_size, default_target='cpu')
+    pdfs = dh.add_array('pdfs', values_per_cell=len(stencil))
+    pdfs_tmp = dh.add_array_like('pdfs_tmp', 'pdfs')
 
-    src = dh.add_array('src', values_per_cell=len(stencil))
-    dh.fill('src', 0.0, ghost_layers=True)
+    for t in get_timesteps(streaming_pattern):
+        accessor = get_accessor(streaming_pattern, t)
+        src = pdfs
+        dst = pdfs if is_inplace(streaming_pattern) else pdfs_tmp
 
-    dst = dh.add_array_like('dst', 'src')
-    dh.fill('dst', 0.0, ghost_layers=True)
+        dh.fill(src.name, 0.0)
+        dh.fill(dst.name, 0.0)
 
-    pull = create_stream_only_kernel(stencil, None, src.name, dst.name, accessor=accessor)
+        stream_kernel = create_stream_only_kernel(stencil, src, dst, accessor=accessor)
+        stream_func = create_kernel(stream_kernel).compile()
 
-    for i, eq in enumerate(pull.main_assignments):
-        assert eq.rhs.offsets == accessor.read(src, stencil)[i].offsets
-        assert eq.lhs.offsets == accessor.write(dst, stencil)[i].offsets
\ No newline at end of file
+        #   Check functionality
+        acc_in = AccessPdfValues(stencil, streaming_dir='in', accessor=accessor)
+        for i in range(len(stencil)):
+            acc_in.write_pdf(dh.cpu_arrays[src.name], (1,1), i, i)
+
+        dh.run_kernel(stream_func)
+
+        acc_out = AccessPdfValues(stencil, streaming_dir='out', accessor=accessor)
+        for i in range(len(stencil)):
+            assert acc_out.read_pdf(dh.cpu_arrays[dst.name], (1,1), i) == i
+        
-- 
GitLab