Category Archives: Probability

The Square Root Cancellation Heuristic

In the first equation of the popular Attention is all you need paper (see also this blog post), the authors write

{\rm Attention}(Q,K,V) = {\rm softmax} \left( \frac{QK^T}{\sqrt{d_k}} V\right).

In this post we are going to discuss where the \sqrt{d_k} comes from, leading us to some classical Probability Theory. We will first talk about the math with some examples and then quickly make the connection.

The principle of square root cancellation also appears in the Batch Norm Paper, Neural Network weight initialization (see also Xavier Uniform), and elsewhere.

The Math

Let v be a $d$-dimensional vector with entries that are either +1 or -1 Then we may write

v= (v_1 , \ldots , v_d).

We will be interested in the simple question: what is the sum of the coordinates? For a fixed vector, we can easily compute this via

v_1 + \cdots + v_d.

The largest this sum can be is d and the smallest it can be is -d. However, what we would like to know the average behavior of this sum

Since the vector has both +1‘s and -1‘s, there will likely be cancellation in the sum. That is, we expect the sum will not be either d or -d. It turns out that as d gets large, there will typically be a lot of cancellation.

Let’s first run an experiment in python. We let d = 500 and compute the sum of 100,000 random vectors each with \pm 1 entries.

import numpy as np 
import seaborn as sns 

d = 500; sums = []

for _ in range(100_000):
    v = np.random.choice([-1, 1], size=d)

_ = sns.histplot(data = sums)

We see that most of the sums are quite a bit smaller than 500, and barely any are smaller than -75 or larger than 75.

This is the so-called square root cancellation. If we take a random vector of dimension d with randomly generated \pm 1 entries, we expect the order of the sum to be \sqrt{d}. Here \sqrt{500} \approx 22.36, and it is a general phenomenon that most of the sums will lie between, say, 4\sqrt{d} and \sqrt{d}. This general principle goes much deeper than \pm 1 vectors and turns out to be very well-studied concept in several areas of mathematics (see this post for Number Theory or this post for Probability). In fact, the famous Riemann Hypothesis can be reformulated as showing a certain sum exhibits square root cancellation.

The square root cancellation heuristic asserts that “most” vectors exhibit square root cancellation. Put another way, if we come across a vector in practice, we’d expect there to be square root cancellation as in the above example.

It turns out that the example we will discuss below generalizes considerably. The only thing that really matters is that each coordinate has mean zero (which is easily remedied by translation of a fixed constant). There are also some technical assumptions that the coordinates of v are not too large, but this is never an issue for vectors that only take on finitely many values. The motivated reader can consult the Lindeberg-Lévy Central Limit Theorem for more along this direction.

Before going any deeper, let’s explain what’s going on in the aforementioned Attention is All you Need paper.

Attention is All You Need Example

In the above equation, we have

\frac{QK^T}{\sqrt{d_k}} .

Each entries of QK^T is the dot product a row of Q and a row of K:

q \cdot k = q_1 k_1 + \cdots + q_{d_k} k_{d_k}.

