Commit a3f92cb4 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Implement interpolation as an optimization

parent 2e85d1d1
......@@ -1322,6 +1322,8 @@ def implement_interpolations(ast_node: ast.Node,
FLOAT32_T = create_type('float32')
interpolation_accesses = ast_node.atoms(InterpolatorAccess)
if not interpolation_accesses:
return ast_node
def can_use_hw_interpolation(i):
return (use_hardware_interpolation_for_f32
......@@ -1346,22 +1348,16 @@ def implement_interpolations(ast_node: ast.Node,
pass
ast_node.subs({old_i: i})
if vectorize:
# TODO can be done in _interpolator_access_to_stencils field.absolute_access == simd_gather
raise NotImplementedError()
else:
substitutions = {i: i.implementation_with_stencils()
for i in interpolation_accesses if not can_use_hw_interpolation(i)}
if isinstance(ast_node, AssignmentCollection):
ast_node = ast_node.subs(substitutions)
else:
ast_node.subs(substitutions)
from pystencils.math_optimizations import ReplaceOptim, optimize_ast
# from pystencils.math_optimizations import ReplaceOptim, optimize_ast
ImplementInterpolationByStencils = ReplaceOptim(lambda e: isinstance(e, InterpolatorAccess)
and not can_use_hw_interpolation(i),
lambda e: e.implementation_with_stencils()
)
# RemoveConjugate = ReplaceOptim(lambda e: isinstance(e, sp.conjugate),
# lambda e: e.args[0]
# )
# optimize_ast(ast_node, [RemoveConjugate])
RemoveConjugate = ReplaceOptim(lambda e: isinstance(e, sp.conjugate),
lambda e: e.args[0]
)
optimize_ast(ast_node, [RemoveConjugate, ImplementInterpolationByStencils])
return ast_node
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment