Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 2586 additions and 219 deletions
......@@ -6,9 +6,3 @@ try:
__all__.append('print_dot')
except ImportError:
pass
try:
from .llvm import generate_llvm # NOQA
__all__.append('generate_llvm')
except ImportError:
pass
This diff is collapsed.
from pystencils.astnodes import Node
from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
from pystencils.enums import Backend
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
def generate_cuda(ast_node: Node, signature_only: bool = False, custom_backend=None, with_globals=True) -> str:
"""Prints an abstract syntax tree node as CUDA code.
Args:
ast_node: ast representation of kernel
signature_only: generate signature without function body
custom_backend: use own custom printer for code generation
with_globals: enable usage of global variables
Returns:
CUDA code for the ast node and its descendants
"""
return generate_c(ast_node, signature_only, dialect=Backend.CUDA,
custom_backend=custom_backend, with_globals=with_globals)
class CudaBackend(CBackend):
def __init__(self, sympy_printer=None,
signature_only=False):
if not sympy_printer:
sympy_printer = CudaSympyPrinter()
super().__init__(sympy_printer, signature_only, dialect=Backend.CUDA)
def _print_SharedMemoryAllocation(self, node):
dtype = node.symbol.dtype
name = self.sympy_printer.doprint(node.symbol.name)
num_elements = '*'.join([str(s) for s in node.shared_mem.shape])
code = f"__shared__ {dtype} {name}[{num_elements}];"
return code
@staticmethod
def _print_ThreadBlockSynchronization(_):
return "__synchtreads();"
def _print_TextureDeclaration(self, node):
cond = node.texture.field.dtype.numpy_dtype.itemsize > 4
return f'texture<{"fp_tex_" if cond else ""}{str(node.texture.field.dtype)}, ' \
f'cudaTextureType{node.texture.field.spacial_dimensions}D, cudaReadModeElementType> {node.texture};'
def _print_SkipIteration(self, _):
return "return;"
class CudaSympyPrinter(CustomSympyPrinter):
language = "CUDA"
def __init__(self):
super(CudaSympyPrinter, self).__init__()
def _print_Function(self, expr):
if isinstance(expr, fast_division):
assert len(expr.args) == 2, f"__fdividef has two arguments, but {len(expr.args)} where given"
return f"__fdividef({self._print(expr.args[0])}, {self._print(expr.args[1])})"
elif isinstance(expr, fast_sqrt):
assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__fsqrt_rn({self._print(expr.args[0])})"
elif isinstance(expr, fast_inv_sqrt):
assert len(expr.args) == 1, f"__frsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__frsqrt_rn({self._print(expr.args[0])})"
return super()._print_Function(expr)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import sympy as sp
from pystencils import Field, TypedSymbol
from pystencils.integer_functions import bitwise_and
from pystencils.boundaries.boundaryhandling import DEFAULT_FLAG_TYPE
from pystencils.data_types import create_type
from pystencils.typing import TypedSymbol, create_type
from pystencils.field import Field
from pystencils.integer_functions import bitwise_and
def add_neumann_boundary(eqs, fields, flag_field, boundary_flag="neumann_flag", inverse_flag=False):
"""
Replaces all neighbor accesses by flag field guarded accesses.
If flag in neighboring cell is set, the center value is used instead
:param eqs: list of equations containing field accesses to direct neighbors
:param fields: fields for which the Neumann boundary should be applied
:param flag_field: integer field marking boundary cells
:param boundary_flag: if flag field has value 'boundary_flag' (no bit operations yet)
the cell is assumed to be boundary
:param inverse_flag: if true, boundary cells are where flag field has not the value of boundary_flag
:return: list of equations with guarded field accesses
Args:
eqs: list of equations containing field accesses to direct neighbors
fields: fields for which the Neumann boundary should be applied
flag_field: integer field marking boundary cells
boundary_flag: if flag field has value 'boundary_flag' (no bit operations yet)
the cell is assumed to be boundary
inverse_flag: if true, boundary cells are where flag field has not the value of boundary_flag
Returns:
list of equations with guarded field accesses
"""
if not hasattr(fields, "__len__"):
fields = [fields]
......