Skip to content
Snippets Groups Projects
Commit 3dacf1e2 authored by Frederik Hennig's avatar Frederik Hennig Committed by Markus Holzer
Browse files

sharedmethodcache

parent 825be1df
1 merge request!285sharedmethodcache
......@@ -32,6 +32,35 @@ def memorycache_if_hashable(maxsize=128, typed=False):
return wrapper
def sharedmethodcache(cache_id: str):
"""Decorator for memoization of instance methods, allowing multiple methods to use the same cache.
This decorator caches results of instance methods per instantiated object of the surrounding class.
It allows multiple methods to use the same cache, by passing them the same `cache_id` string.
Cached values are stored in a dictionary, which is added as a member `self.<cache_id>` to the
`self` object instance. Make sure that this doesn't cause any naming conflicts with other members!
Of course, for this to be useful, said methods must have the same signature (up to additional kwargs)
and must return the same result when called with the same arguments."""
def _decorator(user_method):
def _decorated_func(self, *args, **kwargs):
objdict = self.__dict__
cache = objdict.setdefault(cache_id, dict())
key = args
for item in kwargs.items():
key += item
if key not in cache:
result = user_method(self, *args, **kwargs)
cache[key] = result
return result
else:
return cache[key]
return _decorated_func
return _decorator
# Disable memory cache:
# disk_cache = lambda o: o
# disk_cache_no_fallback = lambda o: o
from pystencils.cache import sharedmethodcache
class Fib:
def __init__(self):
self.fib_rec_called = 0
self.fib_iter_called = 0
@sharedmethodcache("fib_cache")
def fib_rec(self, n):
self.fib_rec_called += 1
return 1 if n <= 1 else self.fib_rec(n-1) + self.fib_rec(n-2)
@sharedmethodcache("fib_cache")
def fib_iter(self, n):
self.fib_iter_called += 1
f1, f2 = 0, 1
for i in range(n):
f2 = f1 + f2
f1 = f2 - f1
return f2
def test_fib_memoization_1():
fib = Fib()
assert "fib_cache" not in fib.__dict__
f13 = fib.fib_rec(13)
assert fib.fib_rec_called == 14
assert "fib_cache" in fib.__dict__
assert fib.fib_cache[(13,)] == f13
for k in range(14):
# fib_iter should use cached results from fib_rec
fib.fib_iter(k)
assert fib.fib_iter_called == 0
def test_fib_memoization_2():
fib = Fib()
f11 = fib.fib_iter(11)
f12 = fib.fib_iter(12)
assert fib.fib_iter_called == 2
f13 = fib.fib_rec(13)
# recursive calls should be cached
assert fib.fib_rec_called == 1
class Triad:
def __init__(self):
self.triad_called = 0
@sharedmethodcache("triad_cache")
def triad(self, a, b, c=0):
self.triad_called += 1
return a * b + c
def test_triab_memoization():
triad = Triad()
t = triad.triad(12, 4, 15)
assert triad.triad_called == 1
assert triad.triad_cache[(12, 4, 15)] == t
t = triad.triad(12, 4, c=15)
assert triad.triad_called == 2
assert triad.triad_cache[(12, 4, 'c', 15)] == t
t = triad.triad(12, 4, 15)
assert triad.triad_called == 2
t = triad.triad(12, 4, c=15)
assert triad.triad_called == 2
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment