Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
import pytest
import sympy as sp
from pystencils.utils import LinearEquationSystem
from pystencils.utils import DotDict
def test_linear_equation_system():
unknowns = sp.symbols("x_:3")
x, y, z = unknowns
m = LinearEquationSystem(unknowns)
m.add_equation(x + y - 2)
m.add_equation(x - y - 1)
assert m.solution_structure() == 'multiple'
m.set_unknown_zero(2)
assert m.solution_structure() == 'single'
solution = m.solution()
assert solution[unknowns[2]] == 0
assert solution[unknowns[1]] == sp.Rational(1, 2)
assert solution[unknowns[0]] == sp.Rational(3, 2)
m.set_unknown_zero(0)
assert m.solution_structure() == 'none'
# special case where less rows than unknowns, but no solution
m = LinearEquationSystem(unknowns)
m.add_equation(x - 3)
m.add_equation(x - 4)
assert m.solution_structure() == 'none'
m.add_equation(y - 4)
assert m.solution_structure() == 'none'
with pytest.raises(ValueError) as e:
m.add_equation(x**2 - 1)
assert 'Not a linear equation' in str(e.value)
x, y, z = sp.symbols("x, y, z")
les = LinearEquationSystem([x, y, z])
les.add_equation(1 * x + 2 * y - 1 * z + 4)
les.add_equation(2 * x + 1 * y + 1 * z - 2)
les.add_equation(1 * x + 2 * y + 1 * z + 2)
# usually reduce is not necessary since almost every function of LinearEquationSystem calls reduce beforehand
les.reduce()
expected_matrix = sp.Matrix([[1, 0, 0, sp.Rational(5, 3)],
[0, 1, 0, sp.Rational(-7, 3)],
[0, 0, 1, sp.Integer(1)]])
assert les.matrix == expected_matrix
assert les.rank == 3
sol = les.solution()
assert sol[x] == sp.Rational(5, 3)
assert sol[y] == sp.Rational(-7, 3)
assert sol[z] == sp.Integer(1)
les = LinearEquationSystem([x, y])
assert les.solution_structure() == 'multiple'
les.add_equation(x + 1)
assert les.solution_structure() == 'multiple'
les.add_equation(y + 2)
assert les.solution_structure() == 'single'
les.add_equation(x + y + 5)
assert les.solution_structure() == 'none'
def test_dot_dict():
d = {'a': {'c': 7}, 'b': 6}
t = DotDict(d)
assert t.a.c == 7
assert t.b == 6
assert len(t) == 2
delattr(t, 'b')
assert len(t) == 1
t.b = 6
assert len(t) == 2
assert t.b == 6
import numpy as np
import pytest
import pystencils.config
import sympy as sp
import pystencils as ps
import pystencils.astnodes as ast
from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set
from pystencils.cpu.vectorization import vectorize
from pystencils.enums import Target
from pystencils.transformations import replace_inner_stride_with_one
supported_instruction_sets = get_supported_instruction_sets()
if supported_instruction_sets:
instruction_set = supported_instruction_sets[-1]
else:
instruction_set = None
# TODO: Skip tests if no instruction set is available and check all codes if they are really vectorised !
def test_vector_type_propagation1(instruction_set=instruction_set):
a, b, c, d, e = sp.symbols("a b c d e")
arr = np.ones((2 ** 2 + 2, 2 ** 3 + 2))
arr *= 10.0
f, g = ps.fields(f=arr, g=arr)
update_rule = [ps.Assignment(a, f[1, 0]),
ps.Assignment(b, a),
ps.Assignment(g[0, 0], b + 3 + f[0, 1])]
ast = ps.create_kernel(update_rule)
vectorize(ast, instruction_set=instruction_set)
# ps.show_code(ast)
func = ast.compile()
dst = np.zeros_like(arr)
func(g=dst, f=arr)
np.testing.assert_equal(dst[1:-1, 1:-1], 2 * 10.0 + 3)
def test_vector_type_propagation2(instruction_set=instruction_set):
data_type = 'float32'
assume_aligned = True
assume_inner_stride_one = True
assume_sufficient_line_padding = True
domain_size = (24, 24)
dh = ps.create_data_handling(domain_size, default_target=Target.CPU)
# fields
g = dh.add_array("g", values_per_cell=1, dtype=data_type, alignment=True, ghost_layers=0)
f = dh.add_array("f", values_per_cell=1, dtype=data_type, alignment=True, ghost_layers=0)
dh.fill("g", 1.0, ghost_layers=True)
dh.fill("f", 0.5, ghost_layers=True)
collision = [
ps.Assignment(sp.Symbol("b"), sp.sqrt(sp.Symbol("a") * (1 - (1 - f.center)**2))),
ps.Assignment(g.center, sp.Symbol("b"))
]
config = ps.CreateKernelConfig(
target=ps.Target.CPU,
data_type=data_type,
default_number_float=data_type,
cpu_vectorize_info={
'assume_aligned': assume_aligned,
'assume_inner_stride_one': assume_inner_stride_one,
'assume_sufficient_line_padding': assume_sufficient_line_padding,
},
ghost_layers=0
)
ast = ps.create_kernel(collision, config=config)
print(ast)
code = ps.get_code_str(ast)
print(code)
kernel = ast.compile()
dh.run_kernel(kernel, a=4)
g_arr = dh.cpu_arrays['g']
np.testing.assert_allclose(g_arr, np.full_like(g_arr, np.sqrt(3)))
def test_vectorize_moved_constants1(instruction_set=instruction_set):
opt = {'instruction_set': instruction_set, 'assume_inner_stride_one': True}
f = ps.fields("f: [1D]")
x = ast.TypedSymbol("x", np.float64)
kernel_func = ps.create_kernel(
[ast.SympyAssignment(x, 2.0), ast.SympyAssignment(f[0], x)],
cpu_prepend_optimizations=[ps.transformations.move_constants_before_loop], # explicitly move constants
cpu_vectorize_info=opt,
)
ps.show_code(kernel_func) # fails if `x` on rhs was not correctly vectorized
kernel = kernel_func.compile()
f_arr = np.zeros(9)
kernel(f=f_arr)
assert(np.all(f_arr == 2))
def test_vectorize_moved_constants2(instruction_set=instruction_set):
opt = {'instruction_set': instruction_set, 'assume_inner_stride_one': True}
f = ps.fields("f: [1D]")
x = ast.TypedSymbol("x", np.float64)
y = ast.TypedSymbol("y", np.float64)
kernel_func = ps.create_kernel(
[ast.SympyAssignment(x, 2.0), ast.SympyAssignment(y, 3.0), ast.SympyAssignment(f[0], x + y)],
cpu_prepend_optimizations=[ps.transformations.move_constants_before_loop], # explicitly move constants
cpu_vectorize_info=opt,
)
ps.show_code(kernel_func) # fails if `x` on rhs was not correctly vectorized
kernel = kernel_func.compile()
f_arr = np.zeros(9)
kernel(f=f_arr)
assert(np.all(f_arr == 5))
@pytest.mark.parametrize('openmp', [True, False])
def test_aligned_and_nt_stores(openmp, instruction_set=instruction_set):
domain_size = (24, 24)
# create a datahandling object
dh = ps.create_data_handling(domain_size, periodicity=(True, True), parallel=False, default_target=Target.CPU)
# fields
alignment = 'cacheline' if openmp else True
g = dh.add_array("g", values_per_cell=1, alignment=alignment)
dh.fill("g", 1.0, ghost_layers=True)
f = dh.add_array("f", values_per_cell=1, alignment=alignment)
dh.fill("f", 0.0, ghost_layers=True)
opt = {'instruction_set': instruction_set, 'assume_aligned': True, 'nontemporal': True,
'assume_inner_stride_one': True}
update_rule = [ps.Assignment(f.center(), 0.25 * (g[-1, 0] + g[1, 0] + g[0, -1] + g[0, 1]))]
# Without the base pointer spec, the inner store is not aligned
config = pystencils.config.CreateKernelConfig(target=dh.default_target, cpu_vectorize_info=opt, cpu_openmp=openmp)
ast = ps.create_kernel(update_rule, config=config)
if instruction_set in ['sse'] or instruction_set.startswith('avx') or instruction_set.startswith('sve'):
assert 'stream' in ast.instruction_set
assert 'streamFence' in ast.instruction_set
if instruction_set in ['neon', 'vsx', 'rvv']:
assert 'cachelineZero' in ast.instruction_set
if instruction_set in ['vsx']:
assert 'storeAAndFlushCacheline' in ast.instruction_set
for instruction in ['stream', 'streamFence', 'cachelineZero', 'storeAAndFlushCacheline', 'flushCacheline']:
if instruction in ast.instruction_set:
assert ast.instruction_set[instruction].split('{')[0] in ps.get_code_str(ast)
kernel = ast.compile()
# ps.show_code(ast)
dh.run_kernel(kernel)
np.testing.assert_equal(np.sum(dh.cpu_arrays['f']), np.prod(domain_size))
def test_nt_stores_symbolic_size(instruction_set=instruction_set):
f, g = ps.fields('f, g: [2D]', layout='fzyx')
update_rule = [ps.Assignment(f.center(), 0.0), ps.Assignment(g.center(), 0.0)]
opt = {'instruction_set': instruction_set, 'assume_aligned': True, 'nontemporal': True,
'assume_inner_stride_one': True}
config = pystencils.config.CreateKernelConfig(target=Target.CPU, cpu_vectorize_info=opt)
ast = ps.create_kernel(update_rule, config=config)
# ps.show_code(ast)
ast.compile()
def test_inplace_update(instruction_set=instruction_set):
shape = (9, 9, 3)
arr = np.ones(shape, order='f')
@ps.kernel
def update_rule(s):
f = ps.fields("f(3) : [2D]", f=arr)
s.tmp0 @= f(0)
s.tmp1 @= f(1)
s.tmp2 @= f(2)
f0, f1, f2 = f(0), f(1), f(2)
f0 @= 2 * s.tmp0
f1 @= 2 * s.tmp0
f2 @= 2 * s.tmp0
config = pystencils.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
ast = ps.create_kernel(update_rule, config=config)
kernel = ast.compile()
kernel(f=arr)
np.testing.assert_equal(arr, 2)
def test_vectorization_fixed_size(instruction_set=instruction_set):
instructions = get_vector_instruction_set(instruction_set=instruction_set)
configurations = []
# Fixed size - multiple of four
arr = np.ones((20 + 2, 24 + 2)) * 5.0
f, g = ps.fields(f=arr, g=arr)
configurations.append((arr, f, g))
# Fixed size - no multiple of four
arr = np.ones((21 + 2, 25 + 2)) * 5.0
f, g = ps.fields(f=arr, g=arr)
configurations.append((arr, f, g))
# Fixed size - different remainder
arr = np.ones((23 + 2, 17 + 2)) * 5.0
f, g = ps.fields(f=arr, g=arr)
configurations.append((arr, f, g))
for arr, f, g in configurations:
update_rule = [ps.Assignment(g[0, 0], f[0, 0] + f[-1, 0] + f[1, 0] + f[0, 1] + f[0, -1] + 42.0)]
ast = ps.create_kernel(update_rule)
vectorize(ast, instruction_set=instruction_set)
code = ps.get_code_str(ast)
add_instruction = instructions["+"][:instructions["+"].find("(")]
assert add_instruction in code
# print(code)
func = ast.compile()
dst = np.zeros_like(arr)
func(g=dst, f=arr)
np.testing.assert_equal(dst[1:-1, 1:-1], 5 * 5.0 + 42.0)
def test_vectorization_variable_size(instruction_set=instruction_set):
f, g = ps.fields("f, g : double[2D]")
update_rule = [ps.Assignment(g[0, 0], f[0, 0] + f[-1, 0] + f[1, 0] + f[0, 1] + f[0, -1] + 42.0)]
ast = ps.create_kernel(update_rule)
replace_inner_stride_with_one(ast)
vectorize(ast, instruction_set=instruction_set)
func = ast.compile()
arr = np.ones((23 + 2, 17 + 2)) * 5.0
dst = np.zeros_like(arr)
func(g=dst, f=arr)
np.testing.assert_equal(dst[1:-1, 1:-1], 5 * 5.0 + 42.0)
def test_piecewise1(instruction_set=instruction_set):
a, b, c, d, e = sp.symbols("a b c d e")
arr = np.ones((2 ** 3 + 2, 2 ** 4 + 2)) * 5.0
f, g = ps.fields(f=arr, g=arr)
update_rule = [ps.Assignment(a, f[1, 0]),
ps.Assignment(b, a),
ps.Assignment(c, f[0, 0] > 0.0),
ps.Assignment(g[0, 0], sp.Piecewise((b + 3 + f[0, 1], c), (0.0, True)))]
ast = ps.create_kernel(update_rule)
vectorize(ast, instruction_set=instruction_set)
func = ast.compile()
dst = np.zeros_like(arr)
func(g=dst, f=arr)
np.testing.assert_equal(dst[1:-1, 1:-1], 5 + 3 + 5.0)
def test_piecewise2(instruction_set=instruction_set):
arr = np.zeros((20, 20))
@ps.kernel
def test_kernel(s):
f, g = ps.fields(f=arr, g=arr)
s.condition @= f[0, 0] > 1
s.result @= 0.0 if s.condition else 1.0
g[0, 0] @= s.result
ast = ps.create_kernel(test_kernel)
# ps.show_code(ast)
vectorize(ast, instruction_set=instruction_set)
# ps.show_code(ast)
func = ast.compile()
func(f=arr, g=arr)
np.testing.assert_equal(arr, np.ones_like(arr))
def test_piecewise3(instruction_set=instruction_set):
arr = np.zeros((22, 22))
@ps.kernel
def test_kernel(s):
f, g = ps.fields(f=arr, g=arr)
s.b @= f[0, 1]
g[0, 0] @= 1.0 / (s.b + s.k) if f[0, 0] > 0.0 else 1.0
ast = ps.create_kernel(test_kernel)
# ps.show_code(ast)
vectorize(ast, instruction_set=instruction_set)
# ps.show_code(ast)
ast.compile()
def test_logical_operators(instruction_set=instruction_set):
arr = np.zeros((22, 22))
@ps.kernel
def kernel_and(s):
f, g = ps.fields(f=arr, g=arr)
s.c @= sp.And(f[0, 1] < 0.0, f[1, 0] < 0.0)
g[0, 0] @= sp.Piecewise([1.0 / f[1, 0], s.c], [1.0, True])
ast = ps.create_kernel(kernel_and)
vectorize(ast, instruction_set=instruction_set)
ast.compile()
@ps.kernel
def kernel_or(s):
f, g = ps.fields(f=arr, g=arr)
s.c @= sp.Or(f[0, 1] < 0.0, f[1, 0] < 0.0)
g[0, 0] @= sp.Piecewise([1.0 / f[1, 0], s.c], [1.0, True])
ast = ps.create_kernel(kernel_or)
vectorize(ast, instruction_set=instruction_set)
ast.compile()
@ps.kernel
def kernel_equal(s):
f, g = ps.fields(f=arr, g=arr)
s.c @= sp.Eq(f[0, 1], 2.0)
g[0, 0] @= sp.Piecewise([1.0 / f[1, 0], s.c], [1.0, True])
ast = ps.create_kernel(kernel_equal)
vectorize(ast, instruction_set=instruction_set)
ast.compile()
def test_hardware_query():
assert {'sse', 'neon', 'sve', 'sve2', 'sme', 'vsx', 'rvv'}.intersection(supported_instruction_sets)
def test_vectorised_pow(instruction_set=instruction_set):
arr = np.zeros((24, 24))
f, g = ps.fields(f=arr, g=arr)
as1 = ps.Assignment(g[0, 0], sp.Pow(f[0, 0], 2))
as2 = ps.Assignment(g[0, 0], sp.Pow(f[0, 0], 0.5))
as3 = ps.Assignment(g[0, 0], sp.Pow(f[0, 0], -0.5))
as4 = ps.Assignment(g[0, 0], sp.Pow(f[0, 0], 4))
as5 = ps.Assignment(g[0, 0], sp.Pow(f[0, 0], -4))
as6 = ps.Assignment(g[0, 0], sp.Pow(f[0, 0], -1))
ast = ps.create_kernel(as1)
vectorize(ast, instruction_set=instruction_set)
ast.compile()
ast = ps.create_kernel(as2)
vectorize(ast, instruction_set=instruction_set)
ast.compile()
ast = ps.create_kernel(as3)
vectorize(ast, instruction_set=instruction_set)
ast.compile()
ast = ps.create_kernel(as4)
vectorize(ast, instruction_set=instruction_set)
ast.compile()
ast = ps.create_kernel(as5)
vectorize(ast, instruction_set=instruction_set)
ast.compile()
ast = ps.create_kernel(as6)
vectorize(ast, instruction_set=instruction_set)
ast.compile()
def test_issue40(*_):
"""https://i10git.cs.fau.de/pycodegen/pystencils/-/issues/40"""
opt = {'instruction_set': "avx512", 'assume_aligned': False,
'nontemporal': False, 'assume_inner_stride_one': True}
src = ps.fields("src(1): double[2D]", layout='fzyx')
eq = [ps.Assignment(sp.Symbol('rho'), 1.0),
ps.Assignment(src[0, 0](0), sp.Rational(4, 9) * sp.Symbol('rho'))]
config = pystencils.config.CreateKernelConfig(target=Target.CPU, cpu_vectorize_info=opt, data_type='float64')
ast = ps.create_kernel(eq, config=config)
code = ps.get_code_str(ast)
assert 'epi32' not in code
import pytest
import numpy as np
import pystencils.config
import sympy as sp
import pystencils as ps
from pystencils.backends.simd_instruction_sets import (get_cacheline_size, get_supported_instruction_sets,
get_vector_instruction_set)
from pystencils.enums import Target
from pystencils.typing import CFunction
from . import test_vectorization
supported_instruction_sets = get_supported_instruction_sets() if get_supported_instruction_sets() else []
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_vectorisation_varying_arch(instruction_set):
shape = (9, 9, 3)
arr = np.ones(shape, order='f')
@ps.kernel
def update_rule(s):
f = ps.fields("f(3) : [2D]", f=arr)
s.tmp0 @= f(0)
s.tmp1 @= f(1)
s.tmp2 @= f(2)
f0, f1, f2 = f(0), f(1), f(2)
f0 @= 2 * s.tmp0
f1 @= 2 * s.tmp0
f2 @= 2 * s.tmp0
config = pystencils.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
ast = ps.create_kernel(update_rule, config=config)
kernel = ast.compile()
kernel(f=arr)
np.testing.assert_equal(arr, 2)
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_vectorized_abs_field(instruction_set, dtype):
"""Some instructions sets have abs, some don't.
Furthermore, the special treatment of unary minus makes this data type-sensitive too.
"""
arr = np.ones((2 ** 2 + 2, 2 ** 3 + 2), dtype=dtype)
arr[-3:, :] = -1
f, g = ps.fields(f=arr, g=arr)
update_rule = [ps.Assignment(g.center(), sp.Abs(f.center()))]
config = pystencils.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
ast = ps.create_kernel(update_rule, config=config)
func = ast.compile()
dst = np.zeros_like(arr)
func(g=dst, f=arr)
np.testing.assert_equal(np.sum(dst[1:-1, 1:-1]), 2 ** 2 * 2 ** 3)
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_vectorized_abs_scalar(instruction_set):
"""Some instructions sets have abs, some don't.
Furthermore, the special treatment of unary minus makes this data type-sensitive too.
"""
arr = np.zeros((2 ** 2 + 2, 2 ** 3 + 2), dtype="float64")
f = ps.fields(f=arr)
update_rule = [ps.Assignment(f.center(), sp.Abs(sp.Symbol("a")))]
config = pystencils.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
ast = ps.create_kernel(update_rule, config=config)
func = ast.compile()
func(f=arr, a=-1)
np.testing.assert_equal(np.sum(arr[1:-1, 1:-1]), 2 ** 2 * 2 ** 3)
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
@pytest.mark.parametrize('nontemporal', [False, True])
def test_strided(instruction_set, dtype, nontemporal):
f, g = ps.fields(f"f, g : {dtype}[2D]")
update_rule = [ps.Assignment(g[0, 0], f[0, 0] + f[-1, 0] + f[1, 0] + f[0, 1] + f[0, -1] + 42.0)]
config = pystencils.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set,
'nontemporal': nontemporal},
default_number_float=dtype)
if 'storeS' not in get_vector_instruction_set(dtype, instruction_set) \
and instruction_set not in ['avx512', 'avx512vl', 'rvv'] and not instruction_set.startswith('sve'):
with pytest.warns(UserWarning) as warn:
ast = ps.create_kernel(update_rule, config=config)
assert 'Could not vectorize loop' in warn[0].message.args[0]
else:
with pytest.warns(None) as warn:
ast = ps.create_kernel(update_rule, config=config)
assert len(warn) == 0
instruction = 'streamS' if nontemporal and 'streamS' in ast.instruction_set else 'storeS'
assert ast.instruction_set[instruction].split('{')[0] in ps.get_code_str(ast)
instruction = 'cachelineZero'
if instruction in ast.instruction_set:
assert ast.instruction_set[instruction] not in ps.get_code_str(ast)
# ps.show_code(ast)
func = ast.compile()
ref_config = pystencils.config.CreateKernelConfig(default_number_float=dtype)
ref_func = ps.create_kernel(update_rule, config=ref_config).compile()
# For some reason other array creations fail on the emulated ppc pipeline
size = (25, 19)
arr = np.zeros(size).astype(dtype)
for i in range(size[0]):
for j in range(size[1]):
arr[i, j] = i * j
dst = np.zeros_like(arr, dtype=dtype)
ref = np.zeros_like(arr, dtype=dtype)
func(g=dst, f=arr)
ref_func(g=ref, f=arr)
# print("dst: ", dst)
# print("np array: ", arr)
np.testing.assert_almost_equal(dst[1:-1, 1:-1], ref[1:-1, 1:-1], 13 if dtype == 'float64' else 5)
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
@pytest.mark.parametrize('gl_field, gl_kernel', [(1, 0), (0, 1), (1, 1)])
def test_alignment_and_correct_ghost_layers(gl_field, gl_kernel, instruction_set, dtype):
domain_size = (128, 128)
dh = ps.create_data_handling(domain_size, periodicity=(True, True), default_target=Target.CPU)
src = dh.add_array("src", values_per_cell=1, dtype=dtype, ghost_layers=gl_field, alignment=True)
dh.fill(src.name, 1.0, ghost_layers=True)
dst = dh.add_array("dst", values_per_cell=1, dtype=dtype, ghost_layers=gl_field, alignment=True)
dh.fill(dst.name, 1.0, ghost_layers=True)
update_rule = ps.Assignment(dst[0, 0], src[0, 0])
opt = {'instruction_set': instruction_set, 'assume_aligned': True,
'nontemporal': True, 'assume_inner_stride_one': True}
config = pystencils.config.CreateKernelConfig(target=dh.default_target,
cpu_vectorize_info=opt, ghost_layers=gl_kernel)
ast = ps.create_kernel(update_rule, config=config)
kernel = ast.compile()
if ('loadA' in ast.instruction_set or 'storeA' in ast.instruction_set) and gl_kernel != gl_field:
with pytest.raises(ValueError):
dh.run_kernel(kernel)
else:
dh.run_kernel(kernel)
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_cacheline_size(instruction_set):
cacheline_size = get_cacheline_size(instruction_set)
if cacheline_size is None and instruction_set in ['sse', 'avx', 'avx512', 'avx512vl', 'rvv']:
pytest.skip()
instruction_set = get_vector_instruction_set('double', instruction_set)
vector_size = instruction_set['bytes']
assert 8 < cacheline_size < 0x100000, "Cache line size is implausible"
if type(vector_size) is int:
assert cacheline_size % vector_size == 0, "Cache line size should be multiple of vector size"
assert cacheline_size & (cacheline_size - 1) == 0, "Cache line size is not a power of 2"
# TODO move to vectorise
@pytest.mark.parametrize('instruction_set',
sorted(set(supported_instruction_sets) - {test_vectorization.instruction_set}))
@pytest.mark.parametrize('function',
[f for f in test_vectorization.__dict__ if f.startswith('test_') and f not in ['test_hardware_query', 'test_aligned_and_nt_stores']])
def test_vectorization_other(instruction_set, function):
test_vectorization.__dict__[function](instruction_set)
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
@pytest.mark.parametrize('field_layout', ('fzyx', 'zyxf'))
def test_square_root(dtype, instruction_set, field_layout):
config = pystencils.config.CreateKernelConfig(data_type=dtype,
default_number_float=dtype,
cpu_vectorize_info={'instruction_set': instruction_set,
'assume_inner_stride_one': True,
'assume_aligned': False,
'nontemporal': False})
src_field = ps.Field.create_generic('pdfs', 2, dtype, index_dimensions=1, layout=field_layout, index_shape=(9,))
eq = [ps.Assignment(sp.Symbol("xi"), sum(src_field.center_vector)),
ps.Assignment(sp.Symbol("xi_2"), sp.Symbol("xi") * sp.sqrt(src_field.center))]
ast = ps.create_kernel(eq, config=config)
ast.compile()
code = ps.get_code_str(ast)
print(code)
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
@pytest.mark.parametrize('padding', (True, False))
def test_square_root_2(dtype, instruction_set, padding):
x, y = sp.symbols("x y")
src = ps.fields(f"src: {dtype}[2D]", layout='fzyx')
up = ps.Assignment(src[0, 0], 1 / x + sp.sqrt(y * 0.52 + x ** 2))
cpu_vec = {'instruction_set': instruction_set, 'assume_inner_stride_one': True,
'assume_sufficient_line_padding': padding,
'assume_aligned': True}
config = ps.CreateKernelConfig(data_type=dtype, default_number_float=dtype, cpu_vectorize_info=cpu_vec)
ast = ps.create_kernel(up, config=config)
ast.compile()
code = ps.get_code_str(ast)
print(code)
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
@pytest.mark.parametrize('padding', (True, False))
def test_pow(dtype, instruction_set, padding):
config = pystencils.config.CreateKernelConfig(data_type=dtype,
default_number_float=dtype,
cpu_vectorize_info={'instruction_set': instruction_set,
'assume_inner_stride_one': True,
'assume_sufficient_line_padding': padding,
'assume_aligned': False, 'nontemporal': False})
src_field = ps.Field.create_generic('pdfs', 2, dtype, index_dimensions=1, layout='fzyx', index_shape=(9,))
eq = [ps.Assignment(sp.Symbol("xi"), sum(src_field.center_vector)),
ps.Assignment(sp.Symbol("xi_2"), sp.Symbol("xi") * sp.Pow(src_field.center, 0.5))]
ast = ps.create_kernel(eq, config=config)
ast.compile()
code = ps.get_code_str(ast)
print(code)
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
@pytest.mark.parametrize('padding', (True, False))
def test_issue62(dtype, instruction_set, padding):
opt = {'instruction_set': instruction_set, 'assume_aligned': True,
'assume_inner_stride_one': True,
'assume_sufficient_line_padding': padding}
dx = sp.Symbol("dx")
dy = sp.Symbol("dy")
src, dst, rhs = ps.fields(f"src, src_tmp, rhs: {dtype}[2D]", layout='fzyx')
up = ps.Assignment(dst[0, 0], ((dy ** 2 * (src[1, 0] + src[-1, 0]))
+ (dx ** 2 * (src[0, 1] + src[0, -1]))
- (rhs[0, 0] * dx ** 2 * dy ** 2)) / (2 * (dx ** 2 + dy ** 2)))
config = ps.CreateKernelConfig(data_type=dtype,
default_number_float=dtype,
cpu_vectorize_info=opt)
ast = ps.create_kernel(up, config=config)
ast.compile()
code = ps.get_code_str(ast)
print(code)
assert 'pow' not in code
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_div_and_unevaluated_expr(dtype, instruction_set):
opt = {'instruction_set': instruction_set, 'assume_aligned': True, 'assume_inner_stride_one': True,
'assume_sufficient_line_padding': False}
x, y, z = sp.symbols("x y z")
rhs = (-4 * x ** 4 * y ** 2 * z ** 2 + (x ** 4 * y ** 2 / 3) + (x ** 4 * z ** 2 / 3)) / x ** 3
src = ps.fields(f"src: {dtype}[2D]", layout='fzyx')
up = ps.Assignment(src[0, 0], rhs)
config = ps.CreateKernelConfig(data_type=dtype,
default_number_float=dtype,
cpu_vectorize_info=opt)
ast = ps.create_kernel(up, config=config)
code = ps.get_code_str(ast)
# print(code)
ast.compile()
assert 'pow' not in code
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('instruction_set', ('sve', 'sve2', 'sme', 'rvv'))
def test_check_ast_parameters_sizeless(dtype, instruction_set):
f, g = ps.fields(f"f, g: {dtype}[3D]", layout='fzyx')
update_rule = [ps.Assignment(g.center(), 2 * f.center())]
config = pystencils.config.CreateKernelConfig(data_type=dtype,
cpu_vectorize_info={'instruction_set': instruction_set})
ast = ps.create_kernel(update_rule, config=config)
ast_symbols = [p.symbol for p in ast.get_parameters()]
assert ast.instruction_set['width'] not in ast_symbols
assert ast.instruction_set['intwidth'] not in ast_symbols
# TODO this test case needs a complete rework of the vectoriser. The reason is that the vectoriser does not
# TODO vectorise symbols at the moment because they could be strides or field sizes, thus involved in pointer arithmetic
# TODO This means that the vectoriser only works if fields are involved on the rhs.
# @pytest.mark.parametrize('dtype', ('float32', 'float64'))
# @pytest.mark.parametrize('instruction_set', supported_instruction_sets)
# def test_vectorised_symbols(dtype, instruction_set):
# opt = {'instruction_set': instruction_set, 'assume_aligned': True, 'assume_inner_stride_one': True,
# 'assume_sufficient_line_padding': False}
#
# x, y, z = sp.symbols("x y z")
# rhs = 1 / x ** 2 * (x + y)
#
# src = ps.fields(f"src: {dtype}[2D]", layout='fzyx')
#
# up = ps.Assignment(src[0, 0], rhs)
#
# config = ps.CreateKernelConfig(data_type=dtype,
# default_number_float=dtype,
# cpu_vectorize_info=opt)
#
# ast = ps.create_kernel(up, config=config)
# code = ps.get_code_str(ast)
# print(code)
#
# ast.compile()
#
# assert 'pow' not in code
import pystencils as ps
def test_version_string():
version = ps.__version__
print(version)
numeric_version = [int(x, 10) for x in version.split('.')[0:1]]
test_version = sum(x * (100 ** i) for i, x in enumerate(numeric_version))
assert test_version >= 1
from collections import defaultdict
import sympy as sp
from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase
from pystencils.field import Field, offsetComponentToDirectionString
from pystencils.types import TypedSymbol, DataType
from pystencils.slicing import normalizeSlice
import pystencils.ast as ast
def fastSubs(term, subsDict):
"""Similar to sympy subs function.
This version is much faster for big substitution dictionaries than sympy version"""
def visit(expr):
if expr in subsDict:
return subsDict[expr]
if not hasattr(expr, 'args'):
return expr
paramList = [visit(a) for a in expr.args]
return expr if not paramList else expr.func(*paramList)
return visit(term)
def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None, loopOrder=None):
"""
Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.
:param body: list of nodes
:param functionName: name of generated C function
:param iterationSlice: if not None, iteration is done only over this slice of the field
:param ghostLayers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
if None, the number of ghost layers is determined automatically and assumed to be equal for a
all dimensions
:param loopOrder: loop ordering from outer to inner loop (optimal ordering is same as layout)
:return: :class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
"""
# find correct ordering by inspecting participating FieldAccesses
fieldAccesses = body.atoms(Field.Access)
fieldList = [e.field for e in fieldAccesses]
fields = set(fieldList)
if loopOrder is None:
loopOrder = getOptimalLoopOrdering(fields)
shapes = set([f.spatialShape for f in fields])
if len(shapes) > 1:
nrOfFixedSizedFields = 0
for shape in shapes:
if not isinstance(shape[0], sp.Basic):
nrOfFixedSizedFields += 1
assert nrOfFixedSizedFields <= 1, "Differently sized field accesses in loop body: " + str(shapes)
shape = list(shapes)[0]
if iterationSlice is not None:
iterationSlice = normalizeSlice(iterationSlice, shape)
if ghostLayers is None:
requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses])
ghostLayers = [(requiredGhostLayers, requiredGhostLayers)] * len(loopOrder)
currentBody = body
lastLoop = None
for i, loopCoordinate in enumerate(reversed(loopOrder)):
if iterationSlice is None:
begin = ghostLayers[loopCoordinate][0]
end = shape[loopCoordinate] - ghostLayers[loopCoordinate][1]
newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, begin, end, 1)
lastLoop = newLoop
currentBody = ast.Block([lastLoop])
else:
sliceComponent = iterationSlice[loopCoordinate]
if type(sliceComponent) is slice:
sc = sliceComponent
newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, sc.start, sc.stop, sc.step)
lastLoop = newLoop
currentBody = ast.Block([lastLoop])
else:
assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate),
sp.sympify(sliceComponent))
currentBody.insertFront(assignment)
return ast.KernelFunction(currentBody, fields, functionName)
def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
r"""
Addressing elements in structured arrays are done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
Returns a new typed symbol, where the name encodes which coordinates have been resolved.
:param fieldAccess: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
:param coordinates: mapping of coordinate ids to its value, where stride*value is calculated
:param previousPtr: the pointer which is dereferenced
:return: tuple with the new pointer symbol and the calculated offset
Example:
>>> field = Field.createGeneric('myfield', spatialDimensions=2, indexDimensions=1)
>>> x, y = sp.symbols("x y")
>>> prevPointer = TypedSymbol("ptr", "double")
>>> createIntermediateBasePointer(field[1,-2](5), {0: x}, prevPointer)
(ptr_E, x*fstride_myfield[0] + fstride_myfield[0])
>>> createIntermediateBasePointer(field[1,-2](5), {0: x, 1 : y }, prevPointer)
(ptr_E_2S, x*fstride_myfield[0] + y*fstride_myfield[1] + fstride_myfield[0] - 2*fstride_myfield[1])
"""
field = fieldAccess.field
offset = 0
name = ""
listToHash = []
for coordinateId, coordinateValue in coordinates.items():
offset += field.strides[coordinateId] * coordinateValue
if coordinateId < field.spatialDimensions:
offset += field.strides[coordinateId] * fieldAccess.offsets[coordinateId]
if type(fieldAccess.offsets[coordinateId]) is int:
offsetComp = offsetComponentToDirectionString(coordinateId, fieldAccess.offsets[coordinateId])
name += "_"
name += offsetComp if offsetComp else "C"
else:
listToHash.append(fieldAccess.offsets[coordinateId])
else:
if type(coordinateValue) is int:
name += "_%d" % (coordinateValue,)
else:
listToHash.append(coordinateValue)
if len(listToHash) > 0:
name += "%0.6X" % (abs(hash(tuple(listToHash))))
newPtr = TypedSymbol(previousPtr.name + name, previousPtr.dtype)
return newPtr, offset
def parseBasePointerInfo(basePointerSpecification, loopOrder, field):
"""
Creates base pointer specification for :func:`resolveFieldAccesses` function.
Specification of how many and which intermediate pointers are created for a field access.
For example [ (0), (2,3,)] creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
zero directly in the field access. These specifications are more sensible defined dependent on the loop ordering.
This function translates more readable version into the specification above.
Allowed specifications:
- "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
spatialInner1 the loop enclosing the innermost
- "spatialOuter<int>" spatialOuter0 is the outermost loop
- "index<int>": index coordinate
- "<int>": specifying directly the coordinate
:param basePointerSpecification: nested list with above specifications
:param loopOrder: list with ordering of loops from outer to inner
:param field:
:return: list of tuples that can be passed to :func:`resolveFieldAccesses`
"""
result = []
specifiedCoordinates = set()
loopOrder = list(reversed(loopOrder))
for specGroup in basePointerSpecification:
newGroup = []
def addNewElement(i):
if i >= field.spatialDimensions + field.indexDimensions:
raise ValueError("Coordinate %d does not exist" % (i,))
newGroup.append(i)
if i in specifiedCoordinates:
raise ValueError("Coordinate %d specified two times" % (i,))
specifiedCoordinates.add(i)
for element in specGroup:
if type(element) is int:
addNewElement(element)
elif element.startswith("spatial"):
element = element[len("spatial"):]
if element.startswith("Inner"):
index = int(element[len("Inner"):])
addNewElement(loopOrder[index])
elif element.startswith("Outer"):
index = int(element[len("Outer"):])
addNewElement(loopOrder[-index])
elif element == "all":
for i in range(field.spatialDimensions):
addNewElement(i)
else:
raise ValueError("Could not parse " + element)
elif element.startswith("index"):
index = int(element[len("index"):])
addNewElement(field.spatialDimensions + index)
else:
raise ValueError("Unknown specification %s" % (element,))
result.append(newGroup)
allCoordinates = set(range(field.spatialDimensions + field.indexDimensions))
rest = allCoordinates - specifiedCoordinates
if rest:
result.append(list(rest))
return result
def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerInfo={}, fieldToFixedCoordinates={}):
"""
Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing
:param astNode: the AST root
:param readOnlyFieldNames: set of field names which are considered read-only
:param fieldToBasePointerInfo: a list of tuples indicating which intermediate base pointers should be created
for details see :func:`parseBasePointerInfo`
:param fieldToFixedCoordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
counters to index the field these symbols are used as coordinates
:return: transformed AST
"""
def visitSympyExpr(expr, enclosingBlock, sympyAssignment):
if isinstance(expr, Field.Access):
fieldAccess = expr
field = fieldAccess.field
if field.name in fieldToBasePointerInfo:
basePointerInfo = fieldToBasePointerInfo[field.name]
else:
basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))]
dtype = DataType(field.dtype)
dtype.alias = False
dtype.ptr = True
if field.name in readOnlyFieldNames:
dtype.const = True
fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbolNameToVariableName(field.name)), dtype)
lastPointer = fieldPtr
def createCoordinateDict(group):
coordDict = {}
for e in group:
if e < field.spatialDimensions:
if field.name in fieldToFixedCoordinates:
coordDict[e] = fieldToFixedCoordinates[field.name][e]
else:
ctrName = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), "int")
else:
coordDict[e] = fieldAccess.index[e-field.spatialDimensions]
return coordDict
for group in reversed(basePointerInfo[1:]):
coordDict = createCoordinateDict(group)
newPtr, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
if newPtr not in enclosingBlock.symbolsDefined:
newAssignment = ast.SympyAssignment(newPtr, lastPointer + offset, isConst=False)
enclosingBlock.insertBefore(newAssignment, sympyAssignment)
lastPointer = newPtr
_, offset = createIntermediateBasePointer(fieldAccess, createCoordinateDict(basePointerInfo[0]),
lastPointer)
baseArr = IndexedBase(lastPointer, shape=(1,))
return baseArr[offset]
else:
newArgs = [visitSympyExpr(e, enclosingBlock, sympyAssignment) for e in expr.args]
kwargs = {'evaluate': False} if type(expr) is sp.Add or type(expr) is sp.Mul else {}
return expr.func(*newArgs, **kwargs) if newArgs else expr
def visitNode(subAst):
if isinstance(subAst, ast.SympyAssignment):
enclosingBlock = subAst.parent
assert type(enclosingBlock) is ast.Block
subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock, subAst)
subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock, subAst)
else:
for i, a in enumerate(subAst.args):
visitNode(a)
return visitNode(astNode)
def moveConstantsBeforeLoop(astNode):
"""
Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.
Call this after creating the loop structure with :func:`makeLoopOverDomain`
:param astNode:
:return:
"""
def findBlockToMoveTo(node):
"""
Traverses parents of node as long as the symbols are independent and returns a (parent) block
the assignment can be safely moved to
:param node: SympyAssignment inside a Block
:return blockToInsertTo, childOfBlockToInsertBefore
"""
assert isinstance(node, ast.SympyAssignment)
assert isinstance(node.parent, ast.Block)
lastBlock = node.parent
lastBlockChild = node
element = node.parent
prevElement = node
while element:
if isinstance(element, ast.Block):
lastBlock = element
lastBlockChild = prevElement
if node.undefinedSymbols.intersection(element.symbolsDefined):
break
prevElement = element
element = element.parent
return lastBlock, lastBlockChild
def checkIfAssignmentAlreadyInBlock(assignment, targetBlock):
for arg in targetBlock.args:
if type(arg) is not ast.SympyAssignment:
continue
if arg.lhs == assignment.lhs:
return arg
return None
for block in astNode.atoms(ast.Block):
children = block.takeChildNodes()
for child in children:
if not isinstance(child, ast.SympyAssignment):
block.append(child)
else:
target, childToInsertBefore = findBlockToMoveTo(child)
if target == block: # movement not possible
target.append(child)
else:
existingAssignment = checkIfAssignmentAlreadyInBlock(child, target)
if not existingAssignment:
target.insertBefore(child, childToInsertBefore)
else:
assert existingAssignment.rhs == child.rhs, "Symbol with same name exists already"
def splitInnerLoop(astNode, symbolGroups):
"""
Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
:param astNode: AST root
:param symbolGroups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups and which
no symbol in a symbol group depends on, are not updated!
:return: transformed AST
"""
allLoops = astNode.atoms(ast.LoopOverCoordinate)
innerLoop = [l for l in allLoops if l.isInnermostLoop]
assert len(innerLoop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
innerLoop = innerLoop[0]
assert type(innerLoop.body) is ast.Block
outerLoop = [l for l in allLoops if l.isOutermostLoop]
assert len(outerLoop) == 1, "Error in AST, multiple outermost loops."
outerLoop = outerLoop[0]
symbolsWithTemporaryArray = dict()
assignmentMap = {a.lhs: a for a in innerLoop.body.args}
assignmentGroups = []
for symbolGroup in symbolGroups:
# get all dependent symbols
symbolsToProcess = list(symbolGroup)
symbolsResolved = set()
while symbolsToProcess:
s = symbolsToProcess.pop()
if s in symbolsResolved:
continue
if s in assignmentMap: # if there is no assignment inside the loop body it is independent already
for newSymbol in assignmentMap[s].rhs.atoms(sp.Symbol):
if type(newSymbol) is not Field.Access and newSymbol not in symbolsWithTemporaryArray:
symbolsToProcess.append(newSymbol)
symbolsResolved.add(s)
for symbol in symbolGroup:
if type(symbol) is not Field.Access:
assert type(symbol) is TypedSymbol
symbolsWithTemporaryArray[symbol] = IndexedBase(symbol, shape=(1,))[innerLoop.loopCounterSymbol]
assignmentGroup = []
for assignment in innerLoop.body.args:
if assignment.lhs in symbolsResolved:
newRhs = assignment.rhs.subs(symbolsWithTemporaryArray.items())
if type(assignment.lhs) is not Field.Access and assignment.lhs in symbolGroup:
newLhs = IndexedBase(assignment.lhs, shape=(1,))[innerLoop.loopCounterSymbol]
else:
newLhs = assignment.lhs
assignmentGroup.append(ast.SympyAssignment(newLhs, newRhs))
assignmentGroups.append(assignmentGroup)
newLoops = [innerLoop.newLoopWithDifferentBody(ast.Block(group)) for group in assignmentGroups]
innerLoop.parent.replace(innerLoop, ast.Block(newLoops))
for tmpArray in symbolsWithTemporaryArray:
outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArray, innerLoop.stop))
outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArray))
def symbolNameToVariableName(symbolName):
"""Replaces characters which are allowed in sympy symbol names but not in C/C++ variable names"""
return symbolName.replace("^", "_")
def typeAllEquations(eqs, typeForSymbol):
"""
Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
Additionally returns sets of all fields which are read/written
:param eqs: list of equations
:param typeForSymbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
:return: ``fieldsRead, fieldsWritten, typedEquations`` set of read fields, set of written fields, list of equations
where symbols have been replaced by typed symbols
"""
fieldsWritten = set()
fieldsRead = set()
def processRhs(term):
"""Replaces Symbols by:
- TypedSymbol if symbol is not a field access
"""
if isinstance(term, Field.Access):
fieldsRead.add(term.field)
return term
elif isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
return TypedSymbol(symbolNameToVariableName(term.name), typeForSymbol[term.name])
else:
newArgs = [processRhs(arg) for arg in term.args]
return term.func(*newArgs) if newArgs else term
def processLhs(term):
"""Replaces symbol by TypedSymbol and adds field to fieldsWriten"""
if isinstance(term, Field.Access):
fieldsWritten.add(term.field)
return term
elif isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
return TypedSymbol(term.name, typeForSymbol[term.name])
else:
assert False, "Expected a symbol as left-hand-side"
typedEquations = []
for eq in eqs:
if isinstance(eq, sp.Eq):
newLhs = processLhs(eq.lhs)
newRhs = processRhs(eq.rhs)
typedEquations.append(ast.SympyAssignment(newLhs, newRhs))
else:
assert isinstance(eq, ast.Node), "Only equations and ast nodes are allowed in input"
typedEquations.append(eq)
typedEquations = typedEquations
return fieldsRead, fieldsWritten, typedEquations
# --------------------------------------- Helper Functions -------------------------------------------------------------
def typingFromSympyInspection(eqs, defaultType="double"):
"""
Creates a default symbol name to type mapping.
If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
:param eqs: list of equations
:param defaultType: the type for non-boolean symbols
:return: dictionary, mapping symbol name to type
"""
result = defaultdict(lambda: defaultType)
for eq in eqs:
# problematic case here is when rhs is a symbol: then it is impossible to decide here without
# further information what type the left hand side is - default fallback is the dict value then
if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
result[eq.lhs.name] = "bool"
return result
def getNextParentOfType(node, parentType):
"""
Traverses the AST nodes parents until a parent of given type was found. If no such parent is found, None is returned
"""
parent = node.parent
while parent is not None:
if isinstance(parent, parentType):
return parent
parent = parent.parent
return None
def getOptimalLoopOrdering(fields):
"""
Determines the optimal loop order for a given set of fields.
If the fields have different memory layout or different sizes an exception is thrown.
:param fields: sequence of fields
:return: list of coordinate ids, where the first list entry should be the outermost loop
"""
assert len(fields) > 0
refField = next(iter(fields))
for field in fields:
if field.spatialDimensions != refField.spatialDimensions:
raise ValueError("All fields have to have the same number of spatial dimensions")
layouts = set([field.layout for field in fields])
if len(layouts) > 1:
raise ValueError("Due to different layout of the fields no optimal loop ordering exists " + str(layouts))
layout = list(layouts)[0]
return list(layout)
def getLoopHierarchy(astNode):
"""Determines the loop structure around a given AST node.
:param astNode: the AST node
:return: list of coordinate ids, where the first list entry is the innermost loop
"""
result = []
node = astNode
while node is not None:
node = getNextParentOfType(node, ast.LoopOverCoordinate)
if node:
result.append(node.coordinateToLoopOver)
return reversed(result)
\ No newline at end of file
import sympy as sp
from sympy.core.cache import cacheit
class TypedSymbol(sp.Symbol):
def __new__(cls, name, *args, **kwds):
obj = TypedSymbol.__xnew_cached_(cls, name, *args, **kwds)
return obj
def __new_stage2__(cls, name, dtype):
obj = super(TypedSymbol, cls).__xnew__(cls, name)
obj._dtype = dtype
return obj
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
@property
def dtype(self):
return self._dtype
def _hashable_content(self):
superClassContents = list(super(TypedSymbol, self)._hashable_content())
t = tuple(superClassContents + [hash(repr(self._dtype))])
return t
def __getnewargs__(self):
return self.name, self.dtype
_c_dtype_dict = {0: 'int', 1: 'double', 2: 'float'}
_dtype_dict = {'int': 0, 'double': 1, 'float': 2}
class DataType(object):
def __init__(self, dtype):
self.alias = True
self.const = False
self.ptr = False
if isinstance(dtype, str):
self.dtype = _dtype_dict[dtype]
else:
self.dtype = dtype
def __repr__(self):
return "{!s} {!s}{!s} {!s}".format("const" if self.const else "", _c_dtype_dict[self.dtype],
"*" if self.ptr else "", "__restrict__" if not self.alias else "")