Skip to content
Snippets Groups Projects
Commit 19f54169 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Allow custom backends for cpujit and gpujit

parent 8a651aa4
No related merge requests found
...@@ -37,7 +37,7 @@ class UnsupportedCDialect(Exception): ...@@ -37,7 +37,7 @@ class UnsupportedCDialect(Exception):
super(UnsupportedCDialect, self).__init__() super(UnsupportedCDialect, self).__init__()
def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str: def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom_backend=None) -> str:
"""Prints an abstract syntax tree node as C or CUDA code. """Prints an abstract syntax tree node as C or CUDA code.
This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
...@@ -57,8 +57,9 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str ...@@ -57,8 +57,9 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str
ast_node.global_variables.update(d.symbols_defined) ast_node.global_variables.update(d.symbols_defined)
else: else:
ast_node.global_variables = d.symbols_defined ast_node.global_variables = d.symbols_defined
if custom_backend:
if dialect == 'c': printer = custom_backend
elif dialect == 'c':
printer = CBackend(signature_only=signature_only, printer = CBackend(signature_only=signature_only,
vector_instruction_set=ast_node.instruction_set) vector_instruction_set=ast_node.instruction_set)
elif dialect == 'cuda': elif dialect == 'cuda':
......
...@@ -43,28 +43,28 @@ Then 'cl.exe' is used to compile. ...@@ -43,28 +43,28 @@ Then 'cl.exe' is used to compile.
For Windows compilers the qualifier should be ``__restrict`` For Windows compilers the qualifier should be ``__restrict``
""" """
import os
import hashlib import hashlib
import json import json
import os
import platform import platform
import shutil import shutil
import subprocess
import textwrap import textwrap
from collections import OrderedDict
from sysconfig import get_paths
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import numpy as np import numpy as np
import subprocess from appdirs import user_cache_dir, user_config_dir
from appdirs import user_config_dir, user_cache_dir
from collections import OrderedDict
from pystencils.utils import recursive_dict_update
from sysconfig import get_paths
from pystencils import FieldType from pystencils import FieldType
from pystencils.backends.cbackend import generate_c, get_headers from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.utils import file_handle_for_atomic_write, atomic_file_write
from pystencils.include import get_pystencils_include_path from pystencils.include import get_pystencils_include_path
from pystencils.utils import (atomic_file_write, file_handle_for_atomic_write,
recursive_dict_update)
def make_python_function(kernel_function_node): def make_python_function(kernel_function_node, custom_backend=None):
""" """
Creates C code from the abstract syntax tree, compiles it and makes it accessible as Python function Creates C code from the abstract syntax tree, compiles it and makes it accessible as Python function
...@@ -75,7 +75,7 @@ def make_python_function(kernel_function_node): ...@@ -75,7 +75,7 @@ def make_python_function(kernel_function_node):
:param kernel_function_node: the abstract syntax tree :param kernel_function_node: the abstract syntax tree
:return: kernel functor :return: kernel functor
""" """
result = compile_and_load(kernel_function_node) result = compile_and_load(kernel_function_node, custom_backend)
return result return result
...@@ -424,11 +424,12 @@ def run_compile_step(command): ...@@ -424,11 +424,12 @@ def run_compile_step(command):
class ExtensionModuleCode: class ExtensionModuleCode:
def __init__(self, module_name='generated'): def __init__(self, module_name='generated', custom_backend=None):
self.module_name = module_name self.module_name = module_name
self._ast_nodes = [] self._ast_nodes = []
self._function_names = [] self._function_names = []
self._custom_backend = custom_backend
def add_function(self, ast, name=None): def add_function(self, ast, name=None):
self._ast_nodes.append(ast) self._ast_nodes.append(ast)
...@@ -452,7 +453,7 @@ class ExtensionModuleCode: ...@@ -452,7 +453,7 @@ class ExtensionModuleCode:
for ast, name in zip(self._ast_nodes, self._function_names): for ast, name in zip(self._ast_nodes, self._function_names):
old_name = ast.function_name old_name = ast.function_name
ast.function_name = "kernel_" + name ast.function_name = "kernel_" + name
print(generate_c(ast), file=file) print(generate_c(ast, custom_backend=self._custom_backend), file=file)
print(create_function_boilerplate_code(ast.get_parameters(), name), file=file) print(create_function_boilerplate_code(ast.get_parameters(), name), file=file)
ast.function_name = old_name ast.function_name = old_name
print(create_module_boilerplate_code(self.module_name, self._function_names), file=file) print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
...@@ -515,10 +516,11 @@ def compile_module(code, code_hash, base_dir): ...@@ -515,10 +516,11 @@ def compile_module(code, code_hash, base_dir):
return lib_file return lib_file
def compile_and_load(ast): def compile_and_load(ast, custom_backend=None):
cache_config = get_cache_config() cache_config = get_cache_config()
code_hash_str = "mod_" + hashlib.sha256(generate_c(ast, dialect='c').encode()).hexdigest() code_hash_str = "mod_" + hashlib.sha256(generate_c(ast, dialect='c',
code = ExtensionModuleCode(module_name=code_hash_str) custom_backend=custom_backend).encode()).hexdigest()
code = ExtensionModuleCode(module_name=code_hash_str, custom_backend=custom_backend)
code.add_function(ast, ast.function_name) code.add_function(ast, ast.function_name)
if cache_config['object_cache'] is False: if cache_config['object_cache'] is False:
......
...@@ -9,7 +9,7 @@ from pystencils.include import get_pystencils_include_path ...@@ -9,7 +9,7 @@ from pystencils.include import get_pystencils_include_path
USE_FAST_MATH = True USE_FAST_MATH = True
def make_python_function(kernel_function_node, argument_dict=None): def make_python_function(kernel_function_node, argument_dict=None, custom_backend=None):
""" """
Creates a kernel function from an abstract syntax tree which Creates a kernel function from an abstract syntax tree which
was created e.g. by :func:`pystencils.gpucuda.create_cuda_kernel` was created e.g. by :func:`pystencils.gpucuda.create_cuda_kernel`
...@@ -35,7 +35,7 @@ def make_python_function(kernel_function_node, argument_dict=None): ...@@ -35,7 +35,7 @@ def make_python_function(kernel_function_node, argument_dict=None):
code = includes + "\n" code = includes + "\n"
code += "#define FUNC_PREFIX __global__\n" code += "#define FUNC_PREFIX __global__\n"
code += "#define RESTRICT __restrict__\n\n" code += "#define RESTRICT __restrict__\n\n"
code += str(generate_c(kernel_function_node, dialect='cuda')) code += str(generate_c(kernel_function_node, dialect='cuda', custom_backend=custom_backend))
options = ["-w", "-std=c++11", "-Wno-deprecated-gpu-targets"] options = ["-w", "-std=c++11", "-Wno-deprecated-gpu-targets"]
if USE_FAST_MATH: if USE_FAST_MATH:
options.append("-use_fast_math") options.append("-use_fast_math")
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment