Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 2859 additions and 449 deletions
import numpy as np
import sympy as sp
import pystencils as ps
from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets
from pystencils.cpu.vectorization import vectorize
from pystencils.transformations import replace_inner_stride_with_one
def test_vector_type_propagation():
a, b, c, d, e = sp.symbols("a b c d e")
arr = np.ones((2 ** 2 + 2, 2 ** 3 + 2))
arr *= 10.0
f, g = ps.fields(f=arr, g=arr)
update_rule = [ps.Assignment(a, f[1, 0]),
ps.Assignment(b, a),
ps.Assignment(g[0, 0], b + 3 + f[0, 1])]
ast = ps.create_kernel(update_rule)
vectorize(ast)
func = ast.compile()
dst = np.zeros_like(arr)
func(g=dst, f=arr)
np.testing.assert_equal(dst[1:-1, 1:-1], 2 * 10.0 + 3)
def test_inplace_update():
shape = (9, 9, 3)
arr = np.ones(shape, order='f')
@ps.kernel
def update_rule(s):
f = ps.fields("f(3) : [2D]", f=arr)
s.tmp0 @= f(0)
s.tmp1 @= f(1)
s.tmp2 @= f(2)
f0, f1, f2 = f(0), f(1), f(2)
f0 @= 2 * s.tmp0
f1 @= 2 * s.tmp0
f2 @= 2 * s.tmp0
ast = ps.create_kernel(update_rule, cpu_vectorize_info={'instruction_set': 'avx'})
kernel = ast.compile()
kernel(f=arr)
np.testing.assert_equal(arr, 2)
def test_vectorization_fixed_size():
configurations = []
# Fixed size - multiple of four
arr = np.ones((20 + 2, 24 + 2)) * 5.0
f, g = ps.fields(f=arr, g=arr)
configurations.append((arr, f, g))
# Fixed size - no multiple of four
arr = np.ones((21 + 2, 25 + 2)) * 5.0
f, g = ps.fields(f=arr, g=arr)
configurations.append((arr, f, g))
# Fixed size - different remainder
arr = np.ones((23 + 2, 17 + 2)) * 5.0
f, g = ps.fields(f=arr, g=arr)
configurations.append((arr, f, g))
for arr, f, g in configurations:
update_rule = [ps.Assignment(g[0, 0], f[0, 0] + f[-1, 0] + f[1, 0] + f[0, 1] + f[0, -1] + 42.0)]
ast = ps.create_kernel(update_rule)
vectorize(ast)
func = ast.compile()
dst = np.zeros_like(arr)
func(g=dst, f=arr)
np.testing.assert_equal(dst[1:-1, 1:-1], 5 * 5.0 + 42.0)
def test_vectorization_variable_size():
f, g = ps.fields("f, g : double[2D]")
update_rule = [ps.Assignment(g[0, 0], f[0, 0] + f[-1, 0] + f[1, 0] + f[0, 1] + f[0, -1] + 42.0)]
ast = ps.create_kernel(update_rule)
replace_inner_stride_with_one(ast)
vectorize(ast)
func = ast.compile()
arr = np.ones((23 + 2, 17 + 2)) * 5.0
dst = np.zeros_like(arr)
func(g=dst, f=arr)
np.testing.assert_equal(dst[1:-1, 1:-1], 5 * 5.0 + 42.0)
def test_piecewise1():
a, b, c, d, e = sp.symbols("a b c d e")
arr = np.ones((2 ** 3 + 2, 2 ** 4 + 2)) * 5.0
f, g = ps.fields(f=arr, g=arr)
update_rule = [ps.Assignment(a, f[1, 0]),
ps.Assignment(b, a),
ps.Assignment(c, f[0, 0] > 0.0),
ps.Assignment(g[0, 0], sp.Piecewise((b + 3 + f[0, 1], c), (0.0, True)))]
ast = ps.create_kernel(update_rule)
vectorize(ast)
func = ast.compile()
dst = np.zeros_like(arr)
func(g=dst, f=arr)
np.testing.assert_equal(dst[1:-1, 1:-1], 5 + 3 + 5.0)
def test_piecewise2():
arr = np.zeros((20, 20))
@ps.kernel
def test_kernel(s):
f, g = ps.fields(f=arr, g=arr)
s.condition @= f[0, 0] > 1
s.result @= 0.0 if s.condition else 1.0
g[0, 0] @= s.result
ast = ps.create_kernel(test_kernel)
vectorize(ast)
func = ast.compile()
func(f=arr, g=arr)
np.testing.assert_equal(arr, np.ones_like(arr))
def test_piecewise3():
arr = np.zeros((22, 22))
@ps.kernel
def test_kernel(s):
f, g = ps.fields(f=arr, g=arr)
s.b @= f[0, 1]
g[0, 0] @= 1.0 / (s.b + s.k) if f[0, 0] > 0.0 else 1.0
ast = ps.create_kernel(test_kernel)
vectorize(ast)
ast.compile()
def test_logical_operators():
arr = np.zeros((22, 22))
@ps.kernel
def test_kernel(s):
f, g = ps.fields(f=arr, g=arr)
s.c @= sp.And(f[0, 1] < 0.0, f[1, 0] < 0.0)
g[0, 0] @= sp.Piecewise([1.0 / f[1, 0], s.c], [1.0, True])
ast = ps.create_kernel(test_kernel)
vectorize(ast)
ast.compile()
def test_hardware_query():
instruction_sets = get_supported_instruction_sets()
assert 'sse' in instruction_sets
[pytest] [pytest]
testpaths = src tests doc/notebooks
pythonpath = src
python_files = test_*.py *_test.py scenario_*.py python_files = test_*.py *_test.py scenario_*.py
norecursedirs = *.egg-info .git .cache .ipynb_checkpoints htmlcov norecursedirs = *.egg-info .git .cache .ipynb_checkpoints htmlcov
addopts = --doctest-modules --durations=20 --cov-config pytest.ini addopts = --doctest-modules --durations=20 --cov-config pytest.ini
markers =
longrun: tests only run at night since they have large execution time
notebook: mark for notebooks
# these warnings all come from third party libraries.
filterwarnings =
ignore:an integer is required:DeprecationWarning
ignore:\s*load will be removed, use:PendingDeprecationWarning
ignore:the imp module is deprecated in favour of importlib:DeprecationWarning
ignore:.*is a deprecated alias for the builtin `bool`:DeprecationWarning
ignore:'contextfilter' is renamed to 'pass_context':DeprecationWarning
ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc':DeprecationWarning
ignore:Animation was deleted without rendering anything:UserWarning
[run] [run]
branch = True branch = True
source = pystencils source = src/pystencils
pystencils_tests tests
omit = doc/* omit = doc/*
pystencils_tests/* tests/*
setup.py setup.py
quicktest.py
conftest.py conftest.py
pystencils/jupytersetup.py versioneer.py
pystencils/cpu/msvc_detection.py src/pystencils/jupytersetup.py
pystencils/sympy_gmpy_bug_workaround.py src/pystencils/cpu/msvc_detection.py
pystencils/cache.py src/pystencils/sympy_gmpy_bug_workaround.py
pystencils/pacxx/benchmark.py src/pystencils/cache.py
src/pystencils/pacxx/benchmark.py
src/pystencils/_version.py
venv/
[report] [report]
exclude_lines = exclude_lines =
...@@ -24,10 +42,12 @@ exclude_lines = ...@@ -24,10 +42,12 @@ exclude_lines =
pragma: no cover pragma: no cover
def __repr__ def __repr__
def _repr_html_
# Don't complain if tests don't hit defensive assertion code: # Don't complain if tests don't hit defensive assertion code:
raise AssertionError raise AssertionError
raise NotImplementedError raise NotImplementedError
NotImplementedError()
#raise ValueError #raise ValueError
# Don't complain if non-runnable code isn't run: # Don't complain if non-runnable code isn't run:
...@@ -36,7 +56,7 @@ exclude_lines = ...@@ -36,7 +56,7 @@ exclude_lines =
if __name__ == .__main__.: if __name__ == .__main__.:
skip_covered = True skip_covered = True
fail_under = 74 fail_under = 85
[html] [html]
directory = coverage_report directory = coverage_report
#!/usr/bin/env python3
from contextlib import redirect_stdout
import io
from tests.test_quicktests import (
test_basic_kernel,
test_basic_blocking_staggered,
test_basic_vectorization,
)
quick_tests = [
test_basic_kernel,
test_basic_blocking_staggered,
test_basic_vectorization,
]
if __name__ == "__main__":
print("Running pystencils quicktests")
for qt in quick_tests:
print(f" -> {qt.__name__}")
with redirect_stdout(io.StringIO()):
qt()
...@@ -8,6 +8,6 @@ read new_version ...@@ -8,6 +8,6 @@ read new_version
git tag -s release/${new_version} git tag -s release/${new_version}
git push origin master release/${new_version} git push origin master release/${new_version}
python setup.py sdist bdist_wheel
rm -rf dist rm -rf dist
twine upload dist/* python setup.py sdist
\ No newline at end of file twine upload dist/*
import distutils from setuptools import setup, __version__ as setuptools_version
import io
import os
import sys
from contextlib import redirect_stdout
from distutils.extension import Extension
from importlib import import_module
from setuptools import find_packages, setup if int(setuptools_version.split('.')[0]) < 61:
raise Exception(
"[ERROR] pystencils requires at least setuptools version 61 to install.\n"
"If this error occurs during an installation via pip, it is likely that there is a conflict between "
"versions of setuptools installed by pip and the system package manager. "
"In this case, it is recommended to install pystencils into a virtual environment instead."
)
if '--use-cython' in sys.argv: import versioneer
USE_CYTHON = True
sys.argv.remove('--use-cython')
else:
USE_CYTHON = False
quick_tests = [
'test_datahandling.test_kernel',
'test_blocking_staggered.test_blocking_staggered',
'test_blocking_staggered.test_blocking_staggered',
'test_vectorization.test_vectorization_variable_size',
]
def get_cmdclass():
return versioneer.get_cmdclass()
class SimpleTestRunner(distutils.cmd.Command):
"""A custom command to run selected tests"""
description = 'run some quick tests' setup(
user_options = [] version=versioneer.get_version(),
cmdclass=get_cmdclass(),
@staticmethod )
def _run_tests_in_module(test):
"""Short test runner function - to work also if py.test is not installed."""
test = 'pystencils_tests.' + test
mod, function_name = test.rsplit('.', 1)
if isinstance(mod, str):
mod = import_module(mod)
func = getattr(mod, function_name)
print(" -> %s in %s" % (function_name, mod.__name__))
with redirect_stdout(io.StringIO()):
func()
def initialize_options(self):
pass
def finalize_options(self):
pass
def run(self):
"""Run command."""
for test in quick_tests:
self._run_tests_in_module(test)
def readme():
with open('README.md') as f:
return f.read()
def cython_extensions(*extensions):
ext = '.pyx' if USE_CYTHON else '.c'
result = [Extension(e, [e.replace('.', '/') + ext]) for e in extensions]
if USE_CYTHON:
from Cython.Build import cythonize
result = cythonize(result, language_level=3)
return result
try:
sys.path.insert(0, os.path.abspath('doc'))
from version_from_git import version_number_from_git
version = version_number_from_git()
with open("RELEASE-VERSION", "w") as f:
f.write(version)
except ImportError:
version = open('RELEASE-VERSION', 'r').read()
setup(name='pystencils',
description='Speeding up stencil computations on CPUs and GPUs',
version=version,
long_description=readme(),
long_description_content_type="text/markdown",
author='Martin Bauer',
license='AGPLv3',
author_email='martin.bauer@fau.de',
url='https://i10git.cs.fau.de/pycodegen/pystencils/',
packages=['pystencils'] + ['pystencils.' + s for s in find_packages('pystencils')],
install_requires=['sympy>=1.1', 'numpy', 'appdirs', 'joblib'],
package_data={'pystencils': ['include/*.h', 'backends/cuda_known_functions.txt']},
ext_modules=cython_extensions("pystencils.boundaries.createindexlistcython"),
classifiers=[
'Development Status :: 4 - Beta',
'Framework :: Jupyter',
'Topic :: Software Development :: Code Generators',
'Topic :: Scientific/Engineering :: Physics',
'Intended Audience :: Developers',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)',
],
project_urls={
"Bug Tracker": "https://i10git.cs.fau.de/pycodegen/pystencils/issues",
"Documentation": "http://pycodegen.pages.walberla.net/pystencils/",
"Source Code": "https://i10git.cs.fau.de/pycodegen/pystencils",
},
extras_require={
'gpu': ['pycuda'],
'alltrafos': ['islpy', 'py-cpuinfo'],
'bench_db': ['blitzdb', 'pymongo', 'pandas'],
'interactive': ['matplotlib', 'ipy_table', 'imageio', 'jupyter', 'pyevtk'],
'autodiff': ['pystencils-autodiff'],
'doc': ['sphinx', 'sphinx_rtd_theme', 'nbsphinx',
'sphinxcontrib-bibtex', 'sphinx_autodoc_typehints', 'pandoc'],
},
tests_require=['pytest', 'pytest-cov', 'pytest-xdist', 'flake8', 'nbformat', 'nbconvert', 'ipython'],
python_requires=">=3.6",
cmdclass={
'quicktest': SimpleTestRunner
},
)
"""Module to generate stencil kernels in C or CUDA using sympy expressions and call them as Python functions""" """Module to generate stencil kernels in C or CUDA using sympy expressions and call them as Python functions"""
from . import sympy_gmpy_bug_workaround # NOQA from .enums import Backend, Target
from . import fd from . import fd
from . import stencil as stencil from . import stencil as stencil
from .assignment import Assignment, assignment_from_stencil from .assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil
from .data_types import TypedSymbol from .typing.typed_sympy import TypedSymbol
from .datahandling import create_data_handling from .display_utils import get_code_obj, get_code_str, show_code, to_dot
from .display_utils import show_code, to_dot
from .field import Field, FieldType, fields from .field import Field, FieldType, fields
from .kernel_decorator import kernel from .config import CreateKernelConfig
from .kernelcreation import create_indexed_kernel, create_kernel, create_staggered_kernel from .cache import clear_cache
from .kernel_decorator import kernel, kernel_config
from .kernelcreation import create_kernel, create_staggered_kernel
from .simp import AssignmentCollection from .simp import AssignmentCollection
from .slicing import make_slice from .slicing import make_slice
from .spatial_coordinates import x_, x_staggered, x_staggered_vector, x_vector, y_, y_staggered, z_, z_staggered
from .sympyextensions import SymbolCreator from .sympyextensions import SymbolCreator
from .datahandling import create_data_handling
try:
import pystencils_autodiff
autodiff = pystencils_autodiff
except ImportError:
pass
__all__ = ['Field', 'FieldType', 'fields', __all__ = ['Field', 'FieldType', 'fields',
'TypedSymbol', 'TypedSymbol',
'make_slice', 'make_slice',
'create_kernel', 'create_indexed_kernel', 'create_staggered_kernel', 'CreateKernelConfig',
'show_code', 'to_dot', 'create_kernel', 'create_staggered_kernel',
'Target', 'Backend',
'show_code', 'to_dot', 'get_code_obj', 'get_code_str',
'AssignmentCollection', 'AssignmentCollection',
'Assignment', 'Assignment', 'AddAugmentedAssignment',
'assignment_from_stencil', 'assignment_from_stencil',
'SymbolCreator', 'SymbolCreator',
'create_data_handling', 'create_data_handling',
'kernel', 'clear_cache',
'kernel', 'kernel_config',
'x_', 'y_', 'z_',
'x_staggered', 'y_staggered', 'z_staggered',
'x_vector', 'x_staggered_vector',
'fd', 'fd',
'stencil'] 'stencil']
from . import _version
__version__ = _version.get_versions()['version']
# This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag
# feature). Distribution tarballs (built by setup.py sdist) and build
# directories (produced by setup.py build) will contain a much shorter file
# that just contains the computed version number.
# This file is released into the public domain.
# Generated by versioneer-0.29
# https://github.com/python-versioneer/python-versioneer
"""Git implementation of _version.py."""
import errno
import os
import re
import subprocess
import sys
from typing import Any, Callable, Dict, List, Optional, Tuple
import functools
def get_keywords() -> Dict[str, str]:
"""Get the keywords needed to look up the version information."""
# these strings will be replaced by git during git-archive.
# setup.py/versioneer.py will grep for the variable names, so they must
# each be defined on a line of their own. _version.py will just call
# get_keywords().
git_refnames = "$Format:%d$"
git_full = "$Format:%H$"
git_date = "$Format:%ci$"
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
return keywords
class VersioneerConfig:
"""Container for Versioneer configuration parameters."""
VCS: str
style: str
tag_prefix: str
parentdir_prefix: str
versionfile_source: str
verbose: bool
def get_config() -> VersioneerConfig:
"""Create, populate and return the VersioneerConfig() object."""
# these strings are filled in when 'setup.py versioneer' creates
# _version.py
cfg = VersioneerConfig()
cfg.VCS = "git"
cfg.style = "pep440"
cfg.tag_prefix = "release/"
cfg.parentdir_prefix = "pystencils-"
cfg.versionfile_source = "src/pystencils/_version.py"
cfg.verbose = False
return cfg
class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario."""
LONG_VERSION_PY: Dict[str, str] = {}
HANDLERS: Dict[str, Dict[str, Callable]] = {}
def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator
"""Create decorator to mark a method as the handler of a VCS."""
def decorate(f: Callable) -> Callable:
"""Store f in HANDLERS[vcs][method]."""
if vcs not in HANDLERS:
HANDLERS[vcs] = {}
HANDLERS[vcs][method] = f
return f
return decorate
def run_command(
commands: List[str],
args: List[str],
cwd: Optional[str] = None,
verbose: bool = False,
hide_stderr: bool = False,
env: Optional[Dict[str, str]] = None,
) -> Tuple[Optional[str], Optional[int]]:
"""Call the given command(s)."""
assert isinstance(commands, list)
process = None
popen_kwargs: Dict[str, Any] = {}
if sys.platform == "win32":
# This hides the console window if pythonw.exe is used
startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
popen_kwargs["startupinfo"] = startupinfo
for command in commands:
try:
dispcmd = str([command] + args)
# remember shell=False, so use git.cmd on windows, not just git
process = subprocess.Popen([command] + args, cwd=cwd, env=env,
stdout=subprocess.PIPE,
stderr=(subprocess.PIPE if hide_stderr
else None), **popen_kwargs)
break
except OSError as e:
if e.errno == errno.ENOENT:
continue
if verbose:
print("unable to run %s" % dispcmd)
print(e)
return None, None
else:
if verbose:
print("unable to find command, tried %s" % (commands,))
return None, None
stdout = process.communicate()[0].strip().decode()
if process.returncode != 0:
if verbose:
print("unable to run %s (error)" % dispcmd)
print("stdout was %s" % stdout)
return None, process.returncode
return stdout, process.returncode
def versions_from_parentdir(
parentdir_prefix: str,
root: str,
verbose: bool,
) -> Dict[str, Any]:
"""Try to determine the version from the parent directory name.
Source tarballs conventionally unpack into a directory that includes both
the project name and a version string. We will also support searching up
two directory levels for an appropriately named parent directory
"""
rootdirs = []
for _ in range(3):
dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix):
return {"version": dirname[len(parentdir_prefix):],
"full-revisionid": None,
"dirty": False, "error": None, "date": None}
rootdirs.append(root)
root = os.path.dirname(root) # up a level
if verbose:
print("Tried directories %s but none started with prefix %s" %
(str(rootdirs), parentdir_prefix))
raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
@register_vcs_handler("git", "get_keywords")
def git_get_keywords(versionfile_abs: str) -> Dict[str, str]:
"""Extract version information from the given file."""
# the code embedded in _version.py can just fetch the value of these
# keywords. When used from setup.py, we don't want to import _version.py,
# so we do it with a regexp instead. This function is not used from
# _version.py.
keywords: Dict[str, str] = {}
try:
with open(versionfile_abs, "r") as fobj:
for line in fobj:
if line.strip().startswith("git_refnames ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["refnames"] = mo.group(1)
if line.strip().startswith("git_full ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["full"] = mo.group(1)
if line.strip().startswith("git_date ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["date"] = mo.group(1)
except OSError:
pass
return keywords
@register_vcs_handler("git", "keywords")
def git_versions_from_keywords(
keywords: Dict[str, str],
tag_prefix: str,
verbose: bool,
) -> Dict[str, Any]:
"""Get version information from git keywords."""
if "refnames" not in keywords:
raise NotThisMethod("Short version file found")
date = keywords.get("date")
if date is not None:
# Use only the last line. Previous lines may contain GPG signature
# information.
date = date.splitlines()[-1]
# git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
# datestamp. However we prefer "%ci" (which expands to an "ISO-8601
# -like" string, which we must then edit to make compliant), because
# it's been around since git-1.5.3, and it's too difficult to
# discover which version we're using, or to work around using an
# older one.
date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
refnames = keywords["refnames"].strip()
if refnames.startswith("$Format"):
if verbose:
print("keywords are unexpanded, not using")
raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
refs = {r.strip() for r in refnames.strip("()").split(",")}
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
# expansion behaves like git log --decorate=short and strips out the
# refs/heads/ and refs/tags/ prefixes that would let us distinguish
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
tags = {r for r in refs if re.search(r'\d', r)}
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
print("likely tags: %s" % ",".join(sorted(tags)))
for ref in sorted(tags):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix):
r = ref[len(tag_prefix):]
# Filter out refs that exactly match prefix or that don't start
# with a number once the prefix is stripped (mostly a concern
# when prefix is '')
if not re.match(r'\d', r):
continue
if verbose:
print("picking %s" % r)
return {"version": r,
"full-revisionid": keywords["full"].strip(),
"dirty": False, "error": None,
"date": date}
# no suitable tags, so version is "0+unknown", but full hex is still there
if verbose:
print("no suitable tags, using unknown + full revision id")
return {"version": "0+unknown",
"full-revisionid": keywords["full"].strip(),
"dirty": False, "error": "no suitable tags", "date": None}
@register_vcs_handler("git", "pieces_from_vcs")
def git_pieces_from_vcs(
tag_prefix: str,
root: str,
verbose: bool,
runner: Callable = run_command
) -> Dict[str, Any]:
"""Get version from 'git describe' in the root of the source tree.
This only gets called if the git-archive 'subst' keywords were *not*
expanded, and _version.py hasn't already been rewritten with a short
version string, meaning we're inside a checked out source tree.
"""
GITS = ["git"]
if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"]
# GIT_DIR can interfere with correct operation of Versioneer.
# It may be intended to be passed to the Versioneer-versioned project,
# but that should not change where we get our version from.
env = os.environ.copy()
env.pop("GIT_DIR", None)
runner = functools.partial(runner, env=env)
_, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
hide_stderr=not verbose)
if rc != 0:
if verbose:
print("Directory %s not under git control" % root)
raise NotThisMethod("'git rev-parse --git-dir' returned error")
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
describe_out, rc = runner(GITS, [
"describe", "--tags", "--dirty", "--always", "--long",
"--match", f"{tag_prefix}[[:digit:]]*"
], cwd=root)
# --long was added in git-1.5.5
if describe_out is None:
raise NotThisMethod("'git describe' failed")
describe_out = describe_out.strip()
full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
if full_out is None:
raise NotThisMethod("'git rev-parse' failed")
full_out = full_out.strip()
pieces: Dict[str, Any] = {}
pieces["long"] = full_out
pieces["short"] = full_out[:7] # maybe improved later
pieces["error"] = None
branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
cwd=root)
# --abbrev-ref was added in git-1.6.3
if rc != 0 or branch_name is None:
raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
branch_name = branch_name.strip()
if branch_name == "HEAD":
# If we aren't exactly on a branch, pick a branch which represents
# the current commit. If all else fails, we are on a branchless
# commit.
branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
# --contains was added in git-1.5.4
if rc != 0 or branches is None:
raise NotThisMethod("'git branch --contains' returned error")
branches = branches.split("\n")
# Remove the first line if we're running detached
if "(" in branches[0]:
branches.pop(0)
# Strip off the leading "* " from the list of branches.
branches = [branch[2:] for branch in branches]
if "master" in branches:
branch_name = "master"
elif not branches:
branch_name = None
else:
# Pick the first branch that is returned. Good or bad.
branch_name = branches[0]
pieces["branch"] = branch_name
# parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
# TAG might have hyphens.
git_describe = describe_out
# look for -dirty suffix
dirty = git_describe.endswith("-dirty")
pieces["dirty"] = dirty
if dirty:
git_describe = git_describe[:git_describe.rindex("-dirty")]
# now we have TAG-NUM-gHEX or HEX
if "-" in git_describe:
# TAG-NUM-gHEX
mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
if not mo:
# unparsable. Maybe git-describe is misbehaving?
pieces["error"] = ("unable to parse git-describe output: '%s'"
% describe_out)
return pieces
# tag
full_tag = mo.group(1)
if not full_tag.startswith(tag_prefix):
if verbose:
fmt = "tag '%s' doesn't start with prefix '%s'"
print(fmt % (full_tag, tag_prefix))
pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
% (full_tag, tag_prefix))
return pieces
pieces["closest-tag"] = full_tag[len(tag_prefix):]
# distance: number of commits since tag
pieces["distance"] = int(mo.group(2))
# commit: short hex revision ID
pieces["short"] = mo.group(3)
else:
# HEX: no tags
pieces["closest-tag"] = None
out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root)
pieces["distance"] = len(out.split()) # total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords()
date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
# Use only the last line. Previous lines may contain GPG signature
# information.
date = date.splitlines()[-1]
pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
return pieces
def plus_or_dot(pieces: Dict[str, Any]) -> str:
"""Return a + if we don't already have one, else return a ."""
if "+" in pieces.get("closest-tag", ""):
return "."
return "+"
def render_pep440(pieces: Dict[str, Any]) -> str:
"""Build up version string, with post-release "local version identifier".
Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
Exceptions:
1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += plus_or_dot(pieces)
rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0+untagged.%d.g%s" % (pieces["distance"],
pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def render_pep440_branch(pieces: Dict[str, Any]) -> str:
"""TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
The ".dev0" means not master branch. Note that .dev0 sorts backwards
(a feature branch will appear "older" than the master branch).
Exceptions:
1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0"
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += "+untagged.%d.g%s" % (pieces["distance"],
pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]:
"""Split pep440 version string at the post-release segment.
Returns the release segments before the post-release and the
post-release version number (or -1 if no post-release segment is present).
"""
vc = str.split(ver, ".post")
return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
def render_pep440_pre(pieces: Dict[str, Any]) -> str:
"""TAG[.postN.devDISTANCE] -- No -dirty.
Exceptions:
1: no tags. 0.post0.devDISTANCE
"""
if pieces["closest-tag"]:
if pieces["distance"]:
# update the post release segment
tag_version, post_version = pep440_split_post(pieces["closest-tag"])
rendered = tag_version
if post_version is not None:
rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"])
else:
rendered += ".post0.dev%d" % (pieces["distance"])
else:
# no commits, use the tag as the version
rendered = pieces["closest-tag"]
else:
# exception #1
rendered = "0.post0.dev%d" % pieces["distance"]
return rendered
def render_pep440_post(pieces: Dict[str, Any]) -> str:
"""TAG[.postDISTANCE[.dev0]+gHEX] .
The ".dev0" means dirty. Note that .dev0 sorts backwards
(a dirty tree will appear "older" than the corresponding clean one),
but you shouldn't be releasing software with -dirty anyways.
Exceptions:
1: no tags. 0.postDISTANCE[.dev0]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += ".post%d" % pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "g%s" % pieces["short"]
else:
# exception #1
rendered = "0.post%d" % pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
rendered += "+g%s" % pieces["short"]
return rendered
def render_pep440_post_branch(pieces: Dict[str, Any]) -> str:
"""TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
The ".dev0" means not master branch.
Exceptions:
1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += ".post%d" % pieces["distance"]
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "g%s" % pieces["short"]
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0.post%d" % pieces["distance"]
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += "+g%s" % pieces["short"]
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def render_pep440_old(pieces: Dict[str, Any]) -> str:
"""TAG[.postDISTANCE[.dev0]] .
The ".dev0" means dirty.
Exceptions:
1: no tags. 0.postDISTANCE[.dev0]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += ".post%d" % pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
else:
# exception #1
rendered = "0.post%d" % pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
return rendered
def render_git_describe(pieces: Dict[str, Any]) -> str:
"""TAG[-DISTANCE-gHEX][-dirty].
Like 'git describe --tags --dirty --always'.
Exceptions:
1: no tags. HEX[-dirty] (note: no 'g' prefix)
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"]:
rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
else:
# exception #1
rendered = pieces["short"]
if pieces["dirty"]:
rendered += "-dirty"
return rendered
def render_git_describe_long(pieces: Dict[str, Any]) -> str:
"""TAG-DISTANCE-gHEX[-dirty].
Like 'git describe --tags --dirty --always -long'.
The distance/hash is unconditional.
Exceptions:
1: no tags. HEX[-dirty] (note: no 'g' prefix)
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
else:
# exception #1
rendered = pieces["short"]
if pieces["dirty"]:
rendered += "-dirty"
return rendered
def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]:
"""Render the given version pieces into the requested style."""
if pieces["error"]:
return {"version": "unknown",
"full-revisionid": pieces.get("long"),
"dirty": None,
"error": pieces["error"],
"date": None}
if not style or style == "default":
style = "pep440" # the default
if style == "pep440":
rendered = render_pep440(pieces)
elif style == "pep440-branch":
rendered = render_pep440_branch(pieces)
elif style == "pep440-pre":
rendered = render_pep440_pre(pieces)
elif style == "pep440-post":
rendered = render_pep440_post(pieces)
elif style == "pep440-post-branch":
rendered = render_pep440_post_branch(pieces)
elif style == "pep440-old":
rendered = render_pep440_old(pieces)
elif style == "git-describe":
rendered = render_git_describe(pieces)
elif style == "git-describe-long":
rendered = render_git_describe_long(pieces)
else:
raise ValueError("unknown style '%s'" % style)
return {"version": rendered, "full-revisionid": pieces["long"],
"dirty": pieces["dirty"], "error": None,
"date": pieces.get("date")}
def get_versions() -> Dict[str, Any]:
"""Get version information or return default if unable to do so."""
# I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
# __file__, we can work backwards from there to the root. Some
# py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
# case we can only use expanded keywords.
cfg = get_config()
verbose = cfg.verbose
try:
return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
verbose)
except NotThisMethod:
pass
try:
root = os.path.realpath(__file__)
# versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__.
for _ in cfg.versionfile_source.split('/'):
root = os.path.dirname(root)
except NameError:
return {"version": "0+unknown", "full-revisionid": None,
"dirty": None,
"error": "unable to find root of source tree",
"date": None}
try:
pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
return render(pieces, cfg.style)
except NotThisMethod:
pass
try:
if cfg.parentdir_prefix:
return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
except NotThisMethod:
pass
return {"version": "0+unknown", "full-revisionid": None,
"dirty": None,
"error": "unable to compute version", "date": None}
import numpy as np import numpy as np
def aligned_empty(shape, byte_alignment=32, dtype=np.float64, byte_offset=0, order='C', align_inner_coordinate=True): def aligned_empty(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, order='C', align_inner_coordinate=True):
""" """
Creates an aligned empty numpy array Creates an aligned empty numpy array
Args: Args:
shape: size of the array shape: size of the array
byte_alignment: alignment in bytes, for the start address of the array holds (a % byte_alignment) == 0 byte_alignment: alignment in bytes, for the start address of the array holds (a % byte_alignment) == 0
By default, use the maximum required by the CPU (or 512 bits if this cannot be detected).
When 'cacheline' is specified, the size of a cache line is used.
dtype: numpy data type dtype: numpy data type
byte_offset: offset in bytes for position that should be aligned i.e. (a+byte_offset) % byte_alignment == 0 byte_offset: offset in bytes for position that should be aligned i.e. (a+byte_offset) % byte_alignment == 0
typically used to align first inner cell instead of ghost layer typically used to align first inner cell instead of ghost layer
order: storage linearization order order: storage linearization order
align_inner_coordinate: if True, the start of the innermost coordinate lines are aligned as well align_inner_coordinate: if True, the start of the innermost coordinate lines are aligned as well
""" """
if byte_alignment is True or byte_alignment == 'cacheline':
from pystencils.backends.simd_instruction_sets import (get_supported_instruction_sets, get_cacheline_size,
get_vector_instruction_set)
instruction_sets = get_supported_instruction_sets()
if instruction_sets is None:
byte_alignment = 64
elif byte_alignment == 'cacheline':
cacheline_sizes = [get_cacheline_size(is_name) for is_name in instruction_sets]
if all([s is None for s in cacheline_sizes]) or \
max([s for s in cacheline_sizes if s is not None]) > 0x100000:
widths = [get_vector_instruction_set(dtype, is_name)['width'] * np.dtype(dtype).itemsize
for is_name in instruction_sets
if type(get_vector_instruction_set(dtype, is_name)['width']) is int]
byte_alignment = 64 if all([s is None for s in widths]) else max(widths)
else:
byte_alignment = max([s for s in cacheline_sizes if s is not None])
elif not any([type(get_vector_instruction_set(dtype, is_name)['width']) is int
for is_name in instruction_sets]):
byte_alignment = 64
else:
byte_alignment = max([get_vector_instruction_set(dtype, is_name)['width'] * np.dtype(dtype).itemsize
for is_name in instruction_sets
if type(get_vector_instruction_set(dtype, is_name)['width']) is int])
if (not align_inner_coordinate) or (not hasattr(shape, '__len__')): if (not align_inner_coordinate) or (not hasattr(shape, '__len__')):
size = np.prod(shape) size = np.prod(shape)
d = np.dtype(dtype) d = np.dtype(dtype)
...@@ -51,7 +77,7 @@ def aligned_empty(shape, byte_alignment=32, dtype=np.float64, byte_offset=0, ord ...@@ -51,7 +77,7 @@ def aligned_empty(shape, byte_alignment=32, dtype=np.float64, byte_offset=0, ord
return tmp return tmp
def aligned_zeros(shape, byte_alignment=16, dtype=float, byte_offset=0, order='C', align_inner_coordinate=True): def aligned_zeros(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, order='C', align_inner_coordinate=True):
arr = aligned_empty(shape, dtype=dtype, byte_offset=byte_offset, arr = aligned_empty(shape, dtype=dtype, byte_offset=byte_offset,
order=order, byte_alignment=byte_alignment, align_inner_coordinate=align_inner_coordinate) order=order, byte_alignment=byte_alignment, align_inner_coordinate=align_inner_coordinate)
x = np.zeros((), arr.dtype) x = np.zeros((), arr.dtype)
...@@ -59,7 +85,7 @@ def aligned_zeros(shape, byte_alignment=16, dtype=float, byte_offset=0, order='C ...@@ -59,7 +85,7 @@ def aligned_zeros(shape, byte_alignment=16, dtype=float, byte_offset=0, order='C
return arr return arr
def aligned_ones(shape, byte_alignment=16, dtype=float, byte_offset=0, order='C', align_inner_coordinate=True): def aligned_ones(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, order='C', align_inner_coordinate=True):
arr = aligned_empty(shape, dtype=dtype, byte_offset=byte_offset, arr = aligned_empty(shape, dtype=dtype, byte_offset=byte_offset,
order=order, byte_alignment=byte_alignment, align_inner_coordinate=align_inner_coordinate) order=order, byte_alignment=byte_alignment, align_inner_coordinate=align_inner_coordinate)
x = np.ones((), arr.dtype) x = np.ones((), arr.dtype)
......
# -*- coding: utf-8 -*-
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment
from sympy.printing.latex import LatexPrinter from sympy.printing.latex import LatexPrinter
try: __all__ = ['Assignment', 'AugmentedAssignment', 'AddAugmentedAssignment', 'assignment_from_stencil']
from sympy.codegen.ast import Assignment
except ImportError:
Assignment = None
__all__ = ['Assignment', 'assignment_from_stencil']
def print_assignment_latex(printer, expr): def print_assignment_latex(printer, expr):
binop = f"{expr.binop}=" if isinstance(expr, AugmentedAssignment) else ''
"""sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer""" """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
printed_lhs = printer.doprint(expr.lhs) printed_lhs = printer.doprint(expr.lhs)
printed_rhs = printer.doprint(expr.rhs) printed_rhs = printer.doprint(expr.rhs)
return r"{printed_lhs} \leftarrow {printed_rhs}".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs) return fr"{printed_lhs} \leftarrow_{{{binop}}} {printed_rhs}"
def assignment_str(assignment): def assignment_str(assignment):
return r"{lhs} ← {rhs}".format(lhs=assignment.lhs, rhs=assignment.rhs) op = f"{assignment.binop}=" if isinstance(assignment, AugmentedAssignment) else ''
return fr"{assignment.lhs} {op} {assignment.rhs}"
if Assignment:
Assignment.__str__ = assignment_str
LatexPrinter._print_Assignment = print_assignment_latex
else: _old_new = sp.codegen.ast.Assignment.__new__
# back port for older sympy versions that don't have Assignment yet
class Assignment(sp.Rel): # pragma: no cover
rel_op = ':=' # TODO Typing Part2 add default type, defult_float_type, default_int_type and use sane defaults
__slots__ = [] def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)):
assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!'
return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
return _old_new(cls, lhs, rhs, *args, **kwargs)
def __new__(cls, lhs, rhs=0, **assumptions):
from sympy.matrices.expressions.matexpr import (
MatrixElement, MatrixSymbol)
lhs = sp.sympify(lhs)
rhs = sp.sympify(rhs)
# Tuple of things that can be on the lhs of an assignment
assignable = (sp.Symbol, MatrixSymbol, MatrixElement, sp.Indexed)
if not isinstance(lhs, assignable):
raise TypeError("Cannot assign to lhs of type %s." % type(lhs))
return sp.Rel.__new__(cls, lhs, rhs, **assumptions)
__str__ = assignment_str Assignment.__str__ = assignment_str
_print_Assignment = print_assignment_latex Assignment.__new__ = _Assignment__new__
LatexPrinter._print_Assignment = print_assignment_latex
# Apparently, in SymPy 1.4 Assignment.__hash__ is not implemented. This has been fixed in current master AugmentedAssignment.__str__ = assignment_str
try: LatexPrinter._print_AugmentedAssignment = print_assignment_latex
sympy_version = sp.__version__.split('.')
if int(sympy_version[0]) <= 1 and int(sympy_version[1]) <= 4: sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
def hash_fun(self):
return hash((self.lhs, self.rhs))
Assignment.__hash__ = hash_fun
except Exception:
pass
def assignment_from_stencil(stencil_array, input_field, output_field, def assignment_from_stencil(stencil_array, input_field, output_field,
...@@ -84,16 +63,19 @@ def assignment_from_stencil(stencil_array, input_field, output_field, ...@@ -84,16 +63,19 @@ def assignment_from_stencil(stencil_array, input_field, output_field,
... [0, 6, 0]] ... [0, 6, 0]]
By default 'visual ordering is used - i.e. the stencil is applied as the nested lists are written down By default 'visual ordering is used - i.e. the stencil is applied as the nested lists are written down
>>> assignment_from_stencil(stencil, f, g, order='visual') >>> expected_output = Assignment(g[0, 0], 3*f[-1, 0] + 6*f[0, -1] + 4*f[0, 0] + 2*f[0, 1] + 5*f[1, 0])
Assignment(g_C, 3*f_W + 6*f_S + 4*f_C + 2*f_N + 5*f_E) >>> assignment_from_stencil(stencil, f, g, order='visual') == expected_output
True
'numpy' ordering uses the first coordinate of the stencil array for x offset, second for y offset etc. 'numpy' ordering uses the first coordinate of the stencil array for x offset, second for y offset etc.
>>> assignment_from_stencil(stencil, f, g, order='numpy') >>> expected_output = Assignment(g[0, 0], 2*f[-1, 0] + 3*f[0, -1] + 4*f[0, 0] + 5*f[0, 1] + 6*f[1, 0])
Assignment(g_C, 2*f_W + 3*f_S + 4*f_C + 5*f_N + 6*f_E) >>> assignment_from_stencil(stencil, f, g, order='numpy') == expected_output
True
You can also pass field accesses to apply the stencil at an already shifted position: You can also pass field accesses to apply the stencil at an already shifted position:
>>> assignment_from_stencil(stencil, f[1, 0], g[2, 0]) >>> expected_output = Assignment(g[2, 0], 3*f[0, 0] + 6*f[1, -1] + 4*f[1, 0] + 2*f[1, 1] + 5*f[2, 0])
Assignment(g_2E, 3*f_C + 6*f_SE + 4*f_E + 2*f_NE + 5*f_2E) >>> assignment_from_stencil(stencil, f[1, 0], g[2, 0]) == expected_output
True
""" """
from pystencils.field import Field from pystencils.field import Field
......
import collections.abc
import itertools
import uuid import uuid
from typing import Any, List, Optional, Sequence, Set, Union from typing import Any, List, Optional, Sequence, Set, Union
import sympy as sp import sympy as sp
from pystencils.data_types import TypedSymbol, cast_func, create_type from pystencils.assignment import Assignment
from pystencils.enums import Target, Backend
from pystencils.field import Field from pystencils.field import Field
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils.sympyextensions import fast_subs from pystencils.sympyextensions import fast_subs
from pystencils.typing import (create_type, get_next_parent_of_type,
FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol, CFunction)
NodeOrExpr = Union['Node', sp.Expr] NodeOrExpr = Union['Node', sp.Expr]
...@@ -33,9 +37,13 @@ class Node: ...@@ -33,9 +37,13 @@ class Node:
raise NotImplementedError() raise NotImplementedError()
def subs(self, subs_dict) -> None: def subs(self, subs_dict) -> None:
"""Inplace! substitute, similar to sympy's but modifies the AST inplace.""" """Inplace! Substitute, similar to sympy's but modifies the AST inplace."""
for a in self.args: for i, a in enumerate(self.args):
a.subs(subs_dict) result = a.subs(subs_dict)
if isinstance(a, sp.Expr): # sympy expressions' subs is out-of-place
self.args[i] = result
else: # all other should be in-place
assert result is None
@property @property
def func(self): def func(self):
...@@ -102,14 +110,22 @@ class Conditional(Node): ...@@ -102,14 +110,22 @@ class Conditional(Node):
result = self.true_block.undefined_symbols result = self.true_block.undefined_symbols
if self.false_block: if self.false_block:
result.update(self.false_block.undefined_symbols) result.update(self.false_block.undefined_symbols)
result.update(self.condition_expr.atoms(sp.Symbol)) if hasattr(self.condition_expr, 'atoms'):
result.update(self.condition_expr.atoms(sp.Symbol))
return result return result
def __str__(self): def __str__(self):
return 'if:({!s}) '.format(self.condition_expr) return self.__repr__()
def __repr__(self): def __repr__(self):
return 'if:({!r}) '.format(self.condition_expr) result = f'if:({self.condition_expr!r}) '
if self.true_block:
result += f'\n\t{self.true_block}) '
if self.false_block:
result = 'else: '
result += f'\n\t{self.false_block} '
return result
def replace_by_true_block(self): def replace_by_true_block(self):
"""Replaces the conditional by its True block""" """Replaces the conditional by its True block"""
...@@ -121,7 +137,6 @@ class Conditional(Node): ...@@ -121,7 +137,6 @@ class Conditional(Node):
class KernelFunction(Node): class KernelFunction(Node):
class Parameter: class Parameter:
"""Function parameter. """Function parameter.
...@@ -161,7 +176,9 @@ class KernelFunction(Node): ...@@ -161,7 +176,9 @@ class KernelFunction(Node):
def field_name(self): def field_name(self):
return self.fields[0].name return self.fields[0].name
def __init__(self, body, target, backend, compile_function, ghost_layers, function_name="kernel"): def __init__(self, body, target: Target, backend: Backend, compile_function, ghost_layers,
function_name: str = "kernel",
assignments=None):
super(KernelFunction, self).__init__() super(KernelFunction, self).__init__()
self._body = body self._body = body
body.parent = self body.parent = self
...@@ -175,15 +192,20 @@ class KernelFunction(Node): ...@@ -175,15 +192,20 @@ class KernelFunction(Node):
self.instruction_set = None # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use self.instruction_set = None # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use
# function that compiles the node to a Python callable, is set by the backends # function that compiles the node to a Python callable, is set by the backends
self._compile_function = compile_function self._compile_function = compile_function
self.assignments = assignments
# If nontemporal stores are activated together with the Neon instruction set it results in cacheline zeroing
# For cacheline zeroing the information of the field size for each field is needed. Thus, in this case
# all field sizes are kernel parameters and not just the common field size used for the loops
self.use_all_written_field_sizes = False
@property @property
def target(self): def target(self):
"""Currently either 'cpu' or 'gpu' """ """See pystencils.Target"""
return self._target return self._target
@property @property
def backend(self): def backend(self):
"""Backend for generating the code e.g. 'llvm', 'c', 'cuda' """ """Backend for generating the code: `Backend`"""
return self._backend return self._backend
@property @property
...@@ -208,13 +230,21 @@ class KernelFunction(Node): ...@@ -208,13 +230,21 @@ class KernelFunction(Node):
return self._body, return self._body,
@property @property
def fields_accessed(self) -> Set['ResolvedFieldAccess']: def fields_accessed(self) -> Set[Field]:
"""Set of Field instances: fields which are accessed inside this kernel function""" """Set of Field instances: fields which are accessed inside this kernel function"""
return set(o.field for o in self.atoms(ResolvedFieldAccess)) return set(o.field for o in itertools.chain(self.atoms(ResolvedFieldAccess)))
def fields_written(self): @property
assigments = self.atoms(SympyAssignment) def fields_written(self) -> Set[Field]:
return {a.lhs.field for a in assigments if isinstance(a.lhs, ResolvedFieldAccess)} assignments = self.atoms(SympyAssignment)
return set().union(itertools.chain.from_iterable([f.field for f in a.lhs.free_symbols if hasattr(f, 'field')]
for a in assignments))
@property
def fields_read(self) -> Set[Field]:
assignments = self.atoms(SympyAssignment)
return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')]
for a in assignments))
def get_parameters(self) -> Sequence['KernelFunction.Parameter']: def get_parameters(self) -> Sequence['KernelFunction.Parameter']:
"""Returns list of parameters for this function. """Returns list of parameters for this function.
...@@ -222,6 +252,11 @@ class KernelFunction(Node): ...@@ -222,6 +252,11 @@ class KernelFunction(Node):
This function is expensive, cache the result where possible! This function is expensive, cache the result where possible!
""" """
field_map = {f.name: f for f in self.fields_accessed} field_map = {f.name: f for f in self.fields_accessed}
sizes = set()
if self.use_all_written_field_sizes:
sizes = set().union(*(a.shape[:a.spatial_dimensions] for a in self.fields_written))
sizes = filter(lambda s: isinstance(s, FieldShapeSymbol), sizes)
def get_fields(symbol): def get_fields(symbol):
if hasattr(symbol, 'field_name'): if hasattr(symbol, 'field_name'):
...@@ -231,9 +266,13 @@ class KernelFunction(Node): ...@@ -231,9 +266,13 @@ class KernelFunction(Node):
return () return ()
argument_symbols = self._body.undefined_symbols - self.global_variables argument_symbols = self._body.undefined_symbols - self.global_variables
argument_symbols.update(sizes)
parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols] parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
if hasattr(self, 'indexing'): if hasattr(self, 'indexing'):
parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()] parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()]
# Exclude paramters of type CFunction. These parameters will result in a C function call that will be handled
# by including a respective header file in the compute kernel. Hence, it is not a free parameter.
parameters = [p for p in parameters if not isinstance(p.symbol, CFunction)]
parameters.sort(key=lambda p: p.symbol.name) parameters.sort(key=lambda p: p.symbol.name)
return parameters return parameters
...@@ -244,7 +283,7 @@ class KernelFunction(Node): ...@@ -244,7 +283,7 @@ class KernelFunction(Node):
def __repr__(self): def __repr__(self):
params = [p.symbol for p in self.get_parameters()] params = [p.symbol for p in self.get_parameters()]
return '{0} {1}({2})'.format(type(self).__name__, self.function_name, params) return f'{type(self).__name__} {self.function_name}({params})'
def compile(self, *args, **kwargs): def compile(self, *args, **kwargs):
if self._compile_function is None: if self._compile_function is None:
...@@ -267,12 +306,17 @@ class SkipIteration(Node): ...@@ -267,12 +306,17 @@ class SkipIteration(Node):
class Block(Node): class Block(Node):
def __init__(self, nodes: List[Node]): def __init__(self, nodes: Union[Node, List[Node]]):
super(Block, self).__init__() super(Block, self).__init__()
if not isinstance(nodes, list):
nodes = [nodes]
self._nodes = nodes self._nodes = nodes
self.parent = None self.parent = None
for n in self._nodes: for n in self._nodes:
n.parent = self try:
n.parent = self
except AttributeError:
pass
@property @property
def args(self): def args(self):
...@@ -282,23 +326,39 @@ class Block(Node): ...@@ -282,23 +326,39 @@ class Block(Node):
for a in self.args: for a in self.args:
a.subs(subs_dict) a.subs(subs_dict)
def insert_front(self, node): def fast_subs(self, subs_dict, skip=None):
node.parent = self self._nodes = [fast_subs(a, subs_dict, skip) for a in self._nodes]
self._nodes.insert(0, node) return self
def insert_front(self, node, if_not_exists=False):
if if_not_exists and len(self._nodes) > 0 and self._nodes[0] == node:
return
if isinstance(node, collections.abc.Iterable):
node = list(node)
for n in node:
n.parent = self
self._nodes = node + self._nodes
else:
node.parent = self
self._nodes.insert(0, node)
def insert_before(self, new_node, insert_before): def insert_before(self, new_node, insert_before, if_not_exists=False):
new_node.parent = self new_node.parent = self
assert self._nodes.count(insert_before) == 1
idx = self._nodes.index(insert_before) idx = self._nodes.index(insert_before)
# move all assignment (definitions to the top) if not if_not_exists or self._nodes[idx] != new_node:
if isinstance(new_node, SympyAssignment) and new_node.is_declaration: self._nodes.insert(idx, new_node)
while idx > 0:
pn = self._nodes[idx - 1] def insert_after(self, new_node, insert_after, if_not_exists=False):
if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional): new_node.parent = self
idx -= 1 assert self._nodes.count(insert_after) == 1
else: idx = self._nodes.index(insert_after) + 1
break
self._nodes.insert(idx, new_node) if not if_not_exists or not (self._nodes[idx - 1] == new_node
or (idx < len(self._nodes) and self._nodes[idx] == new_node)):
self._nodes.insert(idx, new_node)
def append(self, node): def append(self, node):
if isinstance(node, list) or isinstance(node, tuple): if isinstance(node, list) or isinstance(node, tuple):
...@@ -315,6 +375,7 @@ class Block(Node): ...@@ -315,6 +375,7 @@ class Block(Node):
return tmp return tmp
def replace(self, child, replacements): def replace(self, child, replacements):
assert self._nodes.count(child) == 1
idx = self._nodes.index(child) idx = self._nodes.index(child)
del self._nodes[idx] del self._nodes[idx]
if type(replacements) is list: if type(replacements) is list:
...@@ -329,7 +390,10 @@ class Block(Node): ...@@ -329,7 +390,10 @@ class Block(Node):
def symbols_defined(self): def symbols_defined(self):
result = set() result = set()
for a in self.args: for a in self.args:
result.update(a.symbols_defined) if isinstance(a, Assignment):
result.update(a.free_symbols)
else:
result.update(a.symbols_defined)
return result return result
@property @property
...@@ -337,8 +401,12 @@ class Block(Node): ...@@ -337,8 +401,12 @@ class Block(Node):
result = set() result = set()
defined_symbols = set() defined_symbols = set()
for a in self.args: for a in self.args:
result.update(a.undefined_symbols) if isinstance(a, Assignment):
defined_symbols.update(a.symbols_defined) result.update(a.free_symbols)
defined_symbols.update({a.lhs})
else:
result.update(a.undefined_symbols)
defined_symbols.update(a.symbols_defined)
return result - defined_symbols return result - defined_symbols
def __str__(self): def __str__(self):
...@@ -361,9 +429,9 @@ class PragmaBlock(Block): ...@@ -361,9 +429,9 @@ class PragmaBlock(Block):
class LoopOverCoordinate(Node): class LoopOverCoordinate(Node):
LOOP_COUNTER_NAME_PREFIX = "ctr" LOOP_COUNTER_NAME_PREFIX = "ctr"
BlOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr" BLOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr"
def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False): def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False, custom_loop_ctr=None):
super(LoopOverCoordinate, self).__init__(parent=None) super(LoopOverCoordinate, self).__init__(parent=None)
self.body = body self.body = body
body.parent = self body.parent = self
...@@ -374,11 +442,12 @@ class LoopOverCoordinate(Node): ...@@ -374,11 +442,12 @@ class LoopOverCoordinate(Node):
self.body.parent = self self.body.parent = self
self.prefix_lines = [] self.prefix_lines = []
self.is_block_loop = is_block_loop self.is_block_loop = is_block_loop
self.custom_loop_ctr = custom_loop_ctr
def new_loop_with_different_body(self, new_body): def new_loop_with_different_body(self, new_body):
result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop, result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop,
self.step, self.is_block_loop) self.step, self.is_block_loop, self.custom_loop_ctr)
result.prefix_lines = [l for l in self.prefix_lines] result.prefix_lines = [prefix_line for prefix_line in self.prefix_lines]
return result return result
def subs(self, subs_dict): def subs(self, subs_dict):
...@@ -390,6 +459,16 @@ class LoopOverCoordinate(Node): ...@@ -390,6 +459,16 @@ class LoopOverCoordinate(Node):
if hasattr(self.step, "subs"): if hasattr(self.step, "subs"):
self.step = self.step.subs(subs_dict) self.step = self.step.subs(subs_dict)
def fast_subs(self, subs_dict, skip=None):
self.body = fast_subs(self.body, subs_dict, skip)
if isinstance(self.start, sp.Basic):
self.start = fast_subs(self.start, subs_dict, skip)
if isinstance(self.stop, sp.Basic):
self.stop = fast_subs(self.stop, subs_dict, skip)
if isinstance(self.step, sp.Basic):
self.step = fast_subs(self.step, subs_dict, skip)
return self
@property @property
def args(self): def args(self):
result = [self.body] result = [self.body]
...@@ -422,18 +501,21 @@ class LoopOverCoordinate(Node): ...@@ -422,18 +501,21 @@ class LoopOverCoordinate(Node):
@staticmethod @staticmethod
def get_loop_counter_name(coordinate_to_loop_over): def get_loop_counter_name(coordinate_to_loop_over):
return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over) return f"{LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}"
@staticmethod @staticmethod
def get_block_loop_counter_name(coordinate_to_loop_over): def get_block_loop_counter_name(coordinate_to_loop_over):
return "%s_%s" % (LoopOverCoordinate.BlOCK_LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over) return f"{LoopOverCoordinate.BLOCK_LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}"
@property @property
def loop_counter_name(self): def loop_counter_name(self):
if self.is_block_loop: if self.custom_loop_ctr:
return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over) return self.custom_loop_ctr.name
else: else:
return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over) if self.is_block_loop:
return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
else:
return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
@staticmethod @staticmethod
def is_loop_counter_symbol(symbol): def is_loop_counter_symbol(symbol):
...@@ -447,22 +529,26 @@ class LoopOverCoordinate(Node): ...@@ -447,22 +529,26 @@ class LoopOverCoordinate(Node):
@staticmethod @staticmethod
def get_loop_counter_symbol(coordinate_to_loop_over): def get_loop_counter_symbol(coordinate_to_loop_over):
return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int') return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True)
@staticmethod @staticmethod
def get_block_loop_counter_symbol(coordinate_to_loop_over): def get_block_loop_counter_symbol(coordinate_to_loop_over):
return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over), 'int') return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over),
'int',
nonnegative=True)
@property @property
def loop_counter_symbol(self): def loop_counter_symbol(self):
if self.is_block_loop: if self.custom_loop_ctr:
return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over) return self.custom_loop_ctr
else: else:
return self.get_loop_counter_symbol(self.coordinate_to_loop_over) if self.is_block_loop:
return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
else:
return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
@property @property
def is_outermost_loop(self): def is_outermost_loop(self):
from pystencils.transformations import get_next_parent_of_type
return get_next_parent_of_type(self, LoopOverCoordinate) is None return get_next_parent_of_type(self, LoopOverCoordinate) is None
@property @property
...@@ -482,15 +568,17 @@ class LoopOverCoordinate(Node): ...@@ -482,15 +568,17 @@ class LoopOverCoordinate(Node):
class SympyAssignment(Node): class SympyAssignment(Node):
def __init__(self, lhs_symbol, rhs_expr, is_const=True): def __init__(self, lhs_symbol, rhs_expr, is_const=True, use_auto=False):
super(SympyAssignment, self).__init__(parent=None) super(SympyAssignment, self).__init__(parent=None)
self._lhs_symbol = lhs_symbol self._lhs_symbol = sp.sympify(lhs_symbol)
self.rhs = rhs_expr self._rhs = sp.sympify(rhs_expr)
self._is_const = is_const self._is_const = is_const
self._is_declaration = self.__is_declaration() self._is_declaration = self.__is_declaration()
self._use_auto = use_auto
def __is_declaration(self): def __is_declaration(self):
if isinstance(self._lhs_symbol, cast_func): from pystencils.typing import CastFunc
if isinstance(self._lhs_symbol, CastFunc):
return False return False
if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)): if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)):
return False return False
...@@ -500,15 +588,35 @@ class SympyAssignment(Node): ...@@ -500,15 +588,35 @@ class SympyAssignment(Node):
def lhs(self): def lhs(self):
return self._lhs_symbol return self._lhs_symbol
@property
def rhs(self):
return self._rhs
@lhs.setter @lhs.setter
def lhs(self, new_value): def lhs(self, new_value):
self._lhs_symbol = new_value self._lhs_symbol = new_value
self._is_declaration = self.__is_declaration() self._is_declaration = self.__is_declaration()
@rhs.setter
def rhs(self, new_rhs_expr):
self._rhs = new_rhs_expr
def subs(self, subs_dict): def subs(self, subs_dict):
self.lhs = fast_subs(self.lhs, subs_dict) self.lhs = fast_subs(self.lhs, subs_dict)
self.rhs = fast_subs(self.rhs, subs_dict) self.rhs = fast_subs(self.rhs, subs_dict)
def fast_subs(self, subs_dict, skip=None):
self.lhs = fast_subs(self.lhs, subs_dict, skip)
self.rhs = fast_subs(self.rhs, subs_dict, skip)
return self
def optimize(self, optimizations):
try:
from sympy.codegen.rewriting import optimize
self.rhs = optimize(self.rhs, optimizations)
except Exception:
pass
@property @property
def args(self): def args(self):
return [self._lhs_symbol, self.rhs] return [self._lhs_symbol, self.rhs]
...@@ -521,7 +629,7 @@ class SympyAssignment(Node): ...@@ -521,7 +629,7 @@ class SympyAssignment(Node):
@property @property
def undefined_symbols(self): def undefined_symbols(self):
result = self.rhs.atoms(sp.Symbol) result = {s for s in self.rhs.free_symbols if not isinstance(s, sp.Indexed)}
# Add loop counters if there a field accesses # Add loop counters if there a field accesses
loop_counters = set() loop_counters = set()
for symbol in result: for symbol in result:
...@@ -529,7 +637,9 @@ class SympyAssignment(Node): ...@@ -529,7 +637,9 @@ class SympyAssignment(Node):
for i in range(len(symbol.offsets)): for i in range(len(symbol.offsets)):
loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i)) loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
result.update(loop_counters) result.update(loop_counters)
result.update(self._lhs_symbol.atoms(sp.Symbol)) result.update(self._lhs_symbol.atoms(sp.Symbol))
return result return result
@property @property
...@@ -540,6 +650,10 @@ class SympyAssignment(Node): ...@@ -540,6 +650,10 @@ class SympyAssignment(Node):
def is_const(self): def is_const(self):
return self._is_const return self._is_const
@property
def use_auto(self):
return self._use_auto
def replace(self, child, replacement): def replace(self, child, replacement):
if child == self.lhs: if child == self.lhs:
replacement.parent = self replacement.parent = self
...@@ -548,7 +662,7 @@ class SympyAssignment(Node): ...@@ -548,7 +662,7 @@ class SympyAssignment(Node):
replacement.parent = self replacement.parent = self
self.rhs = replacement self.rhs = replacement
else: else:
raise ValueError('%s is not in args of %s' % (replacement, self.__class__)) raise ValueError(f'{replacement} is not in args of {self.__class__}')
def __repr__(self): def __repr__(self):
return repr(self.lhs) + "" + repr(self.rhs) return repr(self.lhs) + "" + repr(self.rhs)
...@@ -556,13 +670,21 @@ class SympyAssignment(Node): ...@@ -556,13 +670,21 @@ class SympyAssignment(Node):
def _repr_html_(self): def _repr_html_(self):
printed_lhs = sp.latex(self.lhs) printed_lhs = sp.latex(self.lhs)
printed_rhs = sp.latex(self.rhs) printed_rhs = sp.latex(self.rhs)
return "${printed_lhs} \\leftarrow {printed_rhs}$".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs) return f"${printed_lhs} \\leftarrow {printed_rhs}$"
def __hash__(self):
return hash((self.lhs, self.rhs))
def __eq__(self, other):
return type(self) is type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs)
class ResolvedFieldAccess(sp.Indexed): class ResolvedFieldAccess(sp.Indexed):
def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values): def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values):
if not isinstance(base, sp.IndexedBase): if not isinstance(base, sp.IndexedBase):
assert isinstance(base, TypedSymbol)
base = sp.IndexedBase(base, shape=(1,)) base = sp.IndexedBase(base, shape=(1,))
assert isinstance(base.label, TypedSymbol)
obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index) obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index)
obj.field = field obj.field = field
obj.offsets = offsets obj.offsets = offsets
...@@ -574,7 +696,7 @@ class ResolvedFieldAccess(sp.Indexed): ...@@ -574,7 +696,7 @@ class ResolvedFieldAccess(sp.Indexed):
self.args[1].subs(old, new), self.args[1].subs(old, new),
self.field, self.offsets, self.idx_coordinate_values) self.field, self.offsets, self.idx_coordinate_values)
def fast_subs(self, substitutions): def fast_subs(self, substitutions, skip=None):
if self in substitutions: if self in substitutions:
return substitutions[self] return substitutions[self]
return ResolvedFieldAccess(self.args[0].subs(substitutions), return ResolvedFieldAccess(self.args[0].subs(substitutions),
...@@ -591,11 +713,14 @@ class ResolvedFieldAccess(sp.Indexed): ...@@ -591,11 +713,14 @@ class ResolvedFieldAccess(sp.Indexed):
def __str__(self): def __str__(self):
top = super(ResolvedFieldAccess, self).__str__() top = super(ResolvedFieldAccess, self).__str__()
return "%s (%s)" % (top, self.typed_symbol.dtype) return f"{top} ({self.typed_symbol.dtype})"
def __getnewargs__(self): def __getnewargs__(self):
return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values
def __getnewargs_ex__(self):
return (self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values), {}
class TemporaryMemoryAllocation(Node): class TemporaryMemoryAllocation(Node):
"""Node for temporary memory buffer allocation. """Node for temporary memory buffer allocation.
...@@ -667,60 +792,79 @@ def early_out(condition): ...@@ -667,60 +792,79 @@ def early_out(condition):
return Conditional(vec_all(condition), Block([SkipIteration()])) return Conditional(vec_all(condition), Block([SkipIteration()]))
class DestructuringBindingsForFieldClass(Node): def get_dummy_symbol(dtype='bool'):
""" return TypedSymbol(f'dummy{uuid.uuid4().hex}', create_type(dtype))
Defines all variables needed for describing a field (shape, pointer, strides)
"""
CLASS_TO_MEMBER_DICT = { class SourceCodeComment(Node):
FieldPointerSymbol: "data", def __init__(self, text):
FieldShapeSymbol: "shape[%i]", self.text = text
FieldStrideSymbol: "stride[%i]"
}
CLASS_NAME_TEMPLATE = "PyStencilsField<{dtype}, {ndim}>"
@property @property
def fields_accessed(self) -> Set['ResolvedFieldAccess']: def args(self):
"""Set of Field instances: fields which are accessed inside this kernel function""" return []
return set(o.field for o in self.atoms(ResolvedFieldAccess))
def __init__(self, body): @property
super(DestructuringBindingsForFieldClass, self).__init__() def symbols_defined(self):
self.headers = ['<PyStencilsField.h>'] return set()
self.body = body
@property @property
def args(self) -> List[NodeOrExpr]: def undefined_symbols(self):
"""Returns all arguments/children of this node."""
return set() return set()
def __str__(self):
return "/* " + self.text + " */"
def __repr__(self):
return self.__str__()
class EmptyLine(Node):
def __init__(self):
pass
@property @property
def symbols_defined(self) -> Set[sp.Symbol]: def args(self):
"""Set of symbols which are defined by this node.""" return []
undefined_field_symbols = {s for s in self.body.undefined_symbols
if isinstance(s, (FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol))}
return undefined_field_symbols
@property @property
def undefined_symbols(self) -> Set[sp.Symbol]: def symbols_defined(self):
field_map = {f.name: f for f in self.fields_accessed} return set()
undefined_field_symbols = self.symbols_defined
corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')}
corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')}
return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&')
for f in corresponding_field_names} | \
(self.body.undefined_symbols - undefined_field_symbols)
def subs(self, subs_dict) -> None: @property
"""Inplace! substitute, similar to sympy's but modifies the AST inplace.""" def undefined_symbols(self):
self.body.subs(subs_dict) return set()
def __str__(self):
return ""
def __repr__(self):
return self.__str__()
class ConditionalFieldAccess(sp.Function):
"""
:class:`pystencils.Field.Access` that is only executed if a certain condition is met.
Can be used, for instance, for out-of-bound checks.
"""
def __new__(cls, field_access, outofbounds_condition, outofbounds_value=0):
return sp.Function.__new__(cls, field_access, outofbounds_condition, sp.S(outofbounds_value))
@property @property
def func(self): def access(self):
return self.__class__ return self.args[0]
def atoms(self, arg_type) -> Set[Any]: @property
return self.body.atoms(arg_type) | {s for s in self.symbols_defined if isinstance(s, arg_type)} def outofbounds_condition(self):
return self.args[1]
@property
def outofbounds_value(self):
return self.args[2]
def get_dummy_symbol(dtype='bool'): def __getnewargs__(self):
return TypedSymbol('dummy%s' % uuid.uuid4().hex, create_type(dtype)) return self.access, self.outofbounds_condition, self.outofbounds_value
def __getnewargs_ex__(self):
return (self.access, self.outofbounds_condition, self.outofbounds_value), {}
...@@ -6,9 +6,3 @@ try: ...@@ -6,9 +6,3 @@ try:
__all__.append('print_dot') __all__.append('print_dot')
except ImportError: except ImportError:
pass pass
try:
from .llvm import generate_llvm # NOQA
__all__.append('generate_llvm')
except ImportError:
pass
from pystencils.typing import CFunction
def get_argument_string(function_shortcut, first=''):
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
if first:
arg_string += first + ', '
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
arg_string += "{" + arg + "},"
else:
arg_string += arg + ","
arg_string = arg_string[:-1] + ")"
return arg_string
def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
if instruction_set not in ['neon', 'sme'] and not instruction_set.startswith('sve'):
raise NotImplementedError(instruction_set)
if instruction_set in ['sve', 'sve2', 'sme']:
cmp = 'cmp'
elif instruction_set.startswith('sve2') and instruction_set not in ('sve256', 'sve2048'):
cmp = 'cmp'
bitwidth = int(instruction_set[4:])
elif instruction_set.startswith('sve'):
cmp = 'cmp'
bitwidth = int(instruction_set[3:])
elif instruction_set == 'neon':
cmp = 'c'
bitwidth = 128
base_names = {
'+': 'add[0, 1]',
'-': 'sub[0, 1]',
'*': 'mul[0, 1]',
'/': 'div[0, 1]',
'sqrt': 'sqrt[0]',
'loadU': 'ld1[0]',
'storeU': 'st1[0, 1]',
'abs': 'abs[0]',
'==': f'{cmp}eq[0, 1]',
'!=': f'{cmp}eq[0, 1]',
'<=': f'{cmp}le[0, 1]',
'<': f'{cmp}lt[0, 1]',
'>=': f'{cmp}ge[0, 1]',
'>': f'{cmp}gt[0, 1]',
}
bits = {'double': 64,
'float': 32,
'int': 32}
result = dict()
if instruction_set in ['sve', 'sve2', 'sme']:
width = 'svcntd()' if data_type == 'double' else 'svcntw()'
intwidth = 'svcntw()'
result['bytes'] = 'svcntb()'
else:
width = bitwidth // bits[data_type]
intwidth = bitwidth // bits['int']
result['bytes'] = bitwidth // 8
if instruction_set.startswith('sve') or instruction_set == 'sme':
base_names['stream'] = 'stnt1[0, 1]'
prefix = 'sv'
suffix = f'_f{bits[data_type]}'
elif instruction_set == 'neon':
prefix = 'v'
suffix = f'q_f{bits[data_type]}'
if instruction_set in ['sve', 'sve2', 'sme']:
predicate = f'{prefix}whilelt_b{bits[data_type]}_u64({{loop_counter}}, {{loop_stop}})'
int_predicate = f'{prefix}whilelt_b{bits["int"]}_u64({{loop_counter}}, {{loop_stop}})'
else:
predicate = f'{prefix}whilelt_b{bits[data_type]}(0, {width})'
int_predicate = f'{prefix}whilelt_b{bits["int"]}(0, {intwidth})'
for intrinsic_id, function_shortcut in base_names.items():
function_shortcut = function_shortcut.strip()
name = function_shortcut[:function_shortcut.index('[')]
arg_string = get_argument_string(function_shortcut, first=predicate if prefix == 'sv' else '')
if prefix == 'sv' and not name.startswith('ld') and not name.startswith('st') and not name.startswith(cmp):
undef = '_x'
else:
undef = ''
result[intrinsic_id] = prefix + name + suffix + undef + arg_string
if instruction_set in ['sve', 'sve2', 'sme']:
result['width'] = CFunction(width, "int")
result['intwidth'] = CFunction(intwidth, "int")
else:
result['width'] = width
result['intwidth'] = intwidth
if instruction_set.startswith('sve') or instruction_set == 'sme':
result['makeVecConst'] = f'svdup_f{bits[data_type]}' + '({0})'
result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})'
result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})'
if instruction_set != 'sme':
vindex = f'svindex_u{bits[data_type]}(0, {{0}})'
result['storeS'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{2}") + ', {1})'
result['loadS'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{1}") + ')'
if instruction_set.startswith('sve2') and instruction_set not in ('sve256', 'sve2048'):
result['streamS'] = f'svstnt1_scatter_u{bits[data_type]}offset_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format(f"{{2}}*{bits[data_type]//8}") + ', {1})'
result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})"
result['float'] = f'svfloat{bits["float"]}_{"s" if instruction_set not in ["sve", "sve2", "sme"] else ""}t'
result['double'] = f'svfloat{bits["double"]}_{"s" if instruction_set not in ["sve", "sve2", "sme"] else ""}t'
result['int'] = f'svint{bits["int"]}_{"s" if instruction_set not in ["sve", "sve2", "sme"] else ""}t'
result['bool'] = f'svbool_{"s" if instruction_set not in ["sve", "sve2", "sme"] else ""}t'
result['headers'] = ['<arm_sve.h>', '<arm_acle.h>', '"arm_neon_helpers.h"']
result['&'] = f'svand_b_z({predicate},' + ' {0}, {1})'
result['|'] = f'svorr_b_z({predicate},' + ' {0}, {1})'
result['blendv'] = f'svsel_f{bits[data_type]}' + '({2}, {1}, {0})'
result['any'] = f'svptest_any({predicate}, {{0}})'
result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}'
result['maskStoreU'] = result['storeU'].replace(predicate, '{2}')
result['maskStream'] = result['stream'].replace(predicate, '{2}')
if instruction_set != 'sme':
result['maskStoreS'] = result['storeS'].replace(predicate, '{3}')
if instruction_set.startswith('sve2') and instruction_set not in ('sve256', 'sve2048'):
result['maskStreamS'] = result['streamS'].replace(predicate, '{3}')
result['streamFence'] = '__dmb(15)'
if instruction_set == 'sme':
result['function_prefix'] = '__attribute__((arm_locally_streaming))'
elif instruction_set not in ['sve', 'sve2', 'sme']:
result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
else:
result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})'
result['makeVec'] = f'makeVec_f{bits[data_type]}' + '(' + ", ".join(['{' + str(i) + '}' for i in
range(width)]) + ')'
result['makeVecConstInt'] = f'vdupq_n_s{bits["int"]}' + '({0})'
result['makeVecInt'] = f'makeVec_s{bits["int"]}' + '({0}, {1}, {2}, {3})'
result['+int'] = f"vaddq_s{bits['int']}" + "({0}, {1})"
result[data_type] = f'float{bits[data_type]}x{width}_t'
result['int'] = f'int{bits["int"]}x{intwidth}_t'
result['bool'] = f'uint{bits[data_type]}x{width}_t'
result['headers'] = ['<arm_neon.h>', '"arm_neon_helpers.h"']
result['!='] = f'vmvnq_u{bits[data_type]}({result["=="]})'
result['&'] = f'vandq_u{bits[data_type]}' + '({0}, {1})'
result['|'] = f'vorrq_u{bits[data_type]}' + '({0}, {1})'
result['blendv'] = f'vbslq_f{bits[data_type]}' + '({2}, {1}, {0})'
result['any'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) > 0'
result['all'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) == 16*0xff'
# SVE has real nontemporal stores, so we only need to zero cachlines on Neon
result['cachelineZero'] = 'cachelineZero((void*) {0})'
result['cachelineSize'] = 'cachelineSize()'
return result
import re
from collections import namedtuple from collections import namedtuple
import hashlib
from typing import Set from typing import Set
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from sympy.core import S from sympy.core import S
from sympy.printing.ccode import C89CodePrinter from sympy.logic.boolalg import BooleanFalse, BooleanTrue
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
from pystencils.astnodes import KernelFunction, Node from sympy.functions.elementary.hyperbolic import HyperbolicFunction
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import ( from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize
reinterpret_cast_func, vector_memory_access) from pystencils.typing import (
PointerType, VectorType, CastFunc, create_type, get_type_of_expression,
ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol, CFunction)
from pystencils.enums import Backend
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.functions import DivFunc, AddressOf
from pystencils.integer_functions import ( from pystencils.integer_functions import (
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
int_div, int_power_of_2, modulo_ceil) int_div, int_power_of_2, modulo_ceil)
from pystencils.kernelparameters import FieldPointerSymbol
try: try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter from sympy.printing.c import C99CodePrinter as CCodePrinter # for sympy versions > 1.6
except ImportError: except ImportError:
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1 from sympy.printing.ccode import C99CodePrinter as CCodePrinter
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
KERNCRAFT_NO_TERNARY_MODE = False
HEADER_REGEX = re.compile(r'^[<"].*[">]$')
def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom_backend=None) -> str: def generate_c(ast_node: Node,
signature_only: bool = False,
dialect: Backend = Backend.C,
custom_backend=None,
with_globals=True) -> str:
"""Prints an abstract syntax tree node as C or CUDA code. """Prints an abstract syntax tree node as C or CUDA code.
This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded This function does not need to distinguish for most AST nodes between C, C++ or CUDA code, it just prints 'C-like'
in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different create_kernel code as encoded in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different
functions. create_kernel functions.
Args: Args:
ast_node: ast_node: ast representation of kernel
signature_only: signature_only: generate signature without function body
dialect: 'c' or 'cuda' dialect: `Backend`: 'C' or 'CUDA'
custom_backend: use own custom printer for code generation
with_globals: enable usage of global variables
Returns: Returns:
C-like code for the ast node and its descendants C-like code for the ast node and its descendants
""" """
...@@ -49,23 +61,25 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom ...@@ -49,23 +61,25 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom
ast_node.global_variables = d.symbols_defined ast_node.global_variables = d.symbols_defined
if custom_backend: if custom_backend:
printer = custom_backend printer = custom_backend
elif dialect == 'c': elif dialect == Backend.C:
try: try:
# TODO Vectorization Revamp: instruction_set should not be just slapped on ast
instruction_set = ast_node.instruction_set instruction_set = ast_node.instruction_set
except Exception: except Exception:
instruction_set = None instruction_set = None
printer = CBackend(signature_only=signature_only, printer = CBackend(signature_only=signature_only,
vector_instruction_set=instruction_set) vector_instruction_set=instruction_set)
elif dialect == 'cuda': elif dialect == Backend.CUDA:
from pystencils.backends.cuda_backend import CudaBackend from pystencils.backends.cuda_backend import CudaBackend
printer = CudaBackend(signature_only=signature_only) printer = CudaBackend(signature_only=signature_only)
else: else:
raise ValueError("Unknown dialect: " + str(dialect)) raise ValueError(f'Unknown {dialect=}')
code = printer(ast_node) code = printer(ast_node)
if not signature_only and isinstance(ast_node, KernelFunction): if not signature_only and isinstance(ast_node, KernelFunction):
code = "\n" + code if with_globals and global_declarations:
for declaration in global_declarations: code = "\n" + code
code = printer(declaration) + "\n" + code for declaration in global_declarations:
code = printer(declaration) + "\n" + code
return code return code
...@@ -74,8 +88,8 @@ def get_global_declarations(ast): ...@@ -74,8 +88,8 @@ def get_global_declarations(ast):
global_declarations = [] global_declarations = []
def visit_node(sub_ast): def visit_node(sub_ast):
nonlocal global_declarations
if hasattr(sub_ast, "required_global_declarations"): if hasattr(sub_ast, "required_global_declarations"):
nonlocal global_declarations
global_declarations += sub_ast.required_global_declarations global_declarations += sub_ast.required_global_declarations
if hasattr(sub_ast, "args"): if hasattr(sub_ast, "args"):
...@@ -84,7 +98,7 @@ def get_global_declarations(ast): ...@@ -84,7 +98,7 @@ def get_global_declarations(ast):
visit_node(ast) visit_node(ast)
return set(global_declarations) return sorted(set(global_declarations), key=str)
def get_headers(ast_node: Node) -> Set[str]: def get_headers(ast_node: Node) -> Set[str]:
...@@ -97,15 +111,22 @@ def get_headers(ast_node: Node) -> Set[str]: ...@@ -97,15 +111,22 @@ def get_headers(ast_node: Node) -> Set[str]:
if hasattr(ast_node, 'headers'): if hasattr(ast_node, 'headers'):
headers.update(ast_node.headers) headers.update(ast_node.headers)
for a in ast_node.args: for a in ast_node.args:
if isinstance(a, Node): if isinstance(a, (sp.Expr, Node)):
headers.update(get_headers(a)) headers.update(get_headers(a))
for g in get_global_declarations(ast_node):
if isinstance(g, Node):
headers.update(get_headers(g))
for h in headers:
assert HEADER_REGEX.match(h), f'header /{h}/ does not follow the pattern /"..."/ or /<...>/'
return headers return headers
# --------------------------------------- Backend Specific Nodes ------------------------------------------------------- # --------------------------------------- Backend Specific Nodes -------------------------------------------------------
# TODO future CustomCodeNode should not be backend specific move it elsewhere
class CustomCodeNode(Node): class CustomCodeNode(Node):
def __init__(self, code, symbols_read, symbols_defined, parent=None): def __init__(self, code, symbols_read, symbols_defined, parent=None):
super(CustomCodeNode, self).__init__(parent=parent) super(CustomCodeNode, self).__init__(parent=parent)
...@@ -114,7 +135,7 @@ class CustomCodeNode(Node): ...@@ -114,7 +135,7 @@ class CustomCodeNode(Node):
self._symbols_defined = set(symbols_defined) self._symbols_defined = set(symbols_defined)
self.headers = [] self.headers = []
def get_code(self, dialect, vector_instruction_set): def get_code(self, dialect, vector_instruction_set, print_arg):
return self._code return self._code
@property @property
...@@ -129,11 +150,17 @@ class CustomCodeNode(Node): ...@@ -129,11 +150,17 @@ class CustomCodeNode(Node):
def undefined_symbols(self): def undefined_symbols(self):
return self._symbols_read - self._symbols_defined return self._symbols_read - self._symbols_defined
def __eq__(self, other):
return type(self) is type(other) and self._code == other._code
def __hash__(self):
return hash(self._code)
class PrintNode(CustomCodeNode): class PrintNode(CustomCodeNode):
# noinspection SpellCheckingInspection # noinspection SpellCheckingInspection
def __init__(self, symbol_to_print): def __init__(self, symbol_to_print):
code = '\nstd::cout << "%s = " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name) code = f'\nstd::cout << "{symbol_to_print.name} = " << {symbol_to_print.name} << std::endl; \n'
super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set()) super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
self.headers.append("<iostream>") self.headers.append("<iostream>")
...@@ -144,7 +171,7 @@ class PrintNode(CustomCodeNode): ...@@ -144,7 +171,7 @@ class PrintNode(CustomCodeNode):
# noinspection PyPep8Naming # noinspection PyPep8Naming
class CBackend: class CBackend:
def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'): def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect=Backend.C):
if sympy_printer is None: if sympy_printer is None:
if vector_instruction_set is not None: if vector_instruction_set is not None:
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set) self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
...@@ -157,6 +184,8 @@ class CBackend: ...@@ -157,6 +184,8 @@ class CBackend:
self._indent = " " self._indent = " "
self._dialect = dialect self._dialect = dialect
self._signatureOnly = signature_only self._signatureOnly = signature_only
self._kwargs = {}
self.sympy_printer._kwargs = self._kwargs
def __call__(self, node): def __call__(self, node):
prev_is = VectorType.instruction_set prev_is = VectorType.instruction_set
...@@ -166,19 +195,25 @@ class CBackend: ...@@ -166,19 +195,25 @@ class CBackend:
return result return result
def _print(self, node): def _print(self, node):
if isinstance(node, str):
return node
for cls in type(node).__mro__: for cls in type(node).__mro__:
method_name = "_print_" + cls.__name__ method_name = f"_print_{cls.__name__}"
if hasattr(self, method_name): if hasattr(self, method_name):
return getattr(self, method_name)(node) return getattr(self, method_name)(node)
raise NotImplementedError(self.__class__.__name__ + " does not support node of type " + node.__class__.__name__) raise NotImplementedError(f"{self.__class__.__name__} does not support node of type {node.__class__.__name__}")
def _print_AbstractType(self, node):
return str(node)
def _print_KernelFunction(self, node): def _print_KernelFunction(self, node):
function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()] function_arguments = [f"{self._print(s.symbol.dtype)} {s.symbol.name}" for s in node.get_parameters()
if not type(s.symbol) is CFunction]
launch_bounds = "" launch_bounds = ""
if self._dialect == 'cuda': if self._dialect == Backend.CUDA:
max_threads = node.indexing.max_threads_per_block() max_threads = node.indexing.max_threads_per_block()
if max_threads: if max_threads:
launch_bounds = "__launch_bounds__({}) ".format(max_threads) launch_bounds = f"__launch_bounds__({max_threads}) "
func_declaration = "FUNC_PREFIX %svoid %s(%s)" % (launch_bounds, node.function_name, func_declaration = "FUNC_PREFIX %svoid %s(%s)" % (launch_bounds, node.function_name,
", ".join(function_arguments)) ", ".join(function_arguments))
if self._signatureOnly: if self._signatureOnly:
...@@ -192,50 +227,166 @@ class CBackend: ...@@ -192,50 +227,166 @@ class CBackend:
return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True))) return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
def _print_PragmaBlock(self, node): def _print_PragmaBlock(self, node):
return "%s\n%s" % (node.pragma_line, self._print_Block(node)) return f"{node.pragma_line}\n{self._print_Block(node)}"
def _print_LoopOverCoordinate(self, node): def _print_LoopOverCoordinate(self, node):
counter_symbol = node.loop_counter_name counter_name = node.loop_counter_name
start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start)) counter_dtype = node.loop_counter_symbol.dtype.c_name
condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop)) start = f"{counter_dtype} {counter_name} = {self.sympy_printer.doprint(node.start)}"
update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),) condition = f"{counter_name} < {self.sympy_printer.doprint(node.stop)}"
loop_str = "for (%s; %s; %s)" % (start, condition, update) update = f"{counter_name} += {self.sympy_printer.doprint(node.step)}"
loop_str = f"for ({start}; {condition}; {update})"
self._kwargs['loop_counter'] = counter_name
self._kwargs['loop_stop'] = node.stop
prefix = "\n".join(node.prefix_lines) prefix = "\n".join(node.prefix_lines)
if prefix: if prefix:
prefix += "\n" prefix += "\n"
return "%s%s\n%s" % (prefix, loop_str, self._print(node.body)) return f"{prefix}{loop_str}\n{self._print(node.body)}"
def _print_SympyAssignment(self, node): def _print_SympyAssignment(self, node):
printed_lhs = self.sympy_printer.doprint(node.lhs)
printed_rhs = self.sympy_printer.doprint(node.rhs)
if node.is_declaration: if node.is_declaration:
data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " " if node.use_auto:
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), data_type = 'auto'
self.sympy_printer.doprint(node.rhs)) else:
data_type = self._print(node.lhs.dtype).replace(' const', '')
if node.is_const:
data_type = f'const {data_type}'
return f"{data_type} {printed_lhs} = {printed_rhs};"
else: else:
lhs_type = get_type_of_expression(node.lhs) lhs_type = get_type_of_expression(node.lhs) # TOOD: this should have been typed
if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): printed_mask = ""
arg, data_type, aligned, nontemporal = node.lhs.args if type(lhs_type) is VectorType and isinstance(node.lhs, CastFunc):
arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args
instr = 'storeU' instr = 'storeU'
if aligned: if nontemporal and 'storeA' not in self._vector_instruction_set and \
instr = 'stream' if nontemporal else 'storeA' 'stream' in self._vector_instruction_set:
instr = 'stream'
elif aligned:
instr = 'stream' if nontemporal and 'stream' in self._vector_instruction_set else 'storeA'
if mask != True: # NOQA
instr = 'maskStream' if nontemporal and 'maskStream' in self._vector_instruction_set else \
'maskStoreA' if aligned else 'maskStoreU'
if instr not in self._vector_instruction_set:
if instr == 'maskStream' and 'stream' in self._vector_instruction_set:
store, load = 'stream', 'loadA'
elif (instr in ('maskStream', 'maskStoreA')) and 'storeA' in self._vector_instruction_set:
store, load = 'storeA', 'loadA'
else:
store, load = 'storeU', 'loadU'
load = load if load in self._vector_instruction_set else 'loadU'
self._vector_instruction_set[instr] = self._vector_instruction_set[store].format(
'{0}', self._vector_instruction_set['blendv'].format(
self._vector_instruction_set[load].format('{0}', **self._kwargs),
'{1}', '{2}', **self._kwargs), **self._kwargs)
printed_mask = self.sympy_printer.doprint(mask)
if data_type.base_type.c_name == 'double':
if self._vector_instruction_set['double'] == '__m256d':
printed_mask = f"_mm256_castpd_si256({printed_mask})"
elif self._vector_instruction_set['double'] == '__m128d':
printed_mask = f"_mm_castpd_si128({printed_mask})"
elif data_type.base_type.c_name == 'float':
if self._vector_instruction_set['float'] == '__m256':
printed_mask = f"_mm256_castps_si256({printed_mask})"
elif self._vector_instruction_set['float'] == '__m128':
printed_mask = f"_mm_castps_si128({printed_mask})"
rhs_type = get_type_of_expression(node.rhs) rhs_type = get_type_of_expression(node.rhs)
if type(rhs_type) is not VectorType: if type(rhs_type) is not VectorType:
rhs = cast_func(node.rhs, VectorType(rhs_type)) raise ValueError(f'Cannot vectorize {node.rhs} of type {rhs_type} inside of the pretty printer! '
f'This should have happen earlier!')
# rhs = CastFunc(node.rhs, VectorType(rhs_type)) # Unknown width
else: else:
rhs = node.rhs rhs = node.rhs
return self._vector_instruction_set[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]), ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0])
self.sympy_printer.doprint(rhs)) + ';'
if stride != 1:
instr = ('maskStreamS' if nontemporal and 'maskStreamS' in self._vector_instruction_set else
'maskStoreS') if mask != True else \
('streamS' if nontemporal and 'streamS' in self._vector_instruction_set else 'storeS') # NOQA
return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
stride, printed_mask, **self._kwargs) + ';'
pre_code = ''
if nontemporal and 'cachelineZero' in self._vector_instruction_set and mask == True: # NOQA
first_cond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0"
offset = sp.Add(*[sp.Symbol(LoopOverCoordinate.get_loop_counter_name(i))
* node.lhs.args[0].field.spatial_strides[i] for i in
range(len(node.lhs.args[0].field.spatial_strides))])
if stride == 1:
offset = offset.subs({node.lhs.args[0].field.spatial_strides[0]: 1})
size = sp.Mul(*node.lhs.args[0].field.spatial_shape)
element_size = 8 if data_type.base_type.c_name == 'double' else 4
size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}"
pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \
self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n'
code = self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
printed_mask, **self._kwargs) + ';'
flushcond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == {CachelineSize.last_symbol}"
if nontemporal and 'flushCacheline' in self._vector_instruction_set:
code2 = self._vector_instruction_set['flushCacheline'].format(
ptr, self.sympy_printer.doprint(rhs), **self._kwargs) + ';'
code = f"{code}\nif ({flushcond}) {{\n\t{code2}\n}}"
elif aligned and nontemporal and 'storeAAndFlushCacheline' in self._vector_instruction_set:
lhs_hash = hashlib.sha1(self.sympy_printer.doprint(node.lhs).encode('ascii')).hexdigest()[:8]
rhs_hash = hashlib.sha1(self.sympy_printer.doprint(rhs).encode('ascii')).hexdigest()[:8]
tmpvar = f'_tmp_{lhs_hash}_{rhs_hash}'
code = 'const ' + self._print(node.lhs.dtype).replace(' const', '') + ' ' + tmpvar + ' = ' \
+ self.sympy_printer.doprint(rhs) + ';'
code1 = self._vector_instruction_set[instr].format(ptr, tmpvar, printed_mask, **self._kwargs) + ';'
maskStore, store, load = 'maskStoreAAndFlushCacheline', 'storeAAndFlushCacheline', 'loadA'
instr2 = maskStore if mask != True else store # NOQA
if instr2 not in self._vector_instruction_set:
self._vector_instruction_set[maskStore] = self._vector_instruction_set[store].format(
'{0}', self._vector_instruction_set['blendv'].format(
self._vector_instruction_set[load].format('{0}', **self._kwargs),
'{1}', '{2}', **self._kwargs),
**self._kwargs)
code2 = self._vector_instruction_set[instr2].format(ptr, tmpvar, printed_mask, **self._kwargs) + ';'
code += f"\nif ({flushcond}) {{\n\t{code2}\n}} else {{\n\t{code1}\n}}"
return pre_code + code
else: else:
return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) return f"{printed_lhs} = {printed_rhs};"
def _print_NontemporalFence(self, _):
if 'streamFence' in self._vector_instruction_set:
return self._vector_instruction_set['streamFence'] + ';'
else:
return ''
def _print_CachelineSize(self, node):
if 'cachelineSize' in self._vector_instruction_set:
code = f'const size_t {node.symbol} = {self._vector_instruction_set["cachelineSize"]};\n'
code += f'const size_t {node.mask_symbol} = {node.symbol} - 1;\n'
vectorsize = self._vector_instruction_set['bytes']
code += f'const size_t {node.last_symbol} = {node.symbol} - {vectorsize};\n'
return code
else:
return ''
def _print_TemporaryMemoryAllocation(self, node): def _print_TemporaryMemoryAllocation(self, node):
align = 64 if self._vector_instruction_set:
align = self._vector_instruction_set['bytes']
else:
align = node.symbol.dtype.base_type.numpy_dtype.itemsize
np_dtype = node.symbol.dtype.base_type.numpy_dtype np_dtype = node.symbol.dtype.base_type.numpy_dtype
required_size = np_dtype.itemsize * node.size + align required_size = np_dtype.itemsize * node.size + align
size = modulo_ceil(required_size, align) size = modulo_ceil(required_size, align)
code = "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};" code = "#if defined(_MSC_VER)\n"
code += "{dtype} {name}=({dtype})_aligned_malloc({size}, {align}) + {offset};\n"
code += "#elif __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L\n"
code += "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};\n"
code += "#else\n"
code += "{dtype} {name};\n"
code += "posix_memalign((void**) &{name}, {align}, {size});\n"
code += "{name} += {offset};\n"
code += "#endif"
return code.format(dtype=node.symbol.dtype, return code.format(dtype=node.symbol.dtype,
name=self.sympy_printer.doprint(node.symbol.name), name=self.sympy_printer.doprint(node.symbol.name),
size=self.sympy_printer.doprint(size), size=self.sympy_printer.doprint(size),
...@@ -243,45 +394,46 @@ class CBackend: ...@@ -243,45 +394,46 @@ class CBackend:
align=align) align=align)
def _print_TemporaryMemoryFree(self, node): def _print_TemporaryMemoryFree(self, node):
align = 64 if self._vector_instruction_set:
return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align)) align = self._vector_instruction_set['bytes']
else:
align = node.symbol.dtype.base_type.numpy_dtype.itemsize
code = "#if defined(_MSC_VER)\n"
code += "_aligned_free(%s - %d);\n" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
code += "#else\n"
code += "free(%s - %d);\n" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
code += "#endif"
return code
def _print_SkipIteration(self, _): def _print_SkipIteration(self, _):
return "continue;" return "continue;"
def _print_CustomCodeNode(self, node): def _print_CustomCodeNode(self, node):
return node.get_code(self._dialect, self._vector_instruction_set) return node.get_code(self._dialect, self._vector_instruction_set, print_arg=self.sympy_printer._print)
def _print_SourceCodeComment(self, node):
return f"/* {node.text } */"
def _print_EmptyLine(self, node):
return ""
def _print_Conditional(self, node): def _print_Conditional(self, node):
if type(node.condition_expr) is BooleanTrue:
return self._print_Block(node.true_block)
elif type(node.condition_expr) is BooleanFalse:
return self._print_Block(node.false_block)
cond_type = get_type_of_expression(node.condition_expr) cond_type = get_type_of_expression(node.condition_expr)
if isinstance(cond_type, VectorType): if isinstance(cond_type, VectorType):
raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all") raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all")
condition_expr = self.sympy_printer.doprint(node.condition_expr) condition_expr = self.sympy_printer.doprint(node.condition_expr)
true_block = self._print_Block(node.true_block) true_block = self._print_Block(node.true_block)
result = "if (%s)\n%s " % (condition_expr, true_block) result = f"if ({condition_expr})\n{true_block} "
if node.false_block: if node.false_block:
false_block = self._print_Block(node.false_block) false_block = self._print_Block(node.false_block)
result += "else " + false_block result += f"else {false_block}"
return result return result
def _print_DestructuringBindingsForFieldClass(self, node):
# Define all undefined symbols
undefined_field_symbols = node.symbols_defined
destructuring_bindings = ["%s %s = %s.%s;" %
(u.dtype,
u.name,
u.field_name if hasattr(u, 'field_name') else u.field_names[0],
node.CLASS_TO_MEMBER_DICT[u.__class__] %
(() if type(u) == FieldPointerSymbol else (u.coordinate,)))
for u in undefined_field_symbols
]
destructuring_bindings.sort() # only for code aesthetics
return "{\n" + self._indent + \
("\n" + self._indent).join(destructuring_bindings) + \
"\n" + self._indent + \
("\n" + self._indent).join(self._print(node.body).splitlines()) + \
"\n}"
# ------------------------------------------ Helper function & classes ------------------------------------------------- # ------------------------------------------ Helper function & classes -------------------------------------------------
...@@ -291,27 +443,31 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -291,27 +443,31 @@ class CustomSympyPrinter(CCodePrinter):
def __init__(self): def __init__(self):
super(CustomSympyPrinter, self).__init__() super(CustomSympyPrinter, self).__init__()
self._float_type = create_type("float32")
if 'Min' in self.known_functions:
del self.known_functions['Min']
if 'Max' in self.known_functions:
del self.known_functions['Max']
def _print_Pow(self, expr): def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication""" """Don't use std::pow function, for small integer exponents, write as multiplication"""
if not expr.free_symbols: # Ideally the printer has as little logic as possible. Therefore,
return self._typed_number(expr.evalf(), get_type_of_expression(expr)) # powers should be rewritten as `DivFunc`s / unevaluated `Mul`s before
# printing. `NodeCollection` offers a convenience function to do just
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: # that. However, `cut_loops` rewrites unevaluated multiplications as
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" # `Pow`s again. Neither `deepcopy` nor `func(*args)` are suited to
elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: # rebuild unevaluated expressions. Therefore, as long as we stick with
return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False))) # SymPy, this is the only way to avoid printing `pow`s.
exp = expr.exp.expr if isinstance(expr.exp, CastFunc) else expr.exp
one_type = expr.base.dtype if hasattr(expr.base, "dtype") else get_type_of_expression(expr.base)
if exp.is_integer and exp.is_number and (0 < exp <= 8):
return f"({self._print(sp.Mul(*[expr.base] * exp, evaluate=False))})"
elif exp.is_integer and exp.is_number and (-8 <= exp < 0):
return f"{self._typed_number(1, one_type)} / ({self._print(sp.Mul(*([expr.base] * -exp), evaluate=False))})"
else: else:
return super(CustomSympyPrinter, self)._print_Pow(expr) return super(CustomSympyPrinter, self)._print_Pow(expr)
# TODO don't print ones in sp.Mul
def _print_Rational(self, expr): def _print_Rational(self, expr):
"""Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0""" """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
res = str(expr.evalf().num) res = str(expr.evalf(17))
return res return res
def _print_Equality(self, expr): def _print_Equality(self, expr):
...@@ -323,6 +479,15 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -323,6 +479,15 @@ class CustomSympyPrinter(CCodePrinter):
result = super(CustomSympyPrinter, self)._print_Piecewise(expr) result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
return result.replace("\n", "") return result.replace("\n", "")
def _print_Abs(self, expr):
if expr.args[0].is_integer:
return f'abs({self._print(expr.args[0])})'
else:
return f'fabs({self._print(expr.args[0])})'
def _print_AbstractType(self, node):
return str(node)
def _print_Function(self, expr): def _print_Function(self, expr):
infix_functions = { infix_functions = {
bitwise_xor: '^', bitwise_xor: '^',
...@@ -333,34 +498,65 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -333,34 +498,65 @@ class CustomSympyPrinter(CCodePrinter):
} }
if hasattr(expr, 'to_c'): if hasattr(expr, 'to_c'):
return expr.to_c(self._print) return expr.to_c(self._print)
if isinstance(expr, reinterpret_cast_func): if isinstance(expr, ReinterpretCastFunc):
arg, data_type = expr.args arg, data_type = expr.args
return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg)) if isinstance(data_type, PointerType):
elif isinstance(expr, address_of): const_str = "const" if data_type.const else ""
return f"(({const_str} {self._print(data_type.base_type)} *)(& {self._print(arg)}))"
else:
return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))"
elif isinstance(expr, AddressOf):
assert len(expr.args) == 1, "address_of must only have one argument" assert len(expr.args) == 1, "address_of must only have one argument"
return "&(%s)" % self._print(expr.args[0]) return f"&({self._print(expr.args[0])})"
elif isinstance(expr, cast_func): elif isinstance(expr, CastFunc):
cast = "(({data_type})({code}))"
arg, data_type = expr.args arg, data_type = expr.args
if isinstance(arg, sp.Number): if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)):
return self._typed_number(arg, data_type) return self._typed_number(arg, data_type)
elif isinstance(arg, (InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)) \
and data_type == BasicType('float32'):
known = self.known_functions[arg.__class__.__name__.lower()]
code = self._print(arg)
return code.replace(known, f"{known}f")
elif isinstance(arg, (sp.Pow, sp.exp)) and data_type == BasicType('float32'):
known = ['sqrt', 'cbrt', 'pow', 'exp']
code = self._print(arg)
for k in known:
if k in code:
return code.replace(k, f'{k}f')
# Powers of small integers are printed as divisions/multiplications.
if '/' in code or '*' in code:
return cast.format(data_type=data_type, code=code)
raise ValueError(f"{code} doesn't give {known=} function back.")
else: else:
return "((%s)(%s))" % (data_type, self._print(arg)) return cast.format(data_type=data_type, code=self._print(arg))
elif isinstance(expr, fast_division): elif isinstance(expr, fast_division):
return "({})".format(self._print(expr.args[0] / expr.args[1])) raise ValueError("fast_division is only supported for Taget.GPU")
elif isinstance(expr, fast_sqrt): elif isinstance(expr, fast_sqrt):
return "({})".format(self._print(sp.sqrt(expr.args[0]))) raise ValueError("fast_sqrt is only supported for Taget.GPU")
elif isinstance(expr, fast_inv_sqrt):
raise ValueError("fast_inv_sqrt is only supported for Taget.GPU")
elif isinstance(expr, vec_any) or isinstance(expr, vec_all): elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
return self._print(expr.args[0]) return self._print(expr.args[0])
elif isinstance(expr, fast_inv_sqrt): elif isinstance(expr, sp.Abs):
return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) return f"abs({self._print(expr.args[0])})"
elif isinstance(expr, sp.Mod):
if expr.args[0].is_integer and expr.args[1].is_integer:
return f"({self._print(expr.args[0])} % {self._print(expr.args[1])})"
else:
return f"fmod({self._print(expr.args[0])}, {self._print(expr.args[1])})"
elif expr.func in infix_functions: elif expr.func in infix_functions:
return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1])) return f"({self._print(expr.args[0])} {infix_functions[expr.func]} {self._print(expr.args[1])})"
elif expr.func == int_power_of_2: elif expr.func == int_power_of_2:
return "(1 << (%s))" % (self._print(expr.args[0])) return f"(1 << ({self._print(expr.args[0])}))"
elif expr.func == int_div: elif expr.func == int_div:
return "((%s) / (%s))" % (self._print(expr.args[0]), self._print(expr.args[1])) return f"(({self._print(expr.args[0])}) / ({self._print(expr.args[1])}))"
elif expr.func == DivFunc:
return f'(({self._print(expr.divisor)}) / ({self._print(expr.dividend)}))'
else: else:
return super(CustomSympyPrinter, self)._print_Function(expr) name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__
arg_str = ', '.join(self._print(a) for a in expr.args)
return f'{name}({arg_str})'
def _typed_number(self, number, dtype): def _typed_number(self, number, dtype):
res = self._print(number) res = self._print(number)
...@@ -368,11 +564,39 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -368,11 +564,39 @@ class CustomSympyPrinter(CCodePrinter):
return res + '.0f' if '.' not in res else res + 'f' return res + '.0f' if '.' not in res else res + 'f'
elif dtype.numpy_dtype == np.float64: elif dtype.numpy_dtype == np.float64:
return res + '.0' if '.' not in res else res return res + '.0' if '.' not in res else res
elif dtype.is_int():
tokens = res.split('.')
if len(tokens) == 1:
return res
elif int(tokens[1]) != 0:
raise ValueError(f"Cannot print non-integer number {res} as an integer.")
else:
return tokens[0]
else: else:
return res return res
_print_Max = C89CodePrinter._print_Max def _print_ConditionalFieldAccess(self, node):
_print_Min = C89CodePrinter._print_Min return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True)))
def _print_Max(self, expr):
def inner_print_max(args):
if len(args) == 1:
return self._print(args[0])
half = len(args) // 2
a = inner_print_max(args[:half])
b = inner_print_max(args[half:])
return f"(({a} > {b}) ? {a} : {b})"
return inner_print_max(expr.args)
def _print_Min(self, expr):
def inner_print_min(args):
if len(args) == 1:
return self._print(args[0])
half = len(args) // 2
a = inner_print_min(args[:half])
b = inner_print_min(args[half:])
return f"(({a} < {b}) ? {a} : {b})"
return inner_print_min(expr.args)
# noinspection PyPep8Naming # noinspection PyPep8Naming
...@@ -391,41 +615,113 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -391,41 +615,113 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert self.instruction_set['width'] == expr_type.width assert self.instruction_set['width'] == expr_type.width
return None return None
def _print_Abs(self, expr):
if isinstance(get_type_of_expression(expr), (VectorType, VectorMemoryAccess)):
return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs)
return super()._print_Abs(expr)
def _typed_vectorized_number(self, expr, data_type):
basic_data_type = data_type.base_type
number = self._typed_number(expr, basic_data_type)
instruction = 'makeVecConst'
if basic_data_type.is_bool():
instruction = 'makeVecConstBool'
# TODO Vectorization Revamp: is int, or sint, or uint (my guess is sint)
elif basic_data_type.is_int():
instruction = 'makeVecConstInt'
return self.instruction_set[instruction].format(number, **self._kwargs)
def _typed_vectorized_symbol(self, expr, data_type):
if not isinstance(expr, TypedSymbol):
raise ValueError(f'{expr} is not a TypeSymbol. It is {expr.type=}')
basic_data_type = data_type.base_type
symbol = self._print(expr)
if basic_data_type != expr.dtype:
symbol = f'(({basic_data_type})({symbol}))'
instruction = 'makeVecConst'
if basic_data_type.is_bool():
instruction = 'makeVecConstBool'
# TODO Vectorization Revamp: is int, or sint, or uint (my guess is sint)
elif basic_data_type.is_int():
instruction = 'makeVecConstInt'
return self.instruction_set[instruction].format(symbol, **self._kwargs)
def _print_CastFunc(self, expr):
arg, data_type = expr.args
if type(data_type) is VectorType:
base_type = data_type.base_type
# vector_memory_access is a cast_func itself so it should't be directly inside a cast_func
assert not isinstance(arg, VectorMemoryAccess)
if isinstance(arg, sp.Tuple):
is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
is_integer = get_type_of_expression(arg[0]) == create_type("int")
printed_args = [self._print(a) for a in arg]
instruction = 'makeVecBool' if is_boolean else 'makeVecInt' if is_integer else 'makeVec'
if instruction == 'makeVecInt' and 'makeVecIndex' in self.instruction_set:
increments = np.array(arg)[1:] - np.array(arg)[:-1]
if len(set(increments)) == 1:
return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0],
**self._kwargs)
return self.instruction_set[instruction].format(*printed_args, **self._kwargs)
else:
if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)):
return self._typed_vectorized_number(arg, data_type)
elif isinstance(arg, TypedSymbol):
return self._typed_vectorized_symbol(arg, data_type)
elif isinstance(arg, (InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)) \
and base_type == BasicType('float32'):
raise NotImplementedError('Vectorizer is not tested for trigonometric functions yet')
# known = self.known_functions[arg.__class__.__name__.lower()]
# code = self._print(arg)
# return code.replace(known, f"{known}f")
elif isinstance(arg, sp.Pow):
if base_type == BasicType('float32') or base_type == BasicType('float64'):
return self._print_Pow(arg)
else:
raise NotImplementedError('Integer Pow is not implemented')
elif isinstance(arg, sp.UnevaluatedExpr):
return self._print(arg.args[0])
else:
raise NotImplementedError('Vectorizer cannot cast between different datatypes')
# to_type = self.instruction_set['suffix'][data_type.base_type.c_name]
# from_type = self.instruction_set['suffix'][get_type_of_expression(arg).base_type.c_name]
# return self.instruction_set['cast'].format(from_type, to_type, self._print(arg))
else:
return self._scalarFallback('_print_Function', expr)
# raise ValueError(f'Non VectorType cast "{data_type}" in vectorized code.')
def _print_Function(self, expr): def _print_Function(self, expr):
if isinstance(expr, vector_memory_access): if isinstance(expr, VectorMemoryAccess):
arg, data_type, aligned, _ = expr.args arg, data_type, aligned, _, mask, stride = expr.args
if stride != 1:
return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs)
instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
return instruction.format("& " + self._print(arg)) return instruction.format(f"& {self._print(arg)}", **self._kwargs)
elif isinstance(expr, cast_func): elif expr.func == DivFunc:
arg, data_type = expr.args
if type(data_type) is VectorType:
return self.instruction_set['makeVec'].format(self._print(arg))
elif expr.func == fast_division:
result = self._scalarFallback('_print_Function', expr) result = self._scalarFallback('_print_Function', expr)
if not result: if not result:
result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1])) result = self.instruction_set['/'].format(self._print(expr.divisor), self._print(expr.dividend),
**self._kwargs)
return result return result
elif expr.func == fast_sqrt: elif isinstance(expr, fast_division):
return "({})".format(self._print(sp.sqrt(expr.args[0]))) raise ValueError("fast_division is only supported for Taget.GPU")
elif expr.func == fast_inv_sqrt: elif isinstance(expr, fast_sqrt):
result = self._scalarFallback('_print_Function', expr) raise ValueError("fast_sqrt is only supported for Taget.GPU")
if not result: elif isinstance(expr, fast_inv_sqrt):
if self.instruction_set['rsqrt']: raise ValueError("fast_inv_sqrt is only supported for Taget.GPU")
return self.instruction_set['rsqrt'].format(self._print(expr.args[0])) elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
else: instr = 'any' if isinstance(expr, vec_any) else 'all'
return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
elif isinstance(expr, vec_any):
expr_type = get_type_of_expression(expr.args[0])
if type(expr_type) is not VectorType:
return self._print(expr.args[0])
else:
return self.instruction_set['any'].format(self._print(expr.args[0]))
elif isinstance(expr, vec_all):
expr_type = get_type_of_expression(expr.args[0]) expr_type = get_type_of_expression(expr.args[0])
if type(expr_type) is not VectorType: if type(expr_type) is not VectorType:
return self._print(expr.args[0]) return self._print(expr.args[0])
else: else:
return self.instruction_set['all'].format(self._print(expr.args[0])) if isinstance(expr.args[0], sp.Rel):
op = expr.args[0].rel_op
if (instr, op) in self.instruction_set:
return self.instruction_set[(instr, op)].format(*[self._print(a) for a in expr.args[0].args],
**self._kwargs)
return self.instruction_set[instr].format(self._print(expr.args[0]), **self._kwargs)
return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
...@@ -438,7 +734,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -438,7 +734,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert len(arg_strings) > 0 assert len(arg_strings) > 0
result = arg_strings[0] result = arg_strings[0]
for item in arg_strings[1:]: for item in arg_strings[1:]:
result = self.instruction_set['&'].format(result, item) result = self.instruction_set['&'].format(result, item, **self._kwargs)
return result return result
def _print_Or(self, expr): def _print_Or(self, expr):
...@@ -450,16 +746,31 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -450,16 +746,31 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert len(arg_strings) > 0 assert len(arg_strings) > 0
result = arg_strings[0] result = arg_strings[0]
for item in arg_strings[1:]: for item in arg_strings[1:]:
result = self.instruction_set['|'].format(result, item) result = self.instruction_set['|'].format(result, item, **self._kwargs)
return result return result
def _print_Add(self, expr, order=None): def _print_Add(self, expr, order=None):
result = self._scalarFallback('_print_Add', expr) try:
result = self._scalarFallback('_print_Add', expr)
except Exception:
result = None
if result: if result:
return result return result
args = expr.args
# special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization
suffix = ""
if all([(type(e) is CastFunc and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer)
or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]):
dtype = set([e.dtype for e in args if type(e) is CastFunc])
assert len(dtype) == 1
dtype = dtype.pop()
args = [CastFunc(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e
for e in args]
suffix = "int"
summands = [] summands = []
for term in expr.args: for term in args:
if term.func == sp.Mul: if term.func == sp.Mul:
sign, t = self._print_Mul(term, inside_add=True) sign, t = self._print_Mul(term, inside_add=True)
else: else:
...@@ -475,30 +786,36 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -475,30 +786,36 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert len(summands) >= 2 assert len(summands) >= 2
processed = summands[0].term processed = summands[0].term
for summand in summands[1:]: for summand in summands[1:]:
func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+'] func = self.instruction_set['-' + suffix] if summand.sign == -1 else self.instruction_set['+' + suffix]
processed = func.format(processed, summand.term) processed = func.format(processed, summand.term, **self._kwargs)
return processed return processed
def _print_Pow(self, expr): def _print_Pow(self, expr):
result = self._scalarFallback('_print_Pow', expr) # Due to loop cutting sp.Mul is evaluated again.
try:
result = self._scalarFallback('_print_Pow', expr)
except ValueError:
result = None
if result: if result:
return result return result
one = self.instruction_set['makeVec'].format(1.0) one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs)
root = self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs)
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" if isinstance(expr.exp, CastFunc) and expr.exp.args[0].is_number:
elif expr.exp == -1: exp = expr.exp.args[0]
one = self.instruction_set['makeVec'].format(1.0) else:
return self.instruction_set['/'].format(one, self._print(expr.base)) exp = expr.exp
elif expr.exp == 0.5:
return self.instruction_set['sqrt'].format(self._print(expr.base)) # TODO the printer should not have any intelligence like this.
elif expr.exp == -0.5: # TODO To remove all of these cases the vectoriser needs to be reworked. See loop cutting
root = self.instruction_set['sqrt'].format(self._print(expr.base)) if exp.is_integer and exp.is_number and 0 < exp < 8:
return self.instruction_set['/'].format(one, root) return self._print(sp.Mul(*[expr.base] * exp, evaluate=False))
elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: elif exp == 0.5:
return self.instruction_set['/'].format(one, return root
self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False))) elif exp == -0.5:
return self.instruction_set['/'].format(one, root, **self._kwargs)
else: else:
raise ValueError("Generic exponential not supported: " + str(expr)) raise ValueError("Generic exponential not supported: " + str(expr))
...@@ -506,7 +823,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -506,7 +823,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
# noinspection PyProtectedMember # noinspection PyProtectedMember
from sympy.core.mul import _keep_coeff from sympy.core.mul import _keep_coeff
result = self._scalarFallback('_print_Mul', expr) if not inside_add:
result = self._scalarFallback('_print_Mul', expr)
else:
result = None
if result: if result:
return result return result
...@@ -537,19 +857,19 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -537,19 +857,19 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = a_str[0] result = a_str[0]
for item in a_str[1:]: for item in a_str[1:]:
result = self.instruction_set['*'].format(result, item) result = self.instruction_set['*'].format(result, item, **self._kwargs)
if len(b) > 0: if len(b) > 0:
denominator_str = b_str[0] denominator_str = b_str[0]
for item in b_str[1:]: for item in b_str[1:]:
denominator_str = self.instruction_set['*'].format(denominator_str, item) denominator_str = self.instruction_set['*'].format(denominator_str, item, **self._kwargs)
result = self.instruction_set['/'].format(result, denominator_str) result = self.instruction_set['/'].format(result, denominator_str, **self._kwargs)
if inside_add: if inside_add:
return sign, result return sign, result
else: else:
if sign < 0: if sign < 0:
return self.instruction_set['*'].format(self._print(S.NegativeOne), result) return self.instruction_set['*'].format(self._print(S.NegativeOne), result, **self._kwargs)
else: else:
return result return result
...@@ -557,13 +877,13 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -557,13 +877,13 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self._scalarFallback('_print_Relational', expr) result = self._scalarFallback('_print_Relational', expr)
if result: if result:
return result return result
return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs)) return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs), **self._kwargs)
def _print_Equality(self, expr): def _print_Equality(self, expr):
result = self._scalarFallback('_print_Equality', expr) result = self._scalarFallback('_print_Equality', expr)
if result: if result:
return result return result
return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs)) return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs), **self._kwargs)
def _print_Piecewise(self, expr): def _print_Piecewise(self, expr):
result = self._scalarFallback('_print_Piecewise', expr) result = self._scalarFallback('_print_Piecewise', expr)
...@@ -581,13 +901,11 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -581,13 +901,11 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self._print(expr.args[-1][0]) result = self._print(expr.args[-1][0])
for true_expr, condition in reversed(expr.args[:-1]): for true_expr, condition in reversed(expr.args[:-1]):
if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"): if isinstance(condition, CastFunc) and get_type_of_expression(condition.args[0]) == create_type("bool"):
if not KERNCRAFT_NO_TERNARY_MODE: result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr),
result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), result, **self._kwargs)
result)
else:
print("Warning - skipping ternary op")
else: else:
# noinspection SpellCheckingInspection # noinspection SpellCheckingInspection
result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition)) result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition),
**self._kwargs)
return result return result
from os.path import dirname, join
from pystencils.astnodes import Node from pystencils.astnodes import Node
from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
from pystencils.enums import Backend
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
lines = f.readlines()
CUDA_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l}
def generate_cuda(astnode: Node, signature_only: bool = False) -> str: def generate_cuda(ast_node: Node, signature_only: bool = False, custom_backend=None, with_globals=True) -> str:
"""Prints an abstract syntax tree node as CUDA code. """Prints an abstract syntax tree node as CUDA code.
Args: Args:
astnode: KernelFunction node to generate code for ast_node: ast representation of kernel
signature_only: if True only the signature is printed signature_only: generate signature without function body
custom_backend: use own custom printer for code generation
with_globals: enable usage of global variables
Returns: Returns:
C-like code for the ast node and its descendants CUDA code for the ast node and its descendants
""" """
return generate_c(astnode, signature_only, dialect='cuda') return generate_c(ast_node, signature_only, dialect=Backend.CUDA,
custom_backend=custom_backend, with_globals=with_globals)
class CudaBackend(CBackend): class CudaBackend(CBackend):
...@@ -29,26 +27,23 @@ class CudaBackend(CBackend): ...@@ -29,26 +27,23 @@ class CudaBackend(CBackend):
if not sympy_printer: if not sympy_printer:
sympy_printer = CudaSympyPrinter() sympy_printer = CudaSympyPrinter()
super().__init__(sympy_printer, signature_only, dialect='cuda') super().__init__(sympy_printer, signature_only, dialect=Backend.CUDA)
def _print_SharedMemoryAllocation(self, node): def _print_SharedMemoryAllocation(self, node):
code = "__shared__ {dtype} {name}[{num_elements}];" dtype = node.symbol.dtype
return code.format(dtype=node.symbol.dtype, name = self.sympy_printer.doprint(node.symbol.name)
name=self.sympy_printer.doprint(node.symbol.name), num_elements = '*'.join([str(s) for s in node.shared_mem.shape])
num_elements='*'.join([str(s) for s in node.shared_mem.shape])) code = f"__shared__ {dtype} {name}[{num_elements}];"
return code
@staticmethod @staticmethod
def _print_ThreadBlockSynchronization(node): def _print_ThreadBlockSynchronization(_):
code = "__synchtreads();" return "__synchtreads();"
return code
def _print_TextureDeclaration(self, node): def _print_TextureDeclaration(self, node):
code = "texture<%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % ( cond = node.texture.field.dtype.numpy_dtype.itemsize > 4
str(node.texture.field.dtype), return f'texture<{"fp_tex_" if cond else ""}{str(node.texture.field.dtype)}, ' \
node.texture.field.spatial_dimensions, f'cudaTextureType{node.texture.field.spacial_dimensions}D, cudaReadModeElementType> {node.texture};'
node.texture
)
return code
def _print_SkipIteration(self, _): def _print_SkipIteration(self, _):
return "return;" return "return;"
...@@ -59,28 +54,15 @@ class CudaSympyPrinter(CustomSympyPrinter): ...@@ -59,28 +54,15 @@ class CudaSympyPrinter(CustomSympyPrinter):
def __init__(self): def __init__(self):
super(CudaSympyPrinter, self).__init__() super(CudaSympyPrinter, self).__init__()
self.known_functions.update(CUDA_KNOWN_FUNCTIONS)
def _print_TextureAccess(self, node):
if node.texture.cubic_bspline_interpolation:
template = "cubicTex%iDSimple<%s>(%s, %s)"
else:
template = "tex%iD<%s>(%s, %s)"
code = template % (
node.texture.field.spatial_dimensions,
str(node.texture.field.dtype),
str(node.texture),
', '.join(self._print(o) for o in node.offsets)
)
return code
def _print_Function(self, expr): def _print_Function(self, expr):
if isinstance(expr, fast_division): if isinstance(expr, fast_division):
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args) assert len(expr.args) == 2, f"__fdividef has two arguments, but {len(expr.args)} where given"
return f"__fdividef({self._print(expr.args[0])}, {self._print(expr.args[1])})"
elif isinstance(expr, fast_sqrt): elif isinstance(expr, fast_sqrt):
return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__fsqrt_rn({self._print(expr.args[0])})"
elif isinstance(expr, fast_inv_sqrt): elif isinstance(expr, fast_inv_sqrt):
return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) assert len(expr.args) == 1, f"__frsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__frsqrt_rn({self._print(expr.args[0])})"
return super()._print_Function(expr) return super()._print_Function(expr)
import graphviz import graphviz
from graphviz import Digraph, lang try:
from graphviz import Digraph
import graphviz.quoting as quote
except ImportError:
from graphviz import Digraph
import graphviz.lang as quote
from sympy.printing.printer import Printer from sympy.printing.printer import Printer
...@@ -12,7 +17,7 @@ class DotPrinter(Printer): ...@@ -12,7 +17,7 @@ class DotPrinter(Printer):
super(DotPrinter, self).__init__() super(DotPrinter, self).__init__()
self._node_to_str_function = node_to_str_function self._node_to_str_function = node_to_str_function
self.dot = Digraph(**kwargs) self.dot = Digraph(**kwargs)
self.dot.quote_edge = lang.quote self.dot.quote_edge = quote.quote
def _print_KernelFunction(self, func): def _print_KernelFunction(self, func):
self.dot.node(str(id(func)), style='filled', fillcolor='#a056db', label=self._node_to_str_function(func)) self.dot.node(str(id(func)), style='filled', fillcolor='#a056db', label=self._node_to_str_function(func))
...@@ -50,22 +55,20 @@ class DotPrinter(Printer): ...@@ -50,22 +55,20 @@ class DotPrinter(Printer):
def __shortened(node): def __shortened(node):
from pystencils.astnodes import LoopOverCoordinate, KernelFunction, SympyAssignment, Block, Conditional from pystencils.astnodes import LoopOverCoordinate, KernelFunction, SympyAssignment, Conditional
if isinstance(node, LoopOverCoordinate): if isinstance(node, LoopOverCoordinate):
return "Loop over dim %d" % (node.coordinate_to_loop_over,) return "Loop over dim %d" % (node.coordinate_to_loop_over,)
elif isinstance(node, KernelFunction): elif isinstance(node, KernelFunction):
params = node.get_parameters() params = node.get_parameters()
param_names = [p.field_name for p in params if p.is_field_pointer] param_names = [p.field_name for p in params if p.is_field_pointer]
param_names += [p.symbol.name for p in params if not p.is_field_parameter] param_names += [p.symbol.name for p in params if not p.is_field_parameter]
return "Func: %s (%s)" % (node.function_name, ",".join(param_names)) return f"Func: {node.function_name} ({','.join(param_names)})"
elif isinstance(node, SympyAssignment): elif isinstance(node, SympyAssignment):
return repr(node.lhs) return repr(node.lhs)
elif isinstance(node, Block):
return "Block" + str(id(node))
elif isinstance(node, Conditional): elif isinstance(node, Conditional):
return repr(node) return repr(node)
else: else:
raise NotImplementedError("Cannot handle node type %s" % (type(node),)) raise NotImplementedError(f"Cannot handle node type {type(node)}")
def print_dot(node, view=False, short=False, **kwargs): def print_dot(node, view=False, short=False, **kwargs):
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import json
from pystencils.astnodes import NodeOrExpr
from pystencils.backends.cbackend import CustomSympyPrinter, generate_c
try:
import yaml
except ImportError:
raise ImportError('yaml not installed')
def expr_to_dict(expr_or_node: NodeOrExpr, with_c_code=True, full_class_names=False):
"""Converts a SymPy expression to a serializable dict (mainly for debugging purposes)
The dict recursively contains all args of the expression as ``dict``s
See :func:`.write_json`
Args:
expr_or_node (NodeOrExpr): a SymPy expression or a :class:`pystencils.astnodes.Node`
with_c_code (bool, optional): include C representation of the nodes
full_class_names (bool, optional): use full class names (type(object) instead of ``type(object).__name__``
"""
self = {'str': str(expr_or_node)}
if with_c_code:
try:
self.update({'c': generate_c(expr_or_node)})
except Exception:
try:
self.update({'c': CustomSympyPrinter().doprint(expr_or_node)})
except Exception:
pass
for a in expr_or_node.args:
self.update({str(a.__class__ if full_class_names else a.__class__.__name__): expr_to_dict(a)})
return self
def print_json(expr_or_node: NodeOrExpr):
"""Print debug JSON of an AST to string
Args:
expr_or_node (NodeOrExpr): a SymPy expression or a :class:`pystencils.astnodes.Node`
Returns:
str: JSON representation of AST
"""
expr_or_node_dict = expr_to_dict(expr_or_node)
return json.dumps(expr_or_node_dict, indent=4)
def write_json(filename: str, expr_or_node: NodeOrExpr):
"""Writes debug JSON represenation of AST to file
Args:
filename (str): Path for the file to write
expr_or_node (NodeOrExpr): a SymPy expression or a :class:`pystencils.astnodes.Node`
"""
expr_or_node_dict = expr_to_dict(expr_or_node)
with open(filename, 'w') as f:
json.dump(expr_or_node_dict, f, indent=4)
def print_yaml(expr_or_node):
expr_or_node_dict = expr_to_dict(expr_or_node, full_class_names=False)
return yaml.dump(expr_or_node_dict)
def write_yaml(filename, expr_or_node):
expr_or_node_dict = expr_to_dict(expr_or_node)
with open(filename, 'w') as f:
yaml.dump(expr_or_node_dict, f)
def get_argument_string(function_shortcut):
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
arg_string += "{" + arg + "},"
else:
arg_string += arg + ","
arg_string = arg_string[:-1] + ")"
return arg_string
def get_vector_instruction_set_ppc(data_type='double', instruction_set='vsx'):
if instruction_set != 'vsx':
raise NotImplementedError(instruction_set)
base_names = {
'+': 'add[0, 1]',
'-': 'sub[0, 1]',
'*': 'mul[0, 1]',
'/': 'div[0, 1]',
'sqrt': 'sqrt[0]',
'rsqrt': 'rsqrte[0]', # rsqrt is available too, but not on Clang
'loadU': 'xl[0x0, 0]',
'loadA': 'ld[0x0, 0]',
'storeU': 'xst[1, 0x0, 0]',
'storeA': 'st[1, 0x0, 0]',
'storeAAndFlushCacheline': 'stl[1, 0x0, 0]',
'abs': 'abs[0]',
'==': 'cmpeq[0, 1]',
'!=': 'cmpne[0, 1]',
'<=': 'cmple[0, 1]',
'<': 'cmplt[0, 1]',
'>=': 'cmpge[0, 1]',
'>': 'cmpgt[0, 1]',
'&': 'and[0, 1]',
'|': 'or[0, 1]',
'blendv': 'sel[0, 1, 2]',
('any', '=='): 'any_eq[0, 1]',
('any', '!='): 'any_ne[0, 1]',
('any', '<='): 'any_le[0, 1]',
('any', '<'): 'any_lt[0, 1]',
('any', '>='): 'any_ge[0, 1]',
('any', '>'): 'any_gt[0, 1]',
('all', '=='): 'all_eq[0, 1]',
('all', '!='): 'all_ne[0, 1]',
('all', '<='): 'all_le[0, 1]',
('all', '<'): 'all_lt[0, 1]',
('all', '>='): 'all_ge[0, 1]',
('all', '>'): 'all_gt[0, 1]',
}
bits = {'double': 64,
'float': 32,
'int': 32}
width = 128 // bits[data_type]
intwidth = 128 // bits['int']
result = dict()
result['bytes'] = 16
for intrinsic_id, function_shortcut in base_names.items():
function_shortcut = function_shortcut.strip()
name = function_shortcut[:function_shortcut.index('[')]
arg_string = get_argument_string(function_shortcut)
result[intrinsic_id] = 'vec_' + name + arg_string
if data_type == 'double':
# Clang and XL C++ are missing these for doubles
result['loadA'] = '(__vector double)' + result['loadA'].format('(float*) {0}')
result['storeA'] = result['storeA'].format('(float*) {0}', '(__vector float) {1}')
result['storeAAndFlushCacheline'] = result['storeAAndFlushCacheline'].format('(float*) {0}',
'(__vector float) {1}')
result['+int'] = "vec_add({0}, {1})"
result['width'] = width
result['intwidth'] = intwidth
result[data_type] = f'__vector {data_type}'
result['int'] = '__vector int'
result['bool'] = f'__vector __bool {"long long" if data_type == "double" else "int"}'
result['headers'] = ['<altivec.h>', '"ppc_altivec_helpers.h"']
result['makeVecConst'] = '((' + result[data_type] + '){{' + \
", ".join(['(' + data_type + ') {0}' for _ in range(width)]) + '}})'
result['makeVec'] = '((' + result[data_type] + '){{' + \
", ".join(['{' + data_type + '} {' + str(i) + '}' for i in range(width)]) + '}})'
result['makeVecConstInt'] = '((' + result['int'] + '){{' + ", ".join(['(int) {0}' for _ in range(intwidth)]) + '}})'
result['makeVecInt'] = '((' + result['int'] + '){{(int) {0}, (int) {1}, (int) {2}, (int) {3}}})'
result['any'] = 'vec_any_ne({0}, ((' + result['bool'] + ') {{' + ", ".join(['0'] * width) + '}}))'
result['all'] = 'vec_all_ne({0}, ((' + result['bool'] + ') {{' + ", ".join(['0'] * width) + '}}))'
result['cachelineSize'] = 'cachelineSize()'
result['cachelineZero'] = 'cachelineZero((void*) {0})'
return result
from pystencils.typing import CFunction
def get_argument_string(function_shortcut, last=''):
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
arg_string += "{" + arg + "},"
else:
arg_string += arg + ","
if last:
arg_string += last + ','
arg_string = arg_string[:-1] + ")"
return arg_string
def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'):
assert instruction_set == 'rvv'
bits = {'double': 64,
'float': 32,
'int': 32}
base_names = {
'+': 'fadd_vv[0, 1]',
'-': 'fsub_vv[0, 1]',
'*': 'fmul_vv[0, 1]',
'/': 'fdiv_vv[0, 1]',
'sqrt': 'fsqrt_v[0]',
'loadU': f'le{bits[data_type]}_v[0]',
'storeU': f'se{bits[data_type]}_v[0, 1]',
'maskStoreU': f'se{bits[data_type]}_v[2, 0, 1]',
'loadS': f'lse{bits[data_type]}_v[0, 1]',
'storeS': f'sse{bits[data_type]}_v[0, 2, 1]',
'maskStoreS': f'sse{bits[data_type]}_v[3, 0, 2, 1]',
'abs': 'fabs_v[0]',
'==': 'mfeq_vv[0, 1]',
'!=': 'mfne_vv[0, 1]',
'<=': 'mfle_vv[0, 1]',
'<': 'mflt_vv[0, 1]',
'>=': 'mfge_vv[0, 1]',
'>': 'mfgt_vv[0, 1]',
'&': 'mand_mm[0, 1]',
'|': 'mor_mm[0, 1]',
'blendv': 'merge_vvm[2, 0, 1]',
'any': 'cpop_m[0]',
'all': 'cpop_m[0]',
}
result = dict()
width = f'vsetvlmax_e{bits[data_type]}m1()'
intwidth = 'vsetvlmax_e{bits["int"]}m1()'
result['bytes'] = 'vsetvlmax_e8m1()'
prefix = 'v'
suffix = f'_f{bits[data_type]}m1'
vl = '{loop_stop} - {loop_counter}'
int_vl = f'({vl})*{bits[data_type]//bits["int"]}'
for intrinsic_id, function_shortcut in base_names.items():
function_shortcut = function_shortcut.strip()
name = function_shortcut[:function_shortcut.index('[')]
if name.startswith('mf'):
suffix2 = suffix + f'_b{bits[data_type]}'
elif name.endswith('_mm') or name.endswith('_m'):
suffix2 = f'_b{bits[data_type]}'
elif intrinsic_id.startswith('mask'):
suffix2 = suffix + '_m'
else:
suffix2 = suffix
arg_string = get_argument_string(function_shortcut, last=vl)
result[intrinsic_id] = prefix + name + suffix2 + arg_string
result['width'] = CFunction(width, "int")
result['intwidth'] = CFunction(intwidth, "int")
result['makeVecConst'] = f'vfmv_v_f_f{bits[data_type]}m1({{0}}, {vl})'
result['makeVecConstInt'] = f'vmv_v_x_i{bits["int"]}m1({{0}}, {int_vl})'
result['makeVecIndex'] = f'vmacc_vx_i{bits["int"]}m1({result["makeVecConstInt"]}, {{1}}, ' + \
f'vid_v_i{bits["int"]}m1({int_vl}), {int_vl})'
result['storeS'] = result['storeS'].replace('{2}', f'{{2}}*{bits[data_type]//8}')
result['loadS'] = result['loadS'].replace('{1}', f'{{1}}*{bits[data_type]//8}')
result['maskStoreS'] = result['maskStoreS'].replace('{2}', f'{{2}}*{bits[data_type]//8}')
result['+int'] = f"vadd_vv_i{bits['int']}m1({{0}}, {{1}}, {int_vl})"
result['float'] = f'vfloat{bits["float"]}m1_t'
result['double'] = f'vfloat{bits["double"]}m1_t'
result['int'] = f'vint{bits["int"]}m1_t'
result['bool'] = f'vbool{bits[data_type]}_t'
result['headers'] = ['<riscv_vector.h>', '"riscv_v_helpers.h"']
result['any'] += ' > 0x0'
result['all'] += f' == vsetvl_e{bits[data_type]}m1({vl})'
result['cachelineSize'] = 'cachelineSize()'
result['cachelineZero'] = 'cachelineZero((void*) {0})'
return result
import os
import platform
from ctypes import CDLL, c_int, c_size_t, sizeof, byref
from warnings import warn
import numpy as np
from pystencils.backends.x86_instruction_sets import get_vector_instruction_set_x86
from pystencils.backends.arm_instruction_sets import get_vector_instruction_set_arm
from pystencils.backends.ppc_instruction_sets import get_vector_instruction_set_ppc
from pystencils.backends.riscv_instruction_sets import get_vector_instruction_set_riscv
from pystencils.cache import memorycache
from pystencils.typing import numpy_name_to_c
def get_vector_instruction_set(data_type='double', instruction_set='avx'):
if data_type == 'float':
warn(f"Ambiguous input for data_type: {data_type}. For single precision please use float32. "
f"For more information please take numpy.dtype as a reference. This input will not be supported in future "
f"releases")
data_type = 'float64'
type_name = numpy_name_to_c(np.dtype(data_type).name)
if instruction_set in ['neon', 'sme'] or instruction_set.startswith('sve'):
return get_vector_instruction_set_arm(type_name, instruction_set)
elif instruction_set in ['vsx']:
return get_vector_instruction_set_ppc(type_name, instruction_set)
elif instruction_set in ['rvv']:
return get_vector_instruction_set_riscv(type_name, instruction_set)
else:
return get_vector_instruction_set_x86(type_name, instruction_set)
@memorycache
def get_supported_instruction_sets():
"""List of supported instruction sets on current hardware, or None if query failed."""
if 'PYSTENCILS_SIMD' in os.environ:
return os.environ['PYSTENCILS_SIMD'].split(',')
if platform.system() == 'Darwin' and platform.machine() == 'arm64':
result = ['neon']
libc = CDLL('/usr/lib/libc.dylib')
value = c_int(0)
size = c_size_t(sizeof(value))
status = libc.sysctlbyname(b"hw.optional.arm.FEAT_SME", byref(value), byref(size), None, 0)
if status == 0 and value.value == 1:
result.insert(0, "sme")
return result
elif platform.system() == 'Windows' and platform.machine() == 'ARM64':
return ['neon']
elif platform.system() == 'Linux' and platform.machine() == 'aarch64':
result = ['neon'] # Neon is mandatory on 64-bit ARM
libc = CDLL('libc.so.6')
hwcap = libc.getauxval(16) # AT_HWCAP
hwcap2 = libc.getauxval(26) # AT_HWCAP2
if hwcap & (1 << 22): # HWCAP_SVE
if hwcap2 & (1 << 1): # HWCAP2_SVE2
name = 'sve2'
else:
name = 'sve'
length = 8 * libc.prctl(51, 0, 0, 0, 0) # PR_SVE_GET_VL
if length < 0:
raise OSError("SVE length query failed")
while length >= 128:
result.append(f"{name}{length}")
length //= 2
result.append(name)
if hwcap2 & (1 << 23): # HWCAP2_SME
result.insert(0, "sme") # prepend to list so it is not automatically chosen as best instruction set
return result
elif platform.system() == 'Linux' and platform.machine().startswith('riscv'):
libc = CDLL('libc.so.6')
hwcap = libc.getauxval(16) # AT_HWCAP
hwcap_isa_v = 1 << (ord('V') - ord('A')) # COMPAT_HWCAP_ISA_V
return ['rvv'] if hwcap & hwcap_isa_v else []
elif platform.system() == 'Linux' and platform.machine().startswith('ppc64'):
libc = CDLL('libc.so.6')
hwcap = libc.getauxval(16) # AT_HWCAP
return ['vsx'] if hwcap & 0x00000080 else [] # PPC_FEATURE_HAS_VSX
elif platform.machine() in ['x86_64', 'x86', 'AMD64', 'i386']:
try:
from cpuinfo import get_cpu_info
except ImportError:
return None
result = []
required_sse_flags = {'sse', 'sse2', 'ssse3', 'sse4_1', 'sse4_2'}
required_avx_flags = {'avx', 'avx2'}
required_avx512_flags = {'avx512f'}
possible_avx512vl_flags = {'avx512vl', 'avx10_1'}
flags = set(get_cpu_info()['flags'])
if flags.issuperset(required_sse_flags):
result.append("sse")
if flags.issuperset(required_avx_flags):
result.append("avx")
if flags.issuperset(required_avx512_flags):
result.append("avx512")
if not flags.isdisjoint(possible_avx512vl_flags):
result.append("avx512vl")
return result
else:
raise NotImplementedError('Instruction set detection for %s on %s is not implemented' %
(platform.system(), platform.machine()))
@memorycache
def get_cacheline_size(instruction_set):
"""Get the size (in bytes) of a cache block that can be zeroed without memory access.
Usually, this is identical to the cache line size."""
instruction_sets = get_vector_instruction_set('double', instruction_set)
if 'cachelineSize' not in instruction_sets:
return None
import pystencils as ps
from pystencils.astnodes import SympyAssignment
import numpy as np
from pystencils.cpu.vectorization import CachelineSize
arr = np.zeros((1, 1), dtype=np.float32)
f = ps.Field.create_from_numpy_array('f', arr, index_dimensions=0)
ass = [CachelineSize(), SympyAssignment(f.center, CachelineSize.symbol)]
ast = ps.create_kernel(ass, cpu_vectorize_info={'instruction_set': instruction_set})
kernel = ast.compile()
kernel(**{f.name: arr, CachelineSize.symbol.name: 0})
return int(arr[0, 0])
def get_argument_string(intrinsic_id, width, function_shortcut):
if intrinsic_id == 'makeVecConst' or intrinsic_id == 'makeVecConstInt':
arg_string = f"({','.join(['{0}'] * width)})"
elif intrinsic_id == 'makeVec' or intrinsic_id == 'makeVecInt':
params = ["{" + str(i) + "}" for i in reversed(range(width))]
arg_string = f"({','.join(params)})"
elif intrinsic_id == 'makeVecBool':
params = [f"(({{{i}}} ? -1.0 : 0.0)" for i in reversed(range(width))]
arg_string = f"({','.join(params)})"
elif intrinsic_id == 'makeVecConstBool':
params = ["(({0}) ? -1.0 : 0.0)" for _ in range(width)]
arg_string = f"({','.join(params)})"
else:
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
arg_string += "{" + arg + "},"
else:
arg_string += arg + ","
arg_string = arg_string[:-1] + ")"
return arg_string
def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
comparisons = {
'==': '_CMP_EQ_UQ',
'!=': '_CMP_NEQ_UQ',
'>=': '_CMP_GE_OQ',
'<=': '_CMP_LE_OQ',
'<': '_CMP_NGE_UQ',
'>': '_CMP_NLE_UQ',
}
base_names = {
'+': 'add[0, 1]',
'-': 'sub[0, 1]',
'*': 'mul[0, 1]',
'/': 'div[0, 1]',
'&': 'and[0, 1]',
'|': 'or[0, 1]',
'blendv': 'blendv[0, 1, 2]',
'sqrt': 'sqrt[0]',
'makeVecConst': 'set[]',
'makeVec': 'set[]',
'makeVecBool': 'set[]',
'makeVecConstBool': 'set[]',
'makeVecInt': 'set[]',
'makeVecConstInt': 'set[]',
'loadU': 'loadu[0]',
'loadA': 'load[0]',
'storeU': 'storeu[0,1]',
'storeA': 'store[0,1]',
'stream': 'stream[0,1]',
'maskStoreA': 'mask_store[0, 2, 1]' if instruction_set.startswith('avx512') else 'maskstore[0, 2, 1]',
'maskStoreU': 'mask_storeu[0, 2, 1]' if instruction_set.startswith('avx512') else 'maskstore[0, 2, 1]',
}
for comparison_op, constant in comparisons.items():
base_names[comparison_op] = f'cmp[0, 1, {constant}]'
headers = {
'avx512': ['<immintrin.h>'],
'avx512vl': ['<immintrin.h>'],
'avx': ['<immintrin.h>'],
'sse': ['<immintrin.h>', '<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>',
'<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>']
}
suffix = {
'double': 'pd',
'float': 'ps',
'int': 'epi32'
}
prefix = {
'sse': '_mm',
'avx': '_mm256',
'avx512vl': '_mm256',
'avx512': '_mm512',
}
width = {
("double", "sse"): 2,
("float", "sse"): 4,
("int", "sse"): 4,
("double", "avx"): 4,
("float", "avx"): 8,
("int", "avx"): 8,
("double", "avx512vl"): 4,
("float", "avx512vl"): 8,
("int", "avx512vl"): 8,
("double", "avx512"): 8,
("float", "avx512"): 16,
("int", "avx512"): 16,
}
result = {
'width': width[(data_type, instruction_set)],
'intwidth': width[('int', instruction_set)],
'bytes': 4 * width[("float", instruction_set)]
}
pre = prefix[instruction_set]
for intrinsic_id, function_shortcut in base_names.items():
function_shortcut = function_shortcut.strip()
name = function_shortcut[:function_shortcut.index('[')]
if 'Int' in intrinsic_id:
suf = suffix['int']
arg_string = get_argument_string(intrinsic_id, result['intwidth'], function_shortcut)
else:
suf = suffix[data_type]
arg_string = get_argument_string(intrinsic_id, result['width'], function_shortcut)
mask_suffix = '_mask' if instruction_set.startswith('avx512') and intrinsic_id in comparisons.keys() else ''
result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string
bit_width = result['width'] * (64 if data_type == 'double' else 32)
result['double'] = f"__m{bit_width}d"
result['float'] = f"__m{bit_width}"
result['int'] = f"__m{bit_width}i"
result['bool'] = result[data_type]
result['headers'] = headers[instruction_set]
result['any'] = f"{pre}_movemask_{suf}({{0}}) > 0"
result['all'] = f"{pre}_movemask_{suf}({{0}}) == {hex(2**result['width']-1)}"
setsuf = "x" if bit_width < 512 and bit_width // result['width'] == 64 else ""
if instruction_set.startswith('avx512'):
size = result['width']
masksize = max(size, 8)
result['&'] = f'_kand_mask{masksize}({{0}}, {{1}})'
result['|'] = f'_kor_mask{masksize}({{0}}, {{1}})'
result['any'] = f'!_ktestz_mask{masksize}_u8({{0}}, {{0}})'
result['all'] = f'_kortestc_mask{masksize}_u8({{0}}, {{0}})'
result['blendv'] = f'{pre}_mask_blend_{suf}({{2}}, {{0}}, {{1}})'
result['rsqrt'] = f"{pre}_rsqrt14_{suf}({{0}})"
result['bool'] = f"__mmask{masksize}"
params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)])
result['makeVecBool'] = f"__mmask8(({params}) )"
params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)])
result['makeVecConstBool'] = f"__mmask8(({params}) )"
vindex = f'{pre}_set_epi{bit_width//size}{setsuf}(' + \
', '.join([str(i) for i in range(result['width'])][::-1]) + ')'
vindex = f'{pre}_mullo_epi{bit_width//size}({vindex}, {pre}_set1_epi{bit_width//size}{setsuf}({{0}}))'
scale = bit_width // size // 8
result['storeS'] = f'{pre}_i{bit_width//size}scatter_{suf}({{0}}, ' + vindex.format("{2}") + \
f', {{1}}, {scale})'
result['maskStoreS'] = f'{pre}_mask_i{bit_width//size}scatter_{suf}({{0}}, {{3}}, ' + vindex.format("{2}") + \
f', {{1}}, {scale})'
if bit_width == 512:
result['loadS'] = f'{pre}_i{bit_width//size}gather_{suf}(' + vindex.format("{1}") + f', {{0}}, {scale})'
else:
result['loadS'] = f'{pre}_i{bit_width//size}gather_{suf}({{0}}, ' + vindex.format("{1}") + f', {scale})'
# abs intrinsic exists in 512 bits, but expands to a sequence. We generate that same sequence for 128 and 256 bits
if instruction_set == 'avx512':
result['abs'] = f"{pre}_abs_{suf}({{0}})"
else:
result['abs'] = f"{pre}_castsi{bit_width}_{suf}({pre}_and_si{bit_width}(" + \
f"{pre}_set1_epi{bit_width // result['width']}{setsuf}(0x7" + \
'f' * (bit_width // result['width'] // 4 - 1) + "), " + \
f"{pre}_cast{suf}_si{bit_width}({{0}})))"
if instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})"
result['+int'] = f"{pre}_add_{suffix['int']}({{0}}, {{1}})"
result['streamFence'] = '_mm_mfence()'
return result