Skip to content
Snippets Groups Projects
Commit ed8b7430 authored by Markus Holzer's avatar Markus Holzer
Browse files

Fixed print Function in CUDA Backend

parent 9ded23ce
Branches
Tags
No related merge requests found
......@@ -52,7 +52,7 @@ pip install pystencils[interactive]
Without `[interactive]` you get a minimal version with very little dependencies.
All options:
- `gpu`: use this if an Nvidia GPU is available and CUDA is installed
- `gpu`: use this if an NVIDIA GPU is available and CUDA is installed
- `opencl`: basic OpenCL support (experimental)
- `alltrafos`: pulls in additional dependencies for loop simplification e.g. libisl
- `bench_db`: functionality to store benchmark result in object databases
......
......@@ -33,10 +33,11 @@ class CudaBackend(CBackend):
super().__init__(sympy_printer, signature_only, dialect='cuda')
def _print_SharedMemoryAllocation(self, node):
code = "__shared__ {dtype} {name}[{num_elements}];"
return code.format(dtype=node.symbol.dtype,
name=self.sympy_printer.doprint(node.symbol.name),
num_elements='*'.join([str(s) for s in node.shared_mem.shape]))
dtype = node.symbol.dtype
name = self.sympy_printer.doprint(node.symbol.name)
num_elements = '*'.join([str(s) for s in node.shared_mem.shape])
code = f"__shared__ {dtype} {name}[{num_elements}];"
return code
@staticmethod
def _print_ThreadBlockSynchronization(node):
......@@ -45,6 +46,7 @@ class CudaBackend(CBackend):
def _print_TextureDeclaration(self, node):
# TODO: use fStrings here
if node.texture.field.dtype.numpy_dtype.itemsize > 4:
code = "texture<fp_tex_%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
str(node.texture.field.dtype),
......@@ -96,9 +98,13 @@ class CudaSympyPrinter(CustomSympyPrinter):
def _print_Function(self, expr):
if isinstance(expr, fast_division):
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
assert len(expr.args) == 2, f"__fdividef has two arguments, but {len(expr.args)} where given"
return f"__fdividef({self._print(expr.args[0])}, {self._print(expr.args[1])})"
elif isinstance(expr, fast_sqrt):
return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__fsqrt_rn({self._print(expr.args[0])})"
elif isinstance(expr, fast_inv_sqrt):
return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
print(len(expr.args) == 1)
assert len(expr.args) == 1, f"__frsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__frsqrt_rn({self._print(expr.args[0])})"
return super()._print_Function(expr)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment