Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 672 additions and 1064 deletions
%% Cell type:code id: tags:
``` python
import pytest
pytest.importorskip('waLBerla')
```
%% Cell type:code id: tags:
``` python
from pystencils.session import *
from time import perf_counter
from statistics import median
from functools import partial
```
%% Cell type:markdown id: tags:
## Benchmark for Python call overhead
%% Cell type:code id: tags:
``` python
inner_repeats = 100
outer_repeats = 5
sizes = [2**i for i in range(1, 8)]
sizes
```
%% Output
$\displaystyle \left[ 2, \ 4, \ 8, \ 16, \ 32, \ 64, \ 128\right]$
[2, 4, 8, 16, 32, 64, 128]
%% Cell type:code id: tags:
``` python
def benchmark_pure(domain_size, extract_first=False):
src = np.zeros(domain_size)
dst = np.zeros_like(src)
f_src, f_dst = ps.fields("src, dst", src=src, dst=dst)
kernel = ps.create_kernel(ps.Assignment(f_dst.center, f_src.center)).compile()
if extract_first:
kernel = kernel.kernel
start = perf_counter()
for i in range(inner_repeats):
kernel(src=src, dst=dst)
src, dst = dst, src
end = perf_counter()
else:
start = perf_counter()
for i in range(inner_repeats):
kernel(src=src, dst=dst)
src, dst = dst, src
end = perf_counter()
return (end - start) / inner_repeats
def benchmark_datahandling(domain_size, parallel=False):
dh = ps.create_data_handling(domain_size, parallel=parallel)
f_src = dh.add_array('src')
f_dst = dh.add_array('dst')
kernel = ps.create_kernel(ps.Assignment(f_dst.center, f_src.center)).compile()
start = perf_counter()
for i in range(inner_repeats):
dh.run_kernel(kernel)
dh.swap('src', 'dst')
end = perf_counter()
return (end - start) / inner_repeats
name_to_func = {
'pure_extract': partial(benchmark_pure, extract_first=True),
'pure_no_extract': partial(benchmark_pure, extract_first=False),
'dh_serial': partial(benchmark_datahandling, parallel=False),
'dh_parallel': partial(benchmark_datahandling, parallel=True),
}
```
%% Cell type:code id: tags:
``` python
result = {'block_size': [],
'name': [],
'time': []}
for bs in sizes:
print("Computing size ", bs)
for name, func in name_to_func.items():
for i in range(outer_repeats):
time = func((bs, bs))
result['block_size'].append(bs)
result['name'].append(name)
result['time'].append(time)
```
%% Output
Computing size 2
Computing size 4
Computing size 8
Computing size 16
Computing size 32
Computing size 64
Computing size 128
%% Cell type:code id: tags:
``` python
if 'is_test_run' not in globals():
import pandas as pd
import seaborn as sns
data = pd.DataFrame.from_dict(result)
plt.subplot(1,2,1)
sns.barplot(x='block_size', y='time', hue='name', data=data, alpha=0.6)
plt.yscale('log')
plt.subplot(1,2,2)
data = pd.DataFrame.from_dict(result)
sns.barplot(x='block_size', y='time', hue='name', data=data, alpha=0.6)
```
%% Output
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import numpy as np
import sympy
from sympy.abc import k
import pystencils
from pystencils.data_types import create_type
def test_sum():
sum = sympy.Sum(k, (k, 1, 100))
expanded_sum = sum.doit()
print(sum)
print(expanded_sum)
x = pystencils.fields('x: float32[1d]')
assignments = pystencils.AssignmentCollection({
x.center(): sum
})
ast = pystencils.create_kernel(assignments)
code = str(pystencils.get_code_obj(ast))
kernel = ast.compile()
print(code)
assert 'double sum' in code
array = np.zeros((10,), np.float32)
kernel(x=array)
assert np.allclose(array, int(expanded_sum) * np.ones_like(array))
def test_sum_use_float():
sum = sympy.Sum(k, (k, 1, 100))
expanded_sum = sum.doit()
print(sum)
print(expanded_sum)
x = pystencils.fields('x: float32[1d]')
assignments = pystencils.AssignmentCollection({
x.center(): sum
})
ast = pystencils.create_kernel(assignments, data_type=create_type('float32'))
code = str(pystencils.get_code_obj(ast))
kernel = ast.compile()
print(code)
print(pystencils.get_code_obj(ast))
assert 'float sum' in code
array = np.zeros((10,), np.float32)
kernel(x=array)
assert np.allclose(array, int(expanded_sum) * np.ones_like(array))
def test_product():
k = pystencils.TypedSymbol('k', create_type('int64'))
sum = sympy.Product(k, (k, 1, 10))
expanded_sum = sum.doit()
print(sum)
print(expanded_sum)
x = pystencils.fields('x: int64[1d]')
assignments = pystencils.AssignmentCollection({
x.center(): sum
})
ast = pystencils.create_kernel(assignments)
code = pystencils.get_code_str(ast)
kernel = ast.compile()
print(code)
assert 'int64_t product' in code
array = np.zeros((10,), np.int64)
kernel(x=array)
assert np.allclose(array, int(expanded_sum) * np.ones_like(array))
def test_prod_var_limit():
k = pystencils.TypedSymbol('k', create_type('int64'))
limit = pystencils.TypedSymbol('limit', create_type('int64'))
sum = sympy.Sum(k, (k, 1, limit))
expanded_sum = sum.replace(limit, 100).doit()
print(sum)
print(expanded_sum)
x = pystencils.fields('x: int64[1d]')
assignments = pystencils.AssignmentCollection({
x.center(): sum
})
ast = pystencils.create_kernel(assignments)
pystencils.show_code(ast)
kernel = ast.compile()
array = np.zeros((10,), np.int64)
kernel(x=array, limit=100)
assert np.allclose(array, int(expanded_sum) * np.ones_like(array))
import pystencils as ps
from pystencils import TypedSymbol
from pystencils.astnodes import LoopOverCoordinate, SympyAssignment
from pystencils.data_types import create_type
from pystencils.transformations import filtered_tree_iteration, get_loop_hierarchy, get_loop_counter_symbol_hierarchy
def test_loop_information():
f, g = ps.fields("f, g: double[2D]")
update_rule = ps.Assignment(g[0, 0], f[0, 0])
ast = ps.create_kernel(update_rule)
inner_loops = [l for l in filtered_tree_iteration(ast, LoopOverCoordinate, stop_type=SympyAssignment)
if l.is_innermost_loop]
loop_order = []
for i in get_loop_hierarchy(inner_loops[0].args[0]):
loop_order.append(i)
assert loop_order == [0, 1]
loop_symbols = get_loop_counter_symbol_hierarchy(inner_loops[0].args[0])
assert loop_symbols == [TypedSymbol("ctr_1", create_type("int"), nonnegative=True),
TypedSymbol("ctr_0", create_type("int"), nonnegative=True)]
from sympy.abc import a, b, c, d, e, f
import pystencils
from pystencils.data_types import cast_func, create_type
def test_type_interference():
x = pystencils.fields('x: float32[3d]')
assignments = pystencils.AssignmentCollection({
a: cast_func(10, create_type('float64')),
b: cast_func(10, create_type('uint16')),
e: 11,
c: b,
f: c + b,
d: c + b + x.center + e,
x.center: c + b + x.center
})
ast = pystencils.create_kernel(assignments)
code = str(pystencils.get_code_str(ast))
assert 'double a' in code
assert 'uint16_t b' in code
assert 'uint16_t f' in code
assert 'int64_t e' in code
import pytest
import numpy as np
import sympy as sp
import pystencils as ps
from pystencils.backends.simd_instruction_sets import (get_cacheline_size, get_supported_instruction_sets,
get_vector_instruction_set)
from pystencils.data_types import cast_func, VectorType
from pystencils.enums import Target
supported_instruction_sets = get_supported_instruction_sets() if get_supported_instruction_sets() else []
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_vectorisation_varying_arch(instruction_set):
shape = (9, 9, 3)
arr = np.ones(shape, order='f')
@ps.kernel
def update_rule(s):
f = ps.fields("f(3) : [2D]", f=arr)
s.tmp0 @= f(0)
s.tmp1 @= f(1)
s.tmp2 @= f(2)
f0, f1, f2 = f(0), f(1), f(2)
f0 @= 2 * s.tmp0
f1 @= 2 * s.tmp0
f2 @= 2 * s.tmp0
config = ps.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
ast = ps.create_kernel(update_rule, config=config)
kernel = ast.compile()
kernel(f=arr)
np.testing.assert_equal(arr, 2)
@pytest.mark.parametrize('dtype', ('float', 'double'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_vectorized_abs(instruction_set, dtype):
"""Some instructions sets have abs, some don't.
Furthermore, the special treatment of unary minus makes this data type-sensitive too.
"""
arr = np.ones((2 ** 2 + 2, 2 ** 3 + 2), dtype=np.float64 if dtype == 'double' else np.float32)
arr[-3:, :] = -1
f, g = ps.fields(f=arr, g=arr)
update_rule = [ps.Assignment(g.center(), sp.Abs(f.center()))]
config = ps.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
ast = ps.create_kernel(update_rule, config=config)
func = ast.compile()
dst = np.zeros_like(arr)
func(g=dst, f=arr)
np.testing.assert_equal(np.sum(dst[1:-1, 1:-1]), 2 ** 2 * 2 ** 3)
@pytest.mark.parametrize('dtype', ('float', 'double'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_strided(instruction_set, dtype):
f, g = ps.fields(f"f, g : float{64 if dtype == 'double' else 32}[2D]")
update_rule = [ps.Assignment(g[0, 0], f[0, 0] + f[-1, 0] + f[1, 0] + f[0, 1] + f[0, -1] + 42.0)]
if 'storeS' not in get_vector_instruction_set(dtype, instruction_set) and not instruction_set in ['avx512', 'rvv'] and not instruction_set.startswith('sve'):
with pytest.warns(UserWarning) as warn:
config = ps.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
ast = ps.create_kernel(update_rule, config=config)
assert 'Could not vectorize loop' in warn[0].message.args[0]
else:
with pytest.warns(None) as warn:
config = ps.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
ast = ps.create_kernel(update_rule, config=config)
assert len(warn) == 0
func = ast.compile()
ref_func = ps.create_kernel(update_rule).compile()
arr = np.random.random((23 + 2, 17 + 2)).astype(np.float64 if dtype == 'double' else np.float32)
dst = np.zeros_like(arr, dtype=np.float64 if dtype == 'double' else np.float32)
ref = np.zeros_like(arr, dtype=np.float64 if dtype == 'double' else np.float32)
func(g=dst, f=arr)
ref_func(g=ref, f=arr)
np.testing.assert_almost_equal(dst, ref, 13 if dtype == 'double' else 5)
@pytest.mark.parametrize('dtype', ('float', 'double'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
@pytest.mark.parametrize('gl_field, gl_kernel', [(1, 0), (0, 1), (1, 1)])
def test_alignment_and_correct_ghost_layers(gl_field, gl_kernel, instruction_set, dtype):
dtype = np.float64 if dtype == 'double' else np.float32
domain_size = (128, 128)
dh = ps.create_data_handling(domain_size, periodicity=(True, True), default_target=Target.CPU)
src = dh.add_array("src", values_per_cell=1, dtype=dtype, ghost_layers=gl_field, alignment=True)
dh.fill(src.name, 1.0, ghost_layers=True)
dst = dh.add_array("dst", values_per_cell=1, dtype=dtype, ghost_layers=gl_field, alignment=True)
dh.fill(dst.name, 1.0, ghost_layers=True)
update_rule = ps.Assignment(dst[0, 0], src[0, 0])
opt = {'instruction_set': instruction_set, 'assume_aligned': True,
'nontemporal': True, 'assume_inner_stride_one': True}
config = ps.CreateKernelConfig(target=dh.default_target, cpu_vectorize_info=opt, ghost_layers=gl_kernel)
ast = ps.create_kernel(update_rule, config=config)
kernel = ast.compile()
if gl_kernel != gl_field:
with pytest.raises(ValueError):
dh.run_kernel(kernel)
else:
dh.run_kernel(kernel)
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_cacheline_size(instruction_set):
cacheline_size = get_cacheline_size(instruction_set)
if cacheline_size is None and instruction_set in ['sse', 'avx', 'avx512', 'rvv']:
pytest.skip()
instruction_set = get_vector_instruction_set('double', instruction_set)
vector_size = instruction_set['bytes']
assert cacheline_size > 8 and cacheline_size < 0x100000, "Cache line size is implausible"
if type(vector_size) is int:
assert cacheline_size % vector_size == 0, "Cache line size should be multiple of vector size"
assert cacheline_size & (cacheline_size - 1) == 0, "Cache line size is not a power of 2"
# test_vectorization is not parametrized because it is supposed to run without pytest, so we parametrize it here
from pystencils_tests import test_vectorization
@pytest.mark.parametrize('instruction_set', sorted(set(supported_instruction_sets) - set([test_vectorization.instruction_set])))
@pytest.mark.parametrize('function', [f for f in test_vectorization.__dict__ if f.startswith('test_') and f != 'test_hardware_query'])
def test_vectorization_other(instruction_set, function):
test_vectorization.__dict__[function](instruction_set)
[pytest] [pytest]
testpaths = src tests doc/notebooks
pythonpath = src
python_files = test_*.py *_test.py scenario_*.py python_files = test_*.py *_test.py scenario_*.py
norecursedirs = *.egg-info .git .cache .ipynb_checkpoints htmlcov norecursedirs = *.egg-info .git .cache .ipynb_checkpoints htmlcov
addopts = --doctest-modules --durations=20 --cov-config pytest.ini addopts = --doctest-modules --durations=20 --cov-config pytest.ini
markers = markers =
kerncraft: tests depending on kerncraft longrun: tests only run at night since they have large execution time
notebook: mark for notebooks notebook: mark for notebooks
# these warnings all come from third party libraries. # these warnings all come from third party libraries.
filterwarnings = filterwarnings =
...@@ -13,23 +15,25 @@ filterwarnings = ...@@ -13,23 +15,25 @@ filterwarnings =
ignore:.*is a deprecated alias for the builtin `bool`:DeprecationWarning ignore:.*is a deprecated alias for the builtin `bool`:DeprecationWarning
ignore:'contextfilter' is renamed to 'pass_context':DeprecationWarning ignore:'contextfilter' is renamed to 'pass_context':DeprecationWarning
ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc':DeprecationWarning ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc':DeprecationWarning
ignore:Animation was deleted without rendering anything:UserWarning
[run] [run]
branch = True branch = True
source = pystencils source = src/pystencils
pystencils_tests tests
omit = doc/* omit = doc/*
pystencils_tests/* tests/*
setup.py setup.py
quicktest.py
conftest.py conftest.py
versioneer.py versioneer.py
pystencils/jupytersetup.py src/pystencils/jupytersetup.py
pystencils/cpu/msvc_detection.py src/pystencils/cpu/msvc_detection.py
pystencils/sympy_gmpy_bug_workaround.py src/pystencils/sympy_gmpy_bug_workaround.py
pystencils/cache.py src/pystencils/cache.py
pystencils/pacxx/benchmark.py src/pystencils/pacxx/benchmark.py
pystencils/_version.py src/pystencils/_version.py
venv/ venv/
[report] [report]
...@@ -52,7 +56,7 @@ exclude_lines = ...@@ -52,7 +56,7 @@ exclude_lines =
if __name__ == .__main__.: if __name__ == .__main__.:
skip_covered = True skip_covered = True
fail_under = 87 fail_under = 85
[html] [html]
directory = coverage_report directory = coverage_report
#!/usr/bin/env python3
from contextlib import redirect_stdout
import io
from tests.test_quicktests import (
test_basic_kernel,
test_basic_blocking_staggered,
test_basic_vectorization,
)
quick_tests = [
test_basic_kernel,
test_basic_blocking_staggered,
test_basic_vectorization,
]
if __name__ == "__main__":
print("Running pystencils quicktests")
for qt in quick_tests:
print(f" -> {qt.__name__}")
with redirect_stdout(io.StringIO()):
qt()
# See the docstring in versioneer.py for instructions. Note that you must
# re-run 'versioneer.py setup' after changing this section, and commit the
# resulting files.
[versioneer]
VCS = git
style = pep440
versionfile_source = pystencils/_version.py
versionfile_build = pystencils/_version.py
tag_prefix = release/
parentdir_prefix = pystencils-
import distutils from setuptools import setup, __version__ as setuptools_version
import io
import os
from contextlib import redirect_stdout
from importlib import import_module
import setuptools if int(setuptools_version.split('.')[0]) < 61:
raise Exception(
"[ERROR] pystencils requires at least setuptools version 61 to install.\n"
"If this error occurs during an installation via pip, it is likely that there is a conflict between "
"versions of setuptools installed by pip and the system package manager. "
"In this case, it is recommended to install pystencils into a virtual environment instead."
)
import versioneer import versioneer
try:
import cython # noqa
USE_CYTHON = True
except ImportError:
USE_CYTHON = False
quick_tests = [
'test_datahandling.test_kernel',
'test_blocking_staggered.test_blocking_staggered',
'test_blocking_staggered.test_blocking_staggered',
'test_vectorization.test_vectorization_variable_size',
]
class SimpleTestRunner(distutils.cmd.Command):
"""A custom command to run selected tests"""
description = 'run some quick tests'
user_options = []
@staticmethod
def _run_tests_in_module(test):
"""Short test runner function - to work also if py.test is not installed."""
test = f'pystencils_tests.{test}'
mod, function_name = test.rsplit('.', 1)
if isinstance(mod, str):
mod = import_module(mod)
func = getattr(mod, function_name)
print(f" -> {function_name} in {mod.__name__}")
with redirect_stdout(io.StringIO()):
func()
def initialize_options(self):
pass
def finalize_options(self):
pass
def run(self):
"""Run command."""
for test in quick_tests:
self._run_tests_in_module(test)
def readme():
with open('README.md') as f:
return f.read()
def cython_extensions(*extensions):
from distutils.extension import Extension
if USE_CYTHON:
ext = '.pyx'
result = [Extension(e, [os.path.join(*e.split(".")) + ext]) for e in extensions]
from Cython.Build import cythonize
result = cythonize(result, language_level=3)
return result
elif all([os.path.exists(os.path.join(*e.split(".")) + '.c') for e in extensions]):
ext = '.c'
result = [Extension(e, [os.path.join(*e.split(".")) + ext]) for e in extensions]
return result
else:
return None
def get_cmdclass(): def get_cmdclass():
cmdclass = {"quicktest": SimpleTestRunner} return versioneer.get_cmdclass()
cmdclass.update(versioneer.get_cmdclass())
return cmdclass
setuptools.setup(name='pystencils',
description='Speeding up stencil computations on CPUs and GPUs',
version=versioneer.get_version(),
long_description=readme(),
long_description_content_type="text/markdown",
author='Martin Bauer, Jan Hönig, Markus Holzer',
license='AGPLv3',
author_email='cs10-codegen@fau.de',
url='https://i10git.cs.fau.de/pycodegen/pystencils/',
packages=['pystencils'] + ['pystencils.' + s for s in setuptools.find_packages('pystencils')],
install_requires=['sympy>=1.5.1,<=1.9', 'numpy>=1.8.0', 'appdirs', 'joblib'],
package_data={'pystencils': ['include/*.h',
'kerncraft_coupling/templates/*',
'backends/cuda_known_functions.txt',
'backends/opencl1.1_known_functions.txt',
'boundaries/createindexlistcython.c',
'boundaries/createindexlistcython.pyx']},
ext_modules=cython_extensions("pystencils.boundaries.createindexlistcython"),
classifiers=[
'Development Status :: 4 - Beta',
'Framework :: Jupyter',
'Topic :: Software Development :: Code Generators',
'Topic :: Scientific/Engineering :: Physics',
'Intended Audience :: Developers',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)',
],
project_urls={
"Bug Tracker": "https://i10git.cs.fau.de/pycodegen/pystencils/issues",
"Documentation": "http://pycodegen.pages.walberla.net/pystencils/",
"Source Code": "https://i10git.cs.fau.de/pycodegen/pystencils",
},
extras_require={
'gpu': ['pycuda'],
'opencl': ['pyopencl'],
'alltrafos': ['islpy', 'py-cpuinfo'],
'bench_db': ['blitzdb', 'pymongo', 'pandas'],
'interactive': ['matplotlib', 'ipy_table', 'imageio', 'jupyter', 'pyevtk', 'rich', 'graphviz'],
'autodiff': ['pystencils-autodiff'],
'doc': ['sphinx', 'sphinx_rtd_theme', 'nbsphinx',
'sphinxcontrib-bibtex', 'sphinx_autodoc_typehints', 'pandoc'],
'use_cython': ['Cython'],
'kerncraft': ['osaca', 'kerncraft'],
'llvm_jit': ['llvmlite']
},
tests_require=['pytest',
'pytest-cov',
'pytest-html',
'ansi2html',
'pytest-xdist',
'flake8',
'nbformat',
'nbconvert',
'ipython',
'randomgen>=1.18'],
python_requires=">=3.8", setup(
cmdclass=get_cmdclass() version=versioneer.get_version(),
) cmdclass=get_cmdclass(),
)
...@@ -2,46 +2,39 @@ ...@@ -2,46 +2,39 @@
from .enums import Backend, Target from .enums import Backend, Target
from . import fd from . import fd
from . import stencil as stencil from . import stencil as stencil
from .assignment import Assignment, assignment_from_stencil from .assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil
from .data_types import TypedSymbol from .typing.typed_sympy import TypedSymbol
from .datahandling import create_data_handling
from .display_utils import get_code_obj, get_code_str, show_code, to_dot from .display_utils import get_code_obj, get_code_str, show_code, to_dot
from .field import Field, FieldType, fields from .field import Field, FieldType, fields
from .kernel_decorator import kernel from .config import CreateKernelConfig
from .kernelcreation import ( from .cache import clear_cache
CreateKernelConfig, create_domain_kernel, create_indexed_kernel, create_kernel, create_staggered_kernel) from .kernel_decorator import kernel, kernel_config
from .kernelcreation import create_kernel, create_staggered_kernel
from .simp import AssignmentCollection from .simp import AssignmentCollection
from .slicing import make_slice from .slicing import make_slice
from .spatial_coordinates import x_, x_staggered, x_staggered_vector, x_vector, y_, y_staggered, z_, z_staggered from .spatial_coordinates import x_, x_staggered, x_staggered_vector, x_vector, y_, y_staggered, z_, z_staggered
from .sympyextensions import SymbolCreator from .sympyextensions import SymbolCreator
from .datahandling import create_data_handling
try:
import pystencils_autodiff
autodiff = pystencils_autodiff
except ImportError:
pass
__all__ = ['Field', 'FieldType', 'fields', __all__ = ['Field', 'FieldType', 'fields',
'TypedSymbol', 'TypedSymbol',
'make_slice', 'make_slice',
'create_kernel', 'create_domain_kernel', 'create_indexed_kernel', 'create_staggered_kernel',
'CreateKernelConfig', 'CreateKernelConfig',
'create_kernel', 'create_staggered_kernel',
'Target', 'Backend', 'Target', 'Backend',
'show_code', 'to_dot', 'get_code_obj', 'get_code_str', 'show_code', 'to_dot', 'get_code_obj', 'get_code_str',
'AssignmentCollection', 'AssignmentCollection',
'Assignment', 'Assignment', 'AddAugmentedAssignment',
'assignment_from_stencil', 'assignment_from_stencil',
'SymbolCreator', 'SymbolCreator',
'create_data_handling', 'create_data_handling',
'kernel', 'clear_cache',
'kernel', 'kernel_config',
'x_', 'y_', 'z_', 'x_', 'y_', 'z_',
'x_staggered', 'y_staggered', 'z_staggered', 'x_staggered', 'y_staggered', 'z_staggered',
'x_vector', 'x_staggered_vector', 'x_vector', 'x_staggered_vector',
'fd', 'fd',
'stencil'] 'stencil']
from ._version import get_versions from . import _version
__version__ = _version.get_versions()['version']
__version__ = get_versions()['version']
del get_versions
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
# directories (produced by setup.py build) will contain a much shorter file # directories (produced by setup.py build) will contain a much shorter file
# that just contains the computed version number. # that just contains the computed version number.
# This file is released into the public domain. Generated by # This file is released into the public domain.
# versioneer-0.19 (https://github.com/python-versioneer/python-versioneer) # Generated by versioneer-0.29
# https://github.com/python-versioneer/python-versioneer
"""Git implementation of _version.py.""" """Git implementation of _version.py."""
...@@ -15,9 +16,11 @@ import os ...@@ -15,9 +16,11 @@ import os
import re import re
import subprocess import subprocess
import sys import sys
from typing import Any, Callable, Dict, List, Optional, Tuple
import functools
def get_keywords(): def get_keywords() -> Dict[str, str]:
"""Get the keywords needed to look up the version information.""" """Get the keywords needed to look up the version information."""
# these strings will be replaced by git during git-archive. # these strings will be replaced by git during git-archive.
# setup.py/versioneer.py will grep for the variable names, so they must # setup.py/versioneer.py will grep for the variable names, so they must
...@@ -33,8 +36,15 @@ def get_keywords(): ...@@ -33,8 +36,15 @@ def get_keywords():
class VersioneerConfig: class VersioneerConfig:
"""Container for Versioneer configuration parameters.""" """Container for Versioneer configuration parameters."""
VCS: str
style: str
tag_prefix: str
parentdir_prefix: str
versionfile_source: str
verbose: bool
def get_config():
def get_config() -> VersioneerConfig:
"""Create, populate and return the VersioneerConfig() object.""" """Create, populate and return the VersioneerConfig() object."""
# these strings are filled in when 'setup.py versioneer' creates # these strings are filled in when 'setup.py versioneer' creates
# _version.py # _version.py
...@@ -43,7 +53,7 @@ def get_config(): ...@@ -43,7 +53,7 @@ def get_config():
cfg.style = "pep440" cfg.style = "pep440"
cfg.tag_prefix = "release/" cfg.tag_prefix = "release/"
cfg.parentdir_prefix = "pystencils-" cfg.parentdir_prefix = "pystencils-"
cfg.versionfile_source = "pystencils/_version.py" cfg.versionfile_source = "src/pystencils/_version.py"
cfg.verbose = False cfg.verbose = False
return cfg return cfg
...@@ -52,13 +62,13 @@ class NotThisMethod(Exception): ...@@ -52,13 +62,13 @@ class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario.""" """Exception raised if a method is not valid for the current scenario."""
LONG_VERSION_PY = {} LONG_VERSION_PY: Dict[str, str] = {}
HANDLERS = {} HANDLERS: Dict[str, Dict[str, Callable]] = {}
def register_vcs_handler(vcs, method): # decorator def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator
"""Create decorator to mark a method as the handler of a VCS.""" """Create decorator to mark a method as the handler of a VCS."""
def decorate(f): def decorate(f: Callable) -> Callable:
"""Store f in HANDLERS[vcs][method].""" """Store f in HANDLERS[vcs][method]."""
if vcs not in HANDLERS: if vcs not in HANDLERS:
HANDLERS[vcs] = {} HANDLERS[vcs] = {}
...@@ -67,22 +77,35 @@ def register_vcs_handler(vcs, method): # decorator ...@@ -67,22 +77,35 @@ def register_vcs_handler(vcs, method): # decorator
return decorate return decorate
def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, def run_command(
env=None): commands: List[str],
args: List[str],
cwd: Optional[str] = None,
verbose: bool = False,
hide_stderr: bool = False,
env: Optional[Dict[str, str]] = None,
) -> Tuple[Optional[str], Optional[int]]:
"""Call the given command(s).""" """Call the given command(s)."""
assert isinstance(commands, list) assert isinstance(commands, list)
p = None process = None
for c in commands:
popen_kwargs: Dict[str, Any] = {}
if sys.platform == "win32":
# This hides the console window if pythonw.exe is used
startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
popen_kwargs["startupinfo"] = startupinfo
for command in commands:
try: try:
dispcmd = str([c] + args) dispcmd = str([command] + args)
# remember shell=False, so use git.cmd on windows, not just git # remember shell=False, so use git.cmd on windows, not just git
p = subprocess.Popen([c] + args, cwd=cwd, env=env, process = subprocess.Popen([command] + args, cwd=cwd, env=env,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=(subprocess.PIPE if hide_stderr stderr=(subprocess.PIPE if hide_stderr
else None)) else None), **popen_kwargs)
break break
except EnvironmentError: except OSError as e:
e = sys.exc_info()[1]
if e.errno == errno.ENOENT: if e.errno == errno.ENOENT:
continue continue
if verbose: if verbose:
...@@ -93,16 +116,20 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, ...@@ -93,16 +116,20 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
if verbose: if verbose:
print("unable to find command, tried %s" % (commands,)) print("unable to find command, tried %s" % (commands,))
return None, None return None, None
stdout = p.communicate()[0].strip().decode() stdout = process.communicate()[0].strip().decode()
if p.returncode != 0: if process.returncode != 0:
if verbose: if verbose:
print("unable to run %s (error)" % dispcmd) print("unable to run %s (error)" % dispcmd)
print("stdout was %s" % stdout) print("stdout was %s" % stdout)
return None, p.returncode return None, process.returncode
return stdout, p.returncode return stdout, process.returncode
def versions_from_parentdir(parentdir_prefix, root, verbose): def versions_from_parentdir(
parentdir_prefix: str,
root: str,
verbose: bool,
) -> Dict[str, Any]:
"""Try to determine the version from the parent directory name. """Try to determine the version from the parent directory name.
Source tarballs conventionally unpack into a directory that includes both Source tarballs conventionally unpack into a directory that includes both
...@@ -111,15 +138,14 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): ...@@ -111,15 +138,14 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
""" """
rootdirs = [] rootdirs = []
for i in range(3): for _ in range(3):
dirname = os.path.basename(root) dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix): if dirname.startswith(parentdir_prefix):
return {"version": dirname[len(parentdir_prefix):], return {"version": dirname[len(parentdir_prefix):],
"full-revisionid": None, "full-revisionid": None,
"dirty": False, "error": None, "date": None} "dirty": False, "error": None, "date": None}
else: rootdirs.append(root)
rootdirs.append(root) root = os.path.dirname(root) # up a level
root = os.path.dirname(root) # up a level
if verbose: if verbose:
print("Tried directories %s but none started with prefix %s" % print("Tried directories %s but none started with prefix %s" %
...@@ -128,39 +154,42 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): ...@@ -128,39 +154,42 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
@register_vcs_handler("git", "get_keywords") @register_vcs_handler("git", "get_keywords")
def git_get_keywords(versionfile_abs): def git_get_keywords(versionfile_abs: str) -> Dict[str, str]:
"""Extract version information from the given file.""" """Extract version information from the given file."""
# the code embedded in _version.py can just fetch the value of these # the code embedded in _version.py can just fetch the value of these
# keywords. When used from setup.py, we don't want to import _version.py, # keywords. When used from setup.py, we don't want to import _version.py,
# so we do it with a regexp instead. This function is not used from # so we do it with a regexp instead. This function is not used from
# _version.py. # _version.py.
keywords = {} keywords: Dict[str, str] = {}
try: try:
f = open(versionfile_abs, "r") with open(versionfile_abs, "r") as fobj:
for line in f.readlines(): for line in fobj:
if line.strip().startswith("git_refnames ="): if line.strip().startswith("git_refnames ="):
mo = re.search(r'=\s*"(.*)"', line) mo = re.search(r'=\s*"(.*)"', line)
if mo: if mo:
keywords["refnames"] = mo.group(1) keywords["refnames"] = mo.group(1)
if line.strip().startswith("git_full ="): if line.strip().startswith("git_full ="):
mo = re.search(r'=\s*"(.*)"', line) mo = re.search(r'=\s*"(.*)"', line)
if mo: if mo:
keywords["full"] = mo.group(1) keywords["full"] = mo.group(1)
if line.strip().startswith("git_date ="): if line.strip().startswith("git_date ="):
mo = re.search(r'=\s*"(.*)"', line) mo = re.search(r'=\s*"(.*)"', line)
if mo: if mo:
keywords["date"] = mo.group(1) keywords["date"] = mo.group(1)
f.close() except OSError:
except EnvironmentError:
pass pass
return keywords return keywords
@register_vcs_handler("git", "keywords") @register_vcs_handler("git", "keywords")
def git_versions_from_keywords(keywords, tag_prefix, verbose): def git_versions_from_keywords(
keywords: Dict[str, str],
tag_prefix: str,
verbose: bool,
) -> Dict[str, Any]:
"""Get version information from git keywords.""" """Get version information from git keywords."""
if not keywords: if "refnames" not in keywords:
raise NotThisMethod("no keywords at all, weird") raise NotThisMethod("Short version file found")
date = keywords.get("date") date = keywords.get("date")
if date is not None: if date is not None:
# Use only the last line. Previous lines may contain GPG signature # Use only the last line. Previous lines may contain GPG signature
...@@ -179,11 +208,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): ...@@ -179,11 +208,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
if verbose: if verbose:
print("keywords are unexpanded, not using") print("keywords are unexpanded, not using")
raise NotThisMethod("unexpanded keywords, not a git-archive tarball") raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
refs = set([r.strip() for r in refnames.strip("()").split(",")]) refs = {r.strip() for r in refnames.strip("()").split(",")}
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those. # just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: " TAG = "tag: "
tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
if not tags: if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use # Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d # a heuristic: assume all version tags have a digit. The old git %d
...@@ -192,7 +221,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): ...@@ -192,7 +221,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we # between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and # filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master". # "stabilization", as well as "HEAD" and "master".
tags = set([r for r in refs if re.search(r'\d', r)]) tags = {r for r in refs if re.search(r'\d', r)}
if verbose: if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags)) print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose: if verbose:
...@@ -201,6 +230,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): ...@@ -201,6 +230,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# sorting will prefer e.g. "2.0" over "2.0rc1" # sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix): if ref.startswith(tag_prefix):
r = ref[len(tag_prefix):] r = ref[len(tag_prefix):]
# Filter out refs that exactly match prefix or that don't start
# with a number once the prefix is stripped (mostly a concern
# when prefix is '')
if not re.match(r'\d', r):
continue
if verbose: if verbose:
print("picking %s" % r) print("picking %s" % r)
return {"version": r, return {"version": r,
...@@ -216,7 +250,12 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): ...@@ -216,7 +250,12 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
@register_vcs_handler("git", "pieces_from_vcs") @register_vcs_handler("git", "pieces_from_vcs")
def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): def git_pieces_from_vcs(
tag_prefix: str,
root: str,
verbose: bool,
runner: Callable = run_command
) -> Dict[str, Any]:
"""Get version from 'git describe' in the root of the source tree. """Get version from 'git describe' in the root of the source tree.
This only gets called if the git-archive 'subst' keywords were *not* This only gets called if the git-archive 'subst' keywords were *not*
...@@ -227,8 +266,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): ...@@ -227,8 +266,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if sys.platform == "win32": if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"] GITS = ["git.cmd", "git.exe"]
out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, # GIT_DIR can interfere with correct operation of Versioneer.
hide_stderr=True) # It may be intended to be passed to the Versioneer-versioned project,
# but that should not change where we get our version from.
env = os.environ.copy()
env.pop("GIT_DIR", None)
runner = functools.partial(runner, env=env)
_, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
hide_stderr=not verbose)
if rc != 0: if rc != 0:
if verbose: if verbose:
print("Directory %s not under git control" % root) print("Directory %s not under git control" % root)
...@@ -236,24 +282,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): ...@@ -236,24 +282,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM) # if there isn't one, this yields HEX[-dirty] (no NUM)
describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", describe_out, rc = runner(GITS, [
"--always", "--long", "describe", "--tags", "--dirty", "--always", "--long",
"--match", "%s*" % tag_prefix], "--match", f"{tag_prefix}[[:digit:]]*"
cwd=root) ], cwd=root)
# --long was added in git-1.5.5 # --long was added in git-1.5.5
if describe_out is None: if describe_out is None:
raise NotThisMethod("'git describe' failed") raise NotThisMethod("'git describe' failed")
describe_out = describe_out.strip() describe_out = describe_out.strip()
full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
if full_out is None: if full_out is None:
raise NotThisMethod("'git rev-parse' failed") raise NotThisMethod("'git rev-parse' failed")
full_out = full_out.strip() full_out = full_out.strip()
pieces = {} pieces: Dict[str, Any] = {}
pieces["long"] = full_out pieces["long"] = full_out
pieces["short"] = full_out[:7] # maybe improved later pieces["short"] = full_out[:7] # maybe improved later
pieces["error"] = None pieces["error"] = None
branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
cwd=root)
# --abbrev-ref was added in git-1.6.3
if rc != 0 or branch_name is None:
raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
branch_name = branch_name.strip()
if branch_name == "HEAD":
# If we aren't exactly on a branch, pick a branch which represents
# the current commit. If all else fails, we are on a branchless
# commit.
branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
# --contains was added in git-1.5.4
if rc != 0 or branches is None:
raise NotThisMethod("'git branch --contains' returned error")
branches = branches.split("\n")
# Remove the first line if we're running detached
if "(" in branches[0]:
branches.pop(0)
# Strip off the leading "* " from the list of branches.
branches = [branch[2:] for branch in branches]
if "master" in branches:
branch_name = "master"
elif not branches:
branch_name = None
else:
# Pick the first branch that is returned. Good or bad.
branch_name = branches[0]
pieces["branch"] = branch_name
# parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
# TAG might have hyphens. # TAG might have hyphens.
git_describe = describe_out git_describe = describe_out
...@@ -270,7 +349,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): ...@@ -270,7 +349,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# TAG-NUM-gHEX # TAG-NUM-gHEX
mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
if not mo: if not mo:
# unparseable. Maybe git-describe is misbehaving? # unparsable. Maybe git-describe is misbehaving?
pieces["error"] = ("unable to parse git-describe output: '%s'" pieces["error"] = ("unable to parse git-describe output: '%s'"
% describe_out) % describe_out)
return pieces return pieces
...@@ -295,13 +374,11 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): ...@@ -295,13 +374,11 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
else: else:
# HEX: no tags # HEX: no tags
pieces["closest-tag"] = None pieces["closest-tag"] = None
count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root)
cwd=root) pieces["distance"] = len(out.split()) # total number of commits
pieces["distance"] = int(count_out) # total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords() # commit date: see ISO-8601 comment in git_versions_from_keywords()
date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
cwd=root)[0].strip()
# Use only the last line. Previous lines may contain GPG signature # Use only the last line. Previous lines may contain GPG signature
# information. # information.
date = date.splitlines()[-1] date = date.splitlines()[-1]
...@@ -310,14 +387,14 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): ...@@ -310,14 +387,14 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
return pieces return pieces
def plus_or_dot(pieces): def plus_or_dot(pieces: Dict[str, Any]) -> str:
"""Return a + if we don't already have one, else return a .""" """Return a + if we don't already have one, else return a ."""
if "+" in pieces.get("closest-tag", ""): if "+" in pieces.get("closest-tag", ""):
return "." return "."
return "+" return "+"
def render_pep440(pieces): def render_pep440(pieces: Dict[str, Any]) -> str:
"""Build up version string, with post-release "local version identifier". """Build up version string, with post-release "local version identifier".
Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
...@@ -342,23 +419,71 @@ def render_pep440(pieces): ...@@ -342,23 +419,71 @@ def render_pep440(pieces):
return rendered return rendered
def render_pep440_pre(pieces): def render_pep440_branch(pieces: Dict[str, Any]) -> str:
"""TAG[.post0.devDISTANCE] -- No -dirty. """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
The ".dev0" means not master branch. Note that .dev0 sorts backwards
(a feature branch will appear "older" than the master branch).
Exceptions: Exceptions:
1: no tags. 0.post0.devDISTANCE 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
""" """
if pieces["closest-tag"]: if pieces["closest-tag"]:
rendered = pieces["closest-tag"] rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0"
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += "+untagged.%d.g%s" % (pieces["distance"],
pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]:
"""Split pep440 version string at the post-release segment.
Returns the release segments before the post-release and the
post-release version number (or -1 if no post-release segment is present).
"""
vc = str.split(ver, ".post")
return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
def render_pep440_pre(pieces: Dict[str, Any]) -> str:
"""TAG[.postN.devDISTANCE] -- No -dirty.
Exceptions:
1: no tags. 0.post0.devDISTANCE
"""
if pieces["closest-tag"]:
if pieces["distance"]: if pieces["distance"]:
rendered += ".post0.dev%d" % pieces["distance"] # update the post release segment
tag_version, post_version = pep440_split_post(pieces["closest-tag"])
rendered = tag_version
if post_version is not None:
rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"])
else:
rendered += ".post0.dev%d" % (pieces["distance"])
else:
# no commits, use the tag as the version
rendered = pieces["closest-tag"]
else: else:
# exception #1 # exception #1
rendered = "0.post0.dev%d" % pieces["distance"] rendered = "0.post0.dev%d" % pieces["distance"]
return rendered return rendered
def render_pep440_post(pieces): def render_pep440_post(pieces: Dict[str, Any]) -> str:
"""TAG[.postDISTANCE[.dev0]+gHEX] . """TAG[.postDISTANCE[.dev0]+gHEX] .
The ".dev0" means dirty. Note that .dev0 sorts backwards The ".dev0" means dirty. Note that .dev0 sorts backwards
...@@ -385,7 +510,36 @@ def render_pep440_post(pieces): ...@@ -385,7 +510,36 @@ def render_pep440_post(pieces):
return rendered return rendered
def render_pep440_old(pieces): def render_pep440_post_branch(pieces: Dict[str, Any]) -> str:
"""TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
The ".dev0" means not master branch.
Exceptions:
1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += ".post%d" % pieces["distance"]
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "g%s" % pieces["short"]
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0.post%d" % pieces["distance"]
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += "+g%s" % pieces["short"]
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def render_pep440_old(pieces: Dict[str, Any]) -> str:
"""TAG[.postDISTANCE[.dev0]] . """TAG[.postDISTANCE[.dev0]] .
The ".dev0" means dirty. The ".dev0" means dirty.
...@@ -407,7 +561,7 @@ def render_pep440_old(pieces): ...@@ -407,7 +561,7 @@ def render_pep440_old(pieces):
return rendered return rendered
def render_git_describe(pieces): def render_git_describe(pieces: Dict[str, Any]) -> str:
"""TAG[-DISTANCE-gHEX][-dirty]. """TAG[-DISTANCE-gHEX][-dirty].
Like 'git describe --tags --dirty --always'. Like 'git describe --tags --dirty --always'.
...@@ -427,7 +581,7 @@ def render_git_describe(pieces): ...@@ -427,7 +581,7 @@ def render_git_describe(pieces):
return rendered return rendered
def render_git_describe_long(pieces): def render_git_describe_long(pieces: Dict[str, Any]) -> str:
"""TAG-DISTANCE-gHEX[-dirty]. """TAG-DISTANCE-gHEX[-dirty].
Like 'git describe --tags --dirty --always -long'. Like 'git describe --tags --dirty --always -long'.
...@@ -447,7 +601,7 @@ def render_git_describe_long(pieces): ...@@ -447,7 +601,7 @@ def render_git_describe_long(pieces):
return rendered return rendered
def render(pieces, style): def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]:
"""Render the given version pieces into the requested style.""" """Render the given version pieces into the requested style."""
if pieces["error"]: if pieces["error"]:
return {"version": "unknown", return {"version": "unknown",
...@@ -461,10 +615,14 @@ def render(pieces, style): ...@@ -461,10 +615,14 @@ def render(pieces, style):
if style == "pep440": if style == "pep440":
rendered = render_pep440(pieces) rendered = render_pep440(pieces)
elif style == "pep440-branch":
rendered = render_pep440_branch(pieces)
elif style == "pep440-pre": elif style == "pep440-pre":
rendered = render_pep440_pre(pieces) rendered = render_pep440_pre(pieces)
elif style == "pep440-post": elif style == "pep440-post":
rendered = render_pep440_post(pieces) rendered = render_pep440_post(pieces)
elif style == "pep440-post-branch":
rendered = render_pep440_post_branch(pieces)
elif style == "pep440-old": elif style == "pep440-old":
rendered = render_pep440_old(pieces) rendered = render_pep440_old(pieces)
elif style == "git-describe": elif style == "git-describe":
...@@ -479,7 +637,7 @@ def render(pieces, style): ...@@ -479,7 +637,7 @@ def render(pieces, style):
"date": pieces.get("date")} "date": pieces.get("date")}
def get_versions(): def get_versions() -> Dict[str, Any]:
"""Get version information or return default if unable to do so.""" """Get version information or return default if unable to do so."""
# I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
# __file__, we can work backwards from there to the root. Some # __file__, we can work backwards from there to the root. Some
...@@ -500,7 +658,7 @@ def get_versions(): ...@@ -500,7 +658,7 @@ def get_versions():
# versionfile_source is the relative path from the top of the source # versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert # tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__. # this to find the root from __file__.
for i in cfg.versionfile_source.split('/'): for _ in cfg.versionfile_source.split('/'):
root = os.path.dirname(root) root = os.path.dirname(root)
except NameError: except NameError:
return {"version": "0+unknown", "full-revisionid": None, return {"version": "0+unknown", "full-revisionid": None,
......
import numpy as np import numpy as np
from pystencils.data_types import BasicType
def aligned_empty(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, order='C', align_inner_coordinate=True): def aligned_empty(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, order='C', align_inner_coordinate=True):
...@@ -21,26 +20,26 @@ def aligned_empty(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, o ...@@ -21,26 +20,26 @@ def aligned_empty(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, o
from pystencils.backends.simd_instruction_sets import (get_supported_instruction_sets, get_cacheline_size, from pystencils.backends.simd_instruction_sets import (get_supported_instruction_sets, get_cacheline_size,
get_vector_instruction_set) get_vector_instruction_set)
type_name = BasicType.numpy_name_to_c(np.dtype(dtype).name)
instruction_sets = get_supported_instruction_sets() instruction_sets = get_supported_instruction_sets()
if instruction_sets is None: if instruction_sets is None:
byte_alignment = 64 byte_alignment = 64
elif byte_alignment == 'cacheline': elif byte_alignment == 'cacheline':
cacheline_sizes = [get_cacheline_size(is_name) for is_name in instruction_sets] cacheline_sizes = [get_cacheline_size(is_name) for is_name in instruction_sets]
if all([s is None for s in cacheline_sizes]): if all([s is None for s in cacheline_sizes]) or \
widths = [get_vector_instruction_set(type_name, is_name)['width'] * np.dtype(dtype).itemsize max([s for s in cacheline_sizes if s is not None]) > 0x100000:
widths = [get_vector_instruction_set(dtype, is_name)['width'] * np.dtype(dtype).itemsize
for is_name in instruction_sets for is_name in instruction_sets
if type(get_vector_instruction_set(type_name, is_name)['width']) is int] if type(get_vector_instruction_set(dtype, is_name)['width']) is int]
byte_alignment = 64 if all([s is None for s in widths]) else max(widths) byte_alignment = 64 if all([s is None for s in widths]) else max(widths)
else: else:
byte_alignment = max([s for s in cacheline_sizes if s is not None]) byte_alignment = max([s for s in cacheline_sizes if s is not None])
elif not any([type(get_vector_instruction_set(type_name, is_name)['width']) is int elif not any([type(get_vector_instruction_set(dtype, is_name)['width']) is int
for is_name in instruction_sets]): for is_name in instruction_sets]):
byte_alignment = 64 byte_alignment = 64
else: else:
byte_alignment = max([get_vector_instruction_set(type_name, is_name)['width'] * np.dtype(dtype).itemsize byte_alignment = max([get_vector_instruction_set(dtype, is_name)['width'] * np.dtype(dtype).itemsize
for is_name in instruction_sets for is_name in instruction_sets
if type(get_vector_instruction_set(type_name, is_name)['width']) is int]) if type(get_vector_instruction_set(dtype, is_name)['width']) is int])
if (not align_inner_coordinate) or (not hasattr(shape, '__len__')): if (not align_inner_coordinate) or (not hasattr(shape, '__len__')):
size = np.prod(shape) size = np.prod(shape)
d = np.dtype(dtype) d = np.dtype(dtype)
...@@ -78,7 +77,7 @@ def aligned_empty(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, o ...@@ -78,7 +77,7 @@ def aligned_empty(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, o
return tmp return tmp
def aligned_zeros(shape, byte_alignment=True, dtype=float, byte_offset=0, order='C', align_inner_coordinate=True): def aligned_zeros(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, order='C', align_inner_coordinate=True):
arr = aligned_empty(shape, dtype=dtype, byte_offset=byte_offset, arr = aligned_empty(shape, dtype=dtype, byte_offset=byte_offset,
order=order, byte_alignment=byte_alignment, align_inner_coordinate=align_inner_coordinate) order=order, byte_alignment=byte_alignment, align_inner_coordinate=align_inner_coordinate)
x = np.zeros((), arr.dtype) x = np.zeros((), arr.dtype)
...@@ -86,7 +85,7 @@ def aligned_zeros(shape, byte_alignment=True, dtype=float, byte_offset=0, order= ...@@ -86,7 +85,7 @@ def aligned_zeros(shape, byte_alignment=True, dtype=float, byte_offset=0, order=
return arr return arr
def aligned_ones(shape, byte_alignment=True, dtype=float, byte_offset=0, order='C', align_inner_coordinate=True): def aligned_ones(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, order='C', align_inner_coordinate=True):
arr = aligned_empty(shape, dtype=dtype, byte_offset=byte_offset, arr = aligned_empty(shape, dtype=dtype, byte_offset=byte_offset,
order=order, byte_alignment=byte_alignment, align_inner_coordinate=align_inner_coordinate) order=order, byte_alignment=byte_alignment, align_inner_coordinate=align_inner_coordinate)
x = np.ones((), arr.dtype) x = np.ones((), arr.dtype)
......
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from sympy.codegen.ast import Assignment from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment
from sympy.printing.latex import LatexPrinter from sympy.printing.latex import LatexPrinter
__all__ = ['Assignment', 'assignment_from_stencil'] __all__ = ['Assignment', 'AugmentedAssignment', 'AddAugmentedAssignment', 'assignment_from_stencil']
def print_assignment_latex(printer, expr): def print_assignment_latex(printer, expr):
binop = f"{expr.binop}=" if isinstance(expr, AugmentedAssignment) else ''
"""sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer""" """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
printed_lhs = printer.doprint(expr.lhs) printed_lhs = printer.doprint(expr.lhs)
printed_rhs = printer.doprint(expr.rhs) printed_rhs = printer.doprint(expr.rhs)
return r"{printed_lhs} \leftarrow {printed_rhs}".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs) return fr"{printed_lhs} \leftarrow_{{{binop}}} {printed_rhs}"
def assignment_str(assignment): def assignment_str(assignment):
return r"{lhs} ← {rhs}".format(lhs=assignment.lhs, rhs=assignment.rhs) op = f"{assignment.binop}=" if isinstance(assignment, AugmentedAssignment) else ''
return fr"{assignment.lhs} {op} {assignment.rhs}"
_old_new = sp.codegen.ast.Assignment.__new__ _old_new = sp.codegen.ast.Assignment.__new__
# TODO Typing Part2 add default type, defult_float_type, default_int_type and use sane defaults
def _Assignment__new__(cls, lhs, rhs, *args, **kwargs): def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)): if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)):
assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!' assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!'
...@@ -31,20 +34,10 @@ Assignment.__str__ = assignment_str ...@@ -31,20 +34,10 @@ Assignment.__str__ = assignment_str
Assignment.__new__ = _Assignment__new__ Assignment.__new__ = _Assignment__new__
LatexPrinter._print_Assignment = print_assignment_latex LatexPrinter._print_Assignment = print_assignment_latex
sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self)) AugmentedAssignment.__str__ = assignment_str
LatexPrinter._print_AugmentedAssignment = print_assignment_latex
# Apparently, in SymPy 1.4 Assignment.__hash__ is not implemented. This has been fixed in current master
try:
sympy_version = sp.__version__.split('.')
if int(sympy_version[0]) <= 1 and int(sympy_version[1]) <= 4: sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
def hash_fun(self):
return hash((self.lhs, self.rhs))
Assignment.__hash__ = hash_fun
except Exception:
pass
def assignment_from_stencil(stencil_array, input_field, output_field, def assignment_from_stencil(stencil_array, input_field, output_field,
......
...@@ -5,12 +5,12 @@ from typing import Any, List, Optional, Sequence, Set, Union ...@@ -5,12 +5,12 @@ from typing import Any, List, Optional, Sequence, Set, Union
import sympy as sp import sympy as sp
import pystencils from pystencils.assignment import Assignment
from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type
from pystencils.enums import Target, Backend from pystencils.enums import Target, Backend
from pystencils.field import Field from pystencils.field import Field
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils.sympyextensions import fast_subs from pystencils.sympyextensions import fast_subs
from pystencils.typing import (create_type, get_next_parent_of_type,
FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol, CFunction)
NodeOrExpr = Union['Node', sp.Expr] NodeOrExpr = Union['Node', sp.Expr]
...@@ -193,6 +193,10 @@ class KernelFunction(Node): ...@@ -193,6 +193,10 @@ class KernelFunction(Node):
# function that compiles the node to a Python callable, is set by the backends # function that compiles the node to a Python callable, is set by the backends
self._compile_function = compile_function self._compile_function = compile_function
self.assignments = assignments self.assignments = assignments
# If nontemporal stores are activated together with the Neon instruction set it results in cacheline zeroing
# For cacheline zeroing the information of the field size for each field is needed. Thus, in this case
# all field sizes are kernel parameters and not just the common field size used for the loops
self.use_all_written_field_sizes = False
@property @property
def target(self): def target(self):
...@@ -228,13 +232,13 @@ class KernelFunction(Node): ...@@ -228,13 +232,13 @@ class KernelFunction(Node):
@property @property
def fields_accessed(self) -> Set[Field]: def fields_accessed(self) -> Set[Field]:
"""Set of Field instances: fields which are accessed inside this kernel function""" """Set of Field instances: fields which are accessed inside this kernel function"""
from pystencils.interpolation_astnodes import InterpolatorAccess return set(o.field for o in itertools.chain(self.atoms(ResolvedFieldAccess)))
return set(o.field for o in itertools.chain(self.atoms(ResolvedFieldAccess), self.atoms(InterpolatorAccess)))
@property @property
def fields_written(self) -> Set[Field]: def fields_written(self) -> Set[Field]:
assignments = self.atoms(SympyAssignment) assignments = self.atoms(SympyAssignment)
return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)} return set().union(itertools.chain.from_iterable([f.field for f in a.lhs.free_symbols if hasattr(f, 'field')]
for a in assignments))
@property @property
def fields_read(self) -> Set[Field]: def fields_read(self) -> Set[Field]:
...@@ -248,6 +252,11 @@ class KernelFunction(Node): ...@@ -248,6 +252,11 @@ class KernelFunction(Node):
This function is expensive, cache the result where possible! This function is expensive, cache the result where possible!
""" """
field_map = {f.name: f for f in self.fields_accessed} field_map = {f.name: f for f in self.fields_accessed}
sizes = set()
if self.use_all_written_field_sizes:
sizes = set().union(*(a.shape[:a.spatial_dimensions] for a in self.fields_written))
sizes = filter(lambda s: isinstance(s, FieldShapeSymbol), sizes)
def get_fields(symbol): def get_fields(symbol):
if hasattr(symbol, 'field_name'): if hasattr(symbol, 'field_name'):
...@@ -257,9 +266,13 @@ class KernelFunction(Node): ...@@ -257,9 +266,13 @@ class KernelFunction(Node):
return () return ()
argument_symbols = self._body.undefined_symbols - self.global_variables argument_symbols = self._body.undefined_symbols - self.global_variables
argument_symbols.update(sizes)
parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols] parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
if hasattr(self, 'indexing'): if hasattr(self, 'indexing'):
parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()] parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()]
# Exclude paramters of type CFunction. These parameters will result in a C function call that will be handled
# by including a respective header file in the compute kernel. Hence, it is not a free parameter.
parameters = [p for p in parameters if not isinstance(p.symbol, CFunction)]
parameters.sort(key=lambda p: p.symbol.name) parameters.sort(key=lambda p: p.symbol.name)
return parameters return parameters
...@@ -293,8 +306,10 @@ class SkipIteration(Node): ...@@ -293,8 +306,10 @@ class SkipIteration(Node):
class Block(Node): class Block(Node):
def __init__(self, nodes: List[Node]): def __init__(self, nodes: Union[Node, List[Node]]):
super(Block, self).__init__() super(Block, self).__init__()
if not isinstance(nodes, list):
nodes = [nodes]
self._nodes = nodes self._nodes = nodes
self.parent = None self.parent = None
for n in self._nodes: for n in self._nodes:
...@@ -333,14 +348,6 @@ class Block(Node): ...@@ -333,14 +348,6 @@ class Block(Node):
assert self._nodes.count(insert_before) == 1 assert self._nodes.count(insert_before) == 1
idx = self._nodes.index(insert_before) idx = self._nodes.index(insert_before)
# move all assignment (definitions to the top)
if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
while idx > 0:
pn = self._nodes[idx - 1]
if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
idx -= 1
else:
break
if not if_not_exists or self._nodes[idx] != new_node: if not if_not_exists or self._nodes[idx] != new_node:
self._nodes.insert(idx, new_node) self._nodes.insert(idx, new_node)
...@@ -349,14 +356,6 @@ class Block(Node): ...@@ -349,14 +356,6 @@ class Block(Node):
assert self._nodes.count(insert_after) == 1 assert self._nodes.count(insert_after) == 1
idx = self._nodes.index(insert_after) + 1 idx = self._nodes.index(insert_after) + 1
# move all assignment (definitions to the top)
if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
while idx > 0:
pn = self._nodes[idx - 1]
if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
idx -= 1
else:
break
if not if_not_exists or not (self._nodes[idx - 1] == new_node if not if_not_exists or not (self._nodes[idx - 1] == new_node
or (idx < len(self._nodes) and self._nodes[idx] == new_node)): or (idx < len(self._nodes) and self._nodes[idx] == new_node)):
self._nodes.insert(idx, new_node) self._nodes.insert(idx, new_node)
...@@ -391,7 +390,7 @@ class Block(Node): ...@@ -391,7 +390,7 @@ class Block(Node):
def symbols_defined(self): def symbols_defined(self):
result = set() result = set()
for a in self.args: for a in self.args:
if isinstance(a, pystencils.Assignment): if isinstance(a, Assignment):
result.update(a.free_symbols) result.update(a.free_symbols)
else: else:
result.update(a.symbols_defined) result.update(a.symbols_defined)
...@@ -402,7 +401,7 @@ class Block(Node): ...@@ -402,7 +401,7 @@ class Block(Node):
result = set() result = set()
defined_symbols = set() defined_symbols = set()
for a in self.args: for a in self.args:
if isinstance(a, pystencils.Assignment): if isinstance(a, Assignment):
result.update(a.free_symbols) result.update(a.free_symbols)
defined_symbols.update({a.lhs}) defined_symbols.update({a.lhs})
else: else:
...@@ -432,7 +431,7 @@ class LoopOverCoordinate(Node): ...@@ -432,7 +431,7 @@ class LoopOverCoordinate(Node):
LOOP_COUNTER_NAME_PREFIX = "ctr" LOOP_COUNTER_NAME_PREFIX = "ctr"
BLOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr" BLOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr"
def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False): def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False, custom_loop_ctr=None):
super(LoopOverCoordinate, self).__init__(parent=None) super(LoopOverCoordinate, self).__init__(parent=None)
self.body = body self.body = body
body.parent = self body.parent = self
...@@ -443,11 +442,12 @@ class LoopOverCoordinate(Node): ...@@ -443,11 +442,12 @@ class LoopOverCoordinate(Node):
self.body.parent = self self.body.parent = self
self.prefix_lines = [] self.prefix_lines = []
self.is_block_loop = is_block_loop self.is_block_loop = is_block_loop
self.custom_loop_ctr = custom_loop_ctr
def new_loop_with_different_body(self, new_body): def new_loop_with_different_body(self, new_body):
result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop, result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop,
self.step, self.is_block_loop) self.step, self.is_block_loop, self.custom_loop_ctr)
result.prefix_lines = [l for l in self.prefix_lines] result.prefix_lines = [prefix_line for prefix_line in self.prefix_lines]
return result return result
def subs(self, subs_dict): def subs(self, subs_dict):
...@@ -509,10 +509,13 @@ class LoopOverCoordinate(Node): ...@@ -509,10 +509,13 @@ class LoopOverCoordinate(Node):
@property @property
def loop_counter_name(self): def loop_counter_name(self):
if self.is_block_loop: if self.custom_loop_ctr:
return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over) return self.custom_loop_ctr.name
else: else:
return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over) if self.is_block_loop:
return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
else:
return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
@staticmethod @staticmethod
def is_loop_counter_symbol(symbol): def is_loop_counter_symbol(symbol):
...@@ -536,14 +539,16 @@ class LoopOverCoordinate(Node): ...@@ -536,14 +539,16 @@ class LoopOverCoordinate(Node):
@property @property
def loop_counter_symbol(self): def loop_counter_symbol(self):
if self.is_block_loop: if self.custom_loop_ctr:
return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over) return self.custom_loop_ctr
else: else:
return self.get_loop_counter_symbol(self.coordinate_to_loop_over) if self.is_block_loop:
return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
else:
return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
@property @property
def is_outermost_loop(self): def is_outermost_loop(self):
from pystencils.transformations import get_next_parent_of_type
return get_next_parent_of_type(self, LoopOverCoordinate) is None return get_next_parent_of_type(self, LoopOverCoordinate) is None
@property @property
...@@ -566,13 +571,14 @@ class SympyAssignment(Node): ...@@ -566,13 +571,14 @@ class SympyAssignment(Node):
def __init__(self, lhs_symbol, rhs_expr, is_const=True, use_auto=False): def __init__(self, lhs_symbol, rhs_expr, is_const=True, use_auto=False):
super(SympyAssignment, self).__init__(parent=None) super(SympyAssignment, self).__init__(parent=None)
self._lhs_symbol = sp.sympify(lhs_symbol) self._lhs_symbol = sp.sympify(lhs_symbol)
self.rhs = sp.sympify(rhs_expr) self._rhs = sp.sympify(rhs_expr)
self._is_const = is_const self._is_const = is_const
self._is_declaration = self.__is_declaration() self._is_declaration = self.__is_declaration()
self.use_auto = use_auto self._use_auto = use_auto
def __is_declaration(self): def __is_declaration(self):
if isinstance(self._lhs_symbol, cast_func): from pystencils.typing import CastFunc
if isinstance(self._lhs_symbol, CastFunc):
return False return False
if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)): if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)):
return False return False
...@@ -582,15 +588,28 @@ class SympyAssignment(Node): ...@@ -582,15 +588,28 @@ class SympyAssignment(Node):
def lhs(self): def lhs(self):
return self._lhs_symbol return self._lhs_symbol
@property
def rhs(self):
return self._rhs
@lhs.setter @lhs.setter
def lhs(self, new_value): def lhs(self, new_value):
self._lhs_symbol = new_value self._lhs_symbol = new_value
self._is_declaration = self.__is_declaration() self._is_declaration = self.__is_declaration()
@rhs.setter
def rhs(self, new_rhs_expr):
self._rhs = new_rhs_expr
def subs(self, subs_dict): def subs(self, subs_dict):
self.lhs = fast_subs(self.lhs, subs_dict) self.lhs = fast_subs(self.lhs, subs_dict)
self.rhs = fast_subs(self.rhs, subs_dict) self.rhs = fast_subs(self.rhs, subs_dict)
def fast_subs(self, subs_dict, skip=None):
self.lhs = fast_subs(self.lhs, subs_dict, skip)
self.rhs = fast_subs(self.rhs, subs_dict, skip)
return self
def optimize(self, optimizations): def optimize(self, optimizations):
try: try:
from sympy.codegen.rewriting import optimize from sympy.codegen.rewriting import optimize
...@@ -600,7 +619,7 @@ class SympyAssignment(Node): ...@@ -600,7 +619,7 @@ class SympyAssignment(Node):
@property @property
def args(self): def args(self):
return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)] return [self._lhs_symbol, self.rhs]
@property @property
def symbols_defined(self): def symbols_defined(self):
...@@ -617,9 +636,10 @@ class SympyAssignment(Node): ...@@ -617,9 +636,10 @@ class SympyAssignment(Node):
if isinstance(symbol, Field.Access): if isinstance(symbol, Field.Access):
for i in range(len(symbol.offsets)): for i in range(len(symbol.offsets)):
loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i)) loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
result = {r for r in result if not isinstance(r, TypedImaginaryUnit)}
result.update(loop_counters) result.update(loop_counters)
result.update(self._lhs_symbol.atoms(sp.Symbol)) result.update(self._lhs_symbol.atoms(sp.Symbol))
return result return result
@property @property
...@@ -630,6 +650,10 @@ class SympyAssignment(Node): ...@@ -630,6 +650,10 @@ class SympyAssignment(Node):
def is_const(self): def is_const(self):
return self._is_const return self._is_const
@property
def use_auto(self):
return self._use_auto
def replace(self, child, replacement): def replace(self, child, replacement):
if child == self.lhs: if child == self.lhs:
replacement.parent = self replacement.parent = self
...@@ -652,7 +676,7 @@ class SympyAssignment(Node): ...@@ -652,7 +676,7 @@ class SympyAssignment(Node):
return hash((self.lhs, self.rhs)) return hash((self.lhs, self.rhs))
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs) return type(self) is type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs)
class ResolvedFieldAccess(sp.Indexed): class ResolvedFieldAccess(sp.Indexed):
......
...@@ -6,9 +6,3 @@ try: ...@@ -6,9 +6,3 @@ try:
__all__.append('print_dot') __all__.append('print_dot')
except ImportError: except ImportError:
pass pass
try:
from .llvm import generate_llvm # NOQA
__all__.append('generate_llvm')
except ImportError:
pass
from pystencils.typing import CFunction
def get_argument_string(function_shortcut, first=''): def get_argument_string(function_shortcut, first=''):
args = function_shortcut[function_shortcut.index('[') + 1: -1] args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "(" arg_string = "("
...@@ -16,10 +19,13 @@ def get_argument_string(function_shortcut, first=''): ...@@ -16,10 +19,13 @@ def get_argument_string(function_shortcut, first=''):
def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
if instruction_set != 'neon' and not instruction_set.startswith('sve'): if instruction_set not in ['neon', 'sme'] and not instruction_set.startswith('sve'):
raise NotImplementedError(instruction_set) raise NotImplementedError(instruction_set)
if instruction_set == 'sve': if instruction_set in ['sve', 'sve2', 'sme']:
cmp = 'cmp'
elif instruction_set.startswith('sve2') and instruction_set not in ('sve256', 'sve2048'):
cmp = 'cmp' cmp = 'cmp'
bitwidth = int(instruction_set[4:])
elif instruction_set.startswith('sve'): elif instruction_set.startswith('sve'):
cmp = 'cmp' cmp = 'cmp'
bitwidth = int(instruction_set[3:]) bitwidth = int(instruction_set[3:])
...@@ -35,9 +41,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -35,9 +41,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
'sqrt': 'sqrt[0]', 'sqrt': 'sqrt[0]',
'loadU': 'ld1[0]', 'loadU': 'ld1[0]',
'loadA': 'ld1[0]',
'storeU': 'st1[0, 1]', 'storeU': 'st1[0, 1]',
'storeA': 'st1[0, 1]',
'abs': 'abs[0]', 'abs': 'abs[0]',
'==': f'{cmp}eq[0, 1]', '==': f'{cmp}eq[0, 1]',
...@@ -54,7 +58,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -54,7 +58,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result = dict() result = dict()
if instruction_set == 'sve': if instruction_set in ['sve', 'sve2', 'sme']:
width = 'svcntd()' if data_type == 'double' else 'svcntw()' width = 'svcntd()' if data_type == 'double' else 'svcntw()'
intwidth = 'svcntw()' intwidth = 'svcntw()'
result['bytes'] = 'svcntb()' result['bytes'] = 'svcntb()'
...@@ -62,14 +66,15 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -62,14 +66,15 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
width = bitwidth // bits[data_type] width = bitwidth // bits[data_type]
intwidth = bitwidth // bits['int'] intwidth = bitwidth // bits['int']
result['bytes'] = bitwidth // 8 result['bytes'] = bitwidth // 8
if instruction_set.startswith('sve'): if instruction_set.startswith('sve') or instruction_set == 'sme':
base_names['stream'] = 'stnt1[0, 1]'
prefix = 'sv' prefix = 'sv'
suffix = f'_f{bits[data_type]}' suffix = f'_f{bits[data_type]}'
elif instruction_set == 'neon': elif instruction_set == 'neon':
prefix = 'v' prefix = 'v'
suffix = f'q_f{bits[data_type]}' suffix = f'q_f{bits[data_type]}'
if instruction_set == 'sve': if instruction_set in ['sve', 'sve2', 'sme']:
predicate = f'{prefix}whilelt_b{bits[data_type]}_u64({{loop_counter}}, {{loop_stop}})' predicate = f'{prefix}whilelt_b{bits[data_type]}_u64({{loop_counter}}, {{loop_stop}})'
int_predicate = f'{prefix}whilelt_b{bits["int"]}_u64({{loop_counter}}, {{loop_stop}})' int_predicate = f'{prefix}whilelt_b{bits["int"]}_u64({{loop_counter}}, {{loop_stop}})'
else: else:
...@@ -88,33 +93,36 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -88,33 +93,36 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result[intrinsic_id] = prefix + name + suffix + undef + arg_string result[intrinsic_id] = prefix + name + suffix + undef + arg_string
if instruction_set == 'sve': if instruction_set in ['sve', 'sve2', 'sme']:
from pystencils.backends.cbackend import CFunction
result['width'] = CFunction(width, "int") result['width'] = CFunction(width, "int")
result['intwidth'] = CFunction(intwidth, "int") result['intwidth'] = CFunction(intwidth, "int")
else: else:
result['width'] = width result['width'] = width
result['intwidth'] = intwidth result['intwidth'] = intwidth
if instruction_set.startswith('sve'): if instruction_set.startswith('sve') or instruction_set == 'sme':
result['makeVecConst'] = f'svdup_f{bits[data_type]}' + '({0})' result['makeVecConst'] = f'svdup_f{bits[data_type]}' + '({0})'
result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})' result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})'
result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})' result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})'
vindex = f'svindex_u{bits[data_type]}(0, {{0}})' if instruction_set != 'sme':
result['storeS'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \ vindex = f'svindex_u{bits[data_type]}(0, {{0}})'
vindex.format("{2}") + ', {1})' result['storeS'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
result['loadS'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \ vindex.format("{2}") + ', {1})'
vindex.format("{1}") + ')' result['loadS'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{1}") + ')'
if instruction_set.startswith('sve2') and instruction_set not in ('sve256', 'sve2048'):
result['streamS'] = f'svstnt1_scatter_u{bits[data_type]}offset_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format(f"{{2}}*{bits[data_type]//8}") + ', {1})'
result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})" result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})"
result['float'] = f'svfloat{bits["float"]}_{"s" if instruction_set != "sve" else ""}t' result['float'] = f'svfloat{bits["float"]}_{"s" if instruction_set not in ["sve", "sve2", "sme"] else ""}t'
result['double'] = f'svfloat{bits["double"]}_{"s" if instruction_set != "sve" else ""}t' result['double'] = f'svfloat{bits["double"]}_{"s" if instruction_set not in ["sve", "sve2", "sme"] else ""}t'
result['int'] = f'svint{bits["int"]}_{"s" if instruction_set != "sve" else ""}t' result['int'] = f'svint{bits["int"]}_{"s" if instruction_set not in ["sve", "sve2", "sme"] else ""}t'
result['bool'] = f'svbool_{"s" if instruction_set != "sve" else ""}t' result['bool'] = f'svbool_{"s" if instruction_set not in ["sve", "sve2", "sme"] else ""}t'
result['headers'] = ['<arm_sve.h>', '"arm_neon_helpers.h"'] result['headers'] = ['<arm_sve.h>', '<arm_acle.h>', '"arm_neon_helpers.h"']
result['&'] = f'svand_b_z({predicate},' + ' {0}, {1})' result['&'] = f'svand_b_z({predicate},' + ' {0}, {1})'
result['|'] = f'svorr_b_z({predicate},' + ' {0}, {1})' result['|'] = f'svorr_b_z({predicate},' + ' {0}, {1})'
...@@ -123,10 +131,17 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -123,10 +131,17 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}' result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}'
result['maskStoreU'] = result['storeU'].replace(predicate, '{2}') result['maskStoreU'] = result['storeU'].replace(predicate, '{2}')
result['maskStoreA'] = result['storeA'].replace(predicate, '{2}') result['maskStream'] = result['stream'].replace(predicate, '{2}')
result['maskStoreS'] = result['storeS'].replace(predicate, '{3}') if instruction_set != 'sme':
result['maskStoreS'] = result['storeS'].replace(predicate, '{3}')
if instruction_set.startswith('sve2') and instruction_set not in ('sve256', 'sve2048'):
result['maskStreamS'] = result['streamS'].replace(predicate, '{3}')
if instruction_set != 'sve': result['streamFence'] = '__dmb(15)'
if instruction_set == 'sme':
result['function_prefix'] = '__attribute__((arm_locally_streaming))'
elif instruction_set not in ['sve', 'sve2', 'sme']:
result['compile_flags'] = [f'-msve-vector-bits={bitwidth}'] result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
else: else:
result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})' result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})'
...@@ -151,9 +166,9 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -151,9 +166,9 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['any'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) > 0' result['any'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) > 0'
result['all'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) == 16*0xff' result['all'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) == 16*0xff'
if instruction_set == 'sve' or bitwidth & (bitwidth - 1) == 0: # SVE has real nontemporal stores, so we only need to zero cachlines on Neon
# only power-of-2 vector sizes will evenly divide a cacheline
result['cachelineSize'] = 'cachelineSize()'
result['cachelineZero'] = 'cachelineZero((void*) {0})' result['cachelineZero'] = 'cachelineZero((void*) {0})'
result['cachelineSize'] = 'cachelineSize()'
return result return result
...@@ -6,16 +6,18 @@ from typing import Set ...@@ -6,16 +6,18 @@ from typing import Set
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from sympy.core import S from sympy.core import S
from sympy.core.cache import cacheit
from sympy.logic.boolalg import BooleanFalse, BooleanTrue from sympy.logic.boolalg import BooleanFalse, BooleanTrue
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
from sympy.functions.elementary.hyperbolic import HyperbolicFunction
from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize
from pystencils.data_types import ( from pystencils.typing import (
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, PointerType, VectorType, CastFunc, create_type, get_type_of_expression,
reinterpret_cast_func, vector_memory_access, BasicType, TypedSymbol) ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol, CFunction)
from pystencils.enums import Backend from pystencils.enums import Backend
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.functions import DivFunc, AddressOf
from pystencils.integer_functions import ( from pystencils.integer_functions import (
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
int_div, int_power_of_2, modulo_ceil) int_div, int_power_of_2, modulo_ceil)
...@@ -30,8 +32,6 @@ __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSy ...@@ -30,8 +32,6 @@ __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSy
HEADER_REGEX = re.compile(r'^[<"].*[">]$') HEADER_REGEX = re.compile(r'^[<"].*[">]$')
KERNCRAFT_NO_TERNARY_MODE = False
def generate_c(ast_node: Node, def generate_c(ast_node: Node,
signature_only: bool = False, signature_only: bool = False,
...@@ -47,7 +47,7 @@ def generate_c(ast_node: Node, ...@@ -47,7 +47,7 @@ def generate_c(ast_node: Node,
Args: Args:
ast_node: ast representation of kernel ast_node: ast representation of kernel
signature_only: generate signature without function body signature_only: generate signature without function body
dialect: `Backend`: 'C', 'CUDA' or 'OPENCL' dialect: `Backend`: 'C' or 'CUDA'
custom_backend: use own custom printer for code generation custom_backend: use own custom printer for code generation
with_globals: enable usage of global variables with_globals: enable usage of global variables
Returns: Returns:
...@@ -63,6 +63,7 @@ def generate_c(ast_node: Node, ...@@ -63,6 +63,7 @@ def generate_c(ast_node: Node,
printer = custom_backend printer = custom_backend
elif dialect == Backend.C: elif dialect == Backend.C:
try: try:
# TODO Vectorization Revamp: instruction_set should not be just slapped on ast
instruction_set = ast_node.instruction_set instruction_set = ast_node.instruction_set
except Exception: except Exception:
instruction_set = None instruction_set = None
...@@ -71,9 +72,6 @@ def generate_c(ast_node: Node, ...@@ -71,9 +72,6 @@ def generate_c(ast_node: Node,
elif dialect == Backend.CUDA: elif dialect == Backend.CUDA:
from pystencils.backends.cuda_backend import CudaBackend from pystencils.backends.cuda_backend import CudaBackend
printer = CudaBackend(signature_only=signature_only) printer = CudaBackend(signature_only=signature_only)
elif dialect == Backend.OPENCL:
from pystencils.backends.opencl_backend import OpenClBackend
printer = OpenClBackend(signature_only=signature_only)
else: else:
raise ValueError(f'Unknown {dialect=}') raise ValueError(f'Unknown {dialect=}')
code = printer(ast_node) code = printer(ast_node)
...@@ -123,12 +121,12 @@ def get_headers(ast_node: Node) -> Set[str]: ...@@ -123,12 +121,12 @@ def get_headers(ast_node: Node) -> Set[str]:
for h in headers: for h in headers:
assert HEADER_REGEX.match(h), f'header /{h}/ does not follow the pattern /"..."/ or /<...>/' assert HEADER_REGEX.match(h), f'header /{h}/ does not follow the pattern /"..."/ or /<...>/'
return sorted(headers) return headers
# --------------------------------------- Backend Specific Nodes ------------------------------------------------------- # --------------------------------------- Backend Specific Nodes -------------------------------------------------------
# TODO future CustomCodeNode should not be backend specific move it elsewhere
class CustomCodeNode(Node): class CustomCodeNode(Node):
def __init__(self, code, symbols_read, symbols_defined, parent=None): def __init__(self, code, symbols_read, symbols_defined, parent=None):
super(CustomCodeNode, self).__init__(parent=parent) super(CustomCodeNode, self).__init__(parent=parent)
...@@ -152,8 +150,8 @@ class CustomCodeNode(Node): ...@@ -152,8 +150,8 @@ class CustomCodeNode(Node):
def undefined_symbols(self): def undefined_symbols(self):
return self._symbols_read - self._symbols_defined return self._symbols_read - self._symbols_defined
def __eq___(self, other): def __eq__(self, other):
return self._code == other._code return type(self) is type(other) and self._code == other._code
def __hash__(self): def __hash__(self):
return hash(self._code) return hash(self._code)
...@@ -167,23 +165,6 @@ class PrintNode(CustomCodeNode): ...@@ -167,23 +165,6 @@ class PrintNode(CustomCodeNode):
self.headers.append("<iostream>") self.headers.append("<iostream>")
class CFunction(TypedSymbol):
def __new__(cls, function, dtype):
return CFunction.__xnew_cached_(cls, function, dtype)
def __new_stage2__(cls, function, dtype):
return super(CFunction, cls).__xnew__(cls, function, dtype)
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), {}
# ------------------------------------------- Printer ------------------------------------------------------------------ # ------------------------------------------- Printer ------------------------------------------------------------------
...@@ -217,12 +198,12 @@ class CBackend: ...@@ -217,12 +198,12 @@ class CBackend:
if isinstance(node, str): if isinstance(node, str):
return node return node
for cls in type(node).__mro__: for cls in type(node).__mro__:
method_name = "_print_" + cls.__name__ method_name = f"_print_{cls.__name__}"
if hasattr(self, method_name): if hasattr(self, method_name):
return getattr(self, method_name)(node) return getattr(self, method_name)(node)
raise NotImplementedError(self.__class__.__name__ + " does not support node of type " + node.__class__.__name__) raise NotImplementedError(f"{self.__class__.__name__} does not support node of type {node.__class__.__name__}")
def _print_Type(self, node): def _print_AbstractType(self, node):
return str(node) return str(node)
def _print_KernelFunction(self, node): def _print_KernelFunction(self, node):
...@@ -249,12 +230,13 @@ class CBackend: ...@@ -249,12 +230,13 @@ class CBackend:
return f"{node.pragma_line}\n{self._print_Block(node)}" return f"{node.pragma_line}\n{self._print_Block(node)}"
def _print_LoopOverCoordinate(self, node): def _print_LoopOverCoordinate(self, node):
counter_symbol = node.loop_counter_name counter_name = node.loop_counter_name
start = f"int64_t {counter_symbol} = {self.sympy_printer.doprint(node.start)}" counter_dtype = node.loop_counter_symbol.dtype.c_name
condition = f"{counter_symbol} < {self.sympy_printer.doprint(node.stop)}" start = f"{counter_dtype} {counter_name} = {self.sympy_printer.doprint(node.start)}"
update = f"{counter_symbol} += {self.sympy_printer.doprint(node.step)}" condition = f"{counter_name} < {self.sympy_printer.doprint(node.stop)}"
update = f"{counter_name} += {self.sympy_printer.doprint(node.step)}"
loop_str = f"for ({start}; {condition}; {update})" loop_str = f"for ({start}; {condition}; {update})"
self._kwargs['loop_counter'] = counter_symbol self._kwargs['loop_counter'] = counter_name
self._kwargs['loop_stop'] = node.stop self._kwargs['loop_stop'] = node.stop
prefix = "\n".join(node.prefix_lines) prefix = "\n".join(node.prefix_lines)
...@@ -263,41 +245,50 @@ class CBackend: ...@@ -263,41 +245,50 @@ class CBackend:
return f"{prefix}{loop_str}\n{self._print(node.body)}" return f"{prefix}{loop_str}\n{self._print(node.body)}"
def _print_SympyAssignment(self, node): def _print_SympyAssignment(self, node):
printed_lhs = self.sympy_printer.doprint(node.lhs)
printed_rhs = self.sympy_printer.doprint(node.rhs)
if node.is_declaration: if node.is_declaration:
if node.use_auto: if node.use_auto:
data_type = 'auto ' data_type = 'auto'
else: else:
data_type = self._print(node.lhs.dtype).replace(' const', '')
if node.is_const: if node.is_const:
prefix = 'const ' data_type = f'const {data_type}'
else: return f"{data_type} {printed_lhs} = {printed_rhs};"
prefix = ''
data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "
return "%s%s = %s;" % (data_type,
self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs))
else: else:
lhs_type = get_type_of_expression(node.lhs) lhs_type = get_type_of_expression(node.lhs) # TOOD: this should have been typed
printed_mask = "" printed_mask = ""
if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): if type(lhs_type) is VectorType and isinstance(node.lhs, CastFunc):
arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args
instr = 'storeU' instr = 'storeU'
if aligned: if nontemporal and 'storeA' not in self._vector_instruction_set and \
'stream' in self._vector_instruction_set:
instr = 'stream'
elif aligned:
instr = 'stream' if nontemporal and 'stream' in self._vector_instruction_set else 'storeA' instr = 'stream' if nontemporal and 'stream' in self._vector_instruction_set else 'storeA'
if mask != True: # NOQA if mask != True: # NOQA
instr = 'maskStoreA' if aligned else 'maskStoreU' instr = 'maskStream' if nontemporal and 'maskStream' in self._vector_instruction_set else \
'maskStoreA' if aligned else 'maskStoreU'
if instr not in self._vector_instruction_set: if instr not in self._vector_instruction_set:
self._vector_instruction_set[instr] = self._vector_instruction_set['store' + instr[-1]].format( if instr == 'maskStream' and 'stream' in self._vector_instruction_set:
store, load = 'stream', 'loadA'
elif (instr in ('maskStream', 'maskStoreA')) and 'storeA' in self._vector_instruction_set:
store, load = 'storeA', 'loadA'
else:
store, load = 'storeU', 'loadU'
load = load if load in self._vector_instruction_set else 'loadU'
self._vector_instruction_set[instr] = self._vector_instruction_set[store].format(
'{0}', self._vector_instruction_set['blendv'].format( '{0}', self._vector_instruction_set['blendv'].format(
self._vector_instruction_set['load' + instr[-1]].format('{0}', **self._kwargs), self._vector_instruction_set[load].format('{0}', **self._kwargs),
'{1}', '{2}', **self._kwargs), **self._kwargs) '{1}', '{2}', **self._kwargs), **self._kwargs)
printed_mask = self.sympy_printer.doprint(mask) printed_mask = self.sympy_printer.doprint(mask)
if data_type.base_type.base_name == 'double': if data_type.base_type.c_name == 'double':
if self._vector_instruction_set['double'] == '__m256d': if self._vector_instruction_set['double'] == '__m256d':
printed_mask = f"_mm256_castpd_si256({printed_mask})" printed_mask = f"_mm256_castpd_si256({printed_mask})"
elif self._vector_instruction_set['double'] == '__m128d': elif self._vector_instruction_set['double'] == '__m128d':
printed_mask = f"_mm_castpd_si128({printed_mask})" printed_mask = f"_mm_castpd_si128({printed_mask})"
elif data_type.base_type.base_name == 'float': elif data_type.base_type.c_name == 'float':
if self._vector_instruction_set['float'] == '__m256': if self._vector_instruction_set['float'] == '__m256':
printed_mask = f"_mm256_castps_si256({printed_mask})" printed_mask = f"_mm256_castps_si256({printed_mask})"
elif self._vector_instruction_set['float'] == '__m128': elif self._vector_instruction_set['float'] == '__m128':
...@@ -305,19 +296,23 @@ class CBackend: ...@@ -305,19 +296,23 @@ class CBackend:
rhs_type = get_type_of_expression(node.rhs) rhs_type = get_type_of_expression(node.rhs)
if type(rhs_type) is not VectorType: if type(rhs_type) is not VectorType:
rhs = cast_func(node.rhs, VectorType(rhs_type)) raise ValueError(f'Cannot vectorize {node.rhs} of type {rhs_type} inside of the pretty printer! '
f'This should have happen earlier!')
# rhs = CastFunc(node.rhs, VectorType(rhs_type)) # Unknown width
else: else:
rhs = node.rhs rhs = node.rhs
ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0]) ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0])
if stride != 1: if stride != 1:
instr = 'maskStoreS' if mask != True else 'storeS' # NOQA instr = ('maskStreamS' if nontemporal and 'maskStreamS' in self._vector_instruction_set else
'maskStoreS') if mask != True else \
('streamS' if nontemporal and 'streamS' in self._vector_instruction_set else 'storeS') # NOQA
return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs), return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
stride, printed_mask, **self._kwargs) + ';' stride, printed_mask, **self._kwargs) + ';'
pre_code = '' pre_code = ''
if nontemporal and 'cachelineZero' in self._vector_instruction_set: if nontemporal and 'cachelineZero' in self._vector_instruction_set and mask == True: # NOQA
first_cond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0" first_cond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0"
offset = sp.Add(*[sp.Symbol(LoopOverCoordinate.get_loop_counter_name(i)) offset = sp.Add(*[sp.Symbol(LoopOverCoordinate.get_loop_counter_name(i))
* node.lhs.args[0].field.spatial_strides[i] for i in * node.lhs.args[0].field.spatial_strides[i] for i in
...@@ -325,7 +320,7 @@ class CBackend: ...@@ -325,7 +320,7 @@ class CBackend:
if stride == 1: if stride == 1:
offset = offset.subs({node.lhs.args[0].field.spatial_strides[0]: 1}) offset = offset.subs({node.lhs.args[0].field.spatial_strides[0]: 1})
size = sp.Mul(*node.lhs.args[0].field.spatial_shape) size = sp.Mul(*node.lhs.args[0].field.spatial_shape)
element_size = 8 if data_type.base_type.base_name == 'double' else 4 element_size = 8 if data_type.base_type.c_name == 'double' else 4
size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}" size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}"
pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \ pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \
self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n' self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n'
...@@ -337,17 +332,26 @@ class CBackend: ...@@ -337,17 +332,26 @@ class CBackend:
code2 = self._vector_instruction_set['flushCacheline'].format( code2 = self._vector_instruction_set['flushCacheline'].format(
ptr, self.sympy_printer.doprint(rhs), **self._kwargs) + ';' ptr, self.sympy_printer.doprint(rhs), **self._kwargs) + ';'
code = f"{code}\nif ({flushcond}) {{\n\t{code2}\n}}" code = f"{code}\nif ({flushcond}) {{\n\t{code2}\n}}"
elif nontemporal and 'storeAAndFlushCacheline' in self._vector_instruction_set: elif aligned and nontemporal and 'storeAAndFlushCacheline' in self._vector_instruction_set:
tmpvar = '_tmp_' + hashlib.sha1(self.sympy_printer.doprint(rhs).encode('ascii')).hexdigest()[:8] lhs_hash = hashlib.sha1(self.sympy_printer.doprint(node.lhs).encode('ascii')).hexdigest()[:8]
rhs_hash = hashlib.sha1(self.sympy_printer.doprint(rhs).encode('ascii')).hexdigest()[:8]
tmpvar = f'_tmp_{lhs_hash}_{rhs_hash}'
code = 'const ' + self._print(node.lhs.dtype).replace(' const', '') + ' ' + tmpvar + ' = ' \ code = 'const ' + self._print(node.lhs.dtype).replace(' const', '') + ' ' + tmpvar + ' = ' \
+ self.sympy_printer.doprint(rhs) + ';' + self.sympy_printer.doprint(rhs) + ';'
code1 = self._vector_instruction_set[instr].format(ptr, tmpvar, printed_mask, **self._kwargs) + ';' code1 = self._vector_instruction_set[instr].format(ptr, tmpvar, printed_mask, **self._kwargs) + ';'
code2 = self._vector_instruction_set['storeAAndFlushCacheline'].format(ptr, tmpvar, printed_mask, maskStore, store, load = 'maskStoreAAndFlushCacheline', 'storeAAndFlushCacheline', 'loadA'
**self._kwargs) + ';' instr2 = maskStore if mask != True else store # NOQA
if instr2 not in self._vector_instruction_set:
self._vector_instruction_set[maskStore] = self._vector_instruction_set[store].format(
'{0}', self._vector_instruction_set['blendv'].format(
self._vector_instruction_set[load].format('{0}', **self._kwargs),
'{1}', '{2}', **self._kwargs),
**self._kwargs)
code2 = self._vector_instruction_set[instr2].format(ptr, tmpvar, printed_mask, **self._kwargs) + ';'
code += f"\nif ({flushcond}) {{\n\t{code2}\n}} else {{\n\t{code1}\n}}" code += f"\nif ({flushcond}) {{\n\t{code2}\n}} else {{\n\t{code1}\n}}"
return pre_code + code return pre_code + code
else: else:
return f"{self.sympy_printer.doprint(node.lhs)} = {self.sympy_printer.doprint(node.rhs)};" return f"{printed_lhs} = {printed_rhs};"
def _print_NontemporalFence(self, _): def _print_NontemporalFence(self, _):
if 'streamFence' in self._vector_instruction_set: if 'streamFence' in self._vector_instruction_set:
...@@ -439,23 +443,31 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -439,23 +443,31 @@ class CustomSympyPrinter(CCodePrinter):
def __init__(self): def __init__(self):
super(CustomSympyPrinter, self).__init__() super(CustomSympyPrinter, self).__init__()
self._float_type = create_type("float32")
def _print_Pow(self, expr): def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication""" """Don't use std::pow function, for small integer exponents, write as multiplication"""
if not expr.free_symbols: # Ideally the printer has as little logic as possible. Therefore,
return self._typed_number(expr.evalf(), get_type_of_expression(expr)) # powers should be rewritten as `DivFunc`s / unevaluated `Mul`s before
# printing. `NodeCollection` offers a convenience function to do just
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: # that. However, `cut_loops` rewrites unevaluated multiplications as
return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})" # `Pow`s again. Neither `deepcopy` nor `func(*args)` are suited to
elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: # rebuild unevaluated expressions. Therefore, as long as we stick with
return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})" # SymPy, this is the only way to avoid printing `pow`s.
exp = expr.exp.expr if isinstance(expr.exp, CastFunc) else expr.exp
one_type = expr.base.dtype if hasattr(expr.base, "dtype") else get_type_of_expression(expr.base)
if exp.is_integer and exp.is_number and (0 < exp <= 8):
return f"({self._print(sp.Mul(*[expr.base] * exp, evaluate=False))})"
elif exp.is_integer and exp.is_number and (-8 <= exp < 0):
return f"{self._typed_number(1, one_type)} / ({self._print(sp.Mul(*([expr.base] * -exp), evaluate=False))})"
else: else:
return super(CustomSympyPrinter, self)._print_Pow(expr) return super(CustomSympyPrinter, self)._print_Pow(expr)
# TODO don't print ones in sp.Mul
def _print_Rational(self, expr): def _print_Rational(self, expr):
"""Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0""" """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
res = str(expr.evalf().num) res = str(expr.evalf(17))
return res return res
def _print_Equality(self, expr): def _print_Equality(self, expr):
...@@ -473,7 +485,7 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -473,7 +485,7 @@ class CustomSympyPrinter(CCodePrinter):
else: else:
return f'fabs({self._print(expr.args[0])})' return f'fabs({self._print(expr.args[0])})'
def _print_Type(self, node): def _print_AbstractType(self, node):
return str(node) return str(node)
def _print_Function(self, expr): def _print_Function(self, expr):
...@@ -486,30 +498,48 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -486,30 +498,48 @@ class CustomSympyPrinter(CCodePrinter):
} }
if hasattr(expr, 'to_c'): if hasattr(expr, 'to_c'):
return expr.to_c(self._print) return expr.to_c(self._print)
if isinstance(expr, reinterpret_cast_func): if isinstance(expr, ReinterpretCastFunc):
arg, data_type = expr.args arg, data_type = expr.args
return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))" if isinstance(data_type, PointerType):
elif isinstance(expr, address_of): const_str = "const" if data_type.const else ""
return f"(({const_str} {self._print(data_type.base_type)} *)(& {self._print(arg)}))"
else:
return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))"
elif isinstance(expr, AddressOf):
assert len(expr.args) == 1, "address_of must only have one argument" assert len(expr.args) == 1, "address_of must only have one argument"
return f"&({self._print(expr.args[0])})" return f"&({self._print(expr.args[0])})"
elif isinstance(expr, cast_func): elif isinstance(expr, CastFunc):
cast = "(({data_type})({code}))"
arg, data_type = expr.args arg, data_type = expr.args
if isinstance(arg, sp.Number) and arg.is_finite: if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)):
return self._typed_number(arg, data_type) return self._typed_number(arg, data_type)
elif isinstance(arg, (InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)) \
and data_type == BasicType('float32'):
known = self.known_functions[arg.__class__.__name__.lower()]
code = self._print(arg)
return code.replace(known, f"{known}f")
elif isinstance(arg, (sp.Pow, sp.exp)) and data_type == BasicType('float32'):
known = ['sqrt', 'cbrt', 'pow', 'exp']
code = self._print(arg)
for k in known:
if k in code:
return code.replace(k, f'{k}f')
# Powers of small integers are printed as divisions/multiplications.
if '/' in code or '*' in code:
return cast.format(data_type=data_type, code=code)
raise ValueError(f"{code} doesn't give {known=} function back.")
else: else:
return f"(({data_type})({self._print(arg)}))" return cast.format(data_type=data_type, code=self._print(arg))
elif isinstance(expr, fast_division): elif isinstance(expr, fast_division):
return f"({self._print(expr.args[0] / expr.args[1])})" raise ValueError("fast_division is only supported for Taget.GPU")
elif isinstance(expr, fast_sqrt): elif isinstance(expr, fast_sqrt):
return f"({self._print(sp.sqrt(expr.args[0]))})" raise ValueError("fast_sqrt is only supported for Taget.GPU")
elif isinstance(expr, fast_inv_sqrt):
raise ValueError("fast_inv_sqrt is only supported for Taget.GPU")
elif isinstance(expr, vec_any) or isinstance(expr, vec_all): elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
return self._print(expr.args[0]) return self._print(expr.args[0])
elif isinstance(expr, fast_inv_sqrt):
return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
elif isinstance(expr, sp.Abs): elif isinstance(expr, sp.Abs):
return f"abs({self._print(expr.args[0])})" return f"abs({self._print(expr.args[0])})"
elif isinstance(expr, sp.Max):
return self._print(expr)
elif isinstance(expr, sp.Mod): elif isinstance(expr, sp.Mod):
if expr.args[0].is_integer and expr.args[1].is_integer: if expr.args[0].is_integer and expr.args[1].is_integer:
return f"({self._print(expr.args[0])} % {self._print(expr.args[1])})" return f"({self._print(expr.args[0])} % {self._print(expr.args[1])})"
...@@ -521,6 +551,8 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -521,6 +551,8 @@ class CustomSympyPrinter(CCodePrinter):
return f"(1 << ({self._print(expr.args[0])}))" return f"(1 << ({self._print(expr.args[0])}))"
elif expr.func == int_div: elif expr.func == int_div:
return f"(({self._print(expr.args[0])}) / ({self._print(expr.args[1])}))" return f"(({self._print(expr.args[0])}) / ({self._print(expr.args[1])}))"
elif expr.func == DivFunc:
return f'(({self._print(expr.divisor)}) / ({self._print(expr.dividend)}))'
else: else:
name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__ name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__
arg_str = ', '.join(self._print(a) for a in expr.args) arg_str = ', '.join(self._print(a) for a in expr.args)
...@@ -543,52 +575,6 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -543,52 +575,6 @@ class CustomSympyPrinter(CCodePrinter):
else: else:
return res return res
def _print_Sum(self, expr):
template = """[&]() {{
{dtype} sum = ({dtype}) 0;
for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{
sum += {expr};
}}
return sum;
}}()"""
var = expr.limits[0][0]
start = expr.limits[0][1]
end = expr.limits[0][2]
code = template.format(
dtype=get_type_of_expression(expr.args[0]),
iterator_dtype='int',
var=self._print(var),
start=self._print(start),
end=self._print(end),
expr=self._print(expr.function),
increment=str(1),
condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>='
)
return code
def _print_Product(self, expr):
template = """[&]() {{
{dtype} product = ({dtype}) 1;
for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{
product *= {expr};
}}
return product;
}}()"""
var = expr.limits[0][0]
start = expr.limits[0][1]
end = expr.limits[0][2]
code = template.format(
dtype=get_type_of_expression(expr.args[0]),
iterator_dtype='int',
var=self._print(var),
start=self._print(start),
end=self._print(end),
expr=self._print(expr.function),
increment=str(1),
condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>='
)
return code
def _print_ConditionalFieldAccess(self, node): def _print_ConditionalFieldAccess(self, node):
return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True))) return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True)))
...@@ -612,27 +598,6 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -612,27 +598,6 @@ class CustomSympyPrinter(CCodePrinter):
return f"(({a} < {b}) ? {a} : {b})" return f"(({a} < {b}) ? {a} : {b})"
return inner_print_min(expr.args) return inner_print_min(expr.args)
def _print_re(self, expr):
return f"real({self._print(expr.args[0])})"
def _print_im(self, expr):
return f"imag({self._print(expr.args[0])})"
def _print_ImaginaryUnit(self, expr):
return "complex<double>{0,1}"
def _print_TypedImaginaryUnit(self, expr):
if expr.dtype.numpy_dtype == np.complex64:
return "complex<float>{0,1}"
elif expr.dtype.numpy_dtype == np.complex128:
return "complex<double>{0,1}"
else:
raise NotImplementedError(
"only complex64 and complex128 supported")
def _print_Complex(self, expr):
return self._typed_number(expr, np.complex64)
# noinspection PyPep8Naming # noinspection PyPep8Naming
class VectorizedCustomSympyPrinter(CustomSympyPrinter): class VectorizedCustomSympyPrinter(CustomSympyPrinter):
...@@ -651,55 +616,100 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -651,55 +616,100 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return None return None
def _print_Abs(self, expr): def _print_Abs(self, expr):
if 'abs' in self.instruction_set and isinstance(expr.args[0], vector_memory_access): if isinstance(get_type_of_expression(expr), (VectorType, VectorMemoryAccess)):
return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs) return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs)
return super()._print_Abs(expr) return super()._print_Abs(expr)
def _typed_vectorized_number(self, expr, data_type):
basic_data_type = data_type.base_type
number = self._typed_number(expr, basic_data_type)
instruction = 'makeVecConst'
if basic_data_type.is_bool():
instruction = 'makeVecConstBool'
# TODO Vectorization Revamp: is int, or sint, or uint (my guess is sint)
elif basic_data_type.is_int():
instruction = 'makeVecConstInt'
return self.instruction_set[instruction].format(number, **self._kwargs)
def _typed_vectorized_symbol(self, expr, data_type):
if not isinstance(expr, TypedSymbol):
raise ValueError(f'{expr} is not a TypeSymbol. It is {expr.type=}')
basic_data_type = data_type.base_type
symbol = self._print(expr)
if basic_data_type != expr.dtype:
symbol = f'(({basic_data_type})({symbol}))'
instruction = 'makeVecConst'
if basic_data_type.is_bool():
instruction = 'makeVecConstBool'
# TODO Vectorization Revamp: is int, or sint, or uint (my guess is sint)
elif basic_data_type.is_int():
instruction = 'makeVecConstInt'
return self.instruction_set[instruction].format(symbol, **self._kwargs)
def _print_CastFunc(self, expr):
arg, data_type = expr.args
if type(data_type) is VectorType:
base_type = data_type.base_type
# vector_memory_access is a cast_func itself so it should't be directly inside a cast_func
assert not isinstance(arg, VectorMemoryAccess)
if isinstance(arg, sp.Tuple):
is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
is_integer = get_type_of_expression(arg[0]) == create_type("int")
printed_args = [self._print(a) for a in arg]
instruction = 'makeVecBool' if is_boolean else 'makeVecInt' if is_integer else 'makeVec'
if instruction == 'makeVecInt' and 'makeVecIndex' in self.instruction_set:
increments = np.array(arg)[1:] - np.array(arg)[:-1]
if len(set(increments)) == 1:
return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0],
**self._kwargs)
return self.instruction_set[instruction].format(*printed_args, **self._kwargs)
else:
if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)):
return self._typed_vectorized_number(arg, data_type)
elif isinstance(arg, TypedSymbol):
return self._typed_vectorized_symbol(arg, data_type)
elif isinstance(arg, (InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)) \
and base_type == BasicType('float32'):
raise NotImplementedError('Vectorizer is not tested for trigonometric functions yet')
# known = self.known_functions[arg.__class__.__name__.lower()]
# code = self._print(arg)
# return code.replace(known, f"{known}f")
elif isinstance(arg, sp.Pow):
if base_type == BasicType('float32') or base_type == BasicType('float64'):
return self._print_Pow(arg)
else:
raise NotImplementedError('Integer Pow is not implemented')
elif isinstance(arg, sp.UnevaluatedExpr):
return self._print(arg.args[0])
else:
raise NotImplementedError('Vectorizer cannot cast between different datatypes')
# to_type = self.instruction_set['suffix'][data_type.base_type.c_name]
# from_type = self.instruction_set['suffix'][get_type_of_expression(arg).base_type.c_name]
# return self.instruction_set['cast'].format(from_type, to_type, self._print(arg))
else:
return self._scalarFallback('_print_Function', expr)
# raise ValueError(f'Non VectorType cast "{data_type}" in vectorized code.')
def _print_Function(self, expr): def _print_Function(self, expr):
if isinstance(expr, vector_memory_access): if isinstance(expr, VectorMemoryAccess):
arg, data_type, aligned, _, mask, stride = expr.args arg, data_type, aligned, _, mask, stride = expr.args
if stride != 1: if stride != 1:
return self.instruction_set['loadS'].format("& " + self._print(arg), stride, **self._kwargs) return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs)
instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
return instruction.format("& " + self._print(arg), **self._kwargs) return instruction.format(f"& {self._print(arg)}", **self._kwargs)
elif isinstance(expr, cast_func): elif expr.func == DivFunc:
arg, data_type = expr.args
if type(data_type) is VectorType:
# vector_memory_access is a cast_func itself so it should't be directly inside a cast_func
assert not isinstance(arg, vector_memory_access)
if isinstance(arg, sp.Tuple):
is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
is_integer = get_type_of_expression(arg[0]) == create_type("int")
printed_args = [self._print(a) for a in arg]
instruction = 'makeVecBool' if is_boolean else 'makeVecInt' if is_integer else 'makeVec'
if instruction == 'makeVecInt' and 'makeVecIndex' in self.instruction_set:
increments = np.array(arg)[1:] - np.array(arg)[:-1]
if len(set(increments)) == 1:
return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0],
**self._kwargs)
return self.instruction_set[instruction].format(*printed_args, **self._kwargs)
else:
is_boolean = get_type_of_expression(arg) == create_type("bool")
is_integer = get_type_of_expression(arg) == create_type("int") or \
(isinstance(arg, TypedSymbol) and not isinstance(arg.dtype, VectorType) and arg.dtype.is_int())
instruction = 'makeVecConstBool' if is_boolean else \
'makeVecConstInt' if is_integer else 'makeVecConst'
return self.instruction_set[instruction].format(self._print(arg), **self._kwargs)
elif expr.func == fast_division:
result = self._scalarFallback('_print_Function', expr) result = self._scalarFallback('_print_Function', expr)
if not result: if not result:
result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]), result = self.instruction_set['/'].format(self._print(expr.divisor), self._print(expr.dividend),
**self._kwargs) **self._kwargs)
return result return result
elif expr.func == fast_sqrt: elif isinstance(expr, fast_division):
return f"({self._print(sp.sqrt(expr.args[0]))})" raise ValueError("fast_division is only supported for Taget.GPU")
elif expr.func == fast_inv_sqrt: elif isinstance(expr, fast_sqrt):
result = self._scalarFallback('_print_Function', expr) raise ValueError("fast_sqrt is only supported for Taget.GPU")
if not result: elif isinstance(expr, fast_inv_sqrt):
if 'rsqrt' in self.instruction_set: raise ValueError("fast_inv_sqrt is only supported for Taget.GPU")
return self.instruction_set['rsqrt'].format(self._print(expr.args[0]), **self._kwargs)
else:
return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
elif isinstance(expr, vec_any) or isinstance(expr, vec_all): elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
instr = 'any' if isinstance(expr, vec_any) else 'all' instr = 'any' if isinstance(expr, vec_any) else 'all'
expr_type = get_type_of_expression(expr.args[0]) expr_type = get_type_of_expression(expr.args[0])
...@@ -750,12 +760,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -750,12 +760,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
# special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization # special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization
suffix = "" suffix = ""
if all([(type(e) is cast_func and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer) if all([(type(e) is CastFunc and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer)
or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]): or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]):
dtype = set([e.dtype for e in args if type(e) is cast_func]) dtype = set([e.dtype for e in args if type(e) is CastFunc])
assert len(dtype) == 1 assert len(dtype) == 1
dtype = dtype.pop() dtype = dtype.pop()
args = [cast_func(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e args = [CastFunc(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e
for e in args] for e in args]
suffix = "int" suffix = "int"
...@@ -781,26 +791,31 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -781,26 +791,31 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return processed return processed
def _print_Pow(self, expr): def _print_Pow(self, expr):
result = self._scalarFallback('_print_Pow', expr) # Due to loop cutting sp.Mul is evaluated again.
try:
result = self._scalarFallback('_print_Pow', expr)
except ValueError:
result = None
if result: if result:
return result return result
one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs) one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs)
root = self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs)
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: if isinstance(expr.exp, CastFunc) and expr.exp.args[0].is_number:
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" exp = expr.exp.args[0]
elif expr.exp == -1: else:
one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs) exp = expr.exp
return self.instruction_set['/'].format(one, self._print(expr.base), **self._kwargs)
elif expr.exp == 0.5: # TODO the printer should not have any intelligence like this.
return self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs) # TODO To remove all of these cases the vectoriser needs to be reworked. See loop cutting
elif expr.exp == -0.5: if exp.is_integer and exp.is_number and 0 < exp < 8:
root = self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs) return self._print(sp.Mul(*[expr.base] * exp, evaluate=False))
elif exp == 0.5:
return root
elif exp == -0.5:
return self.instruction_set['/'].format(one, root, **self._kwargs) return self.instruction_set['/'].format(one, root, **self._kwargs)
elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
return self.instruction_set['/'].format(one,
self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)),
**self._kwargs)
else: else:
raise ValueError("Generic exponential not supported: " + str(expr)) raise ValueError("Generic exponential not supported: " + str(expr))
...@@ -808,7 +823,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -808,7 +823,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
# noinspection PyProtectedMember # noinspection PyProtectedMember
from sympy.core.mul import _keep_coeff from sympy.core.mul import _keep_coeff
result = self._scalarFallback('_print_Mul', expr) if not inside_add:
result = self._scalarFallback('_print_Mul', expr)
else:
result = None
if result: if result:
return result return result
...@@ -883,12 +901,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -883,12 +901,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self._print(expr.args[-1][0]) result = self._print(expr.args[-1][0])
for true_expr, condition in reversed(expr.args[:-1]): for true_expr, condition in reversed(expr.args[:-1]):
if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"): if isinstance(condition, CastFunc) and get_type_of_expression(condition.args[0]) == create_type("bool"):
if not KERNCRAFT_NO_TERNARY_MODE: result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr),
result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), result, **self._kwargs)
result, **self._kwargs)
else:
print("Warning - skipping ternary op")
else: else:
# noinspection SpellCheckingInspection # noinspection SpellCheckingInspection
result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition), result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition),
......
from os.path import dirname, join
from pystencils.astnodes import Node from pystencils.astnodes import Node
from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
from pystencils.enums import Backend from pystencils.enums import Backend
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.interpolation_astnodes import DiffInterpolatorAccess, InterpolationMode
with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
lines = f.readlines()
CUDA_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l}
def generate_cuda(ast_node: Node, signature_only: bool = False, custom_backend=None, with_globals=True) -> str: def generate_cuda(ast_node: Node, signature_only: bool = False, custom_backend=None, with_globals=True) -> str:
...@@ -44,26 +37,13 @@ class CudaBackend(CBackend): ...@@ -44,26 +37,13 @@ class CudaBackend(CBackend):
return code return code
@staticmethod @staticmethod
def _print_ThreadBlockSynchronization(node): def _print_ThreadBlockSynchronization(_):
code = "__synchtreads();" return "__synchtreads();"
return code
def _print_TextureDeclaration(self, node): def _print_TextureDeclaration(self, node):
cond = node.texture.field.dtype.numpy_dtype.itemsize > 4
# TODO: use fStrings here return f'texture<{"fp_tex_" if cond else ""}{str(node.texture.field.dtype)}, ' \
if node.texture.field.dtype.numpy_dtype.itemsize > 4: f'cudaTextureType{node.texture.field.spacial_dimensions}D, cudaReadModeElementType> {node.texture};'
code = "texture<fp_tex_%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
str(node.texture.field.dtype),
node.texture.field.spatial_dimensions,
node.texture
)
else:
code = "texture<%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
str(node.texture.field.dtype),
node.texture.field.spatial_dimensions,
node.texture
)
return code
def _print_SkipIteration(self, _): def _print_SkipIteration(self, _):
return "return;" return "return;"
...@@ -74,31 +54,6 @@ class CudaSympyPrinter(CustomSympyPrinter): ...@@ -74,31 +54,6 @@ class CudaSympyPrinter(CustomSympyPrinter):
def __init__(self): def __init__(self):
super(CudaSympyPrinter, self).__init__() super(CudaSympyPrinter, self).__init__()
self.known_functions.update(CUDA_KNOWN_FUNCTIONS)
def _print_InterpolatorAccess(self, node):
dtype = node.interpolator.field.dtype.numpy_dtype
if type(node) == DiffInterpolatorAccess:
# cubicTex3D_1st_derivative_x(texture tex, float3 coord)
template = f"cubicTex%iD_1st_derivative_{list(reversed('xyz'[:node.ndim]))[node.diff_coordinate_idx]}(%s, %s)" # noqa
elif node.interpolator.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
template = "cubicTex%iDSimple(%s, %s)"
else:
if dtype.itemsize > 4:
# Use PyCuda hack!
# https://github.com/inducer/pycuda/blob/master/pycuda/cuda/pycuda-helpers.hpp
template = "fp_tex%iD(%s, %s)"
else:
template = "tex%iD(%s, %s)"
code = template % (
node.interpolator.field.spatial_dimensions,
str(node.interpolator),
# + 0.5 comes from Nvidia's staggered indexing
', '.join(self._print(o + 0.5) for o in reversed(node.offsets))
)
return code
def _print_Function(self, expr): def _print_Function(self, expr):
if isinstance(expr, fast_division): if isinstance(expr, fast_division):
......
import graphviz import graphviz
from graphviz import Digraph, lang try:
from graphviz import Digraph
import graphviz.quoting as quote
except ImportError:
from graphviz import Digraph
import graphviz.lang as quote
from sympy.printing.printer import Printer from sympy.printing.printer import Printer
...@@ -12,7 +17,7 @@ class DotPrinter(Printer): ...@@ -12,7 +17,7 @@ class DotPrinter(Printer):
super(DotPrinter, self).__init__() super(DotPrinter, self).__init__()
self._node_to_str_function = node_to_str_function self._node_to_str_function = node_to_str_function
self.dot = Digraph(**kwargs) self.dot = Digraph(**kwargs)
self.dot.quote_edge = lang.quote self.dot.quote_edge = quote.quote
def _print_KernelFunction(self, func): def _print_KernelFunction(self, func):
self.dot.node(str(id(func)), style='filled', fillcolor='#a056db', label=self._node_to_str_function(func)) self.dot.node(str(id(func)), style='filled', fillcolor='#a056db', label=self._node_to_str_function(func))
......