Commit faed3110 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

(Optionally) deactivate unify_shape_symbols

parent dbb1c77e
Pipeline #28020 waiting for manual action with stage
in 5 minutes and 21 seconds
...@@ -247,7 +247,7 @@ class KernelFunction(Node): ...@@ -247,7 +247,7 @@ class KernelFunction(Node):
if hasattr(symbol, 'field_name'): if hasattr(symbol, 'field_name'):
return field_map[symbol.field_name], return field_map[symbol.field_name],
elif hasattr(symbol, 'field_names'): elif hasattr(symbol, 'field_names'):
return tuple(field_map[fn] for fn in symbol.field_names) return tuple(field_map.get(fn, field_map.get('diff' + fn)) for fn in symbol.field_names)
return () return ()
argument_symbols = self._body.undefined_symbols - self.global_variables argument_symbols = self._body.undefined_symbols - self.global_variables
...@@ -297,7 +297,7 @@ class Block(Node): ...@@ -297,7 +297,7 @@ class Block(Node):
except AttributeError: except AttributeError:
pass pass
@property @ property
def args(self): def args(self):
return self._nodes return self._nodes
...@@ -361,7 +361,7 @@ class Block(Node): ...@@ -361,7 +361,7 @@ class Block(Node):
replacements.parent = self replacements.parent = self
self._nodes.insert(idx, replacements) self._nodes.insert(idx, replacements)
@property @ property
def symbols_defined(self): def symbols_defined(self):
result = set() result = set()
for a in self.args: for a in self.args:
...@@ -371,7 +371,7 @@ class Block(Node): ...@@ -371,7 +371,7 @@ class Block(Node):
result.update(a.symbols_defined) result.update(a.symbols_defined)
return result return result
@property @ property
def undefined_symbols(self): def undefined_symbols(self):
result = set() result = set()
defined_symbols = set() defined_symbols = set()
...@@ -443,7 +443,7 @@ class LoopOverCoordinate(Node): ...@@ -443,7 +443,7 @@ class LoopOverCoordinate(Node):
self.step = fast_subs(self.step, subs_dict, skip) self.step = fast_subs(self.step, subs_dict, skip)
return self return self
@property @ property
def args(self): def args(self):
result = [self.body] result = [self.body]
for e in [self.start, self.stop, self.step]: for e in [self.start, self.stop, self.step]:
...@@ -461,11 +461,11 @@ class LoopOverCoordinate(Node): ...@@ -461,11 +461,11 @@ class LoopOverCoordinate(Node):
elif child == self.stop: elif child == self.stop:
self.stop = replacement self.stop = replacement
@property @ property
def symbols_defined(self): def symbols_defined(self):
return {self.loop_counter_symbol} return {self.loop_counter_symbol}
@property @ property
def undefined_symbols(self): def undefined_symbols(self):
result = self.body.undefined_symbols result = self.body.undefined_symbols
for possible_symbol in [self.start, self.stop, self.step]: for possible_symbol in [self.start, self.stop, self.step]:
...@@ -473,22 +473,22 @@ class LoopOverCoordinate(Node): ...@@ -473,22 +473,22 @@ class LoopOverCoordinate(Node):
result.update(possible_symbol.atoms(sp.Symbol)) result.update(possible_symbol.atoms(sp.Symbol))
return result - {self.loop_counter_symbol} return result - {self.loop_counter_symbol}
@staticmethod @ staticmethod
def get_loop_counter_name(coordinate_to_loop_over): def get_loop_counter_name(coordinate_to_loop_over):
return f"{LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}" return f"{LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}"
@staticmethod @ staticmethod
def get_block_loop_counter_name(coordinate_to_loop_over): def get_block_loop_counter_name(coordinate_to_loop_over):
return f"{LoopOverCoordinate.BlOCK_LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}" return f"{LoopOverCoordinate.BlOCK_LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}"
@property @ property
def loop_counter_name(self): def loop_counter_name(self):
if self.is_block_loop: if self.is_block_loop:
return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over) return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
else: else:
return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over) return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
@staticmethod @ staticmethod
def is_loop_counter_symbol(symbol): def is_loop_counter_symbol(symbol):
prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
if not symbol.name.startswith(prefix): if not symbol.name.startswith(prefix):
...@@ -498,29 +498,29 @@ class LoopOverCoordinate(Node): ...@@ -498,29 +498,29 @@ class LoopOverCoordinate(Node):
coordinate = int(symbol.name[len(prefix) + 1:]) coordinate = int(symbol.name[len(prefix) + 1:])
return coordinate return coordinate
@staticmethod @ staticmethod
def get_loop_counter_symbol(coordinate_to_loop_over): def get_loop_counter_symbol(coordinate_to_loop_over):
return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True) return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True)
@staticmethod @ staticmethod
def get_block_loop_counter_symbol(coordinate_to_loop_over): def get_block_loop_counter_symbol(coordinate_to_loop_over):
return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over), return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over),
'int', 'int',
nonnegative=True) nonnegative=True)
@property @ property
def loop_counter_symbol(self): def loop_counter_symbol(self):
if self.is_block_loop: if self.is_block_loop:
return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over) return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
else: else:
return self.get_loop_counter_symbol(self.coordinate_to_loop_over) return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
@property @ property
def is_outermost_loop(self): def is_outermost_loop(self):
from pystencils.transformations import get_next_parent_of_type from pystencils.transformations import get_next_parent_of_type
return get_next_parent_of_type(self, LoopOverCoordinate) is None return get_next_parent_of_type(self, LoopOverCoordinate) is None
@property @ property
def is_innermost_loop(self): def is_innermost_loop(self):
return len(self.atoms(LoopOverCoordinate)) == 0 return len(self.atoms(LoopOverCoordinate)) == 0
...@@ -552,11 +552,11 @@ class SympyAssignment(Node): ...@@ -552,11 +552,11 @@ class SympyAssignment(Node):
return False return False
return True return True
@property @ property
def lhs(self): def lhs(self):
return self._lhs_symbol return self._lhs_symbol
@lhs.setter @ lhs.setter
def lhs(self, new_value): def lhs(self, new_value):
self._lhs_symbol = new_value self._lhs_symbol = new_value
self._is_declaration = self.__is_declaration() self._is_declaration = self.__is_declaration()
...@@ -572,17 +572,17 @@ class SympyAssignment(Node): ...@@ -572,17 +572,17 @@ class SympyAssignment(Node):
except Exception: except Exception:
pass pass
@property @ property
def args(self): def args(self):
return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)] return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)]
@property @ property
def symbols_defined(self): def symbols_defined(self):
if not self._is_declaration: if not self._is_declaration:
return set() return set()
return {self._lhs_symbol} return {self._lhs_symbol}
@property @ property
def undefined_symbols(self): def undefined_symbols(self):
result = {s for s in self.rhs.free_symbols if not isinstance(s, sp.Indexed)} result = {s for s in self.rhs.free_symbols if not isinstance(s, sp.Indexed)}
# Add loop counters if there a field accesses # Add loop counters if there a field accesses
...@@ -596,11 +596,11 @@ class SympyAssignment(Node): ...@@ -596,11 +596,11 @@ class SympyAssignment(Node):
result.update(self._lhs_symbol.atoms(sp.Symbol)) result.update(self._lhs_symbol.atoms(sp.Symbol))
return result return result
@property @ property
def is_declaration(self): def is_declaration(self):
return self._is_declaration return self._is_declaration
@property @ property
def is_const(self): def is_const(self):
return self._is_const return self._is_const
...@@ -657,7 +657,7 @@ class ResolvedFieldAccess(sp.Indexed): ...@@ -657,7 +657,7 @@ class ResolvedFieldAccess(sp.Indexed):
super_class_contents = super(ResolvedFieldAccess, self)._hashable_content() super_class_contents = super(ResolvedFieldAccess, self)._hashable_content()
return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field)) return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field))
@property @ property
def typed_symbol(self): def typed_symbol(self):
return self.base.label return self.base.label
...@@ -687,18 +687,18 @@ class TemporaryMemoryAllocation(Node): ...@@ -687,18 +687,18 @@ class TemporaryMemoryAllocation(Node):
self.headers = ['<stdlib.h>'] self.headers = ['<stdlib.h>']
self._align_offset = align_offset self._align_offset = align_offset
@property @ property
def symbols_defined(self): def symbols_defined(self):
return {self.symbol} return {self.symbol}
@property @ property
def undefined_symbols(self): def undefined_symbols(self):
if isinstance(self.size, sp.Basic): if isinstance(self.size, sp.Basic):
return self.size.atoms(sp.Symbol) return self.size.atoms(sp.Symbol)
else: else:
return set() return set()
@property @ property
def args(self): def args(self):
return [self.symbol] return [self.symbol]
...@@ -714,22 +714,22 @@ class TemporaryMemoryFree(Node): ...@@ -714,22 +714,22 @@ class TemporaryMemoryFree(Node):
super(TemporaryMemoryFree, self).__init__(parent=None) super(TemporaryMemoryFree, self).__init__(parent=None)
self.alloc_node = alloc_node self.alloc_node = alloc_node
@property @ property
def symbol(self): def symbol(self):
return self.alloc_node.symbol return self.alloc_node.symbol
def offset(self, byte_alignment): def offset(self, byte_alignment):
return self.alloc_node.offset(byte_alignment) return self.alloc_node.offset(byte_alignment)
@property @ property
def symbols_defined(self): def symbols_defined(self):
return set() return set()
@property @ property
def undefined_symbols(self): def undefined_symbols(self):
return set() return set()
@property @ property
def args(self): def args(self):
return [] return []
...@@ -747,15 +747,15 @@ class SourceCodeComment(Node): ...@@ -747,15 +747,15 @@ class SourceCodeComment(Node):
def __init__(self, text): def __init__(self, text):
self.text = text self.text = text
@property @ property
def args(self): def args(self):
return [] return []
@property @ property
def symbols_defined(self): def symbols_defined(self):
return set() return set()
@property @ property
def undefined_symbols(self): def undefined_symbols(self):
return set() return set()
...@@ -770,15 +770,15 @@ class EmptyLine(Node): ...@@ -770,15 +770,15 @@ class EmptyLine(Node):
def __init__(self): def __init__(self):
pass pass
@property @ property
def args(self): def args(self):
return [] return []
@property @ property
def symbols_defined(self): def symbols_defined(self):
return set() return set()
@property @ property
def undefined_symbols(self): def undefined_symbols(self):
return set() return set()
...@@ -798,15 +798,15 @@ class ConditionalFieldAccess(sp.Function): ...@@ -798,15 +798,15 @@ class ConditionalFieldAccess(sp.Function):
def __new__(cls, field_access, outofbounds_condition, outofbounds_value=0): def __new__(cls, field_access, outofbounds_condition, outofbounds_value=0):
return sp.Function.__new__(cls, field_access, outofbounds_condition, sp.S(outofbounds_value)) return sp.Function.__new__(cls, field_access, outofbounds_condition, sp.S(outofbounds_value))
@property @ property
def access(self): def access(self):
return self.args[0] return self.args[0]
@property @ property
def outofbounds_condition(self): def outofbounds_condition(self):
return self.args[1] return self.args[1]
@property @ property
def outofbounds_value(self): def outofbounds_value(self):
return self.args[2] return self.args[2]
......
...@@ -15,7 +15,8 @@ def create_cuda_kernel(assignments, ...@@ -15,7 +15,8 @@ def create_cuda_kernel(assignments,
iteration_slice=None, iteration_slice=None,
ghost_layers=None, ghost_layers=None,
skip_independence_check=False, skip_independence_check=False,
use_textures_for_interpolation=True): use_textures_for_interpolation=True,
do_unify_shape_symbols=True):
assert assignments, "Assignments must not be empty!" assert assignments, "Assignments must not be empty!"
fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check) fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
all_fields = fields_read.union(fields_written) all_fields = fields_read.union(fields_written)
...@@ -60,7 +61,8 @@ def create_cuda_kernel(assignments, ...@@ -60,7 +61,8 @@ def create_cuda_kernel(assignments,
block = Block(assignments) block = Block(assignments)
block = indexing.guard(block, common_shape) block = indexing.guard(block, common_shape)
unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers) if do_unify_shape_symbols:
unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers)
ast = KernelFunction(block, ast = KernelFunction(block,
'gpu', 'gpu',
......
...@@ -30,6 +30,7 @@ def create_kernel(assignments, ...@@ -30,6 +30,7 @@ def create_kernel(assignments,
use_textures_for_interpolation=True, use_textures_for_interpolation=True,
cpu_prepend_optimizations=[], cpu_prepend_optimizations=[],
use_auto_for_assignments=False, use_auto_for_assignments=False,
do_unify_shape_symbols=True,
opencl_queue=None, opencl_queue=None,
opencl_ctx=None): opencl_ctx=None):
""" """
...@@ -119,7 +120,8 @@ def create_kernel(assignments, ...@@ -119,7 +120,8 @@ def create_kernel(assignments,
indexing_creator=indexing_creator_from_params(gpu_indexing, gpu_indexing_params), indexing_creator=indexing_creator_from_params(gpu_indexing, gpu_indexing_params),
iteration_slice=iteration_slice, ghost_layers=ghost_layers, iteration_slice=iteration_slice, ghost_layers=ghost_layers,
skip_independence_check=skip_independence_check, skip_independence_check=skip_independence_check,
use_textures_for_interpolation=use_textures_for_interpolation) use_textures_for_interpolation=use_textures_for_interpolation,
do_unify_shape_symbols=do_unify_shape_symbols)
if target == 'opencl': if target == 'opencl':
from pystencils.opencl.opencljit import make_python_function from pystencils.opencl.opencljit import make_python_function
ast._backend = 'opencl' ast._backend = 'opencl'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment