From b76caf99cd66d51519ee9c5212ac1adf1bf89074 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 6 Mar 2024 15:24:32 +0100 Subject: [PATCH] introduce freezing of sp.Pow; additions freeze now introduces subtactions --- src/pystencils/backend/functions.py | 2 + .../backend/kernelcreation/freeze.py | 55 ++++++++++++++++++- src/pystencils/backend/kernelfunction.py | 30 ++++++++++ src/pystencils/field.py | 4 +- src/pystencils/sympyextensions/astnodes.py | 2 +- 5 files changed, 87 insertions(+), 6 deletions(-) diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 81c48d108..4380d8c56 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -39,6 +39,8 @@ class MathFunctions(Enum): Min = ("min", 2) Max = ("max", 2) + Pow = ("pow", 2) + def __init__(self, func_name, arg_count): self.function_name = func_name self.arg_count = arg_count diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 31d322aa1..8b54da651 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -1,6 +1,6 @@ from typing import overload, cast, Any from functools import reduce -from operator import add, mul +from operator import add, mul, sub import sympy as sp @@ -93,7 +93,7 @@ class FreezeExpressions: def freeze_expression(self, expr: sp.Expr) -> PsExpression: return cast(PsExpression, self.visit(expr)) - def map_Assignment(self, expr: Assignment): # noqa + def map_Assignment(self, expr: Assignment): lhs = self.visit(expr.lhs) rhs = self.visit(expr.rhs) @@ -112,10 +112,59 @@ class FreezeExpressions: return PsSymbolExpr(symb) def map_Add(self, expr: sp.Add) -> PsExpression: - return reduce(add, (self.visit_expr(arg) for arg in expr.args)) + # TODO: think about numerically sensible ways of freezing sums and products + signs: list[int] = [] + for summand in expr.args: + if summand.is_negative: + signs.append(-1) + elif isinstance(summand, sp.Mul) and any(factor.is_negative for factor in summand.args): + signs.append(-1) + else: + signs.append(1) + + frozen_expr = self.visit_expr(expr.args[0]) + + for sign, arg in zip(signs[1:], expr.args[1:]): + if sign == -1: + arg = - arg + op = sub + else: + op = add + + frozen_expr = op(frozen_expr, self.visit_expr(arg)) + + return frozen_expr def map_Mul(self, expr: sp.Mul) -> PsExpression: return reduce(mul, (self.visit_expr(arg) for arg in expr.args)) + + def map_Pow(self, expr: sp.Pow) -> PsExpression: + base = expr.args[0] + exponent = expr.args[1] + + base_frozen = self.visit_expr(base) + reciprocal = False + expand_product = False + + if exponent.is_Integer: + if exponent.is_negative: + reciprocal = True + exponent = - exponent + + if exponent <= sp.Integer(5): + expand_product = True + + if expand_product: + frozen_expr = reduce(mul, [base_frozen] * int(exponent)) + else: + exponent_frozen = self.visit_expr(exponent) + frozen_expr = PsMathFunction(MathFunctions.Pow)(base_frozen, exponent_frozen) + + if reciprocal: + one = PsExpression.make(PsConstant(1)) + frozen_expr = one / frozen_expr + + return frozen_expr def map_Integer(self, expr: sp.Integer) -> PsConstantExpr: value = int(expr) diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py index 33a9288a3..20d63e85e 100644 --- a/src/pystencils/backend/kernelfunction.py +++ b/src/pystencils/backend/kernelfunction.py @@ -28,6 +28,27 @@ class KernelParameter: def dtype(self): return self._dtype + def _hashable_contents(self): + return (self._name, self._dtype) + + def __hash__(self) -> int: + return hash(self._hashable_contents()) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, KernelParameter): + return False + + return ( + type(self) is type(other) + and self._hashable_contents() == other._hashable_contents() + ) + + def __str__(self) -> str: + return self._name + + def __repr__(self) -> str: + return f"{type(self).__name__}(name = {self._name}, dtype = {self._dtype})" + class FieldParameter(KernelParameter, ABC): __match_args__ = KernelParameter.__match_args__ + ("field",) @@ -39,6 +60,9 @@ class FieldParameter(KernelParameter, ABC): @property def field(self): return self._field + + def _hashable_contents(self): + return super()._hashable_contents() + (self._field,) class FieldShapeParam(FieldParameter): @@ -51,6 +75,9 @@ class FieldShapeParam(FieldParameter): @property def coordinate(self): return self._coordinate + + def _hashable_contents(self): + return super()._hashable_contents() + (self._coordinate,) class FieldStrideParam(FieldParameter): @@ -63,6 +90,9 @@ class FieldStrideParam(FieldParameter): @property def coordinate(self): return self._coordinate + + def _hashable_contents(self): + return super()._hashable_contents() + (self._coordinate,) class FieldPointerParam(FieldParameter): diff --git a/src/pystencils/field.py b/src/pystencils/field.py index 646cecb5d..b055ccb6b 100644 --- a/src/pystencils/field.py +++ b/src/pystencils/field.py @@ -123,8 +123,8 @@ class Field: """ @staticmethod - def create_generic(field_name, spatial_dimensions, dtype: UserTypeSpec = np.float64, index_dimensions=0, layout='numpy', - index_shape=None, field_type=FieldType.GENERIC) -> 'Field': + def create_generic(field_name, spatial_dimensions, dtype: UserTypeSpec = np.float64, index_dimensions=0, + layout='numpy', index_shape=None, field_type=FieldType.GENERIC) -> 'Field': """ Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes diff --git a/src/pystencils/sympyextensions/astnodes.py b/src/pystencils/sympyextensions/astnodes.py index 8483977d8..4fdc0f612 100644 --- a/src/pystencils/sympyextensions/astnodes.py +++ b/src/pystencils/sympyextensions/astnodes.py @@ -4,7 +4,7 @@ import uuid from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union import sympy as sp -from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment +from sympy.codegen.ast import Assignment, AugmentedAssignment from sympy.printing.latex import LatexPrinter import numpy as np -- GitLab