Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 3224 additions and 580 deletions
import time
from pystencils.integer_functions import modulo_ceil
class TimeLoop:
def __init__(self, steps=2):
self._call_data = []
self._fixed_steps = steps
self._pre_run_functions = []
self._post_run_functions = []
self._single_step_functions = []
self.time_steps_run = 0
@property
def fixed_steps(self):
return self._fixed_steps
def add_pre_run_function(self, f):
self._pre_run_functions.append(f)
def add_post_run_function(self, f):
self._post_run_functions.append(f)
def add_single_step_function(self, f):
self._single_step_functions.append(f)
def add_call(self, functor, argument_list):
if hasattr(functor, 'kernel'):
functor = functor.kernel
if not isinstance(argument_list, list):
argument_list = [argument_list]
for argument_dict in argument_list:
self._call_data.append((functor, argument_dict))
def pre_run(self):
for f in self._pre_run_functions:
f()
def post_run(self):
for f in self._post_run_functions:
f()
def run(self, time_steps=1):
self.pre_run()
fixed_steps = self._fixed_steps
call_data = self._call_data
main_iterations, rest_iterations = divmod(time_steps, fixed_steps)
try:
for _ in range(main_iterations):
for func, kwargs in call_data:
func(**kwargs)
self.time_steps_run += fixed_steps
for _ in range(rest_iterations):
for func in self._single_step_functions:
func()
self.time_steps_run += 1
except KeyboardInterrupt:
pass
self.post_run()
def benchmark_run(self, time_steps=0, init_time_steps=0):
init_time_steps_rounded = modulo_ceil(init_time_steps, self._fixed_steps)
time_steps_rounded = modulo_ceil(time_steps, self._fixed_steps)
call_data = self._call_data
self.pre_run()
for i in range(init_time_steps_rounded // self._fixed_steps):
for func, kwargs in call_data:
func(**kwargs)
self.time_steps_run += init_time_steps_rounded
start = time.perf_counter()
for i in range(time_steps_rounded // self._fixed_steps):
for func, kwargs in call_data:
func(**kwargs)
end = time.perf_counter()
self.time_steps_run += time_steps_rounded
self.post_run()
time_for_one_iteration = (end - start) / time_steps
return time_for_one_iteration
def run_time_span(self, seconds):
iterations = 0
self.pre_run()
start = time.perf_counter()
while time.perf_counter() < start + seconds:
for func, kwargs in self._call_data:
func(**kwargs)
iterations += self._fixed_steps
end = time.perf_counter()
self.post_run()
self.time_steps_run += iterations
return iterations, end - start
def benchmark(self, time_for_benchmark=5, init_time_steps=2, number_of_time_steps_for_estimation='auto'):
"""Returns the time in seconds for one time step.
Args:
time_for_benchmark: number of seconds benchmark should take
init_time_steps: number of time steps run initially for warm up, to get arrays into cache etc
number_of_time_steps_for_estimation: time steps run before real benchmarks, to determine number of time
steps that approximately take 'time_for_benchmark' or 'auto'
"""
# Run a few time step to get first estimate
if number_of_time_steps_for_estimation == 'auto':
self.run(1)
iterations, total_time = self.run_time_span(0.5)
duration_of_time_step = total_time / iterations
else:
duration_of_time_step = self.benchmark_run(number_of_time_steps_for_estimation, init_time_steps)
# Run for approximately 'time_for_benchmark' seconds
time_steps = int(time_for_benchmark / duration_of_time_step)
time_steps = max(time_steps, 4)
return self.benchmark_run(time_steps, init_time_steps)
import hashlib
import pickle
import warnings
from collections import OrderedDict
from copy import deepcopy
from types import MappingProxyType
from typing import Set
import sympy as sp
import pystencils as ps
import pystencils.astnodes as ast
from pystencils.assignment import Assignment
from pystencils.typing import (CastFunc, PointerType, StructType, TypedSymbol, get_base_type,
ReinterpretCastFunc, get_next_parent_of_type, parents_of_type)
from pystencils.field import Field, FieldType
from pystencils.typing import FieldPointerSymbol
from pystencils.sympyextensions import fast_subs
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.slicing import normalize_slice
from pystencils.integer_functions import int_div
class NestedScopes:
"""Symbol visibility model using nested scopes
- every accessed symbol that was not defined before, is added as a "free parameter"
- free parameters are global, i.e. they are not in scopes
- push/pop adds or removes a scope
>>> s = NestedScopes()
>>> s.access_symbol("a")
>>> s.is_defined("a")
False
>>> s.free_parameters
{'a'}
>>> s.define_symbol("b")
>>> s.is_defined("b")
True
>>> s.push()
>>> s.is_defined_locally("b")
False
>>> s.define_symbol("c")
>>> s.pop()
>>> s.is_defined("c")
False
"""
def __init__(self):
self.free_parameters = set()
self._defined = [set()]
def access_symbol(self, symbol):
if not self.is_defined(symbol):
self.free_parameters.add(symbol)
def define_symbol(self, symbol):
self._defined[-1].add(symbol)
def is_defined(self, symbol):
return any(symbol in scopes for scopes in self._defined)
def is_defined_locally(self, symbol):
return symbol in self._defined[-1]
def push(self):
self._defined.append(set())
def pop(self):
self._defined.pop()
assert self.depth >= 1
@property
def depth(self):
return len(self._defined)
def filtered_tree_iteration(node, node_type, stop_type=None):
for arg in node.args:
if isinstance(arg, node_type):
yield arg
elif stop_type and isinstance(node, stop_type):
continue
yield from filtered_tree_iteration(arg, node_type)
def generic_visit(term, visitor):
if isinstance(term, AssignmentCollection):
new_main_assignments = generic_visit(term.main_assignments, visitor)
new_subexpressions = generic_visit(term.subexpressions, visitor)
return term.copy(new_main_assignments, new_subexpressions)
elif isinstance(term, list):
return [generic_visit(e, visitor) for e in term]
elif isinstance(term, Assignment):
return Assignment(term.lhs, generic_visit(term.rhs, visitor))
elif isinstance(term, sp.Matrix):
return term.applyfunc(lambda e: generic_visit(e, visitor))
else:
return visitor(term)
def iterate_loops_by_depth(node, nesting_depth):
"""Iterate all LoopOverCoordinate nodes in the given AST of the specified nesting depth.
Args:
node: Root node of the abstract syntax tree
nesting_depth: Nesting depth of the loops the pragmas should be applied to.
Outermost loop has depth 0.
A depth of -1 indicates the innermost loops.
Returns: Iterable listing all loop nodes of given nesting depth.
"""
from pystencils.astnodes import LoopOverCoordinate
def _internal_default(node, nesting_depth):
isloop = isinstance(node, LoopOverCoordinate)
if nesting_depth < 0: # here, a negative value indicates end of descent
return
elif nesting_depth == 0 and isloop:
yield node
else:
next_depth = nesting_depth - 1 if isloop else nesting_depth
for arg in node.args:
yield from _internal_default(arg, next_depth)
def _internal_innermost(node):
if isinstance(node, LoopOverCoordinate) and node.is_innermost_loop:
yield node
else:
for arg in node.args:
yield from _internal_innermost(arg)
if nesting_depth >= 0:
yield from _internal_default(node, nesting_depth)
elif nesting_depth == -1:
yield from _internal_innermost(node)
else:
raise ValueError(f"Invalid nesting depth: {nesting_depth}. Choose a nonnegative number, or -1.")
def unify_shape_symbols(body, common_shape, fields):
"""Replaces symbols for array sizes to ensure they are represented by the same unique symbol.
When creating a kernel with variable array sizes, all passed arrays must have the same size.
This is ensured when the kernel is called. Inside the kernel this means that only on symbol has to be used instead
of one for each field. For example shape_arr1[0] and shape_arr2[0] must be equal, so they should also be
represented by the same symbol.
Args:
body: ast node, for the kernel part where substitutions is made, is modified in-place
common_shape: shape of the field that was chosen
fields: all fields whose shapes should be replaced by common_shape
"""
substitutions = {}
for field in fields:
assert len(field.spatial_shape) == len(common_shape)
if not field.has_fixed_shape:
for common_shape_component, shape_component in zip(common_shape, field.spatial_shape):
if shape_component != common_shape_component:
substitutions[shape_component] = common_shape_component
if substitutions:
body.subs(substitutions)
def get_common_field(field_set):
"""Takes a set of pystencils Fields, checks if a common spatial shape exists and returns one
representative field, that can be used for shape information etc. in the kernel creation.
If the fields have different shapes ValueError is raised"""
nr_of_fixed_shaped_fields = 0
for f in field_set:
if f.has_fixed_shape:
nr_of_fixed_shaped_fields += 1
if nr_of_fixed_shaped_fields > 0 and nr_of_fixed_shaped_fields != len(field_set):
fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
msg += f"Variable shaped: {var_field_names} \nFixed shaped: {fixed_field_names}"
raise ValueError(msg)
shape_set = set([f.spatial_shape for f in field_set])
if nr_of_fixed_shaped_fields == len(field_set):
if len(shape_set) != 1:
raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
# Sort the fields by their name to ensure that always the same field is returned
reference_field = sorted(field_set, key=lambda e: str(e))[0]
return reference_field
def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None):
"""Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.
Args:
body: Block object with inner loop contents
iteration_slice: if not None, iteration is done only over this slice of the field
ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
if None, the number of ghost layers is determined automatically and assumed to be equal for a
all dimensions
loop_order: loop ordering from outer to inner loop (optimal ordering is same as layout)
Returns:
tuple of loop-node, ghost_layer_info
"""
# find correct ordering by inspecting participating FieldAccesses
absolut_accesses_only = False
field_accesses = body.atoms(Field.Access)
field_accesses = {e for e in field_accesses if not e.is_absolute_access}
if len(field_accesses) == 0: # when kernel contains only absolute accesses
absolut_accesses_only = True
# exclude accesses to buffers from field_list, because buffers are treated separately
field_list = [e.field for e in field_accesses if not (FieldType.is_buffer(e.field) or FieldType.is_custom(e.field))]
if len(field_list) == 0: # when kernel contains only custom fields
field_list = [e.field for e in field_accesses if not (FieldType.is_buffer(e.field))]
fields = set(field_list)
if loop_order is None:
loop_order = get_optimal_loop_ordering(fields)
if absolut_accesses_only:
absolut_access_fields = {e.field for e in body.atoms(Field.Access)}
common_field = get_common_field(absolut_access_fields)
common_shape = common_field.spatial_shape
else:
common_field = get_common_field(fields)
common_shape = common_field.spatial_shape
unify_shape_symbols(body, common_shape=common_shape, fields=fields)
if iteration_slice is not None:
iteration_slice = normalize_slice(iteration_slice, common_shape)
if ghost_layers is None:
if absolut_accesses_only:
required_ghost_layers = 0
else:
required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
if isinstance(ghost_layers, int):
ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)
current_body = body
for i, loop_coordinate in enumerate(reversed(loop_order)):
if iteration_slice is None:
begin = ghost_layers[loop_coordinate][0]
end = common_shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
current_body = ast.Block([new_loop])
else:
slice_component = iteration_slice[loop_coordinate]
if type(slice_component) is slice:
sc = slice_component
new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
current_body = ast.Block([new_loop])
else:
assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate),
sp.sympify(slice_component))
current_body.insert_front(assignment)
return current_body, ghost_layers
def get_common_indexed_element(indexed_elements: Set[sp.IndexedBase]) -> sp.IndexedBase:
assert len(indexed_elements) > 0, "indexed_elements can not be empty"
shape_set = {s.shape for s in indexed_elements}
if len(shape_set) != 1:
for shape in shape_set:
assert not isinstance(shape, int), "If indexed elements are used, they must all have the same shape"
return sorted(indexed_elements, key=lambda e: str(e))[0]
def add_outer_loop_over_indexed_elements(loop_node: ast.Block) -> ast.Block:
indexed_elements = loop_node.atoms(sp.Indexed)
if len(indexed_elements) == 0:
return loop_node
reference_element = get_common_indexed_element(indexed_elements)
index = reference_element.indices[0].atoms(TypedSymbol)
assert len(index) == 1, "index expressions must only contain one symbol representing the index"
new_loop = ast.LoopOverCoordinate(loop_node, 0, 0,
reference_element.shape[0], 1, custom_loop_ctr=index.pop())
return ast.Block([new_loop])
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
r"""
Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
Returns a new typed symbol, where the name encodes which coordinates have been resolved.
Args:
field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
coordinates: mapping of coordinate ids to its value, where stride*value is calculated
previous_ptr: the pointer which is de-referenced
Returns
tuple with the new pointer symbol and the calculated offset
Examples:
>>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1)
>>> x, y = sp.symbols("x y")
>>> prev_pointer = TypedSymbol("ptr", "double")
>>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer)
(ptr_01, _stride_myfield_0*x + _stride_myfield_0)
>>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer)
(ptr_01_1m2, _stride_myfield_0*x + _stride_myfield_0 + _stride_myfield_1*y - 2*_stride_myfield_1)
"""
field = field_access.field
offset = 0
name = ""
list_to_hash = []
for coordinate_id, coordinate_value in coordinates.items():
offset += field.strides[coordinate_id] * coordinate_value
if coordinate_id < field.spatial_dimensions:
offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
if field_access.offsets[coordinate_id].is_Integer:
name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id])
else:
list_to_hash.append(field_access.offsets[coordinate_id])
else:
if type(coordinate_value) is int:
name += "_%d%d" % (coordinate_id, coordinate_value)
else:
list_to_hash.append(coordinate_value)
if len(list_to_hash) > 0:
name += hashlib.md5(pickle.dumps(list_to_hash)).hexdigest()[:16]
name = name.replace("-", 'm')
new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype)
return new_ptr, offset
def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions):
"""
Creates base pointer specification for :func:`resolve_field_accesses` function.
Specification of how many and which intermediate pointers are created for a field access.
For example [ (0), (2,3,)] creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
zero directly in the field access. These specifications are defined dependent on the loop ordering.
This function translates more readable version into the specification above.
Allowed specifications:
- "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
spatialInner1 the loop enclosing the innermost
- "spatialOuter<int>" spatialOuter0 is the outermost loop
- "index<int>": index coordinate
- "<int>": specifying directly the coordinate
Args:
base_pointer_specification: nested list with above specifications
loop_order: list with ordering of loops from outer to inner
spatial_dimensions: number of spatial dimensions
index_dimensions: number of index dimensions
Returns:
list of tuples that can be passed to :func:`resolve_field_accesses`
Examples:
>>> parse_base_pointer_info([['spatialOuter0'], ['index0']], loop_order=[2,1,0],
... spatial_dimensions=3, index_dimensions=1)
[[0], [3], [1, 2]]
"""
result = []
specified_coordinates = set()
loop_order = list(reversed(loop_order))
for spec_group in base_pointer_specification:
new_group = []
def add_new_element(elem):
if elem >= spatial_dimensions + index_dimensions:
raise ValueError("Coordinate %d does not exist" % (elem,))
new_group.append(elem)
if elem in specified_coordinates:
raise ValueError("Coordinate %d specified two times" % (elem,))
specified_coordinates.add(elem)
for element in spec_group:
if type(element) is int:
add_new_element(element)
elif element.startswith("spatial"):
element = element[len("spatial"):]
if element.startswith("Inner"):
index = int(element[len("Inner"):])
add_new_element(loop_order[index])
elif element.startswith("Outer"):
index = int(element[len("Outer"):])
add_new_element(loop_order[-index])
elif element == "all":
for i in range(spatial_dimensions):
add_new_element(i)
else:
raise ValueError("Could not parse " + element)
elif element.startswith("index"):
index = int(element[len("index"):])
add_new_element(spatial_dimensions + index)
else:
raise ValueError(f"Unknown specification {element}")
result.append(new_group)
all_coordinates = set(range(spatial_dimensions + index_dimensions))
rest = all_coordinates - specified_coordinates
if rest:
result.append(list(rest))
return result
def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
"""Used for buffer fields to determine the linearized index of the buffer dependent on loop counter symbols.
Args:
ast_node: ast before any field accesses are resolved
loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
for GPU kernels: list of 'loop counters' from inner to outer loop
loop_iterations: iteration slice for each loop from inner to outer, for CPU kernels leave to default
Returns:
base buffer index - required by 'resolve_buffer_accesses' function
"""
if loop_counters is None or loop_iterations is None:
loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)]
loops.reverse()
parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True))
assert len(loops) == len(parents_of_innermost_loop)
assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
loop_counters = [loop.loop_counter_symbol for loop in loops]
loop_iterations = [slice(loop.start, loop.stop, loop.step) for loop in loops]
actual_sizes = list()
actual_steps = list()
for ctr, s in zip(loop_counters, loop_iterations):
if s.step != 1:
if (s.stop - s.start) % s.step == 0:
actual_sizes.append((s.stop - s.start) // s.step)
else:
actual_sizes.append(int_div((s.stop - s.start), s.step))
if (ctr - s.start) % s.step == 0:
actual_steps.append((ctr - s.start) // s.step)
else:
actual_steps.append(int_div((ctr - s.start), s.step))
else:
actual_sizes.append(s.stop - s.start)
actual_steps.append(ctr - s.start)
field_accesses = ast_node.atoms(Field.Access)
buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
buffer_index_size = len(buffer_accesses)
base_buffer_index = actual_steps[0]
actual_stride = 1
for idx, actual_step in enumerate(actual_steps[1:]):
cur_stride = actual_sizes[idx]
actual_stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
base_buffer_index += actual_stride * actual_step
return base_buffer_index * buffer_index_size
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=None):
if read_only_field_names is None:
read_only_field_names = set()
def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, Field.Access):
field_access = expr
# Do not apply transformation if field is not a buffer
if not FieldType.is_buffer(field_access.field):
return expr
buffer = field_access.field
field_ptr = FieldPointerSymbol(buffer.name, buffer.dtype, const=buffer.name in read_only_field_names)
buffer_index = base_buffer_index
if len(field_access.index) > 1:
raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!')
if len(field_access.index) > 0:
cell_index = field_access.index[0]
buffer_index += cell_index
result = ast.ResolvedFieldAccess(field_ptr, buffer_index, field_access.field, field_access.offsets,
field_access.index)
return visit_sympy_expr(result, enclosing_block, sympy_assignment)
else:
if isinstance(expr, ast.ResolvedFieldAccess):
return expr
new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
return expr.func(*new_args, **kwargs) if new_args else expr
def visit_node(sub_ast):
if isinstance(sub_ast, ast.SympyAssignment):
enclosing_block = sub_ast.parent
assert type(enclosing_block) is ast.Block
sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
else:
for i, a in enumerate(sub_ast.args):
visit_node(a)
return visit_node(ast_node)
def resolve_field_accesses(ast_node, read_only_field_names=None,
field_to_base_pointer_info=MappingProxyType({}),
field_to_fixed_coordinates=MappingProxyType({})):
"""
Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing
Args:
ast_node: the AST root
read_only_field_names: set of field names which are considered read-only
field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created
for details see :func:`parse_base_pointer_info`
field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
counters to index the field these symbols are used as coordinates
Returns
transformed AST
"""
if read_only_field_names is None:
read_only_field_names = set()
field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, Field.Access):
field_access = expr
field = field_access.field
if field_access.indirect_addressing_fields:
new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment)
for off in field_access.offsets)
new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment)
if isinstance(ind, sp.Basic) else ind
for ind in field_access.index)
field_access = Field.Access(field_access.field, new_offsets,
new_indices, field_access.is_absolute_access)
if field.name in field_to_base_pointer_info:
base_pointer_info = field_to_base_pointer_info[field.name]
else:
base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))]
field_ptr = FieldPointerSymbol(
field.name,
field.dtype,
const=field.name in read_only_field_names)
def create_coordinate_dict(group_param):
coordinates = {}
for e in group_param:
if e < field.spatial_dimensions:
if field.name in field_to_fixed_coordinates:
if not field_access.is_absolute_access:
coordinates[e] = field_to_fixed_coordinates[field.name][e]
else:
coordinates[e] = 0
else:
if not field_access.is_absolute_access:
coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
else:
coordinates[e] = 0
coordinates[e] *= field.dtype.item_size
else:
if isinstance(field.dtype, StructType):
assert field.index_dimensions == 1
accessed_field_name = field_access.index[0]
if isinstance(accessed_field_name, sp.Symbol):
accessed_field_name = accessed_field_name.name
assert isinstance(accessed_field_name, str)
coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
else:
coordinates[e] = field_access.index[e - field.spatial_dimensions]
return coordinates
last_pointer = field_ptr
for group in reversed(base_pointer_info[1:]):
coord_dict = create_coordinate_dict(group)
new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
if new_ptr not in enclosing_block.symbols_defined:
new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False, use_auto=False)
enclosing_block.insert_before(new_assignment, sympy_assignment)
last_pointer = new_ptr
coord_dict = create_coordinate_dict(base_pointer_info[0])
_, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field,
field_access.offsets, field_access.index)
if isinstance(get_base_type(field_access.field.dtype), StructType):
accessed_field_name = field_access.index[0]
if isinstance(accessed_field_name, sp.Symbol):
accessed_field_name = accessed_field_name.name
new_type = field_access.field.dtype.get_element_type(accessed_field_name)
result = ReinterpretCastFunc(result, new_type)
return visit_sympy_expr(result, enclosing_block, sympy_assignment)
else:
if isinstance(expr, ast.ResolvedFieldAccess):
return expr
if hasattr(expr, 'args'):
new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
else:
new_args = []
kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
return expr.func(*new_args, **kwargs) if new_args else expr
def visit_node(sub_ast):
if isinstance(sub_ast, ast.SympyAssignment):
enclosing_block = sub_ast.parent
assert type(enclosing_block) is ast.Block
sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
elif isinstance(sub_ast, ast.Conditional):
enclosing_block = sub_ast.parent
assert type(enclosing_block) is ast.Block
sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr, enclosing_block, sub_ast)
visit_node(sub_ast.true_block)
if sub_ast.false_block:
visit_node(sub_ast.false_block)
else:
if isinstance(sub_ast, (bool, int, float)):
return
for a in sub_ast.args:
visit_node(a)
return visit_node(ast_node)
def move_constants_before_loop(ast_node):
"""Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.
Call this after creating the loop structure with :func:`make_loop_over_domain`
"""
def find_block_to_move_to(node):
"""
Traverses parents of node as long as the symbols are independent and returns a (parent) block
the assignment can be safely moved to
:param node: SympyAssignment inside a Block
:return blockToInsertTo, childOfBlockToInsertBefore
"""
assert isinstance(node.parent, ast.Block)
def modifies_or_declares(node: ast.Node, symbol_names: Set[str]) -> bool:
if isinstance(node, (ps.Assignment, ast.SympyAssignment)):
if isinstance(node.lhs, ast.ResolvedFieldAccess):
return node.lhs.typed_symbol.name in symbol_names
else:
return node.lhs.name in symbol_names
elif isinstance(node, ast.Block):
for arg in node.args:
if isinstance(arg, ast.SympyAssignment) and arg.is_declaration:
continue
if modifies_or_declares(arg, symbol_names):
return True
return False
elif isinstance(node, ast.LoopOverCoordinate):
return modifies_or_declares(node.body, symbol_names)
elif isinstance(node, ast.Conditional):
return (
modifies_or_declares(node.true_block, symbol_names)
or (node.false_block and modifies_or_declares(node.false_block, symbol_names))
)
elif isinstance(node, ast.KernelFunction):
return False
else:
defs = {s.name for s in node.symbols_defined}
return bool(symbol_names.intersection(defs))
dependencies = {s.name for s in node.undefined_symbols}
last_block = node.parent
last_block_child = node
element = node.parent
prev_element = node
while element:
if isinstance(element, (ast.Conditional, ast.KernelFunction)):
# Never move out of Conditionals or KernelFunctions.
break
elif isinstance(element, ast.Block):
last_block = element
last_block_child = prev_element
if any(modifies_or_declares(sibling, dependencies) for sibling in element.args):
# The node depends on one of the statements in this block.
# Do not move further out.
break
elif isinstance(element, ast.LoopOverCoordinate):
if element.loop_counter_symbol.name in dependencies:
# The node depends on the loop counter.
# Do not move out of this loop.
break
else:
raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n'
f'The expression {element} of type {type(element)} is not known yet.')
# No dependencies to symbols defined/modified within the current element.
# We can move the node up one level and in front of the current element.
prev_element = element
element = element.parent
return last_block, last_block_child
def check_if_assignment_already_in_block(assignment, target_block, rhs_or_lhs=True):
for arg in target_block.args:
if type(arg) is not ast.SympyAssignment:
continue
if (rhs_or_lhs and arg.rhs == assignment.rhs) or (not rhs_or_lhs and arg.lhs == assignment.lhs):
return arg
return None
def get_blocks(node, result_list):
if isinstance(node, ast.Block):
result_list.append(node)
if isinstance(node, ast.Node):
for a in node.args:
get_blocks(a, result_list)
all_blocks = []
get_blocks(ast_node, all_blocks)
for block in all_blocks:
children = block.take_child_nodes()
for child in children:
if not isinstance(child, ast.SympyAssignment): # only move SympyAssignments
block.append(child)
continue
target, child_to_insert_before = find_block_to_move_to(child)
if target == block: # movement not possible
target.append(child)
else:
if isinstance(child, ast.SympyAssignment):
exists_already = check_if_assignment_already_in_block(child, target, False)
else:
exists_already = False
if not exists_already:
target.insert_before(child, child_to_insert_before)
elif exists_already and exists_already.rhs == child.rhs:
if target.args.index(exists_already) > target.args.index(child_to_insert_before):
assert target.args.count(exists_already) == 1
assert target.args.count(child_to_insert_before) == 1
target.args.remove(exists_already)
target.insert_before(exists_already, child_to_insert_before)
else:
# this variable already exists in outer block, but with different rhs
# -> symbol has to be renamed
assert isinstance(child.lhs, TypedSymbol)
new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const),
child_to_insert_before)
block.append(ast.SympyAssignment(child.lhs, new_symbol, is_const=child.is_const))
def split_inner_loop(ast_node: ast.Node, symbol_groups):
"""
Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
Args:
ast_node: AST root
symbol_groups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups
and which no symbol in a symbol group depends on, are not updated!
"""
all_loops = ast_node.atoms(ast.LoopOverCoordinate)
inner_loop = [loop for loop in all_loops if loop.is_innermost_loop]
assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
inner_loop = inner_loop[0]
assert type(inner_loop.body) is ast.Block
outer_loop = [loop for loop in all_loops if loop.is_outermost_loop]
assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
outer_loop = outer_loop[0]
symbols_with_temporary_array = OrderedDict()
assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args if hasattr(a, 'lhs'))
assignment_groups = []
for symbol_group in symbol_groups:
# get all dependent symbols
symbols_to_process = list(symbol_group)
symbols_resolved = set()
while symbols_to_process:
s = symbols_to_process.pop()
if s in symbols_resolved:
continue
if s in assignment_map: # if there is no assignment inside the loop body it is independent already
for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
if not isinstance(new_symbol, Field.Access) and \
new_symbol not in symbols_with_temporary_array:
symbols_to_process.append(new_symbol)
symbols_resolved.add(s)
for symbol in symbol_group:
if not isinstance(symbol, Field.Access):
assert type(symbol) is TypedSymbol
new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
symbols_with_temporary_array[symbol] = sp.IndexedBase(
new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
assignment_group = []
for assignment in inner_loop.body.args:
if assignment.lhs in symbols_resolved:
# use fast_subs here because it checks if multiplications should be evaluated or not
new_rhs = fast_subs(assignment.rhs, symbols_with_temporary_array)
if not isinstance(assignment.lhs, Field.Access) and assignment.lhs in symbol_group:
assert type(assignment.lhs) is TypedSymbol
new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
new_lhs = sp.IndexedBase(new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
else:
new_lhs = assignment.lhs
assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
assignment_groups.append(assignment_group)
new_loops = [
inner_loop.new_loop_with_different_body(ast.Block(group))
for group in assignment_groups
]
inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
for tmp_array in symbols_with_temporary_array:
tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype))
alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop, inner_loop.start)
free_node = ast.TemporaryMemoryFree(alloc_node)
outer_loop.parent.insert_front(alloc_node)
outer_loop.parent.append(free_node)
def cut_loop(loop_node, cutting_points):
"""Cuts loop at given cutting points.
One loop is transformed into len(cuttingPoints)+1 new loops that range from
old_begin to cutting_points[1], ..., cutting_points[-1] to old_end
Modifies the ast in place. Note Issue #5783 of SymPy. Deepcopy will evaluate mul
https://github.com/sympy/sympy/issues/5783
Returns:
list of new loop nodes
"""
if loop_node.step != 1:
raise NotImplementedError("Can only split loops that have a step of 1")
new_loops = ast.Block([])
new_start = loop_node.start
cutting_points = list(cutting_points) + [loop_node.stop]
for new_end in cutting_points:
if new_end - new_start == 1:
new_body = deepcopy(loop_node.body)
new_body.subs({loop_node.loop_counter_symbol: new_start})
new_loops.append(new_body)
elif new_end - new_start == 0:
pass
else:
new_loop = ast.LoopOverCoordinate(
deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
new_start, new_end, loop_node.step)
new_loops.append(new_loop)
new_start = new_end
loop_node.parent.replace(loop_node, new_loops)
return new_loops
def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = False) -> None:
"""Removes conditionals that are always true/false.
Args:
node: ast node, all descendants of this node are simplified
loop_counter_simplification: if enabled, tries to detect if a conditional is always true/false
depending on the surrounding loop. For example if the surrounding loop goes from
x=0 to 10 and the condition is x < 0, it is removed.
This analysis needs the integer set library (ISL) islpy, so it is not done by
default.
"""
from sympy.codegen.rewriting import ReplaceOptim, optimize
remove_casts = ReplaceOptim(lambda e: isinstance(e, CastFunc), lambda p: p.expr)
for conditional in node.atoms(ast.Conditional):
# TODO simplify conditional before the type system! Casts make it very hard here
condition_expression = optimize(conditional.condition_expr, [remove_casts])
condition_expression = sp.simplify(condition_expression)
if condition_expression == sp.true:
conditional.parent.replace(conditional, [conditional.true_block])
elif condition_expression == sp.false:
conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
elif loop_counter_simplification:
try:
# noinspection PyUnresolvedReferences
from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional
simplify_loop_counter_dependent_conditional(conditional)
except ImportError:
warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed")
def cleanup_blocks(node: ast.Node) -> None:
"""Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
if isinstance(node, ast.SympyAssignment):
return
elif isinstance(node, ast.Block):
for a in list(node.args):
cleanup_blocks(a)
if len(node.args) <= 1 and isinstance(node.parent, ast.Block):
node.parent.replace(node, node.args)
return
else:
for a in node.args:
cleanup_blocks(a)
def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, include_first=True) -> None:
"""Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last or
first and last element"""
all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
inner_loop = all_inner_loops.pop()
for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
if include_first:
cut_loop(loop, [loop.start + 1, loop.stop - 1])
else:
cut_loop(loop, [loop.stop - 1])
simplify_conditionals(function_node.body, loop_counter_simplification=True)
cleanup_blocks(function_node.body)
move_constants_before_loop(function_node.body)
cleanup_blocks(function_node.body)
# --------------------------------------- Helper Functions -------------------------------------------------------------
def get_optimal_loop_ordering(fields):
"""
Determines the optimal loop order for a given set of fields.
If the fields have different memory layout or different sizes an exception is thrown.
Args:
fields: sequence of fields
Returns:
list of coordinate ids, where the first list entry should be the outermost loop
"""
assert len(fields) > 0
ref_field = next(iter(fields))
for field in fields:
if field.spatial_dimensions != ref_field.spatial_dimensions:
raise ValueError(
"All fields have to have the same number of spatial dimensions. Spatial field dimensions: "
+ str({f.name: f.spatial_shape
for f in fields}))
layouts = set([field.layout for field in fields])
if len(layouts) > 1:
raise ValueError(
"Due to different layout of the fields no optimal loop ordering exists "
+ str({f.name: f.layout
for f in fields}))
layout = list(layouts)[0]
return list(layout)
def get_loop_hierarchy(ast_node):
"""Determines the loop structure around a given AST node, i.e. the node has to be inside the loops.
Returns:
sequence of LoopOverCoordinate nodes, starting from outer loop to innermost loop
"""
result = []
node = ast_node
while node is not None:
node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
if node:
result.append(node.coordinate_to_loop_over)
return reversed(result)
def get_loop_counter_symbol_hierarchy(ast_node):
"""Determines the loop counter symbols around a given AST node.
:param ast_node: the AST node
:return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop
"""
result = []
node = ast_node
while node is not None:
node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
if node:
result.append(node.loop_counter_symbol)
return result
def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None:
"""Replaces the stride of the innermost loop of a variable sized kernel with 1 (assumes optimal loop ordering).
Variable sized kernels can handle arbitrary field sizes and field shapes. However, the kernel is most efficient
if the innermost loop accesses the fields with stride 1. The inner loop can also only be vectorized if the inner
stride is 1. This transformation hard codes this inner stride to one to enable e.g. vectorization.
Warning: the assumption is not checked at runtime!
"""
inner_loops = []
inner_loop_counters = set()
for loop in filtered_tree_iteration(ast_node,
ast.LoopOverCoordinate,
stop_type=ast.SympyAssignment):
if loop.is_innermost_loop:
inner_loops.append(loop)
inner_loop_counters.add(loop.coordinate_to_loop_over)
if len(inner_loop_counters) != 1:
raise ValueError("Inner loops iterate over different coordinates")
inner_loop_counter = inner_loop_counters.pop()
parameters = ast_node.get_parameters()
stride_params = [
p.symbol for p in parameters
if p.is_field_stride and p.symbol.coordinate == inner_loop_counter
]
subs_dict = {stride_param: 1 for stride_param in stride_params}
if subs_dict:
ast_node.subs(subs_dict)
def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
"""Blocking of loops to enhance cache locality. Modifies the ast node in-place.
Args:
ast_node: kernel function node before vectorization transformation has been applied
block_size: sequence defining block size in x, y, (z) direction.
If chosen as zero the direction will not be used for blocking.
Returns:
number of dimensions blocked
"""
loops = [
l for l in filtered_tree_iteration(
ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
]
body = ast_node.body
coordinates = []
coordinates_taken_into_account = 0
loop_starts = {}
loop_stops = {}
for loop in loops:
coord = loop.coordinate_to_loop_over
if coord not in coordinates:
coordinates.append(coord)
loop_starts[coord] = loop.start
loop_stops[coord] = loop.stop
else:
assert loop.start == loop_starts[coord] and loop.stop == loop_stops[coord], \
f"Multiple loops over coordinate {coord} with different loop bounds"
# Create the outer loops that iterate over the blocks
outer_loop = None
for coord in reversed(coordinates):
if block_size[coord] == 0:
continue
coordinates_taken_into_account += 1
body = ast.Block([outer_loop]) if outer_loop else body
outer_loop = ast.LoopOverCoordinate(body,
coord,
loop_starts[coord],
loop_stops[coord],
step=block_size[coord],
is_block_loop=True)
ast_node.body = ast.Block([outer_loop])
# modify the existing loops to only iterate within one block
for inner_loop in loops:
coord = inner_loop.coordinate_to_loop_over
if block_size[coord] == 0:
continue
block_ctr = ast.LoopOverCoordinate.get_block_loop_counter_symbol(coord)
loop_range = inner_loop.stop - inner_loop.start
if sp.sympify(
loop_range).is_number and loop_range % block_size[coord] == 0:
stop = block_ctr + block_size[coord]
else:
stop = sp.Min(inner_loop.stop, block_ctr + block_size[coord])
inner_loop.start = block_ctr
inner_loop.stop = stop
return coordinates_taken_into_account
from pystencils.typing.cast_functions import (CastFunc, BooleanCastFunc, VectorMemoryAccess, ReinterpretCastFunc,
PointerArithmeticFunc)
from pystencils.typing.types import (is_supported_type, numpy_name_to_c, AbstractType, BasicType, VectorType,
PointerType, StructType, create_type)
from pystencils.typing.typed_sympy import (assumptions_from_dtype, TypedSymbol, FieldStrideSymbol, FieldShapeSymbol,
FieldPointerSymbol, CFunction)
from pystencils.typing.utilities import (typed_symbols, get_base_type, result_type, collate_types,
get_type_of_expression, get_next_parent_of_type, parents_of_type)
__all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCastFunc', 'PointerArithmeticFunc',
'is_supported_type', 'numpy_name_to_c', 'AbstractType', 'BasicType',
'VectorType', 'PointerType', 'StructType', 'create_type', 'assumptions_from_dtype',
'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', 'CFunction',
'typed_symbols', 'get_base_type', 'result_type', 'collate_types',
'get_type_of_expression', 'get_next_parent_of_type', 'parents_of_type']
import numpy as np
import sympy as sp
from sympy.logic.boolalg import Boolean
from pystencils.typing.types import AbstractType, BasicType
from pystencils.typing.typed_sympy import TypedSymbol
class CastFunc(sp.Function):
"""
CastFunc is used in order to introduce static casts. They are especially useful as a way to signal what type
a certain node should have, if it is impossible to add a type to a node, e.g. a sp.Number.
"""
is_Atom = True
def __new__(cls, *args, **kwargs):
if len(args) != 2:
pass
expr, dtype, *other_args = args
# If we have two consecutive casts, throw the inner one away.
# This optimisation is only available for simple casts. Thus the == is intended here!
if expr.__class__ == CastFunc:
expr = expr.args[0]
if not isinstance(dtype, AbstractType):
dtype = BasicType(dtype)
# to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
# however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
# to problems when for example comparing cast_func's for equality
#
# lhs = bitwise_and(a, cast_func(1, 'int'))
# rhs = cast_func(0, 'int')
# print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
# -> thus a separate class boolean_cast_func is introduced
if isinstance(expr, Boolean) and (not isinstance(expr, TypedSymbol) or expr.dtype == BasicType('bool')):
cls = BooleanCastFunc
return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
return self.args[0].canonical
else:
raise NotImplementedError()
@property
def is_commutative(self):
return self.args[0].is_commutative
@property
def dtype(self):
return self.args[1]
@property
def expr(self):
return self.args[0]
@property
def is_integer(self):
"""
Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate
For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
"""
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
else:
return super().is_integer
@property
def is_negative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'):
if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
return False
return super().is_negative
@property
def is_nonnegative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if self.is_negative is False:
return True
else:
return super().is_nonnegative
@property
def is_real(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or np.issubdtype(self.dtype.numpy_dtype,
np.floating) or super().is_real
else:
return super().is_real
class BooleanCastFunc(CastFunc, Boolean):
# TODO: documentation
pass
class VectorMemoryAccess(CastFunc):
"""
Special memory access for vectorized kernel.
Arguments: read/write expression, type, aligned, non-temporal, mask (or none), stride
"""
nargs = (6,)
class ReinterpretCastFunc(CastFunc):
"""
Reinterpret cast is necessary for the StructType
"""
pass
class PointerArithmeticFunc(sp.Function, Boolean):
# TODO: documentation, or deprecate!
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
return self.args[0].canonical
else:
raise NotImplementedError()
from collections import namedtuple
from typing import Union, Tuple, Any, DefaultDict
import logging
import numpy as np
import sympy as sp
from sympy import Piecewise
from sympy.core.numbers import NegativeOne
from sympy.core.relational import Relational
from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
from sympy.functions.elementary.hyperbolic import HyperbolicFunction
from sympy.functions.elementary.integers import RoundFunction
from sympy.logic.boolalg import BooleanFunction
from sympy.logic.boolalg import BooleanAtom
from pystencils import astnodes as ast
from pystencils.functions import DivFunc, AddressOf
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.field import Field
from pystencils.typing.types import BasicType, PointerType
from pystencils.typing.utilities import collate_types
from pystencils.typing.cast_functions import CastFunc, BooleanCastFunc
from pystencils.typing.typed_sympy import TypedSymbol
from pystencils.fast_approximation import fast_sqrt, fast_division, fast_inv_sqrt
from pystencils.utils import ContextVar
class TypeAdder:
# TODO: specification -> jupyter notebook
"""Checks if the input to create_kernel is valid.
Test the following conditions:
- SSA Form for pure symbols:
- Every pure symbol may occur only once as left-hand-side of an assignment
- Every pure symbol that is read, may not be written to later
- Independence / Parallelization condition:
- a field that is written may only be read at exact the same spatial position
(Pure symbols are symbols that are not Field.Accesses)
"""
FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
def __init__(self, type_for_symbol: DefaultDict[str, BasicType], default_number_float: BasicType,
default_number_int: BasicType):
self.type_for_symbol = type_for_symbol
self.default_number_float = ContextVar(default_number_float)
self.default_number_int = ContextVar(default_number_int)
def visit(self, obj):
if isinstance(obj, (list, tuple)):
return [self.visit(e) for e in obj]
if isinstance(obj, ast.SympyAssignment):
return self.process_assignment(obj)
elif isinstance(obj, ast.Conditional):
condition, condition_type = self.figure_out_type(obj.condition_expr)
assert condition_type == BasicType('bool')
true_block = self.visit(obj.true_block)
false_block = None if obj.false_block is None else self.visit(
obj.false_block)
return ast.Conditional(condition, true_block=true_block, false_block=false_block)
elif isinstance(obj, ast.Block):
return ast.Block([self.visit(e) for e in obj.args])
elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
return obj
else:
raise ValueError("Invalid object in kernel " + str(type(obj)))
def process_assignment(self, assignment: ast.SympyAssignment) -> ast.SympyAssignment:
# for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
new_rhs, rhs_type = self.figure_out_type(assignment.rhs)
lhs = assignment.lhs
if not isinstance(lhs, (Field.Access, TypedSymbol)):
if isinstance(lhs, sp.Symbol):
self.type_for_symbol[lhs.name] = rhs_type
else:
raise ValueError(f'Lhs: `{lhs}` is not a subtype of sp.Symbol')
new_lhs, lhs_type = self.figure_out_type(lhs)
assert isinstance(new_lhs, (Field.Access, TypedSymbol))
if lhs_type != rhs_type:
logging.debug(f'Lhs"{new_lhs} of type "{lhs_type}" is assigned with a different datatype '
f'rhs: "{new_rhs}" of type "{rhs_type}".')
return ast.SympyAssignment(new_lhs, CastFunc(new_rhs, lhs_type), assignment.is_const, assignment.use_auto)
else:
return ast.SympyAssignment(new_lhs, new_rhs, assignment.is_const, assignment.use_auto)
# Type System Specification
# - Defined Types: TypedSymbol, Field, Field.Access, ...?
# - Indexed: always unsigned_integer64
# - Undefined Types: Symbol
# - Is specified in Config in the dict or as 'default_type' or behaves like `auto` in the case of lhs.
# - Constants/Numbers: Are either integer or floating. The precision and sign is specified via config
# - Example: 1.4 config:float32 -> float32
# - Expressions deduce types from arguments
# - Functions deduce types from arguments
# - default_type and default_float and default_int can be given for a list of assignment, or
# individually as a list for assignment
# Possible Problems - Do we need to support this?
# - Mixture in expression with int and float
# - Mixture in expression with uint64 and sint64
# TODO Logging: Lowest log level should log all casts ----> cast factory, make cast should contain logging
def figure_out_type(self, expr) -> Tuple[Any, Union[BasicType, PointerType]]:
# Trivial cases
from pystencils.field import Field
import pystencils.integer_functions
from pystencils.bit_masks import flag_cond
bool_type = BasicType('bool')
# TOOO: check the access
if isinstance(expr, Field.Access):
return expr, expr.dtype
elif isinstance(expr, TypedSymbol):
return expr, expr.dtype
elif isinstance(expr, sp.Symbol):
t = TypedSymbol(expr.name, self.type_for_symbol[expr.name])
return t, t.dtype
elif isinstance(expr, np.generic):
assert False, f'Why do we have a np.generic in rhs???? {expr}'
elif isinstance(expr, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)):
return expr, BasicType('float32') # see https://en.cppreference.com/w/cpp/numeric/math/INFINITY
elif isinstance(expr, sp.Number):
if expr.is_Integer:
data_type = self.default_number_int.get()
elif expr.is_Float or expr.is_Rational:
data_type = self.default_number_float.get()
else:
assert False, f'{sp.Number} is neither Float nor Integer'
return CastFunc(expr, data_type), data_type
elif isinstance(expr, AddressOf):
of = expr.args[0]
# TODO Basically this should do address_of already
assert isinstance(of, (Field.Access, TypedSymbol, Field))
return expr, PointerType(of.dtype)
elif isinstance(expr, BooleanAtom):
return expr, bool_type
elif isinstance(expr, Relational):
# TODO Jan: Code duplication with general case
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
if isinstance(expr, sp.Equality) and collated_type.is_float():
logging.warning(f"Using floating point numbers in equality comparison: {expr}")
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
new_eq = expr.func(*new_args)
return new_eq, bool_type
elif isinstance(expr, CastFunc):
new_expr, _ = self.figure_out_type(expr.expr)
return expr.func(*[new_expr, expr.dtype]), expr.dtype
elif isinstance(expr, ast.ConditionalFieldAccess):
access, access_type = self.figure_out_type(expr.access)
value, value_type = self.figure_out_type(expr.outofbounds_value)
condition, condition_type = self.figure_out_type(expr.outofbounds_condition)
assert condition_type == bool_type
collated_type = collate_types([access_type, value_type])
if collated_type == access_type:
new_access = access
else:
logging.warning(f"In {expr} the Field Access had to be casted to {collated_type}. This is "
f"probably due to a type missmatch of the Field and the value of "
f"ConditionalFieldAccess")
new_access = CastFunc(access, collated_type)
new_value = value if value_type == collated_type else CastFunc(value, collated_type)
return expr.func(new_access, condition, new_value), collated_type
elif isinstance(expr, (vec_any, vec_all)):
return expr, bool_type
elif isinstance(expr, BooleanFunction):
args_types = [self.figure_out_type(a) for a in expr.args]
new_args = [a if t.dtype_eq(bool_type) else BooleanCastFunc(a, bool_type) for a, t in args_types]
return expr.func(*new_args), bool_type
elif type(expr, ) in pystencils.integer_functions.__dict__.values() or isinstance(expr, sp.Mod):
args_types = [self.figure_out_type(a) for a in expr.args]
collated_type = collate_types([t for _, t in args_types])
# TODO: should we downcast to integer? If yes then which integer type?
if not collated_type.is_int():
raise ValueError(f"Integer functions or Modulo need to be used with integer types "
f"but {collated_type} was given")
return expr, collated_type
elif isinstance(expr, flag_cond):
# do not process the arguments to the bit shift - they must remain integers
args_types = [self.figure_out_type(a) for a in (expr.args[i] for i in range(2, len(expr.args)))]
collated_type = collate_types([t for _, t in args_types])
new_expressions = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
return expr.func(expr.args[0], expr.args[1], *new_expressions), collated_type
# elif isinstance(expr, sp.Mul):
# raise NotImplementedError('sp.Mul')
# # TODO can we ignore this and move it to general expr handling, i.e. removing Mul? (See todo in backend)
# # args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)]
elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label
return expr, typed_symbol.dtype
elif isinstance(expr, ExprCondPair):
expr_expr, expr_type = self.figure_out_type(expr.expr)
condition, condition_type = self.figure_out_type(expr.cond)
if condition_type != bool_type:
logging.warning(f'Condition "{condition}" is of type "{condition_type}" and not "bool"')
return expr.func(expr_expr, condition), expr_type
elif isinstance(expr, Piecewise):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
new_args = []
for a, t in args_types:
if t != collated_type:
if isinstance(a, ExprCondPair):
new_args.append(a.func(CastFunc(a.expr, collated_type), a.cond))
else:
new_args.append(CastFunc(a, collated_type))
else:
new_args.append(a)
return expr.func(*new_args) if new_args else expr, collated_type
elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction,
HyperbolicFunction, sp.log, RoundFunction)):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
new_func = expr.func(*new_args) if new_args else expr
if collated_type == BasicType('float64'):
return new_func, collated_type
else:
return CastFunc(new_func, collated_type), collated_type
elif isinstance(expr, (fast_sqrt, fast_division, fast_inv_sqrt)):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = BasicType('float32')
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
new_func = expr.func(*new_args) if new_args else expr
return CastFunc(new_func, collated_type), collated_type
elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)):
# Subtraction is realised a multiplication with -1 in SymPy. Thus we exclude the coefficient in this case
# and resolve the typing entirely with the expression itself
if isinstance(expr, sp.Mul):
c, e = expr.as_coeff_Mul()
if c == NegativeOne():
args_types = self.figure_out_type(e)
new_args = [NegativeOne(), args_types[0]]
return expr.func(*new_args, evaluate=False), args_types[1]
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
if isinstance(collated_type, PointerType):
if isinstance(expr, sp.Add):
return expr.func(*[a for a, _ in args_types]), collated_type
else:
raise NotImplementedError(f'Pointer Arithmetic is implemented only for Add, not {expr}')
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
if isinstance(expr, (sp.Add, sp.Mul)):
return expr.func(*new_args, evaluate=False) if new_args else expr, collated_type
else:
return expr.func(*new_args) if new_args else expr, collated_type
else:
raise NotImplementedError(f'expr {type(expr)}: {expr} unknown to typing')
from typing import List
from pystencils.astnodes import Node
from pystencils.config import CreateKernelConfig
from pystencils.typing.leaf_typing import TypeAdder
def add_types(node_list: List[Node], config: CreateKernelConfig):
"""Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
The AST needs to be a pystencils AST. Thus, in the list of nodes every entry must be inherited from
`pystencils.astnodes.Node`
Additionally returns sets of all fields which are read/written
Args:
node_list: List of pystencils Nodes.
config: CreateKernelConfig
Returns:
``typed_equations`` list of equations where symbols have been replaced by typed symbols
"""
check = TypeAdder(type_for_symbol=config.data_type,
default_number_float=config.default_number_float,
default_number_int=config.default_number_int)
return check.visit(node_list)
from typing import Union
import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
from pystencils.typing.types import BasicType, create_type, PointerType
def assumptions_from_dtype(dtype: Union[BasicType, np.dtype]):
"""Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
Args:
dtype (BasicType, np.dtype): a Numpy data type
Returns:
A dict of SymPy assumptions
"""
if hasattr(dtype, 'numpy_dtype'):
dtype = dtype.numpy_dtype
assumptions = dict()
try:
if np.issubdtype(dtype, np.integer):
assumptions.update({'integer': True})
if np.issubdtype(dtype, np.unsignedinteger):
assumptions.update({'negative': False})
if np.issubdtype(dtype, np.integer) or \
np.issubdtype(dtype, np.floating):
assumptions.update({'real': True})
except Exception: # TODO this is dirty
pass
return assumptions
class TypedSymbol(sp.Symbol):
def __new__(cls, *args, **kwds):
obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, name, dtype, **kwargs): # TODO does not match signature of sp.Symbol???
# TODO: also Symbol should be allowed ---> see sympy Variable
assumptions = assumptions_from_dtype(dtype)
assumptions.update(kwargs)
obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
try:
obj.numpy_dtype = create_type(dtype)
except (TypeError, ValueError):
# on error keep the string
obj.numpy_dtype = dtype
return obj
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
@property
def dtype(self):
return self.numpy_dtype
def _hashable_content(self):
return super()._hashable_content(), hash(self.numpy_dtype)
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), self.assumptions0
@property
def canonical(self):
return self
@property
def reversed(self):
return self
@property
def headers(self):
headers = []
try:
if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
try:
if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
return headers
SHAPE_DTYPE = BasicType('int64', const=True)
STRIDE_DTYPE = BasicType('int64', const=True)
class FieldStrideSymbol(TypedSymbol):
"""Sympy symbol representing the stride value of a field in a specific coordinate."""
def __new__(cls, *args, **kwds):
obj = FieldStrideSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, field_name, coordinate):
name = f"_stride_{field_name}_{coordinate}"
obj = super(FieldStrideSymbol, cls).__xnew__(cls, name, STRIDE_DTYPE, positive=True)
obj.field_name = field_name
obj.coordinate = coordinate
return obj
def __getnewargs__(self):
return self.field_name, self.coordinate
def __getnewargs_ex__(self):
return (self.field_name, self.coordinate), {}
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def _hashable_content(self):
return super()._hashable_content(), self.coordinate, self.field_name
class FieldShapeSymbol(TypedSymbol):
"""Sympy symbol representing the shape value of a sequence of fields. In a kernel iterating over multiple fields
there is only one set of `FieldShapeSymbol`s since all the fields have to be of equal size."""
def __new__(cls, *args, **kwds):
obj = FieldShapeSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, field_names, coordinate):
names = "_".join([field_name for field_name in field_names])
name = f"_size_{names}_{coordinate}"
obj = super(FieldShapeSymbol, cls).__xnew__(cls, name, SHAPE_DTYPE, positive=True)
obj.field_names = tuple(field_names)
obj.coordinate = coordinate
return obj
def __getnewargs__(self):
return self.field_names, self.coordinate
def __getnewargs_ex__(self):
return (self.field_names, self.coordinate), {}
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def _hashable_content(self):
return super()._hashable_content(), self.coordinate, self.field_names
class FieldPointerSymbol(TypedSymbol):
"""Sympy symbol representing the pointer to the beginning of the field data."""
def __new__(cls, *args, **kwds):
obj = FieldPointerSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, field_name, field_dtype, const):
from pystencils.typing.utilities import get_base_type
name = f"_data_{field_name}"
dtype = PointerType(get_base_type(field_dtype), const=const, restrict=True)
obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype)
obj.field_name = field_name
return obj
def __getnewargs__(self):
return self.field_name, self.dtype, self.dtype.const
def __getnewargs_ex__(self):
return (self.field_name, self.dtype, self.dtype.const), {}
def _hashable_content(self):
return super()._hashable_content(), self.field_name
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
class CFunction(TypedSymbol):
def __new__(cls, function, dtype):
return CFunction.__xnew_cached_(cls, function, dtype)
def __new_stage2__(cls, function, dtype):
return super(CFunction, cls).__xnew__(cls, function, dtype)
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), {}
from abc import abstractmethod
from typing import Union
import numpy as np
import sympy as sp
def is_supported_type(dtype: np.dtype):
scalar = dtype.type
c = np.issubdtype(dtype, np.generic)
subclass = issubclass(scalar, np.floating) or issubclass(scalar, np.integer) or issubclass(scalar, np.bool_)
additional_checks = dtype.fields is None and dtype.hasobject is False and dtype.subdtype is None
return c and subclass and additional_checks
def numpy_name_to_c(name: str) -> str:
"""
Converts a np.dtype.name into a C type
Args:
name: np.dtype.name string
Returns:
type as a C string
"""
if name == 'float64':
return 'double'
elif name == 'float32':
return 'float'
elif name == 'float16' or name == 'half':
return 'half'
elif name.startswith('int'):
width = int(name[len("int"):])
return f"int{width}_t"
elif name.startswith('uint'):
width = int(name[len("uint"):])
return f"uint{width}_t"
elif name == 'bool':
return 'bool'
else:
raise NotImplementedError(f"Can't map numpy to C name for {name}")
class AbstractType(sp.Atom):
# TODO: Is it necessary to ineherit from sp.Atom?
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
def _sympystr(self, *args, **kwargs):
return str(self)
@property
@abstractmethod
def base_type(self) -> Union[None, 'BasicType']:
"""
Returns: Returns BasicType of a Vector or Pointer type, None otherwise
"""
pass
@property
@abstractmethod
def item_size(self) -> int:
"""
Returns: Number of items.
E.g. width * item_size(basic_type) in vector's case, or simple numpy itemsize in Struct's case.
"""
pass
class BasicType(AbstractType):
"""
BasicType is defined with a const qualifier and a np.dtype.
"""
def __init__(self, dtype: Union[type, 'BasicType', str], const: bool = False):
if isinstance(dtype, BasicType):
self.numpy_dtype = dtype.numpy_dtype
self.const = dtype.const
else:
self.numpy_dtype = np.dtype(dtype)
self.const = const
assert is_supported_type(self.numpy_dtype), f'Type {self.numpy_dtype} is currently not supported!'
def __getnewargs__(self):
return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
@property
def base_type(self):
return None
@property
def item_size(self): # TODO: Do we want self.numpy_type.itemsize????
return 1
def is_float(self):
return issubclass(self.numpy_dtype.type, np.floating)
def is_half(self):
return issubclass(self.numpy_dtype.type, np.half)
def is_int(self):
return issubclass(self.numpy_dtype.type, np.integer)
def is_uint(self):
return issubclass(self.numpy_dtype.type, np.unsignedinteger)
def is_sint(self):
return issubclass(self.numpy_dtype.type, np.signedinteger)
def is_bool(self):
return issubclass(self.numpy_dtype.type, np.bool_)
def dtype_eq(self, other):
if not isinstance(other, BasicType):
return False
else:
return self.numpy_dtype == other.numpy_dtype
@property
def c_name(self) -> str:
return numpy_name_to_c(self.numpy_dtype.name)
def __str__(self):
return f'{self.c_name}{" const" if self.const else ""}'
def __repr__(self):
return f'BasicType( {str(self)} )'
def _repr_html_(self):
return f'BasicType( {str(self)} )'
def __eq__(self, other):
return self.dtype_eq(other) and self.const == other.const
def __hash__(self):
return hash(str(self))
class VectorType(AbstractType):
"""
VectorType consists of a BasicType and a width.
"""
instruction_set = None
def __init__(self, base_type: BasicType, width: int):
self._base_type = base_type
self.width = width
@property
def base_type(self):
return self._base_type
@property
def item_size(self):
return self.width * self.base_type.item_size
def __eq__(self, other):
if not isinstance(other, VectorType):
return False
else:
return (self.base_type, self.width) == (other.base_type, other.width)
def __str__(self):
if self.instruction_set is None:
return f"{self.base_type}[{self.width}]"
else:
# TODO VectorizationRevamp: this seems super weird. the instruction_set should know how to print a type out!
# TODO VectorizationRevamp: this is error prone. base_type could be cons=True. Use dtype instead
if self.base_type == create_type("int64") or self.base_type == create_type("int32"):
return self.instruction_set['int']
elif self.base_type == create_type("float64"):
return self.instruction_set['double']
elif self.base_type == create_type("float32"):
return self.instruction_set['float']
elif self.base_type == create_type("bool"):
return self.instruction_set['bool']
else:
raise NotImplementedError()
def __hash__(self):
return hash((self.base_type, self.width))
def __getnewargs__(self):
return self._base_type, self.width
def __getnewargs_ex__(self):
return (self._base_type, self.width), {}
class PointerType(AbstractType):
def __init__(self, base_type: BasicType, const: bool = False, restrict: bool = True, double_pointer: bool = False):
self._base_type = base_type
self.const = const
self.restrict = restrict
self.double_pointer = double_pointer
def __getnewargs__(self):
return self.base_type, self.const, self.restrict, self.double_pointer
def __getnewargs_ex__(self):
return (self.base_type, self.const, self.restrict, self.double_pointer), {}
@property
def alias(self):
return not self.restrict
@property
def base_type(self):
return self._base_type
@property
def item_size(self):
if self.double_pointer:
raise NotImplementedError("The item_size for double_pointer is not implemented")
else:
return self.base_type.item_size
def __eq__(self, other):
if not isinstance(other, PointerType):
return False
else:
own = (self.base_type, self.const, self.restrict, self.double_pointer)
return own == (other.base_type, other.const, other.restrict, other.double_pointer)
def __str__(self):
restrict_str = "RESTRICT" if self.restrict else ""
const_str = "const" if self.const else ""
if self.double_pointer:
return f'{str(self.base_type)} ** {restrict_str} {const_str}'
else:
return f'{str(self.base_type)} * {restrict_str} {const_str}'
def __repr__(self):
return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self):
return hash((self._base_type, self.const, self.restrict, self.double_pointer))
class StructType(AbstractType):
"""
A list of types (with C offsets).
It is implemented with uint8_t and casts to the correct datatype.
"""
def __init__(self, numpy_type, const=False):
self.const = const
self._dtype = np.dtype(numpy_type)
def __getnewargs__(self):
return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
@property
def base_type(self):
return None
@property
def numpy_dtype(self):
return self._dtype
@property
def item_size(self):
return self.numpy_dtype.itemsize
def get_element_offset(self, element_name):
return self.numpy_dtype.fields[element_name][1]
def get_element_type(self, element_name):
np_element_type = self.numpy_dtype.fields[element_name][0]
return BasicType(np_element_type, self.const)
def has_element(self, element_name):
return element_name in self.numpy_dtype.fields
def __eq__(self, other):
if not isinstance(other, StructType):
return False
else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
def __str__(self):
# structs are handled byte-wise
result = "uint8_t"
if self.const:
result += " const"
return result
def __repr__(self):
return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self):
return hash((self.numpy_dtype, self.const))
def create_type(specification: Union[type, AbstractType, str]) -> AbstractType:
# TODO: Deprecated Use the constructor of BasicType or StructType instead
"""Creates a subclass of Type according to a string or an object of subclass Type.
Args:
specification: Type object, or a string
Returns:
Type object, or a new Type object parsed from the string
"""
if isinstance(specification, AbstractType):
return specification
else:
numpy_dtype = np.dtype(specification)
if numpy_dtype.fields is None:
return BasicType(numpy_dtype, const=False)
else:
return StructType(numpy_dtype, const=False)
from collections import defaultdict
from functools import partial
from typing import Tuple, Union, Sequence
import numpy as np
import sympy as sp
from sympy.logic.boolalg import Boolean, BooleanFunction
import pystencils
from pystencils.cache import memorycache_if_hashable
from pystencils.typing.types import BasicType, VectorType, PointerType, create_type
from pystencils.typing.cast_functions import CastFunc
from pystencils.typing.typed_sympy import TypedSymbol
from pystencils.utils import all_equal
def typed_symbols(names, dtype, **kwargs):
"""
Creates TypedSymbols with the same functionality as sympy.symbols
Args:
names: See sympy.symbols
dtype: The data type all symbols will have
**kwargs: Key value arguments passed to sympy.symbols
Returns:
TypedSymbols
"""
symbols = sp.symbols(names, **kwargs)
if isinstance(symbols, Tuple):
return tuple(TypedSymbol(str(s), dtype) for s in symbols)
else:
return TypedSymbol(str(symbols), dtype)
def get_base_type(data_type):
"""
Returns the BasicType of a Pointer or a Vector
"""
while data_type.base_type is not None:
data_type = data_type.base_type
return data_type
def result_type(*args: np.dtype):
"""Returns the type of the result if the np.dtype arguments would be collated.
We can't use numpy functionality, because numpy casts don't behave exactly like C casts"""
s = sorted(args, key=lambda x: x.itemsize)
def kind_to_value(kind: str) -> int:
if kind == 'f':
return 3
elif kind == 'i':
return 2
elif kind == 'u':
return 1
elif kind == 'b':
return 0
else:
raise NotImplementedError(f'{kind=} is not a supported kind of a type. See "numpy.dtype.kind" for options')
s = sorted(s, key=lambda x: kind_to_value(x.kind))
return s[-1]
def collate_types(types: Sequence[Union[BasicType, VectorType]]):
"""
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy.
"""
# Pointer arithmetic case i.e. pointer + [int, uint] is allowed
if any(isinstance(t, PointerType) for t in types):
pointer_type = None
for t in types:
if isinstance(t, PointerType):
if pointer_type is not None:
raise ValueError(f'Cannot collate the combination of two pointer types "{pointer_type}" and "{t}"')
pointer_type = t
elif isinstance(t, BasicType):
if not (t.is_int() or t.is_uint()):
raise ValueError("Invalid pointer arithmetic")
else:
raise ValueError("Invalid pointer arithmetic")
return pointer_type
# # peel of vector types, if at least one vector type occurred the result will also be the vector type
vector_type = [t for t in types if isinstance(t, VectorType)]
if not all_equal(t.width for t in vector_type):
raise ValueError("Collation failed because of vector types with different width")
# TODO: check if this is needed
# def peel_off_type(dtype, type_to_peel_off):
# while type(dtype) is type_to_peel_off:
# dtype = dtype.base_type
# return dtype
# types = [peel_off_type(t, VectorType) for t in types]
types = [t.base_type if isinstance(t, VectorType) else t for t in types]
# now we should have a list of basic types - struct types are not yet supported
assert all(type(t) is BasicType for t in types)
result_numpy_type = result_type(*(t.numpy_dtype for t in types))
result = BasicType(result_numpy_type)
if vector_type:
result = VectorType(result, vector_type[0].width)
return result
# TODO get_type_of_expression should be used after leaf_typing. So no defaults should be necessary
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
default_float_type='double',
default_int_type='int',
symbol_type_dict=None):
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any
if default_float_type == 'float':
default_float_type = 'float32'
if not symbol_type_dict:
symbol_type_dict = defaultdict(lambda: create_type('double'))
# TODO this line is quite hard to understand, if possible simpl
get_type = partial(get_type_of_expression,
default_float_type=default_float_type,
default_int_type=default_int_type,
symbol_type_dict=symbol_type_dict)
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return create_type(default_int_type)
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type(default_float_type)
elif isinstance(expr, ResolvedFieldAccess):
return expr.field.dtype
elif isinstance(expr, pystencils.field.Field.Access):
return expr.field.dtype
elif isinstance(expr, TypedSymbol):
return expr.dtype
elif isinstance(expr, sp.Symbol):
# TODO delete if case
if symbol_type_dict:
return symbol_type_dict[expr.name]
else:
raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
elif isinstance(expr, CastFunc):
return expr.args[1]
elif isinstance(expr, (vec_any, vec_all)):
return create_type("bool")
elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
return collated_result_type
elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label
return typed_symbol.dtype.base_type
elif isinstance(expr, (Boolean, BooleanFunction)):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool")
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
if vec_args:
result = VectorType(result, width=vec_args[0].width)
return result
elif isinstance(expr, sp.Pow):
base_type = get_type(expr.args[0])
if expr.exp.is_integer:
return base_type
else:
return collate_types([create_type(default_float_type), base_type])
elif isinstance(expr, (sp.Sum, sp.Product)):
return get_type(expr.args[0])
elif isinstance(expr, sp.Expr):
expr: sp.Expr
if expr.args:
types = tuple(get_type(a) for a in expr.args)
return collate_types(types)
else:
if expr.is_integer:
return create_type(default_int_type)
else:
return create_type(default_float_type)
raise NotImplementedError("Could not determine type for", expr, type(expr))
# Fix for sympy versions from 1.9
sympy_version = sp.__version__.split('.')
sympy_version_int = int(sympy_version[0]) * 100 + int(sympy_version[1])
if sympy_version_int >= 109:
# __setstate__ would bypass the contructor, so we remove it
if sympy_version_int >= 111:
del sp.Basic.__setstate__
del sp.Symbol.__setstate__
else:
sp.Number.__getstate__ = sp.Basic.__getstate__
del sp.Basic.__getstate__
# __reduce_ex__ would strip kwargs, so we override it
def basic_reduce_ex(self, protocol):
if hasattr(self, '__getnewargs_ex__'):
args, kwargs = self.__getnewargs_ex__()
else:
args, kwargs = self.__getnewargs__(), {}
if hasattr(self, '__getstate__'):
state = self.__getstate__()
else:
state = None
return partial(type(self), **kwargs), args, state
sp.Basic.__reduce_ex__ = basic_reduce_ex
def get_next_parent_of_type(node, parent_type):
"""Returns the next parent node of given type or None, if root is reached.
Traverses the AST nodes parents until a parent of given type was found.
If no such parent is found, None is returned
"""
parent = node.parent
while parent is not None:
if isinstance(parent, parent_type):
return parent
parent = parent.parent
return None
def parents_of_type(node, parent_type, include_current=False):
"""Generator for all parent nodes of given type"""
parent = node if include_current else node.parent
while parent is not None:
if isinstance(parent, parent_type):
yield parent
parent = parent.parent
import os
import itertools
from itertools import groupby
from collections import Counter
from contextlib import contextmanager
from tempfile import NamedTemporaryFile
from typing import Mapping
import numpy as np
import sympy as sp
class DotDict(dict):
"""Normal dict with additional dot access for all keys"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
# Recursively make DotDict: https://stackoverflow.com/questions/13520421/recursive-dotdict
def __init__(self, dct={}):
for key, value in dct.items():
if isinstance(value, dict):
value = DotDict(value)
self[key] = value
def all_equal(iterable):
"""
Returns ``True`` if all the elements are equal to each other.
Copied from: more-itertools 8.12.0
"""
g = groupby(iterable)
return next(g, True) and not next(g, False)
def recursive_dict_update(d, u):
"""Updates the first dict argument, using second dictionary recursively.
Examples:
>>> d = {'sub_dict': {'a': 1, 'b': 2}, 'outer': 42}
>>> u = {'sub_dict': {'a': 5, 'c': 10}, 'outer': 41, 'outer2': 43}
>>> recursive_dict_update(d, u)
{'sub_dict': {'a': 5, 'b': 2, 'c': 10}, 'outer': 41, 'outer2': 43}
"""
d = d.copy()
for k, v in u.items():
if isinstance(v, Mapping):
r = recursive_dict_update(d.get(k, {}), v)
d[k] = r
else:
d[k] = u[k]
return d
@contextmanager
def atomic_file_write(file_path):
target_folder = os.path.dirname(os.path.abspath(file_path))
with NamedTemporaryFile(delete=False, dir=target_folder) as f:
f.file.close()
yield f.name
os.replace(f.name, file_path)
def fully_contains(l1, l2):
"""Tests if elements of sequence 1 are in sequence 2 in same or higher number.
>>> fully_contains([1, 1, 2], [1, 2]) # 1 is only present once in second list
False
>>> fully_contains([1, 1, 2], [1, 1, 4, 2])
True
"""
l1_counter = Counter(l1)
l2_counter = Counter(l2)
for element, count in l1_counter.items():
if l2_counter[element] < count:
return False
return True
def boolean_array_bounding_box(boolean_array):
"""Returns bounding box around "true" area of boolean array
>>> a = np.zeros((4, 4), dtype=bool)
>>> a[1:-1, 1:-1] = True
>>> boolean_array_bounding_box(a) == [(1, 3), (1, 3)]
True
"""
dim = boolean_array.ndim
shape = boolean_array.shape
assert 0 not in shape, "Shape must not contain zero"
bounds = []
for ax in itertools.combinations(reversed(range(dim)), dim - 1):
nonzero = np.any(boolean_array, axis=ax)
t = np.where(nonzero)[0][[0, -1]]
bounds.append((t[0], t[1] + 1))
return bounds
def binary_numbers(n):
"""Returns all binary numbers up to 2^n - 1
Example:
>>> binary_numbers(2)
[[0, 0], [0, 1], [1, 0], [1, 1]]
"""
result = list()
for i in range(1 << n):
binary_number = bin(i)[2:]
binary_number = '0' * (n - len(binary_number)) + binary_number
result.append((list(map(int, binary_number))))
return result
class LinearEquationSystem:
"""Symbolic linear system of equations - consisting of matrix and right hand side.
Equations can be added incrementally. System is held in reduced row echelon form to quickly determine if
system has a single, multiple, or no solution.
Example:
>>> x, y= sp.symbols("x, y")
>>> les = LinearEquationSystem([x, y])
>>> les.add_equation(x - y - 3)
>>> les.solution_structure()
'multiple'
>>> les.add_equation(x + y - 4)
>>> les.solution_structure()
'single'
>>> les.solution()
{x: 7/2, y: 1/2}
"""
def __init__(self, unknowns):
size = len(unknowns)
self._matrix = sp.zeros(size, size + 1)
self.unknowns = unknowns
self.next_zero_row = 0
self._reduced = True
def copy(self):
"""Returns a copy of the equation system."""
new = LinearEquationSystem(self.unknowns)
new._matrix = self._matrix.copy()
new.next_zero_row = self.next_zero_row
return new
def add_equation(self, linear_equation):
"""Add a linear equation as sympy expression. Implicit "-0" is assumed. Equation has to be linear and contain
only unknowns passed to the constructor otherwise a ValueError is raised. """
self._resize_if_necessary()
linear_equation = linear_equation.expand()
zero_row_idx = self.next_zero_row
self.next_zero_row += 1
control = 0
for i, unknown in enumerate(self.unknowns):
self._matrix[zero_row_idx, i] = linear_equation.coeff(unknown)
control += unknown * self._matrix[zero_row_idx, i]
rest = linear_equation - control
if rest.atoms(sp.Symbol):
raise ValueError("Not a linear equation in the unknowns")
self._matrix[zero_row_idx, -1] = -rest
self._reduced = False
def add_equations(self, linear_equations):
"""Add a sequence of equations. For details see `add_equation`. """
self._resize_if_necessary(len(linear_equations))
for eq in linear_equations:
self.add_equation(eq)
def set_unknown_zero(self, unknown_idx):
"""Sets an unknown to zero - pass the index not the variable itself!"""
assert unknown_idx < len(self.unknowns)
self._resize_if_necessary()
self._matrix[self.next_zero_row, unknown_idx] = 1
self.next_zero_row += 1
self._reduced = False
def reduce(self):
"""Brings the system in reduced row echelon form."""
if self._reduced:
return
self._matrix = self._matrix.rref()[0]
self._update_next_zero_row()
self._reduced = True
@property
def matrix(self):
"""Return a matrix that represents the equation system.
Has one column more than unknowns for the affine part."""
self.reduce()
return self._matrix
@property
def rank(self):
self.reduce()
return self.next_zero_row
def solution_structure(self):
"""Returns either 'multiple', 'none' or 'single' to indicate how many solutions the system has."""
self.reduce()
non_zero_rows = self.next_zero_row
num_unknowns = len(self.unknowns)
if non_zero_rows == 0:
return 'multiple'
*row_begin, left, right = self._matrix.row(non_zero_rows - 1)
if non_zero_rows > num_unknowns:
return 'none'
elif non_zero_rows == num_unknowns:
if left == 0 and right != 0:
return 'none'
else:
return 'single'
elif non_zero_rows < num_unknowns:
if right != 0 and left == 0 and all(e == 0 for e in row_begin):
return 'none'
else:
return 'multiple'
def solution(self):
"""Solves the system. Under- and overdetermined systems are supported.
Returns a dictionary mapping symbol to solution value."""
return sp.solve_linear_system(self._matrix, *self.unknowns)
def _resize_if_necessary(self, new_rows=1):
if self.next_zero_row + new_rows > self._matrix.shape[0]:
self._matrix = self._matrix.row_insert(self._matrix.shape[0] + 1,
sp.zeros(new_rows, self._matrix.shape[1]))
def _update_next_zero_row(self):
result = self._matrix.shape[0]
while result >= 0:
row_to_check = result - 1
if any(e != 0 for e in self._matrix.row(row_to_check)):
break
result -= 1
self.next_zero_row = result
class ContextVar:
def __init__(self, value):
self.stack = [value]
@contextmanager
def __call__(self, new_value):
self.stack.append(new_value)
yield self
self.stack.pop()
def get(self):
return self.stack[-1]
import operator
from functools import reduce
from collections import defaultdict, Sequence
import itertools
import warnings
import sympy as sp
def prod(seq):
"""Takes a sequence and returns the product of all elements"""
return reduce(operator.mul, seq, 1)
def allIn(a, b):
"""Tests if all elements of a container 'a' are contained in 'b'"""
return all(element in b for element in a)
def isIntegerSequence(sequence):
try:
[int(i) for i in sequence]
return True
except TypeError:
return False
def scalarProduct(a, b):
return sum(a_i * b_i for a_i, b_i in zip(a, b))
def equationsToMatrix(equations, degreesOfFreedom):
return sp.Matrix(len(equations), len(degreesOfFreedom),
lambda row, col: equations[row].coeff(degreesOfFreedom[col]))
def kroneckerDelta(*args):
"""Kronecker delta for variable number of arguments,
1 if all args are equal, otherwise 0"""
for a in args:
if a != args[0]:
return 0
return 1
def multidimensionalSummation(i, dim):
"""Multidimensional summation"""
prodArgs = [range(dim)] * i
return itertools.product(*prodArgs)
def normalizeProduct(product):
"""
Expects a sympy expression that can be interpreted as a product and
- for a Mul node returns its factors ('args')
- for a Pow node with positive integer exponent returns a list of factors
- for other node types [product] is returned
"""
def handlePow(power):
if power.exp.is_integer and power.exp.is_number and power.exp > 0:
return [power.base] * power.exp
else:
return [power]
if product.func == sp.Pow:
return handlePow(product)
elif product.func == sp.Mul:
result = []
for a in product.args:
if a.func == sp.Pow:
result += handlePow(a)
else:
result.append(a)
return result
else:
return [product]
def productSymmetric(*args, withDiagonal=True):
"""Similar to itertools.product but returns only values where the index is ascending i.e. values below diagonal"""
ranges = [range(len(a)) for a in args]
for idx in itertools.product(*ranges):
validIndex = True
for t in range(1, len(idx)):
if (withDiagonal and idx[t - 1] > idx[t]) or (not withDiagonal and idx[t - 1] >= idx[t]):
validIndex = False
break
if validIndex:
yield tuple(a[i] for a, i in zip(args, idx))
def fastSubs(term, subsDict, skip=None):
"""Similar to sympy subs function.
This version is much faster for big substitution dictionaries than sympy version"""
def visit(expr):
if skip and skip(expr):
return expr
if hasattr(expr, "fastSubs"):
return expr.fastSubs(subsDict)
if expr in subsDict:
return subsDict[expr]
if not hasattr(expr, 'args'):
return expr
paramList = [visit(a) for a in expr.args]
return expr if not paramList else expr.func(*paramList)
if len(subsDict) == 0:
return term
else:
return visit(term)
def fastSubsWithNormalize(term, subsDict, normalizeFunc):
def visit(expr):
if expr in subsDict:
return subsDict[expr], True
if not hasattr(expr, 'args'):
return expr, False
paramList = []
substituted = False
for a in expr.args:
replacedExpr, s = visit(a)
paramList.append(replacedExpr)
if s:
substituted = True
if not paramList:
return expr, False
else:
if substituted:
result, _ = visit(normalizeFunc(expr.func(*paramList)))
return result, True
else:
return expr.func(*paramList), False
if len(subsDict) == 0:
return term
else:
res, _ = visit(term)
return res
def replaceAdditive(expr, replacement, subExpression, requiredMatchReplacement=0.5, requiredMatchOriginal=None):
"""
Transformation for replacing a given subexpression inside a sum
Example 1:
expr = 3*x + 3 * y
replacement = k
subExpression = x+y
return = 3*k
Example 2:
expr = 3*x + 3 * y + z
replacement = k
subExpression = x+y+z
return:
if minimalMatchingTerms >=3 the expression would not be altered
if smaller than 3 the result is 3*k - 2*z
:param expr: input expression
:param replacement: expression that is inserted for subExpression (if found)
:param subExpression: expression to replace
:param requiredMatchReplacement:
- if float: the percentage of terms of the subExpression that has to be matched in order to replace
- if integer: the total number of terms that has to be matched in order to replace
- None: is equal to integer 1
- if both match parameters are given, both restrictions have to be fulfilled (i.e. logical AND)
:param requiredMatchOriginal:
- if float: the percentage of terms of the original addition expression that has to be matched
- if integer: the total number of terms that has to be matched in order to replace
- None: is equal to integer 1
:return: new expression with replacement
"""
def normalizeMatchParameter(matchParameter, expressingLength):
if matchParameter is None:
return 1
elif isinstance(matchParameter, float):
assert 0 <= matchParameter <= 1
res = int(matchParameter * expressingLength)
return max(res, 1)
elif isinstance(matchParameter, int):
assert matchParameter > 0
return matchParameter
raise ValueError("Invalid parameter")
normalizedReplacementMatch = normalizeMatchParameter(requiredMatchReplacement, len(subExpression.args))
def visit(currentExpr):
if currentExpr.is_Add:
exprMaxLength = max(len(currentExpr.args), len(subExpression.args))
normalizedCurrentExprMatch = normalizeMatchParameter(requiredMatchOriginal, exprMaxLength)
exprCoeffs = currentExpr.as_coefficients_dict()
subexprCoeffDict = subExpression.as_coefficients_dict()
intersection = set(subexprCoeffDict.keys()).intersection(set(exprCoeffs))
if len(intersection) >= max(normalizedReplacementMatch, normalizedCurrentExprMatch):
# find common factor
factors = defaultdict(lambda: 0)
skips = 0
for commonSymbol in subexprCoeffDict.keys():
if commonSymbol not in exprCoeffs:
skips += 1
continue
factor = exprCoeffs[commonSymbol] / subexprCoeffDict[commonSymbol]
factors[sp.simplify(factor)] += 1
commonFactor = max(factors.items(), key=operator.itemgetter(1))[0]
if factors[commonFactor] >= max(normalizedCurrentExprMatch, normalizedReplacementMatch):
return currentExpr - commonFactor * subExpression + commonFactor * replacement
# if no subexpression was found
paramList = [visit(a) for a in currentExpr.args]
if not paramList:
return currentExpr
else:
return currentExpr.func(*paramList, evaluate=False)
return visit(expr)
def replaceSecondOrderProducts(expr, searchSymbols, positive=None, replaceMixed=None):
"""
Replaces second order mixed terms like x*y by 2* ( (x+y)**2 - x**2 - y**2 )
This makes the term longer - simplify usually is undoing these - however this
transformation can be done to find more common sub-expressions
:param expr: input expression
:param searchSymbols: list of symbols that are searched for
Example: given [ x,y,z] terms like x*y, x*z, z*y are replaced
:param positive: there are two ways to do this substitution, either with term
(x+y)**2 or (x-y)**2 . if positive=True the first version is done,
if positive=False the second version is done, if positive=None the
sign is determined by the sign of the mixed term that is replaced
:param replaceMixed: if a list is passed here the expr x+y or x-y is replaced by a special new symbol
the replacement equation is added to the list
:return:
"""
if replaceMixed is not None:
mixedSymbolsReplaced = set([e.lhs for e in replaceMixed])
if expr.is_Mul:
distinctVelTerms = set()
nrOfVelTerms = 0
otherFactors = 1
for t in expr.args:
if t in searchSymbols:
nrOfVelTerms += 1
distinctVelTerms.add(t)
else:
otherFactors *= t
if len(distinctVelTerms) == 2 and nrOfVelTerms == 2:
u, v = sorted(list(distinctVelTerms), key=lambda symbol: symbol.name)
if positive is None:
otherFactorsWithoutSymbols = otherFactors
for s in otherFactors.atoms(sp.Symbol):
otherFactorsWithoutSymbols = otherFactorsWithoutSymbols.subs(s, 1)
positive = otherFactorsWithoutSymbols.is_positive
assert positive is not None
sign = 1 if positive else -1
if replaceMixed is not None:
newSymbolStr = 'P' if positive else 'M'
mixedSymbolName = u.name + newSymbolStr + v.name
mixedSymbol = sp.Symbol(mixedSymbolName.replace("_", ""))
if mixedSymbol not in mixedSymbolsReplaced:
mixedSymbolsReplaced.add(mixedSymbol)
replaceMixed.append(sp.Eq(mixedSymbol, u + sign * v))
else:
mixedSymbol = u + sign * v
return sp.Rational(1, 2) * sign * otherFactors * (mixedSymbol ** 2 - u ** 2 - v ** 2)
paramList = [replaceSecondOrderProducts(a, searchSymbols, positive, replaceMixed) for a in expr.args]
result = expr.func(*paramList, evaluate=False) if paramList else expr
return result
def removeHigherOrderTerms(term, order=3, symbols=None):
"""
Removes all terms that that contain more than 'order' factors of given 'symbols'
Example:
>>> x, y = sp.symbols("x y")
>>> term = x**2 * y + y**2 * x + y**3 + x + y ** 2
>>> removeHigherOrderTerms(term, order=2, symbols=[x, y])
x + y**2
"""
from sympy.core.power import Pow
from sympy.core.add import Add, Mul
result = 0
term = term.expand()
if not symbols:
symbols = sp.symbols(" ".join(["u_%d" % (i,) for i in range(3)]))
symbols += sp.symbols(" ".join(["u_%d" % (i,) for i in range(3)]), real=True)
def velocityFactorsInProduct(product):
uFactorCount = 0
if type(product) is Mul:
for factor in product.args:
if type(factor) == Pow:
if factor.args[0] in symbols:
uFactorCount += factor.args[1]
if factor in symbols:
uFactorCount += 1
elif type(product) is Pow:
if product.args[0] in symbols:
uFactorCount += product.args[1]
return uFactorCount
if type(term) == Mul or type(term) == Pow:
if velocityFactorsInProduct(term) <= order:
return term
else:
return sp.Rational(0, 1)
if type(term) != Add:
return term
for sumTerm in term.args:
if velocityFactorsInProduct(sumTerm) <= order:
result += sumTerm
return result
def completeTheSquare(expr, symbolToComplete, newVariable):
"""
Transforms second order polynomial into only squared part i.e.
a*symbolToComplete**2 + b*symbolToComplete + c
is transformed into
newVariable**2 + d
returns replacedExpr, "a tuple to to replace newVariable such that old expr comes out again"
if given expr is not a second order polynomial:
return expr, None
"""
p = sp.Poly(expr, symbolToComplete)
coeffs = p.all_coeffs()
if len(coeffs) != 3:
return expr, None
a, b, _ = coeffs
expr = expr.subs(symbolToComplete, newVariable - b / (2 * a))
return sp.simplify(expr), (newVariable, symbolToComplete + b / (2 * a))
def makeExponentialFuncArgumentSquares(expr, variablesToCompleteSquares):
"""Completes squares in arguments of exponential which makes them simpler to integrate
Very useful for integrating Maxwell-Boltzmann and its moment generating function"""
expr = sp.simplify(expr)
dim = len(variablesToCompleteSquares)
dummies = [sp.Dummy() for i in range(dim)]
def visit(term):
if term.func == sp.exp:
expArg = term.args[0]
for i in range(dim):
expArg, substitution = completeTheSquare(expArg, variablesToCompleteSquares[i], dummies[i])
return sp.exp(sp.expand(expArg))
else:
paramList = [visit(a) for a in term.args]
if not paramList:
return term
else:
return term.func(*paramList)
result = visit(expr)
for i in range(dim):
result = result.subs(dummies[i], variablesToCompleteSquares[i])
return result
def pow2mul(expr):
"""
Convert integer powers in an expression to Muls, like a**2 => a*a.
"""
pows = list(expr.atoms(sp.Pow))
if any(not e.is_Integer for b, e in (i.as_base_exp() for i in pows)):
raise ValueError("A power contains a non-integer exponent")
repl = zip(pows, (sp.Mul(*[b]*e, evaluate=False) for b, e in (i.as_base_exp() for i in pows)))
return expr.subs(repl)
def extractMostCommonFactor(term):
"""Processes a sum of fractions: determines the most common factor and splits term in common factor and rest"""
import operator
from collections import Counter
from sympy.functions import Abs
coeffDict = term.as_coefficients_dict()
counter = Counter([Abs(v) for v in coeffDict.values()])
commonFactor, occurrences = max(counter.items(), key=operator.itemgetter(1))
if occurrences == 1 and (1 in counter):
commonFactor = 1
return commonFactor, term / commonFactor
def mostCommonTermFactorization(term):
commonFactor, term = extractMostCommonFactor(term)
factorization = sp.factor(term)
if factorization.is_Mul:
symbolsInFactorization = []
constantsInFactorization = 1
for arg in factorization.args:
if len(arg.atoms(sp.Symbol)) == 0:
constantsInFactorization *= arg
else:
symbolsInFactorization.append(arg)
if len(symbolsInFactorization) <= 1:
return sp.Mul(commonFactor, term, evaluate=False)
else:
args = symbolsInFactorization[:-1] + [constantsInFactorization * symbolsInFactorization[-1]]
return sp.Mul(commonFactor, *args)
else:
return sp.Mul(commonFactor, term, evaluate=False)
def countNumberOfOperations(term):
"""
Counts the number of additions, multiplications and division
:param term: a sympy term, equation or sequence of terms/equations
:return: a dictionary with 'adds', 'muls' and 'divs' keys
"""
result = {'adds': 0, 'muls': 0, 'divs': 0}
if isinstance(term, Sequence):
for element in term:
r = countNumberOfOperations(element)
for operationName in result.keys():
result[operationName] += r[operationName]
return result
elif isinstance(term, sp.Eq):
term = term.rhs
term = term.evalf()
def visit(t):
visitChildren = True
if t.func is sp.Add:
result['adds'] += len(t.args) - 1
elif t.func is sp.Mul:
result['muls'] += len(t.args) - 1
for a in t.args:
if a == 1 or a == -1:
result['muls'] -= 1
elif t.func is sp.Float:
pass
elif isinstance(t, sp.Symbol):
visitChildren = False
elif isinstance(t, sp.tensor.Indexed):
visitChildren = False
elif t.is_integer:
pass
elif t.func is sp.Pow:
visitChildren = False
if t.exp.is_integer and t.exp.is_number:
if t.exp >= 0:
result['muls'] += int(t.exp) - 1
else:
result['muls'] -= 1
result['divs'] += 1
result['muls'] += (-int(t.exp)) - 1
else:
warnings.warn("Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate")
else:
warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate")
if visitChildren:
for a in t.args:
visit(a)
visit(term)
return result
def countNumberOfOperationsInAst(ast):
"""Counts number of operations in an abstract syntax tree, see also :func:`countNumberOfOperations`"""
from pystencils.astnodes import SympyAssignment
result = {'adds': 0, 'muls': 0, 'divs': 0}
def visit(node):
if isinstance(node, SympyAssignment):
r = countNumberOfOperations(node.rhs)
result['adds'] += r['adds']
result['muls'] += r['muls']
result['divs'] += r['divs']
else:
for arg in node.args:
visit(arg)
visit(ast)
return result
def matrixFromColumnVectors(columnVectors):
"""Creates a sympy matrix from column vectors.
:param columnVectors: nested sequence - i.e. a sequence of column vectors
"""
c = columnVectors
return sp.Matrix([list(c[i]) for i in range(len(c))]).transpose()
def commonDenominator(expr):
denominators = [r.q for r in expr.atoms(sp.Rational)]
return sp.lcm(denominators)
def getSymmetricPart(term, vars):
"""
Returns the symmetric part of a sympy expressions.
:param term: sympy expression, labeled here as :math:`f`
:param vars: sequence of symbols which are considered as degrees of freedom, labeled here as :math:`x_0, x_1,...`
:returns: :math:`\frac{1}{2} [ f(x_0, x_1, ..) + f(-x_0, -x_1) ]`
"""
substitutionDict = {e: -e for e in vars}
return sp.Rational(1, 2) * (term + term.subs(substitutionDict))
def sortEquationsTopologically(equationSequence):
res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equationSequence])
return [sp.Eq(a, b) for a, b in res]
def getEquationsFromFunction(func, **kwargs):
"""
Mechanism to simplify the generation of a list of sympy equations.
Introduces a special "assignment operator" written as "@=". Each line containing this operator gives an
equation in the result list. Note that executing this function normally yields an error.
Additionally the shortcut object 'S' is available to quickly create new sympy symbols.
Example:
>>> def myKernel():
... from pystencils import Field
... f = Field.createGeneric('f', spatialDimensions=2, indexDimensions=0)
... g = f.newFieldWithDifferentName('g')
...
... S.neighbors @= f[0,1] + f[1,0]
... g[0,0] @= S.neighbors + f[0,0]
>>> getEquationsFromFunction(myKernel)
[Eq(neighbors, f_E + f_N), Eq(g_C, f_C + neighbors)]
"""
import inspect
import re
class SymbolCreator:
def __getattribute__(self, name):
return sp.Symbol(name)
assignmentRegexp = re.compile(r'(\s*)(.+?)@=(.*)')
whitespaceRegexp = re.compile(r'(\s*)(.*)')
sourceLines = inspect.getsourcelines(func)[0]
# determine indentation
firstCodeLine = sourceLines[1]
matchRes = whitespaceRegexp.match(firstCodeLine)
assert matchRes, "First line is not indented"
numWhitespaces = len(matchRes.group(1))
for i in range(1, len(sourceLines)):
sourceLine = sourceLines[i][numWhitespaces:]
if 'return' in sourceLine:
raise ValueError("Function may not have a return statement!")
matchRes = assignmentRegexp.match(sourceLine)
if matchRes:
sourceLine = "%s_result.append(Eq(%s, %s))\n" % matchRes.groups()
sourceLines[i] = sourceLine
code = "".join(sourceLines[1:])
result = []
localsDict = {'_result': result,
'Eq': sp.Eq,
'S': SymbolCreator()}
localsDict.update(kwargs)
globalsDict = inspect.stack()[1][0].f_globals.copy()
globalsDict.update(inspect.stack()[1][0].f_locals)
exec(code, globalsDict, localsDict)
return result
import pytest
import sympy as sp
import numpy
import pystencils
from pystencils.datahandling import create_data_handling
@pytest.mark.parametrize('dtype', ["float64", "float32"])
@pytest.mark.parametrize('sympy_function', [sp.Min, sp.Max])
def test_max(dtype, sympy_function):
dh = create_data_handling(domain_size=(10, 10), periodicity=True)
x = dh.add_array('x', values_per_cell=1, dtype=dtype)
dh.fill("x", 0.0, ghost_layers=True)
y = dh.add_array('y', values_per_cell=1, dtype=dtype)
dh.fill("y", 1.0, ghost_layers=True)
z = dh.add_array('z', values_per_cell=1, dtype=dtype)
dh.fill("z", 2.0, ghost_layers=True)
config = pystencils.CreateKernelConfig(default_number_float=dtype)
# test sp.Max with one argument
assignment_1 = pystencils.Assignment(x.center, sympy_function(y.center + 3.3))
ast_1 = pystencils.create_kernel(assignment_1, config=config)
kernel_1 = ast_1.compile()
# pystencils.show_code(ast_1)
# test sp.Max with two arguments
assignment_2 = pystencils.Assignment(x.center, sympy_function(0.5, y.center - 1.5))
ast_2 = pystencils.create_kernel(assignment_2, config=config)
kernel_2 = ast_2.compile()
# pystencils.show_code(ast_2)
# test sp.Max with many arguments
assignment_3 = pystencils.Assignment(x.center, sympy_function(z.center, 4.5, y.center - 1.5, y.center + z.center))
ast_3 = pystencils.create_kernel(assignment_3, config=config)
kernel_3 = ast_3.compile()
# pystencils.show_code(ast_3)
if sympy_function is sp.Max:
results = [4.3, 0.5, 4.5]
else:
results = [4.3, -0.5, -0.5]
dh.run_kernel(kernel_1)
assert numpy.all(dh.gather_array('x') == results[0])
dh.run_kernel(kernel_2)
assert numpy.all(dh.gather_array('x') == results[1])
dh.run_kernel(kernel_3)
assert numpy.all(dh.gather_array('x') == results[2])
@pytest.mark.parametrize('dtype', ["int64", 'int32'])
@pytest.mark.parametrize('sympy_function', [sp.Min, sp.Max])
def test_max_integer(dtype, sympy_function):
dh = create_data_handling(domain_size=(10, 10), periodicity=True)
x = dh.add_array('x', values_per_cell=1, dtype=dtype)
dh.fill("x", 0, ghost_layers=True)
y = dh.add_array('y', values_per_cell=1, dtype=dtype)
dh.fill("y", 1, ghost_layers=True)
z = dh.add_array('z', values_per_cell=1, dtype=dtype)
dh.fill("z", 2, ghost_layers=True)
config = pystencils.CreateKernelConfig(default_number_int=dtype)
# test sp.Max with one argument
assignment_1 = pystencils.Assignment(x.center, sympy_function(y.center + 3))
ast_1 = pystencils.create_kernel(assignment_1, config=config)
kernel_1 = ast_1.compile()
# pystencils.show_code(ast_1)
# test sp.Max with two arguments
assignment_2 = pystencils.Assignment(x.center, sympy_function(1, y.center - 1))
ast_2 = pystencils.create_kernel(assignment_2, config=config)
kernel_2 = ast_2.compile()
# pystencils.show_code(ast_2)
# test sp.Max with many arguments
assignment_3 = pystencils.Assignment(x.center, sympy_function(z.center, 4, y.center - 1, y.center + z.center))
ast_3 = pystencils.create_kernel(assignment_3, config=config)
kernel_3 = ast_3.compile()
# pystencils.show_code(ast_3)
if sympy_function is sp.Max:
results = [4, 1, 4]
else:
results = [4, 0, 0]
dh.run_kernel(kernel_1)
assert numpy.all(dh.gather_array('x') == results[0])
dh.run_kernel(kernel_2)
assert numpy.all(dh.gather_array('x') == results[1])
dh.run_kernel(kernel_3)
assert numpy.all(dh.gather_array('x') == results[2])
import pytest
import pystencils.config
import sympy
import pystencils as ps
from pystencils.typing import CastFunc, create_type
@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
def test_abs(target):
x, y, z = ps.fields('x, y, z: float64[2d]')
default_int_type = create_type('int64')
assignments = ps.AssignmentCollection({x[0, 0]: sympy.Abs(CastFunc(y[0, 0], default_int_type))})
config = pystencils.config.CreateKernelConfig(target=target)
ast = ps.create_kernel(assignments, config=config)
code = ps.get_code_str(ast)
print(code)
assert 'fabs(' not in code
"""
Test of pystencils.data_types.address_of
"""
import pytest
import pystencils
from pystencils.typing import PointerType, CastFunc, BasicType
from pystencils.functions import AddressOf
from pystencils.simp.simplifications import sympy_cse
import sympy as sp
def test_address_of():
x, y = pystencils.fields('x, y: int64[2d]')
s = pystencils.TypedSymbol('s', PointerType(BasicType('int64')))
assert AddressOf(x[0, 0]).canonical() == x[0, 0]
assert AddressOf(x[0, 0]).dtype == PointerType(x[0, 0].dtype, restrict=True)
with pytest.raises(ValueError):
assert AddressOf(sp.Symbol("a")).dtype
assignments = pystencils.AssignmentCollection({
s: AddressOf(x[0, 0]),
y[0, 0]: CastFunc(s, BasicType('int64'))
})
kernel = pystencils.create_kernel(assignments).compile()
# pystencils.show_code(kernel.ast)
assignments = pystencils.AssignmentCollection({
y[0, 0]: CastFunc(AddressOf(x[0, 0]), BasicType('int64'))
})
kernel = pystencils.create_kernel(assignments).compile()
# pystencils.show_code(kernel.ast)
def test_address_of_with_cse():
x, y = pystencils.fields('x, y: int64[2d]')
assignments = pystencils.AssignmentCollection({
x[0, 0]: CastFunc(AddressOf(x[0, 0]), BasicType('int64')) + 1
})
kernel = pystencils.create_kernel(assignments).compile()
# pystencils.show_code(kernel.ast)
assignments_cse = sympy_cse(assignments)
kernel = pystencils.create_kernel(assignments_cse).compile()
# pystencils.show_code(kernel.ast)
import pytest
from pystencils import create_data_handling
from pystencils.alignedarray import *
from pystencils.field import create_numpy_array_with_layout
def is_aligned(arr, alignment, byte_offset=0):
address = arr.__array_interface__['data'][0]
rest = (address + byte_offset) % alignment
if rest:
print("Alignment rest", rest)
return rest == 0
@pytest.mark.parametrize("alignment", [8, 8*4, True])
@pytest.mark.parametrize("shape", [17, 16, (16, 16), (17, 17), (18, 18), (19, 19)])
def test_1d_arrays(alignment, shape):
arrays = [
aligned_zeros(shape, alignment),
aligned_ones(shape, alignment),
aligned_empty(shape, alignment),
]
for arr in arrays:
assert is_aligned(arr, alignment)
@pytest.mark.parametrize("order", ['C', 'F'])
@pytest.mark.parametrize("alignment", [8, 8*4, True])
@pytest.mark.parametrize("shape", [(16, 16), (17, 17), (18, 18), (19, 19)])
def test_3d_arrays(order, alignment, shape):
arrays = [
aligned_zeros(shape, alignment, order=order),
aligned_ones(shape, alignment, order=order),
aligned_empty(shape, alignment, order=order),
]
for arr in arrays:
assert is_aligned(arr, alignment)
if order == 'C':
assert is_aligned(arr[1], alignment)
assert is_aligned(arr[5], alignment)
else:
assert is_aligned(arr[..., 1], alignment)
assert is_aligned(arr[..., 5], alignment)
@pytest.mark.parametrize("parallel", [False, True])
def test_data_handling(parallel):
for tries in range(16): # try a few times, since we might get lucky and get randomly a correct alignment
dh = create_data_handling((6, 7), default_ghost_layers=1, parallel=parallel)
dh.add_array('test', alignment=8 * 4, values_per_cell=1)
for b in dh.iterate(ghost_layers=True, inner_ghost_layers=True):
arr = b['test']
assert is_aligned(arr[1:, 3:], 8*4)
def test_alignment_of_different_layouts():
offset = 1
byte_offset = 8
for tries in range(16): # try a few times, since we might get lucky and get randomly a correct alignment
arr = create_numpy_array_with_layout((3, 4, 5), layout=(0, 1, 2),
alignment=8*4, byte_offset=byte_offset)
assert is_aligned(arr[offset, ...], 8*4, byte_offset)
arr = create_numpy_array_with_layout((3, 4, 5), layout=(2, 1, 0),
alignment=8*4, byte_offset=byte_offset)
assert is_aligned(arr[..., offset], 8*4, byte_offset)
arr = create_numpy_array_with_layout((3, 4, 5), layout=(2, 0, 1),
alignment=8*4, byte_offset=byte_offset)
assert is_aligned(arr[:, 0, :], 8*4, byte_offset)
import pytest
import sympy as sp
import pystencils as ps
from pystencils import Assignment, AssignmentCollection
from pystencils.astnodes import Conditional
from pystencils.simp.assignment_collection import SymbolGen
a, b, c = sp.symbols("a b c")
x, y, z, t = sp.symbols("x y z t")
symbol_gen = SymbolGen("a")
f = ps.fields("f(2) : [2D]")
d = ps.fields("d(2) : [2D]")
def test_assignment_collection():
ac = AssignmentCollection([Assignment(z, x + y)],
[], subexpression_symbol_generator=symbol_gen)
lhs = ac.add_subexpression(t)
assert lhs == sp.Symbol("a_0")
ac.subexpressions.append(Assignment(t, 3))
ac.topological_sort(sort_main_assignments=False, sort_subexpressions=True)
assert ac.subexpressions[0].lhs == t
assert ac.new_with_inserted_subexpression(sp.Symbol("not_defined")) == ac
ac_inserted = ac.new_with_inserted_subexpression(t)
ac_inserted2 = ac.new_without_subexpressions({lhs})
assert all(a == b for a, b in zip(ac_inserted.all_assignments, ac_inserted2.all_assignments))
print(ac_inserted)
assert ac_inserted.subexpressions[0] == Assignment(lhs, 3)
assert 'a_0' in str(ac_inserted)
assert '<table' in ac_inserted._repr_html_()
def test_free_and_defined_symbols():
ac = AssignmentCollection([Assignment(z, x + y), Conditional(t > 0, Assignment(a, b+1), Assignment(a, b+2))],
[], subexpression_symbol_generator=symbol_gen)
print(ac)
print(ac.__repr__)
def test_vector_assignments():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
assignments = ps.Assignment(sp.Matrix([a, b, c]), sp.Matrix([1, 2, 3]))
print(assignments)
def test_wrong_vector_assignments():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
with pytest.raises(AssertionError,
match=r'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'):
ps.Assignment(sp.Matrix([a, b]), sp.Matrix([1, 2, 3]))
def test_vector_assignment_collection():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
y_m, x_m = sp.Matrix([a, b, c]), sp.Matrix([1, 2, 3])
assignments = ps.AssignmentCollection({y_m: x_m})
print(assignments)
assignments = ps.AssignmentCollection([ps.Assignment(y_m, x_m)])
print(assignments)
def test_new_with_substitutions():
a1 = ps.Assignment(f[0, 0](0), a * b)
a2 = ps.Assignment(f[0, 0](1), b * c)
ac = ps.AssignmentCollection([a1, a2], subexpressions=[])
subs_dict = {f[0, 0](0): d[0, 0](0), f[0, 0](1): d[0, 0](1)}
subs_ac = ac.new_with_substitutions(subs_dict,
add_substitutions_as_subexpressions=False,
substitute_on_lhs=True,
sort_topologically=True)
assert subs_ac.main_assignments[0].lhs == d[0, 0](0)
assert subs_ac.main_assignments[1].lhs == d[0, 0](1)
subs_ac = ac.new_with_substitutions(subs_dict,
add_substitutions_as_subexpressions=False,
substitute_on_lhs=False,
sort_topologically=True)
assert subs_ac.main_assignments[0].lhs == f[0, 0](0)
assert subs_ac.main_assignments[1].lhs == f[0, 0](1)
subs_dict = {a * b: sp.symbols('xi')}
subs_ac = ac.new_with_substitutions(subs_dict,
add_substitutions_as_subexpressions=False,
substitute_on_lhs=False,
sort_topologically=True)
assert subs_ac.main_assignments[0].rhs == sp.symbols('xi')
assert len(subs_ac.subexpressions) == 0
subs_ac = ac.new_with_substitutions(subs_dict,
add_substitutions_as_subexpressions=True,
substitute_on_lhs=False,
sort_topologically=True)
assert subs_ac.main_assignments[0].rhs == sp.symbols('xi')
assert len(subs_ac.subexpressions) == 1
assert subs_ac.subexpressions[0].lhs == sp.symbols('xi')
def test_copy():
a1 = ps.Assignment(f[0, 0](0), a * b)
a2 = ps.Assignment(f[0, 0](1), b * c)
ac = ps.AssignmentCollection([a1, a2], subexpressions=[])
ac2 = ac.copy()
assert ac2 == ac
def test_set_expressions():
a1 = ps.Assignment(f[0, 0](0), a * b)
a2 = ps.Assignment(f[0, 0](1), b * c)
ac = ps.AssignmentCollection([a1, a2], subexpressions=[])
ac.set_main_assignments_from_dict({d[0, 0](0): b * c})
assert len(ac.main_assignments) == 1
assert ac.main_assignments[0] == ps.Assignment(d[0, 0](0), b * c)
ac.set_sub_expressions_from_dict({sp.symbols('xi'): a * b})
assert len(ac.subexpressions) == 1
assert ac.subexpressions[0] == ps.Assignment(sp.symbols('xi'), a * b)
ac = ac.new_without_subexpressions(subexpressions_to_keep={sp.symbols('xi')})
assert ac.subexpressions[0] == ps.Assignment(sp.symbols('xi'), a * b)
ac = ac.new_without_unused_subexpressions()
assert len(ac.subexpressions) == 0
ac2 = ac.new_without_subexpressions()
assert ac == ac2
def test_free_and_bound_symbols():
a1 = ps.Assignment(a, d[0, 0](0))
a2 = ps.Assignment(f[0, 0](1), b * c)
ac = ps.AssignmentCollection([a2], subexpressions=[a1])
assert f[0, 0](1) in ac.bound_symbols
assert d[0, 0](0) in ac.free_symbols
def test_new_merged():
a1 = ps.Assignment(a, b * c)
a2 = ps.Assignment(a, x * y)
a3 = ps.Assignment(t, x ** 2)
# main assignments
a4 = ps.Assignment(f[0, 0](0), a)
a5 = ps.Assignment(d[0, 0](0), a)
ac = ps.AssignmentCollection([a4], subexpressions=[a1])
ac2 = ps.AssignmentCollection([a5], subexpressions=[a2, a3])
merged_ac = ac.new_merged(ac2)
assert len(merged_ac.subexpressions) == 3
assert len(merged_ac.main_assignments) == 2
assert ps.Assignment(sp.symbols('xi_0'), x * y) in merged_ac.subexpressions
assert ps.Assignment(d[0, 0](0), sp.symbols('xi_0')) in merged_ac.main_assignments
assert a1 in merged_ac.subexpressions
assert a3 in merged_ac.subexpressions
a1 = ps.Assignment(a, 20)
a2 = ps.Assignment(a, 10)
acommon = ps.Assignment(b, a)
# main assignments
a3 = ps.Assignment(f[0, 0](0), b)
a4 = ps.Assignment(d[0, 0](0), b)
ac = ps.AssignmentCollection([a3], subexpressions=[a1, acommon])
ac2 = ps.AssignmentCollection([a4], subexpressions=[a2, acommon])
merged_ac = ac.new_merged(ac2).new_without_subexpressions()
assert ps.Assignment(f[0, 0](0), 20) in merged_ac.main_assignments
assert ps.Assignment(d[0, 0](0), 10) in merged_ac.main_assignments
import pystencils
def test_assignment_collection_dict_conversion():
x, y = pystencils.fields('x,y: [2D]')
collection_normal = pystencils.AssignmentCollection(
[pystencils.Assignment(x.center(), y[1, 0] + y[0, 0])],
[]
)
collection_dict = pystencils.AssignmentCollection(
{x.center(): y[1, 0] + y[0, 0]},
{}
)
assert str(collection_normal) == str(collection_dict)
assert collection_dict.main_assignments_dict == {x.center(): y[1, 0] + y[0, 0]}
assert collection_dict.subexpressions_dict == {}
collection_normal = pystencils.AssignmentCollection(
[pystencils.Assignment(y[1, 0], x.center()),
pystencils.Assignment(y[0, 0], x.center())],
[]
)
collection_dict = pystencils.AssignmentCollection(
{y[1, 0]: x.center(),
y[0, 0]: x.center()},
{}
)
assert str(collection_normal) == str(collection_dict)
assert collection_dict.main_assignments_dict == {y[1, 0]: x.center(),
y[0, 0]: x.center()}
assert collection_dict.subexpressions_dict == {}
import numpy as np
import pystencils
def test_assignment_from_stencil():
stencil = [
[0, 0, 4, 1, 0, 0, 0],
[0, 0, 0, 2, 0, 0, 0],
[0, 0, 0, 3, 0, 0, 0]
]
x, y = pystencils.fields('x, y: [2D]')
assignment = pystencils.assignment.assignment_from_stencil(stencil, x, y)
assert isinstance(assignment, pystencils.Assignment)
assert assignment.rhs == x[0, 1] + 4 * x[-1, 1] + 2 * x[0, 0] + 3 * x[0, -1]
assignment = pystencils.assignment.assignment_from_stencil(stencil, x, y, normalization_factor=1 / np.sum(stencil))
assert isinstance(assignment, pystencils.Assignment)
import pytest
import sys
import pystencils.config
import sympy as sp
import pystencils as ps
from pystencils import Assignment
from pystencils.astnodes import Block, LoopOverCoordinate, SkipIteration, SympyAssignment
dst = ps.fields('dst(8): double[2D]')
s = sp.symbols('s_:8')
x = sp.symbols('x')
y = sp.symbols('y')
python_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
def test_kernel_function():
assignments = [
Assignment(dst[0, 0](0), s[0]),
Assignment(x, dst[0, 0](2))
]
ast_node = ps.create_kernel(assignments)
assert ast_node.target == ps.Target.CPU
assert ast_node.backend == ps.Backend.C
# symbols_defined and undefined_symbols will always return an emtpy set
assert ast_node.symbols_defined == set()
assert ast_node.undefined_symbols == set()
assert ast_node.fields_written == {dst}
assert ast_node.fields_read == {dst}
def test_skip_iteration():
# skip iteration is an object which should give back empty data structures.
skipped = SkipIteration()
assert skipped.args == []
assert skipped.symbols_defined == set()
assert skipped.undefined_symbols == set()
def test_block():
assignments = [
Assignment(dst[0, 0](0), s[0]),
Assignment(x, dst[0, 0](2))
]
bl = Block(assignments)
assert bl.symbols_defined == {dst[0, 0](0), dst[0, 0](2), s[0], x}
bl.append([Assignment(y, 10)])
assert bl.symbols_defined == {dst[0, 0](0), dst[0, 0](2), s[0], x, y}
assert len(bl.args) == 3
list_iterator = iter([Assignment(s[1], 11)])
bl.insert_front(list_iterator)
assert bl.args[0] == Assignment(s[1], 11)
def test_loop_over_coordinate():
assignments = [
Assignment(dst[0, 0](0), s[0]),
Assignment(x, dst[0, 0](2))
]
body = Block(assignments)
loop = LoopOverCoordinate(body, coordinate_to_loop_over=0, start=0, stop=10, step=1)
assert loop.body == body
new_body = Block([assignments[0]])
loop = loop.new_loop_with_different_body(new_body)
assert loop.body == new_body
assert loop.start == 0
assert loop.stop == 10
assert loop.step == 1
loop.replace(loop.start, 2)
loop.replace(loop.stop, 20)
loop.replace(loop.step, 2)
assert loop.start == 2
assert loop.stop == 20
assert loop.step == 2