Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 1614 additions and 548 deletions
from typing import List
import sympy as sp
from pystencils.assignment import Assignment
from pystencils.astnodes import Node
from pystencils.sympyextensions import is_constant
from pystencils.transformations import generic_visit
class PlaceholderFunction:
pass
def to_placeholder_function(expr, name):
"""Replaces an expression by a sympy function.
- replacing an expression with just a symbol would lead to problem when calculating derivatives
- placeholder functions get rid of this problem
Examples:
>>> x, t = sp.symbols("x, t")
>>> temperature = x**2 + t**4 # some 'complicated' dependency
>>> temperature_placeholder = to_placeholder_function(temperature, 'T')
>>> diffusivity = temperature_placeholder + 42 * t
>>> sp.diff(diffusivity, t) # returns a symbol instead of the computed derivative
_dT_dt + 42
>>> result, subexpr = remove_placeholder_functions(diffusivity)
>>> result
T + 42*t
>>> subexpr
[Assignment(T, t**4 + x**2), Assignment(_dT_dt, 4*t**3), Assignment(_dT_dx, 2*x)]
"""
symbols = list(expr.atoms(sp.Symbol))
symbols.sort(key=lambda e: e.name)
derivative_symbols = [sp.Symbol(f"_d{name}_d{s.name}") for s in symbols]
derivatives = [sp.diff(expr, s) for s in symbols]
assignments = [Assignment(sp.Symbol(name), expr)]
assignments += [Assignment(symbol, derivative)
for symbol, derivative in zip(derivative_symbols, derivatives)
if not is_constant(derivative)]
def fdiff(_, index):
result = derivatives[index - 1]
return result if is_constant(result) else derivative_symbols[index - 1]
func = type(name, (sp.Function, PlaceholderFunction),
{'fdiff': fdiff,
'value': sp.Symbol(name),
'subexpressions': assignments,
'nargs': len(symbols)})
return func(*symbols)
def remove_placeholder_functions(expr):
subexpressions = []
def visit(e):
if isinstance(e, Node):
return e
elif isinstance(e, PlaceholderFunction):
for se in e.subexpressions:
if se.lhs not in {a.lhs for a in subexpressions}:
subexpressions.append(se)
return e.value
else:
new_args = [visit(a) for a in e.args]
return e.func(*new_args) if new_args else e
return generic_visit(expr, visit), subexpressions
def prepend_placeholder_functions(assignments: List[Assignment]):
result, subexpressions = remove_placeholder_functions(assignments)
return subexpressions + result
...@@ -3,6 +3,9 @@ This module extends the pyplot module with functions to show scalar and vector f ...@@ -3,6 +3,9 @@ This module extends the pyplot module with functions to show scalar and vector f
simulation coordinate system (y-axis goes up), instead of the "image coordinate system" (y axis goes down) that simulation coordinate system (y-axis goes up), instead of the "image coordinate system" (y axis goes down) that
matplotlib normally uses. matplotlib normally uses.
""" """
import warnings
from itertools import cycle
from matplotlib.pyplot import * from matplotlib.pyplot import *
...@@ -65,6 +68,26 @@ def scalar_field(array, **kwargs): ...@@ -65,6 +68,26 @@ def scalar_field(array, **kwargs):
return res return res
def scalar_field_surface(array, **kwargs):
"""Plots scalar field as 3D surface
Args:
array: the two dimensional numpy array to plot
kwargs: keyword arguments passed to :func:`mpl_toolkits.mplot3d.Axes3D.plot_surface`
"""
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
fig = gcf()
ax = fig.add_subplot(111, projection='3d')
x, y = np.meshgrid(np.arange(array.shape[0]), np.arange(array.shape[1]), indexing='ij')
kwargs.setdefault('rstride', 2)
kwargs.setdefault('cstride', 2)
kwargs.setdefault('color', 'b')
kwargs.setdefault('cmap', cm.coolwarm)
return ax.plot_surface(x, y, array, **kwargs)
def scalar_field_alpha_value(array, color, clip=False, **kwargs): def scalar_field_alpha_value(array, color, clip=False, **kwargs):
"""Plots an image with same color everywhere, using the array values as transparency. """Plots an image with same color everywhere, using the array values as transparency.
...@@ -135,6 +158,28 @@ def multiple_scalar_fields(array, **kwargs): ...@@ -135,6 +158,28 @@ def multiple_scalar_fields(array, **kwargs):
colorbar() colorbar()
def phase_plot(phase_field: np.ndarray, linewidth=1.0, clip=True) -> None:
"""Plots a phase field array using the phase variables as alpha channel.
Args:
phase_field: array with len(shape) == 3, first two dimensions are spatial, the last one indexes the phase
components.
linewidth: line width of the 0.5 contour lines that are drawn over the alpha blended phase images
clip: see scalar_field_alpha_value function
"""
color_cycle = cycle(['#fe0002', '#00fe00', '#0000ff', '#ffa800', '#f600ff'])
assert len(phase_field.shape) == 3
with warnings.catch_warnings():
warnings.simplefilter("ignore")
for i in range(phase_field.shape[-1]):
scalar_field_alpha_value(phase_field[..., i], next(color_cycle), clip=clip, interpolation='bilinear')
if linewidth:
for i in range(phase_field.shape[-1]):
scalar_field_contour(phase_field[..., i], levels=[0.5], colors='k', linewidths=[linewidth])
def sympy_function(expr, x_values=None, **kwargs): def sympy_function(expr, x_values=None, **kwargs):
"""Plots the graph of a sympy term that depends on one symbol only. """Plots the graph of a sympy term that depends on one symbol only.
...@@ -284,14 +329,12 @@ def scalar_field_animation(run_function, plot_setup_function=lambda *_: None, re ...@@ -284,14 +329,12 @@ def scalar_field_animation(run_function, plot_setup_function=lambda *_: None, re
return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames) return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames)
def surface_plot_animation(run_function, frames=90, interval=30, **kwargs): def surface_plot_animation(run_function, frames=90, interval=30, zlim=None, **kwargs):
"""Animation of scalar field as 3D plot.""" """Animation of scalar field as 3D plot."""
from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation import matplotlib.animation as animation
import matplotlib.pyplot as plt
from matplotlib import cm from matplotlib import cm
fig = gcf()
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d') ax = fig.add_subplot(111, projection='3d')
data = run_function() data = run_function()
x, y = np.meshgrid(np.arange(data.shape[0]), np.arange(data.shape[1]), indexing='ij') x, y = np.meshgrid(np.arange(data.shape[0]), np.arange(data.shape[1]), indexing='ij')
...@@ -300,13 +343,15 @@ def surface_plot_animation(run_function, frames=90, interval=30, **kwargs): ...@@ -300,13 +343,15 @@ def surface_plot_animation(run_function, frames=90, interval=30, **kwargs):
kwargs.setdefault('color', 'b') kwargs.setdefault('color', 'b')
kwargs.setdefault('cmap', cm.coolwarm) kwargs.setdefault('cmap', cm.coolwarm)
ax.plot_surface(x, y, data, **kwargs) ax.plot_surface(x, y, data, **kwargs)
ax.set_zlim(-1.0, 1.0) if zlim is not None:
ax.set_zlim(*zlim)
def update_figure(*_): def update_figure(*_):
d = run_function() d = run_function()
ax.clear() ax.clear()
plot = ax.plot_surface(x, y, d, **kwargs) plot = ax.plot_surface(x, y, d, **kwargs)
ax.set_zlim(-1.0, 1.0) if zlim is not None:
ax.set_zlim(*zlim)
return plot, return plot,
return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames, blit=False) return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames, blit=False)
import copy
import numpy as np
import sympy as sp
from pystencils.typing import TypedSymbol, CastFunc
from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import CustomCodeNode
from pystencils.sympyextensions import fast_subs
class RNGBase(CustomCodeNode):
id = 0
def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=None, keys=None):
if keys is None:
keys = (0,) * self._num_keys
if offsets is None:
offsets = (0,) * dim
if len(keys) != self._num_keys:
raise ValueError(f"Provided {len(keys)} keys but need {self._num_keys}")
if len(offsets) != dim:
raise ValueError(f"Provided {len(offsets)} offsets but need {dim}")
coordinates = [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] for i in range(dim)]
if dim < 3:
coordinates.append(0)
self._args = sp.sympify([time_step, *coordinates, *keys])
self.result_symbols = tuple(TypedSymbol(f'random_{self.id}_{i}', self._data_type)
for i in range(self._num_vars))
symbols_read = set.union(*[s.atoms(sp.Symbol) for s in self.args])
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self.headers = [f'"{self._name.split("_")[0]}_rand.h"']
RNGBase.id += 1
@property
def args(self):
return self._args
def fast_subs(self, subs_dict, skip):
rng = copy.deepcopy(self)
rng._args = [fast_subs(a, subs_dict, skip) for a in rng._args]
return rng
def get_code(self, dialect, vector_instruction_set, print_arg):
code = "\n"
for r in self.result_symbols:
if vector_instruction_set and not self.args[1].atoms(CastFunc):
# this vector RNG has become scalar through substitution
code += f"{r.dtype} {r.name};\n"
else:
code += f"{vector_instruction_set[r.dtype.c_name] if vector_instruction_set else r.dtype} " + \
f"{r.name};\n"
args = [print_arg(a) for a in self.args] + ['' + r.name for r in self.result_symbols]
code += (self._name + "(" + ", ".join(args) + ");\n")
return code
def __repr__(self):
return ", ".join([str(s) for s in self.result_symbols]) + " \\leftarrow " + \
self._name.capitalize() + "_RNG(" + ", ".join([str(a) for a in self.args]) + ")"
def _hashable_content(self):
return (self._name, *self.result_symbols, *self.args)
def __eq__(self, other):
return type(self) is type(other) and self._hashable_content() == other._hashable_content()
def __hash__(self):
return hash(self._hashable_content())
class PhiloxTwoDoubles(RNGBase):
_name = "philox_double2"
_data_type = np.float64
_num_vars = 2
_num_keys = 2
class PhiloxFourFloats(RNGBase):
_name = "philox_float4"
_data_type = np.float32
_num_vars = 4
_num_keys = 2
class AESNITwoDoubles(RNGBase):
_name = "aesni_double2"
_data_type = np.float64
_num_vars = 2
_num_keys = 4
class AESNIFourFloats(RNGBase):
_name = "aesni_float4"
_data_type = np.float32
_num_vars = 4
_num_keys = 4
def random_symbol(assignment_list, dim, seed=TypedSymbol("seed", np.uint32), rng_node=PhiloxTwoDoubles,
time_step=TypedSymbol("time_step", np.uint32), offsets=None):
"""Return a symbol generator for random numbers
Args:
assignment_list: the subexpressions member of an AssignmentCollection, into which helper variables assignments
will be inserted
dim: 2 or 3 for two or three spatial dimensions
seed: an integer or TypedSymbol(..., np.uint32) to seed the random number generator. If you create multiple
symbol generators, please pass them different seeds so you don't get the same stream of random numbers!
rng_node: which random number generator to use (PhiloxTwoDoubles, PhiloxFourFloats, AESNITwoDoubles,
AESNIFourFloats).
time_step: TypedSymbol(..., np.uint32) that indicates the number of the current time step
offsets: tuple of offsets (constant integers or TypedSymbol(..., np.uint32)) that give the global coordinates
of the local origin
"""
counter = 0
while True:
keys = (counter, seed) + (0,) * (rng_node._num_keys - 2)
node = rng_node(dim, keys=keys, time_step=time_step, offsets=offsets)
inserted = False
for symbol in node.result_symbols:
if not inserted:
assignment_list.insert(0, node)
inserted = True
yield symbol
counter += 1
import time
import socket import socket
from typing import Dict, Sequence, Iterator import time
from types import MappingProxyType
from typing import Dict, Iterator, Sequence
import blitzdb import blitzdb
import six
from blitzdb.backends.file.backend import serializer_classes
from blitzdb.backends.file.utils import JsonEncoder
from pystencils.cpu.cpujit import get_compiler_config from pystencils.cpu.cpujit import get_compiler_config
from pystencils import CreateKernelConfig, Target, Backend, Field
import json
import sympy as sp
from pystencils.typing import BasicType
class PystencilsJsonEncoder(JsonEncoder):
def default(self, obj):
if isinstance(obj, CreateKernelConfig):
return obj.__dict__
if isinstance(obj, (sp.Float, sp.Rational)):
return float(obj)
if isinstance(obj, sp.Integer):
return int(obj)
if isinstance(obj, (BasicType, MappingProxyType)):
return str(obj)
if isinstance(obj, (Target, Backend, sp.Symbol)):
return obj.name
if isinstance(obj, Field):
return f"pystencils.Field(name = {obj.name}, field_type = {obj.field_type.name}, " \
f"dtype = {str(obj.dtype)}, layout = {obj.layout}, shape = {obj.shape}, " \
f"strides = {obj.strides})"
return JsonEncoder.default(self, obj)
class PystencilsJsonSerializer(object):
@classmethod
def serialize(cls, data):
if six.PY3:
if isinstance(data, bytes):
return json.dumps(data.decode('utf-8'), cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
else:
return json.dumps(data, cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
else:
return json.dumps(data, cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
@classmethod
def deserialize(cls, data):
if six.PY3:
return json.loads(data.decode('utf-8'))
else:
return json.loads(data.decode('utf-8'))
class Database: class Database:
...@@ -33,16 +85,18 @@ class Database: ...@@ -33,16 +85,18 @@ class Database:
... assert next(db.filter_params(params))['params'] == params # get data set, keys are 'params', 'results' ... assert next(db.filter_params(params))['params'] == params # get data set, keys are 'params', 'results'
... # and 'env' ... # and 'env'
... # get a pandas object with all results matching a query ... # get a pandas object with all results matching a query
... db.to_pandas({'dx': 1.5}, remove_prefix=True) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE ... df = db.to_pandas({'dx': 1.5}, remove_prefix=True)
dx method error ... # order columns alphabetically (just for doctest output)
pk ... df.reindex(sorted(df.columns), axis=1) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
... 1.5 finite_diff 0.000001 dx error method
pk
... 1.5 0.000001 finite_diff
""" """
class SimulationResult(blitzdb.Document): class SimulationResult(blitzdb.Document):
pass pass
def __init__(self, file: str) -> None: def __init__(self, file: str, serializer_info: tuple = None) -> None:
if file.startswith("mongo://"): if file.startswith("mongo://"):
from pymongo import MongoClient from pymongo import MongoClient
db_name = file[len("mongo://"):] db_name = file[len("mongo://"):]
...@@ -53,6 +107,10 @@ class Database: ...@@ -53,6 +107,10 @@ class Database:
self.backend.autocommit = True self.backend.autocommit = True
if serializer_info:
serializer_classes.update({serializer_info[0]: serializer_info[1]})
self.backend.load_config({'serializer_class': serializer_info[0]}, True)
def save(self, params: Dict, result: Dict, env: Dict = None, **kwargs) -> None: def save(self, params: Dict, result: Dict, env: Dict = None, **kwargs) -> None:
"""Stores a simulation result in the database. """Stores a simulation result in the database.
...@@ -116,7 +174,7 @@ class Database: ...@@ -116,7 +174,7 @@ class Database:
Returns: Returns:
pandas data frame pandas data frame
""" """
from pandas.io.json import json_normalize from pandas import json_normalize
query_result = self.filter_params(parameter_query) query_result = self.filter_params(parameter_query)
attributes = [e.attributes for e in query_result] attributes = [e.attributes for e in query_result]
...@@ -142,10 +200,15 @@ class Database: ...@@ -142,10 +200,15 @@ class Database:
'cpuCompilerConfig': get_compiler_config(), 'cpuCompilerConfig': get_compiler_config(),
} }
try: try:
from git import Repo, InvalidGitRepositoryError from git import Repo
except ImportError:
return result
try:
from git import InvalidGitRepositoryError
repo = Repo(search_parent_directories=True) repo = Repo(search_parent_directories=True)
result['git_hash'] = str(repo.head.commit) result['git_hash'] = str(repo.head.commit)
except (ImportError, InvalidGitRepositoryError): except InvalidGitRepositoryError:
pass pass
return result return result
......
import json
import datetime import datetime
import itertools
import json
import os import os
import socket import socket
import itertools
from copy import deepcopy
from collections import namedtuple from collections import namedtuple
from copy import deepcopy
from time import sleep from time import sleep
from typing import Dict, Callable, Sequence, Any, Tuple, Optional from typing import Any, Callable, Dict, Optional, Sequence, Tuple
from pystencils.runhelper import Database from pystencils.runhelper import Database
from pystencils.runhelper.db import PystencilsJsonSerializer
from pystencils.utils import DotDict from pystencils.utils import DotDict
ParameterDict = Dict[str, Any] ParameterDict = Dict[str, Any]
WeightFunction = Callable[[Dict], int] WeightFunction = Callable[[Dict], int]
FilterFunction = Callable[[ParameterDict], Optional[ParameterDict]] FilterFunction = Callable[[ParameterDict], Optional[ParameterDict]]
...@@ -54,10 +55,11 @@ class ParameterStudy: ...@@ -54,10 +55,11 @@ class ParameterStudy:
Run = namedtuple("Run", ['parameter_dict', 'weight']) Run = namedtuple("Run", ['parameter_dict', 'weight'])
def __init__(self, run_function: Callable[..., Dict], runs: Sequence = (), def __init__(self, run_function: Callable[..., Dict], runs: Sequence = (),
database_connector: str = './db') -> None: database_connector: str = './db',
serializer_info: tuple = ('pystencils_serializer', PystencilsJsonSerializer)) -> None:
self.runs = list(runs) self.runs = list(runs)
self.run_function = run_function self.run_function = run_function
self.db = Database(database_connector) self.db = Database(database_connector, serializer_info)
def add_run(self, parameter_dict: ParameterDict, weight: int = 1) -> None: def add_run(self, parameter_dict: ParameterDict, weight: int = 1) -> None:
"""Schedule a dictionary of parameters to run in this parameter study. """Schedule a dictionary of parameters to run in this parameter study.
...@@ -215,7 +217,7 @@ class ParameterStudy: ...@@ -215,7 +217,7 @@ class ParameterStudy:
def log_message(self, fmt, *args): def log_message(self, fmt, *args):
return return
print("Listening to connections on {}:{}. Scenarios to simulate: {}".format(ip, port, len(filtered_runs))) print(f"Listening to connections on {ip}:{port}. Scenarios to simulate: {len(filtered_runs)}")
server = HTTPServer((ip, port), ParameterStudyServer) server = HTTPServer((ip, port), ParameterStudyServer)
while len(ParameterStudyServer.currently_running) > 0 or len(ParameterStudyServer.runs) > 0: while len(ParameterStudyServer.currently_running) > 0 or len(ParameterStudyServer.runs) > 0:
server.handle_request() server.handle_request()
...@@ -241,7 +243,7 @@ class ParameterStudy: ...@@ -241,7 +243,7 @@ class ParameterStudy:
from urllib.error import URLError from urllib.error import URLError
import time import time
parameter_update = {} if parameter_update is None else parameter_update parameter_update = {} if parameter_update is None else parameter_update
url = "http://{}:{}".format(server, port) url = f"http://{server}:{port}"
client_name = client_name.format(hostname=socket.gethostname(), pid=os.getpid()) client_name = client_name.format(hostname=socket.gethostname(), pid=os.getpid())
start_time = time.time() start_time = time.time()
while True: while True:
...@@ -265,7 +267,7 @@ class ParameterStudy: ...@@ -265,7 +267,7 @@ class ParameterStudy:
'client_name': client_name} 'client_name': client_name}
urlopen(url + '/result', data=json.dumps(answer).encode()) urlopen(url + '/result', data=json.dumps(answer).encode())
except URLError: except URLError:
print("Cannot connect to server {} retrying in 5 seconds...".format(url)) print(f"Cannot connect to server {url} retrying in 5 seconds...")
sleep(5) sleep(5)
def run_from_command_line(self, argv: Optional[Sequence[str]] = None) -> None: def run_from_command_line(self, argv: Optional[Sequence[str]] = None) -> None:
......
import numpy as np
import sympy as sp
import pystencils as ps
from pystencils.jupyter import make_imshow_animation, display_animation, set_display_mode
import pystencils.plot as plt
__all__ = ['sp', 'np', 'ps', 'plt', 'make_imshow_animation', 'display_animation', 'set_display_mode']
from .assignment_collection import AssignmentCollection from .assignment_collection import AssignmentCollection
from .simplifications import (
add_subexpressions_for_constants,
add_subexpressions_for_divisions, add_subexpressions_for_field_reads,
add_subexpressions_for_sums, apply_on_all_subexpressions, apply_to_all_assignments,
subexpression_substitution_in_existing_subexpressions,
subexpression_substitution_in_main_assignments, sympy_cse, sympy_cse_on_assignment_list)
from .subexpression_insertion import (
insert_aliases, insert_zeros, insert_constants,
insert_constant_additions, insert_constant_multiples,
insert_squares, insert_symbol_times_minus_one)
from .simplificationstrategy import SimplificationStrategy from .simplificationstrategy import SimplificationStrategy
from .simplifications import sympy_cse, sympy_cse_on_assignment_list, \
apply_to_all_assignments, apply_on_all_subexpressions, subexpression_substitution_in_existing_subexpressions, \
subexpression_substitution_in_main_assignments, add_subexpressions_for_divisions, add_subexpressions_for_field_reads
__all__ = ['AssignmentCollection', 'SimplificationStrategy', __all__ = ['AssignmentCollection', 'SimplificationStrategy',
'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments', 'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions', 'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions', 'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_constants',
'add_subexpressions_for_field_reads'] 'add_subexpressions_for_divisions', 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads',
'insert_aliases', 'insert_zeros', 'insert_constants',
'insert_constant_additions', 'insert_constant_multiples',
'insert_squares', 'insert_symbol_times_minus_one']
import sympy as sp import itertools
from copy import copy from copy import copy
from typing import List, Optional, Dict, Any, Set, Sequence, Iterator, Iterable, Union from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
import sympy as sp
import pystencils
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.sympyextensions import fast_subs, count_operations, sort_assignments_topologically from pystencils.simp.simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
from pystencils.sympyextensions import count_operations, fast_subs
class AssignmentCollection: class AssignmentCollection:
...@@ -12,15 +17,16 @@ class AssignmentCollection: ...@@ -12,15 +17,16 @@ class AssignmentCollection:
These simplification methods can change the subexpressions, but the number and These simplification methods can change the subexpressions, but the number and
left hand side of the main equations themselves is not altered. left hand side of the main equations themselves is not altered.
Additionally a dictionary of simplification hints is stored, which are set by the functions that create Additionally a dictionary of simplification hints is stored, which are set by the functions that create
equation collections to transport information to the simplification system. assignment collections to transport information to the simplification system.
Attributes: Args:
main_assignments: list of assignments main_assignments: List of assignments. Main assignments are characterised, that the right hand side of each
subexpressions: list of assignments defining subexpressions used in main equations assignment is a field access. Thus the generated equations write on arrays.
simplification_hints: dict that is used to annotate the equation collection with hints that are subexpressions: List of assignments defining subexpressions used in main equations
simplification_hints: Dict that is used to annotate the assignment collection with hints that are
used by the simplification system. See documentation of the simplification rules for used by the simplification system. See documentation of the simplification rules for
potentially required hints and their meaning. potentially required hints and their meaning.
subexpression_symbol_generator: generator for new symbols that are used when new subexpressions are added subexpression_symbol_generator: Generator for new symbols that are used when new subexpressions are added
used to get new symbols that are unique for this AssignmentCollection used to get new symbols that are unique for this AssignmentCollection
""" """
...@@ -28,9 +34,13 @@ class AssignmentCollection: ...@@ -28,9 +34,13 @@ class AssignmentCollection:
# ------------------------------- Creation & Inplace Manipulation -------------------------------------------------- # ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]], def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]], subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = None,
simplification_hints: Optional[Dict[str, Any]] = None, simplification_hints: Optional[Dict[str, Any]] = None,
subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None: subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:
if subexpressions is None:
subexpressions = {}
if isinstance(main_assignments, Dict): if isinstance(main_assignments, Dict):
main_assignments = [Assignment(k, v) main_assignments = [Assignment(k, v)
for k, v in main_assignments.items()] for k, v in main_assignments.items()]
...@@ -38,6 +48,11 @@ class AssignmentCollection: ...@@ -38,6 +48,11 @@ class AssignmentCollection:
subexpressions = [Assignment(k, v) subexpressions = [Assignment(k, v)
for k, v in subexpressions.items()] for k, v in subexpressions.items()]
main_assignments = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in main_assignments]))
subexpressions = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in subexpressions]))
self.main_assignments = main_assignments self.main_assignments = main_assignments
self.subexpressions = subexpressions self.subexpressions = subexpressions
...@@ -46,8 +61,11 @@ class AssignmentCollection: ...@@ -46,8 +61,11 @@ class AssignmentCollection:
self.simplification_hints = simplification_hints self.simplification_hints = simplification_hints
ctrs = [int(n.name[3:])for n in self.rhs_symbols if "xi_" in n.name]
max_ctr = max(ctrs) + 1 if len(ctrs) > 0 else 0
if subexpression_symbol_generator is None: if subexpression_symbol_generator is None:
self.subexpression_symbol_generator = SymbolGen() self.subexpression_symbol_generator = SymbolGen(ctr=max_ctr)
else: else:
self.subexpression_symbol_generator = subexpression_symbol_generator self.subexpression_symbol_generator = subexpression_symbol_generator
...@@ -91,32 +109,70 @@ class AssignmentCollection: ...@@ -91,32 +109,70 @@ class AssignmentCollection:
"""Subexpression and main equations as a single list.""" """Subexpression and main equations as a single list."""
return self.subexpressions + self.main_assignments return self.subexpressions + self.main_assignments
@property
def rhs_symbols(self) -> Set[sp.Symbol]:
"""All symbols used in the assignment collection, which occur on the rhs of any assignment."""
rhs_symbols = set()
for eq in self.all_assignments:
if isinstance(eq, Assignment):
rhs_symbols.update(eq.rhs.atoms(sp.Symbol))
elif isinstance(eq, pystencils.astnodes.Node):
rhs_symbols.update(eq.undefined_symbols)
return rhs_symbols
@property @property
def free_symbols(self) -> Set[sp.Symbol]: def free_symbols(self) -> Set[sp.Symbol]:
"""All symbols used in the assignment collection, which do not occur as left hand sides in any assignment.""" """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
free_symbols = set() return self.rhs_symbols - self.bound_symbols
for eq in self.all_assignments:
free_symbols.update(eq.rhs.atoms(sp.Symbol))
return free_symbols - self.bound_symbols
@property @property
def bound_symbols(self) -> Set[sp.Symbol]: def bound_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur on the left hand side of a main assignment or a subexpression.""" """All symbols which occur on the left hand side of a main assignment or a subexpression."""
bound_symbols_set = set([eq.lhs for eq in self.all_assignments]) bound_symbols_set = set(
assert len(bound_symbols_set) == len(self.subexpressions) + len(self.main_assignments), \ [assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)]
)
assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \
"Not in SSA form - same symbol assigned multiple times" "Not in SSA form - same symbol assigned multiple times"
bound_symbols_set = bound_symbols_set.union(*[
assignment.symbols_defined for assignment in self.all_assignments
if isinstance(assignment, pystencils.astnodes.Node)
])
return bound_symbols_set return bound_symbols_set
@property
def rhs_fields(self):
"""All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
return {s.field for s in self.rhs_symbols if hasattr(s, 'field')}
@property
def free_fields(self):
"""All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
return {s.field for s in self.free_symbols if hasattr(s, 'field')}
@property
def bound_fields(self):
"""All field accessed on the left hand side of a main assignment or a subexpression."""
return {s.field for s in self.bound_symbols if hasattr(s, 'field')}
@property @property
def defined_symbols(self) -> Set[sp.Symbol]: def defined_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur as left-hand-sides of one of the main equations""" """All symbols which occur as left-hand-sides of one of the main equations"""
return set([assignment.lhs for assignment in self.main_assignments]) lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)])
return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments
if isinstance(assignment, pystencils.astnodes.Node)]))
@property @property
def operation_count(self): def operation_count(self):
"""See :func:`count_operations` """ """See :func:`count_operations` """
return count_operations(self.all_assignments, only_type=None) return count_operations(self.all_assignments, only_type=None)
def atoms(self, *args):
return set().union(*[a.atoms(*args) for a in self.all_assignments])
def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]: def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]:
"""Returns all symbols that depend on one of the passed symbols. """Returns all symbols that depend on one of the passed symbols.
...@@ -168,6 +224,7 @@ class AssignmentCollection: ...@@ -168,6 +224,7 @@ class AssignmentCollection:
return {s: func(*args, **kwargs) for s, func in lambdas.items()} return {s: func(*args, **kwargs) for s, func in lambdas.items()}
return f return f
# ---------------------------- Creating new modified collections --------------------------------------------------- # ---------------------------- Creating new modified collections ---------------------------------------------------
def copy(self, def copy(self,
...@@ -192,35 +249,36 @@ class AssignmentCollection: ...@@ -192,35 +249,36 @@ class AssignmentCollection:
return res return res
def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False, def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False,
substitute_on_lhs: bool = True) -> 'AssignmentCollection': substitute_on_lhs: bool = True,
sort_topologically: bool = True) -> 'AssignmentCollection':
"""Returns new object, where terms are substituted according to the passed substitution dict. """Returns new object, where terms are substituted according to the passed substitution dict.
Args: Args:
substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions
add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions
substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments
sort_topologically: if subexpressions are added as substitutions and this parameters is true,
the subexpressions are sorted topologically after insertion
Returns: Returns:
New AssignmentCollection where substitutions have been applied, self is not altered. New AssignmentCollection where substitutions have been applied, self is not altered.
""" """
if substitute_on_lhs: transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs
new_subexpressions = [fast_subs(eq, substitutions) for eq in self.subexpressions] transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions)
new_equations = [fast_subs(eq, substitutions) for eq in self.main_assignments] transformed_assignments = transform(self.main_assignments, fast_subs, substitutions)
else:
new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, substitutions)) for eq in self.subexpressions]
new_equations = [Assignment(eq.lhs, fast_subs(eq.rhs, substitutions)) for eq in self.main_assignments]
if add_substitutions_as_subexpressions: if add_substitutions_as_subexpressions:
new_subexpressions = [Assignment(b, a) for a, b in substitutions.items()] + new_subexpressions transformed_subexpressions = [Assignment(b, a) for a, b in
new_subexpressions = sort_assignments_topologically(new_subexpressions) substitutions.items()] + transformed_subexpressions
return self.copy(new_equations, new_subexpressions) if sort_topologically:
transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions)
return self.copy(transformed_assignments, transformed_subexpressions)
def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection': def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':
"""Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" """Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
own_definitions = set([e.lhs for e in self.main_assignments]) own_definitions = set([e.lhs for e in self.main_assignments])
other_definitions = set([e.lhs for e in other.main_assignments]) other_definitions = set([e.lhs for e in other.main_assignments])
assert len(own_definitions.intersection(other_definitions)) == 0, \ assert len(own_definitions.intersection(other_definitions)) == 0, \
"Cannot new_merged, since both collection define the same symbols" "Cannot merge collections, since both define the same symbols"
own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions} own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}
substitution_dict = {} substitution_dict = {}
...@@ -228,12 +286,13 @@ class AssignmentCollection: ...@@ -228,12 +286,13 @@ class AssignmentCollection:
processed_other_subexpression_equations = [] processed_other_subexpression_equations = []
for other_subexpression_eq in other.subexpressions: for other_subexpression_eq in other.subexpressions:
if other_subexpression_eq.lhs in own_subexpression_symbols: if other_subexpression_eq.lhs in own_subexpression_symbols:
if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]: new_rhs = fast_subs(other_subexpression_eq.rhs, substitution_dict)
if new_rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
continue # exact the same subexpression equation exists already continue # exact the same subexpression equation exists already
else: else:
# different definition - a new name has to be introduced # different definition - a new name has to be introduced
new_lhs = next(self.subexpression_symbol_generator) new_lhs = next(self.subexpression_symbol_generator)
new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict)) new_eq = Assignment(new_lhs, new_rhs)
processed_other_subexpression_equations.append(new_eq) processed_other_subexpression_equations.append(new_eq)
substitution_dict[other_subexpression_eq.lhs] = new_lhs substitution_dict[other_subexpression_eq.lhs] = new_lhs
else: else:
...@@ -256,9 +315,9 @@ class AssignmentCollection: ...@@ -256,9 +315,9 @@ class AssignmentCollection:
if eq.lhs in symbols_to_extract: if eq.lhs in symbols_to_extract:
new_assignments.append(eq) new_assignments.append(eq)
new_sub_expr = [eq for eq in self.subexpressions new_sub_expr = [eq for eq in self.all_assignments
if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract] if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract]
return AssignmentCollection(new_assignments, new_sub_expr) return self.copy(new_assignments, new_sub_expr)
def new_without_unused_subexpressions(self) -> 'AssignmentCollection': def new_without_unused_subexpressions(self) -> 'AssignmentCollection':
"""Returns new collection that only contains subexpressions required to compute the main assignments.""" """Returns new collection that only contains subexpressions required to compute the main assignments."""
...@@ -281,8 +340,10 @@ class AssignmentCollection: ...@@ -281,8 +340,10 @@ class AssignmentCollection:
new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments] new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments]
return self.copy(new_eqs, new_subexpressions) return self.copy(new_eqs, new_subexpressions)
def new_without_subexpressions(self, subexpressions_to_keep: Set[sp.Symbol] = set()) -> 'AssignmentCollection': def new_without_subexpressions(self, subexpressions_to_keep=None) -> 'AssignmentCollection':
"""Returns a new collection where all subexpressions have been inserted.""" """Returns a new collection where all subexpressions have been inserted."""
if subexpressions_to_keep is None:
subexpressions_to_keep = set()
if len(self.subexpressions) == 0: if len(self.subexpressions) == 0:
return self.copy() return self.copy()
...@@ -291,7 +352,7 @@ class AssignmentCollection: ...@@ -291,7 +352,7 @@ class AssignmentCollection:
kept_subexpressions = [] kept_subexpressions = []
if self.subexpressions[0].lhs in subexpressions_to_keep: if self.subexpressions[0].lhs in subexpressions_to_keep:
substitution_dict = {} substitution_dict = {}
kept_subexpressions = self.subexpressions[0] kept_subexpressions.append(self.subexpressions[0])
else: else:
substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
...@@ -310,6 +371,7 @@ class AssignmentCollection: ...@@ -310,6 +371,7 @@ class AssignmentCollection:
def _repr_html_(self): def _repr_html_(self):
"""Interface to Jupyter notebook, to display as a nicely formatted HTML table""" """Interface to Jupyter notebook, to display as a nicely formatted HTML table"""
def make_html_equation_table(equations): def make_html_equation_table(equations):
no_border = 'style="border:none"' no_border = 'style="border:none"'
html_table = '<table style="border:none; width: 100%; ">' html_table = '<table style="border:none; width: 100%; ">'
...@@ -330,19 +392,19 @@ class AssignmentCollection: ...@@ -330,19 +392,19 @@ class AssignmentCollection:
return result return result
def __repr__(self): def __repr__(self):
return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.main_assignments]) return f"AssignmentCollection: {str(tuple(self.defined_symbols))[1:-1]} <- f{tuple(self.free_symbols)}"
def __str__(self): def __str__(self):
result = "Subexpressions:\n" result = "Subexpressions:\n"
for eq in self.subexpressions: for eq in self.subexpressions:
result += "\t{eq}\n".format(eq=eq) result += f"\t{eq}\n"
result += "Main Assignments:\n" result += "Main Assignments:\n"
for eq in self.main_assignments: for eq in self.main_assignments:
result += "\t{eq}\n".format(eq=eq) result += f"\t{eq}\n"
return result return result
def __iter__(self): def __iter__(self):
return self.main_assignments.__iter__() return self.all_assignments.__iter__()
@property @property
def main_assignments_dict(self): def main_assignments_dict(self):
...@@ -357,21 +419,55 @@ class AssignmentCollection: ...@@ -357,21 +419,55 @@ class AssignmentCollection:
for k, v in main_assignments_dict.items()] for k, v in main_assignments_dict.items()]
def set_sub_expressions_from_dict(self, sub_expressions_dict): def set_sub_expressions_from_dict(self, sub_expressions_dict):
self.sub_expressions = [Assignment(k, v) self.subexpressions = [Assignment(k, v)
for k, v in sub_expressions_dict.items()] for k, v in sub_expressions_dict.items()]
def find(self, *args, **kwargs):
return set.union(
*[a.find(*args, **kwargs) for a in self.all_assignments]
)
def match(self, *args, **kwargs):
rtn = {}
for a in self.all_assignments:
partial_result = a.match(*args, **kwargs)
if partial_result:
rtn.update(partial_result)
return rtn
def subs(self, *args, **kwargs):
return AssignmentCollection(
main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments],
subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions]
)
def replace(self, *args, **kwargs):
return AssignmentCollection(
main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments],
subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions]
)
def __eq__(self, other):
return set(self.all_assignments) == set(other.all_assignments)
def __bool__(self):
return bool(self.all_assignments)
class SymbolGen: class SymbolGen:
"""Default symbol generator producing number symbols ζ_0, ζ_1, ...""" """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
def __init__(self, symbol="xi"): def __init__(self, symbol="xi", dtype=None, ctr=0):
self._ctr = 0 self._ctr = ctr
self._symbol = symbol self._symbol = symbol
self._dtype = dtype
def __iter__(self): def __iter__(self):
return self return self
def __next__(self): def __next__(self):
name = "{}_{}".format(self._symbol, self._ctr) name = f"{self._symbol}_{self._ctr}"
self._ctr += 1 self._ctr += 1
if self._dtype is not None:
return pystencils.TypedSymbol(name, self._dtype)
return sp.Symbol(name) return sp.Symbol(name)
from itertools import chain
from typing import Callable, List, Sequence, Union
from collections import defaultdict
import sympy as sp
from pystencils.assignment import Assignment
from pystencils.astnodes import Node
from pystencils.field import Field
from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
from pystencils.typing import TypedSymbol
def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
"""Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
edges = []
for c1, e1 in enumerate(assignments):
if hasattr(e1, 'lhs') and hasattr(e1, 'rhs'):
symbols = [e1.lhs]
elif isinstance(e1, Node):
symbols = e1.symbols_defined
else:
raise NotImplementedError(f"Cannot sort topologically. Object of type {type(e1)} cannot be handled.")
for lhs in symbols:
for c2, e2 in enumerate(assignments):
if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:
edges.append((c1, c2))
elif isinstance(e2, Node) and lhs in e2.undefined_symbols:
edges.append((c1, c2))
return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]
def sympy_cse(ac, **kwargs):
"""Searches for common subexpressions inside the assignment collection.
Searches is done in both the existing subexpressions as well as the assignments themselves.
It uses the sympy subexpression detection to do this. Return a new assignment collection
with the additional subexpressions found
"""
symbol_gen = ac.subexpression_symbol_generator
all_assignments = [e for e in chain(ac.subexpressions, ac.main_assignments) if isinstance(e, Assignment)]
other_objects = [e for e in chain(ac.subexpressions, ac.main_assignments) if not isinstance(e, Assignment)]
replacements, new_eq = sp.cse(all_assignments, symbols=symbol_gen, **kwargs)
replacement_eqs = [Assignment(*r) for r in replacements]
modified_subexpressions = new_eq[:len(ac.subexpressions)]
modified_update_equations = new_eq[len(ac.subexpressions):]
new_subexpressions = sort_assignments_topologically(other_objects + replacement_eqs + modified_subexpressions)
return ac.copy(modified_update_equations, new_subexpressions)
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
"""Extracts common subexpressions from a list of assignments."""
from pystencils.simp.assignment_collection import AssignmentCollection
ec = AssignmentCollection([], assignments)
return sympy_cse(ec).all_assignments
def subexpression_substitution_in_existing_subexpressions(ac):
"""Goes through the subexpressions list and replaces the term in the following subexpressions."""
result = []
for outer_ctr, s in enumerate(ac.subexpressions):
new_rhs = s.rhs
for inner_ctr in range(outer_ctr):
sub_expr = ac.subexpressions[inner_ctr]
new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs)
result.append(Assignment(s.lhs, new_rhs))
return ac.copy(ac.main_assignments, result)
def subexpression_substitution_in_main_assignments(ac):
"""Replaces already existing subexpressions in the equations of the assignment_collection."""
result = []
for s in ac.main_assignments:
new_rhs = s.rhs
for sub_expr in ac.subexpressions:
new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
result.append(Assignment(s.lhs, new_rhs))
return ac.copy(result)
def add_subexpressions_for_constants(ac):
"""Extracts constant factors to subexpressions in the given assignment collection.
SymPy will exclude common factors from a sum only if they are symbols. This simplification
can be applied to exclude common numeric constants from multiple terms of a sum. As a consequence,
the number of multiplications is reduced and in some cases, more common subexpressions can be found.
"""
constants_to_subexp_dict = defaultdict(lambda: next(ac.subexpression_symbol_generator))
def visit(expr):
args = list(expr.args)
if len(args) == 0:
return expr
if isinstance(expr, sp.Add) or isinstance(expr, sp.Mul):
for i, arg in enumerate(args):
if is_constant(arg) and abs(arg) != 1:
if arg < 0:
args[i] = - constants_to_subexp_dict[- arg]
else:
args[i] = constants_to_subexp_dict[arg]
return expr.func(*(visit(a) for a in args))
main_assignments = [Assignment(a.lhs, visit(a.rhs)) for a in ac.main_assignments]
subexpressions = [Assignment(a.lhs, visit(a.rhs)) for a in ac.subexpressions]
symbols_to_collect = set(constants_to_subexp_dict.values())
main_assignments = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in main_assignments]
subexpressions = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in subexpressions]
subexpressions = [Assignment(symb, c) for c, symb in constants_to_subexp_dict.items()] + subexpressions
return ac.copy(main_assignments=main_assignments, subexpressions=subexpressions)
def add_subexpressions_for_divisions(ac):
r"""Introduces subexpressions for all divisions which have no constant in the denominator.
For example :math:`\frac{1}{x}` is replaced while :math:`\frac{1}{3}` is not replaced.
"""
divisors = set()
def search_divisors(term):
if term.func == sp.Pow:
if term.exp.is_integer and term.exp.is_number and term.exp < 0:
divisors.add(term)
else:
for a in term.args:
search_divisors(a)
for eq in ac.all_assignments:
search_divisors(eq.rhs)
divisors = sorted(list(divisors), key=lambda x: str(x))
new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)
def add_subexpressions_for_sums(ac):
r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions."""
addends = []
def contains_sum(term):
if term.func == sp.Add:
return True
if term.is_Atom:
return False
return any([contains_sum(a) for a in term.args])
def search_addends(term):
if term.func == sp.Add:
if all([not contains_sum(a) for a in term.args]):
addends.extend(term.args)
for a in term.args:
search_addends(a)
for eq in ac.all_assignments:
search_addends(eq.rhs)
addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, Field.Access)]
new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)}
return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)
def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True, data_type=None):
r"""Substitutes field accesses on rhs of assignments with subexpressions
Can change semantics of the update rule (which is the goal of this transformation)
This is useful if a field should be update in place - all values are loaded before into subexpression variables,
then the new values are computed and written to the same field in-place.
Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have
this data type. This is useful for mixed precision kernels
"""
field_reads = set()
to_iterate = []
if subexpressions:
to_iterate = chain(to_iterate, ac.subexpressions)
if main_assignments:
to_iterate = chain(to_iterate, ac.main_assignments)
for assignment in to_iterate:
if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'):
field_reads.update(assignment.rhs.atoms(Field.Access))
if not field_reads:
return ac
substitutions = dict()
for fa in field_reads:
lhs = next(ac.subexpression_symbol_generator)
if data_type is not None:
substitutions.update({fa: TypedSymbol(lhs.name, data_type)})
else:
substitutions.update({fa: lhs})
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True,
substitute_on_lhs=False, sort_topologically=False)
def transform_rhs(assignment_list, transformation, *args, **kwargs):
"""Applies a transformation function on the rhs of each element of the passed assignment list
If the list also contains other object, like AST nodes, these are ignored.
Additional parameters are passed to the transformation function"""
return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a
for a in assignment_list]
def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
return [Assignment(transformation(a.lhs, *args, **kwargs),
transformation(a.rhs, *args, **kwargs))
if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a
for a in assignment_list]
def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]):
"""Applies a given operation to all equations in collection."""
def f(ac):
return ac.copy(transform_rhs(ac.main_assignments, operation))
f.__name__ = operation.__name__
return f
def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
"""Applies the given operation on all subexpressions of the AC."""
def f(ac):
return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation))
f.__name__ = operation.__name__
return f
# TODO Markus
# make this really work for Assignmentcollections
# this function should ONLY evaluate
# do the optims_c99 elsewhere optionally
# def apply_sympy_optimisations(ac: AssignmentCollection):
# """ Evaluates constant expressions (e.g. :math:`\\sqrt{3}` will be replaced by its floating point representation)
# and applies the default sympy optimisations. See sympy.codegen.rewriting
# """
#
# # Evaluates all constant terms
#
# assignments = ac.all_assignments
#
# evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
# lambda p: p.evalf())
#
# sympy_optimisations = [evaluate_constant_terms] + list(optims_c99)
#
# assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
# if hasattr(a, 'lhs')
# else a for a in assignments]
# assignments_nodes = [a.atoms(SympyAssignment) for a in assignments]
# for a in chain.from_iterable(assignments_nodes):
# a.optimize(sympy_optimisations)
#
# return AssignmentCollection(assignments)
import sympy as sp
from collections import namedtuple from collections import namedtuple
from typing import Callable, Any, Optional, Sequence from typing import Any, Callable, Optional, Sequence
import sympy as sp
from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.simp.assignment_collection import AssignmentCollection
class SimplificationStrategy: class SimplificationStrategy:
"""A simplification strategy is an ordered collection of simplification rules. """A simplification strategy is an ordered collection of simplification rules.
Each simplification is a function taking an equation collection, and returning a new simplified Each simplification is a function taking an assignment collection, and returning a new simplified
equation collection. The strategy can nicely print intermediate simplification stages and results assignment collection. The strategy can nicely print intermediate simplification stages and results
to Jupyter notebooks. to Jupyter notebooks.
""" """
...@@ -90,7 +92,7 @@ class SimplificationStrategy: ...@@ -90,7 +92,7 @@ class SimplificationStrategy:
assignment_collection = t(assignment_collection) assignment_collection = t(assignment_collection)
end_time = timeit.default_timer() end_time = timeit.default_timer()
op = assignment_collection.operation_count op = assignment_collection.operation_count
time_str = "%.2f ms" % ((end_time - start_time) * 1000,) time_str = f"{(end_time - start_time) * 1000:.2f} ms"
total = op['adds'] + op['muls'] + op['divs'] total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement(t.__name__, time_str, op['adds'], op['muls'], op['divs'], total)) report.add(ReportElement(t.__name__, time_str, op['adds'], op['muls'], op['divs'], total))
return report return report
...@@ -127,7 +129,7 @@ class SimplificationStrategy: ...@@ -127,7 +129,7 @@ class SimplificationStrategy:
def _repr_html_(self): def _repr_html_(self):
def print_assignment_collection(title, c): def print_assignment_collection(title, c):
text = '<h5 style="padding-bottom:10px">%s</h5> <div style="padding-left:20px;">' % (title, ) text = f'<h5 style="padding-bottom:10px">{title}</h5> <div style="padding-left:20px;">'
if self.restrict_symbols: if self.restrict_symbols:
text += "\n".join(["$$" + sp.latex(e) + '$$' text += "\n".join(["$$" + sp.latex(e) + '$$'
for e in c.new_filtered(self.restrict_symbols).main_assignments]) for e in c.new_filtered(self.restrict_symbols).main_assignments])
...@@ -149,5 +151,5 @@ class SimplificationStrategy: ...@@ -149,5 +151,5 @@ class SimplificationStrategy:
def __repr__(self): def __repr__(self):
result = "Simplification Strategy:\n" result = "Simplification Strategy:\n"
for t in self._rules: for t in self._rules:
result += " - %s\n" % (t.__name__,) result += f" - {t.__name__}\n"
return result return result
import sympy as sp
from pystencils.sympyextensions import is_constant
# Subexpression Insertion
def insert_subexpressions(ac, selection_callback, skip=None):
"""
Removes a number of subexpressions from an assignment collection by
inserting their right-hand side wherever they occur.
Args:
- selection_callback: Function that is called to qualify subexpressions
for insertion. Should return `True` for any subexpression that is to be
inserted, and `False` otherwise.
- skip: Set of symbols (left-hand sides of subexpressions) that should be
ignored even if qualified by the callback.
"""
if skip is None:
skip = set()
i = 0
while i < len(ac.subexpressions):
exp = ac.subexpressions[i]
if exp.lhs not in skip and selection_callback(exp):
ac = ac.new_with_inserted_subexpression(exp.lhs)
else:
i += 1
return ac
def insert_aliases(ac, **kwargs):
"""Inserts subexpressions that are aliases of other symbols,
i.e. their right-hand side is only another symbol."""
return insert_subexpressions(ac, lambda x: isinstance(x.rhs, sp.Symbol), **kwargs)
def insert_zeros(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is zero."""
zero = sp.Integer(0)
return insert_subexpressions(ac, lambda x: x.rhs == zero, **kwargs)
def insert_constants(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is constant,
i.e. contains no symbols."""
return insert_subexpressions(ac, lambda x: is_constant(x.rhs), **kwargs)
def insert_symbol_times_minus_one(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is just a
negation of another symbol."""
def callback(exp):
rhs = exp.rhs
minus_one = sp.Integer(-1)
atoms = rhs.atoms(sp.Symbol)
return len(atoms) == 1 and rhs == minus_one * atoms.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_constant_multiples(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is a constant
multiplied with another symbol."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
numbers = rhs.atoms(sp.Number)
return len(symbols) == 1 and len(numbers) == 1 and \
rhs == numbers.pop() * symbols.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_constant_additions(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is a sum of a
constant and another symbol."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
numbers = rhs.atoms(sp.Number)
return len(symbols) == 1 and len(numbers) == 1 and \
rhs == numbers.pop() + symbols.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_squares(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is another symbol squared."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
return len(symbols) == 1 and rhs == symbols.pop() ** 2
return insert_subexpressions(ac, callback, **kwargs)
def bind_symbols_to_skip(insertion_function, skip):
return lambda ac: insertion_function(ac, skip=skip)
from pystencils.simp import (SimplificationStrategy, insert_constants, insert_symbol_times_minus_one,
insert_constant_multiples, insert_constant_additions, insert_squares, insert_zeros)
def create_simplification_strategy():
"""
Creates a default simplification `ps.simp.SimplificationStrategy`. The idea behind the default simplification
strategy is to reduce the number of subexpressions by inserting single constants and to evaluate constant
terms beforehand.
"""
s = SimplificationStrategy()
s.add(insert_symbol_times_minus_one)
s.add(insert_constant_multiples)
s.add(insert_constant_additions)
s.add(insert_squares)
s.add(insert_zeros)
s.add(insert_constants)
s.add(lambda ac: ac.new_without_unused_subexpressions())
import sympy as sp import sympy as sp
from pystencils.field import create_numpy_array_with_layout, get_layout_of_array from pystencils.field import create_numpy_array_with_layout, get_layout_of_array
...@@ -88,9 +89,12 @@ def shift_slice(slices, offset): ...@@ -88,9 +89,12 @@ def shift_slice(slices, offset):
raise ValueError() raise ValueError()
if hasattr(offset, '__len__'): if hasattr(offset, '__len__'):
return [shift_slice_component(k, off) for k, off in zip(slices, offset)] return tuple(shift_slice_component(k, off) for k, off in zip(slices, offset))
else: else:
return [shift_slice_component(k, offset) for k in slices] if isinstance(slices, slice) or isinstance(slices, int) or isinstance(slices, float):
return shift_slice_component(slices, offset)
else:
return tuple(shift_slice_component(k, offset) for k in slices)
def slice_from_direction(direction_name, dim, normal_offset=0, tangential_offset=0): def slice_from_direction(direction_name, dim, normal_offset=0, tangential_offset=0):
......
import sympy
import pystencils
import pystencils.astnodes
x_, y_, z_ = tuple(pystencils.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(3))
x_staggered, y_staggered, z_staggered = x_ + 0.5, y_ + 0.5, z_ + 0.5
def x_vector(ndim):
return sympy.Matrix(tuple(pystencils.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(ndim)))
def x_staggered_vector(ndim):
return sympy.Matrix(tuple(
pystencils.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) + 0.5 for i in range(ndim)
))
"""This submodule offers functions to work with stencils in expression an offset-list form."""
from collections import defaultdict
from typing import Sequence from typing import Sequence
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from collections import defaultdict
from pystencils.utils import binary_numbers
def inverse_direction(direction): def inverse_direction(direction):
"""Returns inverse i.e. negative of given direction tuple""" """Returns inverse i.e. negative of given direction tuple
Example:
>>> inverse_direction((1, -1, 0))
(-1, 1, 0)
"""
return tuple([-i for i in direction]) return tuple([-i for i in direction])
def is_valid_stencil(stencil, max_neighborhood=None): def inverse_direction_string(direction):
"""Returns inverse of given direction string"""
return offset_to_direction_string(inverse_direction(direction_string_to_offset(direction)))
def is_valid(stencil, max_neighborhood=None):
""" """
Tests if a nested sequence is a valid stencil i.e. all the inner sequences have the same length. Tests if a nested sequence is a valid stencil i.e. all the inner sequences have the same length.
If max_neighborhood is specified, it is also verified that the stencil does not contain any direction components If max_neighborhood is specified, it is also verified that the stencil does not contain any direction components
with absolute value greater than the maximal neighborhood. with absolute value greater than the maximal neighborhood.
Examples:
>>> is_valid([(1, 0), (1, 0, 0)]) # stencil entries have different length
False
>>> is_valid([(2, 0), (1, 0)])
True
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=1)
False
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=2)
True
""" """
expected_dim = len(stencil[0]) expected_dim = len(stencil[0])
for d in stencil: for d in stencil:
...@@ -26,15 +50,33 @@ def is_valid_stencil(stencil, max_neighborhood=None): ...@@ -26,15 +50,33 @@ def is_valid_stencil(stencil, max_neighborhood=None):
return True return True
def is_symmetric_stencil(stencil): def is_symmetric(stencil):
"""Tests for every direction d, that -d is also in the stencil""" """Tests for every direction d, that -d is also in the stencil
Examples:
>>> is_symmetric([(1, 0), (0, 1)])
False
>>> is_symmetric([(1, 0), (-1, 0)])
True
"""
for d in stencil: for d in stencil:
if inverse_direction(d) not in stencil: if inverse_direction(d) not in stencil:
return False return False
return True return True
def stencils_have_same_entries(s1, s2): def have_same_entries(s1, s2):
"""Checks if two stencils are the same
Examples:
>>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)]
>>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)]
>>> stencil3 = [(-1, 0), (0, -1), (1, 0)]
>>> have_same_entries(stencil1, stencil2)
True
>>> have_same_entries(stencil1, stencil3)
False
"""
if len(s1) != len(s2): if len(s1) != len(s2):
return False return False
return len(set(s1) - set(s2)) == 0 return len(set(s1) - set(s2)) == 0
...@@ -43,7 +85,7 @@ def stencils_have_same_entries(s1, s2): ...@@ -43,7 +85,7 @@ def stencils_have_same_entries(s1, s2):
# -------------------------------------Expression - Coefficient Form Conversion ---------------------------------------- # -------------------------------------Expression - Coefficient Form Conversion ----------------------------------------
def stencil_coefficient_dict(expr): def coefficient_dict(expr):
"""Extracts coefficients in front of field accesses in a expression. """Extracts coefficients in front of field accesses in a expression.
Expression may only access a single field at a single index. Expression may only access a single field at a single index.
...@@ -57,12 +99,12 @@ def stencil_coefficient_dict(expr): ...@@ -57,12 +99,12 @@ def stencil_coefficient_dict(expr):
Examples: Examples:
>>> import pystencils as ps >>> import pystencils as ps
>>> f = ps.fields("f(3) : double[2D]") >>> f = ps.fields("f(3) : double[2D]")
>>> field, coeffs, nonlinear_part = stencil_coefficient_dict(2 * f[0, 1](1) + 3 * f[-1, 0](1) + 123) >>> field, coeffs, nonlinear_part = coefficient_dict(2 * f[0, 1](1) + 3 * f[-1, 0](1) + 123)
>>> assert nonlinear_part == 123 and field == f(1) >>> assert nonlinear_part == 123 and field == f(1)
>>> sorted(coeffs.items()) >>> sorted(coeffs.items())
[((-1, 0), 3), ((0, 1), 2)] [((-1, 0), 3), ((0, 1), 2)]
""" """
from .field import Field from pystencils.field import Field
expr = expr.expand() expr = expr.expand()
field_accesses = expr.atoms(Field.Access) field_accesses = expr.atoms(Field.Access)
fields = set(fa.field for fa in field_accesses) fields = set(fa.field for fa in field_accesses)
...@@ -77,70 +119,70 @@ def stencil_coefficient_dict(expr): ...@@ -77,70 +119,70 @@ def stencil_coefficient_dict(expr):
field = fields.pop() field = fields.pop()
idx = accessed_indices.pop() idx = accessed_indices.pop()
coefficients = defaultdict(lambda: 0) coeffs = defaultdict(lambda: 0)
coefficients.update({fa.offsets: expr.coeff(fa) for fa in field_accesses}) coeffs.update({fa.offsets: expr.coeff(fa) for fa in field_accesses})
linear_part = sum(c * field[off](*idx) for off, c in coefficients.items()) linear_part = sum(c * field[off](*idx) for off, c in coeffs.items())
nonlinear_part = expr - linear_part nonlinear_part = expr - linear_part
return field(*idx), coefficients, nonlinear_part return field(*idx), coeffs, nonlinear_part
def stencil_coefficients(expr): def coefficients(expr):
"""Returns two lists - one with accessed offsets and one with their coefficients. """Returns two lists - one with accessed offsets and one with their coefficients.
Same restrictions as `stencil_coefficient_dict` apply. Expression must not have any nonlinear part Same restrictions as `coefficient_dict` apply. Expression must not have any nonlinear part
>>> import pystencils as ps >>> import pystencils as ps
>>> f = ps.fields("f(3) : double[2D]") >>> f = ps.fields("f(3) : double[2D]")
>>> coff = stencil_coefficients(2 * f[0, 1](1) + 3 * f[-1, 0](1)) >>> coff = coefficients(2 * f[0, 1](1) + 3 * f[-1, 0](1))
""" """
field_center, coefficients, nonlinear_part = stencil_coefficient_dict(expr) field_center, coeffs, nonlinear_part = coefficient_dict(expr)
assert nonlinear_part == 0 assert nonlinear_part == 0
stencil = list(coefficients.keys()) stencil = list(coeffs.keys())
entries = [coefficients[c] for c in stencil] entries = [coeffs[c] for c in stencil]
return stencil, entries return stencil, entries
def stencil_coefficient_list(expr, matrix_form=False): def coefficient_list(expr, matrix_form=False):
"""Returns stencil coefficients in the form of nested lists """Returns stencil coefficients in the form of nested lists
Same restrictions as `stencil_coefficient_dict` apply. Expression must not have any nonlinear part Same restrictions as `coefficient_dict` apply. Expression must not have any nonlinear part
Examples: Examples:
>>> import pystencils as ps >>> import pystencils as ps
>>> f = ps.fields("f: double[2D]") >>> f = ps.fields("f: double[2D]")
>>> stencil_coefficient_list(2 * f[0, 1] + 3 * f[-1, 0]) >>> coefficient_list(2 * f[0, 1] + 3 * f[-1, 0])
[[0, 0, 0], [3, 0, 0], [0, 2, 0]] [[0, 0, 0], [3, 0, 0], [0, 2, 0]]
>>> stencil_coefficient_list(2 * f[0, 1] + 3 * f[-1, 0], matrix_form=True) >>> coefficient_list(2 * f[0, 1] + 3 * f[-1, 0], matrix_form=True)
Matrix([ Matrix([
[0, 2, 0], [0, 2, 0],
[3, 0, 0], [3, 0, 0],
[0, 0, 0]]) [0, 0, 0]])
""" """
field_center, coefficients, nonlinear_part = stencil_coefficient_dict(expr) field_center, coeffs, nonlinear_part = coefficient_dict(expr)
assert nonlinear_part == 0 assert nonlinear_part == 0
field = field_center.field field = field_center.field
dim = field.spatial_dimensions dim = field.spatial_dimensions
max_offsets = defaultdict(lambda: 0) max_offsets = defaultdict(lambda: 0)
for offset in coefficients.keys(): for offset in coeffs.keys():
for d, off in enumerate(offset): for d, off in enumerate(offset):
max_offsets[d] = max(max_offsets[d], abs(off)) max_offsets[d] = max(max_offsets[d], abs(off))
if dim == 1: if dim == 1:
result = [coefficients[(i,)] for i in range(-max_offsets[0], max_offsets[0] + 1)] result = [coeffs[(i,)] for i in range(-max_offsets[0], max_offsets[0] + 1)]
return sp.Matrix(result) if matrix_form else result return sp.Matrix(result) if matrix_form else result
else: else:
y_range = list(range(-max_offsets[1], max_offsets[1] + 1)) y_range = list(range(-max_offsets[1], max_offsets[1] + 1))
if matrix_form: if matrix_form:
y_range.reverse() y_range.reverse()
if dim == 2: if dim == 2:
result = [[coefficients[(i, j)] result = [[coeffs[(i, j)]
for i in range(-max_offsets[0], max_offsets[0] + 1)] for i in range(-max_offsets[0], max_offsets[0] + 1)]
for j in y_range] for j in y_range]
return sp.Matrix(result) if matrix_form else result return sp.Matrix(result) if matrix_form else result
elif dim == 3: elif dim == 3:
result = [[[coefficients[(i, j, k)] result = [[[coeffs[(i, j, k)]
for i in range(-max_offsets[0], max_offsets[0] + 1)] for i in range(-max_offsets[0], max_offsets[0] + 1)]
for j in y_range] for j in y_range]
for k in range(-max_offsets[2], max_offsets[2] + 1)] for k in range(-max_offsets[2], max_offsets[2] + 1)]
...@@ -253,13 +295,45 @@ def direction_string_to_offset(direction: str, dim: int = 3): ...@@ -253,13 +295,45 @@ def direction_string_to_offset(direction: str, dim: int = 3):
return offset[:dim] return offset[:dim]
def adjacent_directions(direction):
"""
Returns all adjacent directions for a direction as tuple of tuples. This is useful for exmple to find all directions
relevant for neighbour communication.
Args:
direction: tuple representing a direction. For example (0, 1, 0) for the northern side
Examples:
>>> adjacent_directions((0, 0, 0))
((0, 0, 0),)
>>> adjacent_directions((0, 1, 0))
((0, 1, 0),)
>>> adjacent_directions((0, 1, 1))
((0, 0, 1), (0, 1, 0), (0, 1, 1))
>>> adjacent_directions((-1, -1))
((-1, -1), (-1, 0), (0, -1))
"""
result = set()
if all(e == 0 for e in direction):
result.add(direction)
return tuple(result)
binary_numbers_list = binary_numbers(len(direction))
for adjacent_direction in binary_numbers_list:
for i, entry in enumerate(direction):
if entry == 0:
adjacent_direction[i] = 0
if entry == -1 and adjacent_direction[i] == 1:
adjacent_direction[i] = -1
if not all(e == 0 for e in adjacent_direction):
result.add(tuple(adjacent_direction))
return tuple(sorted(result))
# -------------------------------------- Visualization ----------------------------------------------------------------- # -------------------------------------- Visualization -----------------------------------------------------------------
def visualize_stencil(stencil, **kwargs): def plot(stencil, **kwargs):
dim = len(stencil[0]) dim = len(stencil[0])
if dim == 2: if dim == 2:
visualize_stencil_2d(stencil, **kwargs) plot_2d(stencil, **kwargs)
else: else:
slicing = False slicing = False
if 'slice' in kwargs: if 'slice' in kwargs:
...@@ -267,18 +341,19 @@ def visualize_stencil(stencil, **kwargs): ...@@ -267,18 +341,19 @@ def visualize_stencil(stencil, **kwargs):
del kwargs['slice'] del kwargs['slice']
if slicing: if slicing:
visualize_stencil_3d_by_slicing(stencil, **kwargs) plot_3d_slicing(stencil, **kwargs)
else: else:
visualize_stencil_3d(stencil, **kwargs) plot_3d(stencil, **kwargs)
def visualize_stencil_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs): def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs):
""" """
Creates a matplotlib 2D plot of the stencil Creates a matplotlib 2D plot of the stencil
Args: Args:
stencil: sequence of directions stencil: sequence of directions
axes: optional matplotlib axes axes: optional matplotlib axes
figure: optional matplotlib figure
data: data to annotate the directions with, if none given, the indices are used data: data to annotate the directions with, if none given, the indices are used
textsize: size of annotation text textsize: size of annotation text
""" """
...@@ -292,15 +367,15 @@ def visualize_stencil_2d(stencil, axes=None, figure=None, data=None, textsize='1 ...@@ -292,15 +367,15 @@ def visualize_stencil_2d(stencil, axes=None, figure=None, data=None, textsize='1
text_box_style = BoxStyle("Round", pad=0.3) text_box_style = BoxStyle("Round", pad=0.3)
head_length = 0.1 head_length = 0.1
max_offsets = [max(abs(d[c]) for d in stencil) for c in (0, 1)] max_offsets = [max(abs(int(d[c])) for d in stencil) for c in (0, 1)]
if data is None: if data is None:
data = list(range(len(stencil))) data = list(range(len(stencil)))
for direction, annotation in zip(stencil, data): for direction, annotation in zip(stencil, data):
assert len(direction) == 2, "Works only for 2D stencils" assert len(direction) == 2, "Works only for 2D stencils"
direction = tuple(int(i) for i in direction)
if not(direction[0] == 0 and direction[1] == 0): if not (direction[0] == 0 and direction[1] == 0):
axes.arrow(0, 0, direction[0], direction[1], head_width=0.08, head_length=head_length, color='k') axes.arrow(0, 0, direction[0], direction[1], head_width=0.08, head_length=head_length, color='k')
if isinstance(annotation, sp.Basic): if isinstance(annotation, sp.Basic):
...@@ -316,7 +391,7 @@ def visualize_stencil_2d(stencil, axes=None, figure=None, data=None, textsize='1 ...@@ -316,7 +391,7 @@ def visualize_stencil_2d(stencil, axes=None, figure=None, data=None, textsize='1
else: else:
return 0 return 0
text_position = [direction[c] + position_correction(direction[c]) for c in (0, 1)] text_position = [direction[c] + position_correction(direction[c]) for c in (0, 1)]
axes.text(*text_position, annotation, verticalalignment='center', axes.text(x=text_position[0], y=text_position[1], s=annotation, verticalalignment='center',
zorder=30, horizontalalignment='center', size=textsize, zorder=30, horizontalalignment='center', size=textsize,
bbox=dict(boxstyle=text_box_style, facecolor='#00b6eb', alpha=0.85, linewidth=0)) bbox=dict(boxstyle=text_box_style, facecolor='#00b6eb', alpha=0.85, linewidth=0))
...@@ -328,12 +403,13 @@ def visualize_stencil_2d(stencil, axes=None, figure=None, data=None, textsize='1 ...@@ -328,12 +403,13 @@ def visualize_stencil_2d(stencil, axes=None, figure=None, data=None, textsize='1
axes.set_ylim([-border - max_offsets[1], border + max_offsets[1]]) axes.set_ylim([-border - max_offsets[1], border + max_offsets[1]])
def visualize_stencil_3d_by_slicing(stencil, slice_axis=2, figure=None, data=None, **kwargs): def plot_3d_slicing(stencil, slice_axis=2, figure=None, data=None, **kwargs):
"""Visualizes a 3D, first-neighborhood stencil by plotting 3 slices along a given axis. """Visualizes a 3D, first-neighborhood stencil by plotting 3 slices along a given axis.
Args: Args:
stencil: stencil as sequence of directions stencil: stencil as sequence of directions
slice_axis: 0, 1, or 2 indicating the axis to slice through slice_axis: 0, 1, or 2 indicating the axis to slice through
figure: optional matplotlib figure
data: optional data to print as text besides the arrows data: optional data to print as text besides the arrows
""" """
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -357,12 +433,12 @@ def visualize_stencil_3d_by_slicing(stencil, slice_axis=2, figure=None, data=Non ...@@ -357,12 +433,12 @@ def visualize_stencil_3d_by_slicing(stencil, slice_axis=2, figure=None, data=Non
splitted_data[split_idx].append(i if data is None else data[i]) splitted_data[split_idx].append(i if data is None else data[i])
for i in range(3): for i in range(3):
visualize_stencil_2d(splitted_directions[i], axes=axes[i], data=splitted_data[i], **kwargs) plot_2d(splitted_directions[i], axes=axes[i], data=splitted_data[i], **kwargs)
for i in [-1, 0, 1]: for i in [-1, 0, 1]:
axes[i + 1].set_title("Cut at %s=%d" % (axes_names[slice_axis], i)) axes[i + 1].set_title("Cut at %s=%d" % (axes_names[slice_axis], i), y=1.08)
def visualize_stencil_3d(stencil, figure=None, axes=None, data=None, textsize='8'): def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'):
""" """
Draws 3D stencil into a 3D coordinate system, parameters are similar to :func:`visualize_stencil_2d` Draws 3D stencil into a 3D coordinate system, parameters are similar to :func:`visualize_stencil_2d`
If data is None, no labels are drawn. To draw the labels as in the 2D case, use ``data=list(range(len(stencil)))`` If data is None, no labels are drawn. To draw the labels as in the 2D case, use ``data=list(range(len(stencil)))``
...@@ -379,17 +455,21 @@ def visualize_stencil_3d(stencil, figure=None, axes=None, data=None, textsize='8 ...@@ -379,17 +455,21 @@ def visualize_stencil_3d(stencil, figure=None, axes=None, data=None, textsize='8
FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
self._verts3d = xs, ys, zs self._verts3d = xs, ys, zs
def draw(self, renderer): def do_3d_projection(self, *_):
xs3d, ys3d, zs3d = self._verts3d xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
FancyArrowPatch.draw(self, renderer)
return np.min(zs)
if axes is None: if axes is None:
if figure is None: if figure is None:
figure = plt.figure() figure = plt.figure()
axes = figure.gca(projection='3d') axes = figure.add_subplot(projection='3d')
axes.set_aspect("equal") try:
axes.set_aspect("equal")
except NotImplementedError:
pass
if data is None: if data is None:
data = [None] * len(stencil) data = [None] * len(stencil)
...@@ -401,10 +481,11 @@ def visualize_stencil_3d(stencil, figure=None, axes=None, data=None, textsize='8 ...@@ -401,10 +481,11 @@ def visualize_stencil_3d(stencil, figure=None, axes=None, data=None, textsize='8
r = [-1, 1] r = [-1, 1]
for s, e in combinations(np.array(list(product(r, r, r))), 2): for s, e in combinations(np.array(list(product(r, r, r))), 2):
if np.sum(np.abs(s - e)) == r[1] - r[0]: if np.sum(np.abs(s - e)) == r[1] - r[0]:
axes.plot3D(*zip(s, e), color="k", alpha=0.5) axes.plot(*zip(s, e), color="k", alpha=0.5)
for d, annotation in zip(stencil, data): for d, annotation in zip(stencil, data):
assert len(d) == 3, "Works only for 3D stencils" assert len(d) == 3, "Works only for 3D stencils"
d = tuple(int(i) for i in d)
if not (d[0] == 0 and d[1] == 0 and d[2] == 0): if not (d[0] == 0 and d[1] == 0 and d[2] == 0):
if d[0] == 0: if d[0] == 0:
color = '#348abd' color = '#348abd'
...@@ -424,8 +505,8 @@ def visualize_stencil_3d(stencil, figure=None, axes=None, data=None, textsize='8 ...@@ -424,8 +505,8 @@ def visualize_stencil_3d(stencil, figure=None, axes=None, data=None, textsize='8
else: else:
annotation = str(annotation) annotation = str(annotation)
axes.text(d[0] * text_offset, d[1] * text_offset, d[2] * text_offset, axes.text(x=d[0] * text_offset, y=d[1] * text_offset, z=d[2] * text_offset,
annotation, verticalalignment='center', zorder=30, s=annotation, verticalalignment='center', zorder=30,
size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#777777', alpha=0.6, linewidth=0)) size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#777777', alpha=0.6, linewidth=0))
axes.set_xlim([-text_offset * 1.1, text_offset * 1.1]) axes.set_xlim([-text_offset * 1.1, text_offset * 1.1])
...@@ -434,14 +515,14 @@ def visualize_stencil_3d(stencil, figure=None, axes=None, data=None, textsize='8 ...@@ -434,14 +515,14 @@ def visualize_stencil_3d(stencil, figure=None, axes=None, data=None, textsize='8
axes.set_axis_off() axes.set_axis_off()
def visualize_stencil_expression(expr, **kwargs): def plot_expression(expr, **kwargs):
"""Displays coefficients of a linear update expression of a single field as matplotlib arrow drawing.""" """Displays coefficients of a linear update expression of a single field as matplotlib arrow drawing."""
stencil, coefficients = stencil_coefficients(expr) stencil, coeffs = coefficients(expr)
dim = len(stencil[0]) dim = len(stencil[0])
assert 0 < dim <= 3 assert 0 < dim <= 3
if dim == 1: if dim == 1:
return stencil_coefficient_list(expr, matrix_form=True) return coefficient_list(expr, matrix_form=True)
elif dim == 2: elif dim == 2:
return visualize_stencil_2d(stencil, data=coefficients, **kwargs) return plot_2d(stencil, data=coeffs, **kwargs)
elif dim == 3: elif dim == 3:
return visualize_stencil_3d_by_slicing(stencil, data=coefficients, **kwargs) return plot_3d_slicing(stencil, data=coeffs, **kwargs)
import itertools import itertools
import warnings
import operator import operator
from functools import reduce, partial import warnings
from collections import defaultdict, Counter from collections import Counter, defaultdict
from functools import partial, reduce
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union
import sympy as sp import sympy as sp
from sympy import PolynomialError
from sympy.functions import Abs from sympy.functions import Abs
from typing import Optional, Union, List, TypeVar, Iterable, Sequence, Callable, Dict, Tuple from sympy.core.numbers import Zero
from pystencils.data_types import get_type_of_expression, get_base_type, cast_func
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.functions import DivFunc
from pystencils.typing import CastFunc, get_type_of_expression, PointerType, VectorType
from pystencils.typing.typed_sympy import FieldPointerSymbol
T = TypeVar('T') T = TypeVar('T')
...@@ -155,17 +160,23 @@ def fast_subs(expression: T, substitutions: Dict, ...@@ -155,17 +160,23 @@ def fast_subs(expression: T, substitutions: Dict,
if type(expression) is sp.Matrix: if type(expression) is sp.Matrix:
return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions)) return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions))
def visit(expr): def visit(expr, evaluate=True):
if skip and skip(expr): if skip and skip(expr):
return expr return expr
if hasattr(expr, "fast_subs"): elif hasattr(expr, "fast_subs"):
return expr.fast_subs(substitutions) return expr.fast_subs(substitutions, skip)
if expr in substitutions: elif expr in substitutions:
return substitutions[expr] return substitutions[expr]
if not hasattr(expr, 'args'): elif not hasattr(expr, 'args'):
return expr return expr
param_list = [visit(a) for a in expr.args] elif isinstance(expr, (sp.UnevaluatedExpr, DivFunc)):
return expr if not param_list else expr.func(*param_list) args = [visit(a, False) for a in expr.args]
return expr.func(*args)
else:
param_list = [visit(a, evaluate) for a in expr.args]
if isinstance(expr, (sp.Mul, sp.Add)):
return expr if not param_list else expr.func(*param_list, evaluate=evaluate)
return expr if not param_list else expr.func(*param_list)
if len(substitutions) == 0: if len(substitutions) == 0:
return expression return expression
...@@ -173,6 +184,14 @@ def fast_subs(expression: T, substitutions: Dict, ...@@ -173,6 +184,14 @@ def fast_subs(expression: T, substitutions: Dict,
return visit(expression) return visit(expression)
def is_constant(expr):
"""Simple version of checking if a sympy expression is constant.
Works also for piecewise defined functions - sympy's is_constant() has a problem there, see:
https://github.com/sympy/sympy/issues/16662
"""
return len(expr.free_symbols) == 0
def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
required_match_replacement: Optional[Union[int, float]] = 0.5, required_match_replacement: Optional[Union[int, float]] = 0.5,
required_match_original: Optional[Union[int, float]] = None) -> sp.Expr: required_match_original: Optional[Union[int, float]] = None) -> sp.Expr:
...@@ -224,6 +243,9 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, ...@@ -224,6 +243,9 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
normalized_replacement_match = normalize_match_parameter(required_match_replacement, len(subexpression.args)) normalized_replacement_match = normalize_match_parameter(required_match_replacement, len(subexpression.args))
if isinstance(subexpression, sp.Number):
return expr.subs({replacement: subexpression})
def visit(current_expr): def visit(current_expr):
if current_expr.is_Add: if current_expr.is_Add:
expr_max_length = max(len(current_expr.args), len(subexpression.args)) expr_max_length = max(len(current_expr.args), len(subexpression.args))
...@@ -233,7 +255,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, ...@@ -233,7 +255,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
intersection = set(subexpression_coefficient_dict.keys()).intersection(set(expr_coefficients)) intersection = set(subexpression_coefficient_dict.keys()).intersection(set(expr_coefficients))
if len(intersection) >= max(normalized_replacement_match, normalized_current_expr_match): if len(intersection) >= max(normalized_replacement_match, normalized_current_expr_match):
# find common factor # find common factor
factors = defaultdict(lambda: 0) factors = defaultdict(int)
skips = 0 skips = 0
for common_symbol in subexpression_coefficient_dict.keys(): for common_symbol in subexpression_coefficient_dict.keys():
if common_symbol not in expr_coefficients: if common_symbol not in expr_coefficients:
...@@ -251,7 +273,10 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, ...@@ -251,7 +273,10 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
if not param_list: if not param_list:
return current_expr return current_expr
else: else:
return current_expr.func(*param_list, evaluate=False) if current_expr.func == sp.Mul and Zero() in param_list:
return sp.simplify(current_expr)
else:
return current_expr.func(*param_list, evaluate=False)
return visit(expr) return visit(expr)
...@@ -259,7 +284,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, ...@@ -259,7 +284,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol], def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol],
positive: Optional[bool] = None, positive: Optional[bool] = None,
replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr: replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr:
"""Replaces second order mixed terms like x*y by 2*( (x+y)**2 - x**2 - y**2 ). """Replaces second order mixed terms like 4*x*y by 2*( (x+y)**2 - x**2 - y**2 ).
This makes the term longer - simplify usually is undoing these - however this This makes the term longer - simplify usually is undoing these - however this
transformation can be done to find more common sub-expressions transformation can be done to find more common sub-expressions
...@@ -280,7 +305,7 @@ def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Sym ...@@ -280,7 +305,7 @@ def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Sym
if expr.is_Mul: if expr.is_Mul:
distinct_search_symbols = set() distinct_search_symbols = set()
nr_of_search_terms = 0 nr_of_search_terms = 0
other_factors = 1 other_factors = sp.Integer(1)
for t in expr.args: for t in expr.args:
if t in search_symbols: if t in search_symbols:
nr_of_search_terms += 1 nr_of_search_terms += 1
...@@ -331,7 +356,7 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order ...@@ -331,7 +356,7 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order
factor_count = 0 factor_count = 0
if type(product) is Mul: if type(product) is Mul:
for factor in product.args: for factor in product.args:
if type(factor) == Pow: if type(factor) is Pow:
if factor.args[0] in symbols: if factor.args[0] in symbols:
factor_count += factor.args[1] factor_count += factor.args[1]
if factor in symbols: if factor in symbols:
...@@ -341,13 +366,13 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order ...@@ -341,13 +366,13 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order
factor_count += product.args[1] factor_count += product.args[1]
return factor_count return factor_count
if type(expr) == Mul or type(expr) == Pow: if type(expr) is Mul or type(expr) is Pow:
if velocity_factors_in_product(expr) <= order: if velocity_factors_in_product(expr) <= order:
return expr return expr
else: else:
return sp.Rational(0, 1) return Zero()
if type(expr) != Add: if type(expr) is not Add:
return expr return expr
for sum_term in expr.args: for sum_term in expr.args:
...@@ -417,7 +442,104 @@ def extract_most_common_factor(term): ...@@ -417,7 +442,104 @@ def extract_most_common_factor(term):
return common_factor, term / common_factor return common_factor, term / common_factor
def count_operations(term: Union[sp.Expr, List[sp.Expr]], def recursive_collect(expr, symbols, order_by_occurences=False):
"""Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,
and so on.
``expr`` must be rewritable as a polynomial in the given ``symbols``.
It it is not, ``recursive_collect`` will fail quietly, returning the original expression.
Args:
expr: A sympy expression.
symbols: A sequence of symbols
order_by_occurences: If True, during recursive descent, always collect the symbol occuring
most often in the expression.
"""
if order_by_occurences:
symbols = list(expr.atoms(sp.Symbol) & set(symbols))
symbols = sorted(symbols, key=expr.count, reverse=True)
if len(symbols) == 0:
return expr
symbol = symbols[0]
collected = expr.collect(symbol)
try:
collected_poly = sp.Poly(collected, symbol)
except PolynomialError:
return expr
coeffs = collected_poly.all_coeffs()[::-1]
rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs))
return rec_sum
def summands(expr):
return set(expr.args) if isinstance(expr, sp.Add) else {expr}
def simplify_by_equality(expr, a, b, c):
"""
Uses the equality a = b + c, where a and b must be symbols, to simplify expr
by attempting to express additive combinations of two quantities by the third.
This works on expressions that are reducible to the form
:math:`a * (...) + b * (...) + c * (...)`,
without any mixed terms of a, b and c.
"""
if not isinstance(a, sp.Symbol) or not isinstance(b, sp.Symbol):
raise ValueError("a and b must be symbols.")
c = sp.sympify(c)
if not (isinstance(c, sp.Symbol) or is_constant(c)):
raise ValueError("c must be either a symbol or a constant!")
expr = sp.sympify(expr)
expr_expanded = sp.expand(expr)
a_coeff = expr_expanded.coeff(a, 1)
expr_expanded -= (a * a_coeff).expand()
b_coeff = expr_expanded.coeff(b, 1)
expr_expanded -= (b * b_coeff).expand()
if isinstance(c, sp.Symbol):
c_coeff = expr_expanded.coeff(c, 1)
rest = expr_expanded - (c * c_coeff).expand()
else:
c_coeff = expr_expanded / c
rest = 0
a_summands = summands(a_coeff)
b_summands = summands(b_coeff)
c_summands = summands(c_coeff)
# replace b + c by a
b_plus_c_coeffs = b_summands & c_summands
for coeff in b_plus_c_coeffs:
rest += a * coeff
b_summands -= b_plus_c_coeffs
c_summands -= b_plus_c_coeffs
# replace a - b by c
neg_b_summands = {-x for x in b_summands}
a_minus_b_coeffs = a_summands & neg_b_summands
for coeff in a_minus_b_coeffs:
rest += c * coeff
a_summands -= a_minus_b_coeffs
b_summands -= {-x for x in a_minus_b_coeffs}
# replace a - c by b
neg_c_summands = {-x for x in c_summands}
a_minus_c_coeffs = a_summands & neg_c_summands
for coeff in a_minus_c_coeffs:
rest += b * coeff
a_summands -= a_minus_c_coeffs
c_summands -= {-x for x in a_minus_c_coeffs}
# put it back together
return (rest + a * sum(a_summands) + b * sum(b_summands) + c * sum(c_summands)).expand()
def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
only_type: Optional[str] = 'real') -> Dict[str, int]: only_type: Optional[str] = 'real') -> Dict[str, int]:
"""Counts the number of additions, multiplications and division. """Counts the number of additions, multiplications and division.
...@@ -428,8 +550,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -428,8 +550,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
Returns: Returns:
dict with 'adds', 'muls' and 'divs' keys dict with 'adds', 'muls' and 'divs' keys
""" """
result = {'adds': 0, 'muls': 0, 'divs': 0} from pystencils.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division
result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0,
'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0}
if isinstance(term, Sequence): if isinstance(term, Sequence):
for element in term: for element in term:
r = count_operations(element, only_type) r = count_operations(element, only_type)
...@@ -439,16 +563,20 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -439,16 +563,20 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif isinstance(term, Assignment): elif isinstance(term, Assignment):
term = term.rhs term = term.rhs
if hasattr(term, 'evalf'):
term = term.evalf()
def check_type(e): def check_type(e):
if only_type is None: if only_type is None:
return True return True
if isinstance(e, FieldPointerSymbol) and only_type == "real":
return only_type == "int"
try: try:
base_type = get_base_type(get_type_of_expression(e)) base_type = get_type_of_expression(e)
except ValueError: except ValueError:
return False return False
if isinstance(base_type, VectorType):
return False
if isinstance(base_type, PointerType):
return only_type == 'int'
if only_type == 'int' and (base_type.is_int() or base_type.is_uint()): if only_type == 'int' and (base_type.is_int() or base_type.is_uint()):
return True return True
if only_type == 'real' and (base_type.is_float()): if only_type == 'real' and (base_type.is_float()):
...@@ -473,23 +601,37 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -473,23 +601,37 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
pass pass
elif isinstance(t, sp.Symbol): elif isinstance(t, sp.Symbol):
visit_children = False visit_children = False
elif isinstance(t, sp.tensor.Indexed): elif isinstance(t, sp.Indexed):
visit_children = False visit_children = False
elif t.is_integer: elif t.is_integer:
pass pass
elif isinstance(t, cast_func): elif isinstance(t, CastFunc):
visit_children = False visit_children = False
visit(t.args[0]) visit(t.args[0])
elif t.func is fast_sqrt:
result['fast_sqrts'] += 1
elif t.func is fast_inv_sqrt:
result['fast_inv_sqrts'] += 1
elif t.func is fast_division:
result['fast_div'] += 1
elif t.func is sp.Pow: elif t.func is sp.Pow:
if check_type(t.args[0]): if check_type(t.args[0]):
visit_children = False visit_children = True
if t.exp.is_integer and t.exp.is_number: if t.exp.is_integer and t.exp.is_number:
if t.exp >= 0: if t.exp >= 0:
result['muls'] += int(t.exp) - 1 result['muls'] += int(t.exp) - 1
else: else:
result['muls'] -= 1 if result['muls'] > 0:
result['muls'] -= 1
result['divs'] += 1 result['divs'] += 1
result['muls'] += (-int(t.exp)) - 1 result['muls'] += (-int(t.exp)) - 1
elif sp.nsimplify(t.exp) == sp.Rational(1, 2):
result['sqrts'] += 1
elif sp.nsimplify(t.exp) == -sp.Rational(1, 2):
result["sqrts"] += 1
result["divs"] += 1
else:
warnings.warn(f"Cannot handle exponent {t.exp} of sp.Pow node")
else: else:
warnings.warn("Counting operations: only integer exponents are supported in Pow, " warnings.warn("Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate") "counting will be inaccurate")
...@@ -497,8 +639,12 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -497,8 +639,12 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
for child_term, condition in t.args: for child_term, condition in t.args:
visit(child_term) visit(child_term)
visit_children = False visit_children = False
elif isinstance(t, (sp.Rel, sp.UnevaluatedExpr)):
pass
elif isinstance(t, DivFunc):
result["divs"] += 1
else: else:
warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate") warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate")
if visit_children: if visit_children:
for a in t.args: for a in t.args:
...@@ -511,14 +657,13 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -511,14 +657,13 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
def count_operations_in_ast(ast) -> Dict[str, int]: def count_operations_in_ast(ast) -> Dict[str, int]:
"""Counts number of operations in an abstract syntax tree, see also :func:`count_operations`""" """Counts number of operations in an abstract syntax tree, see also :func:`count_operations`"""
from pystencils.astnodes import SympyAssignment from pystencils.astnodes import SympyAssignment
result = {'adds': 0, 'muls': 0, 'divs': 0} result = defaultdict(int)
def visit(node): def visit(node):
if isinstance(node, SympyAssignment): if isinstance(node, SympyAssignment):
r = count_operations(node.rhs) r = count_operations(node.rhs)
result['adds'] += r['adds'] for k, v in r.items():
result['muls'] += r['muls'] result[k] += v
result['divs'] += r['divs']
else: else:
for arg in node.args: for arg in node.args:
visit(arg) visit(arg)
...@@ -547,12 +692,6 @@ def get_symmetric_part(expr: sp.Expr, symbols: Iterable[sp.Symbol]) -> sp.Expr: ...@@ -547,12 +692,6 @@ def get_symmetric_part(expr: sp.Expr, symbols: Iterable[sp.Symbol]) -> sp.Expr:
return sp.Rational(1, 2) * (expr + expr.subs(substitution_dict)) return sp.Rational(1, 2) * (expr + expr.subs(substitution_dict))
def sort_assignments_topologically(assignments: Sequence[Assignment]) -> List[Assignment]:
"""Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in assignments])
return [Assignment(a, b) for a, b in res]
class SymbolCreator: class SymbolCreator:
def __getattribute__(self, name): def __getattribute__(self, name):
return sp.Symbol(name) return sp.Symbol(name)
File moved
import hashlib
import pickle
import warnings import warnings
from collections import defaultdict, OrderedDict, namedtuple from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from types import MappingProxyType from types import MappingProxyType
import pickle from typing import Set
import hashlib
import sympy as sp import sympy as sp
from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase import pystencils as ps
import pystencils.astnodes as ast
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.typing import (CastFunc, PointerType, StructType, TypedSymbol, get_base_type,
ReinterpretCastFunc, get_next_parent_of_type, parents_of_type)
from pystencils.field import Field, FieldType from pystencils.field import Field, FieldType
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, reinterpret_cast_func, \ from pystencils.typing import FieldPointerSymbol
cast_func, pointer_arithmetic_func, get_type_of_expression, collate_types, create_type from pystencils.sympyextensions import fast_subs
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.slicing import normalize_slice from pystencils.slicing import normalize_slice
import pystencils.astnodes as ast from pystencils.integer_functions import int_div
class NestedScopes: class NestedScopes:
...@@ -80,6 +85,60 @@ def filtered_tree_iteration(node, node_type, stop_type=None): ...@@ -80,6 +85,60 @@ def filtered_tree_iteration(node, node_type, stop_type=None):
yield from filtered_tree_iteration(arg, node_type) yield from filtered_tree_iteration(arg, node_type)
def generic_visit(term, visitor):
if isinstance(term, AssignmentCollection):
new_main_assignments = generic_visit(term.main_assignments, visitor)
new_subexpressions = generic_visit(term.subexpressions, visitor)
return term.copy(new_main_assignments, new_subexpressions)
elif isinstance(term, list):
return [generic_visit(e, visitor) for e in term]
elif isinstance(term, Assignment):
return Assignment(term.lhs, generic_visit(term.rhs, visitor))
elif isinstance(term, sp.Matrix):
return term.applyfunc(lambda e: generic_visit(e, visitor))
else:
return visitor(term)
def iterate_loops_by_depth(node, nesting_depth):
"""Iterate all LoopOverCoordinate nodes in the given AST of the specified nesting depth.
Args:
node: Root node of the abstract syntax tree
nesting_depth: Nesting depth of the loops the pragmas should be applied to.
Outermost loop has depth 0.
A depth of -1 indicates the innermost loops.
Returns: Iterable listing all loop nodes of given nesting depth.
"""
from pystencils.astnodes import LoopOverCoordinate
def _internal_default(node, nesting_depth):
isloop = isinstance(node, LoopOverCoordinate)
if nesting_depth < 0: # here, a negative value indicates end of descent
return
elif nesting_depth == 0 and isloop:
yield node
else:
next_depth = nesting_depth - 1 if isloop else nesting_depth
for arg in node.args:
yield from _internal_default(arg, next_depth)
def _internal_innermost(node):
if isinstance(node, LoopOverCoordinate) and node.is_innermost_loop:
yield node
else:
for arg in node.args:
yield from _internal_innermost(arg)
if nesting_depth >= 0:
yield from _internal_default(node, nesting_depth)
elif nesting_depth == -1:
yield from _internal_innermost(node)
else:
raise ValueError(f"Invalid nesting depth: {nesting_depth}. Choose a nonnegative number, or -1.")
def unify_shape_symbols(body, common_shape, fields): def unify_shape_symbols(body, common_shape, fields):
"""Replaces symbols for array sizes to ensure they are represented by the same unique symbol. """Replaces symbols for array sizes to ensure they are represented by the same unique symbol.
...@@ -104,9 +163,10 @@ def unify_shape_symbols(body, common_shape, fields): ...@@ -104,9 +163,10 @@ def unify_shape_symbols(body, common_shape, fields):
body.subs(substitutions) body.subs(substitutions)
def get_common_shape(field_set): def get_common_field(field_set):
"""Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise """Takes a set of pystencils Fields, checks if a common spatial shape exists and returns one
ValueError is raised""" representative field, that can be used for shape information etc. in the kernel creation.
If the fields have different shapes ValueError is raised"""
nr_of_fixed_shaped_fields = 0 nr_of_fixed_shaped_fields = 0
for f in field_set: for f in field_set:
if f.has_fixed_shape: if f.has_fixed_shape:
...@@ -116,7 +176,7 @@ def get_common_shape(field_set): ...@@ -116,7 +176,7 @@ def get_common_shape(field_set):
fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape]) fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape]) var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n" msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
msg += "Variable shaped: %s \nFixed shaped: %s" % (var_field_names, fixed_field_names) msg += f"Variable shaped: {var_field_names} \nFixed shaped: {fixed_field_names}"
raise ValueError(msg) raise ValueError(msg)
shape_set = set([f.spatial_shape for f in field_set]) shape_set = set([f.spatial_shape for f in field_set])
...@@ -124,16 +184,16 @@ def get_common_shape(field_set): ...@@ -124,16 +184,16 @@ def get_common_shape(field_set):
if len(shape_set) != 1: if len(shape_set) != 1:
raise ValueError("Differently sized field accesses in loop body: " + str(shape_set)) raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0] # Sort the fields by their name to ensure that always the same field is returned
return shape reference_field = sorted(field_set, key=lambda e: str(e))[0]
return reference_field
def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layers=None, loop_order=None): def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None):
"""Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST. """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.
Args: Args:
body: Block object with inner loop contents body: Block object with inner loop contents
function_name: name of generated C function
iteration_slice: if not None, iteration is done only over this slice of the field iteration_slice: if not None, iteration is done only over this slice of the field
ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
if None, the number of ghost layers is determined automatically and assumed to be equal for a if None, the number of ghost layers is determined automatically and assumed to be equal for a
...@@ -141,27 +201,41 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer ...@@ -141,27 +201,41 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
loop_order: loop ordering from outer to inner loop (optimal ordering is same as layout) loop_order: loop ordering from outer to inner loop (optimal ordering is same as layout)
Returns: Returns:
:class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts tuple of loop-node, ghost_layer_info
""" """
# find correct ordering by inspecting participating FieldAccesses # find correct ordering by inspecting participating FieldAccesses
absolut_accesses_only = False
field_accesses = body.atoms(Field.Access) field_accesses = body.atoms(Field.Access)
field_accesses = {e for e in field_accesses if not e.is_absolute_access} field_accesses = {e for e in field_accesses if not e.is_absolute_access}
if len(field_accesses) == 0: # when kernel contains only absolute accesses
absolut_accesses_only = True
# exclude accesses to buffers from field_list, because buffers are treated separately # exclude accesses to buffers from field_list, because buffers are treated separately
field_list = [e.field for e in field_accesses if not FieldType.is_buffer(e.field)] field_list = [e.field for e in field_accesses if not (FieldType.is_buffer(e.field) or FieldType.is_custom(e.field))]
if len(field_list) == 0: # when kernel contains only custom fields
field_list = [e.field for e in field_accesses if not (FieldType.is_buffer(e.field))]
fields = set(field_list) fields = set(field_list)
if loop_order is None: if loop_order is None:
loop_order = get_optimal_loop_ordering(fields) loop_order = get_optimal_loop_ordering(fields)
shape = get_common_shape(fields) if absolut_accesses_only:
unify_shape_symbols(body, common_shape=shape, fields=fields) absolut_access_fields = {e.field for e in body.atoms(Field.Access)}
common_field = get_common_field(absolut_access_fields)
common_shape = common_field.spatial_shape
else:
common_field = get_common_field(fields)
common_shape = common_field.spatial_shape
unify_shape_symbols(body, common_shape=common_shape, fields=fields)
if iteration_slice is not None: if iteration_slice is not None:
iteration_slice = normalize_slice(iteration_slice, shape) iteration_slice = normalize_slice(iteration_slice, common_shape)
if ghost_layers is None: if ghost_layers is None:
required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses]) if absolut_accesses_only:
required_ghost_layers = 0
else:
required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order) ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
if isinstance(ghost_layers, int): if isinstance(ghost_layers, int):
ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order) ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)
...@@ -170,7 +244,7 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer ...@@ -170,7 +244,7 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
for i, loop_coordinate in enumerate(reversed(loop_order)): for i, loop_coordinate in enumerate(reversed(loop_order)):
if iteration_slice is None: if iteration_slice is None:
begin = ghost_layers[loop_coordinate][0] begin = ghost_layers[loop_coordinate][0]
end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1] end = common_shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1) new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
current_body = ast.Block([new_loop]) current_body = ast.Block([new_loop])
else: else:
...@@ -184,8 +258,29 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer ...@@ -184,8 +258,29 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
sp.sympify(slice_component)) sp.sympify(slice_component))
current_body.insert_front(assignment) current_body.insert_front(assignment)
ast_node = ast.KernelFunction(current_body, ghost_layers=ghost_layers, function_name=function_name, backend='cpu') return current_body, ghost_layers
return ast_node
def get_common_indexed_element(indexed_elements: Set[sp.IndexedBase]) -> sp.IndexedBase:
assert len(indexed_elements) > 0, "indexed_elements can not be empty"
shape_set = {s.shape for s in indexed_elements}
if len(shape_set) != 1:
for shape in shape_set:
assert not isinstance(shape, int), "If indexed elements are used, they must all have the same shape"
return sorted(indexed_elements, key=lambda e: str(e))[0]
def add_outer_loop_over_indexed_elements(loop_node: ast.Block) -> ast.Block:
indexed_elements = loop_node.atoms(sp.Indexed)
if len(indexed_elements) == 0:
return loop_node
reference_element = get_common_indexed_element(indexed_elements)
index = reference_element.indices[0].atoms(TypedSymbol)
assert len(index) == 1, "index expressions must only contain one symbol representing the index"
new_loop = ast.LoopOverCoordinate(loop_node, 0, 0,
reference_element.shape[0], 1, custom_loop_ctr=index.pop())
return ast.Block([new_loop])
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr): def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
...@@ -222,7 +317,7 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr): ...@@ -222,7 +317,7 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
if coordinate_id < field.spatial_dimensions: if coordinate_id < field.spatial_dimensions:
offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id] offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
if type(field_access.offsets[coordinate_id]) is int: if field_access.offsets[coordinate_id].is_Integer:
name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id]) name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id])
else: else:
list_to_hash.append(field_access.offsets[coordinate_id]) list_to_hash.append(field_access.offsets[coordinate_id])
...@@ -283,6 +378,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dime ...@@ -283,6 +378,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dime
if elem in specified_coordinates: if elem in specified_coordinates:
raise ValueError("Coordinate %d specified two times" % (elem,)) raise ValueError("Coordinate %d specified two times" % (elem,))
specified_coordinates.add(elem) specified_coordinates.add(elem)
for element in spec_group: for element in spec_group:
if type(element) is int: if type(element) is int:
add_new_element(element) add_new_element(element)
...@@ -303,7 +399,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dime ...@@ -303,7 +399,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dime
index = int(element[len("index"):]) index = int(element[len("index"):])
add_new_element(spatial_dimensions + index) add_new_element(spatial_dimensions + index)
else: else:
raise ValueError("Unknown specification %s" % (element,)) raise ValueError(f"Unknown specification {element}")
result.append(new_group) result.append(new_group)
...@@ -322,7 +418,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): ...@@ -322,7 +418,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
ast_node: ast before any field accesses are resolved ast_node: ast before any field accesses are resolved
loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes) loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
for GPU kernels: list of 'loop counters' from inner to outer loop for GPU kernels: list of 'loop counters' from inner to outer loop
loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default loop_iterations: iteration slice for each loop from inner to outer, for CPU kernels leave to default
Returns: Returns:
base buffer index - required by 'resolve_buffer_accesses' function base buffer index - required by 'resolve_buffer_accesses' function
...@@ -334,23 +430,43 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): ...@@ -334,23 +430,43 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
assert len(loops) == len(parents_of_innermost_loop) assert len(loops) == len(parents_of_innermost_loop)
assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop)) assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
loop_iterations = [(l.stop - l.start) / l.step for l in loops] loop_counters = [loop.loop_counter_symbol for loop in loops]
loop_counters = [l.loop_counter_symbol for l in loops] loop_iterations = [slice(loop.start, loop.stop, loop.step) for loop in loops]
actual_sizes = list()
actual_steps = list()
for ctr, s in zip(loop_counters, loop_iterations):
if s.step != 1:
if (s.stop - s.start) % s.step == 0:
actual_sizes.append((s.stop - s.start) // s.step)
else:
actual_sizes.append(int_div((s.stop - s.start), s.step))
if (ctr - s.start) % s.step == 0:
actual_steps.append((ctr - s.start) // s.step)
else:
actual_steps.append(int_div((ctr - s.start), s.step))
else:
actual_sizes.append(s.stop - s.start)
actual_steps.append(ctr - s.start)
field_accesses = ast_node.atoms(Field.Access) field_accesses = ast_node.atoms(Field.Access)
buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)} buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
loop_counters = [v * len(buffer_accesses) for v in loop_counters] buffer_index_size = len(buffer_accesses)
base_buffer_index = actual_steps[0]
actual_stride = 1
for idx, actual_step in enumerate(actual_steps[1:]):
cur_stride = actual_sizes[idx]
actual_stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
base_buffer_index += actual_stride * actual_step
return base_buffer_index * buffer_index_size
base_buffer_index = loop_counters[0]
stride = 1
for idx, var in enumerate(loop_counters[1:]):
cur_stride = loop_iterations[idx]
stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
base_buffer_index += var * stride
return base_buffer_index
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=None):
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()): if read_only_field_names is None:
read_only_field_names = set()
def visit_sympy_expr(expr, enclosing_block, sympy_assignment): def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, Field.Access): if isinstance(expr, Field.Access):
...@@ -396,7 +512,7 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s ...@@ -396,7 +512,7 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s
return visit_node(ast_node) return visit_node(ast_node)
def resolve_field_accesses(ast_node, read_only_field_names=set(), def resolve_field_accesses(ast_node, read_only_field_names=None,
field_to_base_pointer_info=MappingProxyType({}), field_to_base_pointer_info=MappingProxyType({}),
field_to_fixed_coordinates=MappingProxyType({})): field_to_fixed_coordinates=MappingProxyType({})):
""" """
...@@ -413,6 +529,8 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -413,6 +529,8 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
Returns Returns
transformed AST transformed AST
""" """
if read_only_field_names is None:
read_only_field_names = set()
field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0])) field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0])) field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
...@@ -435,7 +553,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -435,7 +553,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
else: else:
base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))] base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))]
field_ptr = FieldPointerSymbol(field.name, field.dtype, const=field.name in read_only_field_names) field_ptr = FieldPointerSymbol(
field.name,
field.dtype,
const=field.name in read_only_field_names)
def create_coordinate_dict(group_param): def create_coordinate_dict(group_param):
coordinates = {} coordinates = {}
...@@ -456,6 +577,8 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -456,6 +577,8 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if isinstance(field.dtype, StructType): if isinstance(field.dtype, StructType):
assert field.index_dimensions == 1 assert field.index_dimensions == 1
accessed_field_name = field_access.index[0] accessed_field_name = field_access.index[0]
if isinstance(accessed_field_name, sp.Symbol):
accessed_field_name = accessed_field_name.name
assert isinstance(accessed_field_name, str) assert isinstance(accessed_field_name, str)
coordinates[e] = field.dtype.get_element_offset(accessed_field_name) coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
else: else:
...@@ -469,7 +592,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -469,7 +592,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
coord_dict = create_coordinate_dict(group) coord_dict = create_coordinate_dict(group)
new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer) new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
if new_ptr not in enclosing_block.symbols_defined: if new_ptr not in enclosing_block.symbols_defined:
new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False) new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False, use_auto=False)
enclosing_block.insert_before(new_assignment, sympy_assignment) enclosing_block.insert_before(new_assignment, sympy_assignment)
last_pointer = new_ptr last_pointer = new_ptr
...@@ -479,15 +602,21 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -479,15 +602,21 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
field_access.offsets, field_access.index) field_access.offsets, field_access.index)
if isinstance(get_base_type(field_access.field.dtype), StructType): if isinstance(get_base_type(field_access.field.dtype), StructType):
new_type = field_access.field.dtype.get_element_type(field_access.index[0]) accessed_field_name = field_access.index[0]
result = reinterpret_cast_func(result, new_type) if isinstance(accessed_field_name, sp.Symbol):
accessed_field_name = accessed_field_name.name
new_type = field_access.field.dtype.get_element_type(accessed_field_name)
result = ReinterpretCastFunc(result, new_type)
return visit_sympy_expr(result, enclosing_block, sympy_assignment) return visit_sympy_expr(result, enclosing_block, sympy_assignment)
else: else:
if isinstance(expr, ast.ResolvedFieldAccess): if isinstance(expr, ast.ResolvedFieldAccess):
return expr return expr
new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args] if hasattr(expr, 'args'):
new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
else:
new_args = []
kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {} kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
return expr.func(*new_args, **kwargs) if new_args else expr return expr.func(*new_args, **kwargs) if new_args else expr
...@@ -497,8 +626,17 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -497,8 +626,17 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
assert type(enclosing_block) is ast.Block assert type(enclosing_block) is ast.Block
sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast) sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast) sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
elif isinstance(sub_ast, ast.Conditional):
enclosing_block = sub_ast.parent
assert type(enclosing_block) is ast.Block
sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr, enclosing_block, sub_ast)
visit_node(sub_ast.true_block)
if sub_ast.false_block:
visit_node(sub_ast.false_block)
else: else:
for i, a in enumerate(sub_ast.args): if isinstance(sub_ast, (bool, int, float)):
return
for a in sub_ast.args:
visit_node(a) visit_node(a)
return visit_node(ast_node) return visit_node(ast_node)
...@@ -518,30 +656,74 @@ def move_constants_before_loop(ast_node): ...@@ -518,30 +656,74 @@ def move_constants_before_loop(ast_node):
""" """
assert isinstance(node.parent, ast.Block) assert isinstance(node.parent, ast.Block)
def modifies_or_declares(node: ast.Node, symbol_names: Set[str]) -> bool:
if isinstance(node, (ps.Assignment, ast.SympyAssignment)):
if isinstance(node.lhs, ast.ResolvedFieldAccess):
return node.lhs.typed_symbol.name in symbol_names
else:
return node.lhs.name in symbol_names
elif isinstance(node, ast.Block):
for arg in node.args:
if isinstance(arg, ast.SympyAssignment) and arg.is_declaration:
continue
if modifies_or_declares(arg, symbol_names):
return True
return False
elif isinstance(node, ast.LoopOverCoordinate):
return modifies_or_declares(node.body, symbol_names)
elif isinstance(node, ast.Conditional):
return (
modifies_or_declares(node.true_block, symbol_names)
or (node.false_block and modifies_or_declares(node.false_block, symbol_names))
)
elif isinstance(node, ast.KernelFunction):
return False
else:
defs = {s.name for s in node.symbols_defined}
return bool(symbol_names.intersection(defs))
dependencies = {s.name for s in node.undefined_symbols}
last_block = node.parent last_block = node.parent
last_block_child = node last_block_child = node
element = node.parent element = node.parent
prev_element = node prev_element = node
while element: while element:
if isinstance(element, ast.Block): if isinstance(element, (ast.Conditional, ast.KernelFunction)):
# Never move out of Conditionals or KernelFunctions.
break
elif isinstance(element, ast.Block):
last_block = element last_block = element
last_block_child = prev_element last_block_child = prev_element
if isinstance(element, ast.Conditional): if any(modifies_or_declares(sibling, dependencies) for sibling in element.args):
critical_symbols = element.condition_expr.atoms(sp.Symbol) # The node depends on one of the statements in this block.
# Do not move further out.
break
elif isinstance(element, ast.LoopOverCoordinate):
if element.loop_counter_symbol.name in dependencies:
# The node depends on the loop counter.
# Do not move out of this loop.
break
else: else:
critical_symbols = element.symbols_defined raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n'
if node.undefined_symbols.intersection(critical_symbols): f'The expression {element} of type {type(element)} is not known yet.')
break
# No dependencies to symbols defined/modified within the current element.
# We can move the node up one level and in front of the current element.
prev_element = element prev_element = element
element = element.parent element = element.parent
return last_block, last_block_child return last_block, last_block_child
def check_if_assignment_already_in_block(assignment, target_block): def check_if_assignment_already_in_block(assignment, target_block, rhs_or_lhs=True):
for arg in target_block.args: for arg in target_block.args:
if type(arg) is not ast.SympyAssignment: if type(arg) is not ast.SympyAssignment:
continue continue
if arg.lhs == assignment.lhs: if (rhs_or_lhs and arg.rhs == assignment.rhs) or (not rhs_or_lhs and arg.lhs == assignment.lhs):
return arg return arg
return None return None
...@@ -557,21 +739,36 @@ def move_constants_before_loop(ast_node): ...@@ -557,21 +739,36 @@ def move_constants_before_loop(ast_node):
for block in all_blocks: for block in all_blocks:
children = block.take_child_nodes() children = block.take_child_nodes()
for child in children: for child in children:
if not isinstance(child, ast.SympyAssignment): # only move SympyAssignments
block.append(child)
continue
target, child_to_insert_before = find_block_to_move_to(child) target, child_to_insert_before = find_block_to_move_to(child)
if target == block: # movement not possible if target == block: # movement not possible
target.append(child) target.append(child)
else: else:
if isinstance(child, ast.SympyAssignment): if isinstance(child, ast.SympyAssignment):
exists_already = check_if_assignment_already_in_block(child, target) exists_already = check_if_assignment_already_in_block(child, target, False)
else: else:
exists_already = False exists_already = False
if not exists_already: if not exists_already:
target.insert_before(child, child_to_insert_before) target.insert_before(child, child_to_insert_before)
elif exists_already and exists_already.rhs == child.rhs: elif exists_already and exists_already.rhs == child.rhs:
pass if target.args.index(exists_already) > target.args.index(child_to_insert_before):
assert target.args.count(exists_already) == 1
assert target.args.count(child_to_insert_before) == 1
target.args.remove(exists_already)
target.insert_before(exists_already, child_to_insert_before)
else: else:
block.append(child) # don't move in this case - better would be to rename symbol # this variable already exists in outer block, but with different rhs
# -> symbol has to be renamed
assert isinstance(child.lhs, TypedSymbol)
new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const),
child_to_insert_before)
block.append(ast.SympyAssignment(child.lhs, new_symbol, is_const=child.is_const))
def split_inner_loop(ast_node: ast.Node, symbol_groups): def split_inner_loop(ast_node: ast.Node, symbol_groups):
...@@ -585,16 +782,16 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): ...@@ -585,16 +782,16 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
and which no symbol in a symbol group depends on, are not updated! and which no symbol in a symbol group depends on, are not updated!
""" """
all_loops = ast_node.atoms(ast.LoopOverCoordinate) all_loops = ast_node.atoms(ast.LoopOverCoordinate)
inner_loop = [l for l in all_loops if l.is_innermost_loop] inner_loop = [loop for loop in all_loops if loop.is_innermost_loop]
assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?" assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
inner_loop = inner_loop[0] inner_loop = inner_loop[0]
assert type(inner_loop.body) is ast.Block assert type(inner_loop.body) is ast.Block
outer_loop = [l for l in all_loops if l.is_outermost_loop] outer_loop = [loop for loop in all_loops if loop.is_outermost_loop]
assert len(outer_loop) == 1, "Error in AST, multiple outermost loops." assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
outer_loop = outer_loop[0] outer_loop = outer_loop[0]
symbols_with_temporary_array = OrderedDict() symbols_with_temporary_array = OrderedDict()
assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args) assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args if hasattr(a, 'lhs'))
assignment_groups = [] assignment_groups = []
for symbol_group in symbol_groups: for symbol_group in symbol_groups:
...@@ -608,30 +805,36 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): ...@@ -608,30 +805,36 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
if s in assignment_map: # if there is no assignment inside the loop body it is independent already if s in assignment_map: # if there is no assignment inside the loop body it is independent already
for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol): for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
if type(new_symbol) is not Field.Access and new_symbol not in symbols_with_temporary_array: if not isinstance(new_symbol, Field.Access) and \
new_symbol not in symbols_with_temporary_array:
symbols_to_process.append(new_symbol) symbols_to_process.append(new_symbol)
symbols_resolved.add(s) symbols_resolved.add(s)
for symbol in symbol_group: for symbol in symbol_group:
if type(symbol) is not Field.Access: if not isinstance(symbol, Field.Access):
assert type(symbol) is TypedSymbol assert type(symbol) is TypedSymbol
new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype)) new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
symbols_with_temporary_array[symbol] = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol] symbols_with_temporary_array[symbol] = sp.IndexedBase(
new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
assignment_group = [] assignment_group = []
for assignment in inner_loop.body.args: for assignment in inner_loop.body.args:
if assignment.lhs in symbols_resolved: if assignment.lhs in symbols_resolved:
new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items()) # use fast_subs here because it checks if multiplications should be evaluated or not
if type(assignment.lhs) is not Field.Access and assignment.lhs in symbol_group: new_rhs = fast_subs(assignment.rhs, symbols_with_temporary_array)
if not isinstance(assignment.lhs, Field.Access) and assignment.lhs in symbol_group:
assert type(assignment.lhs) is TypedSymbol assert type(assignment.lhs) is TypedSymbol
new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype)) new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
new_lhs = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol] new_lhs = sp.IndexedBase(new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
else: else:
new_lhs = assignment.lhs new_lhs = assignment.lhs
assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs)) assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
assignment_groups.append(assignment_group) assignment_groups.append(assignment_group)
new_loops = [inner_loop.new_loop_with_different_body(ast.Block(group)) for group in assignment_groups] new_loops = [
inner_loop.new_loop_with_different_body(ast.Block(group))
for group in assignment_groups
]
inner_loop.parent.replace(inner_loop, ast.Block(new_loops)) inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
for tmp_array in symbols_with_temporary_array: for tmp_array in symbols_with_temporary_array:
...@@ -648,14 +851,15 @@ def cut_loop(loop_node, cutting_points): ...@@ -648,14 +851,15 @@ def cut_loop(loop_node, cutting_points):
One loop is transformed into len(cuttingPoints)+1 new loops that range from One loop is transformed into len(cuttingPoints)+1 new loops that range from
old_begin to cutting_points[1], ..., cutting_points[-1] to old_end old_begin to cutting_points[1], ..., cutting_points[-1] to old_end
Modifies the ast in place Modifies the ast in place. Note Issue #5783 of SymPy. Deepcopy will evaluate mul
https://github.com/sympy/sympy/issues/5783
Returns: Returns:
list of new loop nodes list of new loop nodes
""" """
if loop_node.step != 1: if loop_node.step != 1:
raise NotImplementedError("Can only split loops that have a step of 1") raise NotImplementedError("Can only split loops that have a step of 1")
new_loops = [] new_loops = ast.Block([])
new_start = loop_node.start new_start = loop_node.start
cutting_points = list(cutting_points) + [loop_node.stop] cutting_points = list(cutting_points) + [loop_node.stop]
for new_end in cutting_points: for new_end in cutting_points:
...@@ -666,8 +870,9 @@ def cut_loop(loop_node, cutting_points): ...@@ -666,8 +870,9 @@ def cut_loop(loop_node, cutting_points):
elif new_end - new_start == 0: elif new_end - new_start == 0:
pass pass
else: else:
new_loop = ast.LoopOverCoordinate(deepcopy(loop_node.body), loop_node.coordinate_to_loop_over, new_loop = ast.LoopOverCoordinate(
new_start, new_end, loop_node.step) deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
new_start, new_end, loop_node.step)
new_loops.append(new_loop) new_loops.append(new_loop)
new_start = new_end new_start = new_end
loop_node.parent.replace(loop_node, new_loops) loop_node.parent.replace(loop_node, new_loops)
...@@ -685,11 +890,16 @@ def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = Fa ...@@ -685,11 +890,16 @@ def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = Fa
This analysis needs the integer set library (ISL) islpy, so it is not done by This analysis needs the integer set library (ISL) islpy, so it is not done by
default. default.
""" """
from sympy.codegen.rewriting import ReplaceOptim, optimize
remove_casts = ReplaceOptim(lambda e: isinstance(e, CastFunc), lambda p: p.expr)
for conditional in node.atoms(ast.Conditional): for conditional in node.atoms(ast.Conditional):
conditional.condition_expr = sp.simplify(conditional.condition_expr) # TODO simplify conditional before the type system! Casts make it very hard here
if conditional.condition_expr == sp.true: condition_expression = optimize(conditional.condition_expr, [remove_casts])
condition_expression = sp.simplify(condition_expression)
if condition_expression == sp.true:
conditional.parent.replace(conditional, [conditional.true_block]) conditional.parent.replace(conditional, [conditional.true_block])
elif conditional.condition_expr == sp.false: elif condition_expression == sp.false:
conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else []) conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
elif loop_counter_simplification: elif loop_counter_simplification:
try: try:
...@@ -715,243 +925,19 @@ def cleanup_blocks(node: ast.Node) -> None: ...@@ -715,243 +925,19 @@ def cleanup_blocks(node: ast.Node) -> None:
cleanup_blocks(a) cleanup_blocks(a)
class KernelConstraintsCheck: def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, include_first=True) -> None:
"""Checks if the input to create_kernel is valid. """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last or
first and last element"""
Test the following conditions:
- SSA Form for pure symbols:
- Every pure symbol may occur only once as left-hand-side of an assignment
- Every pure symbol that is read, may not be written to later
- Independence / Parallelization condition:
- a field that is written may only be read at exact the same spatial position
(Pure symbols are symbols that are not Field.Accesses)
"""
FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
def __init__(self, type_for_symbol, check_independence_condition):
self._type_for_symbol = type_for_symbol
self.scopes = NestedScopes()
self._field_writes = defaultdict(set)
self.fields_read = set()
self.check_independence_condition = check_independence_condition
def process_assignment(self, assignment):
# for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
new_rhs = self.process_expression(assignment.rhs)
new_lhs = self._process_lhs(assignment.lhs)
return ast.SympyAssignment(new_lhs, new_rhs)
def process_expression(self, rhs, type_constants=True):
self._update_accesses_rhs(rhs)
if isinstance(rhs, Field.Access):
self.fields_read.add(rhs.field)
self.fields_read.update(rhs.indirect_addressing_fields)
return rhs
elif isinstance(rhs, TypedSymbol):
return rhs
elif isinstance(rhs, sp.Symbol):
return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
elif type_constants and isinstance(rhs, sp.Number):
return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
elif isinstance(rhs, sp.Mul):
new_args = [self.process_expression(arg, type_constants) if arg not in (-1, 1) else arg for arg in rhs.args]
return rhs.func(*new_args) if new_args else rhs
elif isinstance(rhs, sp.Indexed):
return rhs
else:
if isinstance(rhs, sp.Pow):
# don't process exponents -> they should remain integers
return sp.Pow(self.process_expression(rhs.args[0], type_constants), rhs.args[1])
else:
new_args = [self.process_expression(arg, type_constants) for arg in rhs.args]
return rhs.func(*new_args) if new_args else rhs
@property
def fields_written(self):
return set(k.field for k, v in self._field_writes.items() if len(v))
def _process_lhs(self, lhs):
assert isinstance(lhs, sp.Symbol)
self._update_accesses_lhs(lhs)
if not isinstance(lhs, Field.Access) and not isinstance(lhs, TypedSymbol):
return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
else:
return lhs
def _update_accesses_lhs(self, lhs):
if isinstance(lhs, Field.Access):
fai = self.FieldAndIndex(lhs.field, lhs.index)
self._field_writes[fai].add(lhs.offsets)
if len(self._field_writes[fai]) > 1:
raise ValueError("Field {} is written at two different locations".format(lhs.field.name))
elif isinstance(lhs, sp.Symbol):
if self.scopes.is_defined_locally(lhs):
raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(lhs.name))
if lhs in self.scopes.free_parameters:
raise ValueError("Symbol {} is written, after it has been read".format(lhs.name))
self.scopes.define_symbol(lhs)
def _update_accesses_rhs(self, rhs):
if isinstance(rhs, Field.Access) and self.check_independence_condition:
writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)]
for write_offset in writes:
assert len(writes) == 1
if write_offset != rhs.offsets:
raise ValueError("Violation of loop independence condition. Field "
"{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset))
self.fields_read.add(rhs.field)
elif isinstance(rhs, sp.Symbol):
self.scopes.access_symbol(rhs)
def add_types(eqs, type_for_symbol, check_independence_condition):
"""Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
Additionally returns sets of all fields which are read/written
Args:
eqs: list of equations
type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed
kernels
Returns:
``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
list of equations where symbols have been replaced by typed symbols
"""
if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'):
type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
def visit(obj):
if isinstance(obj, list) or isinstance(obj, tuple):
return [visit(e) for e in obj]
if isinstance(obj, sp.Eq) or isinstance(obj, ast.SympyAssignment) or isinstance(obj, Assignment):
return check.process_assignment(obj)
elif isinstance(obj, ast.Conditional):
check.scopes.push()
false_block = None if obj.false_block is None else visit(obj.false_block)
result = ast.Conditional(check.process_expression(obj.condition_expr, type_constants=False),
true_block=visit(obj.true_block), false_block=false_block)
check.scopes.pop()
return result
elif isinstance(obj, ast.Block):
check.scopes.push()
result = ast.Block([visit(e) for e in obj.args])
check.scopes.pop()
return result
elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
return obj
else:
raise ValueError("Invalid object in kernel " + str(type(obj)))
typed_equations = visit(eqs)
return check.fields_read, check.fields_written, typed_equations
def insert_casts(node):
"""Checks the types and inserts casts and pointer arithmetic where necessary.
Args:
node: the head node of the ast
Returns:
modified AST
"""
def cast(zipped_args_types, target_dtype):
"""
Adds casts to the arguments if their type differs from the target type
:param zipped_args_types: a zipped list of args and types
:param target_dtype: The target data type
:return: args with possible casts
"""
casted_args = []
for argument, data_type in zipped_args_types:
if data_type.numpy_dtype != target_dtype.numpy_dtype: # ignoring const
casted_args.append(cast_func(argument, target_dtype))
else:
casted_args.append(argument)
return casted_args
def pointer_arithmetic(expr_args):
"""
Creates a valid pointer arithmetic function
:param expr_args: Arguments of the add expression
:return: pointer_arithmetic_func
"""
pointer = None
new_args = []
for arg, data_type in expr_args:
if data_type.func is PointerType:
assert pointer is None
pointer = arg
for arg, data_type in expr_args:
if arg != pointer:
assert data_type.is_int() or data_type.is_uint()
new_args.append(arg)
new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
return pointer_arithmetic_func(pointer, new_args)
if isinstance(node, sp.AtomicExpr) or isinstance(node, cast_func):
return node
args = []
for arg in node.args:
args.append(insert_casts(arg))
# TODO indexed, LoopOverCoordinate
if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
# TODO optimize pow, don't cast integer on double
types = [get_type_of_expression(arg) for arg in args]
assert len(types) > 0
target = collate_types(types)
zipped = list(zip(args, types))
if target.func is PointerType:
assert node.func is sp.Add
return pointer_arithmetic(zipped)
else:
return node.func(*cast(zipped, target))
elif node.func is ast.SympyAssignment:
lhs = args[0]
rhs = args[1]
target = get_type_of_expression(lhs)
if target.func is PointerType:
return node.func(*args) # TODO fix, not complete
else:
return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
elif node.func is ast.ResolvedFieldAccess:
return node
elif node.func is ast.Block:
for old_arg, new_arg in zip(node.args, args):
node.replace(old_arg, new_arg)
return node
elif node.func is ast.LoopOverCoordinate:
for old_arg, new_arg in zip(node.args, args):
node.replace(old_arg, new_arg)
return node
elif node.func is sp.Piecewise:
expressions = [expr for (expr, _) in args]
types = [get_type_of_expression(expr) for expr in expressions]
target = collate_types(types)
zipped = list(zip(expressions, types))
casted_expressions = cast(zipped, target)
args = [arg.func(*[expr, arg.cond]) for (arg, expr) in zip(args, casted_expressions)]
return node.func(*args)
def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None:
"""Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element"""
all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop] all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop" assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
inner_loop = all_inner_loops.pop() inner_loop = all_inner_loops.pop()
for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True): for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
cut_loop(loop, [loop.stop - 1]) if include_first:
cut_loop(loop, [loop.start + 1, loop.stop - 1])
else:
cut_loop(loop, [loop.stop - 1])
simplify_conditionals(function_node.body, loop_counter_simplification=True) simplify_conditionals(function_node.body, loop_counter_simplification=True)
cleanup_blocks(function_node.body) cleanup_blocks(function_node.body)
...@@ -961,58 +947,6 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) - ...@@ -961,58 +947,6 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -
# --------------------------------------- Helper Functions ------------------------------------------------------------- # --------------------------------------- Helper Functions -------------------------------------------------------------
def typing_from_sympy_inspection(eqs, default_type="double"):
"""
Creates a default symbol name to type mapping.
If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
Args:
eqs: list of equations
default_type: the type for non-boolean symbols
Returns:
dictionary, mapping symbol name to type
"""
result = defaultdict(lambda: default_type)
for eq in eqs:
if isinstance(eq, ast.Conditional):
result.update(typing_from_sympy_inspection(eq.true_block.args))
if eq.false_block:
result.update(typing_from_sympy_inspection(eq.false_block.args))
elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment):
continue
else:
# problematic case here is when rhs is a symbol: then it is impossible to decide here without
# further information what type the left hand side is - default fallback is the dict value then
if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
result[eq.lhs.name] = "bool"
return result
def get_next_parent_of_type(node, parent_type):
"""Returns the next parent node of given type or None, if root is reached.
Traverses the AST nodes parents until a parent of given type was found.
If no such parent is found, None is returned
"""
parent = node.parent
while parent is not None:
if isinstance(parent, parent_type):
return parent
parent = parent.parent
return None
def parents_of_type(node, parent_type, include_current=False):
"""Generator for all parent nodes of given type"""
parent = node if include_current else node.parent
while parent is not None:
if isinstance(parent, parent_type):
yield parent
parent = parent.parent
def get_optimal_loop_ordering(fields): def get_optimal_loop_ordering(fields):
""" """
Determines the optimal loop order for a given set of fields. Determines the optimal loop order for a given set of fields.
...@@ -1028,17 +962,50 @@ def get_optimal_loop_ordering(fields): ...@@ -1028,17 +962,50 @@ def get_optimal_loop_ordering(fields):
ref_field = next(iter(fields)) ref_field = next(iter(fields))
for field in fields: for field in fields:
if field.spatial_dimensions != ref_field.spatial_dimensions: if field.spatial_dimensions != ref_field.spatial_dimensions:
raise ValueError("All fields have to have the same number of spatial dimensions. Spatial field dimensions: " raise ValueError(
+ str({f.name: f.spatial_shape for f in fields})) "All fields have to have the same number of spatial dimensions. Spatial field dimensions: "
+ str({f.name: f.spatial_shape
for f in fields}))
layouts = set([field.layout for field in fields]) layouts = set([field.layout for field in fields])
if len(layouts) > 1: if len(layouts) > 1:
raise ValueError("Due to different layout of the fields no optimal loop ordering exists " raise ValueError(
+ str({f.name: f.layout for f in fields})) "Due to different layout of the fields no optimal loop ordering exists "
+ str({f.name: f.layout
for f in fields}))
layout = list(layouts)[0] layout = list(layouts)[0]
return list(layout) return list(layout)
def get_loop_hierarchy(ast_node):
"""Determines the loop structure around a given AST node, i.e. the node has to be inside the loops.
Returns:
sequence of LoopOverCoordinate nodes, starting from outer loop to innermost loop
"""
result = []
node = ast_node
while node is not None:
node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
if node:
result.append(node.coordinate_to_loop_over)
return reversed(result)
def get_loop_counter_symbol_hierarchy(ast_node):
"""Determines the loop counter symbols around a given AST node.
:param ast_node: the AST node
:return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop
"""
result = []
node = ast_node
while node is not None:
node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
if node:
result.append(node.loop_counter_symbol)
return result
def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None: def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None:
"""Replaces the stride of the innermost loop of a variable sized kernel with 1 (assumes optimal loop ordering). """Replaces the stride of the innermost loop of a variable sized kernel with 1 (assumes optimal loop ordering).
...@@ -1050,7 +1017,9 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None: ...@@ -1050,7 +1017,9 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None:
""" """
inner_loops = [] inner_loops = []
inner_loop_counters = set() inner_loop_counters = set()
for loop in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment): for loop in filtered_tree_iteration(ast_node,
ast.LoopOverCoordinate,
stop_type=ast.SympyAssignment):
if loop.is_innermost_loop: if loop.is_innermost_loop:
inner_loops.append(loop) inner_loops.append(loop)
inner_loop_counters.add(loop.coordinate_to_loop_over) inner_loop_counters.add(loop.coordinate_to_loop_over)
...@@ -1061,8 +1030,10 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None: ...@@ -1061,8 +1030,10 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None:
inner_loop_counter = inner_loop_counters.pop() inner_loop_counter = inner_loop_counters.pop()
parameters = ast_node.get_parameters() parameters = ast_node.get_parameters()
stride_params = [p.symbol for p in parameters stride_params = [
if p.is_field_stride and p.symbol.coordinate == inner_loop_counter] p.symbol for p in parameters
if p.is_field_stride and p.symbol.coordinate == inner_loop_counter
]
subs_dict = {stride_param: 1 for stride_param in stride_params} subs_dict = {stride_param: 1 for stride_param in stride_params}
if subs_dict: if subs_dict:
ast_node.subs(subs_dict) ast_node.subs(subs_dict)
...@@ -1073,17 +1044,23 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int: ...@@ -1073,17 +1044,23 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
Args: Args:
ast_node: kernel function node before vectorization transformation has been applied ast_node: kernel function node before vectorization transformation has been applied
block_size: sequence defining block size in x, y, (z) direction block_size: sequence defining block size in x, y, (z) direction.
If chosen as zero the direction will not be used for blocking.
Returns: Returns:
number of dimensions blocked number of dimensions blocked
""" """
loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)] loops = [
l for l in filtered_tree_iteration(
ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
]
body = ast_node.body body = ast_node.body
coordinates = [] coordinates = []
coordinates_taken_into_account = 0
loop_starts = {} loop_starts = {}
loop_stops = {} loop_stops = {}
for loop in loops: for loop in loops:
coord = loop.coordinate_to_loop_over coord = loop.coordinate_to_loop_over
if coord not in coordinates: if coord not in coordinates:
...@@ -1092,26 +1069,36 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int: ...@@ -1092,26 +1069,36 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
loop_stops[coord] = loop.stop loop_stops[coord] = loop.stop
else: else:
assert loop.start == loop_starts[coord] and loop.stop == loop_stops[coord], \ assert loop.start == loop_starts[coord] and loop.stop == loop_stops[coord], \
"Multiple loops over coordinate {} with different loop bounds".format(coord) f"Multiple loops over coordinate {coord} with different loop bounds"
# Create the outer loops that iterate over the blocks # Create the outer loops that iterate over the blocks
outer_loop = None outer_loop = None
for coord in reversed(coordinates): for coord in reversed(coordinates):
if block_size[coord] == 0:
continue
coordinates_taken_into_account += 1
body = ast.Block([outer_loop]) if outer_loop else body body = ast.Block([outer_loop]) if outer_loop else body
outer_loop = ast.LoopOverCoordinate(body, coord, loop_starts[coord], loop_stops[coord], outer_loop = ast.LoopOverCoordinate(body,
step=block_size[coord], is_block_loop=True) coord,
loop_starts[coord],
loop_stops[coord],
step=block_size[coord],
is_block_loop=True)
ast_node.body = ast.Block([outer_loop]) ast_node.body = ast.Block([outer_loop])
# modify the existing loops to only iterate within one block # modify the existing loops to only iterate within one block
for inner_loop in loops: for inner_loop in loops:
coord = inner_loop.coordinate_to_loop_over coord = inner_loop.coordinate_to_loop_over
if block_size[coord] == 0:
continue
block_ctr = ast.LoopOverCoordinate.get_block_loop_counter_symbol(coord) block_ctr = ast.LoopOverCoordinate.get_block_loop_counter_symbol(coord)
loop_range = inner_loop.stop - inner_loop.start loop_range = inner_loop.stop - inner_loop.start
if sp.sympify(loop_range).is_number and loop_range % block_size[coord] == 0: if sp.sympify(
loop_range).is_number and loop_range % block_size[coord] == 0:
stop = block_ctr + block_size[coord] stop = block_ctr + block_size[coord]
else: else:
stop = sp.Min(inner_loop.stop, block_ctr + block_size[coord]) stop = sp.Min(inner_loop.stop, block_ctr + block_size[coord])
inner_loop.start = block_ctr inner_loop.start = block_ctr
inner_loop.stop = stop inner_loop.stop = stop
return len(coordinates) return coordinates_taken_into_account
from pystencils.typing.cast_functions import (CastFunc, BooleanCastFunc, VectorMemoryAccess, ReinterpretCastFunc,
PointerArithmeticFunc)
from pystencils.typing.types import (is_supported_type, numpy_name_to_c, AbstractType, BasicType, VectorType,
PointerType, StructType, create_type)
from pystencils.typing.typed_sympy import (assumptions_from_dtype, TypedSymbol, FieldStrideSymbol, FieldShapeSymbol,
FieldPointerSymbol, CFunction)
from pystencils.typing.utilities import (typed_symbols, get_base_type, result_type, collate_types,
get_type_of_expression, get_next_parent_of_type, parents_of_type)
__all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCastFunc', 'PointerArithmeticFunc',
'is_supported_type', 'numpy_name_to_c', 'AbstractType', 'BasicType',
'VectorType', 'PointerType', 'StructType', 'create_type', 'assumptions_from_dtype',
'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', 'CFunction',
'typed_symbols', 'get_base_type', 'result_type', 'collate_types',
'get_type_of_expression', 'get_next_parent_of_type', 'parents_of_type']