From 956c89a049255cf0a0ac52e813c45171329275d7 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 20 Apr 2018 10:09:07 +0200
Subject: [PATCH] Bug fix for shared library cache -> switched to atomic
 filesystem write

- when running multiple pystencils instances, sometimes errors happened
  because one process might have partially written a cached file, which
  is already read before writing was finished
-> switched to "atomic write" (only on linux yet) that uses os.rename
   which is guaranteed to be atomic
---
 cpu/cpujit.py | 41 +++++++++++++++++++++--------------------
 utils.py      | 32 ++++++++++++++++++++++++++++++++
 2 files changed, 53 insertions(+), 20 deletions(-)

diff --git a/cpu/cpujit.py b/cpu/cpujit.py
index 7f32f2b5e..b793302e2 100644
--- a/cpu/cpujit.py
+++ b/cpu/cpujit.py
@@ -76,7 +76,7 @@ from collections import OrderedDict
 from pystencils.transformations import symbol_name_to_variable_name
 from pystencils.data_types import to_ctypes, get_base_type, StructType
 from pystencils.field import FieldType
-from pystencils.utils import recursive_dict_update
+from pystencils.utils import recursive_dict_update, file_handle_for_atomic_write, atomic_file_write
 
 
 def make_python_function(kernel_function_node, argument_dict={}):
@@ -263,19 +263,18 @@ def compile_object_cache_to_shared_library():
 atexit.register(compile_object_cache_to_shared_library)
 
 
-def generate_code(ast, restrict_qualifier, function_prefix, target_file):
+def generate_code(ast, restrict_qualifier, function_prefix, source_file):
     headers = get_headers(ast)
     headers.update(['<cmath>', '<cstdint>'])
 
-    with open(target_file, 'w') as source_file:
-        code = generate_c(ast)
-        includes = "\n".join(["#include %s" % (include_file,) for include_file in headers])
-        print(includes, file=source_file)
-        print("#define RESTRICT %s" % (restrict_qualifier,), file=source_file)
-        print("#define FUNC_PREFIX %s" % (function_prefix,), file=source_file)
-        print('extern "C" { ', file=source_file)
-        print(code, file=source_file)
-        print('}', file=source_file)
+    code = generate_c(ast)
+    includes = "\n".join(["#include %s" % (include_file,) for include_file in headers])
+    print(includes, file=source_file)
+    print("#define RESTRICT %s" % (restrict_qualifier,), file=source_file)
+    print("#define FUNC_PREFIX %s" % (function_prefix,), file=source_file)
+    print('extern "C" { ', file=source_file)
+    print(code, file=source_file)
+    print('}', file=source_file)
 
 
 def run_compile_step(command):
@@ -298,16 +297,18 @@ def compile_linux(ast, code_hash_str, src_file, lib_file):
     compiler_config = get_compiler_config()
 
     object_file = os.path.join(cache_config['object_cache'], code_hash_str + '.o')
-    # Compilation
     if not os.path.exists(object_file):
-        generate_code(ast, compiler_config['restrict_qualifier'], '', src_file)
-        compile_cmd = [compiler_config['command'], '-c'] + compiler_config['flags'].split()
-        compile_cmd += ['-o', object_file, src_file]
-        run_compile_step(compile_cmd)
+        with file_handle_for_atomic_write(src_file) as f:
+            generate_code(ast, compiler_config['restrict_qualifier'], '', f)
+        with atomic_file_write(object_file) as file_name:
+            compile_cmd = [compiler_config['command'], '-c'] + compiler_config['flags'].split()
+            compile_cmd += ['-o', file_name, src_file]
+            run_compile_step(compile_cmd)
 
     # Linking
-    run_compile_step([compiler_config['command'], '-shared', object_file, '-o', lib_file] +
-                     compiler_config['flags'].split())
+    with atomic_file_write(lib_file) as file_name:
+        run_compile_step([compiler_config['command'], '-shared', object_file, '-o', file_name] +
+                         compiler_config['flags'].split())
 
 
 def compile_windows(ast, code_hash_str, src_file, lib_file):
@@ -317,8 +318,8 @@ def compile_windows(ast, code_hash_str, src_file, lib_file):
     object_file = os.path.join(cache_config['object_cache'], code_hash_str + '.obj')
     # Compilation
     if not os.path.exists(object_file):
-        generate_code(ast, compiler_config['restrict_qualifier'],
-                      '__declspec(dllexport)', src_file)
+        with file_handle_for_atomic_write(src_file) as f:
+            generate_code(ast, compiler_config['restrict_qualifier'], '__declspec(dllexport)', f)
 
         # /c compiles only, /EHsc turns of exception handling in c code
         compile_cmd = ['cl.exe', '/c', '/EHsc'] + compiler_config['flags'].split()
diff --git a/utils.py b/utils.py
index 57abda263..bc7833629 100644
--- a/utils.py
+++ b/utils.py
@@ -1,3 +1,6 @@
+import os
+from tempfile import NamedTemporaryFile
+from contextlib import contextmanager
 from typing import Mapping
 
 
@@ -34,3 +37,32 @@ def recursive_dict_update(d, u):
         else:
             d[k] = u[k]
     return d
+
+
+@contextmanager
+def file_handle_for_atomic_write(file_path):
+    """Open temporary file object that atomically moves to destination upon exiting.
+
+    Allows reading and writing to and from the same filename.
+    The file will not be moved to destination in case of an exception.
+
+    Args:
+        file_path: path to file to be opened
+    """
+    target_folder = os.path.dirname(os.path.abspath(file_path))
+    with NamedTemporaryFile(delete=False, dir=target_folder, mode='w') as f:
+        try:
+            yield f
+        finally:
+            f.flush()
+            os.fsync(f.fileno())
+    os.rename(f.name, file_path)
+
+
+@contextmanager
+def atomic_file_write(file_path):
+    target_folder = os.path.dirname(os.path.abspath(file_path))
+    with NamedTemporaryFile(delete=False, dir=target_folder) as f:
+        f.file.close()
+        yield f.name
+    os.rename(f.name, file_path)
-- 
GitLab