diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 37bc43f979110c4393eabc636793215d713bc1f7..c49077bb27a2af158624e7222d999973487f0951 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -101,6 +101,41 @@ minimal-sympy-master: tags: - docker + +pycodegen-integration: + image: i10git.cs.fau.de:5005/pycodegen/pycodegen/full + stage: test + when: manual + script: + - git clone https://gitlab-ci-token:${CI_JOB_TOKEN}@i10git.cs.fau.de/pycodegen/pycodegen.git + - cd pycodegen + - git submodule sync --recursive + - git submodule update --init --recursive + - git submodule foreach git fetch origin # compare the latest master version! + - git submodule foreach git reset --hard origin/master + - cd pystencils + - git remote add test $CI_REPOSITORY_URL + - git fetch test + - git reset --hard $CI_COMMIT_SHA + - cd .. + - export PYTHONPATH=`pwd`/pystencils:`pwd`/lbmpy:`pwd`/pygrandchem:`pwd`/pystencils_walberla:`pwd`/lbmpy_walberla + - ./install_walberla.sh + - export NUM_CORES=$(nproc --all) + - mkdir -p ~/.config/matplotlib + - echo "backend:template" > ~/.config/matplotlib/matplotlibrc + - cd pystencils + - py.test -v -n $NUM_CORES . + - cd ../lbmpy + - py.test -v -n $NUM_CORES . + - cd ../pygrandchem + - py.test -v -n $NUM_CORES . + - cd ../walberla/build/ + - make CodegenJacobiCPU CodegenJacobiGPU MicroBenchmarkGpuLbm LbCodeGenerationExample + tags: + - docker + - cuda + - AVX + # -------------------- Linter & Documentation -------------------------------------------------------------------------- diff --git a/pystencils/cache.py b/pystencils/cache.py index ecf9727f0aba74b2b5cccae8d0ed0d781cac7ccf..5df15ae7c7498e6e849c93f2f071435560a2c415 100644 --- a/pystencils/cache.py +++ b/pystencils/cache.py @@ -1,10 +1,14 @@ import os +from collections import Hashable +from functools import partial +from itertools import chain try: from functools import lru_cache as memorycache except ImportError: from backports.functools_lru_cache import lru_cache as memorycache + try: from joblib import Memory from appdirs import user_cache_dir @@ -22,6 +26,20 @@ except ImportError: return o +def _wrapper(wrapped_func, cached_func, *args, **kwargs): + if all(isinstance(a, Hashable) for a in chain(args, kwargs.values())): + return cached_func(*args, **kwargs) + else: + return wrapped_func(*args, **kwargs) + + +def memorycache_if_hashable(maxsize=128, typed=False): + + def wrapper(func): + return partial(_wrapper, func, memorycache(maxsize, typed)(func)) + + return wrapper + # Disable memory cache: # disk_cache = lambda o: o # disk_cache_no_fallback = lambda o: o diff --git a/pystencils/data_types.py b/pystencils/data_types.py index fc7f8a4f0e916922297a7314d432627503ba8464..45099302b9615164fe329d21f8f85c631d45b1e3 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -1,11 +1,13 @@ import ctypes +from collections import defaultdict +from functools import partial import numpy as np import sympy as sp from sympy.core.cache import cacheit from sympy.logic.boolalg import Boolean -from pystencils.cache import memorycache +from pystencils.cache import memorycache, memorycache_if_hashable from pystencils.utils import all_equal try: @@ -408,11 +410,22 @@ def collate_types(types): return result -@memorycache(maxsize=2048) -def get_type_of_expression(expr, default_float_type='double', default_int_type='int'): +@memorycache_if_hashable(maxsize=2048) +def get_type_of_expression(expr, + default_float_type='double', + default_int_type='int', + symbol_type_dict=None): from pystencils.astnodes import ResolvedFieldAccess from pystencils.cpu.vectorization import vec_all, vec_any + if not symbol_type_dict: + symbol_type_dict = defaultdict(lambda: create_type('double')) + + get_type = partial(get_type_of_expression, + default_float_type=default_float_type, + default_int_type=default_int_type, + symbol_type_dict=symbol_type_dict) + expr = sp.sympify(expr) if isinstance(expr, sp.Integer): return create_type(default_int_type) @@ -423,14 +436,17 @@ def get_type_of_expression(expr, default_float_type='double', default_int_type=' elif isinstance(expr, TypedSymbol): return expr.dtype elif isinstance(expr, sp.Symbol): - raise ValueError("All symbols inside this expression have to be typed! ", str(expr)) + if symbol_type_dict: + return symbol_type_dict[expr.name] + else: + raise ValueError("All symbols inside this expression have to be typed! ", str(expr)) elif isinstance(expr, cast_func): return expr.args[1] - elif isinstance(expr, vec_any) or isinstance(expr, vec_all): + elif isinstance(expr, (vec_any, vec_all)): return create_type("bool") elif hasattr(expr, 'func') and expr.func == sp.Piecewise: - collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args)) - collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args)) + collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args)) + collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args)) if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType: collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width) return collated_result_type @@ -440,16 +456,16 @@ def get_type_of_expression(expr, default_float_type='double', default_int_type=' elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction): # if any arg is of vector type return a vector boolean, else return a normal scalar boolean result = create_type("bool") - vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)] + vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)] if vec_args: result = VectorType(result, width=vec_args[0].width) return result - elif isinstance(expr, sp.Pow): - return get_type_of_expression(expr.args[0]) + elif isinstance(expr, (sp.Pow, sp.Sum, sp.Product)): + return get_type(expr.args[0]) elif isinstance(expr, sp.Expr): expr: sp.Expr if expr.args: - types = tuple(get_type_of_expression(a) for a in expr.args) + types = tuple(get_type(a) for a in expr.args) return collate_types(types) else: if expr.is_integer: diff --git a/pystencils/test_type_interference.py b/pystencils/test_type_interference.py new file mode 100644 index 0000000000000000000000000000000000000000..0daa6f9d2a36948184d537809457b4aaf8001d29 --- /dev/null +++ b/pystencils/test_type_interference.py @@ -0,0 +1,26 @@ +from sympy.abc import a, b, c, d, e, f + +import pystencils +from pystencils.data_types import cast_func, create_type + + +def test_type_interference(): + x = pystencils.fields('x: float32[3d]') + assignments = pystencils.AssignmentCollection({ + a: cast_func(10, create_type('float64')), + b: cast_func(10, create_type('uint16')), + e: 11, + c: b, + f: c + b, + d: c + b + x.center + e, + x.center: c + b + x.center + }) + + ast = pystencils.create_kernel(assignments) + + code = str(pystencils.show_code(ast)) + print(code) + assert 'double a' in code + assert 'uint16_t b' in code + assert 'uint16_t f' in code + assert 'int64_t e' in code diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 39a9abddbad8dab1ae6c4e3361a43ff1885c4ee8..554b9cf9b2257e6b0ef97b24239ea210baa8b63b 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -147,7 +147,10 @@ def get_common_shape(field_set): return shape -def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None): +def make_loop_over_domain(body, + iteration_slice=None, + ghost_layers=None, + loop_order=None): """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST. Args: @@ -189,17 +192,21 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or if iteration_slice is None: begin = ghost_layers[loop_coordinate][0] end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1] - new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1) + new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, + begin, end, 1) current_body = ast.Block([new_loop]) else: slice_component = iteration_slice[loop_coordinate] if type(slice_component) is slice: sc = slice_component - new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step) + new_loop = ast.LoopOverCoordinate(current_body, + loop_coordinate, sc.start, + sc.stop, sc.step) current_body = ast.Block([new_loop]) else: - assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate), - sp.sympify(slice_component)) + assignment = ast.SympyAssignment( + ast.LoopOverCoordinate.get_loop_counter_symbol( + loop_coordinate), sp.sympify(slice_component)) current_body.insert_front(assignment) return current_body, ghost_layers @@ -238,9 +245,11 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr): offset += field.strides[coordinate_id] * coordinate_value if coordinate_id < field.spatial_dimensions: - offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id] + offset += field.strides[coordinate_id] * field_access.offsets[ + coordinate_id] if type(field_access.offsets[coordinate_id]) is int: - name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id]) + name += "_%d%d" % (coordinate_id, + field_access.offsets[coordinate_id]) else: list_to_hash.append(field_access.offsets[coordinate_id]) else: @@ -257,7 +266,8 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr): return new_ptr, offset -def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions): +def parse_base_pointer_info(base_pointer_specification, loop_order, + spatial_dimensions, index_dimensions): """ Creates base pointer specification for :func:`resolve_field_accesses` function. @@ -295,11 +305,13 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dime def add_new_element(elem): if elem >= spatial_dimensions + index_dimensions: - raise ValueError("Coordinate %d does not exist" % (elem,)) + raise ValueError("Coordinate %d does not exist" % (elem, )) new_group.append(elem) if elem in specified_coordinates: - raise ValueError("Coordinate %d specified two times" % (elem,)) + raise ValueError("Coordinate %d specified two times" % + (elem, )) specified_coordinates.add(elem) + for element in spec_group: if type(element) is int: add_new_element(element) @@ -320,7 +332,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dime index = int(element[len("index"):]) add_new_element(spatial_dimensions + index) else: - raise ValueError("Unknown specification %s" % (element,)) + raise ValueError("Unknown specification %s" % (element, )) result.append(new_group) @@ -345,30 +357,42 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): base buffer index - required by 'resolve_buffer_accesses' function """ if loop_counters is None or loop_iterations is None: - loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)] + loops = [ + l for l in filtered_tree_iteration( + ast_node, ast.LoopOverCoordinate, ast.SympyAssignment) + ] loops.reverse() - parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True)) + parents_of_innermost_loop = list( + parents_of_type(loops[0], + ast.LoopOverCoordinate, + include_current=True)) assert len(loops) == len(parents_of_innermost_loop) - assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop)) + assert all(l1 is l2 + for l1, l2 in zip(loops, parents_of_innermost_loop)) loop_iterations = [(l.stop - l.start) / l.step for l in loops] loop_counters = [l.loop_counter_symbol for l in loops] field_accesses = ast_node.atoms(AbstractField.AbstractAccess) - buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)} + buffer_accesses = { + fa + for fa in field_accesses if FieldType.is_buffer(fa.field) + } loop_counters = [v * len(buffer_accesses) for v in loop_counters] base_buffer_index = loop_counters[0] stride = 1 for idx, var in enumerate(loop_counters[1:]): cur_stride = loop_iterations[idx] - stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride + stride *= int(cur_stride) if isinstance(cur_stride, + float) else cur_stride base_buffer_index += var * stride return base_buffer_index -def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()): - +def resolve_buffer_accesses(ast_node, + base_buffer_index, + read_only_field_names=set()): def visit_sympy_expr(expr, enclosing_block, sympy_assignment): if isinstance(expr, AbstractField.AbstractAccess): field_access = expr @@ -378,17 +402,24 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s return expr buffer = field_access.field - field_ptr = FieldPointerSymbol(buffer.name, buffer.dtype, const=buffer.name in read_only_field_names) + field_ptr = FieldPointerSymbol( + buffer.name, + buffer.dtype, + const=buffer.name in read_only_field_names) buffer_index = base_buffer_index if len(field_access.index) > 1: - raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!') + raise RuntimeError( + 'Only indexing dimensions up to 1 are currently supported in buffers!' + ) if len(field_access.index) > 0: cell_index = field_access.index[0] buffer_index += cell_index - result = ast.ResolvedFieldAccess(field_ptr, buffer_index, field_access.field, field_access.offsets, + result = ast.ResolvedFieldAccess(field_ptr, buffer_index, + field_access.field, + field_access.offsets, field_access.index) return visit_sympy_expr(result, enclosing_block, sympy_assignment) @@ -396,16 +427,23 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s if isinstance(expr, ast.ResolvedFieldAccess): return expr - new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args] - kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {} + new_args = [ + visit_sympy_expr(e, enclosing_block, sympy_assignment) + for e in expr.args + ] + kwargs = { + 'evaluate': False + } if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {} return expr.func(*new_args, **kwargs) if new_args else expr def visit_node(sub_ast): if isinstance(sub_ast, ast.SympyAssignment): enclosing_block = sub_ast.parent assert type(enclosing_block) is ast.Block - sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast) - sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast) + sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, + sub_ast) + sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, + sub_ast) else: for i, a in enumerate(sub_ast.args): visit_node(a) @@ -413,7 +451,8 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s return visit_node(ast_node) -def resolve_field_accesses(ast_node, read_only_field_names=set(), +def resolve_field_accesses(ast_node, + read_only_field_names=set(), field_to_base_pointer_info=MappingProxyType({}), field_to_fixed_coordinates=MappingProxyType({})): """ @@ -430,8 +469,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), Returns transformed AST """ - field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0])) - field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0])) + field_to_base_pointer_info = OrderedDict( + sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0])) + field_to_fixed_coordinates = OrderedDict( + sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0])) def visit_sympy_expr(expr, enclosing_block, sympy_assignment): if isinstance(expr, AbstractField.AbstractAccess): @@ -439,20 +480,29 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), field = field_access.field if field_access.indirect_addressing_fields: - new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment) - for off in field_access.offsets) - new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment) - if isinstance(ind, sp.Basic) else ind - for ind in field_access.index) + new_offsets = tuple( + visit_sympy_expr(off, enclosing_block, sympy_assignment) + for off in field_access.offsets) + new_indices = tuple( + visit_sympy_expr(ind, enclosing_block, sympy_assignment + ) if isinstance(ind, sp.Basic) else ind + for ind in field_access.index) field_access = Field.Access(field_access.field, new_offsets, - new_indices, field_access.is_absolute_access) + new_indices, + field_access.is_absolute_access) if field.name in field_to_base_pointer_info: base_pointer_info = field_to_base_pointer_info[field.name] else: - base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))] + base_pointer_info = [ + list( + range(field.index_dimensions + field.spatial_dimensions)) + ] - field_ptr = FieldPointerSymbol(field.name, field.dtype, const=field.name in read_only_field_names) + field_ptr = FieldPointerSymbol( + field.name, + field.dtype, + const=field.name in read_only_field_names) def create_coordinate_dict(group_param): coordinates = {} @@ -460,12 +510,15 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), if e < field.spatial_dimensions: if field.name in field_to_fixed_coordinates: if not field_access.is_absolute_access: - coordinates[e] = field_to_fixed_coordinates[field.name][e] + coordinates[e] = field_to_fixed_coordinates[ + field.name][e] else: coordinates[e] = 0 else: if not field_access.is_absolute_access: - coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e) + coordinates[ + e] = ast.LoopOverCoordinate.get_loop_counter_symbol( + e) else: coordinates[e] = 0 coordinates[e] *= field.dtype.item_size @@ -474,9 +527,11 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), assert field.index_dimensions == 1 accessed_field_name = field_access.index[0] assert isinstance(accessed_field_name, str) - coordinates[e] = field.dtype.get_element_offset(accessed_field_name) + coordinates[e] = field.dtype.get_element_offset( + accessed_field_name) else: - coordinates[e] = field_access.index[e - field.spatial_dimensions] + coordinates[e] = field_access.index[ + e - field.spatial_dimensions] return coordinates @@ -484,19 +539,27 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), for group in reversed(base_pointer_info[1:]): coord_dict = create_coordinate_dict(group) - new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer) + new_ptr, offset = create_intermediate_base_pointer( + field_access, coord_dict, last_pointer) if new_ptr not in enclosing_block.symbols_defined: - new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False) - enclosing_block.insert_before(new_assignment, sympy_assignment) + new_assignment = ast.SympyAssignment(new_ptr, + last_pointer + offset, + is_const=False) + enclosing_block.insert_before(new_assignment, + sympy_assignment) last_pointer = new_ptr coord_dict = create_coordinate_dict(base_pointer_info[0]) - _, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer) - result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field, - field_access.offsets, field_access.index) + _, offset = create_intermediate_base_pointer( + field_access, coord_dict, last_pointer) + result = ast.ResolvedFieldAccess(last_pointer, offset, + field_access.field, + field_access.offsets, + field_access.index) if isinstance(get_base_type(field_access.field.dtype), StructType): - new_type = field_access.field.dtype.get_element_type(field_access.index[0]) + new_type = field_access.field.dtype.get_element_type( + field_access.index[0]) result = reinterpret_cast_func(result, new_type) return visit_sympy_expr(result, enclosing_block, sympy_assignment) @@ -504,20 +567,28 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), if isinstance(expr, ast.ResolvedFieldAccess): return expr - new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args] - kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {} + new_args = [ + visit_sympy_expr(e, enclosing_block, sympy_assignment) + for e in expr.args + ] + kwargs = { + 'evaluate': False + } if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {} return expr.func(*new_args, **kwargs) if new_args else expr def visit_node(sub_ast): if isinstance(sub_ast, ast.SympyAssignment): enclosing_block = sub_ast.parent assert type(enclosing_block) is ast.Block - sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast) - sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast) + sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, + sub_ast) + sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, + sub_ast) elif isinstance(sub_ast, ast.Conditional): enclosing_block = sub_ast.parent assert type(enclosing_block) is ast.Block - sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr, enclosing_block, sub_ast) + sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr, + enclosing_block, sub_ast) visit_node(sub_ast.true_block) if sub_ast.false_block: visit_node(sub_ast.false_block) @@ -561,11 +632,14 @@ def move_constants_before_loop(ast_node): element = element.parent return last_block, last_block_child - def check_if_assignment_already_in_block(assignment, target_block, rhs_or_lhs=True): + def check_if_assignment_already_in_block(assignment, + target_block, + rhs_or_lhs=True): for arg in target_block.args: if type(arg) is not ast.SympyAssignment: continue - if (rhs_or_lhs and arg.rhs == assignment.rhs) or (not rhs_or_lhs and arg.lhs == assignment.lhs): + if (rhs_or_lhs and arg.rhs == assignment.rhs) or ( + not rhs_or_lhs and arg.lhs == assignment.lhs): return arg return None @@ -588,21 +662,24 @@ def move_constants_before_loop(ast_node): # Before traversing the next child, all symbols are substituted first. child.subs(substitute_variables) - if not isinstance(child, ast.SympyAssignment): # only move SympyAssignments + if not isinstance( + child, ast.SympyAssignment): # only move SympyAssignments block.append(child) continue target, child_to_insert_before = find_block_to_move_to(child) - if target == block: # movement not possible + if target == block: # movement not possible target.append(child) else: if isinstance(child, ast.SympyAssignment): - exists_already = check_if_assignment_already_in_block(child, target, False) + exists_already = check_if_assignment_already_in_block( + child, target, False) else: exists_already = False if not exists_already: - rhs_identical = check_if_assignment_already_in_block(child, target, True) + rhs_identical = check_if_assignment_already_in_block( + child, target, True) if rhs_identical: # there is already an assignment out there with the same rhs # -> replace all lhs symbols in this block with the lhs of the outer assignment @@ -617,7 +694,9 @@ def move_constants_before_loop(ast_node): # -> symbol has to be renamed assert isinstance(child.lhs, TypedSymbol) new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype) - target.insert_before(ast.SympyAssignment(new_symbol, child.rhs), child_to_insert_before) + target.insert_before( + ast.SympyAssignment(new_symbol, child.rhs), + child_to_insert_before) substitute_variables[child.lhs] = new_symbol @@ -633,7 +712,9 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): """ all_loops = ast_node.atoms(ast.LoopOverCoordinate) inner_loop = [l for l in all_loops if l.is_innermost_loop] - assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?" + assert len( + inner_loop + ) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?" inner_loop = inner_loop[0] assert type(inner_loop.body) is ast.Block outer_loop = [l for l in all_loops if l.is_outermost_loop] @@ -664,28 +745,38 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): if not isinstance(symbol, AbstractField.AbstractAccess): assert type(symbol) is TypedSymbol new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype)) - symbols_with_temporary_array[symbol] = sp.IndexedBase(new_ts, - shape=(1,))[inner_loop.loop_counter_symbol] + symbols_with_temporary_array[symbol] = sp.IndexedBase( + new_ts, shape=(1, ))[inner_loop.loop_counter_symbol] assignment_group = [] for assignment in inner_loop.body.args: if assignment.lhs in symbols_resolved: - new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items()) - if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group: + new_rhs = assignment.rhs.subs( + symbols_with_temporary_array.items()) + if not isinstance(assignment.lhs, AbstractField.AbstractAccess + ) and assignment.lhs in symbol_group: assert type(assignment.lhs) is TypedSymbol - new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype)) - new_lhs = sp.IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol] + new_ts = TypedSymbol(assignment.lhs.name, + PointerType(assignment.lhs.dtype)) + new_lhs = sp.IndexedBase( + new_ts, shape=(1, ))[inner_loop.loop_counter_symbol] else: new_lhs = assignment.lhs assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs)) assignment_groups.append(assignment_group) - new_loops = [inner_loop.new_loop_with_different_body(ast.Block(group)) for group in assignment_groups] + new_loops = [ + inner_loop.new_loop_with_different_body(ast.Block(group)) + for group in assignment_groups + ] inner_loop.parent.replace(inner_loop, ast.Block(new_loops)) for tmp_array in symbols_with_temporary_array: - tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype)) - alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop, inner_loop.start) + tmp_array_pointer = TypedSymbol(tmp_array.name, + PointerType(tmp_array.dtype)) + alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, + inner_loop.stop, + inner_loop.start) free_node = ast.TemporaryMemoryFree(alloc_node) outer_loop.parent.insert_front(alloc_node) outer_loop.parent.append(free_node) @@ -715,15 +806,17 @@ def cut_loop(loop_node, cutting_points): elif new_end - new_start == 0: pass else: - new_loop = ast.LoopOverCoordinate(deepcopy(loop_node.body), loop_node.coordinate_to_loop_over, - new_start, new_end, loop_node.step) + new_loop = ast.LoopOverCoordinate( + deepcopy(loop_node.body), loop_node.coordinate_to_loop_over, + new_start, new_end, loop_node.step) new_loops.append(new_loop) new_start = new_end loop_node.parent.replace(loop_node, new_loops) return new_loops -def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = False) -> None: +def simplify_conditionals(node: ast.Node, + loop_counter_simplification: bool = False) -> None: """Removes conditionals that are always true/false. Args: @@ -739,14 +832,18 @@ def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = Fa if conditional.condition_expr == sp.true: conditional.parent.replace(conditional, [conditional.true_block]) elif conditional.condition_expr == sp.false: - conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else []) + conditional.parent.replace( + conditional, + [conditional.false_block] if conditional.false_block else []) elif loop_counter_simplification: try: # noinspection PyUnresolvedReferences from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional simplify_loop_counter_dependent_conditional(conditional) except ImportError: - warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed") + warnings.warn( + "Integer simplifications in conditionals skipped, because ISLpy package not installed" + ) def cleanup_blocks(node: ast.Node) -> None: @@ -808,18 +905,28 @@ class KernelConstraintsCheck: elif type_constants and isinstance(rhs, sp.Number): return cast_func(rhs, create_type(self._type_for_symbol['_constant'])) elif isinstance(rhs, sp.Mul): - new_args = [self.process_expression(arg, type_constants) if arg not in (-1, 1) else arg for arg in rhs.args] + new_args = [ + self.process_expression(arg, type_constants) + if arg not in (-1, 1) else arg for arg in rhs.args + ] return rhs.func(*new_args) if new_args else rhs elif isinstance(rhs, sp.Indexed): return rhs elif isinstance(rhs, cast_func): - return cast_func(self.process_expression(rhs.args[0], type_constants=False), rhs.dtype) + return cast_func( + self.process_expression(rhs.args[0], type_constants=False), + rhs.dtype) else: if isinstance(rhs, sp.Pow): # don't process exponents -> they should remain integers - return sp.Pow(self.process_expression(rhs.args[0], type_constants), rhs.args[1]) + return sp.Pow( + self.process_expression(rhs.args[0], type_constants), + rhs.args[1]) else: - new_args = [self.process_expression(arg, type_constants) for arg in rhs.args] + new_args = [ + self.process_expression(arg, type_constants) + for arg in rhs.args + ] return rhs.func(*new_args) if new_args else rhs @property @@ -829,7 +936,7 @@ class KernelConstraintsCheck: def _process_lhs(self, lhs): assert isinstance(lhs, sp.Symbol) self._update_accesses_lhs(lhs) - if not isinstance(lhs, AbstractField.AbstractAccess) and not isinstance(lhs, TypedSymbol): + if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)): return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name]) else: return lhs @@ -839,22 +946,32 @@ class KernelConstraintsCheck: fai = self.FieldAndIndex(lhs.field, lhs.index) self._field_writes[fai].add(lhs.offsets) if len(self._field_writes[fai]) > 1: - raise ValueError("Field {} is written at two different locations".format(lhs.field.name)) + raise ValueError( + "Field {} is written at two different locations".format( + lhs.field.name)) elif isinstance(lhs, sp.Symbol): if self.scopes.is_defined_locally(lhs): - raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(lhs.name)) + raise ValueError( + "Assignments not in SSA form, multiple assignments to {}". + format(lhs.name)) if lhs in self.scopes.free_parameters: - raise ValueError("Symbol {} is written, after it has been read".format(lhs.name)) + raise ValueError( + "Symbol {} is written, after it has been read".format( + lhs.name)) self.scopes.define_symbol(lhs) def _update_accesses_rhs(self, rhs): - if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition: - writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)] + if isinstance(rhs, AbstractField.AbstractAccess + ) and self.check_independence_condition: + writes = self._field_writes[self.FieldAndIndex( + rhs.field, rhs.index)] for write_offset in writes: assert len(writes) == 1 if write_offset != rhs.offsets: - raise ValueError("Violation of loop independence condition. Field " - "{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset)) + raise ValueError( + "Violation of loop independence condition. Field " + "{} is read at {} and written at {}".format( + rhs.field, rhs.offsets, write_offset)) self.fields_read.add(rhs.field) elif isinstance(rhs, sp.Symbol): self.scopes.access_symbol(rhs) @@ -875,21 +992,29 @@ def add_types(eqs, type_for_symbol, check_independence_condition): ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields, list of equations where symbols have been replaced by typed symbols """ - if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'): + if isinstance(type_for_symbol, + str) or not hasattr(type_for_symbol, '__getitem__'): type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol) - check = KernelConstraintsCheck(type_for_symbol, check_independence_condition) + # assignments = ast.Block(eqs).atoms(ast.Assignment) + # type_for_symbol.update( {a.lhs: get_type_of_expression(a.rhs) for a in assignments}) + # print(type_for_symbol) + check = KernelConstraintsCheck(type_for_symbol, + check_independence_condition) def visit(obj): - if isinstance(obj, list) or isinstance(obj, tuple): + if isinstance(obj, (list, tuple)): return [visit(e) for e in obj] - if isinstance(obj, sp.Eq) or isinstance(obj, ast.SympyAssignment) or isinstance(obj, Assignment): + if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)): return check.process_assignment(obj) elif isinstance(obj, ast.Conditional): check.scopes.push() - false_block = None if obj.false_block is None else visit(obj.false_block) - result = ast.Conditional(check.process_expression(obj.condition_expr, type_constants=False), - true_block=visit(obj.true_block), false_block=false_block) + false_block = None if obj.false_block is None else visit( + obj.false_block) + result = ast.Conditional(check.process_expression( + obj.condition_expr, type_constants=False), + true_block=visit(obj.true_block), + false_block=false_block) check.scopes.pop() return result elif isinstance(obj, ast.Block): @@ -897,7 +1022,8 @@ def add_types(eqs, type_for_symbol, check_independence_condition): result = ast.Block([visit(e) for e in obj.args]) check.scopes.pop() return result - elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate): + elif isinstance( + obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate): return obj else: raise ValueError("Invalid object in kernel " + str(type(obj))) @@ -956,7 +1082,8 @@ def insert_casts(node): for arg in node.args: args.append(insert_casts(arg)) # TODO indexed, LoopOverCoordinate - if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge): + if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, + sp.Lt, sp.Le, sp.Gt, sp.Ge): # TODO optimize pow, don't cast integer on double types = [get_type_of_expression(arg) for arg in args] assert len(types) > 0 @@ -974,7 +1101,8 @@ def insert_casts(node): if target.func is PointerType: return node.func(*args) # TODO fix, not complete else: - return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target)) + return node.func( + lhs, *cast([(rhs, get_type_of_expression(rhs))], target)) elif node.func is ast.ResolvedFieldAccess: return node elif node.func is ast.Block: @@ -991,19 +1119,30 @@ def insert_casts(node): target = collate_types(types) zipped = list(zip(expressions, types)) casted_expressions = cast(zipped, target) - args = [arg.func(*[expr, arg.cond]) for (arg, expr) in zip(args, casted_expressions)] + args = [ + arg.func(*[expr, arg.cond]) + for (arg, expr) in zip(args, casted_expressions) + ] return node.func(*args) -def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None: +def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction + ) -> None: """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element""" - all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop] - assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop" + all_inner_loops = [ + l for l in function_node.atoms(ast.LoopOverCoordinate) + if l.is_innermost_loop + ] + assert len( + all_inner_loops + ) == 1, "Transformation works only on kernels with exactly one inner loop" inner_loop = all_inner_loops.pop() - for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True): + for loop in parents_of_type(inner_loop, + ast.LoopOverCoordinate, + include_current=True): cut_loop(loop, [loop.stop - 1]) simplify_conditionals(function_node.body, loop_counter_simplification=True) @@ -1016,7 +1155,7 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) - # --------------------------------------- Helper Functions ------------------------------------------------------------- -def typing_from_sympy_inspection(eqs, default_type="double"): +def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='int64'): """ Creates a default symbol name to type mapping. If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double') @@ -1032,17 +1171,25 @@ def typing_from_sympy_inspection(eqs, default_type="double"): if isinstance(eq, ast.Conditional): result.update(typing_from_sympy_inspection(eq.true_block.args)) if eq.false_block: - result.update(typing_from_sympy_inspection(eq.false_block.args)) + result.update(typing_from_sympy_inspection( + eq.false_block.args)) elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment): continue else: from pystencils.cpu.vectorization import vec_all, vec_any - if isinstance(eq.rhs, vec_all) or isinstance(eq.rhs, vec_any): + if isinstance(eq.rhs, (vec_all, vec_any)): result[eq.lhs.name] = "bool" # problematic case here is when rhs is a symbol: then it is impossible to decide here without # further information what type the left hand side is - default fallback is the dict value then if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol): result[eq.lhs.name] = "bool" + try: + result[eq.lhs.name] = get_type_of_expression(eq.rhs, + default_float_type=default_type, + default_int_type=default_int_type, + symbol_type_dict=result) + except Exception: + pass # gracefully fail in case get_type_of_expression cannot determine type return result @@ -1084,13 +1231,17 @@ def get_optimal_loop_ordering(fields): ref_field = next(iter(fields)) for field in fields: if field.spatial_dimensions != ref_field.spatial_dimensions: - raise ValueError("All fields have to have the same number of spatial dimensions. Spatial field dimensions: " - + str({f.name: f.spatial_shape for f in fields})) + raise ValueError( + "All fields have to have the same number of spatial dimensions. Spatial field dimensions: " + + str({f.name: f.spatial_shape + for f in fields})) layouts = set([field.layout for field in fields]) if len(layouts) > 1: - raise ValueError("Due to different layout of the fields no optimal loop ordering exists " - + str({f.name: f.layout for f in fields})) + raise ValueError( + "Due to different layout of the fields no optimal loop ordering exists " + + str({f.name: f.layout + for f in fields})) layout = list(layouts)[0] return list(layout) @@ -1135,7 +1286,9 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None: """ inner_loops = [] inner_loop_counters = set() - for loop in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment): + for loop in filtered_tree_iteration(ast_node, + ast.LoopOverCoordinate, + stop_type=ast.SympyAssignment): if loop.is_innermost_loop: inner_loops.append(loop) inner_loop_counters.add(loop.coordinate_to_loop_over) @@ -1146,8 +1299,10 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None: inner_loop_counter = inner_loop_counters.pop() parameters = ast_node.get_parameters() - stride_params = [p.symbol for p in parameters - if p.is_field_stride and p.symbol.coordinate == inner_loop_counter] + stride_params = [ + p.symbol for p in parameters + if p.is_field_stride and p.symbol.coordinate == inner_loop_counter + ] subs_dict = {stride_param: 1 for stride_param in stride_params} if subs_dict: ast_node.subs(subs_dict) @@ -1163,7 +1318,10 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int: Returns: number of dimensions blocked """ - loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)] + loops = [ + l for l in filtered_tree_iteration( + ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment) + ] body = ast_node.body coordinates = [] @@ -1183,8 +1341,12 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int: outer_loop = None for coord in reversed(coordinates): body = ast.Block([outer_loop]) if outer_loop else body - outer_loop = ast.LoopOverCoordinate(body, coord, loop_starts[coord], loop_stops[coord], - step=block_size[coord], is_block_loop=True) + outer_loop = ast.LoopOverCoordinate(body, + coord, + loop_starts[coord], + loop_stops[coord], + step=block_size[coord], + is_block_loop=True) ast_node.body = ast.Block([outer_loop]) @@ -1193,7 +1355,8 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int: coord = inner_loop.coordinate_to_loop_over block_ctr = ast.LoopOverCoordinate.get_block_loop_counter_symbol(coord) loop_range = inner_loop.stop - inner_loop.start - if sp.sympify(loop_range).is_number and loop_range % block_size[coord] == 0: + if sp.sympify( + loop_range).is_number and loop_range % block_size[coord] == 0: stop = block_ctr + block_size[coord] else: stop = sp.Min(inner_loop.stop, block_ctr + block_size[coord])