Commit c83faa47 authored by Markus Holzer's avatar Markus Holzer
Browse files

adapted to new flake8 version

parent 5eedf7cc
...@@ -117,7 +117,7 @@ class Conditional(Node): ...@@ -117,7 +117,7 @@ class Conditional(Node):
if self.true_block: if self.true_block:
repr += '\n\t{}) '.format(self.true_block) repr += '\n\t{}) '.format(self.true_block)
if self.false_block: if self.false_block:
repr = 'else: '.format(self.false_block) repr = 'else: '
repr += '\n\t{} '.format(self.false_block) repr += '\n\t{} '.format(self.false_block)
return repr return repr
...@@ -421,7 +421,7 @@ class LoopOverCoordinate(Node): ...@@ -421,7 +421,7 @@ class LoopOverCoordinate(Node):
def new_loop_with_different_body(self, new_body): def new_loop_with_different_body(self, new_body):
result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop, result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop,
self.step, self.is_block_loop) self.step, self.is_block_loop)
result.prefix_lines = [l for l in self.prefix_lines] result.prefix_lines = [lo for lo in self.prefix_lines]
return result return result
def subs(self, subs_dict): def subs(self, subs_dict):
......
...@@ -7,7 +7,7 @@ from pystencils.interpolation_astnodes import DiffInterpolatorAccess, Interpolat ...@@ -7,7 +7,7 @@ from pystencils.interpolation_astnodes import DiffInterpolatorAccess, Interpolat
with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f: with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
lines = f.readlines() lines = f.readlines()
CUDA_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l} CUDA_KNOWN_FUNCTIONS = {lo.strip(): lo.strip() for lo in lines if lo}
def generate_cuda(astnode: Node, signature_only: bool = False) -> str: def generate_cuda(astnode: Node, signature_only: bool = False) -> str:
......
...@@ -8,7 +8,7 @@ from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqr ...@@ -8,7 +8,7 @@ from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqr
with open(join(dirname(__file__), 'opencl1.1_known_functions.txt')) as f: with open(join(dirname(__file__), 'opencl1.1_known_functions.txt')) as f:
lines = f.readlines() lines = f.readlines()
OPENCL_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l} OPENCL_KNOWN_FUNCTIONS = {lo.strip(): lo.strip() for lo in lines if lo}
def generate_opencl(astnode: Node, signature_only: bool = False) -> str: def generate_opencl(astnode: Node, signature_only: bool = False) -> str:
......
...@@ -177,8 +177,8 @@ def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None, ass ...@@ -177,8 +177,8 @@ def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None, ass
wrapper_block = ast.PragmaBlock('#pragma omp parallel' + threads_clause, body.take_child_nodes()) wrapper_block = ast.PragmaBlock('#pragma omp parallel' + threads_clause, body.take_child_nodes())
body.append(wrapper_block) body.append(wrapper_block)
outer_loops = [l for l in filtered_tree_iteration(body, LoopOverCoordinate, stop_type=SympyAssignment) outer_loops = [lo for lo in filtered_tree_iteration(body, LoopOverCoordinate, stop_type=SympyAssignment)
if l.is_outermost_loop] if lo.is_outermost_loop]
assert outer_loops, "No outer loop found" assert outer_loops, "No outer loop found"
if assume_single_outer_loop and len(outer_loops) > 1: if assume_single_outer_loop and len(outer_loops) > 1:
raise ValueError("More than one outer loop found, only one outer loop expected") raise ValueError("More than one outer loop found, only one outer loop expected")
...@@ -194,7 +194,7 @@ def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None, ass ...@@ -194,7 +194,7 @@ def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None, ass
num_threads = multiprocessing.cpu_count() num_threads = multiprocessing.cpu_count()
if loop_range is not None and loop_range < num_threads and not collapse: if loop_range is not None and loop_range < num_threads and not collapse:
contained_loops = [l for l in loop_to_parallelize.body.args if isinstance(l, LoopOverCoordinate)] contained_loops = [lo for lo in loop_to_parallelize.body.args if isinstance(lo, LoopOverCoordinate)]
if len(contained_loops) == 1: if len(contained_loops) == 1:
contained_loop = contained_loops[0] contained_loop = contained_loops[0]
try: try:
......
...@@ -83,7 +83,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a ...@@ -83,7 +83,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
"""Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type.""" """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment) all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
inner_loops = [n for n in all_loops if n.is_innermost_loop] inner_loops = [n for n in all_loops if n.is_innermost_loop]
zero_loop_counters = {l.loop_counter_symbol: 0 for l in all_loops} zero_loop_counters = {lo.loop_counter_symbol: 0 for lo in all_loops}
for loop_node in inner_loops: for loop_node in inner_loops:
loop_range = loop_node.stop - loop_node.start loop_range = loop_node.stop - loop_node.start
...@@ -95,7 +95,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a ...@@ -95,7 +95,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
loop_node.stop = new_stop loop_node.stop = new_stop
else: else:
cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start
loop_nodes = [l for l in cut_loop(loop_node, [cutting_point]).args if isinstance(l, ast.LoopOverCoordinate)] loop_nodes = [lo for lo in cut_loop(loop_node,
[cutting_point]).args if isinstance(lo, ast.LoopOverCoordinate)]
assert len(loop_nodes) in (0, 1, 2) # 2 for main and tail loop, 1 if loop range divisible by vector width assert len(loop_nodes) in (0, 1, 2) # 2 for main and tail loop, 1 if loop range divisible by vector width
if len(loop_nodes) == 0: if len(loop_nodes) == 0:
continue continue
......
...@@ -206,8 +206,8 @@ class BlockIndexing(AbstractIndexing): ...@@ -206,8 +206,8 @@ class BlockIndexing(AbstractIndexing):
sorted_block_size = sorted_block_size[:-1] sorted_block_size = sorted_block_size[:-1]
result = list(block_size) result = list(block_size)
for l, bs in zip(reversed(layout), sorted_block_size): for lo, bs in zip(reversed(layout), sorted_block_size):
result[l] = bs result[lo] = bs
return tuple(result[:len(layout)]) return tuple(result[:len(layout)])
def max_threads_per_block(self): def max_threads_per_block(self):
......
...@@ -46,8 +46,8 @@ class PyStencilsKerncraftKernel(KernelCode): ...@@ -46,8 +46,8 @@ class PyStencilsKerncraftKernel(KernelCode):
self._keep_intermediates = debug_print self._keep_intermediates = debug_print
# Loops # Loops
inner_loops = [l for l in filtered_tree_iteration(ast, LoopOverCoordinate, stop_type=SympyAssignment) inner_loops = [lo for lo in filtered_tree_iteration(ast, LoopOverCoordinate, stop_type=SympyAssignment)
if l.is_innermost_loop] if lo.is_innermost_loop]
if len(inner_loops) == 0: if len(inner_loops) == 0:
raise ValueError("No loop found in pystencils AST") raise ValueError("No loop found in pystencils AST")
else: else:
......
...@@ -179,7 +179,7 @@ def coefficient_list(expr, matrix_form=False): ...@@ -179,7 +179,7 @@ def coefficient_list(expr, matrix_form=False):
for i in range(-max_offsets[0], max_offsets[0] + 1)] for i in range(-max_offsets[0], max_offsets[0] + 1)]
for j in y_range] for j in y_range]
for k in range(-max_offsets[2], max_offsets[2] + 1)] for k in range(-max_offsets[2], max_offsets[2] + 1)]
return [sp.Matrix(l) for l in result] if matrix_form else result return [sp.Matrix(lo) for lo in result] if matrix_form else result
else: else:
raise ValueError("Can only handle fields with 1,2 or 3 spatial dimensions") raise ValueError("Can only handle fields with 1,2 or 3 spatial dimensions")
......
...@@ -351,14 +351,14 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): ...@@ -351,14 +351,14 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
base buffer index - required by 'resolve_buffer_accesses' function base buffer index - required by 'resolve_buffer_accesses' function
""" """
if loop_counters is None or loop_iterations is None: if loop_counters is None or loop_iterations is None:
loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)] loops = [lo for lo in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)]
loops.reverse() loops.reverse()
parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True)) parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True))
assert len(loops) == len(parents_of_innermost_loop) assert len(loops) == len(parents_of_innermost_loop)
assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop)) assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
loop_iterations = [(l.stop - l.start) / l.step for l in loops] loop_iterations = [(lo.stop - lo.start) / lo.step for lo in loops]
loop_counters = [l.loop_counter_symbol for l in loops] loop_counters = [lo.loop_counter_symbol for lo in loops]
field_accesses = ast_node.atoms(AbstractField.AbstractAccess) field_accesses = ast_node.atoms(AbstractField.AbstractAccess)
buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)} buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
...@@ -659,11 +659,11 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): ...@@ -659,11 +659,11 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
and which no symbol in a symbol group depends on, are not updated! and which no symbol in a symbol group depends on, are not updated!
""" """
all_loops = ast_node.atoms(ast.LoopOverCoordinate) all_loops = ast_node.atoms(ast.LoopOverCoordinate)
inner_loop = [l for l in all_loops if l.is_innermost_loop] inner_loop = [lo for lo in all_loops if lo.is_innermost_loop]
assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?" assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
inner_loop = inner_loop[0] inner_loop = inner_loop[0]
assert type(inner_loop.body) is ast.Block assert type(inner_loop.body) is ast.Block
outer_loop = [l for l in all_loops if l.is_outermost_loop] outer_loop = [lo for lo in all_loops if lo.is_outermost_loop]
assert len(outer_loop) == 1, "Error in AST, multiple outermost loops." assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
outer_loop = outer_loop[0] outer_loop = outer_loop[0]
...@@ -1077,7 +1077,7 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, i ...@@ -1077,7 +1077,7 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, i
"""Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last or """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last or
first and last element""" first and last element"""
all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop] all_inner_loops = [lo for lo in function_node.atoms(ast.LoopOverCoordinate) if lo.is_innermost_loop]
assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop" assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
inner_loop = all_inner_loops.pop() inner_loop = all_inner_loops.pop()
...@@ -1265,7 +1265,7 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int: ...@@ -1265,7 +1265,7 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
number of dimensions blocked number of dimensions blocked
""" """
loops = [ loops = [
l for l in filtered_tree_iteration( lo for lo in filtered_tree_iteration(
ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment) ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
] ]
body = ast_node.body body = ast_node.body
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment