Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 1087 additions and 67 deletions
"""Special symbols representing kernel parameters related to fields/arrays.
A `KernelFunction` node determines parameters that have to be passed to the function by searching for all undefined
symbols. Some symbols are not directly defined by the user, but are related to the `Field`s used in the kernel:
For each field a `FieldPointerSymbol` needs to be passed in, which is the pointer to the memory region where
the field is stored. This pointer is represented by the `FieldPointerSymbol` class that additionally stores the
name of the corresponding field. For fields where the size is not known at compile time, additionally shape and stride
information has to be passed in at runtime. These values are represented by `FieldShapeSymbol`
and `FieldPointerSymbol`.
The special symbols in this module store only the field name instead of a field reference. Storing a field reference
directly leads to problems with copying and pickling behaviour due to the circular dependency of `Field` and
e.g. `FieldShapeSymbol`, since a Field contains `FieldShapeSymbol`s in its shape, and a `FieldShapeSymbol`
would reference back to the field.
"""
from typing import Union
import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
from pystencils.data_types import (
PointerType, TypedSymbol, create_composite_type_from_string, get_base_type)
from pystencils.typing.types import BasicType, create_type, PointerType
def assumptions_from_dtype(dtype: Union[BasicType, np.dtype]):
"""Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
Args:
dtype (BasicType, np.dtype): a Numpy data type
Returns:
A dict of SymPy assumptions
"""
if hasattr(dtype, 'numpy_dtype'):
dtype = dtype.numpy_dtype
assumptions = dict()
try:
if np.issubdtype(dtype, np.integer):
assumptions.update({'integer': True})
SHAPE_DTYPE = create_composite_type_from_string("const int64")
STRIDE_DTYPE = create_composite_type_from_string("const int64")
if np.issubdtype(dtype, np.unsignedinteger):
assumptions.update({'negative': False})
if np.issubdtype(dtype, np.integer) or \
np.issubdtype(dtype, np.floating):
assumptions.update({'real': True})
except Exception: # TODO this is dirty
pass
return assumptions
class TypedSymbol(sp.Symbol):
def __new__(cls, *args, **kwds):
obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, name, dtype, **kwargs): # TODO does not match signature of sp.Symbol???
# TODO: also Symbol should be allowed ---> see sympy Variable
assumptions = assumptions_from_dtype(dtype)
assumptions.update(kwargs)
obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
try:
obj.numpy_dtype = create_type(dtype)
except (TypeError, ValueError):
# on error keep the string
obj.numpy_dtype = dtype
return obj
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
@property
def dtype(self):
return self.numpy_dtype
def _hashable_content(self):
return super()._hashable_content(), hash(self.numpy_dtype)
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), self.assumptions0
@property
def canonical(self):
return self
@property
def reversed(self):
return self
@property
def headers(self):
headers = []
try:
if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
try:
if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
return headers
SHAPE_DTYPE = BasicType('int64', const=True)
STRIDE_DTYPE = BasicType('int64', const=True)
class FieldStrideSymbol(TypedSymbol):
......@@ -83,6 +159,8 @@ class FieldPointerSymbol(TypedSymbol):
return obj
def __new_stage2__(cls, field_name, field_dtype, const):
from pystencils.typing.utilities import get_base_type
name = f"_data_{field_name}"
dtype = PointerType(get_base_type(field_dtype), const=const, restrict=True)
obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype)
......@@ -100,3 +178,20 @@ class FieldPointerSymbol(TypedSymbol):
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
class CFunction(TypedSymbol):
def __new__(cls, function, dtype):
return CFunction.__xnew_cached_(cls, function, dtype)
def __new_stage2__(cls, function, dtype):
return super(CFunction, cls).__xnew__(cls, function, dtype)
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), {}
from abc import abstractmethod
from typing import Union
import numpy as np
import sympy as sp
def is_supported_type(dtype: np.dtype):
scalar = dtype.type
c = np.issubdtype(dtype, np.generic)
subclass = issubclass(scalar, np.floating) or issubclass(scalar, np.integer) or issubclass(scalar, np.bool_)
additional_checks = dtype.fields is None and dtype.hasobject is False and dtype.subdtype is None
return c and subclass and additional_checks
def numpy_name_to_c(name: str) -> str:
"""
Converts a np.dtype.name into a C type
Args:
name: np.dtype.name string
Returns:
type as a C string
"""
if name == 'float64':
return 'double'
elif name == 'float32':
return 'float'
elif name == 'float16' or name == 'half':
return 'half'
elif name.startswith('int'):
width = int(name[len("int"):])
return f"int{width}_t"
elif name.startswith('uint'):
width = int(name[len("uint"):])
return f"uint{width}_t"
elif name == 'bool':
return 'bool'
else:
raise NotImplementedError(f"Can't map numpy to C name for {name}")
class AbstractType(sp.Atom):
# TODO: Is it necessary to ineherit from sp.Atom?
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
def _sympystr(self, *args, **kwargs):
return str(self)
@property
@abstractmethod
def base_type(self) -> Union[None, 'BasicType']:
"""
Returns: Returns BasicType of a Vector or Pointer type, None otherwise
"""
pass
@property
@abstractmethod
def item_size(self) -> int:
"""
Returns: Number of items.
E.g. width * item_size(basic_type) in vector's case, or simple numpy itemsize in Struct's case.
"""
pass
class BasicType(AbstractType):
"""
BasicType is defined with a const qualifier and a np.dtype.
"""
def __init__(self, dtype: Union[type, 'BasicType', str], const: bool = False):
if isinstance(dtype, BasicType):
self.numpy_dtype = dtype.numpy_dtype
self.const = dtype.const
else:
self.numpy_dtype = np.dtype(dtype)
self.const = const
assert is_supported_type(self.numpy_dtype), f'Type {self.numpy_dtype} is currently not supported!'
def __getnewargs__(self):
return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
@property
def base_type(self):
return None
@property
def item_size(self): # TODO: Do we want self.numpy_type.itemsize????
return 1
def is_float(self):
return issubclass(self.numpy_dtype.type, np.floating)
def is_half(self):
return issubclass(self.numpy_dtype.type, np.half)
def is_int(self):
return issubclass(self.numpy_dtype.type, np.integer)
def is_uint(self):
return issubclass(self.numpy_dtype.type, np.unsignedinteger)
def is_sint(self):
return issubclass(self.numpy_dtype.type, np.signedinteger)
def is_bool(self):
return issubclass(self.numpy_dtype.type, np.bool_)
def dtype_eq(self, other):
if not isinstance(other, BasicType):
return False
else:
return self.numpy_dtype == other.numpy_dtype
@property
def c_name(self) -> str:
return numpy_name_to_c(self.numpy_dtype.name)
def __str__(self):
return f'{self.c_name}{" const" if self.const else ""}'
def __repr__(self):
return f'BasicType( {str(self)} )'
def _repr_html_(self):
return f'BasicType( {str(self)} )'
def __eq__(self, other):
return self.dtype_eq(other) and self.const == other.const
def __hash__(self):
return hash(str(self))
class VectorType(AbstractType):
"""
VectorType consists of a BasicType and a width.
"""
instruction_set = None
def __init__(self, base_type: BasicType, width: int):
self._base_type = base_type
self.width = width
@property
def base_type(self):
return self._base_type
@property
def item_size(self):
return self.width * self.base_type.item_size
def __eq__(self, other):
if not isinstance(other, VectorType):
return False
else:
return (self.base_type, self.width) == (other.base_type, other.width)
def __str__(self):
if self.instruction_set is None:
return f"{self.base_type}[{self.width}]"
else:
# TODO VectorizationRevamp: this seems super weird. the instruction_set should know how to print a type out!
# TODO VectorizationRevamp: this is error prone. base_type could be cons=True. Use dtype instead
if self.base_type == create_type("int64") or self.base_type == create_type("int32"):
return self.instruction_set['int']
elif self.base_type == create_type("float64"):
return self.instruction_set['double']
elif self.base_type == create_type("float32"):
return self.instruction_set['float']
elif self.base_type == create_type("bool"):
return self.instruction_set['bool']
else:
raise NotImplementedError()
def __hash__(self):
return hash((self.base_type, self.width))
def __getnewargs__(self):
return self._base_type, self.width
def __getnewargs_ex__(self):
return (self._base_type, self.width), {}
class PointerType(AbstractType):
def __init__(self, base_type: BasicType, const: bool = False, restrict: bool = True, double_pointer: bool = False):
self._base_type = base_type
self.const = const
self.restrict = restrict
self.double_pointer = double_pointer
def __getnewargs__(self):
return self.base_type, self.const, self.restrict, self.double_pointer
def __getnewargs_ex__(self):
return (self.base_type, self.const, self.restrict, self.double_pointer), {}
@property
def alias(self):
return not self.restrict
@property
def base_type(self):
return self._base_type
@property
def item_size(self):
if self.double_pointer:
raise NotImplementedError("The item_size for double_pointer is not implemented")
else:
return self.base_type.item_size
def __eq__(self, other):
if not isinstance(other, PointerType):
return False
else:
own = (self.base_type, self.const, self.restrict, self.double_pointer)
return own == (other.base_type, other.const, other.restrict, other.double_pointer)
def __str__(self):
restrict_str = "RESTRICT" if self.restrict else ""
const_str = "const" if self.const else ""
if self.double_pointer:
return f'{str(self.base_type)} ** {restrict_str} {const_str}'
else:
return f'{str(self.base_type)} * {restrict_str} {const_str}'
def __repr__(self):
return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self):
return hash((self._base_type, self.const, self.restrict, self.double_pointer))
class StructType(AbstractType):
"""
A list of types (with C offsets).
It is implemented with uint8_t and casts to the correct datatype.
"""
def __init__(self, numpy_type, const=False):
self.const = const
self._dtype = np.dtype(numpy_type)
def __getnewargs__(self):
return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
@property
def base_type(self):
return None
@property
def numpy_dtype(self):
return self._dtype
@property
def item_size(self):
return self.numpy_dtype.itemsize
def get_element_offset(self, element_name):
return self.numpy_dtype.fields[element_name][1]
def get_element_type(self, element_name):
np_element_type = self.numpy_dtype.fields[element_name][0]
return BasicType(np_element_type, self.const)
def has_element(self, element_name):
return element_name in self.numpy_dtype.fields
def __eq__(self, other):
if not isinstance(other, StructType):
return False
else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
def __str__(self):
# structs are handled byte-wise
result = "uint8_t"
if self.const:
result += " const"
return result
def __repr__(self):
return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self):
return hash((self.numpy_dtype, self.const))
def create_type(specification: Union[type, AbstractType, str]) -> AbstractType:
# TODO: Deprecated Use the constructor of BasicType or StructType instead
"""Creates a subclass of Type according to a string or an object of subclass Type.
Args:
specification: Type object, or a string
Returns:
Type object, or a new Type object parsed from the string
"""
if isinstance(specification, AbstractType):
return specification
else:
numpy_dtype = np.dtype(specification)
if numpy_dtype.fields is None:
return BasicType(numpy_dtype, const=False)
else:
return StructType(numpy_dtype, const=False)
from collections import defaultdict
from functools import partial
from typing import Tuple, Union, Sequence
import numpy as np
import sympy as sp
from sympy.logic.boolalg import Boolean, BooleanFunction
import pystencils
from pystencils.cache import memorycache_if_hashable
from pystencils.typing.types import BasicType, VectorType, PointerType, create_type
from pystencils.typing.cast_functions import CastFunc
from pystencils.typing.typed_sympy import TypedSymbol
from pystencils.utils import all_equal
def typed_symbols(names, dtype, **kwargs):
"""
Creates TypedSymbols with the same functionality as sympy.symbols
Args:
names: See sympy.symbols
dtype: The data type all symbols will have
**kwargs: Key value arguments passed to sympy.symbols
Returns:
TypedSymbols
"""
symbols = sp.symbols(names, **kwargs)
if isinstance(symbols, Tuple):
return tuple(TypedSymbol(str(s), dtype) for s in symbols)
else:
return TypedSymbol(str(symbols), dtype)
def get_base_type(data_type):
"""
Returns the BasicType of a Pointer or a Vector
"""
while data_type.base_type is not None:
data_type = data_type.base_type
return data_type
def result_type(*args: np.dtype):
"""Returns the type of the result if the np.dtype arguments would be collated.
We can't use numpy functionality, because numpy casts don't behave exactly like C casts"""
s = sorted(args, key=lambda x: x.itemsize)
def kind_to_value(kind: str) -> int:
if kind == 'f':
return 3
elif kind == 'i':
return 2
elif kind == 'u':
return 1
elif kind == 'b':
return 0
else:
raise NotImplementedError(f'{kind=} is not a supported kind of a type. See "numpy.dtype.kind" for options')
s = sorted(s, key=lambda x: kind_to_value(x.kind))
return s[-1]
def collate_types(types: Sequence[Union[BasicType, VectorType]]):
"""
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy.
"""
# Pointer arithmetic case i.e. pointer + [int, uint] is allowed
if any(isinstance(t, PointerType) for t in types):
pointer_type = None
for t in types:
if isinstance(t, PointerType):
if pointer_type is not None:
raise ValueError(f'Cannot collate the combination of two pointer types "{pointer_type}" and "{t}"')
pointer_type = t
elif isinstance(t, BasicType):
if not (t.is_int() or t.is_uint()):
raise ValueError("Invalid pointer arithmetic")
else:
raise ValueError("Invalid pointer arithmetic")
return pointer_type
# # peel of vector types, if at least one vector type occurred the result will also be the vector type
vector_type = [t for t in types if isinstance(t, VectorType)]
if not all_equal(t.width for t in vector_type):
raise ValueError("Collation failed because of vector types with different width")
# TODO: check if this is needed
# def peel_off_type(dtype, type_to_peel_off):
# while type(dtype) is type_to_peel_off:
# dtype = dtype.base_type
# return dtype
# types = [peel_off_type(t, VectorType) for t in types]
types = [t.base_type if isinstance(t, VectorType) else t for t in types]
# now we should have a list of basic types - struct types are not yet supported
assert all(type(t) is BasicType for t in types)
result_numpy_type = result_type(*(t.numpy_dtype for t in types))
result = BasicType(result_numpy_type)
if vector_type:
result = VectorType(result, vector_type[0].width)
return result
# TODO get_type_of_expression should be used after leaf_typing. So no defaults should be necessary
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
default_float_type='double',
default_int_type='int',
symbol_type_dict=None):
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any
if default_float_type == 'float':
default_float_type = 'float32'
if not symbol_type_dict:
symbol_type_dict = defaultdict(lambda: create_type('double'))
# TODO this line is quite hard to understand, if possible simpl
get_type = partial(get_type_of_expression,
default_float_type=default_float_type,
default_int_type=default_int_type,
symbol_type_dict=symbol_type_dict)
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return create_type(default_int_type)
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type(default_float_type)
elif isinstance(expr, ResolvedFieldAccess):
return expr.field.dtype
elif isinstance(expr, pystencils.field.Field.Access):
return expr.field.dtype
elif isinstance(expr, TypedSymbol):
return expr.dtype
elif isinstance(expr, sp.Symbol):
# TODO delete if case
if symbol_type_dict:
return symbol_type_dict[expr.name]
else:
raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
elif isinstance(expr, CastFunc):
return expr.args[1]
elif isinstance(expr, (vec_any, vec_all)):
return create_type("bool")
elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
return collated_result_type
elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label
return typed_symbol.dtype.base_type
elif isinstance(expr, (Boolean, BooleanFunction)):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool")
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
if vec_args:
result = VectorType(result, width=vec_args[0].width)
return result
elif isinstance(expr, sp.Pow):
base_type = get_type(expr.args[0])
if expr.exp.is_integer:
return base_type
else:
return collate_types([create_type(default_float_type), base_type])
elif isinstance(expr, (sp.Sum, sp.Product)):
return get_type(expr.args[0])
elif isinstance(expr, sp.Expr):
expr: sp.Expr
if expr.args:
types = tuple(get_type(a) for a in expr.args)
return collate_types(types)
else:
if expr.is_integer:
return create_type(default_int_type)
else:
return create_type(default_float_type)
raise NotImplementedError("Could not determine type for", expr, type(expr))
# Fix for sympy versions from 1.9
sympy_version = sp.__version__.split('.')
sympy_version_int = int(sympy_version[0]) * 100 + int(sympy_version[1])
if sympy_version_int >= 109:
# __setstate__ would bypass the contructor, so we remove it
if sympy_version_int >= 111:
del sp.Basic.__setstate__
del sp.Symbol.__setstate__
else:
sp.Number.__getstate__ = sp.Basic.__getstate__
del sp.Basic.__getstate__
# __reduce_ex__ would strip kwargs, so we override it
def basic_reduce_ex(self, protocol):
if hasattr(self, '__getnewargs_ex__'):
args, kwargs = self.__getnewargs_ex__()
else:
args, kwargs = self.__getnewargs__(), {}
if hasattr(self, '__getstate__'):
state = self.__getstate__()
else:
state = None
return partial(type(self), **kwargs), args, state
sp.Basic.__reduce_ex__ = basic_reduce_ex
def get_next_parent_of_type(node, parent_type):
"""Returns the next parent node of given type or None, if root is reached.
Traverses the AST nodes parents until a parent of given type was found.
If no such parent is found, None is returned
"""
parent = node.parent
while parent is not None:
if isinstance(parent, parent_type):
return parent
parent = parent.parent
return None
def parents_of_type(node, parent_type, include_current=False):
"""Generator for all parent nodes of given type"""
parent = node if include_current else node.parent
while parent is not None:
if isinstance(parent, parent_type):
yield parent
parent = parent.parent
import os
import itertools
from itertools import groupby
from collections import Counter
from contextlib import contextmanager
from tempfile import NamedTemporaryFile
......@@ -23,13 +24,13 @@ class DotDict(dict):
self[key] = value
def all_equal(iterator):
iterator = iter(iterator)
try:
first = next(iterator)
except StopIteration:
return True
return all(first == rest for rest in iterator)
def all_equal(iterable):
"""
Returns ``True`` if all the elements are equal to each other.
Copied from: more-itertools 8.12.0
"""
g = groupby(iterable)
return next(g, True) and not next(g, False)
def recursive_dict_update(d, u):
......@@ -51,33 +52,13 @@ def recursive_dict_update(d, u):
return d
@contextmanager
def file_handle_for_atomic_write(file_path):
"""Open temporary file object that atomically moves to destination upon exiting.
Allows reading and writing to and from the same filename.
The file will not be moved to destination in case of an exception.
Args:
file_path: path to file to be opened
"""
target_folder = os.path.dirname(os.path.abspath(file_path))
with NamedTemporaryFile(delete=False, dir=target_folder, mode='w') as f:
try:
yield f
finally:
f.flush()
os.fsync(f.fileno())
os.rename(f.name, file_path)
@contextmanager
def atomic_file_write(file_path):
target_folder = os.path.dirname(os.path.abspath(file_path))
with NamedTemporaryFile(delete=False, dir=target_folder) as f:
f.file.close()
yield f.name
os.rename(f.name, file_path)
os.replace(f.name, file_path)
def fully_contains(l1, l2):
......@@ -101,8 +82,8 @@ def boolean_array_bounding_box(boolean_array):
>>> a = np.zeros((4, 4), dtype=bool)
>>> a[1:-1, 1:-1] = True
>>> boolean_array_bounding_box(a)
[(1, 3), (1, 3)]
>>> boolean_array_bounding_box(a) == [(1, 3), (1, 3)]
True
"""
dim = boolean_array.ndim
shape = boolean_array.shape
......@@ -115,6 +96,21 @@ def boolean_array_bounding_box(boolean_array):
return bounds
def binary_numbers(n):
"""Returns all binary numbers up to 2^n - 1
Example:
>>> binary_numbers(2)
[[0, 0], [0, 1], [1, 0], [1, 1]]
"""
result = list()
for i in range(1 << n):
binary_number = bin(i)[2:]
binary_number = '0' * (n - len(binary_number)) + binary_number
result.append((list(map(int, binary_number))))
return result
class LinearEquationSystem:
"""Symbolic linear system of equations - consisting of matrix and right hand side.
......@@ -240,3 +236,17 @@ class LinearEquationSystem:
break
result -= 1
self.next_zero_row = result
class ContextVar:
def __init__(self, value):
self.stack = [value]
@contextmanager
def __call__(self, new_value):
self.stack.append(new_value)
yield self
self.stack.pop()
def get(self):
return self.stack[-1]
File moved
import pytest
import sympy as sp
import numpy
import pystencils
from pystencils.datahandling import create_data_handling
@pytest.mark.parametrize('dtype', ["float64", "float32"])
@pytest.mark.parametrize('sympy_function', [sp.Min, sp.Max])
def test_max(dtype, sympy_function):
dh = create_data_handling(domain_size=(10, 10), periodicity=True)
x = dh.add_array('x', values_per_cell=1, dtype=dtype)
dh.fill("x", 0.0, ghost_layers=True)
y = dh.add_array('y', values_per_cell=1, dtype=dtype)
dh.fill("y", 1.0, ghost_layers=True)
z = dh.add_array('z', values_per_cell=1, dtype=dtype)
dh.fill("z", 2.0, ghost_layers=True)
config = pystencils.CreateKernelConfig(default_number_float=dtype)
# test sp.Max with one argument
assignment_1 = pystencils.Assignment(x.center, sympy_function(y.center + 3.3))
ast_1 = pystencils.create_kernel(assignment_1, config=config)
kernel_1 = ast_1.compile()
# pystencils.show_code(ast_1)
# test sp.Max with two arguments
assignment_2 = pystencils.Assignment(x.center, sympy_function(0.5, y.center - 1.5))
ast_2 = pystencils.create_kernel(assignment_2, config=config)
kernel_2 = ast_2.compile()
# pystencils.show_code(ast_2)
# test sp.Max with many arguments
assignment_3 = pystencils.Assignment(x.center, sympy_function(z.center, 4.5, y.center - 1.5, y.center + z.center))
ast_3 = pystencils.create_kernel(assignment_3, config=config)
kernel_3 = ast_3.compile()
# pystencils.show_code(ast_3)
if sympy_function is sp.Max:
results = [4.3, 0.5, 4.5]
else:
results = [4.3, -0.5, -0.5]
dh.run_kernel(kernel_1)
assert numpy.all(dh.gather_array('x') == results[0])
dh.run_kernel(kernel_2)
assert numpy.all(dh.gather_array('x') == results[1])
dh.run_kernel(kernel_3)
assert numpy.all(dh.gather_array('x') == results[2])
@pytest.mark.parametrize('dtype', ["int64", 'int32'])
@pytest.mark.parametrize('sympy_function', [sp.Min, sp.Max])
def test_max_integer(dtype, sympy_function):
dh = create_data_handling(domain_size=(10, 10), periodicity=True)
x = dh.add_array('x', values_per_cell=1, dtype=dtype)
dh.fill("x", 0, ghost_layers=True)
y = dh.add_array('y', values_per_cell=1, dtype=dtype)
dh.fill("y", 1, ghost_layers=True)
z = dh.add_array('z', values_per_cell=1, dtype=dtype)
dh.fill("z", 2, ghost_layers=True)
config = pystencils.CreateKernelConfig(default_number_int=dtype)
# test sp.Max with one argument
assignment_1 = pystencils.Assignment(x.center, sympy_function(y.center + 3))
ast_1 = pystencils.create_kernel(assignment_1, config=config)
kernel_1 = ast_1.compile()
# pystencils.show_code(ast_1)
# test sp.Max with two arguments
assignment_2 = pystencils.Assignment(x.center, sympy_function(1, y.center - 1))
ast_2 = pystencils.create_kernel(assignment_2, config=config)
kernel_2 = ast_2.compile()
# pystencils.show_code(ast_2)
# test sp.Max with many arguments
assignment_3 = pystencils.Assignment(x.center, sympy_function(z.center, 4, y.center - 1, y.center + z.center))
ast_3 = pystencils.create_kernel(assignment_3, config=config)
kernel_3 = ast_3.compile()
# pystencils.show_code(ast_3)
if sympy_function is sp.Max:
results = [4, 1, 4]
else:
results = [4, 0, 0]
dh.run_kernel(kernel_1)
assert numpy.all(dh.gather_array('x') == results[0])
dh.run_kernel(kernel_2)
assert numpy.all(dh.gather_array('x') == results[1])
dh.run_kernel(kernel_3)
assert numpy.all(dh.gather_array('x') == results[2])
import pytest
import pystencils.config
import sympy
import pystencils as ps
from pystencils.data_types import cast_func, create_type
from pystencils.typing import CastFunc, create_type
def test_abs():
@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
def test_abs(target):
x, y, z = ps.fields('x, y, z: float64[2d]')
default_int_type = create_type('int64')
assignments = ps.AssignmentCollection({
x[0, 0]: sympy.Abs(cast_func(y[0, 0], default_int_type))
})
assignments = ps.AssignmentCollection({x[0, 0]: sympy.Abs(CastFunc(y[0, 0], default_int_type))})
config = ps.CreateKernelConfig(target=ps.Target.GPU)
config = pystencils.config.CreateKernelConfig(target=target)
ast = ps.create_kernel(assignments, config=config)
code = ps.get_code_str(ast)
print(code)
......
"""
Test of pystencils.data_types.address_of
"""
import pytest
import pystencils
from pystencils.typing import PointerType, CastFunc, BasicType
from pystencils.functions import AddressOf
from pystencils.simp.simplifications import sympy_cse
import sympy as sp
def test_address_of():
x, y = pystencils.fields('x, y: int64[2d]')
s = pystencils.TypedSymbol('s', PointerType(BasicType('int64')))
assert AddressOf(x[0, 0]).canonical() == x[0, 0]
assert AddressOf(x[0, 0]).dtype == PointerType(x[0, 0].dtype, restrict=True)
with pytest.raises(ValueError):
assert AddressOf(sp.Symbol("a")).dtype
assignments = pystencils.AssignmentCollection({
s: AddressOf(x[0, 0]),
y[0, 0]: CastFunc(s, BasicType('int64'))
})
kernel = pystencils.create_kernel(assignments).compile()
# pystencils.show_code(kernel.ast)
assignments = pystencils.AssignmentCollection({
y[0, 0]: CastFunc(AddressOf(x[0, 0]), BasicType('int64'))
})
kernel = pystencils.create_kernel(assignments).compile()
# pystencils.show_code(kernel.ast)
def test_address_of_with_cse():
x, y = pystencils.fields('x, y: int64[2d]')
assignments = pystencils.AssignmentCollection({
x[0, 0]: CastFunc(AddressOf(x[0, 0]), BasicType('int64')) + 1
})
kernel = pystencils.create_kernel(assignments).compile()
# pystencils.show_code(kernel.ast)
assignments_cse = sympy_cse(assignments)
kernel = pystencils.create_kernel(assignments_cse).compile()
# pystencils.show_code(kernel.ast)
......@@ -59,13 +59,13 @@ def test_alignment_of_different_layouts():
byte_offset = 8
for tries in range(16): # try a few times, since we might get lucky and get randomly a correct alignment
arr = create_numpy_array_with_layout((3, 4, 5), layout=(0, 1, 2),
alignment=True, byte_offset=byte_offset)
alignment=8*4, byte_offset=byte_offset)
assert is_aligned(arr[offset, ...], 8*4, byte_offset)
arr = create_numpy_array_with_layout((3, 4, 5), layout=(2, 1, 0),
alignment=True, byte_offset=byte_offset)
alignment=8*4, byte_offset=byte_offset)
assert is_aligned(arr[..., offset], 8*4, byte_offset)
arr = create_numpy_array_with_layout((3, 4, 5), layout=(2, 0, 1),
alignment=True, byte_offset=byte_offset)
alignment=8*4, byte_offset=byte_offset)
assert is_aligned(arr[:, 0, :], 8*4, byte_offset)
......@@ -170,3 +170,19 @@ def test_new_merged():
assert ps.Assignment(d[0, 0](0), sp.symbols('xi_0')) in merged_ac.main_assignments
assert a1 in merged_ac.subexpressions
assert a3 in merged_ac.subexpressions
a1 = ps.Assignment(a, 20)
a2 = ps.Assignment(a, 10)
acommon = ps.Assignment(b, a)
# main assignments
a3 = ps.Assignment(f[0, 0](0), b)
a4 = ps.Assignment(d[0, 0](0), b)
ac = ps.AssignmentCollection([a3], subexpressions=[a1, acommon])
ac2 = ps.AssignmentCollection([a4], subexpressions=[a2, acommon])
merged_ac = ac.new_merged(ac2).new_without_subexpressions()
assert ps.Assignment(f[0, 0](0), 20) in merged_ac.main_assignments
assert ps.Assignment(d[0, 0](0), 10) in merged_ac.main_assignments
import pytest
import sys
import pystencils.config
import sympy as sp
import pystencils as ps
from pystencils import Assignment
from pystencils.astnodes import Block, LoopOverCoordinate, SkipIteration, SympyAssignment
sympy_numeric_version = [int(x, 10) for x in sp.__version__.split('.') if x.isdigit()]
if len(sympy_numeric_version) < 3:
sympy_numeric_version.append(0)
sympy_numeric_version.reverse()
sympy_version = sum(x * (100 ** i) for i, x in enumerate(sympy_numeric_version))
dst = ps.fields('dst(8): double[2D]')
s = sp.symbols('s_:8')
x = sp.symbols('x')
y = sp.symbols('y')
python_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
@pytest.mark.skipif(sympy_version < 10501,
reason="Old Sympy Versions behave differently which wont be supported in the near future")
def test_kernel_function():
assignments = [
Assignment(dst[0, 0](0), s[0]),
......@@ -44,8 +41,6 @@ def test_skip_iteration():
assert skipped.undefined_symbols == set()
@pytest.mark.skipif(sympy_version < 10501,
reason="Old Sympy Versions behave differently which wont be supported in the near future")
def test_block():
assignments = [
Assignment(dst[0, 0](0), s[0]),
......@@ -91,21 +86,3 @@ def test_loop_over_coordinate():
assert loop.stop == 20
assert loop.step == 2
def test_sympy_assignment():
pytest.importorskip('sympy.codegen.rewriting')
from sympy.codegen.rewriting import optims_c99
assignment = SympyAssignment(dst[0, 0](0), sp.log(x + 3) / sp.log(2) + sp.log(x ** 2 + 1))
assignment.optimize(optims_c99)
ast = ps.create_kernel([assignment])
code = ps.get_code_str(ast)
assert 'log1p' in code
assert 'log2' in code
assignment.replace(assignment.lhs, dst[0, 0](1))
assignment.replace(assignment.rhs, sp.log(2))
assert assignment.lhs == dst[0, 0](1)
assert assignment.rhs == sp.log(2)
import pytest
import pystencils as ps
@pytest.mark.parametrize('target', [ps.Target.CPU, ps.Target.GPU])
def test_add_augmented_assignment(target):
if target == ps.Target.GPU:
pytest.importorskip("cupy")
domain_size = (5, 5)
dh = ps.create_data_handling(domain_size=domain_size, periodicity=True, default_target=target)
f = dh.add_array("f", values_per_cell=1)
dh.fill(f.name, 0.0)
g = dh.add_array("g", values_per_cell=1)
dh.fill(g.name, 1.0)
up = ps.AddAugmentedAssignment(f.center, g.center)
config = ps.CreateKernelConfig(target=dh.default_target)
ast = ps.create_kernel(up, config=config)
kernel = ast.compile()
for i in range(10):
dh.run_kernel(kernel)
if target == ps.Target.GPU:
dh.all_to_cpu()
result = dh.gather_array(f.name)
for x in range(domain_size[0]):
for y in range(domain_size[1]):
assert result[x, y] == 10
import pytest
from pystencils import Assignment, CreateKernelConfig, Target, fields, create_kernel, get_code_str
@pytest.mark.parametrize('target', (Target.CPU, Target.GPU))
def test_intermediate_base_pointer(target):
x = fields(f'x: double[3d]')
y = fields(f'y: double[3d]')
update = Assignment(x.center, y.center)
config = CreateKernelConfig(base_pointer_specification=[], target=target)
ast = create_kernel(update, config=config)
code = get_code_str(ast)
# no intermediate base pointers are created
assert "_data_x[_stride_x_0*ctr_0 + _stride_x_1*ctr_1 + _stride_x_2*ctr_2] = " \
"_data_y[_stride_y_0*ctr_0 + _stride_y_1*ctr_1 + _stride_y_2*ctr_2];" in code
config = CreateKernelConfig(base_pointer_specification=[[0]], target=target)
ast = create_kernel(update, config=config)
code = get_code_str(ast)
# intermediate base pointers for y and z
assert "double * RESTRICT _data_x_10_20 = _data_x + _stride_x_1*ctr_1 + _stride_x_2*ctr_2;" in code
assert " double * RESTRICT _data_y_10_20 = _data_y + _stride_y_1*ctr_1 + _stride_y_2*ctr_2;" in code
assert "_data_x_10_20[_stride_x_0*ctr_0] = _data_y_10_20[_stride_y_0*ctr_0];" in code
config = CreateKernelConfig(base_pointer_specification=[[1]], target=target)
ast = create_kernel(update, config=config)
code = get_code_str(ast)
# intermediate base pointers for x and z
assert "double * RESTRICT _data_x_00_20 = _data_x + _stride_x_0*ctr_0 + _stride_x_2*ctr_2;" in code
assert "double * RESTRICT _data_y_00_20 = _data_y + _stride_y_0*ctr_0 + _stride_y_2*ctr_2;" in code
assert "_data_x_00_20[_stride_x_1*ctr_1] = _data_y_00_20[_stride_y_1*ctr_1];" in code
config = CreateKernelConfig(base_pointer_specification=[[2]], target=target)
ast = create_kernel(update, config=config)
code = get_code_str(ast)
# intermediate base pointers for x and y
assert "double * RESTRICT _data_x_00_10 = _data_x + _stride_x_0*ctr_0 + _stride_x_1*ctr_1;" in code
assert "double * RESTRICT _data_y_00_10 = _data_y + _stride_y_0*ctr_0 + _stride_y_1*ctr_1;" in code
assert "_data_x_00_10[_stride_x_2*ctr_2] = _data_y_00_10[_stride_y_2*ctr_2];" in code
config = CreateKernelConfig(target=target)
ast = create_kernel(update, config=config)
code = get_code_str(ast)
# by default no intermediate base pointers are created
assert "_data_x[_stride_x_0*ctr_0 + _stride_x_1*ctr_1 + _stride_x_2*ctr_2] = " \
"_data_y[_stride_y_0*ctr_0 + _stride_y_1*ctr_1 + _stride_y_2*ctr_2];" in code
import pytest
import numpy as np
import pystencils as ps
from pystencils import Field, Assignment, create_kernel
from pystencils.bit_masks import flag_cond
def test_flag_condition():
@pytest.mark.parametrize('mask_type', [np.uint8, np.uint16, np.uint32, np.uint64])
def test_flag_condition(mask_type):
f_arr = np.zeros((2, 2, 2), dtype=np.float64)
mask_arr = np.zeros((2, 2), dtype=np.uint64)
mask_arr = np.zeros((2, 2), dtype=mask_type)
mask_arr[0, 1] = (1 << 3)
mask_arr[1, 0] = (1 << 5)
......@@ -16,7 +20,7 @@ def test_flag_condition():
v1 = 42.3
v2 = 39.7
v3 = 119.87
v3 = 119
assignments = [
Assignment(f(0), flag_cond(3, mask(0), v1)),
......@@ -25,6 +29,8 @@ def test_flag_condition():
kernel = create_kernel(assignments).compile()
kernel(f=f_arr, mask=mask_arr)
code = ps.get_code_str(kernel)
assert '119.0' in code
reference = np.zeros((2, 2, 2), dtype=np.float64)
reference[0, 1, 0] = v1
......
......@@ -77,4 +77,4 @@ def test_jacobi3d_fixed_field_size():
print("Fixed Field Size: Smaller than block sizes")
arr = np.empty([3, 5, 6])
check_equivalence(jacobi(dst, src), arr)
\ No newline at end of file
check_equivalence(jacobi(dst, src), arr)
......@@ -12,8 +12,10 @@ def test_blocking_staggered():
f[0, 0, 0] - f[0, 0, -1],
]
assignments = [ps.Assignment(stag.staggered_access(d), terms[i]) for i, d in enumerate(stag.staggered_stencil)]
reference_kernel = ps.create_staggered_kernel(assignments)
print(ps.show_code(reference_kernel))
reference_kernel = reference_kernel.compile()
kernel = ps.create_staggered_kernel(assignments, cpu_blocking=(3, 16, 8)).compile()
reference_kernel = ps.create_staggered_kernel(assignments).compile()
print(ps.show_code(kernel.ast))
f_arr = np.random.rand(80, 33, 19)
......
......@@ -97,7 +97,7 @@ def test_kernel_vs_copy_boundary():
def test_boundary_gpu():
pytest.importorskip('pycuda')
pytest.importorskip('cupy')
dh = SerialDataHandling(domain_size=(7, 7), default_target=Target.GPU)
src = dh.add_array('src')
dh.fill("src", 0.0, ghost_layers=True)
......