Commit eaeec78a authored by Martin Bauer's avatar Martin Bauer
Browse files

RNG: Possibility to pass seed and block offset parameters

parent d7ad30ff
import warnings
from collections import defaultdict from collections import defaultdict
import sympy as sp
import numpy as np import numpy as np
import sympy as sp
from pystencils.field import Field from pystencils.field import Field
from pystencils.sympyextensions import multidimensional_sum, prod from pystencils.sympyextensions import multidimensional_sum, prod
from pystencils.utils import LinearEquationSystem, fully_contains from pystencils.utils import LinearEquationSystem, fully_contains
import warnings
class FiniteDifferenceStencilDerivation: class FiniteDifferenceStencilDerivation:
"""Derives finite difference stencils. """Derives finite difference stencils.
......
...@@ -22,13 +22,31 @@ philox_float4({parameters}, ...@@ -22,13 +22,31 @@ philox_float4({parameters},
""" """
def _get_philox_code(template, dialect, vector_instruction_set, time_step, offsets, keys, dim, result_symbols):
parameters = [time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i]
for i in range(dim)] + list(keys)
while len(parameters) < 6:
parameters.append(0)
parameters = parameters[:6]
assert len(parameters) == 6
if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
return template.format(parameters=', '.join(str(p) for p in parameters),
result_symbols=result_symbols)
else:
raise NotImplementedError("Not yet implemented for this backend")
class PhiloxTwoDoubles(CustomCodeNode): class PhiloxTwoDoubles(CustomCodeNode):
def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), keys=(0, 0)): def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=(0, 0, 0), keys=(0, 0)):
self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float64) for _ in range(2)) self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float64) for _ in range(2))
symbols_read = [s for s in keys if isinstance(s, sp.Symbol)] symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols) super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self._time_step = time_step self._time_step = time_step
self._offsets = offsets
self.headers = ['"philox_rand.h"'] self.headers = ['"philox_rand.h"']
self.keys = tuple(keys) self.keys = tuple(keys)
self._args = sp.sympify((dim, time_step, keys)) self._args = sp.sympify((dim, time_step, keys))
...@@ -40,7 +58,7 @@ class PhiloxTwoDoubles(CustomCodeNode): ...@@ -40,7 +58,7 @@ class PhiloxTwoDoubles(CustomCodeNode):
@property @property
def undefined_symbols(self): def undefined_symbols(self):
result = {a for a in self.args if isinstance(a, sp.Symbol)} result = {a for a in (self._time_step, *self._offsets, *self.keys) if isinstance(a, sp.Symbol)}
loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i) loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
for i in range(self._dim)] for i in range(self._dim)]
result.update(loop_counters) result.update(loop_counters)
...@@ -50,20 +68,8 @@ class PhiloxTwoDoubles(CustomCodeNode): ...@@ -50,20 +68,8 @@ class PhiloxTwoDoubles(CustomCodeNode):
return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
def get_code(self, dialect, vector_instruction_set): def get_code(self, dialect, vector_instruction_set):
parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) return _get_philox_code(philox_two_doubles_call, dialect, vector_instruction_set,
for i in range(self._dim)] + list(self.keys) self._time_step, self._offsets, self.keys, self._dim, self.result_symbols)
while len(parameters) < 6:
parameters.append(0)
parameters = parameters[:6]
assert len(parameters) == 6
if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
return philox_two_doubles_call.format(parameters=', '.join(str(p) for p in parameters),
result_symbols=self.result_symbols)
else:
raise NotImplementedError("Not yet implemented for this backend")
def __repr__(self): def __repr__(self):
return "{}, {} <- PhiloxRNG".format(*self.result_symbols) return "{}, {} <- PhiloxRNG".format(*self.result_symbols)
...@@ -71,15 +77,15 @@ class PhiloxTwoDoubles(CustomCodeNode): ...@@ -71,15 +77,15 @@ class PhiloxTwoDoubles(CustomCodeNode):
class PhiloxFourFloats(CustomCodeNode): class PhiloxFourFloats(CustomCodeNode):
def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), keys=(0, 0)): def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=(0, 0, 0), keys=(0, 0)):
self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float32) for _ in range(4)) self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float32) for _ in range(4))
symbols_read = [s for s in keys if isinstance(s, sp.Symbol)] symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols) super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self._time_step = time_step self._time_step = time_step
self._offsets = offsets
self.headers = ['"philox_rand.h"'] self.headers = ['"philox_rand.h"']
self.keys = tuple(keys) self.keys = tuple(keys)
self._args = sp.sympify((dim, time_step, keys)) self._args = sp.sympify((dim, time_step, offsets, keys))
self._dim = dim self._dim = dim
@property @property
...@@ -88,7 +94,7 @@ class PhiloxFourFloats(CustomCodeNode): ...@@ -88,7 +94,7 @@ class PhiloxFourFloats(CustomCodeNode):
@property @property
def undefined_symbols(self): def undefined_symbols(self):
result = {a for a in self.args if isinstance(a, sp.Symbol)} result = {a for a in (self._time_step, *self._offsets, *self.keys) if isinstance(a, sp.Symbol)}
loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i) loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
for i in range(self._dim)] for i in range(self._dim)]
result.update(loop_counters) result.update(loop_counters)
...@@ -98,28 +104,17 @@ class PhiloxFourFloats(CustomCodeNode): ...@@ -98,28 +104,17 @@ class PhiloxFourFloats(CustomCodeNode):
return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
def get_code(self, dialect, vector_instruction_set): def get_code(self, dialect, vector_instruction_set):
parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) return _get_philox_code(philox_four_floats_call, dialect, vector_instruction_set,
for i in range(self._dim)] + list(self.keys) self._time_step, self._offsets, self.keys, self._dim, self.result_symbols)
while len(parameters) < 6:
parameters.append(0)
parameters = parameters[:6]
assert len(parameters) == 6
if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
return philox_four_floats_call.format(parameters=', '.join(str(p) for p in parameters),
result_symbols=self.result_symbols)
else:
raise NotImplementedError("Not yet implemented for this backend")
def __repr__(self): def __repr__(self):
return "{}, {}, {}, {} <- PhiloxRNG".format(*self.result_symbols) return "{}, {}, {}, {} <- PhiloxRNG".format(*self.result_symbols)
def random_symbol(assignment_list, rng_node=PhiloxTwoDoubles, *args, **kwargs): def random_symbol(assignment_list, seed=TypedSymbol("seed", np.uint32), rng_node=PhiloxTwoDoubles, *args, **kwargs):
counter = 0
while True: while True:
node = rng_node(*args, **kwargs) node = rng_node(*args, keys=(counter, seed), **kwargs)
inserted = False inserted = False
for symbol in node.result_symbols: for symbol in node.result_symbols:
if not inserted: if not inserted:
......
...@@ -4,25 +4,12 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, ...@@ -4,25 +4,12 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set,
import sympy as sp import sympy as sp
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.astnodes import Node from pystencils.simp.simplifications import (
sort_assignments_topologically, sympy_cse_on_assignment_list,
transform_lhs_and_rhs, transform_rhs)
from pystencils.sympyextensions import count_operations, fast_subs from pystencils.sympyextensions import count_operations, fast_subs
def transform_rhs(assignment_list, transformation, *args, **kwargs):
"""Applies a transformation function on the rhs of each element of the passed assignment list
If the list also contains other object, like AST nodes, these are ignored.
Additional parameters are passed to the transformation function"""
return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a
for a in assignment_list]
def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
return [Assignment(transformation(a.lhs, *args, **kwargs),
transformation(a.rhs, *args, **kwargs))
if isinstance(a, Assignment) else a
for a in assignment_list]
class AssignmentCollection: class AssignmentCollection:
""" """
A collection of equations with subexpression definitions, also represented as assignments, A collection of equations with subexpression definitions, also represented as assignments,
...@@ -98,9 +85,9 @@ class AssignmentCollection: ...@@ -98,9 +85,9 @@ class AssignmentCollection:
def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None: def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:
"""Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition.""" """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition."""
if sort_subexpressions: if sort_subexpressions:
self.subexpressions = sort_assignments_topologically(self.subexpressions) self.subexpressions = sympy_cse_on_assignment_list(self.subexpressions)
if sort_main_assignments: if sort_main_assignments:
self.main_assignments = sort_assignments_topologically(self.main_assignments) self.main_assignments = sympy_cse_on_assignment_list(self.main_assignments)
# ---------------------------------------------- Properties ------------------------------------------------------- # ---------------------------------------------- Properties -------------------------------------------------------
...@@ -419,22 +406,3 @@ class SymbolGen: ...@@ -419,22 +406,3 @@ class SymbolGen:
name = "{}_{}".format(self._symbol, self._ctr) name = "{}_{}".format(self._symbol, self._ctr)
self._ctr += 1 self._ctr += 1
return sp.Symbol(name) return sp.Symbol(name)
def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
"""Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
edges = []
for c1, e1 in enumerate(assignments):
if isinstance(e1, Assignment):
symbols = [e1.lhs]
elif isinstance(e1, Node):
symbols = e1.symbols_defined
else:
symbols = []
for lhs in symbols:
for c2, e2 in enumerate(assignments):
if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:
edges.append((c1, c2))
elif isinstance(e2, Node) and lhs in e2.undefined_symbols:
edges.append((c1, c2))
return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]
from typing import Callable, List from itertools import chain
from typing import Callable, List, Sequence, Union
import sympy as sp import sympy as sp
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.astnodes import Node
from pystencils.field import AbstractField, Field from pystencils.field import AbstractField, Field
from pystencils.simp.assignment_collection import AssignmentCollection, transform_rhs
from pystencils.sympyextensions import subs_additive from pystencils.sympyextensions import subs_additive
AC = AssignmentCollection
def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
"""Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
edges = []
for c1, e1 in enumerate(assignments):
if isinstance(e1, Assignment):
symbols = [e1.lhs]
elif isinstance(e1, Node):
symbols = e1.symbols_defined
else:
symbols = []
for lhs in symbols:
for c2, e2 in enumerate(assignments):
if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:
edges.append((c1, c2))
elif isinstance(e2, Node) and lhs in e2.undefined_symbols:
edges.append((c1, c2))
return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]
def sympy_cse(ac: AC) -> AC: def sympy_cse(ac):
"""Searches for common subexpressions inside the equation collection. """Searches for common subexpressions inside the equation collection.
Searches is done in both the existing subexpressions as well as the assignments themselves. Searches is done in both the existing subexpressions as well as the assignments themselves.
...@@ -18,27 +36,28 @@ def sympy_cse(ac: AC) -> AC: ...@@ -18,27 +36,28 @@ def sympy_cse(ac: AC) -> AC:
with the additional subexpressions found with the additional subexpressions found
""" """
symbol_gen = ac.subexpression_symbol_generator symbol_gen = ac.subexpression_symbol_generator
replacements, new_eq = sp.cse(ac.subexpressions + ac.main_assignments,
symbols=symbol_gen) all_assignments = [e for e in chain(ac.subexpressions, ac.main_assignments) if isinstance(e, Assignment)]
other_objects = [e for e in chain(ac.subexpressions, ac.main_assignments) if not isinstance(e, Assignment)]
replacements, new_eq = sp.cse(all_assignments, symbols=symbol_gen)
replacement_eqs = [Assignment(*r) for r in replacements] replacement_eqs = [Assignment(*r) for r in replacements]
modified_subexpressions = new_eq[:len(ac.subexpressions)] modified_subexpressions = new_eq[:len(ac.subexpressions)]
modified_update_equations = new_eq[len(ac.subexpressions):] modified_update_equations = new_eq[len(ac.subexpressions):]
new_subexpressions = replacement_eqs + modified_subexpressions new_subexpressions = sort_assignments_topologically(other_objects + replacement_eqs + modified_subexpressions)
topologically_sorted_pairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in new_subexpressions])
new_subexpressions = [Assignment(a[0], a[1]) for a in topologically_sorted_pairs]
return ac.copy(modified_update_equations, new_subexpressions) return ac.copy(modified_update_equations, new_subexpressions)
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]: def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
"""Extracts common subexpressions from a list of assignments.""" """Extracts common subexpressions from a list of assignments."""
ec = AC([], assignments) from pystencils.simp.assignment_collection import AssignmentCollection
ec = AssignmentCollection([], assignments)
return sympy_cse(ec).all_assignments return sympy_cse(ec).all_assignments
def subexpression_substitution_in_existing_subexpressions(ac: AC) -> AC: def subexpression_substitution_in_existing_subexpressions(ac):
"""Goes through the subexpressions list and replaces the term in the following subexpressions.""" """Goes through the subexpressions list and replaces the term in the following subexpressions."""
result = [] result = []
for outer_ctr, s in enumerate(ac.subexpressions): for outer_ctr, s in enumerate(ac.subexpressions):
...@@ -52,7 +71,7 @@ def subexpression_substitution_in_existing_subexpressions(ac: AC) -> AC: ...@@ -52,7 +71,7 @@ def subexpression_substitution_in_existing_subexpressions(ac: AC) -> AC:
return ac.copy(ac.main_assignments, result) return ac.copy(ac.main_assignments, result)
def subexpression_substitution_in_main_assignments(ac: AC) -> AC: def subexpression_substitution_in_main_assignments(ac):
"""Replaces already existing subexpressions in the equations of the assignment_collection.""" """Replaces already existing subexpressions in the equations of the assignment_collection."""
result = [] result = []
for s in ac.main_assignments: for s in ac.main_assignments:
...@@ -63,7 +82,7 @@ def subexpression_substitution_in_main_assignments(ac: AC) -> AC: ...@@ -63,7 +82,7 @@ def subexpression_substitution_in_main_assignments(ac: AC) -> AC:
return ac.copy(result) return ac.copy(result)
def add_subexpressions_for_divisions(ac: AC) -> AC: def add_subexpressions_for_divisions(ac):
r"""Introduces subexpressions for all divisions which have no constant in the denominator. r"""Introduces subexpressions for all divisions which have no constant in the denominator.
For example :math:`\frac{1}{x}` is replaced while :math:`\frac{1}{3}` is not replaced. For example :math:`\frac{1}{x}` is replaced while :math:`\frac{1}{3}` is not replaced.
...@@ -87,7 +106,7 @@ def add_subexpressions_for_divisions(ac: AC) -> AC: ...@@ -87,7 +106,7 @@ def add_subexpressions_for_divisions(ac: AC) -> AC:
return ac.new_with_substitutions(substitutions, True) return ac.new_with_substitutions(substitutions, True)
def add_subexpressions_for_sums(ac: AC) -> AC: def add_subexpressions_for_sums(ac):
r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions.""" r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions."""
addends = [] addends = []
...@@ -114,7 +133,7 @@ def add_subexpressions_for_sums(ac: AC) -> AC: ...@@ -114,7 +133,7 @@ def add_subexpressions_for_sums(ac: AC) -> AC:
return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False) return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)
def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignments=True) -> AC: def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True):
r"""Substitutes field accesses on rhs of assignments with subexpressions r"""Substitutes field accesses on rhs of assignments with subexpressions
Can change semantics of the update rule (which is the goal of this transformation) Can change semantics of the update rule (which is the goal of this transformation)
...@@ -132,17 +151,36 @@ def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignm ...@@ -132,17 +151,36 @@ def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignm
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False) return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)
def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]: def transform_rhs(assignment_list, transformation, *args, **kwargs):
"""Applies a transformation function on the rhs of each element of the passed assignment list
If the list also contains other object, like AST nodes, these are ignored.
Additional parameters are passed to the transformation function"""
return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a
for a in assignment_list]
def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
return [Assignment(transformation(a.lhs, *args, **kwargs),
transformation(a.rhs, *args, **kwargs))
if isinstance(a, Assignment) else a
for a in assignment_list]
def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]):
"""Applies sympy expand operation to all equations in collection.""" """Applies sympy expand operation to all equations in collection."""
def f(ac: AC) -> AC:
def f(ac):
return ac.copy(transform_rhs(ac.main_assignments, operation)) return ac.copy(transform_rhs(ac.main_assignments, operation))
f.__name__ = operation.__name__ f.__name__ = operation.__name__
return f return f
def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]: def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
"""Applies the given operation on all subexpressions of the AC.""" """Applies the given operation on all subexpressions of the AC."""
def f(ac: AC) -> AC:
def f(ac):
return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation)) return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation))
f.__name__ = operation.__name__ f.__name__ = operation.__name__
return f return f
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment