From eaeec78a46d684cb6fff07b48461489420ad6288 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Thu, 8 Aug 2019 16:51:10 +0200
Subject: [PATCH] RNG: Possibility to pass seed and block offset parameters

---
 pystencils/fd/derivation.py              |  5 +-
 pystencils/rng.py                        | 67 ++++++++++----------
 pystencils/simp/assignment_collection.py | 42 ++-----------
 pystencils/simp/simplifications.py       | 78 ++++++++++++++++++------
 4 files changed, 96 insertions(+), 96 deletions(-)

diff --git a/pystencils/fd/derivation.py b/pystencils/fd/derivation.py
index 6cb11a93c..bfb4f7393 100644
--- a/pystencils/fd/derivation.py
+++ b/pystencils/fd/derivation.py
@@ -1,14 +1,13 @@
+import warnings
 from collections import defaultdict
 
-import sympy as sp
 import numpy as np
+import sympy as sp
 
 from pystencils.field import Field
 from pystencils.sympyextensions import multidimensional_sum, prod
 from pystencils.utils import LinearEquationSystem, fully_contains
 
-import warnings
-
 
 class FiniteDifferenceStencilDerivation:
     """Derives finite difference stencils.
diff --git a/pystencils/rng.py b/pystencils/rng.py
index 26a92b313..198251624 100644
--- a/pystencils/rng.py
+++ b/pystencils/rng.py
@@ -22,13 +22,31 @@ philox_float4({parameters},
 """
 
 
+def _get_philox_code(template, dialect, vector_instruction_set, time_step, offsets, keys, dim, result_symbols):
+    parameters = [time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i]
+                                for i in range(dim)] + list(keys)
+
+    while len(parameters) < 6:
+        parameters.append(0)
+    parameters = parameters[:6]
+
+    assert len(parameters) == 6
+
+    if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
+        return template.format(parameters=', '.join(str(p) for p in parameters),
+                               result_symbols=result_symbols)
+    else:
+        raise NotImplementedError("Not yet implemented for this backend")
+
+
 class PhiloxTwoDoubles(CustomCodeNode):
 
-    def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), keys=(0, 0)):
+    def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=(0, 0, 0), keys=(0, 0)):
         self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float64) for _ in range(2))
         symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
         super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
         self._time_step = time_step
+        self._offsets = offsets
         self.headers = ['"philox_rand.h"']
         self.keys = tuple(keys)
         self._args = sp.sympify((dim, time_step, keys))
@@ -40,7 +58,7 @@ class PhiloxTwoDoubles(CustomCodeNode):
 
     @property
     def undefined_symbols(self):
-        result = {a for a in self.args if isinstance(a, sp.Symbol)}
+        result = {a for a in (self._time_step, *self._offsets, *self.keys) if isinstance(a, sp.Symbol)}
         loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
                          for i in range(self._dim)]
         result.update(loop_counters)
@@ -50,20 +68,8 @@ class PhiloxTwoDoubles(CustomCodeNode):
         return self  # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
 
     def get_code(self, dialect, vector_instruction_set):
-        parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i)
-                                          for i in range(self._dim)] + list(self.keys)
-
-        while len(parameters) < 6:
-            parameters.append(0)
-        parameters = parameters[:6]
-
-        assert len(parameters) == 6
-
-        if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
-            return philox_two_doubles_call.format(parameters=', '.join(str(p) for p in parameters),
-                                                  result_symbols=self.result_symbols)
-        else:
-            raise NotImplementedError("Not yet implemented for this backend")
+        return _get_philox_code(philox_two_doubles_call, dialect, vector_instruction_set,
+                                self._time_step, self._offsets, self.keys, self._dim, self.result_symbols)
 
     def __repr__(self):
         return "{}, {} <- PhiloxRNG".format(*self.result_symbols)
@@ -71,15 +77,15 @@ class PhiloxTwoDoubles(CustomCodeNode):
 
 class PhiloxFourFloats(CustomCodeNode):
 
-    def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), keys=(0, 0)):
+    def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=(0, 0, 0), keys=(0, 0)):
         self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float32) for _ in range(4))
         symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
-
         super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
         self._time_step = time_step
+        self._offsets = offsets
         self.headers = ['"philox_rand.h"']
         self.keys = tuple(keys)
-        self._args = sp.sympify((dim, time_step, keys))
+        self._args = sp.sympify((dim, time_step, offsets, keys))
         self._dim = dim
 
     @property
@@ -88,7 +94,7 @@ class PhiloxFourFloats(CustomCodeNode):
 
     @property
     def undefined_symbols(self):
-        result = {a for a in self.args if isinstance(a, sp.Symbol)}
+        result = {a for a in (self._time_step, *self._offsets, *self.keys) if isinstance(a, sp.Symbol)}
         loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
                          for i in range(self._dim)]
         result.update(loop_counters)
