Commit faed3110 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

(Optionally) deactivate unify_shape_symbols

parent dbb1c77e
Pipeline #28020 waiting for manual action with stage
in 5 minutes and 21 seconds
......@@ -247,7 +247,7 @@ class KernelFunction(Node):
if hasattr(symbol, 'field_name'):
return field_map[symbol.field_name],
elif hasattr(symbol, 'field_names'):
return tuple(field_map[fn] for fn in symbol.field_names)
return tuple(field_map.get(fn, field_map.get('diff' + fn)) for fn in symbol.field_names)
return ()
argument_symbols = self._body.undefined_symbols - self.global_variables
......@@ -297,7 +297,7 @@ class Block(Node):
except AttributeError:
pass
@property
@ property
def args(self):
return self._nodes
......@@ -361,7 +361,7 @@ class Block(Node):
replacements.parent = self
self._nodes.insert(idx, replacements)
@property
@ property
def symbols_defined(self):
result = set()
for a in self.args:
......@@ -371,7 +371,7 @@ class Block(Node):
result.update(a.symbols_defined)
return result
@property
@ property
def undefined_symbols(self):
result = set()
defined_symbols = set()
......@@ -443,7 +443,7 @@ class LoopOverCoordinate(Node):
self.step = fast_subs(self.step, subs_dict, skip)
return self
@property
@ property
def args(self):
result = [self.body]
for e in [self.start, self.stop, self.step]:
......@@ -461,11 +461,11 @@ class LoopOverCoordinate(Node):
elif child == self.stop:
self.stop = replacement
@property
@ property
def symbols_defined(self):
return {self.loop_counter_symbol}
@property
@ property
def undefined_symbols(self):
result = self.body.undefined_symbols
for possible_symbol in [self.start, self.stop, self.step]:
......@@ -473,22 +473,22 @@ class LoopOverCoordinate(Node):
result.update(possible_symbol.atoms(sp.Symbol))
return result - {self.loop_counter_symbol}
@staticmethod
@ staticmethod
def get_loop_counter_name(coordinate_to_loop_over):
return f"{LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}"
@staticmethod
@ staticmethod
def get_block_loop_counter_name(coordinate_to_loop_over):
return f"{LoopOverCoordinate.BlOCK_LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}"
@property
@ property
def loop_counter_name(self):
if self.is_block_loop:
return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
else:
return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
@staticmethod
@ staticmethod
def is_loop_counter_symbol(symbol):
prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
if not symbol.name.startswith(prefix):
......@@ -498,29 +498,29 @@ class LoopOverCoordinate(Node):
coordinate = int(symbol.name[len(prefix) + 1:])
return coordinate
@staticmethod
@ staticmethod
def get_loop_counter_symbol(coordinate_to_loop_over):
return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True)
@staticmethod
@ staticmethod
def get_block_loop_counter_symbol(coordinate_to_loop_over):
return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over),
'int',
nonnegative=True)
@property
@ property
def loop_counter_symbol(self):
if self.is_block_loop:
return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
else:
return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
@property
@ property
def is_outermost_loop(self):
from pystencils.transformations import get_next_parent_of_type
return get_next_parent_of_type(self, LoopOverCoordinate) is None
@property
@ property
def is_innermost_loop(self):
return len(self.atoms(LoopOverCoordinate)) == 0
......@@ -552,11 +552,11 @@ class SympyAssignment(Node):
return False
return True
@property
@ property
def lhs(self):
return self._lhs_symbol
@lhs.setter
@ lhs.setter
def lhs(self, new_value):
self._lhs_symbol = new_value
self._is_declaration = self.__is_declaration()
......@@ -572,17 +572,17 @@ class SympyAssignment(Node):
except Exception:
pass
@property
@ property
def args(self):
return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)]
@property
@ property
def symbols_defined(self):
if not self._is_declaration:
return set()
return {self._lhs_symbol}
@property
@ property
def undefined_symbols(self):
result = {s for s in self.rhs.free_symbols if not isinstance(s, sp.Indexed)}
# Add loop counters if there a field accesses
......@@ -596,11 +596,11 @@ class SympyAssignment(Node):
result.update(self._lhs_symbol.atoms(sp.Symbol))
return result
@property
@ property
def is_declaration(self):
return self._is_declaration
@property
@ property
def is_const(self):
return self._is_const
......@@ -657,7 +657,7 @@ class ResolvedFieldAccess(sp.Indexed):
super_class_contents = super(ResolvedFieldAccess, self)._hashable_content()
return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field))
@property
@ property
def typed_symbol(self):
return self.base.label
......@@ -687,18 +687,18 @@ class TemporaryMemoryAllocation(Node):
self.headers = ['<stdlib.h>']
self._align_offset = align_offset
@property
@ property
def symbols_defined(self):
return {self.symbol}
@property
@ property
def undefined_symbols(self):
if isinstance(self.size, sp.Basic):
return self.size.atoms(sp.Symbol)
else:
return set()
@property
@ property
def args(self):
return [self.symbol]
......@@ -714,22 +714,22 @@ class TemporaryMemoryFree(Node):
super(TemporaryMemoryFree, self).__init__(parent=None)
self.alloc_node = alloc_node
@property
@ property
def symbol(self):
return self.alloc_node.symbol
def offset(self, byte_alignment):
return self.alloc_node.offset(byte_alignment)
@property
@ property
def symbols_defined(self):
return set()
@property
@ property
def undefined_symbols(self):
return set()
@property
@ property
def args(self):
return []
......@@ -747,15 +747,15 @@ class SourceCodeComment(Node):
def __init__(self, text):
self.text = text
@property
@ property
def args(self):
return []
@property
@ property
def symbols_defined(self):
return set()
@property
@ property
def undefined_symbols(self):
return set()
......@@ -770,15 +770,15 @@ class EmptyLine(Node):
def __init__(self):
pass
@property
@ property
def args(self):
return []
@property
@ property
def symbols_defined(self):
return set()
@property
@ property
def undefined_symbols(self):
return set()
......@@ -798,15 +798,15 @@ class ConditionalFieldAccess(sp.Function):
def __new__(cls, field_access, outofbounds_condition, outofbounds_value=0):
return sp.Function.__new__(cls, field_access, outofbounds_condition, sp.S(outofbounds_value))
@property
@ property
def access(self):
return self.args[0]
@property
@ property
def outofbounds_condition(self):
return self.args[1]
@property
@ property
def outofbounds_value(self):
return self.args[2]
......
......@@ -15,7 +15,8 @@ def create_cuda_kernel(assignments,
iteration_slice=None,
ghost_layers=None,
skip_independence_check=False,
use_textures_for_interpolation=True):
use_textures_for_interpolation=True,
do_unify_shape_symbols=True):
assert assignments, "Assignments must not be empty!"
fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
all_fields = fields_read.union(fields_written)
......@@ -60,6 +61,7 @@ def create_cuda_kernel(assignments,
block = Block(assignments)
block = indexing.guard(block, common_shape)
if do_unify_shape_symbols:
unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers)
ast = KernelFunction(block,
......
......@@ -30,6 +30,7 @@ def create_kernel(assignments,
use_textures_for_interpolation=True,
cpu_prepend_optimizations=[],
use_auto_for_assignments=False,
do_unify_shape_symbols=True,
opencl_queue=None,
opencl_ctx=None):
"""
......@@ -119,7 +120,8 @@ def create_kernel(assignments,
indexing_creator=indexing_creator_from_params(gpu_indexing, gpu_indexing_params),
iteration_slice=iteration_slice, ghost_layers=ghost_layers,
skip_independence_check=skip_independence_check,
use_textures_for_interpolation=use_textures_for_interpolation)
use_textures_for_interpolation=use_textures_for_interpolation,
do_unify_shape_symbols=do_unify_shape_symbols)
if target == 'opencl':
from pystencils.opencl.opencljit import make_python_function
ast._backend = 'opencl'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment