Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 317 additions and 71 deletions
File moved
...@@ -4,9 +4,11 @@ import warnings ...@@ -4,9 +4,11 @@ import warnings
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from types import MappingProxyType from types import MappingProxyType
from typing import Set
import sympy as sp import sympy as sp
import pystencils as ps
import pystencils.astnodes as ast import pystencils.astnodes as ast
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.typing import (CastFunc, PointerType, StructType, TypedSymbol, get_base_type, from pystencils.typing import (CastFunc, PointerType, StructType, TypedSymbol, get_base_type,
...@@ -98,6 +100,45 @@ def generic_visit(term, visitor): ...@@ -98,6 +100,45 @@ def generic_visit(term, visitor):
return visitor(term) return visitor(term)
def iterate_loops_by_depth(node, nesting_depth):
"""Iterate all LoopOverCoordinate nodes in the given AST of the specified nesting depth.
Args:
node: Root node of the abstract syntax tree
nesting_depth: Nesting depth of the loops the pragmas should be applied to.
Outermost loop has depth 0.
A depth of -1 indicates the innermost loops.
Returns: Iterable listing all loop nodes of given nesting depth.
"""
from pystencils.astnodes import LoopOverCoordinate
def _internal_default(node, nesting_depth):
isloop = isinstance(node, LoopOverCoordinate)
if nesting_depth < 0: # here, a negative value indicates end of descent
return
elif nesting_depth == 0 and isloop:
yield node
else:
next_depth = nesting_depth - 1 if isloop else nesting_depth
for arg in node.args:
yield from _internal_default(arg, next_depth)
def _internal_innermost(node):
if isinstance(node, LoopOverCoordinate) and node.is_innermost_loop:
yield node
else:
for arg in node.args:
yield from _internal_innermost(arg)
if nesting_depth >= 0:
yield from _internal_default(node, nesting_depth)
elif nesting_depth == -1:
yield from _internal_innermost(node)
else:
raise ValueError(f"Invalid nesting depth: {nesting_depth}. Choose a nonnegative number, or -1.")
def unify_shape_symbols(body, common_shape, fields): def unify_shape_symbols(body, common_shape, fields):
"""Replaces symbols for array sizes to ensure they are represented by the same unique symbol. """Replaces symbols for array sizes to ensure they are represented by the same unique symbol.
...@@ -122,9 +163,10 @@ def unify_shape_symbols(body, common_shape, fields): ...@@ -122,9 +163,10 @@ def unify_shape_symbols(body, common_shape, fields):
body.subs(substitutions) body.subs(substitutions)
def get_common_shape(field_set): def get_common_field(field_set):
"""Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise """Takes a set of pystencils Fields, checks if a common spatial shape exists and returns one
ValueError is raised""" representative field, that can be used for shape information etc. in the kernel creation.
If the fields have different shapes ValueError is raised"""
nr_of_fixed_shaped_fields = 0 nr_of_fixed_shaped_fields = 0
for f in field_set: for f in field_set:
if f.has_fixed_shape: if f.has_fixed_shape:
...@@ -142,8 +184,9 @@ def get_common_shape(field_set): ...@@ -142,8 +184,9 @@ def get_common_shape(field_set):
if len(shape_set) != 1: if len(shape_set) != 1:
raise ValueError("Differently sized field accesses in loop body: " + str(shape_set)) raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0] # Sort the fields by their name to ensure that always the same field is returned
return shape reference_field = sorted(field_set, key=lambda e: str(e))[0]
return reference_field
def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None): def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None):
...@@ -161,9 +204,11 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or ...@@ -161,9 +204,11 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
tuple of loop-node, ghost_layer_info tuple of loop-node, ghost_layer_info
""" """
# find correct ordering by inspecting participating FieldAccesses # find correct ordering by inspecting participating FieldAccesses
absolut_accesses_only = False
field_accesses = body.atoms(Field.Access) field_accesses = body.atoms(Field.Access)
field_accesses = {e for e in field_accesses if not e.is_absolute_access} field_accesses = {e for e in field_accesses if not e.is_absolute_access}
if len(field_accesses) == 0: # when kernel contains only absolute accesses
absolut_accesses_only = True
# exclude accesses to buffers from field_list, because buffers are treated separately # exclude accesses to buffers from field_list, because buffers are treated separately
field_list = [e.field for e in field_accesses if not (FieldType.is_buffer(e.field) or FieldType.is_custom(e.field))] field_list = [e.field for e in field_accesses if not (FieldType.is_buffer(e.field) or FieldType.is_custom(e.field))]
if len(field_list) == 0: # when kernel contains only custom fields if len(field_list) == 0: # when kernel contains only custom fields
...@@ -174,14 +219,23 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or ...@@ -174,14 +219,23 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
if loop_order is None: if loop_order is None:
loop_order = get_optimal_loop_ordering(fields) loop_order = get_optimal_loop_ordering(fields)
shape = get_common_shape(fields) if absolut_accesses_only:
unify_shape_symbols(body, common_shape=shape, fields=fields) absolut_access_fields = {e.field for e in body.atoms(Field.Access)}
common_field = get_common_field(absolut_access_fields)
common_shape = common_field.spatial_shape
else:
common_field = get_common_field(fields)
common_shape = common_field.spatial_shape
unify_shape_symbols(body, common_shape=common_shape, fields=fields)
if iteration_slice is not None: if iteration_slice is not None:
iteration_slice = normalize_slice(iteration_slice, shape) iteration_slice = normalize_slice(iteration_slice, common_shape)
if ghost_layers is None: if ghost_layers is None:
required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses]) if absolut_accesses_only:
required_ghost_layers = 0
else:
required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order) ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
if isinstance(ghost_layers, int): if isinstance(ghost_layers, int):
ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order) ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)
...@@ -190,7 +244,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or ...@@ -190,7 +244,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
for i, loop_coordinate in enumerate(reversed(loop_order)): for i, loop_coordinate in enumerate(reversed(loop_order)):
if iteration_slice is None: if iteration_slice is None:
begin = ghost_layers[loop_coordinate][0] begin = ghost_layers[loop_coordinate][0]
end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1] end = common_shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1) new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
current_body = ast.Block([new_loop]) current_body = ast.Block([new_loop])
else: else:
...@@ -207,6 +261,28 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or ...@@ -207,6 +261,28 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
return current_body, ghost_layers return current_body, ghost_layers
def get_common_indexed_element(indexed_elements: Set[sp.IndexedBase]) -> sp.IndexedBase:
assert len(indexed_elements) > 0, "indexed_elements can not be empty"
shape_set = {s.shape for s in indexed_elements}
if len(shape_set) != 1:
for shape in shape_set:
assert not isinstance(shape, int), "If indexed elements are used, they must all have the same shape"
return sorted(indexed_elements, key=lambda e: str(e))[0]
def add_outer_loop_over_indexed_elements(loop_node: ast.Block) -> ast.Block:
indexed_elements = loop_node.atoms(sp.Indexed)
if len(indexed_elements) == 0:
return loop_node
reference_element = get_common_indexed_element(indexed_elements)
index = reference_element.indices[0].atoms(TypedSymbol)
assert len(index) == 1, "index expressions must only contain one symbol representing the index"
new_loop = ast.LoopOverCoordinate(loop_node, 0, 0,
reference_element.shape[0], 1, custom_loop_ctr=index.pop())
return ast.Block([new_loop])
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr): def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
r""" r"""
Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]` Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
...@@ -342,7 +418,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): ...@@ -342,7 +418,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
ast_node: ast before any field accesses are resolved ast_node: ast before any field accesses are resolved
loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes) loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
for GPU kernels: list of 'loop counters' from inner to outer loop for GPU kernels: list of 'loop counters' from inner to outer loop
loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default loop_iterations: iteration slice for each loop from inner to outer, for CPU kernels leave to default
Returns: Returns:
base buffer index - required by 'resolve_buffer_accesses' function base buffer index - required by 'resolve_buffer_accesses' function
...@@ -354,15 +430,25 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): ...@@ -354,15 +430,25 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
assert len(loops) == len(parents_of_innermost_loop) assert len(loops) == len(parents_of_innermost_loop)
assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop)) assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
actual_sizes = [int_div((loop.stop - loop.start), loop.step) loop_counters = [loop.loop_counter_symbol for loop in loops]
if loop.step != 1 else loop.stop - loop.start for loop in loops] loop_iterations = [slice(loop.start, loop.stop, loop.step) for loop in loops]
actual_steps = [int_div((loop.loop_counter_symbol - loop.start), loop.step) actual_sizes = list()
if loop.step != 1 else loop.loop_counter_symbol - loop.start for loop in loops] actual_steps = list()
for ctr, s in zip(loop_counters, loop_iterations):
if s.step != 1:
if (s.stop - s.start) % s.step == 0:
actual_sizes.append((s.stop - s.start) // s.step)
else:
actual_sizes.append(int_div((s.stop - s.start), s.step))
else: if (ctr - s.start) % s.step == 0:
actual_sizes = loop_iterations actual_steps.append((ctr - s.start) // s.step)
actual_steps = loop_counters else:
actual_steps.append(int_div((ctr - s.start), s.step))
else:
actual_sizes.append(s.stop - s.start)
actual_steps.append(ctr - s.start)
field_accesses = ast_node.atoms(Field.Access) field_accesses = ast_node.atoms(Field.Access)
buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)} buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
...@@ -506,7 +592,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=None, ...@@ -506,7 +592,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=None,
coord_dict = create_coordinate_dict(group) coord_dict = create_coordinate_dict(group)
new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer) new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
if new_ptr not in enclosing_block.symbols_defined: if new_ptr not in enclosing_block.symbols_defined:
new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False) new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False, use_auto=False)
enclosing_block.insert_before(new_assignment, sympy_assignment) enclosing_block.insert_before(new_assignment, sympy_assignment)
last_pointer = new_ptr last_pointer = new_ptr
...@@ -570,21 +656,65 @@ def move_constants_before_loop(ast_node): ...@@ -570,21 +656,65 @@ def move_constants_before_loop(ast_node):
""" """
assert isinstance(node.parent, ast.Block) assert isinstance(node.parent, ast.Block)
def modifies_or_declares(node: ast.Node, symbol_names: Set[str]) -> bool:
if isinstance(node, (ps.Assignment, ast.SympyAssignment)):
if isinstance(node.lhs, ast.ResolvedFieldAccess):
return node.lhs.typed_symbol.name in symbol_names
else:
return node.lhs.name in symbol_names
elif isinstance(node, ast.Block):
for arg in node.args:
if isinstance(arg, ast.SympyAssignment) and arg.is_declaration:
continue
if modifies_or_declares(arg, symbol_names):
return True
return False
elif isinstance(node, ast.LoopOverCoordinate):
return modifies_or_declares(node.body, symbol_names)
elif isinstance(node, ast.Conditional):
return (
modifies_or_declares(node.true_block, symbol_names)
or (node.false_block and modifies_or_declares(node.false_block, symbol_names))
)
elif isinstance(node, ast.KernelFunction):
return False
else:
defs = {s.name for s in node.symbols_defined}
return bool(symbol_names.intersection(defs))
dependencies = {s.name for s in node.undefined_symbols}
last_block = node.parent last_block = node.parent
last_block_child = node last_block_child = node
element = node.parent element = node.parent
prev_element = node prev_element = node
while element: while element:
if isinstance(element, ast.Block): if isinstance(element, (ast.Conditional, ast.KernelFunction)):
# Never move out of Conditionals or KernelFunctions.
break
elif isinstance(element, ast.Block):
last_block = element last_block = element
last_block_child = prev_element last_block_child = prev_element
if isinstance(element, ast.Conditional): if any(modifies_or_declares(sibling, dependencies) for sibling in element.args):
break # The node depends on one of the statements in this block.
# Do not move further out.
break
elif isinstance(element, ast.LoopOverCoordinate):
if element.loop_counter_symbol.name in dependencies:
# The node depends on the loop counter.
# Do not move out of this loop.
break
else: else:
critical_symbols = set([s.name for s in element.symbols_defined]) raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n'
if set([s.name for s in node.undefined_symbols]).intersection(critical_symbols): f'The expression {element} of type {type(element)} is not known yet.')
break
# No dependencies to symbols defined/modified within the current element.
# We can move the node up one level and in front of the current element.
prev_element = element prev_element = element
element = element.parent element = element.parent
return last_block, last_block_child return last_block, last_block_child
...@@ -721,7 +851,8 @@ def cut_loop(loop_node, cutting_points): ...@@ -721,7 +851,8 @@ def cut_loop(loop_node, cutting_points):
One loop is transformed into len(cuttingPoints)+1 new loops that range from One loop is transformed into len(cuttingPoints)+1 new loops that range from
old_begin to cutting_points[1], ..., cutting_points[-1] to old_end old_begin to cutting_points[1], ..., cutting_points[-1] to old_end
Modifies the ast in place Modifies the ast in place. Note Issue #5783 of SymPy. Deepcopy will evaluate mul
https://github.com/sympy/sympy/issues/5783
Returns: Returns:
list of new loop nodes list of new loop nodes
......
...@@ -3,14 +3,14 @@ from pystencils.typing.cast_functions import (CastFunc, BooleanCastFunc, VectorM ...@@ -3,14 +3,14 @@ from pystencils.typing.cast_functions import (CastFunc, BooleanCastFunc, VectorM
from pystencils.typing.types import (is_supported_type, numpy_name_to_c, AbstractType, BasicType, VectorType, from pystencils.typing.types import (is_supported_type, numpy_name_to_c, AbstractType, BasicType, VectorType,
PointerType, StructType, create_type) PointerType, StructType, create_type)
from pystencils.typing.typed_sympy import (assumptions_from_dtype, TypedSymbol, FieldStrideSymbol, FieldShapeSymbol, from pystencils.typing.typed_sympy import (assumptions_from_dtype, TypedSymbol, FieldStrideSymbol, FieldShapeSymbol,
FieldPointerSymbol) FieldPointerSymbol, CFunction)
from pystencils.typing.utilities import (typed_symbols, get_base_type, result_type, collate_types, from pystencils.typing.utilities import (typed_symbols, get_base_type, result_type, collate_types,
get_type_of_expression, get_next_parent_of_type, parents_of_type) get_type_of_expression, get_next_parent_of_type, parents_of_type)
__all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCastFunc', 'PointerArithmeticFunc', __all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCastFunc', 'PointerArithmeticFunc',
'is_supported_type', 'numpy_name_to_c', 'AbstractType', 'BasicType', 'is_supported_type', 'numpy_name_to_c', 'AbstractType', 'BasicType',
'VectorType', 'PointerType', 'StructType', 'create_type', 'VectorType', 'PointerType', 'StructType', 'create_type', 'assumptions_from_dtype',
'assumptions_from_dtype', 'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', 'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', 'CFunction',
'typed_symbols', 'get_base_type', 'result_type', 'collate_types', 'typed_symbols', 'get_base_type', 'result_type', 'collate_types',
'get_type_of_expression', 'get_next_parent_of_type', 'parents_of_type'] 'get_type_of_expression', 'get_next_parent_of_type', 'parents_of_type']
...@@ -6,11 +6,12 @@ import numpy as np ...@@ -6,11 +6,12 @@ import numpy as np
import sympy as sp import sympy as sp
from sympy import Piecewise from sympy import Piecewise
from sympy.core.numbers import NegativeOne
from sympy.core.relational import Relational from sympy.core.relational import Relational
from sympy.functions.elementary.piecewise import ExprCondPair from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
from sympy.functions.elementary.hyperbolic import HyperbolicFunction from sympy.functions.elementary.hyperbolic import HyperbolicFunction
from sympy.codegen import Assignment from sympy.functions.elementary.integers import RoundFunction
from sympy.logic.boolalg import BooleanFunction from sympy.logic.boolalg import BooleanFunction
from sympy.logic.boolalg import BooleanAtom from sympy.logic.boolalg import BooleanAtom
...@@ -51,7 +52,7 @@ class TypeAdder: ...@@ -51,7 +52,7 @@ class TypeAdder:
def visit(self, obj): def visit(self, obj):
if isinstance(obj, (list, tuple)): if isinstance(obj, (list, tuple)):
return [self.visit(e) for e in obj] return [self.visit(e) for e in obj]
if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)): if isinstance(obj, ast.SympyAssignment):
return self.process_assignment(obj) return self.process_assignment(obj)
elif isinstance(obj, ast.Conditional): elif isinstance(obj, ast.Conditional):
condition, condition_type = self.figure_out_type(obj.condition_expr) condition, condition_type = self.figure_out_type(obj.condition_expr)
...@@ -67,7 +68,7 @@ class TypeAdder: ...@@ -67,7 +68,7 @@ class TypeAdder:
else: else:
raise ValueError("Invalid object in kernel " + str(type(obj))) raise ValueError("Invalid object in kernel " + str(type(obj)))
def process_assignment(self, assignment: Union[sp.Eq, ast.SympyAssignment, Assignment]) -> ast.SympyAssignment: def process_assignment(self, assignment: ast.SympyAssignment) -> ast.SympyAssignment:
# for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1 # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
new_rhs, rhs_type = self.figure_out_type(assignment.rhs) new_rhs, rhs_type = self.figure_out_type(assignment.rhs)
...@@ -81,11 +82,11 @@ class TypeAdder: ...@@ -81,11 +82,11 @@ class TypeAdder:
assert isinstance(new_lhs, (Field.Access, TypedSymbol)) assert isinstance(new_lhs, (Field.Access, TypedSymbol))
if lhs_type != rhs_type: if lhs_type != rhs_type:
logging.warning(f'Lhs"{new_lhs} of type "{lhs_type}" is assigned with a different datatype ' logging.debug(f'Lhs"{new_lhs} of type "{lhs_type}" is assigned with a different datatype '
f'rhs: "{new_rhs}" of type "{rhs_type}".') f'rhs: "{new_rhs}" of type "{rhs_type}".')
return ast.SympyAssignment(new_lhs, CastFunc(new_rhs, lhs_type)) return ast.SympyAssignment(new_lhs, CastFunc(new_rhs, lhs_type), assignment.is_const, assignment.use_auto)
else: else:
return ast.SympyAssignment(new_lhs, new_rhs) return ast.SympyAssignment(new_lhs, new_rhs, assignment.is_const, assignment.use_auto)
# Type System Specification # Type System Specification
# - Defined Types: TypedSymbol, Field, Field.Access, ...? # - Defined Types: TypedSymbol, Field, Field.Access, ...?
...@@ -171,12 +172,13 @@ class TypeAdder: ...@@ -171,12 +172,13 @@ class TypeAdder:
args_types = [self.figure_out_type(a) for a in expr.args] args_types = [self.figure_out_type(a) for a in expr.args]
new_args = [a if t.dtype_eq(bool_type) else BooleanCastFunc(a, bool_type) for a, t in args_types] new_args = [a if t.dtype_eq(bool_type) else BooleanCastFunc(a, bool_type) for a, t in args_types]
return expr.func(*new_args), bool_type return expr.func(*new_args), bool_type
elif type(expr, ) in pystencils.integer_functions.__dict__.values(): elif type(expr, ) in pystencils.integer_functions.__dict__.values() or isinstance(expr, sp.Mod):
args_types = [self.figure_out_type(a) for a in expr.args] args_types = [self.figure_out_type(a) for a in expr.args]
collated_type = collate_types([t for _, t in args_types]) collated_type = collate_types([t for _, t in args_types])
# TODO: should we downcast to integer? If yes then which integer type? # TODO: should we downcast to integer? If yes then which integer type?
if not collated_type.is_int(): if not collated_type.is_int():
raise ValueError(f"Integer functions need to be used with integer types but {collated_type} was given") raise ValueError(f"Integer functions or Modulo need to be used with integer types "
f"but {collated_type} was given")
return expr, collated_type return expr, collated_type
elif isinstance(expr, flag_cond): elif isinstance(expr, flag_cond):
...@@ -212,7 +214,7 @@ class TypeAdder: ...@@ -212,7 +214,7 @@ class TypeAdder:
new_args.append(a) new_args.append(a)
return expr.func(*new_args) if new_args else expr, collated_type return expr.func(*new_args) if new_args else expr, collated_type
elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction, elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction,
HyperbolicFunction, sp.log)): HyperbolicFunction, sp.log, RoundFunction)):
args_types = [self.figure_out_type(arg) for arg in expr.args] args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types]) collated_type = collate_types([t for _, t in args_types])
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types] new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
...@@ -228,6 +230,15 @@ class TypeAdder: ...@@ -228,6 +230,15 @@ class TypeAdder:
new_func = expr.func(*new_args) if new_args else expr new_func = expr.func(*new_args) if new_args else expr
return CastFunc(new_func, collated_type), collated_type return CastFunc(new_func, collated_type), collated_type
elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)): elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)):
# Subtraction is realised a multiplication with -1 in SymPy. Thus we exclude the coefficient in this case
# and resolve the typing entirely with the expression itself
if isinstance(expr, sp.Mul):
c, e = expr.as_coeff_Mul()
if c == NegativeOne():
args_types = self.figure_out_type(e)
new_args = [NegativeOne(), args_types[0]]
return expr.func(*new_args, evaluate=False), args_types[1]
args_types = [self.figure_out_type(arg) for arg in expr.args] args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types]) collated_type = collate_types([t for _, t in args_types])
if isinstance(collated_type, PointerType): if isinstance(collated_type, PointerType):
......
from typing import List from typing import List
from pystencils.astnodes import Node
from pystencils.config import CreateKernelConfig from pystencils.config import CreateKernelConfig
from pystencils.typing.leaf_typing import TypeAdder from pystencils.typing.leaf_typing import TypeAdder
from sympy.codegen import Assignment
def add_types(eqs: List[Assignment], config: CreateKernelConfig): def add_types(node_list: List[Node], config: CreateKernelConfig):
"""Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`. """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
The AST needs to be a pystencils AST. Thus, in the list of nodes every entry must be inherited from
`pystencils.astnodes.Node`
Additionally returns sets of all fields which are read/written Additionally returns sets of all fields which are read/written
Args: Args:
eqs: list of equations node_list: List of pystencils Nodes.
config: CreateKernelConfig config: CreateKernelConfig
Returns: Returns:
...@@ -22,4 +24,4 @@ def add_types(eqs: List[Assignment], config: CreateKernelConfig): ...@@ -22,4 +24,4 @@ def add_types(eqs: List[Assignment], config: CreateKernelConfig):
default_number_float=config.default_number_float, default_number_float=config.default_number_float,
default_number_int=config.default_number_int) default_number_int=config.default_number_int)
return check.visit(eqs) return check.visit(node_list)
...@@ -178,3 +178,20 @@ class FieldPointerSymbol(TypedSymbol): ...@@ -178,3 +178,20 @@ class FieldPointerSymbol(TypedSymbol):
__xnew__ = staticmethod(__new_stage2__) __xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__)) __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
class CFunction(TypedSymbol):
def __new__(cls, function, dtype):
return CFunction.__xnew_cached_(cls, function, dtype)
def __new_stage2__(cls, function, dtype):
return super(CFunction, cls).__xnew__(cls, function, dtype)
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), {}
...@@ -7,7 +7,7 @@ import sympy as sp ...@@ -7,7 +7,7 @@ import sympy as sp
def is_supported_type(dtype: np.dtype): def is_supported_type(dtype: np.dtype):
scalar = dtype.type scalar = dtype.type
c = np.issctype(dtype) c = np.issubdtype(dtype, np.generic)
subclass = issubclass(scalar, np.floating) or issubclass(scalar, np.integer) or issubclass(scalar, np.bool_) subclass = issubclass(scalar, np.floating) or issubclass(scalar, np.integer) or issubclass(scalar, np.bool_)
additional_checks = dtype.fields is None and dtype.hasobject is False and dtype.subdtype is None additional_checks = dtype.fields is None and dtype.hasobject is False and dtype.subdtype is None
return c and subclass and additional_checks return c and subclass and additional_checks
...@@ -25,6 +25,8 @@ def numpy_name_to_c(name: str) -> str: ...@@ -25,6 +25,8 @@ def numpy_name_to_c(name: str) -> str:
return 'double' return 'double'
elif name == 'float32': elif name == 'float32':
return 'float' return 'float'
elif name == 'float16' or name == 'half':
return 'half'
elif name.startswith('int'): elif name.startswith('int'):
width = int(name[len("int"):]) width = int(name[len("int"):])
return f"int{width}_t" return f"int{width}_t"
...@@ -68,7 +70,7 @@ class BasicType(AbstractType): ...@@ -68,7 +70,7 @@ class BasicType(AbstractType):
BasicType is defined with a const qualifier and a np.dtype. BasicType is defined with a const qualifier and a np.dtype.
""" """
def __init__(self, dtype: Union[np.dtype, 'BasicType', str], const: bool = False): def __init__(self, dtype: Union[type, 'BasicType', str], const: bool = False):
if isinstance(dtype, BasicType): if isinstance(dtype, BasicType):
self.numpy_dtype = dtype.numpy_dtype self.numpy_dtype = dtype.numpy_dtype
self.const = dtype.const self.const = dtype.const
...@@ -94,6 +96,9 @@ class BasicType(AbstractType): ...@@ -94,6 +96,9 @@ class BasicType(AbstractType):
def is_float(self): def is_float(self):
return issubclass(self.numpy_dtype.type, np.floating) return issubclass(self.numpy_dtype.type, np.floating)
def is_half(self):
return issubclass(self.numpy_dtype.type, np.half)
def is_int(self): def is_int(self):
return issubclass(self.numpy_dtype.type, np.integer) return issubclass(self.numpy_dtype.type, np.integer)
...@@ -120,7 +125,10 @@ class BasicType(AbstractType): ...@@ -120,7 +125,10 @@ class BasicType(AbstractType):
return f'{self.c_name}{" const" if self.const else ""}' return f'{self.c_name}{" const" if self.const else ""}'
def __repr__(self): def __repr__(self):
return str(self) return f'BasicType( {str(self)} )'
def _repr_html_(self):
return f'BasicType( {str(self)} )'
def __eq__(self, other): def __eq__(self, other):
return self.dtype_eq(other) and self.const == other.const return self.dtype_eq(other) and self.const == other.const
...@@ -181,16 +189,17 @@ class VectorType(AbstractType): ...@@ -181,16 +189,17 @@ class VectorType(AbstractType):
class PointerType(AbstractType): class PointerType(AbstractType):
def __init__(self, base_type: BasicType, const: bool = False, restrict: bool = True): def __init__(self, base_type: BasicType, const: bool = False, restrict: bool = True, double_pointer: bool = False):
self._base_type = base_type self._base_type = base_type
self.const = const self.const = const
self.restrict = restrict self.restrict = restrict
self.double_pointer = double_pointer
def __getnewargs__(self): def __getnewargs__(self):
return self.base_type, self.const, self.restrict return self.base_type, self.const, self.restrict, self.double_pointer
def __getnewargs_ex__(self): def __getnewargs_ex__(self):
return (self.base_type, self.const, self.restrict), {} return (self.base_type, self.const, self.restrict, self.double_pointer), {}
@property @property
def alias(self): def alias(self):
...@@ -202,22 +211,34 @@ class PointerType(AbstractType): ...@@ -202,22 +211,34 @@ class PointerType(AbstractType):
@property @property
def item_size(self): def item_size(self):
return self.base_type.item_size if self.double_pointer:
raise NotImplementedError("The item_size for double_pointer is not implemented")
else:
return self.base_type.item_size
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, PointerType): if not isinstance(other, PointerType):
return False return False
else: else:
return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict) own = (self.base_type, self.const, self.restrict, self.double_pointer)
return own == (other.base_type, other.const, other.restrict, other.double_pointer)
def __str__(self): def __str__(self):
return f'{str(self.base_type)} * {"RESTRICT " if self.restrict else "" }{"const" if self.const else ""}' restrict_str = "RESTRICT" if self.restrict else ""
const_str = "const" if self.const else ""
if self.double_pointer:
return f'{str(self.base_type)} ** {restrict_str} {const_str}'
else:
return f'{str(self.base_type)} * {restrict_str} {const_str}'
def __repr__(self): def __repr__(self):
return str(self) return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self): def __hash__(self):
return hash((self._base_type, self.const, self.restrict)) return hash((self._base_type, self.const, self.restrict, self.double_pointer))
class StructType(AbstractType): class StructType(AbstractType):
...@@ -273,11 +294,14 @@ class StructType(AbstractType): ...@@ -273,11 +294,14 @@ class StructType(AbstractType):
def __repr__(self): def __repr__(self):
return str(self) return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self): def __hash__(self):
return hash((self.numpy_dtype, self.const)) return hash((self.numpy_dtype, self.const))
def create_type(specification: Union[np.dtype, AbstractType, str]) -> AbstractType: def create_type(specification: Union[type, AbstractType, str]) -> AbstractType:
# TODO: Deprecated Use the constructor of BasicType or StructType instead # TODO: Deprecated Use the constructor of BasicType or StructType instead
"""Creates a subclass of Type according to a string or an object of subclass Type. """Creates a subclass of Type according to a string or an object of subclass Type.
......
...@@ -187,18 +187,15 @@ def get_type_of_expression(expr, ...@@ -187,18 +187,15 @@ def get_type_of_expression(expr,
# Fix for sympy versions from 1.9 # Fix for sympy versions from 1.9
sympy_version = sp.__version__.split('.') sympy_version = sp.__version__.split('.')
if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109: sympy_version_int = int(sympy_version[0]) * 100 + int(sympy_version[1])
if sympy_version_int >= 109:
# __setstate__ would bypass the contructor, so we remove it # __setstate__ would bypass the contructor, so we remove it
sp.Number.__getstate__ = sp.Basic.__getstate__ if sympy_version_int >= 111:
del sp.Basic.__getstate__ del sp.Basic.__setstate__
del sp.Symbol.__setstate__
class FunctorWithStoredKwargs: else:
def __init__(self, func, **kwargs): sp.Number.__getstate__ = sp.Basic.__getstate__
self.func = func del sp.Basic.__getstate__
self.kwargs = kwargs
def __call__(self, *args):
return self.func(*args, **self.kwargs)
# __reduce_ex__ would strip kwargs, so we override it # __reduce_ex__ would strip kwargs, so we override it
def basic_reduce_ex(self, protocol): def basic_reduce_ex(self, protocol):
...@@ -210,9 +207,7 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109: ...@@ -210,9 +207,7 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
state = self.__getstate__() state = self.__getstate__()
else: else:
state = None state = None
return FunctorWithStoredKwargs(type(self), **kwargs), args, state return partial(type(self), **kwargs), args, state
sp.Number.__reduce_ex__ = sp.Basic.__reduce_ex__
sp.Basic.__reduce_ex__ = basic_reduce_ex sp.Basic.__reduce_ex__ = basic_reduce_ex
......
...@@ -82,8 +82,8 @@ def boolean_array_bounding_box(boolean_array): ...@@ -82,8 +82,8 @@ def boolean_array_bounding_box(boolean_array):
>>> a = np.zeros((4, 4), dtype=bool) >>> a = np.zeros((4, 4), dtype=bool)
>>> a[1:-1, 1:-1] = True >>> a[1:-1, 1:-1] = True
>>> boolean_array_bounding_box(a) >>> boolean_array_bounding_box(a) == [(1, 3), (1, 3)]
[(1, 3), (1, 3)] True
""" """
dim = boolean_array.ndim dim = boolean_array.ndim
shape = boolean_array.shape shape = boolean_array.shape
...@@ -96,6 +96,21 @@ def boolean_array_bounding_box(boolean_array): ...@@ -96,6 +96,21 @@ def boolean_array_bounding_box(boolean_array):
return bounds return bounds
def binary_numbers(n):
"""Returns all binary numbers up to 2^n - 1
Example:
>>> binary_numbers(2)
[[0, 0], [0, 1], [1, 0], [1, 1]]
"""
result = list()
for i in range(1 << n):
binary_number = bin(i)[2:]
binary_number = '0' * (n - len(binary_number)) + binary_number
result.append((list(map(int, binary_number))))
return result
class LinearEquationSystem: class LinearEquationSystem:
"""Symbolic linear system of equations - consisting of matrix and right hand side. """Symbolic linear system of equations - consisting of matrix and right hand side.
......
File moved
File moved
...@@ -170,3 +170,19 @@ def test_new_merged(): ...@@ -170,3 +170,19 @@ def test_new_merged():
assert ps.Assignment(d[0, 0](0), sp.symbols('xi_0')) in merged_ac.main_assignments assert ps.Assignment(d[0, 0](0), sp.symbols('xi_0')) in merged_ac.main_assignments
assert a1 in merged_ac.subexpressions assert a1 in merged_ac.subexpressions
assert a3 in merged_ac.subexpressions assert a3 in merged_ac.subexpressions
a1 = ps.Assignment(a, 20)
a2 = ps.Assignment(a, 10)
acommon = ps.Assignment(b, a)
# main assignments
a3 = ps.Assignment(f[0, 0](0), b)
a4 = ps.Assignment(d[0, 0](0), b)
ac = ps.AssignmentCollection([a3], subexpressions=[a1, acommon])
ac2 = ps.AssignmentCollection([a4], subexpressions=[a2, acommon])
merged_ac = ac.new_merged(ac2).new_without_subexpressions()
assert ps.Assignment(f[0, 0](0), 20) in merged_ac.main_assignments
assert ps.Assignment(d[0, 0](0), 10) in merged_ac.main_assignments
import pytest
import pystencils as ps
@pytest.mark.parametrize('target', [ps.Target.CPU, ps.Target.GPU])
def test_add_augmented_assignment(target):
if target == ps.Target.GPU:
pytest.importorskip("cupy")
domain_size = (5, 5)
dh = ps.create_data_handling(domain_size=domain_size, periodicity=True, default_target=target)
f = dh.add_array("f", values_per_cell=1)
dh.fill(f.name, 0.0)
g = dh.add_array("g", values_per_cell=1)
dh.fill(g.name, 1.0)
up = ps.AddAugmentedAssignment(f.center, g.center)
config = ps.CreateKernelConfig(target=dh.default_target)
ast = ps.create_kernel(up, config=config)
kernel = ast.compile()
for i in range(10):
dh.run_kernel(kernel)
if target == ps.Target.GPU:
dh.all_to_cpu()
result = dh.gather_array(f.name)
for x in range(domain_size[0]):
for y in range(domain_size[1]):
assert result[x, y] == 10