@@ -98,28 +104,17 @@ class PhiloxFourFloats(CustomCodeNode):
         return self  # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
 
     def get_code(self, dialect, vector_instruction_set):
-        parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i)
-                                          for i in range(self._dim)] + list(self.keys)
-
-        while len(parameters) < 6:
-            parameters.append(0)
-        parameters = parameters[:6]
-
-        assert len(parameters) == 6
-
-        if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
-            return philox_four_floats_call.format(parameters=', '.join(str(p) for p in parameters),
-                                                  result_symbols=self.result_symbols)
-        else:
-            raise NotImplementedError("Not yet implemented for this backend")
+        return _get_philox_code(philox_four_floats_call, dialect, vector_instruction_set,
+                                self._time_step, self._offsets, self.keys, self._dim, self.result_symbols)
 
     def __repr__(self):
         return "{}, {}, {}, {} <- PhiloxRNG".format(*self.result_symbols)
 
 
-def random_symbol(assignment_list, rng_node=PhiloxTwoDoubles, *args, **kwargs):
+def random_symbol(assignment_list, seed=TypedSymbol("seed", np.uint32), rng_node=PhiloxTwoDoubles, *args, **kwargs):
+    counter = 0
     while True:
-        node = rng_node(*args, **kwargs)
+        node = rng_node(*args, keys=(counter, seed), **kwargs)
         inserted = False
         for symbol in node.result_symbols:
             if not inserted:
diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py
index fca052ef9..9d253ff7b 100644
--- a/pystencils/simp/assignment_collection.py
+++ b/pystencils/simp/assignment_collection.py
@@ -4,25 +4,12 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set,
 import sympy as sp
 
 from pystencils.assignment import Assignment
-from pystencils.astnodes import Node
+from pystencils.simp.simplifications import (
+    sort_assignments_topologically, sympy_cse_on_assignment_list,
+    transform_lhs_and_rhs, transform_rhs)
 from pystencils.sympyextensions import count_operations, fast_subs
 
 
-def transform_rhs(assignment_list, transformation, *args, **kwargs):
-    """Applies a transformation function on the rhs of each element of the passed assignment list
-    If the list also contains other object, like AST nodes, these are ignored.
-    Additional parameters are passed to the transformation function"""
-    return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a
-            for a in assignment_list]
-
-
-def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
-    return [Assignment(transformation(a.lhs, *args, **kwargs),
-                       transformation(a.rhs, *args, **kwargs))
-            if isinstance(a, Assignment) else a
-            for a in assignment_list]
-
-
 class AssignmentCollection:
     """
     A collection of equations with subexpression definitions, also represented as assignments,
@@ -98,9 +85,9 @@ class AssignmentCollection:
     def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:
         """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition."""
         if sort_subexpressions:
-            self.subexpressions = sort_assignments_topologically(self.subexpressions)
+            self.subexpressions = sympy_cse_on_assignment_list(self.subexpressions)
         if sort_main_assignments:
-            self.main_assignments = sort_assignments_topologically(self.main_assignments)
+            self.main_assignments = sympy_cse_on_assignment_list(self.main_assignments)
 
     # ---------------------------------------------- Properties  -------------------------------------------------------
 
@@ -419,22 +406,3 @@ class SymbolGen:
         name = "{}_{}".format(self._symbol, self._ctr)
         self._ctr += 1
         return sp.Symbol(name)
-
-
-def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
-    """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
-    edges = []
-    for c1, e1 in enumerate(assignments):
-        if isinstance(e1, Assignment):
-            symbols = [e1.lhs]
-        elif isinstance(e1, Node):
-            symbols = e1.symbols_defined
-        else:
-            symbols = []
-        for lhs in symbols:
-            for c2, e2 in enumerate(assignments):
-                if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:
-                    edges.append((c1, c2))
-                elif isinstance(e2, Node) and lhs in e2.undefined_symbols:
-                    edges.append((c1, c2))
-    return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]
diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py
index 24726fd29..ab2b3d83d 100644
--- a/pystencils/simp/simplifications.py
+++ b/pystencils/simp/simplifications.py
@@ -1,16 +1,34 @@
-from typing import Callable, List
+from itertools import chain
+from typing import Callable, List, Sequence, Union
 
 import sympy as sp
 
 from pystencils.assignment import Assignment
+from pystencils.astnodes import Node
 from pystencils.field import AbstractField, Field
-from pystencils.simp.assignment_collection import AssignmentCollection, transform_rhs
 from pystencils.sympyextensions import subs_additive
 
-AC = AssignmentCollection
+
+def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
+    """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
+    edges = []
+    for c1, e1 in enumerate(assignments):
+        if isinstance(e1, Assignment):
+            symbols = [e1.lhs]
+        elif isinstance(e1, Node):
+            symbols = e1.symbols_defined
+        else:
+            symbols = []
+        for lhs in symbols:
+            for c2, e2 in enumerate(assignments):
+                if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:
+                    edges.append((c1, c2))
+                elif isinstance(e2, Node) and lhs in e2.undefined_symbols:
+                    edges.append((c1, c2))
+    return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]
 
 
-def sympy_cse(ac: AC) -> AC:
+def sympy_cse(ac):
     """Searches for common subexpressions inside the equation collection.
 
     Searches is done in both the existing subexpressions as well as the assignments themselves.
