diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 916a61392bded7660c95962c64cfe55f0c743b88..3e8e8d8e4ea6652b4f331f173266e5e02d23ee49 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -45,6 +45,7 @@ from .sympyextensions.reduction import ( MinReducedssignment, MaxReducedssignment ) +from .binop_mapping import binop_str_to_expr __all__ = [ "Field", @@ -75,6 +76,7 @@ __all__ = [ "inspect", "AssignmentCollection", "Assignment", + "binop_str_to_expr", "AddAugmentedAssignment", "AddReducedAssignment", "SubReducedAssignment", diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 1238f16af8d7ae9619e664526c86df897e81b2a9..68868e1438c54463f65fa1f278f059bc398c2ec4 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -13,7 +13,7 @@ from ...sympyextensions import ( integer_functions, ConditionalFieldAccess, ) -from ...sympyextensions.binop_mapping import binop_str_to_expr +from ...binop_mapping import binop_str_to_expr from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType from ...sympyextensions.pointers import AddressOf, mem_acc from ...sympyextensions.reduction import ReducedAssignment @@ -185,6 +185,8 @@ class FreezeExpressions: orig_lhs_symb = lhs.symbol dtype = rhs.dtype # TODO: kernel with (implicit) up/downcasts? + assert isinstance(dtype, PsNumericType) + # replace original symbol with pointer-based type used for export orig_lhs_symb_as_ptr = PsSymbol(orig_lhs_symb.name, PsPointerType(dtype)) @@ -196,7 +198,6 @@ class FreezeExpressions: new_rhs: PsExpression = binop_str_to_expr(expr.op, new_lhs.clone(), rhs) # match for reduction operation and set neutral init_val - new_rhs: PsExpression init_val: PsExpression match expr.op: case "+": diff --git a/src/pystencils/sympyextensions/binop_mapping.py b/src/pystencils/binop_mapping.py similarity index 85% rename from src/pystencils/sympyextensions/binop_mapping.py rename to src/pystencils/binop_mapping.py index 04cfb6107d3fd8f449b9a300a7442ca3baaea5ab..060fa40aad732922e6f14d5f62ab04e71a2ff487 100644 --- a/src/pystencils/sympyextensions/binop_mapping.py +++ b/src/pystencils/binop_mapping.py @@ -1,8 +1,8 @@ from operator import truediv, mul, sub, add -from ..backend.ast.expressions import PsCall, PsExpression -from ..backend.exceptions import FreezeError -from ..backend.functions import MathFunctions, PsMathFunction +from .backend.ast.expressions import PsExpression, PsCall +from .backend.exceptions import FreezeError +from .backend.functions import PsMathFunction, MathFunctions _available_operator_interface: set[str] = {'+', '-', '*', '/'} diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index b47ad8a9e4dd7c3b624ec7fd003bafce8193752c..d68bfbcacb7933ff32da32015673a0fb75ac1a92 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -7,8 +7,8 @@ from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO from .kernel import Kernel, GpuKernel, GpuThreadsRange from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr, ReductionPointerVariable from .parameters import Parameter +from ..binop_mapping import binop_str_to_expr from ..backend.ast.expressions import PsSymbolExpr, PsMemAcc, PsConstantExpr -from ..sympyextensions.binop_mapping import binop_str_to_expr from ..types import create_numeric_type, PsIntegerType, PsScalarType @@ -155,14 +155,14 @@ class DefaultKernelCreationDriver: self._intermediates.constants_eliminated = kernel_ast.clone() # Init local reduction variable copy - for local_red, prop in self._ctx.local_reduction_symbols.items(): - kernel_ast.statements = [PsDeclaration(PsSymbolExpr(local_red), prop.init_val)] + kernel_ast.statements + for local_red, local_prop in self._ctx.local_reduction_symbols.items(): + kernel_ast.statements = [PsDeclaration(PsSymbolExpr(local_red), local_prop.init_val)] + kernel_ast.statements # Write back result to reduction target variable - for red_ptr, prop in self._ctx.reduction_pointer_symbols.items(): + for red_ptr, ptr_prop in self._ctx.reduction_pointer_symbols.items(): ptr_access = PsMemAcc(PsSymbolExpr(red_ptr), PsConstantExpr(PsConstant(0, self._ctx.index_dtype))) kernel_ast.statements += [PsAssignment( - ptr_access, binop_str_to_expr(prop.op, ptr_access, PsSymbolExpr(prop.local_symbol)))] + ptr_access, binop_str_to_expr(ptr_prop.op, ptr_access, PsSymbolExpr(ptr_prop.local_symbol)))] # Target-Specific optimizations if self._cfg.target.is_cpu(): diff --git a/src/pystencils/sympyextensions/__init__.py b/src/pystencils/sympyextensions/__init__.py index 8d832ba2a9400d1ec0c1aef28c417066fdb78541..6ab24e936a2355782755badbf835d7e5c3bee73e 100644 --- a/src/pystencils/sympyextensions/__init__.py +++ b/src/pystencils/sympyextensions/__init__.py @@ -2,7 +2,6 @@ from .astnodes import ConditionalFieldAccess from .typed_sympy import TypedSymbol, CastFunc from .pointers import mem_acc from .reduction import reduced_assign -from .binop_mapping import binop_str_to_expr from .math import ( prod, @@ -36,7 +35,6 @@ from .math import ( __all__ = [ "ConditionalFieldAccess", "reduced_assign", - "binop_str_to_expr", "TypedSymbol", "CastFunc", "mem_acc",