Memoization in One Line of Python
by Steve Marx on
In my last post, I laid out four steps to dynamic programming:
- Formulate a recursive solution.
- Use memoization to avoid recalculating subproblems.
- Use a bottom-up approach to avoid recursion.
- Optimize space efficiency by forgetting results you no longer need.
To optimize time complexity, you can actually stop after step 2. A memoization solution is already optimal because each required subproblem is calculated only once.
Memoization is an easy way to speed up a slow recursive program. Python 3.9 has a decorator called cache
that adds memoization to a function with just one line of code!
(In older versions of Python, you can use the equivalent lru_cache(maxsize=None)
.)
Memoized Fibonacci
Revisiting everybody’s favorite recursive example, here’s a Python function to compute the nth Fibonacci number:
def fib(n):
return 1 if n <= 2 else fib(n-1) + fib(n-2)
As explained in my last post, this solution is quite inefficient because of how many times it has to recalculate the same numbers. Memoization can fix this by remembering previously calculated values rather than recalculating them each time.
By adding @cache
above our function, we get an optimal memoized version:
from functools import cache
@cache
def fib(n):
return 1 if n <= 2 else fib(n-1) + fib(n-2)
That’s really it!
Using @cache
to solve Advent of Code Day 10
Advent of Code is an annual advent calendar of programming puzzles. This is my first year participating, and I’ve been having a blast!
Day 10 part 2 can be solved using memoization.
In brief, you’re given a bunch of adapters with different “joltage” ratings. You need to form a chain of these adapters from a wall socket, which counts as 0 jolts, to the highest rated adapter you have. Each adapter can only connect to an adapter that’s 1-3 jolts lower than its rating.
The challenge is to find how many different ways you can connect the adapters to reach that final adapter.
Here’s a solution using recursion and @cache
. Note that data
is a variable containing the raw data file from Advent of Code, where each line represents an adapter’s joltage rating:
from functools import cache
# socket is implicitly 0 "jolts"
joltages = set(map(int, data.split('\n'))).union([0])
# we always have to reach the highest adapter
goal = max(joltages)
@cache
def ways(n):
# base case of the highest adapter
if n == goal:
return 1
# base case of non-existent adapter
if n not in joltages:
return 0
return ways(n+1) + ways(n+2) + ways(n+3)
print(ways(0))
Bonus: dynamic programming solution
This is the actual code I wrote when I originally solved Day 10 (with added comments):
from collections import defaultdict
joltages = list(map(int, data.split('\n'))) + [0]
ways = defaultdict(lambda: 0)
# base case of the highest adapter
ways[goal] = 1
# from second largest to smallest
for n in sorted(joltages, reverse=True)[1:]:
ways[n] = ways[n+1] + ways[n+2] + ways[n+3]
print(ways[0])
To me, the dynamic programming solution is natural enough that I skipped right past the recursive approach.
Limitation
cache
and lru_cache
require that all the parameters to your function are hashable. This is because a dictionary is used to remember past results, keyed by the function parameters.