python • computer science
Memoization in One Line of Python
by Steve Marx on
In my last post, I laid out four steps to dynamic programming:
- Formulate a recursive solution.
- Use memoization to avoid recalculating subproblems.
- Use a bottom-up approach to avoid recursion.
- Optimize space efficiency by forgetting results you no longer need.
To optimize time complexity, you can actually stop after step 2. A memoization solution is already optimal because each required subproblem is calculated only once.
Memoization is an easy way to speed up a slow recursive program. Python 3.9 has a decorator called
cache that adds memoization to a function with just one line of code!
(In older versions of Python, you can use the equivalent
Revisiting everybody’s favorite recursive example, here’s a Python function to compute the nth Fibonacci number:
def fib(n): return 1 if n <= 2 else fib(n-1) + fib(n-2)
As explained in my last post, this solution is quite inefficient because of how many times it has to recalculate the same numbers. Memoization can fix this by remembering previously calculated values rather than recalculating them each time.
@cache above our function, we get an optimal memoized version:
from functools import cache @cache def fib(n): return 1 if n <= 2 else fib(n-1) + fib(n-2)
That’s really it!
@cache to solve Advent of Code Day 10
Day 10 part 2 can be solved using memoization.
In brief, you’re given a bunch of adapters with different “joltage” ratings. You need to form a chain of these adapters from a wall socket, which counts as 0 jolts, to the highest rated adapter you have. Each adapter can only connect to an adapter that’s 1-3 jolts lower than its rating.
The challenge is to find how many different ways you can connect the adapters to reach that final adapter.
Here’s a solution using recursion and
@cache. Note that
data is a variable containing the raw data file from Advent of Code, where each line represents an adapter’s joltage rating:
from functools import cache # socket is implicitly 0 "jolts" joltages = set(map(int, data.split('\n'))).union() # we always have to reach the highest adapter goal = max(joltages) @cache def ways(n): # base case of the highest adapter if n == goal: return 1 # base case of non-existent adapter if n not in joltages: return 0 return ways(n+1) + ways(n+2) + ways(n+3) print(ways(0))
Bonus: dynamic programming solution
This is the actual code I wrote when I originally solved Day 10 (with added comments):
from collections import defaultdict joltages = list(map(int, data.split('\n'))) +  ways = defaultdict(lambda: 0) # base case of the highest adapter ways[goal] = 1 # from second largest to smallest for n in sorted(joltages, reverse=True)[1:]: ways[n] = ways[n+1] + ways[n+2] + ways[n+3] print(ways)
To me, the dynamic programming solution is natural enough that I skipped right past the recursive approach.
lru_cache require that all the parameters to your function are hashable. This is because a dictionary is used to remember past results, keyed by the function parameters.