From 65c7e576af6cf01dd533900a10f97559032009fd Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Wed, 11 May 2022 14:21:54 +0200 Subject: [PATCH] Implemented Logarithm --- pystencils/kernel_contrains_check.py | 3 ++- pystencils/typing/leaf_typing.py | 2 +- pystencils_tests/test_logarithm.py | 26 ++++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 pystencils_tests/test_logarithm.py diff --git a/pystencils/kernel_contrains_check.py b/pystencils/kernel_contrains_check.py index b4b681e1c..f1fa4b8a1 100644 --- a/pystencils/kernel_contrains_check.py +++ b/pystencils/kernel_contrains_check.py @@ -10,10 +10,11 @@ from pystencils.field import Field from pystencils.node_collection import NodeCollection from pystencils.transformations import NestedScopes - +# TODO use this in Constraint Checker accepted_functions = [ sp.Pow, sp.sqrt, + sp.log, # TODO trigonometric functions (and whatever tests will fail) ] diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index b5af7e4b2..ddffd61ce 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -212,7 +212,7 @@ class TypeAdder: new_args.append(a) return expr.func(*new_args) if new_args else expr, collated_type elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction, - HyperbolicFunction)): + HyperbolicFunction, sp.log)): args_types = [self.figure_out_type(arg) for arg in expr.args] collated_type = collate_types([t for _, t in args_types]) new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types] diff --git a/pystencils_tests/test_logarithm.py b/pystencils_tests/test_logarithm.py new file mode 100644 index 000000000..85d7814a3 --- /dev/null +++ b/pystencils_tests/test_logarithm.py @@ -0,0 +1,26 @@ +import pytest +import numpy as np +import sympy as sp + +import pystencils as ps + + +@pytest.mark.parametrize('dtype', ["float64", "float32"]) +def test_log(dtype): + a = sp.Symbol("a") + x = ps.fields(f'x: {dtype}[1d]') + + assignments = ps.AssignmentCollection({x.center(): sp.log(a)}) + + ast = ps.create_kernel(assignments) + code = ps.get_code_str(ast) + kernel = ast.compile() + + # ps.show_code(ast) + + if dtype == "float64": + assert "float" not in code + + array = np.zeros((10,), dtype=dtype) + kernel(x=array, a=100) + assert np.allclose(array, 4.60517019) -- GitLab