Commit 19f54169 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Allow custom backends for cpujit and gpujit

parent 8a651aa4
......@@ -37,7 +37,7 @@ class UnsupportedCDialect(Exception):
super(UnsupportedCDialect, self).__init__()
def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str:
def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom_backend=None) -> str:
"""Prints an abstract syntax tree node as C or CUDA code.
This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
......@@ -57,8 +57,9 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str
ast_node.global_variables = d.symbols_defined
if dialect == 'c':
if custom_backend:
printer = custom_backend
elif dialect == 'c':
printer = CBackend(signature_only=signature_only,
elif dialect == 'cuda':
......@@ -43,28 +43,28 @@ Then 'cl.exe' is used to compile.
For Windows compilers the qualifier should be ``__restrict``
import os
import hashlib
import json
import os
import platform
import shutil
import subprocess
import textwrap
from collections import OrderedDict
from sysconfig import get_paths
from tempfile import TemporaryDirectory
import numpy as np
import subprocess
from appdirs import user_config_dir, user_cache_dir
from collections import OrderedDict
from appdirs import user_cache_dir, user_config_dir
from pystencils.utils import recursive_dict_update
from sysconfig import get_paths
from pystencils import FieldType
from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.utils import file_handle_for_atomic_write, atomic_file_write
from pystencils.include import get_pystencils_include_path
from pystencils.utils import (atomic_file_write, file_handle_for_atomic_write,
def make_python_function(kernel_function_node):
def make_python_function(kernel_function_node, custom_backend=None):
Creates C code from the abstract syntax tree, compiles it and makes it accessible as Python function
......@@ -75,7 +75,7 @@ def make_python_function(kernel_function_node):
:param kernel_function_node: the abstract syntax tree
:return: kernel functor
result = compile_and_load(kernel_function_node)
result = compile_and_load(kernel_function_node, custom_backend)
return result
......@@ -424,11 +424,12 @@ def run_compile_step(command):
class ExtensionModuleCode:
def __init__(self, module_name='generated'):
def __init__(self, module_name='generated', custom_backend=None):
self.module_name = module_name
self._ast_nodes = []
self._function_names = []
self._custom_backend = custom_backend
def add_function(self, ast, name=None):
......@@ -452,7 +453,7 @@ class ExtensionModuleCode:
for ast, name in zip(self._ast_nodes, self._function_names):
old_name = ast.function_name
ast.function_name = "kernel_" + name
print(generate_c(ast), file=file)
print(generate_c(ast, custom_backend=self._custom_backend), file=file)
print(create_function_boilerplate_code(ast.get_parameters(), name), file=file)
ast.function_name = old_name
print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
......@@ -515,10 +516,11 @@ def compile_module(code, code_hash, base_dir):
return lib_file
def compile_and_load(ast):
def compile_and_load(ast, custom_backend=None):
cache_config = get_cache_config()
code_hash_str = "mod_" + hashlib.sha256(generate_c(ast, dialect='c').encode()).hexdigest()
code = ExtensionModuleCode(module_name=code_hash_str)
code_hash_str = "mod_" + hashlib.sha256(generate_c(ast, dialect='c',
code = ExtensionModuleCode(module_name=code_hash_str, custom_backend=custom_backend)
code.add_function(ast, ast.function_name)
if cache_config['object_cache'] is False:
......@@ -9,7 +9,7 @@ from pystencils.include import get_pystencils_include_path
def make_python_function(kernel_function_node, argument_dict=None):
def make_python_function(kernel_function_node, argument_dict=None, custom_backend=None):
Creates a kernel function from an abstract syntax tree which
was created e.g. by :func:`pystencils.gpucuda.create_cuda_kernel`
......@@ -35,7 +35,7 @@ def make_python_function(kernel_function_node, argument_dict=None):
code = includes + "\n"
code += "#define FUNC_PREFIX __global__\n"
code += "#define RESTRICT __restrict__\n\n"
code += str(generate_c(kernel_function_node, dialect='cuda'))
code += str(generate_c(kernel_function_node, dialect='cuda', custom_backend=custom_backend))
options = ["-w", "-std=c++11", "-Wno-deprecated-gpu-targets"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment