From 0738e9e0c3971eb5c8023ad6f2274a2dc0d42df5 Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Wed, 28 Apr 2021 13:26:26 +0000
Subject: [PATCH] Incorporate header files and compiler flags into object cache
 hash

---
 pystencils/cpu/cpujit.py | 20 ++++++++++++++++++--
 1 file changed, 18 insertions(+), 2 deletions(-)

diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py
index fc9b810c3..4da43010a 100644
--- a/pystencils/cpu/cpujit.py
+++ b/pystencils/cpu/cpujit.py
@@ -205,10 +205,23 @@ def read_config():
     if config['cache']['object_cache'] is not False:
         config['cache']['object_cache'] = os.path.expanduser(config['cache']['object_cache']).format(pid=os.getpid())
 
-        if config['cache']['clear_cache_on_start']:
+        clear_cache = False
+        cache_status_file = os.path.join(config['cache']['object_cache'], 'last_config.json')
+        if os.path.exists(cache_status_file):
+            # check if compiler config has changed
+            last_config = json.load(open(cache_status_file, 'r'))
+            if set(last_config.items()) != set(config['compiler'].items()):
+                clear_cache = True
+            else:
+                for key in last_config.keys():
+                    if last_config[key] != config['compiler'][key]:
+                        clear_cache = True
+
+        if config['cache']['clear_cache_on_start'] or clear_cache:
             shutil.rmtree(config['cache']['object_cache'], ignore_errors=True)
 
         create_folder(config['cache']['object_cache'], False)
+        json.dump(config['compiler'], open(cache_status_file, 'w'), indent=4)
 
     if config['compiler']['os'] == 'windows':
         from pystencils.cpu.msvc_detection import get_environment
@@ -531,6 +544,9 @@ class ExtensionModuleCode:
         header_list = list(headers)
         header_list.sort()
         header_list.insert(0, '"Python.h"')
+        ps_headers = [os.path.join(os.path.dirname(__file__), '..', 'include', h[1:-1]) for h in header_list
+                      if os.path.exists(os.path.join(os.path.dirname(__file__), '..', 'include', h[1:-1]))]
+        header_hash = b''.join([hashlib.sha256(open(h, 'rb').read()).digest() for h in ps_headers])
 
         includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])
         self._code_string += includes
@@ -546,7 +562,7 @@ class ExtensionModuleCode:
             self._code_string += create_function_boilerplate_code(ast.get_parameters(), name, ast)
             ast.function_name = old_name
 
-        self._code_hash = "mod_" + hashlib.sha256(self._code_string.encode()).hexdigest()
+        self._code_hash = "mod_" + hashlib.sha256(self._code_string.encode() + header_hash).hexdigest()
         self._code_string += create_module_boilerplate_code(self._code_hash, self._function_names)
 
     def get_hash_of_code(self):
-- 
GitLab