From 3dacf1e28061e78dfc77213617f38f84e3352970 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 16 Mar 2022 10:19:59 +0100 Subject: [PATCH] sharedmethodcache --- pystencils/cache.py | 29 ++++++++ pystencils_tests/test_sharedmethodcache.py | 82 ++++++++++++++++++++++ 2 files changed, 111 insertions(+) create mode 100644 pystencils_tests/test_sharedmethodcache.py diff --git a/pystencils/cache.py b/pystencils/cache.py index f29678920..b8ac2b06e 100644 --- a/pystencils/cache.py +++ b/pystencils/cache.py @@ -32,6 +32,35 @@ def memorycache_if_hashable(maxsize=128, typed=False): return wrapper + +def sharedmethodcache(cache_id: str): + """Decorator for memoization of instance methods, allowing multiple methods to use the same cache. + + This decorator caches results of instance methods per instantiated object of the surrounding class. + It allows multiple methods to use the same cache, by passing them the same `cache_id` string. + Cached values are stored in a dictionary, which is added as a member `self.<cache_id>` to the + `self` object instance. Make sure that this doesn't cause any naming conflicts with other members! + Of course, for this to be useful, said methods must have the same signature (up to additional kwargs) + and must return the same result when called with the same arguments.""" + def _decorator(user_method): + def _decorated_func(self, *args, **kwargs): + objdict = self.__dict__ + cache = objdict.setdefault(cache_id, dict()) + + key = args + for item in kwargs.items(): + key += item + + if key not in cache: + result = user_method(self, *args, **kwargs) + cache[key] = result + return result + else: + return cache[key] + return _decorated_func + return _decorator + + # Disable memory cache: # disk_cache = lambda o: o # disk_cache_no_fallback = lambda o: o diff --git a/pystencils_tests/test_sharedmethodcache.py b/pystencils_tests/test_sharedmethodcache.py new file mode 100644 index 000000000..7489dd61b --- /dev/null +++ b/pystencils_tests/test_sharedmethodcache.py @@ -0,0 +1,82 @@ +from pystencils.cache import sharedmethodcache + + +class Fib: + + def __init__(self): + self.fib_rec_called = 0 + self.fib_iter_called = 0 + + @sharedmethodcache("fib_cache") + def fib_rec(self, n): + self.fib_rec_called += 1 + return 1 if n <= 1 else self.fib_rec(n-1) + self.fib_rec(n-2) + + @sharedmethodcache("fib_cache") + def fib_iter(self, n): + self.fib_iter_called += 1 + f1, f2 = 0, 1 + for i in range(n): + f2 = f1 + f2 + f1 = f2 - f1 + return f2 + + +def test_fib_memoization_1(): + fib = Fib() + + assert "fib_cache" not in fib.__dict__ + + f13 = fib.fib_rec(13) + assert fib.fib_rec_called == 14 + assert "fib_cache" in fib.__dict__ + assert fib.fib_cache[(13,)] == f13 + + for k in range(14): + # fib_iter should use cached results from fib_rec + fib.fib_iter(k) + + assert fib.fib_iter_called == 0 + + +def test_fib_memoization_2(): + fib = Fib() + + f11 = fib.fib_iter(11) + f12 = fib.fib_iter(12) + + assert fib.fib_iter_called == 2 + + f13 = fib.fib_rec(13) + + # recursive calls should be cached + assert fib.fib_rec_called == 1 + + +class Triad: + + def __init__(self): + self.triad_called = 0 + + @sharedmethodcache("triad_cache") + def triad(self, a, b, c=0): + self.triad_called += 1 + return a * b + c + + +def test_triab_memoization(): + triad = Triad() + + t = triad.triad(12, 4, 15) + assert triad.triad_called == 1 + assert triad.triad_cache[(12, 4, 15)] == t + + t = triad.triad(12, 4, c=15) + assert triad.triad_called == 2 + assert triad.triad_cache[(12, 4, 'c', 15)] == t + + t = triad.triad(12, 4, 15) + assert triad.triad_called == 2 + + t = triad.triad(12, 4, c=15) + assert triad.triad_called == 2 -- GitLab