Commit 9e649232 authored by Jan Hönig's avatar Jan Hönig
Browse files

Fixed simplification bug

parent 3b30bace
......@@ -377,6 +377,8 @@ class Block(Node):
return tmp
def replace(self, child, replacements):
if self._nodes.count(child) != 1:
print('here')
assert self._nodes.count(child) == 1
idx = self._nodes.index(child)
del self._nodes[idx]
......
......@@ -219,7 +219,7 @@ class CBackend:
method_name = f"_print_{cls.__name__}"
if hasattr(self, method_name):
return getattr(self, method_name)(node)
raise NotImplementedError(f"{self.__class__.__name__} does not support node of type {node.__class__.__name__}")
return self.sympy_printer.doprint(node)
def _print_AbstractType(self, node):
return str(node)
......
......@@ -5,6 +5,7 @@ import sympy as sp
import pystencils.astnodes as ast
from pystencils.typing import parents_of_type
from pystencils.backends import generate_c
def remove_brackets(s):
......@@ -51,7 +52,7 @@ def simplify_loop_counter_dependent_conditional(conditional):
dofs_in_loops, iteration_set = isl_iteration_set(conditional)
if dofs_in_condition.issubset(dofs_in_loops):
symbol_names = ','.join(dofs_in_loops)
condition_str = remove_brackets(str(conditional.condition_expr))
condition_str = remove_brackets(generate_c(conditional.condition_expr))
condition_set = isl.BasicSet(f"{{ [{symbol_names}] : {condition_str} }}")
if condition_set.is_empty():
......
......@@ -772,7 +772,8 @@ def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = Fa
default.
"""
for conditional in node.atoms(ast.Conditional):
conditional.condition_expr = sp.simplify(conditional.condition_expr)
# TODO simplify conditional before the type system!
# conditional.condition_expr = sp.simplify(conditional.condition_expr)
if conditional.condition_expr == sp.true:
conditional.parent.replace(conditional, [conditional.true_block])
elif conditional.condition_expr == sp.false:
......
......@@ -102,7 +102,7 @@ class TypeAdder:
# Possible Problems - Do we need to support this?
# - Mixture in expression with int and float
# - Mixture in expression with uint64 and sint64
# TODO: Lowest log level should log all casts ----> cast factory, make cast should contain logging
# TODO Logging: Lowest log level should log all casts ----> cast factory, make cast should contain logging
def figure_out_type(self, expr) -> Tuple[Any, Union[BasicType, PointerType]]:
# Trivial cases
from pystencils.field import Field
......@@ -112,7 +112,6 @@ class TypeAdder:
# TOOO: check the access
if isinstance(expr, Field.Access):
# TODO if Struct, look at the reinterpreted dtype
return expr, expr.dtype
elif isinstance(expr, TypedSymbol):
return expr, expr.dtype
......@@ -139,7 +138,7 @@ class TypeAdder:
elif isinstance(expr, BooleanAtom):
return expr, bool_type
elif isinstance(expr, Relational):
# TODO JAN: Code duplication with general case
# TODO Jan: Code duplication with general case
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
if isinstance(expr, sp.Equality) and collated_type.is_float():
......@@ -188,7 +187,7 @@ class TypeAdder:
return expr.func(expr.args[0], expr.args[1], *new_expressions), collated_type
# elif isinstance(expr, sp.Mul):
# raise NotImplementedError('sp.Mul')
# # TODO can we ignore this and move it to general expr handling, i.e. removing Mul?
# # TODO can we ignore this and move it to general expr handling, i.e. removing Mul? (See todo in backend)
# # args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)]
elif isinstance(expr, sp.Indexed):
raise NotImplementedError('sp.Indexed')
......
......@@ -12,8 +12,10 @@ def test_blocking_staggered():
f[0, 0, 0] - f[0, 0, -1],
]
assignments = [ps.Assignment(stag.staggered_access(d), terms[i]) for i, d in enumerate(stag.staggered_stencil)]
reference_kernel = ps.create_staggered_kernel(assignments)
print(ps.show_code(reference_kernel))
reference_kernel = reference_kernel.compile()
kernel = ps.create_staggered_kernel(assignments, cpu_blocking=(3, 16, 8)).compile()
reference_kernel = ps.create_staggered_kernel(assignments).compile()
print(ps.show_code(kernel.ast))
f_arr = np.random.rand(80, 33, 19)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment