From 7750926f40a965ddcf51365d6392968ab85b2d78 Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Tue, 13 Aug 2019 14:03:52 +0200
Subject: [PATCH] Remove code duplication from Philox generation

---
 pystencils/include/philox_rand.h |  3 +-
 pystencils/rng.py                | 72 ++++++++++----------------------
 2 files changed, 23 insertions(+), 52 deletions(-)

diff --git a/pystencils/include/philox_rand.h b/pystencils/include/philox_rand.h
index bbf1c1a..2832049 100644
--- a/pystencils/include/philox_rand.h
+++ b/pystencils/include/philox_rand.h
@@ -52,8 +52,7 @@ QUALIFIERS void _philox4x32bumpkey(uint32* key)
 
 QUALIFIERS double _uniform_double_hq(uint32 x, uint32 y)
 {
-    unsigned long long z = (unsigned long long)x ^
-        ((unsigned long long)y << (53 - 32));
+    uint64 z = (uint64)x ^ ((uint64)y << (53 - 32));
     return z * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0);
 }
 
diff --git a/pystencils/rng.py b/pystencils/rng.py
index 1982516..48f3fb9 100644
--- a/pystencils/rng.py
+++ b/pystencils/rng.py
@@ -5,21 +5,18 @@ from pystencils import TypedSymbol
 from pystencils.astnodes import LoopOverCoordinate
 from pystencils.backends.cbackend import CustomCodeNode
 
-philox_two_doubles_call = """
-{result_symbols[0].dtype} {result_symbols[0].name};
-{result_symbols[1].dtype} {result_symbols[1].name};
-philox_double2({parameters}, {result_symbols[0].name}, {result_symbols[1].name});
-"""
 
-philox_four_floats_call = """
-{result_symbols[0].dtype} {result_symbols[0].name};
-{result_symbols[1].dtype} {result_symbols[1].name};
-{result_symbols[2].dtype} {result_symbols[2].name};
-{result_symbols[3].dtype} {result_symbols[3].name};
-philox_float4({parameters},
-              {result_symbols[0].name}, {result_symbols[1].name}, {result_symbols[2].name}, {result_symbols[3].name});
-
-"""
+def _get_philox_template(data_type, num_vars):
+    if data_type is np.float32:
+        c_type = "float"
+    elif data_type is np.float64:
+        c_type = "double"
+    template = "\n"
+    for i in range(num_vars):
+        template += "{{result_symbols[{}].dtype}} {{result_symbols[{}].name}};\n".format(i, i)
+    template += ("philox_{}{}({{parameters}}, " + ", ".join(["{{result_symbols[{}].name}}"] * num_vars) + ");\n") \
+        .format(c_type, num_vars, *tuple(range(num_vars)))
+    return template
 
 
 def _get_philox_code(template, dialect, vector_instruction_set, time_step, offsets, keys, dim, result_symbols):
@@ -39,10 +36,10 @@ def _get_philox_code(template, dialect, vector_instruction_set, time_step, offse
         raise NotImplementedError("Not yet implemented for this backend")
 
 
-class PhiloxTwoDoubles(CustomCodeNode):
+class PhiloxBase(CustomCodeNode):
 
     def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=(0, 0, 0), keys=(0, 0)):
-        self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float64) for _ in range(2))
+        self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, self._data_type) for _ in range(self._num_vars))
         symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
         super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
         self._time_step = time_step
@@ -68,47 +65,22 @@ class PhiloxTwoDoubles(CustomCodeNode):
         return self  # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
 
     def get_code(self, dialect, vector_instruction_set):
-        return _get_philox_code(philox_two_doubles_call, dialect, vector_instruction_set,
+        template = _get_philox_template(self._data_type, self._num_vars)
+        return _get_philox_code(template, dialect, vector_instruction_set,
                                 self._time_step, self._offsets, self.keys, self._dim, self.result_symbols)
 
     def __repr__(self):
-        return "{}, {} <- PhiloxRNG".format(*self.result_symbols)
-
-
-class PhiloxFourFloats(CustomCodeNode):
-
-    def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=(0, 0, 0), keys=(0, 0)):
-        self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float32) for _ in range(4))
-        symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
-        super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
-        self._time_step = time_step
-        self._offsets = offsets
-        self.headers = ['"philox_rand.h"']
-        self.keys = tuple(keys)
-        self._args = sp.sympify((dim, time_step, offsets, keys))
-        self._dim = dim
-
-    @property
-    def args(self):
-        return self._args
+        return (", ".join(['{}'] * self._num_vars) + " <- PhiloxRNG").format(*self.result_symbols)
 
-    @property
-    def undefined_symbols(self):
-        result = {a for a in (self._time_step, *self._offsets, *self.keys) if isinstance(a, sp.Symbol)}
-        loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
-                         for i in range(self._dim)]
-        result.update(loop_counters)
-        return result
 
-    def fast_subs(self, _):
-        return self  # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
+class PhiloxTwoDoubles(PhiloxBase):
+    _data_type = np.float64
+    _num_vars = 2
 
-    def get_code(self, dialect, vector_instruction_set):
-        return _get_philox_code(philox_four_floats_call, dialect, vector_instruction_set,
-                                self._time_step, self._offsets, self.keys, self._dim, self.result_symbols)
 
-    def __repr__(self):
-        return "{}, {}, {}, {} <- PhiloxRNG".format(*self.result_symbols)
+class PhiloxFourFloats(PhiloxBase):
+    _data_type = np.float32
+    _num_vars = 4
 
 
 def random_symbol(assignment_list, seed=TypedSymbol("seed", np.uint32), rng_node=PhiloxTwoDoubles, *args, **kwargs):
-- 
GitLab