The Intro
Attention Is All You Need | arxiv
Pytorch Transformers from Scratch (Attention is all you need) | Aladdin Persson | YouTube
TRANSFORMERS FROM SCRATCH | blog
transformer_from_scratch.py | GitHub
Attention and Q,K,V
Queries , Keys , and Values are terms from the field of Recommendation Algorithms.
There is a collection of many key-value pairs D = { ( k 1 , v 1 ) , ( k 2 , v 2 ) , ⋯ , ( k m , v m ) } D = \{(k_1,v_1),(k_2,v_2),\cdots,(k_m,v_m)\} D = {( k 1 , v 1 ) , ( k 2 , v 2 ) , ⋯ , ( k m , v m )} and a query q q q . You need to find a key-value pair ( k ? , v ? ) (k_?,v_?) ( k ? , v ? ) that is best for your query.
Attention ( q , D ) = d e f ∑ i = 1 m α ( q , k i ) v i = [ α ( q , k 1 ) ⋯ α ( q , k m ) ] [ v 1 ⋮ v m ] \begin{split}
\text{Attention}(q,D) &\overset{def}{=} \sum_{i=1}^m \alpha(q,k_i)v_i\\
&=
\begin{bmatrix}\alpha(q,k_1) & \cdots & \alpha(q,k_m)\end{bmatrix}
\begin{bmatrix}v_1 \\ \vdots \\ v_m\end{bmatrix}
\end{split}
Attention ( q , D ) = d e f i = 1 ∑ m α ( q , k i ) v i = [ α ( q , k 1 ) ⋯ α ( q , k m ) ] v 1 ⋮ v m
where α ( q , k i ) ∈ R \alpha(q,k_i)\in\R α ( q , k i ) ∈ R are scalar attention weights . The operation itself is typically referred to as attention pooling . The name attention derives from the fact that the operation pays particular attention to the terms for which the weight α \alpha α is significant (i.e., large).
As such, the attention over D D D generates a linear combination of values contained in the database.
We could apply softmax operation to [ α ( q , k 1 ) ⋯ α ( q , k m ) ] [\alpha(q,k_1)\cdots\alpha(q,k_m)] [ α ( q , k 1 ) ⋯ α ( q , k m )] in order to make the weights nonnegative and also sum up to 1.
Attention ( q , D ) = softmax ( [ α ( q , k 1 ) ⋯ α ( q , k m ) ] ) [ v 1 ⋮ v m ] \text{Attention}(q,D) =
\text{softmax}\left(\begin{bmatrix}\alpha(q,k_1) & \cdots & \alpha(q,k_m)\end{bmatrix}\right)
\begin{bmatrix}v_1 \\ \vdots \\ v_m\end{bmatrix}
Attention ( q , D ) = softmax ( [ α ( q , k 1 ) ⋯ α ( q , k m ) ] ) v 1 ⋮ v m
Expecially, when q , k , v q,k,v q , k , v are all vectors (row vectors) and the function α ( ⋅ , ⋅ ) \alpha(\cdot,\cdot) α ( ⋅ , ⋅ ) is vector dot-product, we get scaled dot-product Dot-Product Attention that is
Attention ( q , D ) = softmax ( q ⋅ [ k 1 T , ⋯ , k m T ] ) [ v 1 ⋮ v m ] = softmax ( q K T ) V \begin{split}
\text{Attention}(q,D) &=
\text{softmax}\left(q\cdot [k^T_1,\cdots, k^T_m]\right)
\begin{bmatrix}v_1 \\ \vdots \\ v_m\end{bmatrix}\\
&= \text{softmax}(qK^T)V
\end{split}
Attention ( q , D ) = softmax ( q ⋅ [ k 1 T , ⋯ , k m T ] ) v 1 ⋮ v m = softmax ( q K T ) V
Self-Attention
The “self” in “self-attention” means that there is no a collection of key-value pairs and no query; instead, query, key and value are all come from the input itself.
Input Embedding
{ x 1 , x 2 , ⋯ , x n } → Input Embedding f ( ⋅ ) { a 1 , a 2 , ⋯ , a n } \{x_1, x_2, \cdots, x_n\} \xrightarrow[\text{Input Embedding}]{f(\cdot)} \{a_1, a_2, \cdots, a_n\}
{ x 1 , x 2 , ⋯ , x n } f ( ⋅ ) Input Embedding { a 1 , a 2 , ⋯ , a n }
x 1 , ⋯ , x n x_1,\cdots,x_n x 1 , ⋯ , x n is the original input sequence, and a 1 , ⋯ , a n a_1,\cdots,a_n a 1 , ⋯ , a n is the sequence linerly embedded to higher dimension from x 1 , ⋯ , x n x_1,\cdots,x_n x 1 , ⋯ , x n . Their elements are all row vectors.
Q,K,V
Q = [ q 1 q 2 ⋮ q n ] n × d q = [ a 1 a 2 ⋮ a n ] n × d a ⋅ W q d a × d q , K = [ k 1 k 2 ⋮ k n ] n × d k = [ a 1 a 2 ⋮ a n ] n × d a ⋅ W k d a × d k , V = [ v 1 v 2 ⋮ v n ] n × d v = [ a 1 a 2 ⋮ a n ] n × d a ⋅ W v d a × d v Q =
\underset{n\times d_q}{
\begin{bmatrix}
q_1 \\ q_2 \\ \vdots \\ q_n
\end{bmatrix}
} =
\underset{n\times d_a}{
\begin{bmatrix}
a_1 \\ a_2 \\ \vdots \\ a_n
\end{bmatrix}
} \cdot
\underset{d_a\times d_q}{W^q},
\quad
K =
\underset{n\times d_k}{
\begin{bmatrix}
k_1 \\ k_2 \\ \vdots \\ k_n
\end{bmatrix}
} =
\underset{n\times d_a}{
\begin{bmatrix}
a_1 \\ a_2 \\ \vdots \\ a_n
\end{bmatrix}
} \cdot
\underset{d_a\times d_k}{W^k},
\quad
V =
\underset{n\times d_v}{
\begin{bmatrix}
v_1 \\ v_2 \\ \vdots \\ v_n
\end{bmatrix}
} =
\underset{n\times d_a}{
\begin{bmatrix}
a_1 \\ a_2 \\ \vdots \\ a_n
\end{bmatrix}
} \cdot
\underset{d_a\times d_v}{W^v}
Q = n × d q q 1 q 2 ⋮ q n = n × d a a 1 a 2 ⋮ a n ⋅ d a × d q W q , K = n × d k k 1 k 2 ⋮ k n = n × d a a 1 a 2 ⋮ a n ⋅ d a × d k W k , V = n × d v v 1 v 2 ⋮ v n = n × d a a 1 a 2 ⋮ a n ⋅ d a × d v W v
d q = d k = ? d v d_q = d_k \overset{?}{=} d_v d q = d k = ? d v
Attention
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V
Attention ( Q , K , V ) = softmax ( d k Q K T ) V
Note: d k d_k d k is the dimention of k 1 ⋯ n k_{1\cdots n} k 1 ⋯ n rather than dimention of K K K which is ( n × d k ) (n\times d_k) ( n × d k ) .
Q K T = [ q 1 q 2 ⋮ q n ] ⋅ [ k 1 T k 2 T ⋯ k n T ] = [ q 1 k 1 T q 1 k 2 T ⋯ q 1 k n T q 2 k 1 T q 2 k 2 T ⋯ q 2 k n T ⋮ ⋮ ⋱ ⋮ q n k 1 T q n k 2 T ⋯ q n k n T ] QK^T =
\begin{bmatrix}
q_1 \\ q_2 \\ \vdots \\ q_n
\end{bmatrix}
\cdot
\begin{bmatrix}
k_1^T & k_2^T & \cdots & k_n^T
\end{bmatrix} =
\begin{bmatrix}
q_1k_1^T & q_1k_2^T & \cdots & q_1k_n^T\\
q_2k_1^T & q_2k_2^T & \cdots & q_2k_n^T\\
\vdots & \vdots & \ddots & \vdots\\
q_nk_1^T & q_nk_2^T & \cdots & q_nk_n^T\\
\end{bmatrix}
Q K T = q 1 q 2 ⋮ q n ⋅ [ k 1 T k 2 T ⋯ k n T ] = q 1 k 1 T q 2 k 1 T ⋮ q n k 1 T q 1 k 2 T q 2 k 2 T ⋮ q n k 2 T ⋯ ⋯ ⋱ ⋯ q 1 k n T q 2 k n T ⋮ q n k n T
softmax ( Q K T d k ) = [ softmax ( q 1 k 1 T d k q 1 k 2 T d k ⋯ q 1 k n T d k ) softmax ( q 2 k 1 T d k q 2 k 2 T d k ⋯ q 2 k n T d k ) ⋮ ⋮ ⋮ softmax ( q n k 1 T d k q n k 2 T d k ⋯ q n k n T d k ) ] \text{softmax}(\frac{QK^T}{\sqrt{d_k}}) =
\begin{bmatrix}
\text{softmax}(\frac{q_1k_1^T}{\sqrt{d_k}} & \frac{q_1k_2^T}{\sqrt{d_k}} & \cdots & \frac{q_1k_n^T}{\sqrt{d_k}})\\
\text{softmax}(\frac{q_2k_1^T}{\sqrt{d_k}} & \frac{q_2k_2^T}{\sqrt{d_k}} & \cdots & \frac{q_2k_n^T}{\sqrt{d_k}})\\
\vdots & \vdots & & \vdots\\
\text{softmax}(\frac{q_nk_1^T}{\sqrt{d_k}} & \frac{q_nk_2^T}{\sqrt{d_k}} & \cdots & \frac{q_nk_n^T}{\sqrt{d_k}})\\
\end{bmatrix}
softmax ( d k Q K T ) = softmax ( d k q 1 k 1 T softmax ( d k q 2 k 1 T ⋮ softmax ( d k q n k 1 T d k q 1 k 2 T d k q 2 k 2 T ⋮ d k q n k 2 T ⋯ ⋯ ⋯ d k q 1 k n T ) d k q 2 k n T ) ⋮ d k q n k n T )
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q,K,V) =
\text{softmax}(\frac{QK^T}{\sqrt{d_k}}) V
Attention ( Q , K , V ) = softmax ( d k Q K T ) V
Multi-head Self Attention
Q n × d m o d e l = [ q 1 q 2 ⋮ q n ] n × d m o d e l = [ a 1 a 2 ⋮ a n ] n × d a ⋅ W q d a × d m o d e l K n × d m o d e l = [ k 1 k 2 ⋮ k n ] n × d m o d e l = [ a 1 a 2 ⋮ a n ] n × d a ⋅ W k d a × d m o d e l V n × d m o d e l = [ v 1 v 2 ⋮ v n ] n × d m o d e l = [ a 1 a 2 ⋮ a n ] n × d a ⋅ W v d a × d m o d e l \underset{n\times d_{model}}{Q} =
\underset{n\times d_{model}}{
\begin{bmatrix}
q_1 \\ q_2 \\ \vdots \\ q_n
\end{bmatrix}
} =
\underset{n\times d_a}{
\begin{bmatrix}
a_1 \\ a_2 \\ \vdots \\ a_n
\end{bmatrix}
} \cdot
\underset{d_a\times d_{model}}{W^q}\\
\underset{n\times d_{model}}{K} =
\underset{n\times d_{model}}{
\begin{bmatrix}
k_1 \\ k_2 \\ \vdots \\ k_n
\end{bmatrix}
} =
\underset{n\times d_a}{
\begin{bmatrix}
a_1 \\ a_2 \\ \vdots \\ a_n
\end{bmatrix}
} \cdot
\underset{d_a\times d_{model}}{W^k}\\
\underset{n\times d_{model}}{V} =
\underset{n\times d_{model}}{
\begin{bmatrix}
v_1 \\ v_2 \\ \vdots \\ v_n
\end{bmatrix}
} =
\underset{n\times d_a}{
\begin{bmatrix}
a_1 \\ a_2 \\ \vdots \\ a_n
\end{bmatrix}
} \cdot
\underset{d_a\times d_{model}}{W^v}
n × d m o d e l Q = n × d m o d e l q 1 q 2 ⋮ q n = n × d a a 1 a 2 ⋮ a n ⋅ d a × d m o d e l W q n × d m o d e l K = n × d m o d e l k 1 k 2 ⋮ k n = n × d a a 1 a 2 ⋮ a n ⋅ d a × d m o d e l W k n × d m o d e l V = n × d m o d e l v 1 v 2 ⋮ v n = n × d a a 1 a 2 ⋮ a n ⋅ d a × d m o d e l W v
Let d q = d k = d v = d m o d e l / h d_q=d_k=d_v=d_{model}/h d q = d k = d v = d m o d e l / h
MultiHead ( Q , K , V ) = Concat ( h e a d 1 , ⋯ , h e a d h ) W O h e a d i = Attention ( Q W i Q , K W i K , V W i V ) \text{MultiHead}(Q,K,V) = \text{Concat}(head_1,\cdots,head_h)W^O\\
head_i = \text{Attention}(QW^Q_i, KW^K_i, VW^V_i)
MultiHead ( Q , K , V ) = Concat ( h e a d 1 , ⋯ , h e a d h ) W O h e a d i = Attention ( Q W i Q , K W i K , V W i V )
In the paper Attention is all you need , the auther emploied d m o d e l = 512 d_{model} = 512 d m o d e l = 512 , h = 8 h = 8 h = 8 and d q = d k = d v = d m o d e l / h = 64 d_q = d_k = d_v = d_{model}/h = 64 d q = d k = d v = d m o d e l / h = 64 .
Q → Q W 1 Q , Q W 2 Q , ⋯ , Q W h Q K → K W 1 K , K W 2 K , ⋯ , K W h K V → V W 1 V , V W 2 V , ⋯ , V W h V \begin{split}
Q &\rightarrow QW^Q_1,QW^Q_2,\cdots,QW^Q_h\\
K &\rightarrow KW^K_1,KW^K_2,\cdots,KW^K_h\\
V &\rightarrow VW^V_1,VW^V_2,\cdots,VW^V_h
\end{split}
Q K V → Q W 1 Q , Q W 2 Q , ⋯ , Q W h Q → K W 1 K , K W 2 K , ⋯ , K W h K → V W 1 V , V W 2 V , ⋯ , V W h V
W 1 Q ⋯ W 8 Q W^Q_1 \cdots W^Q_8 W 1 Q ⋯ W 8 Q , W 1 K ⋯ W 8 K W^K_1 \cdots W^K_8 W 1 K ⋯ W 8 K and W 1 V ⋯ W 8 V W^V_1 \cdots W^V_8 W 1 V ⋯ W 8 V are matrixs as the following:
1 2 ⋮ 64 ⋮ 512 [ 1 0 ⋯ 0 0 1 ⋯ 0 ⋮ ⋮ ⋱ ⋮ 0 0 ⋯ 1 0 0 ⋯ 0 ⋮ ⋮ ⋮ 0 0 ⋯ 0 ] , 1 ⋮ 65 66 ⋮ 128 ⋮ 512 [ 0 0 ⋯ 0 ⋮ ⋮ ⋮ 0 0 ⋯ 0 1 0 ⋯ 0 0 1 ⋯ 0 ⋮ ⋮ ⋱ ⋮ 0 0 ⋯ 1 0 0 ⋯ 0 ⋮ ⋮ ⋮ 0 0 ⋯ 0 ] , ⋯ , 1 ⋮ 448 449 ⋮ 512 [ 0 0 ⋯ 0 ⋮ ⋮ ⋮ 0 0 ⋯ 0 1 0 ⋯ 0 0 1 ⋯ 0 ⋮ ⋮ ⋱ ⋮ 0 0 ⋯ 1 ] \begin{array}{}
\begin{matrix}
1 \\ 2 \\ \vdots \\ 64 \\ \\ \vdots \\ 512
\end{matrix}
\begin{bmatrix}
\textcolor{red}{1} & 0 & \cdots & 0\\
0 & \textcolor{red}{1} & \cdots & 0\\
\vdots & \vdots & \ddots & \vdots\\
0 & 0 & \cdots & \textcolor{red}{1}\\
0 & 0 & \cdots & 0\\
\vdots & \vdots && \vdots\\
0 & 0 & \cdots & 0
\end{bmatrix}
\end{array},
\quad
\begin{array}{}
\begin{matrix}
1 \\ \vdots \\ \\ 65 \\ 66 \\ \vdots \\ 128 \\ \\ \vdots \\ 512
\end{matrix}
\begin{bmatrix}
0 & 0 & \cdots & 0\\
\vdots & \vdots && \vdots\\
0 & 0 & \cdots & 0\\
\textcolor{red}{1} & 0 & \cdots & 0\\
0 & \textcolor{red}{1} & \cdots & 0\\
\vdots & \vdots & \ddots & \vdots\\
0 & 0 & \cdots & \textcolor{red}{1}\\
0 & 0 & \cdots & 0\\
\vdots & \vdots && \vdots\\
0 & 0 & \cdots & 0
\end{bmatrix}
\end{array},
\quad \cdots, \quad
\begin{array}{}
\begin{matrix}
1 \\ \vdots \\ \\ 448 \\ 449 \\ \vdots \\ 512
\end{matrix}
\begin{bmatrix}
0 & 0 & \cdots & 0\\
\vdots & \vdots && \vdots\\
0 & 0 & \cdots & 0\\
\textcolor{red}{1} & 0 & \cdots & 0\\
0 & \textcolor{red}{1} & \cdots & 0\\
\vdots & \vdots & \ddots & \vdots\\
0 & 0 & \cdots & \textcolor{red}{1}
\end{bmatrix}
\end{array}
1 2 ⋮ 64 ⋮ 512 1 0 ⋮ 0 0 ⋮ 0 0 1 ⋮ 0 0 ⋮ 0 ⋯ ⋯ ⋱ ⋯ ⋯ ⋯ 0 0 ⋮ 1 0 ⋮ 0 , 1 ⋮ 65 66 ⋮ 128 ⋮ 512 0 ⋮ 0 1 0 ⋮ 0 0 ⋮ 0 0 ⋮ 0 0 1 ⋮ 0 0 ⋮ 0 ⋯ ⋯ ⋯ ⋯ ⋱ ⋯ ⋯ ⋯ 0 ⋮ 0 0 0 ⋮ 1 0 ⋮ 0 , ⋯ , 1 ⋮ 448 449 ⋮ 512 0 ⋮ 0 1 0 ⋮ 0 0 ⋮ 0 0 1 ⋮ 0 ⋯ ⋯ ⋯ ⋯ ⋱ ⋯ 0 ⋮ 0 0 0 ⋮ 1
In this way, Q n × 512 , K n × 512 , V n × 512 \underset{n\times 512}{Q}, \underset{n\times 512}{K}, \underset{n\times 512}{V} n × 512 Q , n × 512 K , n × 512 V are splited uniformly by the number of their columns to 8 parts.
Q n × 512 → W 1 ⋯ W 8 Q 1 n × 64 , ⋯ , Q 8 n × 64 K n × 512 → W 1 ⋯ W 8 K 1 n × 64 , ⋯ , K 8 n × 64 V n × 512 → W 1 ⋯ W 8 V 1 n × 64 , ⋯ , V 8 n × 64 \begin{split}
\underset{n\times 512}{Q} \xrightarrow{W_1\cdots W_8} \underset{n\times 64}{Q_1},\cdots, \underset{n\times 64}{Q_8}\\
\underset{n\times 512}{K} \xrightarrow{W_1\cdots W_8} \underset{n\times 64}{K_1},\cdots, \underset{n\times 64}{K_8}\\
\underset{n\times 512}{V} \xrightarrow{W_1\cdots W_8} \underset{n\times 64}{V_1},\cdots, \underset{n\times 64}{V_8}
\end{split}
n × 512 Q W 1 ⋯ W 8 n × 64 Q 1 , ⋯ , n × 64 Q 8 n × 512 K W 1 ⋯ W 8 n × 64 K 1 , ⋯ , n × 64 K 8 n × 512 V W 1 ⋯ W 8 n × 64 V 1 , ⋯ , n × 64 V 8
expecially, we have
Q n × 512 = concat [ Q 1 n × 64 ⋯ Q 8 n × 64 ] K n × 512 = concat [ K 1 n × 64 ⋯ K 8 n × 64 ] V n × 512 = concat [ V 1 n × 64 ⋯ V 8 n × 64 ] \begin{split}
\underset{n\times 512}{Q} = \text{concat}[\underset{n\times 64}{Q_1}\cdots \underset{n\times 64}{Q_8}]\\
\underset{n\times 512}{K} = \text{concat}[\underset{n\times 64}{K_1}\cdots \underset{n\times 64}{K_8}]\\
\underset{n\times 512}{V} = \text{concat}[\underset{n\times 64}{V_1}\cdots \underset{n\times 64}{V_8}]
\end{split}
n × 512 Q = concat [ n × 64 Q 1 ⋯ n × 64 Q 8 ] n × 512 K = concat [ n × 64 K 1 ⋯ n × 64 K 8 ] n × 512 V = concat [ n × 64 V 1 ⋯ n × 64 V 8 ]
What’s the difference between Multi-head and normal Attention
we use the same way to get Q , K , V Q,K,V Q , K , V from the input sequence x 1 , ⋯ , x n x_1,\cdots,x_n x 1 , ⋯ , x n by linearly embedding. Then we split the Q , K , V Q,K,V Q , K , V to several parts to get Q 1 ⋯ Q h Q_1\cdots Q_h Q 1 ⋯ Q h , K 1 ⋯ K h K_1\cdots K_h K 1 ⋯ K h , V 1 ⋯ V h V_1\cdots V_h V 1 ⋯ V h .
Expecially, when we use special parameter matrices, we could have a special case that is:
Q = concat [ Q 1 ⋯ Q h ] K = concat [ K 1 ⋯ K h ] V = concat [ V 1 ⋯ V h ] \begin{split}
Q &= \text{concat}[Q_1\cdots Q_h]\\
K &= \text{concat}[K_1\cdots K_h]\\
V &= \text{concat}[V_1\cdots V_h]
\end{split}
Q K V = concat [ Q 1 ⋯ Q h ] = concat [ K 1 ⋯ K h ] = concat [ V 1 ⋯ V h ]
How to calculate normal Attention
Attention ( Q , K , V ) = Attention ( cnt [ Q 1 ⋯ Q h ] , cnt [ K 1 ⋯ K h ] , cnt [ V 1 ⋯ V h ] ) \text{Attention}(Q,K,V) =
\text{Attention}(\text{cnt}[Q_1\cdots Q_h],\text{cnt}[K_1\cdots K_h],\text{cnt}[V_1\cdots V_h])
Attention ( Q , K , V ) = Attention ( cnt [ Q 1 ⋯ Q h ] , cnt [ K 1 ⋯ K h ] , cnt [ V 1 ⋯ V h ])
How to calculate multi-head Attention
MultiHead ( Q , K , V ) = concate [ Attention ( Q 1 , K 1 , V 1 ) , ⋯ , Attention ( Q h , K h , V h ) ] \text{MultiHead}(Q,K,V) = \text{concate}[\text{Attention}(Q_1,K_1,V_1),\cdots, \text{Attention}(Q_h,K_h,V_h)]
MultiHead ( Q , K , V ) = concate [ Attention ( Q 1 , K 1 , V 1 ) , ⋯ , Attention ( Q h , K h , V h )]
comparison on computation
Q = concat [ Q 1 ⋯ Q h ] K = concat [ K 1 ⋯ K h ] V = concat [ V 1 ⋯ V h ] \begin{split}
Q &= \text{concat}[Q_1\cdots Q_h]\\
K &= \text{concat}[K_1\cdots K_h]\\
V &= \text{concat}[V_1\cdots V_h]
\end{split}
Q K V = concat [ Q 1 ⋯ Q h ] = concat [ K 1 ⋯ K h ] = concat [ V 1 ⋯ V h ]
Q K T V = [ Q 1 Q 2 ⋯ Q h ] [ K 1 T K 2 T ⋮ K h T ] V = ( Q 1 K 1 T + Q 2 K 2 T + ⋯ + Q h K h T ) [ V 1 V 2 ⋯ V h ] = [ Q 1 K 1 T V 1 + Q 2 K 2 T V 1 + ⋮ + Q h K h T V 1 , Q 1 K 1 T V 2 + Q 2 K 2 T V 2 + ⋮ + Q h K h T V 2 , ⋯ , Q 1 K 1 T V h + Q 2 K 2 T V h + ⋮ + Q h K h T V h ] \begin{split}
QK^TV &=
\begin{bmatrix}Q_1 & Q_2 & \cdots & Q_h\end{bmatrix}
\begin{bmatrix}K^T_1 \\ K^T_2 \\ \vdots \\ K^T_h\end{bmatrix}V\\
&= (Q_1K^T_1 + Q_2K^T_2 + \cdots + Q_hK^T_h)
\begin{bmatrix}V_1 & V_2 & \cdots & V_h\end{bmatrix}\\
&=
\begin{bmatrix}
\begin{matrix}
\textcolor{red}{Q_1 K^T_1 V_1}\\ + \\ Q_2 K^T_2 V_1 \\ + \\ \vdots \\ + \\ Q_h K^T_h V_1
\end{matrix},&
\begin{matrix}
Q_1 K^T_1 V_2 \\ + \\ \textcolor{red}{Q_2 K^T_2 V_2} \\ + \\ \vdots \\ + \\ Q_h K^T_h V_2
\end{matrix},
& \cdots, &
\begin{matrix}
Q_1 K^T_1 V_h \\ + \\ Q_2 K^T_2 V_h \\ + \\ \vdots \\ + \\ \textcolor{red}{Q_h K^T_h V_h}
\end{matrix}
\end{bmatrix}
\end{split}
Q K T V = [ Q 1 Q 2 ⋯ Q h ] K 1 T K 2 T ⋮ K h T V = ( Q 1 K 1 T + Q 2 K 2 T + ⋯ + Q h K h T ) [ V 1 V 2 ⋯ V h ] = Q 1 K 1 T V 1 + Q 2 K 2 T V 1 + ⋮ + Q h K h T V 1 , Q 1 K 1 T V 2 + Q 2 K 2 T V 2 + ⋮ + Q h K h T V 2 , ⋯ , Q 1 K 1 T V h + Q 2 K 2 T V h + ⋮ + Q h K h T V h
Attention ↑ ↓ Multi-Head \text{Attention} \uparrow\downarrow \text{Multi-Head}
Attention ↑↓ Multi-Head
[ Q 1 K 1 T V 1 Q 2 K 2 T V 2 Q h K h T V h ] \begin{bmatrix}
\textcolor{red}{Q_1 K^T_1 V_1}&
\textcolor{red}{Q_2 K^T_2 V_2}&
\textcolor{red}{Q_h K^T_h V_h}
\end{bmatrix}
[ Q 1 K 1 T V 1 Q 2 K 2 T V 2 Q h K h T V h ]
Something about 1 d k \frac{1}{\sqrt{d_k}} d k 1
随机变量乘积的期望和方差
assume that q 1 , ⋯ , q d k ∼ N ( μ q , σ q ) q^1,\cdots,q^{d_k} \sim N(\mu_q, \sigma_q) q 1 , ⋯ , q d k ∼ N ( μ q , σ q ) and k 1 , ⋯ , k d k ∼ N ( μ k , σ k ) k^1,\cdots,k^{d_k} \sim N(\mu_k, \sigma_k) k 1 , ⋯ , k d k ∼ N ( μ k , σ k )
Actually, μ q = μ k = 0 \mu_q = \mu_k = 0 μ q = μ k = 0 and σ q = σ k = 1 \sigma_q = \sigma_k = 1 σ q = σ k = 1
E ( q ⋅ k T ) = E ( ∑ i = 1 d k q i k i ) = ∑ i = 1 d k E ( q i k i ) = ∑ i = 1 d k ( E q i ⋅ E k i + c o v ( q i , k i ) ) → c o v ( q i , k i ) = 0 i n d e p e n d e n c e = ∑ i = 1 d k ( E q i ⋅ E k i + 0 ) = d k ⋅ μ q ⋅ μ k → μ = 0 0 \begin{split}
E(q\cdot k^T) &= E(\sum_{i=1}^{d_k}q^ik^i) \\
&= \sum_{i=1}^{d_k}E(q^ik^i)\\
&= \sum_{i=1}^{d_k} \left(Eq^i\cdot Ek^i + cov(q^i, k^i) \right)\\
\xrightarrow[cov(q^i,k^i)=0]{independence} &= \sum_{i=1}^{d_k} \left(Eq^i\cdot Ek^i + 0 \right) \\
&= d_k\cdot\mu_q\cdot\mu_k
\xrightarrow{\mu=0} 0
\end{split}
E ( q ⋅ k T ) in d e p e n d e n ce co v ( q i , k i ) = 0 = E ( i = 1 ∑ d k q i k i ) = i = 1 ∑ d k E ( q i k i ) = i = 1 ∑ d k ( E q i ⋅ E k i + co v ( q i , k i ) ) = i = 1 ∑ d k ( E q i ⋅ E k i + 0 ) = d k ⋅ μ q ⋅ μ k μ = 0 0
V a r ( q ⋅ k T ) = V a r ( ∑ i = 1 d k q i k i ) → i n d e p e n d e n c e = ∑ i = 1 d k V a r ( q i k i ) = ∑ i = 1 d k V a r ( q i ) ⋅ V a r ( k i ) + V a r ( q i ) ⋅ E 2 k i + E 2 q i ⋅ V a r ( k i ) = ∑ i = 1 d k σ q ⋅ σ k + σ q ⋅ μ k 2 + μ q 2 ⋅ σ k = d k ⋅ ( σ q ⋅ σ k + σ q ⋅ μ k 2 + μ q 2 ⋅ σ k ) → σ = 1 μ = 0 d k \begin{split}
Var(q\cdot k^T) &= Var(\sum_{i=1}^{d_k}q^ik^i)\\
\xrightarrow{independence} &= \sum_{i=1}^{d_k} Var(q^ik^i)\\
&= \sum_{i=1}^{d_k} Var(q^i)\cdot Var(k^i) + Var(q^i)\cdot E^2k^i + E^2q^i\cdot Var(k^i)\\
&= \sum_{i=1}^{d_k} \sigma_q\cdot \sigma_k + \sigma_q\cdot \mu_k^2 + \mu_q^2\cdot \sigma_k\\
&= d_k \cdot (\sigma_q\cdot \sigma_k + \sigma_q\cdot \mu_k^2 + \mu_q^2\cdot \sigma_k)
\xrightarrow[\sigma=1]{\mu=0} d_k
\end{split}
Va r ( q ⋅ k T ) in d e p e n d e n ce = Va r ( i = 1 ∑ d k q i k i ) = i = 1 ∑ d k Va r ( q i k i ) = i = 1 ∑ d k Va r ( q i ) ⋅ Va r ( k i ) + Va r ( q i ) ⋅ E 2 k i + E 2 q i ⋅ Va r ( k i ) = i = 1 ∑ d k σ q ⋅ σ k + σ q ⋅ μ k 2 + μ q 2 ⋅ σ k = d k ⋅ ( σ q ⋅ σ k + σ q ⋅ μ k 2 + μ q 2 ⋅ σ k ) μ = 0 σ = 1 d k
You can infer from the equations above
E ( q ⋅ k T d k ) = 1 d k ⋅ E ( q ⋅ k T ) = d k ⋅ μ q ⋅ μ k → μ = 0 0 V a r ( q ⋅ k T d k ) = 1 d k ⋅ V a r ( q ⋅ k T ) = σ q ⋅ σ k + σ q ⋅ μ k 2 + μ q 2 ⋅ σ k → σ = 1 μ = 0 1 E\left(\frac{q\cdot k^T}{\sqrt{d_k}}\right) = \frac{1}{\sqrt{d_k}}\cdot E(q\cdot k^T) = \sqrt{d_k} \cdot \mu_q \cdot \mu_k \xrightarrow{\mu=0} 0
\\
Var\left(\frac{q\cdot k^T}{\sqrt{d_k}}\right) = \frac{1}{d_k}\cdot Var(q\cdot k^T) = \sigma_q\cdot \sigma_k + \sigma_q\cdot \mu_k^2 + \mu_q^2\cdot \sigma_k \xrightarrow[\sigma=1]{\mu=0} 1
E ( d k q ⋅ k T ) = d k 1 ⋅ E ( q ⋅ k T ) = d k ⋅ μ q ⋅ μ k μ = 0 0 Va r ( d k q ⋅ k T ) = d k 1 ⋅ Va r ( q ⋅ k T ) = σ q ⋅ σ k + σ q ⋅ μ k 2 + μ q 2 ⋅ σ k μ = 0 σ = 1 1
Holistic Perspective
1 B: Batch_size, T: Block_size (Time), C: Embedding_size (Channel)
graph LR
Input["Input: [B,T,C]"]
Q["Q: [B,T,dim_q]"] --> QK
K["K: [B,T,dim_k]"] --> QK
V["V: [B,T,dim_v]"]
QK["Q·K^T: [B,T,T] (dim_q=dim_k)"]
Input --"Wq: [C, dim_q]"--> Q
Input --"Wk: [C, dim_k]"--> K
Input --"Wv: [C, dim_v]"--> V;
Out["(Q·K^T)·V: [B,T,dim_v]"]
QK --> Out
V ----> Out
Implementation in Python
Self-Attention
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 import torchfrom torch import nn''' x.shape: [n, dim_in] Q = x @ Wq, K = x @ Wk, V = x @ Wv attention = softmax((Q @ K^T)/sqrt(dim_k)) @ V ''' class SelfAttention (nn.Module): def __init__ (self, dim_in, dim_q, dim_k, dim_v ): super (SelfAttention, self).__init__() assert dim_k == dim_q self.dim_in = dim_in self.dim_q = dim_k self.dim_k = dim_k self.dim_v = dim_v self.linear_q = nn.Linear(dim_in, dim_q, bias=False ) self.linear_k = nn.Linear(dim_in, dim_k, bias=False ) self.linear_v = nn.Linear(dim_in, dim_v, bias=False ) self.norm = (dim_k)**(1 /2 ) def forward (self, x ): '''x: n, dim_in''' assert x.shape[-1 ] == self.dim_in q = self.linear_q(x) k = self.linear_k(x) v = self.linear_v(x) attention = torch.mm(q, k.transpose(0 ,1 )) / self.norm attention = nn.Softmax(-1 )(attention) attention = torch.mm(attention, v) return attention if __name__ == "__main__" : input = torch.rand(3 , 16 ) attention = SelfAttention(dim_in=16 , dim_q=8 , dim_k=8 , dim_v=16 ) output = attention.forward(input ) print (output.shape)
MultiHead-Attention
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 import torchfrom torch import nn''' x.shape: [n, dim_in] Q = x @ Wq, K = x @ Wk, V = x @ Wv Q = [Q1,..,Qh], K = [K1,...,Kh], V = [V1,...,Vh] concat[attention(Q1,K1,V1), ..., attention(Qh,Kh,Vh)] ''' class MultiHeadSelfAttention (nn.Module): def __init__ (self, dim_in, dim_q, dim_k, dim_v, num_heads=8 ): super (MultiHeadSelfAttention, self).__init__() assert dim_q == dim_k self.dim_in = dim_in self.dim_q = dim_q self.dim_k = dim_k self.dim_v = dim_v self.num_heads = num_heads self.linear_q = nn.Linear(dim_in, dim_q, bias=False ) self.linear_k = nn.Linear(dim_in, dim_k, bias=False ) self.linear_v = nn.Linear(dim_in, dim_v, bias=False ) self.norm_fact = (dim_k // num_heads)**(1 /2 ) def forward (self, x ): '''x.shape: [n, dim_in]''' n, dim_in = x.shape assert dim_in == self.dim_in dim_q = self.dim_q dim_k = self.dim_k dim_v = self.dim_v heads = self.num_heads q = self.linear_q(x).reshape(n, heads, dim_q//heads).transpose(0 ,1 ) k = self.linear_k(x).reshape(n, heads, dim_k//heads).transpose(0 ,1 ) v = self.linear_v(x).reshape(n, heads, dim_v//heads).transpose(0 ,1 ) attention = torch.matmul(q, k.transpose(1 ,2 )) / self.norm_fact attention = nn.Softmax(dim=-1 )(attention) attention = torch.matmul(attention, v) attention = attention.transpose(0 ,1 ).reshape(n, dim_v) return attention if __name__ == "__main__" : input = torch.rand(3 , 16 ) multihead = MultiHeadSelfAttention(dim_in=16 , dim_q=8 , dim_k=8 , dim_v=16 , num_heads=8 ) output = multihead.forward(input ) print (output.shape)
MSA of Images
How to calculate the Multi-head Self Attention of an image, for example one with a shape of 3×224×224 pixels?
Patches
We first splits the RGB image into non-overlapping patches. Each patch is treated as a “token”, a term from NLP meaning something like a unit. And its feature is set as a concatenation of the raw pixel RGB values.
For example, in the implementation of Swin-Transformer , they use a patch size of 4×4 and thus the feature dimension of each patch is 4 × 4 × 3 = 48.
We use patch splitting operation to transform an RGB image from shape of 3×224×224 to 48×56×56, which has 224 4 × 224 4 = 56 × 56 \frac{224}{4}\times\frac{224}{4} = 56\times56 4 224 × 4 224 = 56 × 56 patches and each patch has 3 × 4 × 4 = 48 3\times4\times4 = 48 3 × 4 × 4 = 48 pixels value. Finally, we transpose or reshape the 48×56×56 image to get a 56×56×48 image.
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ → patch splitting ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ \fbox{$\begin{array}{}
·&·&·&·&·&·&·&·\\
·&·&·&·&·&·&·&·\\
·&·&·&·&·&·&·&·\\
·&·&·&·&·&·&·&·\\
·&·&·&·&·&·&·&·\\
·&·&·&·&·&·&·&·\\
·&·&·&·&·&·&·&·\\
·&·&·&·&·&·&·&·
\end{array}$}
\xrightarrow{\text{patch splitting}}
\begin{array}{}
\fbox{$\begin{array}{}·&·\\·&·\end{array}$}
\fbox{$\begin{array}{}·&·\\·&·\end{array}$}
\fbox{$\begin{array}{}·&·\\·&·\end{array}$}
\fbox{$\begin{array}{}·&·\\·&·\end{array}$}\\
\fbox{$\begin{array}{}·&·\\·&·\end{array}$}
\fbox{$\begin{array}{}·&·\\·&·\end{array}$}
\fbox{$\begin{array}{}·&·\\·&·\end{array}$}
\fbox{$\begin{array}{}·&·\\·&·\end{array}$}\\
\fbox{$\begin{array}{}·&·\\·&·\end{array}$}
\fbox{$\begin{array}{}·&·\\·&·\end{array}$}
\fbox{$\begin{array}{}·&·\\·&·\end{array}$}
\fbox{$\begin{array}{}·&·\\·&·\end{array}$}\\
\fbox{$\begin{array}{}·&·\\·&·\end{array}$}
\fbox{$\begin{array}{}·&·\\·&·\end{array}$}
\fbox{$\begin{array}{}·&·\\·&·\end{array}$}
\fbox{$\begin{array}{}·&·\\·&·\end{array}$}
\end{array}
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ patch splitting ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