From fd7730d0758a480fb70ab619e23a76679c9592fe Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Wed, 9 Dec 2020 16:50:27 +0100
Subject: [PATCH] Make the RNG node behave more like a regular node

---
 .gitlab-ci.yml                  |  2 +-
 pystencils/astnodes.py          |  8 +++-
 pystencils/rng.py               | 53 +++++++++-------------
 pystencils_tests/test_random.py | 78 ++++++++++++++++-----------------
 4 files changed, 67 insertions(+), 74 deletions(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 727db59f6..781368840 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -142,7 +142,7 @@ pycodegen-integration:
     - cd ../pygrandchem
     - py.test -v -n $NUM_CORES .
     - cd ../walberla/build/
-    - make CodegenJacobiCPU CodegenJacobiGPU CodegenPoissonCPU CodegenPoissonGPU MicroBenchmarkGpuLbm LbCodeGenerationExample UniformGridBenchmarkGPU_trt UniformGridBenchmarkGPU_entropic_kbc_n4
+    - make CodegenJacobiCPU CodegenJacobiGPU CodegenPoissonCPU CodegenPoissonGPU MicroBenchmarkGpuLbm LbCodeGenerationExample UniformGridBenchmarkGPU_trt UniformGridBenchmarkGPU_entropic_kbc_n4 FluctuatingMRT
     - cd apps/benchmarks/UniformGridGPU
     - make -j $NUM_CORES
     - cd ../UniformGridGenerated
diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 58a438717..b874db9b0 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -37,8 +37,12 @@ class Node:
 
     def subs(self, subs_dict) -> None:
         """Inplace! Substitute, similar to sympy's but modifies the AST inplace."""
-        for a in self.args:
-            a.subs(subs_dict)
+        for i, a in enumerate(self.args):
+            result = a.subs(subs_dict)
+            if isinstance(a, sp.Expr):  # sympy expressions' subs is out-of-place
+                self.args[i] = result
+            else:  # all other should be in-place
+                assert result is None
 
     @property
     def func(self):
diff --git a/pystencils/rng.py b/pystencils/rng.py
index 27606324a..f567e0c1b 100644
--- a/pystencils/rng.py
+++ b/pystencils/rng.py
@@ -19,11 +19,9 @@ def _get_rng_template(name, data_type, num_vars):
     return template
 
 
-def _get_rng_code(template, dialect, vector_instruction_set, time_step, coordinates, keys, dim, result_symbols):
-    parameters = [time_step] + coordinates + [0] * (3 - dim) + list(keys)
-
+def _get_rng_code(template, dialect, vector_instruction_set, args, result_symbols):
     if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
-        return template.format(parameters=', '.join(str(p) for p in parameters),
+        return template.format(parameters=', '.join(str(a) for a in args),
                                result_symbols=result_symbols)
     else:
         raise NotImplementedError("Not yet implemented for this backend")
@@ -31,47 +29,38 @@ def _get_rng_code(template, dialect, vector_instruction_set, time_step, coordina
 
 class RNGBase(CustomCodeNode):
 
-    def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=(0, 0, 0), keys=None):
+    id = 0
+
+    def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=None, keys=None):
         if keys is None:
             keys = (0,) * self._num_keys
+        if offsets is None:
+            offsets = (0,) * dim
         if len(keys) != self._num_keys:
             raise ValueError(f"Provided {len(keys)} keys but need {self._num_keys}")
-        if len(offsets) != 3:
-            raise ValueError(f"Provided {len(offsets)} offsets but need {3}")
-        self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, self._data_type) for _ in range(self._num_vars))
-        symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
+        if len(offsets) != dim:
+            raise ValueError(f"Provided {len(offsets)} offsets but need {dim}")
+        coordinates = [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] for i in range(dim)]
+        if dim < 3:
+            coordinates.append(0)
+
+        self._args = sp.sympify([time_step, *coordinates, *keys])
+        self.result_symbols = tuple(TypedSymbol(f'random_{self.id}_{i}', self._data_type)
+                                    for i in range(self._num_vars))
+        symbols_read = set.union(*[s.atoms(sp.Symbol) for s in self.args])
         super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
-        self._time_step = time_step
-        self._offsets = offsets
-        self._coordinates = [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] for i in range(dim)]
+
         self.headers = [f'"{self._name}_rand.h"']
-        self.keys = tuple(keys)
-        self._args = sp.sympify((dim, time_step, keys))
-        self._dim = dim
+
+        RNGBase.id += 1
 
     @property
     def args(self):
         return self._args
 
-    @property
-    def undefined_symbols(self):
-        result = {a for a in (self._time_step, *self._offsets, *self.keys) if isinstance(a, sp.Symbol)}
-        loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
-                         for i in range(self._dim)]
-        result.update(loop_counters)
-        return result
-
-    def subs(self, subs_dict) -> None:
-        for i in range(len(self._coordinates)):
-            self._coordinates[i] = self._coordinates[i].subs(subs_dict)
-
-    def fast_subs(self, *_):
-        return self  # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
-
     def get_code(self, dialect, vector_instruction_set):
         template = _get_rng_template(self._name, self._data_type, self._num_vars)
