Commit 9ed3556c authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Make openmp for tf cpu optional

parent 8f91bec3
......@@ -137,7 +137,12 @@ if pystencils.gpucuda.cudajit.USE_FAST_MATH:
_nvcc_flags.append('-use_fast_math')
def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=True, additional_compile_flags=[]):
def compile_file(file,
use_nvcc=False,
nvcc='nvcc',
overwrite_destination_file=True,
additional_compile_flags=[],
openmp=True):
if 'tensorflow_host_compiler' not in get_compiler_config():
get_compiler_config()['tensorflow_host_compiler'] = get_compiler_config()['command']
import tensorflow as tf
......@@ -155,7 +160,6 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
'cu',
'-Xcompiler',
_position_independent_flag,
_openmp_flag,
_do_not_link_flag,
*tf.sysconfig.get_compile_flags(),
*_include_flags,
......@@ -166,13 +170,14 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
*(get_compiler_config()['flags']).split(' '),
file,
_do_not_link_flag,
_openmp_flag,
*tf.sysconfig.get_compile_flags(),
*_include_flags,
*additional_compile_flags,
_output_flag]
destination_file = f'{file}_{_hash(".".join(command_prefix).encode()).hexdigest()}{_object_file_extension}'
if openmp:
command_prefix.append(_output_flag)
if not exists(destination_file) or overwrite_destination_file:
command = command_prefix + [destination_file]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment