Skip to content
Snippets Groups Projects
Commit ed8b7430 authored by Markus Holzer's avatar Markus Holzer
Browse files

Fixed print Function in CUDA Backend

parent 9ded23ce
No related merge requests found
...@@ -52,7 +52,7 @@ pip install pystencils[interactive] ...@@ -52,7 +52,7 @@ pip install pystencils[interactive]
Without `[interactive]` you get a minimal version with very little dependencies. Without `[interactive]` you get a minimal version with very little dependencies.
All options: All options:
- `gpu`: use this if an Nvidia GPU is available and CUDA is installed - `gpu`: use this if an NVIDIA GPU is available and CUDA is installed
- `opencl`: basic OpenCL support (experimental) - `opencl`: basic OpenCL support (experimental)
- `alltrafos`: pulls in additional dependencies for loop simplification e.g. libisl - `alltrafos`: pulls in additional dependencies for loop simplification e.g. libisl
- `bench_db`: functionality to store benchmark result in object databases - `bench_db`: functionality to store benchmark result in object databases
......
...@@ -33,10 +33,11 @@ class CudaBackend(CBackend): ...@@ -33,10 +33,11 @@ class CudaBackend(CBackend):
super().__init__(sympy_printer, signature_only, dialect='cuda') super().__init__(sympy_printer, signature_only, dialect='cuda')
def _print_SharedMemoryAllocation(self, node): def _print_SharedMemoryAllocation(self, node):
code = "__shared__ {dtype} {name}[{num_elements}];" dtype = node.symbol.dtype
return code.format(dtype=node.symbol.dtype, name = self.sympy_printer.doprint(node.symbol.name)
name=self.sympy_printer.doprint(node.symbol.name), num_elements = '*'.join([str(s) for s in node.shared_mem.shape])
num_elements='*'.join([str(s) for s in node.shared_mem.shape])) code = f"__shared__ {dtype} {name}[{num_elements}];"
return code
@staticmethod @staticmethod
def _print_ThreadBlockSynchronization(node): def _print_ThreadBlockSynchronization(node):
...@@ -45,6 +46,7 @@ class CudaBackend(CBackend): ...@@ -45,6 +46,7 @@ class CudaBackend(CBackend):
def _print_TextureDeclaration(self, node): def _print_TextureDeclaration(self, node):
# TODO: use fStrings here
if node.texture.field.dtype.numpy_dtype.itemsize > 4: if node.texture.field.dtype.numpy_dtype.itemsize > 4:
code = "texture<fp_tex_%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % ( code = "texture<fp_tex_%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
str(node.texture.field.dtype), str(node.texture.field.dtype),
...@@ -96,9 +98,13 @@ class CudaSympyPrinter(CustomSympyPrinter): ...@@ -96,9 +98,13 @@ class CudaSympyPrinter(CustomSympyPrinter):
def _print_Function(self, expr): def _print_Function(self, expr):
if isinstance(expr, fast_division): if isinstance(expr, fast_division):
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args) assert len(expr.args) == 2, f"__fdividef has two arguments, but {len(expr.args)} where given"
return f"__fdividef({self._print(expr.args[0])}, {self._print(expr.args[1])})"
elif isinstance(expr, fast_sqrt): elif isinstance(expr, fast_sqrt):
return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__fsqrt_rn({self._print(expr.args[0])})"
elif isinstance(expr, fast_inv_sqrt): elif isinstance(expr, fast_inv_sqrt):
return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) print(len(expr.args) == 1)
assert len(expr.args) == 1, f"__frsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__frsqrt_rn({self._print(expr.args[0])})"
return super()._print_Function(expr) return super()._print_Function(expr)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment