diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py index fc9b810c30977825c503b560ba17244777bd7930..4da43010a44945920687a84db2f5fa3bb848b01c 100644 --- a/pystencils/cpu/cpujit.py +++ b/pystencils/cpu/cpujit.py @@ -205,10 +205,23 @@ def read_config(): if config['cache']['object_cache'] is not False: config['cache']['object_cache'] = os.path.expanduser(config['cache']['object_cache']).format(pid=os.getpid()) - if config['cache']['clear_cache_on_start']: + clear_cache = False + cache_status_file = os.path.join(config['cache']['object_cache'], 'last_config.json') + if os.path.exists(cache_status_file): + # check if compiler config has changed + last_config = json.load(open(cache_status_file, 'r')) + if set(last_config.items()) != set(config['compiler'].items()): + clear_cache = True + else: + for key in last_config.keys(): + if last_config[key] != config['compiler'][key]: + clear_cache = True + + if config['cache']['clear_cache_on_start'] or clear_cache: shutil.rmtree(config['cache']['object_cache'], ignore_errors=True) create_folder(config['cache']['object_cache'], False) + json.dump(config['compiler'], open(cache_status_file, 'w'), indent=4) if config['compiler']['os'] == 'windows': from pystencils.cpu.msvc_detection import get_environment @@ -531,6 +544,9 @@ class ExtensionModuleCode: header_list = list(headers) header_list.sort() header_list.insert(0, '"Python.h"') + ps_headers = [os.path.join(os.path.dirname(__file__), '..', 'include', h[1:-1]) for h in header_list + if os.path.exists(os.path.join(os.path.dirname(__file__), '..', 'include', h[1:-1]))] + header_hash = b''.join([hashlib.sha256(open(h, 'rb').read()).digest() for h in ps_headers]) includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list]) self._code_string += includes @@ -546,7 +562,7 @@ class ExtensionModuleCode: self._code_string += create_function_boilerplate_code(ast.get_parameters(), name, ast) ast.function_name = old_name - self._code_hash = "mod_" + hashlib.sha256(self._code_string.encode()).hexdigest() + self._code_hash = "mod_" + hashlib.sha256(self._code_string.encode() + header_hash).hexdigest() self._code_string += create_module_boilerplate_code(self._code_hash, self._function_names) def get_hash_of_code(self):