It turns out that this is the sum of d_k numbers, and by the square root cancellation principle, we expect the sum to be of order \sqrt{d_k}. Hence dividing through by this number allows the sum to be, on average, of size comparable to 1 (rather than the much larger \sqrt{d_k}.

Keeping the sums from being too large helps us avoid the problem of exploding gradients in deep learning.

Back to the Math

Recall we had a vector v of d dimensions,

v = (v_1 , \ldots , v_d),

and we are interesting in what the sum is, on average. To make the question more precise, we assume each coordinate is independently generated randomly with mean zero. The key assumption is independence, that is the generation of one coordinate does not impact the others. The mean zero assumption can be obtained by shifting the values (for instance instead of recording a six sided dice roll, record the dice roll minus 3.5).

Here are some examples in python that you are welcome to check

v1 = np.random.choice([1,2,3,4,5,6],size=500) - 3.5
v2 = np.random.exponential(1,size=500) - 1 
v3 = np.random.normal(0,1,size=500) 

For such a vector, and d a bit large (say over 30), we expect

|v_1 + \cdots + v_d| \approx \sqrt{v_1^2 + \cdots + v_d^2}.

Specializing to the case where the coordinates are \pm 1, we recover the \sqrt{d} from earlier. Now why should we expect that the sum is of size \sqrt{d}?

One way to see this is to explicitly compute the expected value of

\left( v_1^2 + \cdots + v_d^2 \right),

which is precisely $d$. This computation can be done by expanding the square and using that the covariance of two different coordinates is 0. We will put the classical computation at the end of the post for those interested.

It is a common theme in this area that working with the square of the sum is theoretically much easier than working with the sum itself.

The Central Limit Theorem

It is worth reiterating that the square root cancellation heuristic generalizes much further than the \pm 1 example, but let’s continue our focus there. The astute reader will notice that the above histplot looked like the normal distribution. In fact, it is. If we consider the sum

\frac{v_1 + \cdots + v_d}{\sqrt{d}},

the histplot will be normal with mean zero and variance. Put another way,

v_1 + \cdots + v_d \approx \sqrt{d},

and we can quantify the \approx precisely. Thus if we have a sum of d numbers (where we expect the numbers to be on average 0), we can divide the sum by \sqrt{d} in order to make the sum \approx 1.

The Variance Computation

We will show

E[ \left( v_1^2 + \cdots + v_d^2 \right) ] = \sum_{j} E[v_j^2]


E[ \left( v_1^2 + \cdots + v_d^2 \right) ]= E[\sum_{ i,j} v_i v_j ] =\sum_{ i,j} E[v_i v_j ] ,

using FOIL from algebra and linearity of expectation. It is thus enough to show,

\sum_{ i,j} E[v_i v_j ]  = 0,

for every i \neq j. But this follows from independence.

Entropy and Sumsets: An example

The following post is a result of a discussion with Imre Ruzsa. Motivated by the following easy inequality in additive combinatorics

\displaystyle A+2 \cdot A \subset A+A+A , \ \ q \cdot A := \{qa : a \in A\},

I asked if the following was true for a finitely valued random variable {X}:

\displaystyle H(X+2 \cdot X) \leq H(X+X+X), \ H(X) := -\sum_{x \in X} \mathbb{P}(X = x) \log_2 \mathbb{P}(X = x).\ \ \ \ \ (1)

Here all sums are of independent copies of the random variables. The idea is that one might expect {X+X} to be a bit more uniform than {2 \cdot X}.

First Imre provided a counterexample to the question

\displaystyle H(X+ 2\cdot Y) \leq H(X+Y+Y).

I find this example is particularly elegant. Let {X} be uniform on {\{0,1\}} and {Y} be uniform on {\{0 , \ldots , n\}}. Then {X+2 \cdot Y} is uniform on {\{0 , \ldots , 2n+1\}}, while the support of {X + Y + Y} is {\{0 , \ldots , 2n+1\}} but is not uniform (there is concentration in the middle thanks to the distribution of {Y+Y}).

We then seriously questioned the validity (1). After some discussion, Imre eventually said something about higher dimensional concentration that made me think one should check (1) for the “Gaussian.” The reason Gaussian is in quotes is that it is not finitely valued as assumed in (1), so strictly speaking we cannot check it for the Gaussian. To see if there was hope, I looked at the differential entropy of a real valued random variable {G} with density {p} defined via

\displaystyle H(G) := -\int_{-\infty}^{\infty} p(x) \log p(x) dx.

Let us take {G} to be the Gaussian with mean zero (this is irrelevant for entropy) and variance 1. Recall some basic properties of variance:

\displaystyle {\rm Var}(aG) = a^2 {\rm Var}(G) , \ \ {\rm Var}(G+G) = 2 {\rm Var}(G),

where {a \in \mathbb{R}} and {G+G} is understood to be the sum of two independent copies of {G}. Thus

\displaystyle {\rm Var}(G + 2 \cdot G) = 5 , \ \ {\rm Var}(G + G +G ) = 3.

So we indeed see that (1) is not true for the Gaussian. To construct a finitely valued random variable that does not satisfy (1), we can convolve a Bernoulli random variable with itself until (1) is not satisfied (assuming that going from discrete to continuous does not destroy (1) which is not obvious without checking as {2 \cdot X} has a strange support condition, for instance the same argument would prove H(2 \cdot G) \geq H(G+G) which is clearly not true for discrete random variables). Anyways, I wrote some quick python code to check this and found that for {X = B + B + B} where {B} is the random variable of a fair coin flip, we have

\displaystyle H(X+2 \cdot X) \approx 2.984 , \ \ H(X+X+X) \approx 2.623.

Here {X+ 2 \cdot X} and {X+X+X} are supported on {\{0 , \ldots , 9\}} and so their entropies are bounded by the entropy of uniform distribution on 10 elements which is

\displaystyle \log_2 10 \approx 3.322.

Sometimes entropy analogs of sumset inequalities hold and sometimes they do not (see this paper of Ruzsa or this paper of Tao, or a host of work by Madiman and coauthors).