-        return _get_rng_code(template, dialect, vector_instruction_set,
-                             self._time_step, self._coordinates, self.keys, self._dim, self.result_symbols)
+        return _get_rng_code(template, dialect, vector_instruction_set, self.args, self.result_symbols)
 
     def __repr__(self):
         return (", ".join(['{}'] * self._num_vars) + " \\leftarrow {}RNG").format(*self.result_symbols,
diff --git a/pystencils_tests/test_random.py b/pystencils_tests/test_random.py
index 9c7ee4dd0..322718d1b 100644
--- a/pystencils_tests/test_random.py
+++ b/pystencils_tests/test_random.py
@@ -12,58 +12,58 @@ philox_reference = np.array([[[3576608082, 1252663339, 1987745383,  348040302],
                              [[2958765206, 3725192638, 2623672781, 1373196132],
                               [ 850605163, 1694561295, 3285694973, 2799652583]]])
 
-def test_philox_double():
-    for target in ('cpu', 'gpu'):
-        if target == 'gpu':
-            pytest.importorskip('pycuda')
+@pytest.mark.parametrize('target', ('cpu', 'gpu'))
+def test_philox_double(target):
+    if target == 'gpu':
+        pytest.importorskip('pycuda')
 
-        dh = ps.create_data_handling((2, 2), default_ghost_layers=0, default_target=target)
-        f = dh.add_array("f", values_per_cell=2)
+    dh = ps.create_data_handling((2, 2), default_ghost_layers=0, default_target=target)
+    f = dh.add_array("f", values_per_cell=2)
 
-        dh.fill('f', 42.0)
+    dh.fill('f', 42.0)
 
-        philox_node = PhiloxTwoDoubles(dh.dim)
-        assignments = [philox_node,
-                       ps.Assignment(f(0), philox_node.result_symbols[0]),
-                       ps.Assignment(f(1), philox_node.result_symbols[1])]
-        kernel = ps.create_kernel(assignments, target=dh.default_target).compile()
+    philox_node = PhiloxTwoDoubles(dh.dim)
+    assignments = [philox_node,
+                   ps.Assignment(f(0), philox_node.result_symbols[0]),
+                   ps.Assignment(f(1), philox_node.result_symbols[1])]
+    kernel = ps.create_kernel(assignments, target=dh.default_target).compile()
 
-        dh.all_to_gpu()
-        dh.run_kernel(kernel, time_step=124)
-        dh.all_to_cpu()
+    dh.all_to_gpu()
+    dh.run_kernel(kernel, time_step=124)
+    dh.all_to_cpu()
 
-        arr = dh.gather_array('f')
-        assert np.logical_and(arr <= 1.0, arr >= 0).all()
+    arr = dh.gather_array('f')
+    assert np.logical_and(arr <= 1.0, arr >= 0).all()
 
-        x = philox_reference[:,:,0::2]
-        y = philox_reference[:,:,1::2]
-        z = x ^ y << (53 - 32)
-        double_reference = z * 2.**-53 + 2.**-54
-        assert(np.allclose(arr, double_reference, rtol=0, atol=np.finfo(np.float64).eps))
+    x = philox_reference[:,:,0::2]
+    y = philox_reference[:,:,1::2]
+    z = x ^ y << (53 - 32)
+    double_reference = z * 2.**-53 + 2.**-54
+    assert(np.allclose(arr, double_reference, rtol=0, atol=np.finfo(np.float64).eps))
 
 
-def test_philox_float():
-    for target in ('cpu', 'gpu'):
-        if target == 'gpu':
-            pytest.importorskip('pycuda')
+@pytest.mark.parametrize('target', ('cpu', 'gpu'))
+def test_philox_float(target):
+    if target == 'gpu':
+        pytest.importorskip('pycuda')
 
-        dh = ps.create_data_handling((2, 2), default_ghost_layers=0, default_target=target)
-        f = dh.add_array("f", values_per_cell=4)
+    dh = ps.create_data_handling((2, 2), default_ghost_layers=0, default_target=target)
+    f = dh.add_array("f", values_per_cell=4)
 
-        dh.fill('f', 42.0)
+    dh.fill('f', 42.0)
 
-        philox_node = PhiloxFourFloats(dh.dim)
-        assignments = [philox_node] + [ps.Assignment(f(i), philox_node.result_symbols[i]) for i in range(4)]
-        kernel = ps.create_kernel(assignments, target=dh.default_target).compile()
+    philox_node = PhiloxFourFloats(dh.dim)
+    assignments = [philox_node] + [ps.Assignment(f(i), philox_node.result_symbols[i]) for i in range(4)]
+    kernel = ps.create_kernel(assignments, target=dh.default_target).compile()
 
-        dh.all_to_gpu()
-        dh.run_kernel(kernel, time_step=124)
-        dh.all_to_cpu()
-        arr = dh.gather_array('f')
-        assert np.logical_and(arr <= 1.0, arr >= 0).all()
+    dh.all_to_gpu()
+    dh.run_kernel(kernel, time_step=124)
+    dh.all_to_cpu()
+    arr = dh.gather_array('f')
+    assert np.logical_and(arr <= 1.0, arr >= 0).all()
 
-        float_reference = philox_reference * 2.**-32 + 2.**-33
-        assert(np.allclose(arr, float_reference, rtol=0, atol=np.finfo(np.float32).eps))
+    float_reference = philox_reference * 2.**-32 + 2.**-33
+    assert(np.allclose(arr, float_reference, rtol=0, atol=np.finfo(np.float32).eps))
 
 def test_aesni_double():
     dh = ps.create_data_handling((2, 2), default_ghost_layers=0, default_target="cpu")
-- 
GitLab