@@ -18,27 +36,28 @@ def sympy_cse(ac: AC) -> AC:
     with the additional subexpressions found
     """
     symbol_gen = ac.subexpression_symbol_generator
-    replacements, new_eq = sp.cse(ac.subexpressions + ac.main_assignments,
-                                  symbols=symbol_gen)
+
+    all_assignments = [e for e in chain(ac.subexpressions, ac.main_assignments) if isinstance(e, Assignment)]
+    other_objects = [e for e in chain(ac.subexpressions, ac.main_assignments) if not isinstance(e, Assignment)]
+    replacements, new_eq = sp.cse(all_assignments, symbols=symbol_gen)
+
     replacement_eqs = [Assignment(*r) for r in replacements]
 
     modified_subexpressions = new_eq[:len(ac.subexpressions)]
     modified_update_equations = new_eq[len(ac.subexpressions):]
 
-    new_subexpressions = replacement_eqs + modified_subexpressions
-    topologically_sorted_pairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in new_subexpressions])
-    new_subexpressions = [Assignment(a[0], a[1]) for a in topologically_sorted_pairs]
-
+    new_subexpressions = sort_assignments_topologically(other_objects + replacement_eqs + modified_subexpressions)
     return ac.copy(modified_update_equations, new_subexpressions)
 
 
 def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
     """Extracts common subexpressions from a list of assignments."""
-    ec = AC([], assignments)
+    from pystencils.simp.assignment_collection import AssignmentCollection
+    ec = AssignmentCollection([], assignments)
     return sympy_cse(ec).all_assignments
 
 
-def subexpression_substitution_in_existing_subexpressions(ac: AC) -> AC:
+def subexpression_substitution_in_existing_subexpressions(ac):
     """Goes through the subexpressions list and replaces the term in the following subexpressions."""
     result = []
     for outer_ctr, s in enumerate(ac.subexpressions):
@@ -52,7 +71,7 @@ def subexpression_substitution_in_existing_subexpressions(ac: AC) -> AC:
     return ac.copy(ac.main_assignments, result)
 
 
-def subexpression_substitution_in_main_assignments(ac: AC) -> AC:
+def subexpression_substitution_in_main_assignments(ac):
     """Replaces already existing subexpressions in the equations of the assignment_collection."""
     result = []
     for s in ac.main_assignments:
@@ -63,7 +82,7 @@ def subexpression_substitution_in_main_assignments(ac: AC) -> AC:
     return ac.copy(result)
 
 
-def add_subexpressions_for_divisions(ac: AC) -> AC:
+def add_subexpressions_for_divisions(ac):
     r"""Introduces subexpressions for all divisions which have no constant in the denominator.
 
     For example :math:`\frac{1}{x}` is replaced while :math:`\frac{1}{3}` is not replaced.
@@ -87,7 +106,7 @@ def add_subexpressions_for_divisions(ac: AC) -> AC:
     return ac.new_with_substitutions(substitutions, True)
 
 
-def add_subexpressions_for_sums(ac: AC) -> AC:
+def add_subexpressions_for_sums(ac):
     r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions."""
     addends = []
 
@@ -114,7 +133,7 @@ def add_subexpressions_for_sums(ac: AC) -> AC:
     return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)
 
 
-def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignments=True) -> AC:
+def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True):
     r"""Substitutes field accesses on rhs of assignments with subexpressions
 
     Can change semantics of the update rule (which is the goal of this transformation)
@@ -132,17 +151,36 @@ def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignm
     return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)
 
 
-def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]:
+def transform_rhs(assignment_list, transformation, *args, **kwargs):
+    """Applies a transformation function on the rhs of each element of the passed assignment list
+    If the list also contains other object, like AST nodes, these are ignored.
+    Additional parameters are passed to the transformation function"""
+    return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a
+            for a in assignment_list]
+
+
+def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
+    return [Assignment(transformation(a.lhs, *args, **kwargs),
+                       transformation(a.rhs, *args, **kwargs))
+            if isinstance(a, Assignment) else a
+            for a in assignment_list]
+
+
+def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]):
     """Applies sympy expand operation to all equations in collection."""
-    def f(ac: AC) -> AC:
+
+    def f(ac):
         return ac.copy(transform_rhs(ac.main_assignments, operation))
+
     f.__name__ = operation.__name__
     return f
 
 
-def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]:
+def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
     """Applies the given operation on all subexpressions of the AC."""
-    def f(ac: AC) -> AC:
+
+    def f(ac):
         return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation))
+
     f.__name__ = operation.__name__
     return f
-- 
GitLab