Commit ed8b7430 authored by Markus Holzer's avatar Markus Holzer
Browse files

Fixed print Function in CUDA Backend

parent 9ded23ce
......@@ -52,7 +52,7 @@ pip install pystencils[interactive]
Without `[interactive]` you get a minimal version with very little dependencies.
All options:
- `gpu`: use this if an Nvidia GPU is available and CUDA is installed
- `gpu`: use this if an NVIDIA GPU is available and CUDA is installed
- `opencl`: basic OpenCL support (experimental)
- `alltrafos`: pulls in additional dependencies for loop simplification e.g. libisl
- `bench_db`: functionality to store benchmark result in object databases
......
......@@ -33,10 +33,11 @@ class CudaBackend(CBackend):
super().__init__(sympy_printer, signature_only, dialect='cuda')
def _print_SharedMemoryAllocation(self, node):
code = "__shared__ {dtype} {name}[{num_elements}];"
return code.format(dtype=node.symbol.dtype,
name=self.sympy_printer.doprint(node.symbol.name),
num_elements='*'.join([str(s) for s in node.shared_mem.shape]))
dtype = node.symbol.dtype
name = self.sympy_printer.doprint(node.symbol.name)
num_elements = '*'.join([str(s) for s in node.shared_mem.shape])
code = f"__shared__ {dtype} {name}[{num_elements}];"
return code
@staticmethod
def _print_ThreadBlockSynchronization(node):
......@@ -45,6 +46,7 @@ class CudaBackend(CBackend):
def _print_TextureDeclaration(self, node):
# TODO: use fStrings here
if node.texture.field.dtype.numpy_dtype.itemsize > 4:
code = "texture<fp_tex_%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
str(node.texture.field.dtype),
......@@ -96,9 +98,13 @@ class CudaSympyPrinter(CustomSympyPrinter):
def _print_Function(self, expr):
if isinstance(expr, fast_division):
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
assert len(expr.args) == 2, f"__fdividef has two arguments, but {len(expr.args)} where given"
return f"__fdividef({self._print(expr.args[0])}, {self._print(expr.args[1])})"
elif isinstance(expr, fast_sqrt):
return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__fsqrt_rn({self._print(expr.args[0])})"
elif isinstance(expr, fast_inv_sqrt):
return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
print(len(expr.args) == 1)
assert len(expr.args) == 1, f"__frsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__frsqrt_rn({self._print(expr.args[0])})"
return super()._print_Function(expr)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment