diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index dbe3ae0eb892449efee86123de62ca94b6c28c7e..abcde4fdee7f85af337bf55fdfd0d120aa3ce767 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -326,9 +326,10 @@ nbackend-unit-tests: before_script: - pip install -e .[tests] script: - - pytest tests/nbackend + - pytest tests/nbackend tests/symbolics tags: - docker + - cuda11 doctest: stage: "Unit Tests" diff --git a/conftest.py b/conftest.py index ef075c534f6b32beafc693bb589db476083f4312..742ff7caa8a1825d1073d8342ce5cb1a70b5de46 100644 --- a/conftest.py +++ b/conftest.py @@ -7,7 +7,6 @@ import pathlib import nbformat import pytest -from nbconvert import PythonExporter # Trigger config file reading / creation once - to avoid race conditions when multiple instances are creating it # at the same time @@ -134,6 +133,7 @@ class IPyNbTest(pytest.Item): class IPyNbFile(pytest.File): def collect(self): + from nbconvert import PythonExporter exporter = PythonExporter() exporter.exclude_markdown = True exporter.exclude_input_prompt = True diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 9af45d147dd6b0c6f9b24cb7794bb8e87b67fed1..4763300a13d7ae5b363425977650f435f04fc0ec 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -5,6 +5,6 @@ API Reference .. toctree:: :maxdepth: 2 - sympyextensions/index + symbolic_language/index kernelcreation/index types diff --git a/docs/source/api/sympyextensions/astnodes.rst b/docs/source/api/symbolic_language/astnodes.rst similarity index 100% rename from docs/source/api/sympyextensions/astnodes.rst rename to docs/source/api/symbolic_language/astnodes.rst diff --git a/docs/source/api/sympyextensions/field.rst b/docs/source/api/symbolic_language/field.rst similarity index 58% rename from docs/source/api/sympyextensions/field.rst rename to docs/source/api/symbolic_language/field.rst index a435c716c7fe044d8f56cef599ea247297fb4a75..a76d39a3a10df67e853558b52752a1cdf107562e 100644 --- a/docs/source/api/sympyextensions/field.rst +++ b/docs/source/api/symbolic_language/field.rst @@ -1,6 +1,6 @@ -------------------------- -Fields (pystencils.field) -------------------------- +----------------------------- +Fields API (pystencils.field) +----------------------------- .. automodule:: pystencils.field :members: diff --git a/docs/source/api/symbolic_language/index.rst b/docs/source/api/symbolic_language/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..bf478414ba09068e3bb3d0ca6a08a95f1a113c6c --- /dev/null +++ b/docs/source/api/symbolic_language/index.rst @@ -0,0 +1,70 @@ +***************** +Symbolic Language +***************** + +.. toctree:: + :maxdepth: 1 + + field + astnodes + sympyextensions + +Pystencils allows you to define near-arbitrarily complex numerical kernels in its symbolic +language, which is based on the computer algebra system `SymPy <https://www.sympy.org>`_. +The pystencils code generator is able to parse and translate a large portion of SymPy's +symbolic expression toolkit, and furthermore extends it with its own features. +Among the supported SymPy features are: symbols, constants, arithmetic and logical expressions, +trigonometric and most transcendental functions, as well as piecewise definitions. + +Fields +====== + +The most important extension to SymPy brought by pystencils are *fields*. +Fields are a symbolic representation of multidimensional cartesian numerical arrays, +as used in many stencil algorithms. +They are represented by the `Field` class. + +Piecewise Definitions +===================== + +Pystencils can parse and translate piecewise function definitions using `sympy.Piecewise` +*only if* they have a default case. +So, for instance, + +.. code-block:: Python + + sp.Piecewise((0, x < 0), (1, x >= 0)) + +will result in an error from pystencils, while the equivalent + +.. code-block:: Python + + sp.Piecewise((0, x < 0), (1, True)) + +will be accepted. This is because pystencils cannot reason about whether or not +the given cases completely cover the entire possible input range. + +Integer Operations +================== + +Division and Remainder +---------------------- + +Care has to be taken when working with integer division operations in pystencils. +The python operators ``//`` and ``%`` work differently from their counterparts in the C family of languages. +Where in C, integer division always rounds toward zero, ``//`` performs a floor-divide (or euclidean division) +which rounds toward negative infinity. +These two operations differ whenever one of the operands is negative. +Accordingly, in Python ``a % b`` returns the *euclidean modulus*, +while C ``a % b`` computes the *remainder* of division. +The euclidean modulus is always nonnegative, while the remainder, if nonzero, always has the same sign as ``a``. + +When ``//`` and ``%`` occur in symbolic expressions given to pystencils, they are interpreted the Python-way. +This can lead to inefficient generated code, since Pythonic integer division does not map to the corresponding C +operators. +To achieve C behaviour (and efficient code), you can use `pystencils.symb.int_div` and `pystencils.symb.int_rem` +which translate to C ``/`` and ``%``, respectively. + +When expressions are translated in an integer type context, the Python ``/`` operator (or `sympy.Div`) +will also be converted to C-style ``/`` integer division. +Still, use of ``/`` for integers is discouraged, as it is designed to return a floating-point value in Python. diff --git a/docs/source/api/symbolic_language/sympyextensions.rst b/docs/source/api/symbolic_language/sympyextensions.rst new file mode 100644 index 0000000000000000000000000000000000000000..413746cd44a8d6da7645f8203e7a8938741d27b0 --- /dev/null +++ b/docs/source/api/symbolic_language/sympyextensions.rst @@ -0,0 +1,6 @@ +------------------- +Extensions to SymPy +------------------- + +.. automodule:: pystencils.symb + :members: diff --git a/docs/source/api/sympyextensions/index.rst b/docs/source/api/sympyextensions/index.rst deleted file mode 100644 index 606a36771e583f244d008483768ceae237ec8177..0000000000000000000000000000000000000000 --- a/docs/source/api/sympyextensions/index.rst +++ /dev/null @@ -1,10 +0,0 @@ -***************** -Symbolic Language -***************** - -.. toctree:: - :maxdepth: 1 - - field - astnodes - diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 5e6adfa4ff73d0a813f8668d634a27a698180aea..908f31052ed4b8d49a66cd1ce801d9841ef4fb7d 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -419,6 +419,10 @@ class PsCall(PsExpression): if not isinstance(other, PsCall): return False return super().structurally_equal(other) and self._function == other._function + + def __str__(self): + args = ", ".join(str(arg) for arg in self._args) + return f"PsCall({self._function}, ({args}))" class PsTernary(PsExpression): @@ -632,7 +636,7 @@ class PsIntDiv(PsBinOp, PsIntOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any]: - from .util import c_intdiv + from ...utils import c_intdiv return c_intdiv @@ -642,7 +646,7 @@ class PsRem(PsBinOp, PsIntOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any]: - from .util import c_rem + from ...utils import c_rem return c_rem diff --git a/src/pystencils/backend/ast/util.py b/src/pystencils/backend/ast/util.py index 2fdf6078d0cf285062656a739d2bdcea7736e1c2..0d3b78629fa9ee41d753893b1b6b4198cc75ae51 100644 --- a/src/pystencils/backend/ast/util.py +++ b/src/pystencils/backend/ast/util.py @@ -36,14 +36,3 @@ class AstEqWrapper: # TODO: consider replacing this with smth. more performant # TODO: Check that repr is implemented by all AST nodes return hash(repr(self._node)) - - -def c_intdiv(num, denom): - """C-style integer division""" - return int(num / denom) - - -def c_rem(num, denom): - """C-style integer remainder""" - div = c_intdiv(num, denom) - return num - div * denom diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index e420deaa657e0569996837c64a10187e323cf511..30b243d9cd614d9f843021dff52167297fccdbba 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -33,16 +33,25 @@ class MathFunctions(Enum): """ Exp = ("exp", 1) + Log = ("log", 1) Sin = ("sin", 1) Cos = ("cos", 1) Tan = ("tan", 1) + Sinh = ("sinh", 1) + Cosh = ("cosh", 1) + ASin = ("asin", 1) + ACos = ("acos", 1) + ATan = ("atan", 1) Abs = ("abs", 1) + Floor = ("floor", 1) + Ceil = ("ceil", 1) Min = ("min", 2) Max = ("max", 2) Pow = ("pow", 2) + ATan2 = ("atan2", 2) def __init__(self, func_name, num_args): self.function_name = func_name @@ -137,3 +146,15 @@ class PsMathFunction(PsFunction): @property def func(self) -> MathFunctions: return self._func + + def __str__(self) -> str: + return f"{self._func.function_name}" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsMathFunction): + return False + + return self._func == other._func + + def __hash__(self) -> int: + return hash(self._func) diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index ab6261b95361c3832afce8c9203f6c523b12621a..3865db38fe603a6cf5fe4d31deef1743d4276bd6 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -7,8 +7,14 @@ import sympy.core.relational import sympy.logic.boolalg from sympy.codegen.ast import AssignmentBase, AugmentedAssignment -from ...sympyextensions import Assignment, AssignmentCollection, integer_functions +from ...sympyextensions import ( + Assignment, + AssignmentCollection, + integer_functions, + ConditionalFieldAccess, +) from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc +from ...sympyextensions.pointers import AddressOf from ...field import Field, FieldType from .context import KernelCreationContext @@ -27,10 +33,12 @@ from ..ast.expressions import ( PsBitwiseAnd, PsBitwiseOr, PsBitwiseXor, + PsAddressOf, PsCall, PsCast, PsConstantExpr, PsIntDiv, + PsRem, PsLeftShift, PsLookup, PsRightShift, @@ -350,6 +358,12 @@ class FreezeExpressions: else: return PsArrayAccess(ptr, index) + def map_ConditionalFieldAccess(self, acc: ConditionalFieldAccess): + facc = self.visit_expr(acc.access) + condition = self.visit_expr(acc.outofbounds_condition) + fallback = self.visit_expr(acc.outofbounds_value) + return PsTernary(condition, fallback, facc) + def map_Function(self, func: sp.Function) -> PsExpression: """Map SymPy function calls by mapping sympy function classes to backend-supported function symbols. @@ -361,16 +375,36 @@ class FreezeExpressions: match func: case sp.Abs(): return PsCall(PsMathFunction(MathFunctions.Abs), args) + case sp.floor(): + return PsCall(PsMathFunction(MathFunctions.Floor), args) + case sp.ceiling(): + return PsCall(PsMathFunction(MathFunctions.Ceil), args) case sp.exp(): return PsCall(PsMathFunction(MathFunctions.Exp), args) + case sp.log(): + return PsCall(PsMathFunction(MathFunctions.Log), args) case sp.sin(): return PsCall(PsMathFunction(MathFunctions.Sin), args) case sp.cos(): return PsCall(PsMathFunction(MathFunctions.Cos), args) case sp.tan(): return PsCall(PsMathFunction(MathFunctions.Tan), args) + case sp.sinh(): + return PsCall(PsMathFunction(MathFunctions.Sinh), args) + case sp.cosh(): + return PsCall(PsMathFunction(MathFunctions.Cosh), args) + case sp.asin(): + return PsCall(PsMathFunction(MathFunctions.ASin), args) + case sp.acos(): + return PsCall(PsMathFunction(MathFunctions.ACos), args) + case sp.atan(): + return PsCall(PsMathFunction(MathFunctions.ATan), args) + case sp.atan2(): + return PsCall(PsMathFunction(MathFunctions.ATan2), args) case integer_functions.int_div(): return PsIntDiv(*args) + case integer_functions.int_rem(): + return PsRem(*args) case integer_functions.bit_shift_left(): return PsLeftShift(*args) case integer_functions.bit_shift_right(): @@ -389,6 +423,8 @@ class FreezeExpressions: # TODO: requires if *expression* # case integer_functions.modulo_ceil(): # case integer_functions.div_ceil(): + case AddressOf(): + return PsAddressOf(*args) case _: raise FreezeError(f"Unsupported function: {func}") @@ -414,12 +450,19 @@ class FreezeExpressions: return ternary def map_Min(self, expr: sp.Min) -> PsCall: - args = tuple(self.visit_expr(arg) for arg in expr.args) - return PsCall(PsMathFunction(MathFunctions.Min), args) + return self._minmax(expr, PsMathFunction(MathFunctions.Min)) def map_Max(self, expr: sp.Max) -> PsCall: - args = tuple(self.visit_expr(arg) for arg in expr.args) - return PsCall(PsMathFunction(MathFunctions.Max), args) + return self._minmax(expr, PsMathFunction(MathFunctions.Max)) + + def _minmax(self, expr: sp.Min | sp.Max, func: PsMathFunction) -> PsCall: + args = [self.visit_expr(arg) for arg in expr.args] + while len(args) > 1: + args = [ + (PsCall(func, (args[i], args[i + 1])) if i + 1 < len(args) else args[i]) + for i in range(0, len(args), 2) + ] + return cast(PsCall, args[0]) def map_CastFunc(self, cast_expr: CastFunc) -> PsCast: return PsCast(cast_expr.dtype, self.visit_expr(cast_expr.expr)) @@ -443,12 +486,12 @@ class FreezeExpressions: raise FreezeError(f"Unsupported relation: {other}") def map_And(self, conj: sympy.logic.And) -> PsAnd: - arg1, arg2 = [self.visit_expr(arg) for arg in conj.args] - return PsAnd(arg1, arg2) + args = [self.visit_expr(arg) for arg in conj.args] + return reduce(PsAnd, args) # type: ignore def map_Or(self, disj: sympy.logic.Or) -> PsOr: - arg1, arg2 = [self.visit_expr(arg) for arg in disj.args] - return PsOr(arg1, arg2) + args = [self.visit_expr(arg) for arg in disj.args] + return reduce(PsOr, args) # type: ignore def map_Not(self, neg: sympy.logic.Not) -> PsNot: arg = self.visit_expr(neg.args[0]) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 42885bedc9272469b8bfac151358cdaedf4e880f..8ef6edd24d54aa9a9992c04adf51abe620ab4813 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -13,6 +13,7 @@ from ...types import ( PsPointerType, PsBoolType, constify, + deconstify, ) from ..ast.structural import ( PsAstNode, @@ -413,6 +414,11 @@ class Typifier: tc.apply_dtype(ptr_tc.target_type.base_type, expr) case PsAddressOf(arg): + if not isinstance(arg, (PsSymbolExpr, PsSubscript, PsDeref, PsLookup)): + raise TypificationError( + f"Illegal expression below AddressOf operator: {arg}" + ) + arg_tc = TypeContext() self.visit_expr(arg, arg_tc) @@ -421,7 +427,29 @@ class Typifier: f"Unable to determine type of argument to AddressOf: {arg}" ) - ptr_type = PsPointerType(arg_tc.target_type, const=True) + # Inherit pointed-to type from referenced object, not from the subexpression + match arg: + case PsSymbolExpr(s): + pointed_to_type = s.get_dtype() + case PsSubscript(arr, _) | PsDeref(arr): + arr_type = arr.get_dtype() + assert isinstance(arr_type, PsDereferencableType) + pointed_to_type = arr_type.base_type + case PsLookup(aggr, member_name): + struct_type = aggr.get_dtype() + assert isinstance(struct_type, PsStructType) + if struct_type.const: + pointed_to_type = constify( + struct_type.get_member(member_name).dtype + ) + else: + pointed_to_type = deconstify( + struct_type.get_member(member_name).dtype + ) + case _: + assert False, "unreachable code" + + ptr_type = PsPointerType(pointed_to_type, const=True) tc.apply_dtype(ptr_type, expr) case PsLookup(aggr, member_name): @@ -438,7 +466,7 @@ class Typifier: member = aggr_type.find_member(member_name) if member is None: raise TypificationError( - f"Aggregate of type {aggr_type} does not have a member {member}." + f"Aggregate of type {aggr_type} does not have a member {member_name}." ) member_type = member.dtype diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 17589bf27d3109a3ce891acfdc248e0120da2f77..25228ba8fba81f84d45844ea23bd8813fd15ce73 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -1,8 +1,10 @@ from typing import Sequence from abc import ABC, abstractmethod +from pystencils.backend.ast.expressions import PsCall + from ..functions import CFunction, PsMathFunction, MathFunctions -from ...types import PsType, PsIeeeFloatType +from ...types import PsIntegerType, PsIeeeFloatType from .platform import Platform from ..exceptions import MaterializationError @@ -22,6 +24,9 @@ from ..ast.expressions import ( PsArrayAccess, PsVectorArrayAccess, PsLookup, + PsGe, + PsLe, + PsTernary ) from ...types import PsVectorType, PsCustomType from ..transformations.select_intrinsics import IntrinsicOps @@ -51,26 +56,54 @@ class GenericCpu(Platform): else: assert False, "unreachable code" - def select_function( - self, math_function: PsMathFunction, dtype: PsType - ) -> CFunction: - func = math_function.func + def select_function(self, call: PsCall) -> PsExpression: + assert isinstance(call.function, PsMathFunction) + + func = call.function.func + dtype = call.get_dtype() arg_types = (dtype,) * func.num_args + if isinstance(dtype, PsIeeeFloatType) and dtype.width in (32, 64): + cfunc: CFunction match func: case ( MathFunctions.Exp + | MathFunctions.Log | MathFunctions.Sin | MathFunctions.Cos | MathFunctions.Tan + | MathFunctions.Sinh + | MathFunctions.Cosh + | MathFunctions.ASin + | MathFunctions.ACos + | MathFunctions.ATan + | MathFunctions.ATan2 | MathFunctions.Pow + | MathFunctions.Floor + | MathFunctions.Ceil ): - return CFunction(func.function_name, arg_types, dtype) + cfunc = CFunction(func.function_name, arg_types, dtype) case MathFunctions.Abs | MathFunctions.Min | MathFunctions.Max: - return CFunction("f" + func.function_name, arg_types, dtype) + cfunc = CFunction("f" + func.function_name, arg_types, dtype) + + call.function = cfunc + return call + + if isinstance(dtype, PsIntegerType): + match func: + case MathFunctions.Abs: + zero = PsExpression.make(PsConstant(0, dtype)) + arg = call.args[0] + return PsTernary(PsGe(arg, zero), arg, - arg) + case MathFunctions.Min: + arg1, arg2 = call.args + return PsTernary(PsLe(arg1, arg2), arg1, arg2) + case MathFunctions.Max: + arg1, arg2 = call.args + return PsTernary(PsGe(arg1, arg2), arg1, arg2) raise MaterializationError( - f"No implementation available for function {math_function} on data type {dtype}" + f"No implementation available for function {func} on data type {dtype}" ) # Internals diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 6e860d32b07de8d7993ee96b103ce2d0983766d8..1403b8f5ca1329812749bcc9d9bd50d4fcf4ac98 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -1,5 +1,3 @@ -from pystencils.backend.functions import CFunction, PsMathFunction -from pystencils.types.types import PsType from .platform import Platform from ..kernelcreation.iteration_space import ( @@ -13,6 +11,7 @@ from ..ast.expressions import ( PsExpression, PsLiteralExpr, PsAdd, + PsCall ) from ..ast.expressions import PsLt, PsAnd from ...types import PsSignedIntegerType @@ -57,9 +56,7 @@ class GenericGpu(Platform): return indices[:dim] - def select_function( - self, math_function: PsMathFunction, dtype: PsType - ) -> CFunction: + def select_function(self, call: PsCall) -> PsExpression: raise NotImplementedError() # Internals diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py index 2c718ae5fd329197c3a67a26d51c9a737f63271f..3f8912e81c6f42ba776dfd5e9cd7d8895f93ae4b 100644 --- a/src/pystencils/backend/platforms/platform.py +++ b/src/pystencils/backend/platforms/platform.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod from ..ast.structural import PsBlock -from ..functions import PsMathFunction, CFunction -from ...types import PsType +from ..ast.expressions import PsCall, PsExpression from ..kernelcreation.context import KernelCreationContext from ..kernelcreation.iteration_space import IterationSpace @@ -33,8 +32,8 @@ class Platform(ABC): @abstractmethod def select_function( - self, math_function: PsMathFunction, dtype: PsType - ) -> CFunction: + self, call: PsCall + ) -> PsExpression: """Select an implementation for the given function on the given data type. If no viable implementation exists, raise a `MaterializationError`. diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index 15a6d5c5c48f5e27f7bc65880f0ad1eb851b4825..bd3b2bb5871b59dd1e0431ed3d5ee7d26c4ff7a6 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -316,7 +316,7 @@ class EliminateConstants: ) elif isinstance(expr, PsDiv): if is_int: - from ..ast.util import c_intdiv + from ...utils import c_intdiv folded = PsConstant(c_intdiv(v1, v2), dtype) elif isinstance(dtype, PsIeeeFloatType): folded = PsConstant(v1 / v2, dtype) diff --git a/src/pystencils/backend/transformations/select_functions.py b/src/pystencils/backend/transformations/select_functions.py index c4085b2bc913a885cdd9f6a42de19a8f2c2ab404..e41c345ae4ed71101d07fcaa5b9df88b1e0f54e2 100644 --- a/src/pystencils/backend/transformations/select_functions.py +++ b/src/pystencils/backend/transformations/select_functions.py @@ -12,13 +12,12 @@ class SelectFunctions: self._platform = platform def __call__(self, node: PsAstNode) -> PsAstNode: - self.visit(node) - return node + return self.visit(node) - def visit(self, node: PsAstNode): - for c in node.children: - self.visit(c) + def visit(self, node: PsAstNode) -> PsAstNode: + node.children = [self.visit(c) for c in node.children] if isinstance(node, PsCall) and isinstance(node.function, PsMathFunction): - impl = self._platform.select_function(node.function, node.get_dtype()) - node.function = impl + return self._platform.select_function(node) + else: + return node diff --git a/src/pystencils/config.py b/src/pystencils/config.py index 7c49c4c37c0257adde151c3c32680faa50e9a36a..fe0e87900800c019d042b0994211bfb9cdd99b15 100644 --- a/src/pystencils/config.py +++ b/src/pystencils/config.py @@ -10,7 +10,7 @@ from .field import Field, FieldType from .backend.jit import JitBase from .backend.exceptions import PsOptionsError -from .types import PsIntegerType, PsNumericType, PsIeeeFloatType +from .types import PsIntegerType, UserTypeSpec, PsIeeeFloatType from .defaults import DEFAULTS @@ -169,7 +169,7 @@ class CreateKernelConfig: index_dtype: PsIntegerType = DEFAULTS.index_dtype """Data type used for all index calculations.""" - default_dtype: PsNumericType = PsIeeeFloatType(64) + default_dtype: UserTypeSpec = PsIeeeFloatType(64) """Default numeric data type. This data type will be applied to all untyped symbols. diff --git a/src/pystencils/kernel_wrapper.py b/src/pystencils/kernel_wrapper.py index d5dfbecca5ae25fe926c0cf5e6753e7bd0cca5cb..fba94abd5b67ddfa5d196b24cabd1c6633c9253a 100644 --- a/src/pystencils/kernel_wrapper.py +++ b/src/pystencils/kernel_wrapper.py @@ -16,6 +16,10 @@ class KernelWrapper: def __call__(self, **kwargs): return self.kernel(**kwargs) + + @property + def target(self) -> pystencils.Target: + return self.ast.target @property def code(self): diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 66c2a0d6c16e291ba5f6315478406668e7e91069..3cda5aa46313d46251ef9c73c6348e2f65c1af54 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -2,6 +2,7 @@ from typing import cast from .enums import Target from .config import CreateKernelConfig +from .types import create_numeric_type from .backend import ( KernelFunction, KernelParameter, @@ -53,7 +54,7 @@ def create_kernel( """ ctx = KernelCreationContext( - default_dtype=config.default_dtype, index_dtype=config.index_dtype + default_dtype=create_numeric_type(config.default_dtype), index_dtype=config.index_dtype ) if isinstance(assignments, Assignment): diff --git a/src/pystencils/symb.py b/src/pystencils/symb.py new file mode 100644 index 0000000000000000000000000000000000000000..0c682b26113c70ca2304bc63a15a6aa7e8d8ad9f --- /dev/null +++ b/src/pystencils/symb.py @@ -0,0 +1,23 @@ +"""pystencils extensions to the SymPy symbolic language.""" + +from .sympyextensions.integer_functions import ( + bitwise_and, + bitwise_or, + bitwise_xor, + bit_shift_left, + bit_shift_right, + int_div, + int_rem, + int_power_of_2, +) + +__all__ = [ + "bitwise_and", + "bitwise_or", + "bitwise_xor", + "bit_shift_left", + "bit_shift_right", + "int_div", + "int_rem", + "int_power_of_2", +] diff --git a/src/pystencils/sympyextensions/integer_functions.py b/src/pystencils/sympyextensions/integer_functions.py index c3dd181083ecbfecf5dc6d445f5fa1d0ac45f601..3b215266eac147ea3082a0536e180728201d6b3e 100644 --- a/src/pystencils/sympyextensions/integer_functions.py +++ b/src/pystencils/sympyextensions/integer_functions.py @@ -45,9 +45,19 @@ class bitwise_or(IntegerFunctionTwoArgsMixIn): # noinspection PyPep8Naming class int_div(IntegerFunctionTwoArgsMixIn): + """C-style round-to-zero integer division""" def _eval_op(self, arg1, arg2): - return int(arg1 // arg2) + from ..utils import c_intdiv + return c_intdiv(arg1, arg2) + + +class int_rem(IntegerFunctionTwoArgsMixIn): + """C-style round-to-zero integer remainder""" + + def _eval_op(self, arg1, arg2): + from ..utils import c_rem + return c_rem(arg1, arg2) # noinspection PyPep8Naming diff --git a/src/pystencils/sympyextensions/pointers.py b/src/pystencils/sympyextensions/pointers.py index a814f941e0a2968be7fedfbb82bff612ae8f1d1a..c69f9376dd31c8e0f7976c750e9af17308b7991e 100644 --- a/src/pystencils/sympyextensions/pointers.py +++ b/src/pystencils/sympyextensions/pointers.py @@ -1,5 +1,5 @@ import sympy as sp -from ..types import PsPointerType +from ..types import PsPointerType, PsType class AddressOf(sp.Function): @@ -25,7 +25,9 @@ class AddressOf(sp.Function): @property def dtype(self): - if hasattr(self.args[0], 'dtype'): - return PsPointerType(self.args[0].dtype, restrict=True, const=True) + arg_type = getattr(self.args[0], 'dtype', None) + if arg_type is not None: + assert isinstance(arg_type, PsType) + return PsPointerType(arg_type, restrict=True, const=True) else: raise ValueError(f'pystencils supports only non void pointers. Current address_of type: {self.args[0]}') diff --git a/src/pystencils/utils.py b/src/pystencils/utils.py index f872ae48a54a2ae8c9437ec826be9b51c061f52e..98331e7e5f561405cee99c3422b9f232747ecfcb 100644 --- a/src/pystencils/utils.py +++ b/src/pystencils/utils.py @@ -250,3 +250,14 @@ class ContextVar: def get(self): return self.stack[-1] + + +def c_intdiv(num, denom): + """C-style integer division""" + return int(num / denom) + + +def c_rem(num, denom): + """C-style integer remainder""" + div = c_intdiv(num, denom) + return num - div * denom diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index 1593be684ca48ec24b730168feaadb3331d4d7dd..b22df7d0bd132cc530e289b630f9c48851e4996b 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -25,9 +25,11 @@ from pystencils.backend.ast.expressions import ( PsLt, PsLe, PsGt, - PsGe + PsGe, + PsCall ) from pystencils.backend.constants import PsConstant +from pystencils.backend.functions import PsMathFunction, MathFunctions from pystencils.backend.kernelcreation import ( KernelCreationContext, FreezeExpressions, @@ -163,14 +165,21 @@ def test_freeze_booleans(): x2 = PsExpression.make(ctx.get_symbol("x")) y2 = PsExpression.make(ctx.get_symbol("y")) z2 = PsExpression.make(ctx.get_symbol("z")) + w2 = PsExpression.make(ctx.get_symbol("w")) - x, y, z = sp.symbols("x, y, z") + x, y, z, w = sp.symbols("x, y, z, w") + + expr = freeze(sp.Not(sp.And(x, y))) + assert expr.structurally_equal(PsNot(PsAnd(x2, y2))) + + expr = freeze(sp.Or(sp.Not(z), sp.And(y, sp.Not(x)))) + assert expr.structurally_equal(PsOr(PsNot(z2), PsAnd(y2, PsNot(x2)))) - expr1 = freeze(sp.Not(sp.And(x, y))) - assert expr1.structurally_equal(PsNot(PsAnd(x2, y2))) + expr = freeze(sp.And(w, x, y, z)) + assert expr.structurally_equal(PsAnd(PsAnd(PsAnd(w2, x2), y2), z2)) - expr2 = freeze(sp.Or(sp.Not(z), sp.And(y, sp.Not(x)))) - assert expr2.structurally_equal(PsOr(PsNot(z2), PsAnd(y2, PsNot(x2)))) + expr = freeze(sp.Or(w, x, y, z)) + assert expr.structurally_equal(PsOr(PsOr(PsOr(w2, x2), y2), z2)) @pytest.mark.parametrize("rel_pair", [ @@ -220,3 +229,33 @@ def test_freeze_piecewise(): piecewise = sp.Piecewise((x, p), (y, q), (z, sp.Or(p, q))) with pytest.raises(FreezeError): freeze(piecewise) + + +def test_multiarg_min_max(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + w, x, y, z = sp.symbols("w, x, y, z") + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + z2 = PsExpression.make(ctx.get_symbol("z")) + w2 = PsExpression.make(ctx.get_symbol("w")) + + def op(a, b): + return PsCall(PsMathFunction(MathFunctions.Min), (a, b)) + + expr = freeze(sp.Min(w, x, y)) + assert expr.structurally_equal(op(op(w2, x2), y2)) + + expr = freeze(sp.Min(w, x, y, z)) + assert expr.structurally_equal(op(op(w2, x2), op(y2, z2))) + + def op(a, b): + return PsCall(PsMathFunction(MathFunctions.Max), (a, b)) + + expr = freeze(sp.Max(w, x, y)) + assert expr.structurally_equal(op(op(w2, x2), y2)) + + expr = freeze(sp.Max(w, x, y, z)) + assert expr.structurally_equal(op(op(w2, x2), op(y2, z2))) diff --git a/tests/nbackend/test_functions.py b/tests/nbackend/test_functions.py index e88e51e48244f557d19e8c3ebbc5129d335c40a5..c14e118a088a8ffb18b9444f8481f266b85f03e3 100644 --- a/tests/nbackend/test_functions.py +++ b/tests/nbackend/test_functions.py @@ -6,16 +6,25 @@ from pystencils import create_kernel, CreateKernelConfig, Target, Assignment, Fi UNARY_FUNCTIONS = { "exp": (sp.exp, np.exp), + "log": (sp.log, np.log), "sin": (sp.sin, np.sin), "cos": (sp.cos, np.cos), "tan": (sp.tan, np.tan), + "sinh": (sp.sinh, np.sinh), + "cosh": (sp.cosh, np.cosh), + "asin": (sp.asin, np.arcsin), + "acos": (sp.acos, np.arccos), + "atan": (sp.atan, np.arctan), "abs": (sp.Abs, np.abs), + "floor": (sp.floor, np.floor), + "ceil": (sp.ceiling, np.ceil), } BINARY_FUNCTIONS = { "min": (sp.Min, np.fmin), "max": (sp.Max, np.fmax), "pow": (sp.Pow, np.power), + "atan2": (sp.atan2, np.arctan2), } diff --git a/tests/test_Min_Max.py b/tests/symbolics/test_Min_Max.py similarity index 100% rename from tests/test_Min_Max.py rename to tests/symbolics/test_Min_Max.py diff --git a/tests/symbolics/test_abs.py b/tests/symbolics/test_abs.py new file mode 100644 index 0000000000000000000000000000000000000000..daa4b17c17c51909c3124079677dcfa1eb87cb33 --- /dev/null +++ b/tests/symbolics/test_abs.py @@ -0,0 +1,21 @@ +import pytest + +import pystencils as ps +import sympy + + +@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU)) +def test_abs(target): + if target == ps.Target.GPU: + # FIXME + pytest.skip("GPU target not ready yet") + + x, y, z = ps.fields('x, y, z: int64[2d]') + + assignments = ps.AssignmentCollection({x[0, 0]: sympy.Abs(y[0, 0])}) + + config = ps.CreateKernelConfig(target=target) + ast = ps.create_kernel(assignments, config=config) + code = ps.get_code_str(ast) + print(code) + assert 'fabs(' not in code diff --git a/tests/test_address_of.py b/tests/symbolics/test_address_of.py similarity index 53% rename from tests/test_address_of.py rename to tests/symbolics/test_address_of.py index 0d25be62c145032ee8e3c298da45543ff835d63d..da11ecbe5374b95801f2de027b4db4df9e2fa04d 100644 --- a/tests/test_address_of.py +++ b/tests/symbolics/test_address_of.py @@ -3,35 +3,37 @@ Test of pystencils.data_types.address_of """ import pytest import pystencils -from pystencils.typing import PointerType, CastFunc, BasicType +from pystencils.types import PsPointerType, create_type from pystencils.sympyextensions.pointers import AddressOf -from pystencils.simp.simplifications import sympy_cse +from pystencils.sympyextensions.typed_sympy import CastFunc +from pystencils.sympyextensions import sympy_cse import sympy as sp def test_address_of(): x, y = pystencils.fields('x, y: int64[2d]') - s = pystencils.TypedSymbol('s', PointerType(BasicType('int64'))) + s = pystencils.TypedSymbol('s', PsPointerType(create_type('int64'))) assert AddressOf(x[0, 0]).canonical() == x[0, 0] - assert AddressOf(x[0, 0]).dtype == PointerType(x[0, 0].dtype, restrict=True) + assert AddressOf(x[0, 0]).dtype == PsPointerType(x[0, 0].dtype, restrict=True, const=True) + with pytest.raises(ValueError): assert AddressOf(sp.Symbol("a")).dtype assignments = pystencils.AssignmentCollection({ s: AddressOf(x[0, 0]), - y[0, 0]: CastFunc(s, BasicType('int64')) + y[0, 0]: CastFunc(s, create_type('int64')) }) - kernel = pystencils.create_kernel(assignments).compile() + _ = pystencils.create_kernel(assignments).compile() # pystencils.show_code(kernel.ast) assignments = pystencils.AssignmentCollection({ - y[0, 0]: CastFunc(AddressOf(x[0, 0]), BasicType('int64')) + y[0, 0]: CastFunc(AddressOf(x[0, 0]), create_type('int64')) }) - kernel = pystencils.create_kernel(assignments).compile() + _ = pystencils.create_kernel(assignments).compile() # pystencils.show_code(kernel.ast) @@ -39,12 +41,12 @@ def test_address_of_with_cse(): x, y = pystencils.fields('x, y: int64[2d]') assignments = pystencils.AssignmentCollection({ - x[0, 0]: CastFunc(AddressOf(x[0, 0]), BasicType('int64')) + 1 + x[0, 0]: CastFunc(AddressOf(x[0, 0]), create_type('int64')) + 1 }) - kernel = pystencils.create_kernel(assignments).compile() + _ = pystencils.create_kernel(assignments).compile() # pystencils.show_code(kernel.ast) assignments_cse = sympy_cse(assignments) - kernel = pystencils.create_kernel(assignments_cse).compile() + _ = pystencils.create_kernel(assignments_cse).compile() # pystencils.show_code(kernel.ast) diff --git a/tests/test_conditional_field_access.py b/tests/symbolics/test_conditional_field_access.py similarity index 93% rename from tests/test_conditional_field_access.py rename to tests/symbolics/test_conditional_field_access.py index 1e120304bdb8bf812dbd4719f11327ef69a60e79..bd384a95948511ede2d65222b69a81479c717a30 100644 --- a/tests/test_conditional_field_access.py +++ b/tests/symbolics/test_conditional_field_access.py @@ -51,6 +51,9 @@ def add_fixed_constant_boundary_handling(assignments, with_cse): @pytest.mark.parametrize('dtype', ('float64', 'float32')) @pytest.mark.parametrize('with_cse', (False, 'with_cse')) def test_boundary_check(dtype, with_cse): + if with_cse: + pytest.xfail("Doesn't typify correctly yet.") + f, g = ps.fields(f"f, g : {dtype}[2D]") stencil = ps.Assignment(g[0, 0], (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4) @@ -59,7 +62,7 @@ def test_boundary_check(dtype, with_cse): assignments = add_fixed_constant_boundary_handling(ps.AssignmentCollection([stencil]), with_cse) - config = ps.CreateKernelConfig(data_type=dtype, default_number_float=dtype, ghost_layers=0) + config = ps.CreateKernelConfig(default_dtype=ps.create_type(dtype), ghost_layers=0) kernel_checked = ps.create_kernel(assignments, config=config).compile() # ps.show_code(kernel_checked) diff --git a/tests/test_abs.py b/tests/test_abs.py deleted file mode 100644 index 277cf4f5c4a39598aafbded82a267e6619c15bee..0000000000000000000000000000000000000000 --- a/tests/test_abs.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest - -import pystencils.config -import sympy - -import pystencils as ps -from pystencils.typing import CastFunc, create_type - - -@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU)) -def test_abs(target): - x, y, z = ps.fields('x, y, z: float64[2d]') - - default_int_type = create_type('int64') - - assignments = ps.AssignmentCollection({x[0, 0]: sympy.Abs(CastFunc(y[0, 0], default_int_type))}) - - config = pystencils.config.CreateKernelConfig(target=target) - ast = ps.create_kernel(assignments, config=config) - code = ps.get_code_str(ast) - print(code) - assert 'fabs(' not in code