Skip to content
Snippets Groups Projects
Commit b44b0045 authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'Fix_Wrong_fString' into 'master'

Fix: Wrong fString in Cuda Backend

See merge request pycodegen/pystencils!159
parents 82af488a ed8b7430
Branches
Tags
1 merge request!159Fix: Wrong fString in Cuda Backend
Pipeline #24585 passed with stages
in 43 minutes and 24 seconds
......@@ -52,7 +52,7 @@ pip install pystencils[interactive]
Without `[interactive]` you get a minimal version with very little dependencies.
All options:
- `gpu`: use this if an Nvidia GPU is available and CUDA is installed
- `gpu`: use this if an NVIDIA GPU is available and CUDA is installed
- `opencl`: basic OpenCL support (experimental)
- `alltrafos`: pulls in additional dependencies for loop simplification e.g. libisl
- `bench_db`: functionality to store benchmark result in object databases
......
......@@ -33,10 +33,11 @@ class CudaBackend(CBackend):
super().__init__(sympy_printer, signature_only, dialect='cuda')
def _print_SharedMemoryAllocation(self, node):
code = "__shared__ {dtype} {name}[{num_elements}];"
return code.format(dtype=node.symbol.dtype,
name=self.sympy_printer.doprint(node.symbol.name),
num_elements='*'.join([str(s) for s in node.shared_mem.shape]))
dtype = node.symbol.dtype
name = self.sympy_printer.doprint(node.symbol.name)
num_elements = '*'.join([str(s) for s in node.shared_mem.shape])
code = f"__shared__ {dtype} {name}[{num_elements}];"
return code
@staticmethod
def _print_ThreadBlockSynchronization(node):
......@@ -45,6 +46,7 @@ class CudaBackend(CBackend):
def _print_TextureDeclaration(self, node):
# TODO: use fStrings here
if node.texture.field.dtype.numpy_dtype.itemsize > 4:
code = "texture<fp_tex_%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
str(node.texture.field.dtype),
......@@ -96,9 +98,13 @@ class CudaSympyPrinter(CustomSympyPrinter):
def _print_Function(self, expr):
if isinstance(expr, fast_division):
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
assert len(expr.args) == 2, f"__fdividef has two arguments, but {len(expr.args)} where given"
return f"__fdividef({self._print(expr.args[0])}, {self._print(expr.args[1])})"
elif isinstance(expr, fast_sqrt):
return f"__fsqrt_rn({tuple(self._print(a) for a in expr.args)})"
assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__fsqrt_rn({self._print(expr.args[0])})"
elif isinstance(expr, fast_inv_sqrt):
return f"__frsqrt_rn({tuple(self._print(a) for a in expr.args)})"
print(len(expr.args) == 1)
assert len(expr.args) == 1, f"__frsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__frsqrt_rn({self._print(expr.args[0])})"
return super()._print_Function(expr)
......@@ -12,6 +12,7 @@ def test_fast_sqrt():
assert len(insert_fast_sqrts(expr).atoms(fast_sqrt)) == 1
assert len(insert_fast_sqrts([expr])[0].atoms(fast_sqrt)) == 1
ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_sqrts(expr)), target='gpu')
ast.compile()
code_str = ps.get_code_str(ast)
assert '__fsqrt_rn' in code_str
......@@ -21,6 +22,7 @@ def test_fast_sqrt():
ac = ps.AssignmentCollection([expr], [])
assert len(insert_fast_sqrts(ac).main_assignments[0].atoms(fast_inv_sqrt)) == 1
ast = ps.create_kernel(insert_fast_sqrts(ac), target='gpu')
ast.compile()
code_str = ps.get_code_str(ast)
assert '__frsqrt_rn' in code_str
......@@ -34,5 +36,6 @@ def test_fast_divisions():
assert len(insert_fast_divisions(expr).atoms(fast_division)) == 1
ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_divisions(expr)), target='gpu')
ast.compile()
code_str = ps.get_code_str(ast)
assert '__fdividef' in code_str
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment