Skip to content
Snippets Groups Projects
Commit b76caf99 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

introduce freezing of sp.Pow; additions freeze now introduces subtactions

parent 672d668c
No related merge requests found
Pipeline #63756 failed with stages
in 2 minutes and 54 seconds
...@@ -39,6 +39,8 @@ class MathFunctions(Enum): ...@@ -39,6 +39,8 @@ class MathFunctions(Enum):
Min = ("min", 2) Min = ("min", 2)
Max = ("max", 2) Max = ("max", 2)
Pow = ("pow", 2)
def __init__(self, func_name, arg_count): def __init__(self, func_name, arg_count):
self.function_name = func_name self.function_name = func_name
self.arg_count = arg_count self.arg_count = arg_count
......
from typing import overload, cast, Any from typing import overload, cast, Any
from functools import reduce from functools import reduce
from operator import add, mul from operator import add, mul, sub
import sympy as sp import sympy as sp
...@@ -93,7 +93,7 @@ class FreezeExpressions: ...@@ -93,7 +93,7 @@ class FreezeExpressions:
def freeze_expression(self, expr: sp.Expr) -> PsExpression: def freeze_expression(self, expr: sp.Expr) -> PsExpression:
return cast(PsExpression, self.visit(expr)) return cast(PsExpression, self.visit(expr))
def map_Assignment(self, expr: Assignment): # noqa def map_Assignment(self, expr: Assignment):
lhs = self.visit(expr.lhs) lhs = self.visit(expr.lhs)
rhs = self.visit(expr.rhs) rhs = self.visit(expr.rhs)
...@@ -112,10 +112,59 @@ class FreezeExpressions: ...@@ -112,10 +112,59 @@ class FreezeExpressions:
return PsSymbolExpr(symb) return PsSymbolExpr(symb)
def map_Add(self, expr: sp.Add) -> PsExpression: def map_Add(self, expr: sp.Add) -> PsExpression:
return reduce(add, (self.visit_expr(arg) for arg in expr.args)) # TODO: think about numerically sensible ways of freezing sums and products
signs: list[int] = []
for summand in expr.args:
if summand.is_negative:
signs.append(-1)
elif isinstance(summand, sp.Mul) and any(factor.is_negative for factor in summand.args):
signs.append(-1)
else:
signs.append(1)
frozen_expr = self.visit_expr(expr.args[0])
for sign, arg in zip(signs[1:], expr.args[1:]):
if sign == -1:
arg = - arg
op = sub
else:
op = add
frozen_expr = op(frozen_expr, self.visit_expr(arg))
return frozen_expr
def map_Mul(self, expr: sp.Mul) -> PsExpression: def map_Mul(self, expr: sp.Mul) -> PsExpression:
return reduce(mul, (self.visit_expr(arg) for arg in expr.args)) return reduce(mul, (self.visit_expr(arg) for arg in expr.args))
def map_Pow(self, expr: sp.Pow) -> PsExpression:
base = expr.args[0]
exponent = expr.args[1]
base_frozen = self.visit_expr(base)
reciprocal = False
expand_product = False
if exponent.is_Integer:
if exponent.is_negative:
reciprocal = True
exponent = - exponent
if exponent <= sp.Integer(5):
expand_product = True
if expand_product:
frozen_expr = reduce(mul, [base_frozen] * int(exponent))
else:
exponent_frozen = self.visit_expr(exponent)
frozen_expr = PsMathFunction(MathFunctions.Pow)(base_frozen, exponent_frozen)
if reciprocal:
one = PsExpression.make(PsConstant(1))
frozen_expr = one / frozen_expr
return frozen_expr
def map_Integer(self, expr: sp.Integer) -> PsConstantExpr: def map_Integer(self, expr: sp.Integer) -> PsConstantExpr:
value = int(expr) value = int(expr)
......
...@@ -28,6 +28,27 @@ class KernelParameter: ...@@ -28,6 +28,27 @@ class KernelParameter:
def dtype(self): def dtype(self):
return self._dtype return self._dtype
def _hashable_contents(self):
return (self._name, self._dtype)
def __hash__(self) -> int:
return hash(self._hashable_contents())
def __eq__(self, other: object) -> bool:
if not isinstance(other, KernelParameter):
return False
return (
type(self) is type(other)
and self._hashable_contents() == other._hashable_contents()
)
def __str__(self) -> str:
return self._name
def __repr__(self) -> str:
return f"{type(self).__name__}(name = {self._name}, dtype = {self._dtype})"
class FieldParameter(KernelParameter, ABC): class FieldParameter(KernelParameter, ABC):
__match_args__ = KernelParameter.__match_args__ + ("field",) __match_args__ = KernelParameter.__match_args__ + ("field",)
...@@ -39,6 +60,9 @@ class FieldParameter(KernelParameter, ABC): ...@@ -39,6 +60,9 @@ class FieldParameter(KernelParameter, ABC):
@property @property
def field(self): def field(self):
return self._field return self._field
def _hashable_contents(self):
return super()._hashable_contents() + (self._field,)
class FieldShapeParam(FieldParameter): class FieldShapeParam(FieldParameter):
...@@ -51,6 +75,9 @@ class FieldShapeParam(FieldParameter): ...@@ -51,6 +75,9 @@ class FieldShapeParam(FieldParameter):
@property @property
def coordinate(self): def coordinate(self):
return self._coordinate return self._coordinate
def _hashable_contents(self):
return super()._hashable_contents() + (self._coordinate,)
class FieldStrideParam(FieldParameter): class FieldStrideParam(FieldParameter):
...@@ -63,6 +90,9 @@ class FieldStrideParam(FieldParameter): ...@@ -63,6 +90,9 @@ class FieldStrideParam(FieldParameter):
@property @property
def coordinate(self): def coordinate(self):
return self._coordinate return self._coordinate
def _hashable_contents(self):
return super()._hashable_contents() + (self._coordinate,)
class FieldPointerParam(FieldParameter): class FieldPointerParam(FieldParameter):
......
...@@ -123,8 +123,8 @@ class Field: ...@@ -123,8 +123,8 @@ class Field:
""" """
@staticmethod @staticmethod
def create_generic(field_name, spatial_dimensions, dtype: UserTypeSpec = np.float64, index_dimensions=0, layout='numpy', def create_generic(field_name, spatial_dimensions, dtype: UserTypeSpec = np.float64, index_dimensions=0,
index_shape=None, field_type=FieldType.GENERIC) -> 'Field': layout='numpy', index_shape=None, field_type=FieldType.GENERIC) -> 'Field':
""" """
Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes
......
...@@ -4,7 +4,7 @@ import uuid ...@@ -4,7 +4,7 @@ import uuid
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
import sympy as sp import sympy as sp
from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment from sympy.codegen.ast import Assignment, AugmentedAssignment
from sympy.printing.latex import LatexPrinter from sympy.printing.latex import LatexPrinter
import numpy as np import numpy as np
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment