Commit e8db1ac3 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add test for MatrixSymbols

parent a8997a2f
Pipeline #21828 failed with stage
in 33 minutes and 48 seconds
import sympy as sp
import pystencils
from pystencils.data_types import TypedMatrixSymbol, TypedSymbol, create_type
class DynamicFunction(sp.Function):
"""
Function that is passed as an argument to a kernel.
Can be printed for example as `std::function` or as a functor template.
"""
def __new__(cls, typed_function_symbol, return_dtype, *args):
obj = sp.Function.__new__(cls, typed_function_symbol, return_dtype, *args)
if hasattr(return_dtype, 'shape'):
obj.shape = return_dtype.shape
return obj
@property
def function_dtype(self):
return self.args[0].dtype
@property
def dtype(self):
return self.args[1].dtype
@property
def name(self):
return self.args[0].name
def __str__(self):
return f'{self.name}({", ".join(str(a) for a in self.args[2:])})'
def __repr__(self):
return self.__str__()
def test_dynamic_matrix_location_dependent():
x, y = pystencils.fields('x, y: float32[3d]')
A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>')
my_fun_call = DynamicFunction(TypedSymbol('my_fun',
'std::function<Vector3<double>(int, int, int)>'),
A.dtype,
*pystencils.x_vector(3))
assignments = pystencils.AssignmentCollection({
A: my_fun_call,
y.center: A[0] + A[1] + A[2]
})
ast = pystencils.create_kernel(assignments)
pystencils.show_code(ast)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment