Optimizing Python

March 9, 2023

Python is an easy language to write, but it’s also very slow. Since it’s a dynamically typed and interpreted language, every Python operation is much slower than the corresponding operation would be in C or FORTRAN—every line of Python must be interpreted, type checked, and so forth (see this little overview of what the Python interpreter does).

Fortunately for those of us who like programming in Python, there are a number of different ways to make Python code faster. The simplest way is just to use NumPy, the de facto standard for any sort of array-based computation in Python; NumPy functions are written in C/C++, and so are much faster than the corresponding native Python functions.

Another strategy is to use a just-in-time compiler to accelerate Python code, like Jax or Numba. This approach incurs a substantial O(1) cost (compilation) but makes all subsequent calls orders of magnitude faster. Unfortunately, these libraries don’t support all possible Python functions or external libraries, meaning that sometimes it’s difficult to write JIT-compilable code.

How do these strategies fare on a real-world problem? I selected pairwise distance calculations for a list of points as a test case; this problem is pretty common in a lot of scientific contexts, including calculating electrostatic interactions in molecular dynamics or quantum mechanics.

We can start by importing the necessary libraries and writing two functions. The first function is the “naïve” Python approach, and the second uses scipy.spatial.distance.cdist, one of the most overpowered functions I’ve encountered in any Python library.

import numpy as np
import numba
import cctk
import scipy

mol = cctk.XYZFile.read_file("30_dcm.xyz").get_molecule()
points = mol.geometry.view(np.ndarray)

def naive_get_distance(points):
    N = points.shape[0]
    distances = np.zeros(shape=(N,N))
    for i, A in enumerate(points):
        for j, B in enumerate(points):
            distances[i,j] = np.linalg.norm(A-B)
    return distances

def scipy_get_distance(points):
    return scipy.spatial.distance.cdist(points,points)

If we score these functions in Jupyter, we can see that cdist is almost 2000 times faster than the pure Python function!

%%timeit
naive_get_distance(points)

103 ms ± 981 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%%timeit
scipy_get_distance(points)

55.2 µs ± 2.57 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In this case, it’s pretty obvious that we should just use cdist. But what if there wasn’t a magic built-in function for this task—how close can we get to the performance of cdist with other performance optimizations?

The first and most obvious optimization is simply to take advantage of the symmetry of the matrix, and not compute entries below the diagonal. (Note that this is sort of cheating, since cdist doesn’t know that both arguments are the same.)

def triangle_get_distance(points):
    N = points.shape[0]
    distances = np.zeros(shape=(N,N))
    for i in range(N):
        for j in range(i,N):
            distances[i,j] = np.linalg.norm(points[i]-points[j])
            distances[j,i] = distances[i,j]
    return distances

As expected, this roughly halves our time:

%%timeit
triangle_get_distance(points)

57.6 ms ± 409 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Next, we can use Numba to compile this function. This yields roughly a 10-fold speedup, bringing us to about two orders of magnitude slower than cdist.

numba_triangle_get_distance = numba.njit(triangle_get_distance)
%%timeit
numba_triangle_get_distance(points)

5.74 ms ± 36.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Defining our own norm with Numba, instead of using np.linalg.norm, gives us another nice boost:

def custom_norm(AB):
    return np.sqrt(AB[0]*AB[0] + AB[1]*AB[1] + AB[2]*AB[2])

numba_custom_norm = numba.njit(custom_norm)

def cn_triangle_get_distance(points):
    N = points.shape[0]
    distances = np.zeros(shape=(N,N))
    for i in range(N):
        for j in range(i,N):
            distances[i,j] = numba_custom_norm(points[i] - points[j])
            distances[j,i] = distances[i,j]
    return distances

numba_cn_triangle_get_distance = numba.njit(cn_triangle_get_distance)
%%timeit
numba_cn_triangle_get_distance(points)

1.35 ms ± 21.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

What about trying to write this program using only vectorized NumPy functions? This takes a bit more creativity; I came up with the following function, which is a bit memory-inefficient but still runs quite quickly:

def numpy_get_distance(points):
    N = points.shape[0]

    points_row = np.repeat(np.expand_dims(points,1), N, axis=1)
    points_col = np.repeat(np.expand_dims(points,0), N, axis=0)

    sq_diff = np.square(np.subtract(points_row, points_col))
    return np.sqrt(np.sum(sq_diff, axis=2))
%%timeit
numpy_get_distance(points)

426 µs ± 6.34 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Unfortunately, calling np.repeat with arguments isn’t supported by Numba, meaning that I had to get a bit more creative to write a Numba-compilable version of the previous program. The best solution that I found involved a few array reshaping operations, which are (presumably) pretty inefficient, and the final code only runs a little bit faster than the Numpy-only version.

def numpy_get_distance2(points):
    N = points.shape[0]

    points_row = np.swapaxes(points.repeat(N).reshape((N,3,N)),1,2)
    points_col = np.swapaxes(points_row,0,1)

    sq_diff = np.square(np.subtract(points_row, points_col))
    return np.sqrt(np.sum(sq_diff, axis=2))

numba_np_get_distance2 = numba.njit(numpy_get_distance2)
%%timeit
numba_np_get_distance2(points)

338 µs ± 4.11 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

I tried a few other approaches, but ultimately wasn’t able to find anything better; in theory, splitting the loops into chunks could improve cache utilization, but in practice anything clever I tried just made things slower.

In the end, we were able to accelerate our code about 250x by using a combination of NumPy and Numba, but were unable to match the speed of an optimized low-level implementation. Maybe in a future post I’ll drop into C or C++ and see how close I can get to the reference—until then, I hope you found this useful.

(I’m sure that there are ways that even this Python version could be improved; I did not even look at any other libraries, like Jax, Cython, or PyPy. Let me know if you think of anything clever!)



If you want email updates when I write new posts, you can subscribe on Substack.