pythoncomputer science

Memoization in One Line of Python

by Steve Marx on

In my last post, I laid out four steps to dynamic programming:

  1. Formulate a recursive solution.
  2. Use memoization to avoid recalculating subproblems.
  3. Use a bottom-up approach to avoid recursion.
  4. Optimize space efficiency by forgetting results you no longer need.

To optimize time complexity, you can actually stop after step 2. A memoization solution is already optimal because each required subproblem is calculated only once.

Memoization is an easy way to speed up a slow recursive program. Python 3.9 has a decorator called cache that adds memoization to a function with just one line of code!

(In older versions of Python, you can use the equivalent lru_cache(maxsize=None).)

Memoized Fibonacci

Revisiting everybody’s favorite recursive example, here’s a Python function to compute the nth Fibonacci number:

def fib(n):
    return 1 if n <= 2 else fib(n-1) + fib(n-2)

As explained in my last post, this solution is quite inefficient because of how many times it has to recalculate the same numbers. Memoization can fix this by remembering previously calculated values rather than recalculating them each time.

By adding @cache above our function, we get an optimal memoized version:

from functools import cache

@cache
def fib(n):
    return 1 if n <= 2 else fib(n-1) + fib(n-2)

That’s really it!

Using @cache to solve Advent of Code Day 10

Advent of Code Day 10

Advent of Code is an annual advent calendar of programming puzzles. This is my first year participating, and I’ve been having a blast!

Day 10 part 2 can be solved using memoization.

In brief, you’re given a bunch of adapters with different “joltage” ratings. You need to form a chain of these adapters from a wall socket, which counts as 0 jolts, to the highest rated adapter you have. Each adapter can only connect to an adapter that’s 1-3 jolts lower than its rating.

The challenge is to find how many different ways you can connect the adapters to reach that final adapter.

Here’s a solution using recursion and @cache. Note that data is a variable containing the raw data file from Advent of Code, where each line represents an adapter’s joltage rating:

from functools import cache

# socket is implicitly 0 "jolts"
joltages = set(map(int, data.split('\n'))).union([0])
# we always have to reach the highest adapter
goal = max(joltages)

@cache
def ways(n):
    # base case of the highest adapter
    if n == goal:
        return 1

    # base case of non-existent adapter
    if n not in joltages:
        return 0

    return ways(n+1) + ways(n+2) + ways(n+3)

print(ways(0))

Bonus: dynamic programming solution

This is the actual code I wrote when I originally solved Day 10 (with added comments):

from collections import defaultdict

joltages = list(map(int, data.split('\n'))) + [0]
ways = defaultdict(lambda: 0)

# base case of the highest adapter
ways[goal] = 1

# from second largest to smallest
for n in sorted(joltages, reverse=True)[1:]:
    ways[n] = ways[n+1] + ways[n+2] + ways[n+3]

print(ways[0])

To me, the dynamic programming solution is natural enough that I skipped right past the recursive approach.

Limitation

cache and lru_cache require that all the parameters to your function are hashable. This is because a dictionary is used to remember past results, keyed by the function parameters.

Me. In your inbox?

Admit it. You're intrigued.

Subscribe

Related posts