자기-어텐션으로 시퀀스를 인코딩하는 방법과, 토큰 순서 정보를 주입하기 위한 고정(사인/코사인) 위치 인코딩을 설명한다.
Title: 11.6. Self-Attention and Positional Encoding — Dive into Deep Learning 1.0.3 documentation
Colab에서 노트북 열기
딥러닝에서는 시퀀스를 인코딩하기 위해 CNN이나 RNN을 자주 사용합니다. 이제 어텐션 메커니즘을 염두에 두고, 토큰 시퀀스를 어텐션 메커니즘에 입력하되 매 단계마다 각 토큰이 자체적인 쿼리(query), 키(key), 값(value)을 가진다고 상상해 봅시다. 여기서 다음 층에서 토큰 표현의 값을 계산할 때, 그 토큰은 (자신의 쿼리 벡터를 통해) 다른 어떤 토큰이든 (키 벡터에 기반해 매칭하여) 주의를 기울일 수 있습니다. 쿼리-키의 호환성 점수 전체를 사용하면, 각 토큰에 대해 다른 토큰들에 대한 적절한 가중합을 구성하여 표현을 계산할 수 있습니다. 모든 토큰이 서로에게 주의를 기울이기 때문에(디코더 단계가 인코더 단계에 주의를 기울이는 경우와 달리), 이러한 아키텍처는 보통 자기-어텐션(self-attention) 모델(Lin et al., 2017, Vaswani et al., 2017)로 묘사되며, 다른 곳에서는 인트라-어텐션(intra-attention) 모델(Cheng et al., 2016, Parikh et al., 2016, Paulus et al., 2017)로도 불립니다. 이 절에서는 시퀀스 순서를 위한 추가 정보를 사용하는 것을 포함하여, 자기-어텐션을 사용한 시퀀스 인코딩을 다룹니다.
import math
import torch
from torch import nn
from d2l import torch as d2l
import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import numpy as np
import tensorflow as tf
from d2l import tensorflow as d2l
입력 토큰 시퀀스 이 주어졌고 각 ()라고 하자. 자기-어텐션의 출력은 같은 길이의 시퀀스 이며,
(11.6.1)¶
이는 (11.1.1)에서 정의한 어텐션 풀링에 따른 것입니다. 멀티-헤드 어텐션을 사용하면, 아래 코드 조각은 형태가 (배치 크기, 시간 단계 수 또는 토큰 기준 시퀀스 길이, )인 텐서의 자기-어텐션을 계산합니다. 출력 텐서의 형태도 동일합니다.
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()
batch_size, num_queries, valid_lens = 2, 4, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, jnp.array([3, 2])
X = jnp.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention.init_with_output(d2l.get_key(), X, X, X, valid_lens,
training=False)[0][0],
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, tf.constant([3, 2])
X = tf.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens, training=False),
(batch_size, num_queries, num_hiddens))
각 입력 또는 출력 토큰이 차원 벡터로 표현되는, 길이 인 토큰 시퀀스를 같은 길이의 다른 시퀀스로 매핑하는 아키텍처를 비교해 봅시다. 구체적으로 CNN, RNN, 자기-어텐션을 살펴봅니다. 계산 복잡도, 순차 연산 수, 최대 경로 길이를 비교할 것입니다. 순차 연산은 병렬 계산을 막는 반면, 시퀀스 위치 조합 간 경로가 짧을수록 시퀀스 내 장거리 의존성을 학습하기가 더 쉽습니다(Hochreiter et al., 2001).
그림 11.6.1 CNN(패딩 토큰은 생략), RNN, 자기-어텐션 아키텍처 비교.¶
어떤 텍스트 시퀀스든 “1차원 이미지”로 볼 수 있습니다. 마찬가지로 1차원 CNN은 텍스트에서 n-그램 같은 국소 특징을 처리할 수 있습니다. 길이 인 시퀀스가 주어졌을 때, 커널 크기가 이고 입력/출력 채널 수가 모두 인 합성곱 층을 생각해 봅시다. 이 합성곱 층의 계산 복잡도는 입니다. 그림 11.6.1에서 보듯 CNN은 계층적이므로, 순차 연산은 이고 최대 경로 길이는 입니다. 예를 들어 그림 11.6.1에서 과 는 커널 크기 3인 2-층 CNN의 수용영역(receptive field) 안에 있습니다.
RNN의 은닉 상태를 갱신할 때, 가중치 행렬과 차원 은닉 상태의 곱은 의 계산 복잡도를 가집니다. 시퀀스 길이가 이므로 순환 층의 계산 복잡도는 입니다. 그림 11.6.1에 따르면 병렬화할 수 없는 순차 연산이 개 있으며 최대 경로 길이도 입니다.
자기-어텐션에서는 쿼리, 키, 값이 모두 행렬입니다. (11.3.6)의 스케일드 닷-프로덕트 어텐션을 고려하면, 행렬에 행렬을 곱한 다음, 그 출력 행렬에 행렬을 곱합니다. 결과적으로 자기-어텐션의 계산 복잡도는 입니다. 그림 11.6.1에서 볼 수 있듯, 각 토큰은 자기-어텐션을 통해 어떤 다른 토큰과도 직접 연결됩니다. 따라서 계산은 개의 순차 연산만으로 병렬화할 수 있고, 최대 경로 길이도 입니다.
정리하면, CNN과 자기-어텐션은 모두 병렬 계산의 이점을 누리며 자기-어텐션이 최대 경로 길이가 가장 짧습니다. 하지만 시퀀스 길이에 대한 이차(제곱) 계산 복잡도 때문에, 매우 긴 시퀀스에서는 자기-어텐션이 지나치게 느려질 수 있습니다.
토큰을 하나씩 순환적으로 처리하는 RNN과 달리, 자기-어텐션은 순차 연산을 버리고 병렬 계산을 택합니다. 하지만 자기-어텐션만으로는 시퀀스의 순서를 보존하지 못한다는 점에 유의해야 합니다. 입력 시퀀스가 어떤 순서로 들어왔는지 모델이 아는 것이 정말 중요하다면 어떻게 해야 할까요?
토큰 순서에 대한 정보를 보존하는 지배적인 접근법은, 각 토큰에 연관된 추가 입력으로 이 정보를 모델에 제공하는 것입니다. 이러한 입력을 _위치 인코딩(positional encodings)_이라고 하며, 학습될 수도 있고 사전에 고정될 수도 있습니다. 이제 사인과 코사인 함수에 기반한, 고정 위치 인코딩의 간단한 설계를 설명합니다(Vaswani et al., 2017).
입력 표현 가 시퀀스의 개 토큰에 대한 차원 임베딩을 포함한다고 하자. 위치 인코딩은 같은 형태의 위치 임베딩 행렬 를 사용하여 를 출력하며, 그 번째 행의 번째 또는 번째 열 원소는 다음과 같습니다.
(11.6.2)¶
처음 보면 이 삼각함수 기반 설계는 이상해 보일 수 있습니다. 이 설계를 설명하기 전에, 먼저 아래 PositionalEncoding 클래스에서 이를 구현해 봅시다.
class PositionalEncoding(nn.Module): #@save
"""Positional encoding."""
def __init__ (self, num_hiddens, dropout, max_len=1000):
super(). __init__ ()
self.dropout = nn.Dropout(dropout)
# Create a long enough P
self.P = torch.zeros((1, max_len, num_hiddens))
X = torch.arange(max_len, dtype=torch.float32).reshape(
-1, 1) / torch.pow(10000, torch.arange(
0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
class PositionalEncoding(nn.Block): #@save
"""Positional encoding."""
def __init__ (self, num_hiddens, dropout, max_len=1000):
super(). __init__ ()
self.dropout = nn.Dropout(dropout)
# Create a long enough P
self.P = np.zeros((1, max_len, num_hiddens))
X = np.arange(max_len).reshape(-1, 1) / np.power(
10000, np.arange(0, num_hiddens, 2) / num_hiddens)
self.P[:, :, 0::2] = np.sin(X)
self.P[:, :, 1::2] = np.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].as_in_ctx(X.ctx)
return self.dropout(X)
class PositionalEncoding(nn.Module): #@save
"""Positional encoding."""
num_hiddens: int
dropout: float
max_len: int = 1000
def setup(self):
# Create a long enough P
self.P = jnp.zeros((1, self.max_len, self.num_hiddens))
X = jnp.arange(self.max_len, dtype=jnp.float32).reshape(
-1, 1) / jnp.power(10000, jnp.arange(
0, self.num_hiddens, 2, dtype=jnp.float32) / self.num_hiddens)
self.P = self.P.at[:, :, 0::2].set(jnp.sin(X))
self.P = self.P.at[:, :, 1::2].set(jnp.cos(X))
@nn.compact
def __call__ (self, X, training=False):
# Flax sow API is used to capture intermediate variables
self.sow('intermediates', 'P', self.P)
X = X + self.P[:, :X.shape[1], :]
return nn.Dropout(self.dropout)(X, deterministic=not training)
class PositionalEncoding(tf.keras.layers.Layer): #@save
"""Positional encoding."""
def __init__ (self, num_hiddens, dropout, max_len=1000):
super(). __init__ ()
self.dropout = tf.keras.layers.Dropout(dropout)
# Create a long enough P
self.P = np.zeros((1, max_len, num_hiddens))
X = np.arange(max_len, dtype=np.float32).reshape(
-1,1)/np.power(10000, np.arange(
0, num_hiddens, 2, dtype=np.float32) / num_hiddens)
self.P[:, :, 0::2] = np.sin(X)
self.P[:, :, 1::2] = np.cos(X)
def call(self, X, **kwargs):
X = X + self.P[:, :X.shape[1], :]
return self.dropout(X, **kwargs)
위치 임베딩 행렬 에서, 행은 시퀀스 내 위치에 대응하고 열은 서로 다른 위치 인코딩 차원을 나타냅니다. 아래 예제에서 위치 임베딩 행렬의 6번째와 7번째 열이 8번째와 9번째 열보다 더 높은 주파수를 갖는 것을 볼 수 있습니다. 6번째와 7번째(8번째와 9번째도 마찬가지) 열 사이의 위상 차이는 사인과 코사인 함수를 번갈아 쓰기 때문에 나타납니다.
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.initialize()
X = pos_encoding(np.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(np.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in np.arange(6, 10)])
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
params = pos_encoding.init(d2l.get_key(), jnp.zeros((1, num_steps, encoding_dim)))
X, inter_vars = pos_encoding.apply(params, jnp.zeros((1, num_steps, encoding_dim)),
mutable='intermediates')
P = inter_vars['intermediates']['P'][0] # retrieve intermediate value P
P = P[:, :X.shape[1], :]
d2l.plot(jnp.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in jnp.arange(6, 10)])
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
X = pos_encoding(tf.zeros((1, num_steps, encoding_dim)), training=False)
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(np.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in np.arange(6, 10)])
인코딩 차원에 따라 단조롭게 감소하는 주파수가 절대적 위치 정보와 어떻게 관련되는지 보기 위해, 0,1,…,7의 이진 표현을 출력해 봅시다. 보이는 것처럼 최하위 비트, 그다음 비트, 그다음 비트는 각각 매 숫자마다, 두 숫자마다, 네 숫자마다 번갈아 바뀝니다.
for i in range(8):
print(f'{i} in binary is {i:>03b}')
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
for i in range(8):
print(f'{i} in binary is {i:>03b}')
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
for i in range(8):
print(f'{i} in binary is {i:>03b}')
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
for i in range(8):
print(f'{i} in binary is {i:>03b}')
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
이진 표현에서는 상위 비트가 하위 비트보다 더 낮은 주파수를 갖습니다. 마찬가지로 아래 히트맵에서 보이듯, 위치 인코딩은 삼각함수를 사용해 인코딩 차원을 따라 주파수를 감소시킵니다. 출력은 부동소수점 수이므로, 이러한 연속 표현은 이진 표현보다 공간 효율적입니다.
P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
P = np.expand_dims(np.expand_dims(P[0, :, :], 0), 0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
P = jnp.expand_dims(jnp.expand_dims(P[0, :, :], axis=0), axis=0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
P = tf.expand_dims(tf.expand_dims(P[0, :, :], axis=0), axis=0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
위의 위치 인코딩은 절대적 위치 정보를 포착하는 것 외에도, 모델이 상대적 위치에 따라 주의를 기울이는 것을 쉽게 학습할 수 있게 합니다. 이는 어떤 고정된 위치 오프셋 에 대해서도, 위치 에서의 위치 인코딩이 위치 에서의 위치 인코딩을 선형 투영(linear projection)한 것으로 표현될 수 있기 때문입니다.
이 투영은 수학적으로 설명할 수 있습니다. 라고 두면, (11.6.2)의 쌍은 어떤 고정 오프셋 에 대해서도 선형적으로 로 투영될 수 있습니다:
(11.6.3)¶
여기서 투영 행렬은 어떤 위치 인덱스 에도 의존하지 않습니다.
자기-어텐션에서는 쿼리, 키, 값이 모두 같은 곳에서 옵니다. CNN과 자기-어텐션은 모두 병렬 계산의 이점을 가지며, 자기-어텐션의 최대 경로 길이가 가장 짧습니다. 그러나 시퀀스 길이에 대한 이차 계산 복잡도 때문에, 매우 긴 시퀀스에서는 자기-어텐션이 지나치게 느려질 수 있습니다. 시퀀스 순서 정보를 사용하기 위해, 입력 표현에 위치 인코딩을 더해 절대적 또는 상대적 위치 정보를 주입할 수 있습니다.
위치 인코딩을 포함한 자기-어텐션 층을 여러 층 쌓아 시퀀스를 표현하는 딥 아키텍처를 설계한다고 가정하자. 가능한 문제점은 무엇일까?
학습 가능한 위치 인코딩 방법을 설계할 수 있는가?
자기-어텐션에서 비교되는 쿼리와 키 사이의 오프셋이 다를 때, 오프셋에 따라 서로 다른 학습 임베딩을 부여할 수 있을까? 힌트: 상대 위치 임베딩(relative position embeddings)(Huang et al., 2018, Shaw et al., 2018)을 참고하라.