Commit 800eaf6c authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'master' into 'ConditionalFieldAccess'

# Conflicts:
#   pystencils/
parents 8dfe844c ae69d3d3
Pipeline #18637 passed with stage
in 2 minutes and 59 seconds
......@@ -465,11 +465,13 @@ class LoopOverCoordinate(Node):
def get_loop_counter_symbol(coordinate_to_loop_over):
return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int')
return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True)
def get_block_loop_counter_symbol(coordinate_to_loop_over):
return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over), 'int')
return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over),
def loop_counter_symbol(self):
......@@ -503,7 +505,7 @@ class SympyAssignment(Node):
def __init__(self, lhs_symbol, rhs_expr, is_const=True):
super(SympyAssignment, self).__init__(parent=None)
self._lhs_symbol = lhs_symbol
self.rhs = sp.simplify(rhs_expr)
self.rhs = sp.sympify(rhs_expr)
self._is_const = is_const
self._is_declaration = self.__is_declaration()
......@@ -42,12 +42,11 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
AST node representing a function, that can be printed as C or CUDA code
def type_symbol(term):
if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
if not hasattr(type_info, '__getitem__'):
if isinstance(type_info, str) or not hasattr(type_info, '__getitem__'):
return TypedSymbol(, create_type(type_info))
return TypedSymbol(, type_info[])
......@@ -240,6 +240,10 @@ class TypedSymbol(sp.Symbol):
def canonical(self):
return self
def reversed(self):
return self
def create_type(specification):
"""Creates a subclass of Type according to a string or an object of subclass Type.
......@@ -111,6 +111,7 @@ class AssignmentCollection:
"Not in SSA form - same symbol assigned multiple times"
return bound_symbols_set
def free_fields(self):
"""All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
return {s.field for s in self.free_symbols if hasattr(s, 'field')}
......@@ -46,6 +46,7 @@ def test_inplace_update():
np.testing.assert_equal(arr, 2)
def test_vectorization_fixed_size():
configurations = []
# Fixed size - multiple of four
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment