diff --git a/docs/source/backend/index.rst b/docs/source/backend/index.rst index df194bde9b9c50cc9ee3f1064ff5b5361205c227..e0e914b4d423fb5b9e32950185c6aa3474976d39 100644 --- a/docs/source/backend/index.rst +++ b/docs/source/backend/index.rst @@ -9,7 +9,7 @@ who wish to customize or extend the behaviour of the code generator in their app .. toctree:: :maxdepth: 1 - symbols + objects ast iteration_space translation diff --git a/docs/source/backend/symbols.rst b/docs/source/backend/objects.rst similarity index 80% rename from docs/source/backend/symbols.rst rename to docs/source/backend/objects.rst index 66c8c43ba63c7740f033e7409cb5fc6f6be9bc07..b0c3af6db67ff3cfb1e6a3d3603e84e6c4abb6cb 100644 --- a/docs/source/backend/symbols.rst +++ b/docs/source/backend/objects.rst @@ -8,5 +8,8 @@ Symbols, Constants and Arrays .. autoclass:: pystencils.backend.constants.PsConstant :members: +.. autoclass:: pystencils.backend.literals.PsLiteral + :members: + .. automodule:: pystencils.backend.arrays :members: diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 0666d96873d4bdd3d722a7912b6e704b4aee1cf8..7bcf62b973d8ace8e9ad9847ae165c398f1cbb0e 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -5,6 +5,7 @@ import operator from ..symbols import PsSymbol from ..constants import PsConstant +from ..literals import PsLiteral from ..arrays import PsLinearizedArray, PsArrayBasePointer from ..functions import PsFunction from ...types import ( @@ -76,12 +77,19 @@ class PsExpression(PsAstNode, ABC): def make(obj: PsConstant) -> PsConstantExpr: pass + @overload + @staticmethod + def make(obj: PsLiteral) -> PsLiteralExpr: + pass + @staticmethod - def make(obj: PsSymbol | PsConstant) -> PsSymbolExpr | PsConstantExpr: + def make(obj: PsSymbol | PsConstant | PsLiteral) -> PsExpression: if isinstance(obj, PsSymbol): return PsSymbolExpr(obj) elif isinstance(obj, PsConstant): return PsConstantExpr(obj) + elif isinstance(obj, PsLiteral): + return PsLiteralExpr(obj) else: raise ValueError(f"Cannot make expression out of {obj}") @@ -150,6 +158,34 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): def __repr__(self) -> str: return f"PsConstantExpr({repr(self._constant)})" + + +class PsLiteralExpr(PsLeafMixIn, PsExpression): + __match_args__ = ("literal",) + + def __init__(self, literal: PsLiteral): + super().__init__(literal.dtype) + self._literal = literal + + @property + def literal(self) -> PsLiteral: + return self._literal + + @literal.setter + def literal(self, lit: PsLiteral): + self._literal = lit + + def clone(self) -> PsLiteralExpr: + return PsLiteralExpr(self._literal) + + def structurally_equal(self, other: PsAstNode) -> bool: + if not isinstance(other, PsLiteralExpr): + return False + + return self._literal == other._literal + + def __repr__(self) -> str: + return f"PsLiteralExpr({repr(self._literal)})" class PsSubscript(PsLvalue, PsExpression): diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 588ac410a6118b668eacd08114fdea3c7853ba6f..f3d56c6c4c20e5969ee10d08ee42b6803a2e0b1c 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -34,6 +34,7 @@ from .ast.expressions import ( PsSub, PsSubscript, PsSymbolExpr, + PsLiteralExpr, PsVectorArrayAccess, PsAnd, PsOr, @@ -245,6 +246,9 @@ class CAstPrinter: ) return dtype.create_literal(constant.value) + + case PsLiteralExpr(lit): + return lit.text case PsVectorArrayAccess(): raise EmissionError("Cannot print vectorized array accesses") diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 313b622beaa151d294b8bcf6c66d730830ce2497..e420deaa657e0569996837c64a10187e323cf511 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -15,10 +15,13 @@ TODO: Figure out the best way to describe function signatures and overloads for """ from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import Any, Sequence, TYPE_CHECKING from abc import ABC from enum import Enum +from ..types import PsType +from .exceptions import PsInternalCompilerError + if TYPE_CHECKING: from .ast.expressions import PsExpression @@ -69,19 +72,59 @@ class PsFunction(ABC): class CFunction(PsFunction): - """A concrete C function.""" + """A concrete C function. + + Instances of this class represent a C function by its name, parameter types, and return type. + + Args: + name: Function name + param_types: Types of the function parameters + return_type: The function's return type + """ + + __match_args__ = ("name", "parameter_types", "return_type") + + @staticmethod + def parse(obj) -> CFunction: + """Parse the signature of a Python callable object to obtain a CFunction object. + + The callable must be fully annotated with type-like objects convertible by `create_type`. + """ + import inspect + from pystencils.types import create_type - def __init__(self, qualified_name: str, arg_count: int): - self._qname = qualified_name - self._arg_count = arg_count + if not inspect.isfunction(obj): + raise PsInternalCompilerError(f"Cannot parse object {obj} as a function") + + func_sig = inspect.signature(obj) + func_name = obj.__name__ + arg_types = [ + create_type(param.annotation) for param in func_sig.parameters.values() + ] + ret_type = create_type(func_sig.return_annotation) + + return CFunction(func_name, arg_types, ret_type) + + def __init__(self, name: str, param_types: Sequence[PsType], return_type: PsType): + super().__init__(name, len(param_types)) + + self._param_types = tuple(param_types) + self._return_type = return_type @property - def qualified_name(self) -> str: - return self._qname + def parameter_types(self) -> tuple[PsType, ...]: + return self._param_types @property - def arg_count(self) -> int: - return self._arg_count + def return_type(self) -> PsType: + return self._return_type + + def __str__(self) -> str: + param_types = ", ".join(str(t) for t in self._param_types) + return f"{self._return_type} {self._name}({param_types})" + + def __repr__(self) -> str: + return f"CFunction({self._name}, {self._param_types}, {self._return_type})" class PsMathFunction(PsFunction): diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py index 83c406b0a99d52cee9599f321d6c32477f6dbf8a..d5695be93a7a89134f3bc7a12f623006768579d1 100644 --- a/src/pystencils/backend/kernelcreation/ast_factory.py +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -37,6 +37,10 @@ class AstFactory: self._freeze = FreezeExpressions(ctx) self._typify = Typifier(ctx) + @overload + def parse_sympy(self, sp_obj: sp.Symbol) -> PsSymbolExpr: + pass + @overload def parse_sympy(self, sp_obj: sp.Expr) -> PsExpression: pass diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 1bf3c49807ff52a34bc9ab319f5da67e4fa59ebc..dbec20235f0a37cfab763771e2e2fbed05a3c196 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -39,11 +39,12 @@ from ..ast.expressions import ( PsLookup, PsSubscript, PsSymbolExpr, + PsLiteralExpr, PsRel, PsNeg, PsNot, ) -from ..functions import PsMathFunction +from ..functions import PsMathFunction, CFunction __all__ = ["Typifier"] @@ -158,6 +159,14 @@ class TypeContext: f" Constant type: {c.dtype}\n" f" Target type: {self._target_type}" ) + + case PsLiteralExpr(lit): + if not self._compatible(lit.dtype): + raise TypificationError( + f"Type mismatch at literal {lit}: Literal type did not match the context's target type\n" + f" Literal type: {lit.dtype}\n" + f" Target type: {self._target_type}" + ) case PsSymbolExpr(symb): assert symb.dtype is not None @@ -356,6 +365,9 @@ class Typifier: else: tc.infer_dtype(expr) + case PsLiteralExpr(lit): + tc.apply_dtype(lit.dtype, expr) + case PsArrayAccess(bptr, idx): tc.apply_dtype(bptr.array.element_type, expr) @@ -467,6 +479,14 @@ class Typifier: for arg in args: self.visit_expr(arg, tc) tc.infer_dtype(expr) + + case CFunction(_, arg_types, ret_type): + tc.apply_dtype(ret_type, expr) + + for arg, arg_type in zip(args, arg_types, strict=True): + arg_tc = TypeContext(arg_type) + self.visit_expr(arg, arg_tc) + case _: raise TypificationError( f"Don't know how to typify calls to {function}" @@ -494,6 +514,7 @@ class Typifier: ) else: items_tc.apply_dtype(tc.target_type.base_type) + tc.infer_dtype(expr) else: arr_type = PsArrayType(items_tc.target_type, len(items)) tc.apply_dtype(arr_type, expr) diff --git a/src/pystencils/backend/literals.py b/src/pystencils/backend/literals.py new file mode 100644 index 0000000000000000000000000000000000000000..dc7504f520f8950b46df76b0359aaad371244b19 --- /dev/null +++ b/src/pystencils/backend/literals.py @@ -0,0 +1,43 @@ +from __future__ import annotations +from ..types import PsType, constify + + +class PsLiteral: + """Representation of literal code. + + Instances of this class represent code literals inside the AST. + These literals are not to be confused with C literals; the name `Literal` refers to the fact that + the code generator takes them "literally", printing them as they are. + + Each literal has to be annotated with a type, and is considered constant within the scope of a kernel. + Instances of `PsLiteral` are immutable. + """ + + __match_args__ = ("text", "dtype") + + def __init__(self, text: str, dtype: PsType) -> None: + self._text = text + self._dtype = constify(dtype) + + @property + def text(self) -> str: + return self._text + + @property + def dtype(self) -> PsType: + return self._dtype + + def __str__(self) -> str: + return f"{self._text}: {self._dtype}" + + def __repr__(self) -> str: + return f"PsLiteral({repr(self._text)}, {repr(self._dtype)})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsLiteral): + return False + + return self._text == other._text and self._dtype == other._dtype + + def __hash__(self) -> int: + return hash((PsLiteral, self._text, self._dtype)) diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 6899ac9474d303623f79e2bdc7c3765c64380a6c..17589bf27d3109a3ce891acfdc248e0120da2f77 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -55,6 +55,7 @@ class GenericCpu(Platform): self, math_function: PsMathFunction, dtype: PsType ) -> CFunction: func = math_function.func + arg_types = (dtype,) * func.num_args if isinstance(dtype, PsIeeeFloatType) and dtype.width in (32, 64): match func: case ( @@ -64,9 +65,9 @@ class GenericCpu(Platform): | MathFunctions.Tan | MathFunctions.Pow ): - return CFunction(func.function_name, func.num_args) + return CFunction(func.function_name, arg_types, dtype) case MathFunctions.Abs | MathFunctions.Min | MathFunctions.Max: - return CFunction("f" + func.function_name, func.num_args) + return CFunction("f" + func.function_name, arg_types, dtype) raise MaterializationError( f"No implementation available for function {math_function} on data type {dtype}" diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 839cd34f4c9fada060a0ac6253b635f7c7812948..6e860d32b07de8d7993ee96b103ce2d0983766d8 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -11,26 +11,26 @@ from ..kernelcreation.iteration_space import ( from ..ast.structural import PsBlock, PsConditional from ..ast.expressions import ( PsExpression, - PsSymbolExpr, + PsLiteralExpr, PsAdd, ) from ..ast.expressions import PsLt, PsAnd from ...types import PsSignedIntegerType -from ..symbols import PsSymbol +from ..literals import PsLiteral int32 = PsSignedIntegerType(width=32, const=False) BLOCK_IDX = [ - PsSymbolExpr(PsSymbol(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z") + PsLiteralExpr(PsLiteral(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z") ] THREAD_IDX = [ - PsSymbolExpr(PsSymbol(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z") + PsLiteralExpr(PsLiteral(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z") ] BLOCK_DIM = [ - PsSymbolExpr(PsSymbol(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z") + PsLiteralExpr(PsLiteral(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z") ] GRID_DIM = [ - PsSymbolExpr(PsSymbol(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z") + PsLiteralExpr(PsLiteral(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z") ] diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py index fa5af4655810943c47f503329a8c41ce3baa36c5..ccaf9fbe99f46ce4b0ecbb81c775c9f274678026 100644 --- a/src/pystencils/backend/platforms/x86.py +++ b/src/pystencils/backend/platforms/x86.py @@ -10,7 +10,7 @@ from ..ast.expressions import ( PsSubscript, ) from ..transformations.select_intrinsics import IntrinsicOps -from ...types import PsCustomType, PsVectorType +from ...types import PsCustomType, PsVectorType, PsPointerType from ..constants import PsConstant from ..exceptions import MaterializationError @@ -124,10 +124,13 @@ class X86VectorCpu(GenericVectorCpu): def constant_vector(self, c: PsConstant) -> PsExpression: vtype = c.dtype assert isinstance(vtype, PsVectorType) + stype = vtype.scalar_type prefix = self._vector_arch.intrin_prefix(vtype) suffix = self._vector_arch.intrin_suffix(vtype) - set_func = CFunction(f"{prefix}_set_{suffix}", vtype.vector_entries) + set_func = CFunction( + f"{prefix}_set_{suffix}", (stype,) * vtype.vector_entries, vtype + ) values = c.value return set_func(*values) @@ -164,7 +167,10 @@ def _x86_packed_load( ) -> CFunction: prefix = varch.intrin_prefix(vtype) suffix = varch.intrin_suffix(vtype) - return CFunction(f"{prefix}_load{'' if aligned else 'u'}_{suffix}", 1) + ptr_type = PsPointerType(vtype.scalar_type, const=True) + return CFunction( + f"{prefix}_load{'' if aligned else 'u'}_{suffix}", (ptr_type,), vtype + ) @cache @@ -173,7 +179,12 @@ def _x86_packed_store( ) -> CFunction: prefix = varch.intrin_prefix(vtype) suffix = varch.intrin_suffix(vtype) - return CFunction(f"{prefix}_store{'' if aligned else 'u'}_{suffix}", 2) + ptr_type = PsPointerType(vtype.scalar_type, const=True) + return CFunction( + f"{prefix}_store{'' if aligned else 'u'}_{suffix}", + (ptr_type, vtype), + PsCustomType("void"), + ) @cache @@ -197,4 +208,5 @@ def _x86_op_intrin( case _: assert False - return CFunction(f"{prefix}_{opstr}_{suffix}", 3 if op == IntrinsicOps.FMA else 2) + num_args = 3 if op == IntrinsicOps.FMA else 2 + return CFunction(f"{prefix}_{opstr}_{suffix}", (vtype,) * num_args, vtype) diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index 7fa4766eb305954f56d10b8cf8052c2fb26cb8fe..ddfa33f08272d59d032e1e657e66baa96fb41d04 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -9,6 +9,7 @@ from ..ast.expressions import ( PsExpression, PsConstantExpr, PsSymbolExpr, + PsLiteralExpr, PsBinOp, PsAdd, PsSub, @@ -159,8 +160,8 @@ class EliminateConstants: Returns: (transformed_expr, is_const): The tranformed expression, and a flag indicating whether it is constant """ - # Return constants as they are - if isinstance(expr, PsConstantExpr): + # Return constants and literals as they are + if isinstance(expr, (PsConstantExpr, PsLiteralExpr)): return expr, True # Shortcut symbols @@ -251,7 +252,6 @@ class EliminateConstants: # Detect constant expressions if all(subtree_constness): dtype = expr.get_dtype() - assert isinstance(dtype, PsNumericType) is_int = isinstance(dtype, PsIntegerType) is_float = isinstance(dtype, PsIeeeFloatType) @@ -274,6 +274,7 @@ class EliminateConstants: py_operator = expr.python_operator if do_fold and py_operator is not None: + assert isinstance(dtype, PsNumericType) folded = PsConstant(py_operator(val), dtype) return self._typify(PsConstantExpr(folded)), True @@ -287,6 +288,7 @@ class EliminateConstants: v2 = op2.constant.value if do_fold: + assert isinstance(dtype, PsNumericType) py_operator = expr.python_operator folded = None @@ -316,7 +318,7 @@ class EliminateConstants: # If required, extract constant subexpressions if self._extract_constant_exprs: for i, (child, is_const) in enumerate(subtree_results): - if is_const and not isinstance(child, PsConstantExpr): + if is_const and not isinstance(child, (PsConstantExpr, PsLiteralExpr)): replacement = ecc.extract_expression(child) expr.set_child(i, replacement) diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index 5824239e40ff8365a20658defab344892973f58f..cb9c9e92064d2061198f66cf5dc893d6491e1b34 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -7,6 +7,7 @@ from ..ast.expressions import ( PsExpression, PsSymbolExpr, PsConstantExpr, + PsLiteralExpr, PsCall, PsDeref, PsSubscript, @@ -40,7 +41,7 @@ class HoistContext: symbol in self.invariant_symbols ) - case PsConstantExpr(): + case PsConstantExpr() | PsLiteralExpr(): return True case PsCall(func): diff --git a/tests/nbackend/kernelcreation/test_domain_kernels.py b/tests/nbackend/kernelcreation/test_domain_kernels.py index 29744c384e03131784c08857c494bb7f83e7f0bd..9ce2f661d840641d28774134070fc7050e90e6d1 100644 --- a/tests/nbackend/kernelcreation/test_domain_kernels.py +++ b/tests/nbackend/kernelcreation/test_domain_kernels.py @@ -59,5 +59,3 @@ def test_filter_kernel_fixedsize(): expected[1:-1, 1:-1].fill(18.0) np.testing.assert_allclose(dst_arr, expected) - -test_filter_kernel() \ No newline at end of file diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 60d0d6e7424bdfea730cafe18995afdb7dc253df..01f68c0a3e637e3139990f9208710e9861243e9d 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -26,10 +26,12 @@ from pystencils.backend.ast.expressions import ( PsLe, PsGt, PsLt, + PsCall, ) from pystencils.backend.constants import PsConstant +from pystencils.backend.functions import CFunction from pystencils.types import constify, create_type, create_numeric_type -from pystencils.types.quick import Fp, Bool +from pystencils.types.quick import Fp, Int, Bool from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError @@ -354,7 +356,7 @@ def test_invalid_conditions(): x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"] - + cond = PsConditional(x + y, PsBlock([])) with pytest.raises(TypificationError): typify(cond) @@ -362,3 +364,24 @@ def test_invalid_conditions(): cond = PsConditional(PsAnd(p, PsOr(x, q)), PsBlock([])) with pytest.raises(TypificationError): typify(cond) + + +def test_cfunction(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] + p, q = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "pq"] + + def _threeway(x: np.float32, y: np.float32) -> np.int32: + assert False + + threeway = CFunction.parse(_threeway) + + result = typify(PsCall(threeway, [x, y])) + + assert result.get_dtype() == Int(32, const=True) + assert result.args[0].get_dtype() == Fp(32, const=True) + assert result.args[1].get_dtype() == Fp(32, const=True) + + with pytest.raises(TypificationError): + _ = typify(PsCall(threeway, (x, p))) diff --git a/tests/nbackend/test_ast_nodes.py b/tests/nbackend/test_ast.py similarity index 100% rename from tests/nbackend/test_ast_nodes.py rename to tests/nbackend/test_ast.py diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py index 1fc6821d7b530a8b8e10b0298b641219bd31a53a..4c83e6e995f0823f81a4627e93d38256f648d28c 100644 --- a/tests/nbackend/test_code_printing.py +++ b/tests/nbackend/test_code_printing.py @@ -1,49 +1,16 @@ from pystencils import Target -from pystencils.backend.ast.expressions import PsExpression, PsArrayAccess +from pystencils.backend.ast.expressions import PsExpression from pystencils.backend.ast.structural import PsAssignment, PsLoop, PsBlock from pystencils.backend.kernelfunction import KernelFunction from pystencils.backend.symbols import PsSymbol from pystencils.backend.constants import PsConstant +from pystencils.backend.literals import PsLiteral from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer from pystencils.types.quick import Fp, SInt, UInt, Bool from pystencils.backend.emission import CAstPrinter -# def test_basic_kernel(): - -# u_arr = PsLinearizedArray("u", Fp(64), (..., ), (1, )) -# u_size = PsExpression.make(u_arr.shape[0]) -# u_base = PsArrayBasePointer("u_data", u_arr) - -# loop_ctr = PsExpression.make(PsSymbol("ctr", UInt(32))) -# one = PsExpression.make(PsConstant(1, SInt(32))) - -# update = PsAssignment( -# PsArrayAccess(u_base, loop_ctr), -# PsArrayAccess(u_base, loop_ctr + one) + PsArrayAccess(u_base, loop_ctr - one), -# ) - -# loop = PsLoop( -# loop_ctr, -# one, -# u_size - one, -# one, -# PsBlock([update]) -# ) - -# func = KernelFunction(PsBlock([loop]), Target.CPU, "kernel", set()) - -# printer = CAstPrinter() -# code = printer(func) - -# paramlist = func.get_parameters().params -# params_str = ", ".join(f"{p.dtype} {p.name}" for p in paramlist) - -# assert code.find("(" + params_str + ")") >= 0 -# assert code.find("u_data[ctr] = u_data[ctr + 1] + u_data[ctr - 1];") >= 0 - - def test_arithmetic_precedence(): (a, b, c, d, e, f) = [PsExpression.make(PsSymbol(x, Fp(64))) for x in "abcdef"] cprint = CAstPrinter() diff --git a/tests/nbackend/test_extensions.py b/tests/nbackend/test_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..75726a3512b1291531ecf73a61af22258d17003c --- /dev/null +++ b/tests/nbackend/test_extensions.py @@ -0,0 +1,59 @@ + +import sympy as sp + +from pystencils import make_slice, Field, Assignment +from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, FullIterationSpace +from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations +from pystencils.backend.literals import PsLiteral +from pystencils.backend.emission import CAstPrinter +from pystencils.backend.ast.expressions import PsExpression, PsSubscript +from pystencils.backend.ast.structural import PsBlock, PsDeclaration +from pystencils.types.quick import Arr, Int + + +def test_literals(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + f = Field.create_generic("f", 3) + x = sp.Symbol("x") + + cells = PsExpression.make(PsLiteral("CELLS", Arr(Int(64, const=True), 3))) + global_constant = PsExpression.make(PsLiteral("C", ctx.default_dtype)) + + loop_slice = make_slice[ + 0:PsSubscript(cells, factory.parse_index(0)), + 0:PsSubscript(cells, factory.parse_index(1)), + 0:PsSubscript(cells, factory.parse_index(2)), + ] + + ispace = FullIterationSpace.create_from_slice(ctx, loop_slice) + ctx.set_iteration_space(ispace) + + x_decl = PsDeclaration(factory.parse_sympy(x), global_constant) + + loop_body = PsBlock([ + x_decl, + factory.parse_sympy(Assignment(f.center(), x)) + ]) + + loops = factory.loops_from_ispace(ispace, loop_body) + ast = PsBlock([loops]) + + canon = CanonicalizeSymbols(ctx) + ast = canon(ast) + + hoist = HoistLoopInvariantDeclarations(ctx) + ast = hoist(ast) + + assert isinstance(ast, PsBlock) + assert len(ast.statements) == 2 + assert ast.statements[0] == x_decl + + code = CAstPrinter()(ast) + print(code) + + assert "const double x = C;" in code + assert "CELLS[0]" in code + assert "CELLS[1]" in code + assert "CELLS[2]" in code diff --git a/tests/nbackend/test_functions.py b/tests/nbackend/test_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..e88e51e48244f557d19e8c3ebbc5129d335c40a5 --- /dev/null +++ b/tests/nbackend/test_functions.py @@ -0,0 +1,77 @@ +import sympy as sp +import numpy as np +import pytest + +from pystencils import create_kernel, CreateKernelConfig, Target, Assignment, Field + +UNARY_FUNCTIONS = { + "exp": (sp.exp, np.exp), + "sin": (sp.sin, np.sin), + "cos": (sp.cos, np.cos), + "tan": (sp.tan, np.tan), + "abs": (sp.Abs, np.abs), +} + +BINARY_FUNCTIONS = { + "min": (sp.Min, np.fmin), + "max": (sp.Max, np.fmax), + "pow": (sp.Pow, np.power), +} + + +@pytest.mark.parametrize("target", (Target.GenericCPU,)) +@pytest.mark.parametrize("function_name", UNARY_FUNCTIONS.keys()) +@pytest.mark.parametrize("dtype", (np.float32, np.float64)) +def test_unary_functions(target, function_name, dtype): + sp_func, np_func = UNARY_FUNCTIONS[function_name] + resolution: dtype = np.finfo(dtype).resolution + + inp = np.array( + [[0.1, 0.2, 0.3], [-0.8, -1.6, -12.592], [np.pi, np.e, 0.0]], dtype=dtype + ) + outp = np.zeros_like(inp) + + reference = np_func(inp) + + inp_field = Field.create_from_numpy_array("inp", inp) + outp_field = inp_field.new_field_with_different_name("outp") + + asms = [Assignment(outp_field.center(), sp_func(inp_field.center()))] + gen_config = CreateKernelConfig(target=target, default_dtype=dtype) + + kernel = create_kernel(asms, gen_config) + kfunc = kernel.compile() + kfunc(inp=inp, outp=outp) + + np.testing.assert_allclose(outp, reference, rtol=resolution) + + +@pytest.mark.parametrize("target", (Target.GenericCPU,)) +@pytest.mark.parametrize("function_name", BINARY_FUNCTIONS.keys()) +@pytest.mark.parametrize("dtype", (np.float32, np.float64)) +def test_binary_functions(target, function_name, dtype): + sp_func, np_func = BINARY_FUNCTIONS[function_name] + resolution: dtype = np.finfo(dtype).resolution + + inp = np.array( + [[0.1, 0.2, 0.3], [-0.8, -1.6, -12.592], [np.pi, np.e, 0.0]], dtype=dtype + ) + inp2 = np.array( + [[3.1, -0.5, 21.409], [11.0, 1.0, -14e3], [2.0 * np.pi, - np.e, 0.0]], dtype=dtype + ) + outp = np.zeros_like(inp) + + reference = np_func(inp, inp2) + + inp_field = Field.create_from_numpy_array("inp", inp) + inp2_field = Field.create_from_numpy_array("inp2", inp) + outp_field = inp_field.new_field_with_different_name("outp") + + asms = [Assignment(outp_field.center(), sp_func(inp_field.center(), inp2_field.center()))] + gen_config = CreateKernelConfig(target=target, default_dtype=dtype) + + kernel = create_kernel(asms, gen_config) + kfunc = kernel.compile() + kfunc(inp=inp, inp2=inp2, outp=outp) + + np.testing.assert_allclose(outp, reference, rtol=resolution)