Skip to content
Snippets Groups Projects
Commit 65c7e576 authored by Markus Holzer's avatar Markus Holzer
Browse files

Implemented Logarithm

parent ea943334
Branches
Tags
No related merge requests found
...@@ -10,10 +10,11 @@ from pystencils.field import Field ...@@ -10,10 +10,11 @@ from pystencils.field import Field
from pystencils.node_collection import NodeCollection from pystencils.node_collection import NodeCollection
from pystencils.transformations import NestedScopes from pystencils.transformations import NestedScopes
# TODO use this in Constraint Checker
accepted_functions = [ accepted_functions = [
sp.Pow, sp.Pow,
sp.sqrt, sp.sqrt,
sp.log,
# TODO trigonometric functions (and whatever tests will fail) # TODO trigonometric functions (and whatever tests will fail)
] ]
......
...@@ -212,7 +212,7 @@ class TypeAdder: ...@@ -212,7 +212,7 @@ class TypeAdder:
new_args.append(a) new_args.append(a)
return expr.func(*new_args) if new_args else expr, collated_type return expr.func(*new_args) if new_args else expr, collated_type
elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction, elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction,
HyperbolicFunction)): HyperbolicFunction, sp.log)):
args_types = [self.figure_out_type(arg) for arg in expr.args] args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types]) collated_type = collate_types([t for _, t in args_types])
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types] new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
......
import pytest
import numpy as np
import sympy as sp
import pystencils as ps
@pytest.mark.parametrize('dtype', ["float64", "float32"])
def test_log(dtype):
a = sp.Symbol("a")
x = ps.fields(f'x: {dtype}[1d]')
assignments = ps.AssignmentCollection({x.center(): sp.log(a)})
ast = ps.create_kernel(assignments)
code = ps.get_code_str(ast)
kernel = ast.compile()
# ps.show_code(ast)
if dtype == "float64":
assert "float" not in code
array = np.zeros((10,), dtype=dtype)
kernel(x=array, a=100)
assert np.allclose(array, 4.60517019)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment