Commit 99872129 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add possibility to override nvcc arch for tensorflow_jit and compile...

Add possibility to override nvcc arch for tensorflow_jit and compile tensorflow module without loading
parent 97d525cd
Pipeline #17945 failed with stage
in 2 minutes and 24 seconds
......@@ -8,6 +8,7 @@
"""
import hashlib
import os
import subprocess
import sysconfig
from os.path import exists, join
......@@ -45,7 +46,7 @@ except ImportError:
pass
def link_and_load(object_files, destination_file=None, overwrite_destination_file=True, additional_link_flags=[]):
def link(object_files, destination_file=None, overwrite_destination_file=True, additional_link_flags=[]):
"""Compiles given :param:`source_file` to a Tensorflow shared Library.
.. warning::
......@@ -75,11 +76,18 @@ def link_and_load(object_files, destination_file=None, overwrite_destination_fil
subprocess.check_call(command)
return destination_file
def link_and_load(object_files, destination_file=None, overwrite_destination_file=True, additional_link_flags=[]):
destination_file = link(object_files, destination_file, overwrite_destination_file, additional_link_flags)
lib = tf.load_op_library(destination_file)
return lib
def try_get_cuda_arch_flag():
if 'PYSTENCILS_TENSORFLOW_NVCC_ARCH' in os.environ:
return "-arch=sm_" + os.environ['PYSTENCILS_TENSORFLOW_NVCC_ARCH']
try:
from pycuda.driver import Context
arch = "sm_%d%d" % Context.get_device().compute_capability()
......@@ -143,7 +151,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
return destination_file
def compile_sources_and_load(host_sources, cuda_sources=[], additional_compile_flags=[], additional_link_flags=[]):
def compile_sources(host_sources, cuda_sources=[], additional_compile_flags=[], additional_link_flags=[]):
object_files = []
......@@ -172,3 +180,42 @@ def compile_sources_and_load(host_sources, cuda_sources=[], additional_compile_f
if module:
print('Loaded Tensorflow module.')
return module
def compile_sources_and_load(host_sources,
cuda_sources=[],
additional_compile_flags=[],
additional_link_flags=[],
compile_only=False):
object_files = []
for source in tqdm(host_sources + cuda_sources, desc='Compiling Tensorflow module'):
is_cuda = source in cuda_sources
if exists(source):
source_code = read_file(source)
else:
source_code = source
file_extension = '.cu' if is_cuda else '.cpp'
file_name = join(pystencils.cache.cache_dir, f'{_hash(source_code.encode()).hexdigest()}{file_extension}')
write_file(file_name, source_code)
compile_file(file_name,
use_nvcc=is_cuda,
overwrite_destination_file=False,
additional_compile_flags=additional_compile_flags)
object_files.append(file_name + '.o')
print('Linking Tensorflow module...')
module_file = link(object_files,
overwrite_destination_file=False,
additional_link_flags=additional_link_flags)
if not compile_only:
module = tf.load_op_library(module_file)
if module:
print('Loaded Tensorflow module.')
return module
else:
return module_file
......@@ -8,6 +8,8 @@
"""
from os.path import exists
import pytest
import sympy
......@@ -44,6 +46,10 @@ def test_tensorflow_jit_gpu():
assert 'call_forward_jit_gpu' in dir(lib)
assert 'call_backward_jit_gpu' in dir(lib)
file_name = pystencils_autodiff.tensorflow_jit.compile_sources_and_load([], [str(module)], compile_only=True)
print(file_name)
assert exists(file_name)
def test_tensorflow_jit_cpu():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment