From 19f54169af547d8d33b42b5948b174c750476059 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 15 Jul 2019 09:18:36 +0200 Subject: [PATCH] Allow custom backends for cpujit and gpujit --- pystencils/backends/cbackend.py | 7 ++++--- pystencils/cpu/cpujit.py | 30 ++++++++++++++++-------------- pystencils/gpucuda/cudajit.py | 4 ++-- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index cbd75e1..9cab5a2 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -37,7 +37,7 @@ class UnsupportedCDialect(Exception): super(UnsupportedCDialect, self).__init__() -def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str: +def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom_backend=None) -> str: """Prints an abstract syntax tree node as C or CUDA code. This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded @@ -57,8 +57,9 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str ast_node.global_variables.update(d.symbols_defined) else: ast_node.global_variables = d.symbols_defined - - if dialect == 'c': + if custom_backend: + printer = custom_backend + elif dialect == 'c': printer = CBackend(signature_only=signature_only, vector_instruction_set=ast_node.instruction_set) elif dialect == 'cuda': diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py index 0912836..76255c1 100644 --- a/pystencils/cpu/cpujit.py +++ b/pystencils/cpu/cpujit.py @@ -43,28 +43,28 @@ Then 'cl.exe' is used to compile. For Windows compilers the qualifier should be ``__restrict`` """ -import os import hashlib import json +import os import platform import shutil +import subprocess import textwrap +from collections import OrderedDict +from sysconfig import get_paths from tempfile import TemporaryDirectory import numpy as np -import subprocess -from appdirs import user_config_dir, user_cache_dir -from collections import OrderedDict +from appdirs import user_cache_dir, user_config_dir -from pystencils.utils import recursive_dict_update -from sysconfig import get_paths from pystencils import FieldType from pystencils.backends.cbackend import generate_c, get_headers -from pystencils.utils import file_handle_for_atomic_write, atomic_file_write from pystencils.include import get_pystencils_include_path +from pystencils.utils import (atomic_file_write, file_handle_for_atomic_write, + recursive_dict_update) -def make_python_function(kernel_function_node): +def make_python_function(kernel_function_node, custom_backend=None): """ Creates C code from the abstract syntax tree, compiles it and makes it accessible as Python function @@ -75,7 +75,7 @@ def make_python_function(kernel_function_node): :param kernel_function_node: the abstract syntax tree :return: kernel functor """ - result = compile_and_load(kernel_function_node) + result = compile_and_load(kernel_function_node, custom_backend) return result @@ -424,11 +424,12 @@ def run_compile_step(command): class ExtensionModuleCode: - def __init__(self, module_name='generated'): + def __init__(self, module_name='generated', custom_backend=None): self.module_name = module_name self._ast_nodes = [] self._function_names = [] + self._custom_backend = custom_backend def add_function(self, ast, name=None): self._ast_nodes.append(ast) @@ -452,7 +453,7 @@ class ExtensionModuleCode: for ast, name in zip(self._ast_nodes, self._function_names): old_name = ast.function_name ast.function_name = "kernel_" + name - print(generate_c(ast), file=file) + print(generate_c(ast, custom_backend=self._custom_backend), file=file) print(create_function_boilerplate_code(ast.get_parameters(), name), file=file) ast.function_name = old_name print(create_module_boilerplate_code(self.module_name, self._function_names), file=file) @@ -515,10 +516,11 @@ def compile_module(code, code_hash, base_dir): return lib_file -def compile_and_load(ast): +def compile_and_load(ast, custom_backend=None): cache_config = get_cache_config() - code_hash_str = "mod_" + hashlib.sha256(generate_c(ast, dialect='c').encode()).hexdigest() - code = ExtensionModuleCode(module_name=code_hash_str) + code_hash_str = "mod_" + hashlib.sha256(generate_c(ast, dialect='c', + custom_backend=custom_backend).encode()).hexdigest() + code = ExtensionModuleCode(module_name=code_hash_str, custom_backend=custom_backend) code.add_function(ast, ast.function_name) if cache_config['object_cache'] is False: diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py index 90c23b1..2146543 100644 --- a/pystencils/gpucuda/cudajit.py +++ b/pystencils/gpucuda/cudajit.py @@ -9,7 +9,7 @@ from pystencils.include import get_pystencils_include_path USE_FAST_MATH = True -def make_python_function(kernel_function_node, argument_dict=None): +def make_python_function(kernel_function_node, argument_dict=None, custom_backend=None): """ Creates a kernel function from an abstract syntax tree which was created e.g. by :func:`pystencils.gpucuda.create_cuda_kernel` @@ -35,7 +35,7 @@ def make_python_function(kernel_function_node, argument_dict=None): code = includes + "\n" code += "#define FUNC_PREFIX __global__\n" code += "#define RESTRICT __restrict__\n\n" - code += str(generate_c(kernel_function_node, dialect='cuda')) + code += str(generate_c(kernel_function_node, dialect='cuda', custom_backend=custom_backend)) options = ["-w", "-std=c++11", "-Wno-deprecated-gpu-targets"] if USE_FAST_MATH: options.append("-use_fast_math") -- GitLab