From 19f54169af547d8d33b42b5948b174c750476059 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 15 Jul 2019 09:18:36 +0200
Subject: [PATCH] Allow custom backends for cpujit and gpujit

---
 pystencils/backends/cbackend.py |  7 ++++---
 pystencils/cpu/cpujit.py        | 30 ++++++++++++++++--------------
 pystencils/gpucuda/cudajit.py   |  4 ++--
 3 files changed, 22 insertions(+), 19 deletions(-)

diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index cbd75e1..9cab5a2 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -37,7 +37,7 @@ class UnsupportedCDialect(Exception):
         super(UnsupportedCDialect, self).__init__()
 
 
-def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str:
+def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom_backend=None) -> str:
     """Prints an abstract syntax tree node as C or CUDA code.
 
     This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
@@ -57,8 +57,9 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str
             ast_node.global_variables.update(d.symbols_defined)
         else:
             ast_node.global_variables = d.symbols_defined
-
-    if dialect == 'c':
+    if custom_backend:
+        printer = custom_backend
+    elif dialect == 'c':
         printer = CBackend(signature_only=signature_only,
                            vector_instruction_set=ast_node.instruction_set)
     elif dialect == 'cuda':
diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py
index 0912836..76255c1 100644
--- a/pystencils/cpu/cpujit.py
+++ b/pystencils/cpu/cpujit.py
@@ -43,28 +43,28 @@ Then 'cl.exe' is used to compile.
   For Windows compilers the qualifier should be ``__restrict``
 
 """
-import os
 import hashlib
 import json
+import os
 import platform
 import shutil
+import subprocess
 import textwrap
+from collections import OrderedDict
+from sysconfig import get_paths
 from tempfile import TemporaryDirectory
 
 import numpy as np
-import subprocess
-from appdirs import user_config_dir, user_cache_dir
-from collections import OrderedDict
+from appdirs import user_cache_dir, user_config_dir
 
-from pystencils.utils import recursive_dict_update
-from sysconfig import get_paths
 from pystencils import FieldType
 from pystencils.backends.cbackend import generate_c, get_headers
-from pystencils.utils import file_handle_for_atomic_write, atomic_file_write
 from pystencils.include import get_pystencils_include_path
+from pystencils.utils import (atomic_file_write, file_handle_for_atomic_write,
+                              recursive_dict_update)
 
 
-def make_python_function(kernel_function_node):
+def make_python_function(kernel_function_node, custom_backend=None):
     """
     Creates C code from the abstract syntax tree, compiles it and makes it accessible as Python function
 
@@ -75,7 +75,7 @@ def make_python_function(kernel_function_node):
     :param kernel_function_node: the abstract syntax tree
     :return: kernel functor
     """
-    result = compile_and_load(kernel_function_node)
+    result = compile_and_load(kernel_function_node, custom_backend)
     return result
 
 
@@ -424,11 +424,12 @@ def run_compile_step(command):
 
 
 class ExtensionModuleCode:
-    def __init__(self, module_name='generated'):
+    def __init__(self, module_name='generated', custom_backend=None):
         self.module_name = module_name
 
         self._ast_nodes = []
         self._function_names = []
+        self._custom_backend = custom_backend
 
     def add_function(self, ast, name=None):
         self._ast_nodes.append(ast)
@@ -452,7 +453,7 @@ class ExtensionModuleCode:
         for ast, name in zip(self._ast_nodes, self._function_names):
             old_name = ast.function_name
             ast.function_name = "kernel_" + name
-            print(generate_c(ast), file=file)
+            print(generate_c(ast, custom_backend=self._custom_backend), file=file)
             print(create_function_boilerplate_code(ast.get_parameters(), name), file=file)
             ast.function_name = old_name
         print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
@@ -515,10 +516,11 @@ def compile_module(code, code_hash, base_dir):
     return lib_file
 
 
-def compile_and_load(ast):
+def compile_and_load(ast, custom_backend=None):
     cache_config = get_cache_config()
-    code_hash_str = "mod_" + hashlib.sha256(generate_c(ast, dialect='c').encode()).hexdigest()
-    code = ExtensionModuleCode(module_name=code_hash_str)
+    code_hash_str = "mod_" + hashlib.sha256(generate_c(ast, dialect='c',
+                                                       custom_backend=custom_backend).encode()).hexdigest()
+    code = ExtensionModuleCode(module_name=code_hash_str, custom_backend=custom_backend)
     code.add_function(ast, ast.function_name)
 
     if cache_config['object_cache'] is False:
diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py
index 90c23b1..2146543 100644
--- a/pystencils/gpucuda/cudajit.py
+++ b/pystencils/gpucuda/cudajit.py
@@ -9,7 +9,7 @@ from pystencils.include import get_pystencils_include_path
 USE_FAST_MATH = True
 
 
-def make_python_function(kernel_function_node, argument_dict=None):
+def make_python_function(kernel_function_node, argument_dict=None, custom_backend=None):
     """
     Creates a kernel function from an abstract syntax tree which
     was created e.g. by :func:`pystencils.gpucuda.create_cuda_kernel`
@@ -35,7 +35,7 @@ def make_python_function(kernel_function_node, argument_dict=None):
     code = includes + "\n"
     code += "#define FUNC_PREFIX __global__\n"
     code += "#define RESTRICT __restrict__\n\n"
-    code += str(generate_c(kernel_function_node, dialect='cuda'))
+    code += str(generate_c(kernel_function_node, dialect='cuda', custom_backend=custom_backend))
     options = ["-w", "-std=c++11", "-Wno-deprecated-gpu-targets"]
     if USE_FAST_MATH:
         options.append("-use_fast_math")
-- 
GitLab