From 65c7e576af6cf01dd533900a10f97559032009fd Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Wed, 11 May 2022 14:21:54 +0200
Subject: [PATCH] Implemented Logarithm

---
 pystencils/kernel_contrains_check.py |  3 ++-
 pystencils/typing/leaf_typing.py     |  2 +-
 pystencils_tests/test_logarithm.py   | 26 ++++++++++++++++++++++++++
 3 files changed, 29 insertions(+), 2 deletions(-)
 create mode 100644 pystencils_tests/test_logarithm.py

diff --git a/pystencils/kernel_contrains_check.py b/pystencils/kernel_contrains_check.py
index b4b681e1c..f1fa4b8a1 100644
--- a/pystencils/kernel_contrains_check.py
+++ b/pystencils/kernel_contrains_check.py
@@ -10,10 +10,11 @@ from pystencils.field import Field
 from pystencils.node_collection import NodeCollection
 from pystencils.transformations import NestedScopes
 
-
+# TODO use this in Constraint Checker
 accepted_functions = [
     sp.Pow,
     sp.sqrt,
+    sp.log,
     # TODO trigonometric functions (and whatever tests will fail)
 ]
 
diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py
index b5af7e4b2..ddffd61ce 100644
--- a/pystencils/typing/leaf_typing.py
+++ b/pystencils/typing/leaf_typing.py
@@ -212,7 +212,7 @@ class TypeAdder:
                     new_args.append(a)
             return expr.func(*new_args) if new_args else expr, collated_type
         elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction,
-                               HyperbolicFunction)):
+                               HyperbolicFunction, sp.log)):
             args_types = [self.figure_out_type(arg) for arg in expr.args]
             collated_type = collate_types([t for _, t in args_types])
             new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
diff --git a/pystencils_tests/test_logarithm.py b/pystencils_tests/test_logarithm.py
new file mode 100644
index 000000000..85d7814a3
--- /dev/null
+++ b/pystencils_tests/test_logarithm.py
@@ -0,0 +1,26 @@
+import pytest
+import numpy as np
+import sympy as sp
+
+import pystencils as ps
+
+
+@pytest.mark.parametrize('dtype', ["float64", "float32"])
+def test_log(dtype):
+    a = sp.Symbol("a")
+    x = ps.fields(f'x: {dtype}[1d]')
+
+    assignments = ps.AssignmentCollection({x.center(): sp.log(a)})
+
+    ast = ps.create_kernel(assignments)
+    code = ps.get_code_str(ast)
+    kernel = ast.compile()
+
+    # ps.show_code(ast)
+
+    if dtype == "float64":
+        assert "float" not in code
+
+    array = np.zeros((10,), dtype=dtype)
+    kernel(x=array, a=100)
+    assert np.allclose(array, 4.60517019)
-- 
GitLab