728x90
๋ฐ์ํ
Transformer๋ ๋ ผ๋ฌธ์ผ๋ก๋ง ์ฝ์ด๋ดค์ง, ์ฝ๋๋ก ๋ฏ์ด๋ณด๋ ๊ฒ์ ์ฒ์์ด๋ค.
๋ ผ๋ฌธ ์ ์๋ค์ ์ ๋ง ์ฒ์ฌ๊ฐ ๋ง๋ ๊ฒ ๊ฐ๋ค.
์ ํ๋ธ๋ฅผ ์ฐธ๊ณ ํด์ ์ฝ๋๋ฅผ ๊ตฌํํ์์ผ๋ฉฐ, ์ด๋ฒ ํฌ์คํ ์ ์ค๋ก์ง ์ํคํ ์ฒ์๋ง ์ด์ ์ ๋ง์ท๋ค.
๋ฐ์ดํฐ ๋ถ๋ถ์ ๋ค์์ฃผ์ ์ฌ๋ฆด ์์ .
1. Input Embedding ๊ตฌํํ๊ธฐ
import torch
import torch.nn as nn
import math
#Input embedding
class InputEmbeddings(nn.Module):
#d ์ฐจ์ ์ค์ , vocab size ์ค์ (์ผ๋ง๋ ๋ง์ ๋จ์ด ๋ฃ์๊ฑด์ง)
def __init__(self,d_model : int, vocab_size : int):
super().__init__()
self.d_model = d_model
self.vocab_size = vocab_size
#Input Embedding (๋จ์ด ์ฌ์ด์ฆ์ ์ฐจ์)
self.embedding = nn.Embedding(vocab_size, d_model)
def forward(self,x):
#๋ฃจํธ(d์ฐจ์)์ ๊ณฑํด์ค (๊ฐ์ค์น ๊ฐ๋
์ผ๋ก)
return self.embedding(x) * math.sqrt(self.d_model)
2. Positional Encoding ๊ตฌํํ๊ธฐ
#Positional Encoding
class PositionalEncoding(nn.Module):
#ํจ์ ๋ฆฌํด ๊ฐ์ ์ฃผ์ ์ญํ (-> None)
#ํด๋น ํจ์์ ๋ฐํ ํ์
์ ์์ ํ์
์ ๋ํ๋ด๊ธฐ ์ํด ์ฌ์ฉํ๋ค๊ณ (only for ์ฝ๋ ๊ฐ๋
์ฑ)
def __init__(self, d_model : int, seq_len: int, dropout:float) -> None:
super().__init__()
self.d_model = d_model
self.seq_len = seq_len
self.dropout = nn.Dropout(dropout)
#1. ๋น ํ
์ ์
์ฑ (seq_len, d_model)
pe = torch.zeros(seq_len,d_model)
#2. row ๋ฐฉํฅ์ผ๋ก (0~seq_len) ์์ฑ (unsqueeze(dim=1))
# ๋จ์ด์ ์์น๋ฅผ ์๋ฏธํจ
position = torch.arange(0,seq_len, dtype = torch.float).unsqueeze(1)
#3. col ๋ฐฉํฅ์ผ๋ก step=2๋ฅผ ํ์ฉํ์ฌ i์ 2๋ฐฐ์๋ฅผ ๋ง๋ฆ (0~2i)
_2i = torch.exp(torch.arange(0,d_model,2,dtype=torch.float))
#4. cos, sine ํจ์ ์ ์
#์ด ๊ธฐ์ค step 2์ฉ ๊ฐ๊ฒ ๋ค๋ ์๋ฏธ (0::2)
pe[:,0::2] = torch.sin(position/10000**(_2i/d_model))
pe[:,1::2] = torch.cos(position/10000**(_2i/d_model))
#์ฐจ์ ์ถ๊ฐ
#(0๋ฒ์งธ์ => batch ์ฐจ์์ ๋ํ๋ด๊ธฐ ์ํจ์)
#๊ธฐ์กด : seq_len, d_model => 1,seq_len,d_model
pe = pe.unsqueeze(0) #(1,Seq_len,d_model)
#https://velog.io/@nawnoes/pytorch-%EB%AA%A8%EB%8D%B8%EC%9D%98-%ED%8C%8C%EB%9D%BC%EB%AF%B8%ED%84%B0%EB%A1%9C-%EB%93%B1%EB%A1%9D%ED%95%98%EC%A7%80-%EC%95%8A%EA%B8%B0-%EC%9C%84%ED%95%9C-registerbuffer
#๋ฒํผ๋ ๋ฐ์ดํฐ๋ฅผ ํ ๊ณณ์์ ๋ค๋ฅธ ํ ๊ณณ์ผ๋ก ์ ์กํ๋ ๋์ ์ผ์์ ์ผ๋ก ๊ทธ ๋ฐ์ดํฐ๋ฅผ ๋ณด๊ดํ๋ ๋ฉ๋ชจ๋ฆฌ์ ์์ญ
#๋ชจ๋ธ์ ํ์ต ๊ฐ๋ฅํ ๋งค๊ฐ๋ณ์๋ ์๋์ง๋ง ๋ชจ๋ธ๊ณผ ํจ๊ป ์ ์ฅ ๋ฐ ๋ก๋๋์ด์ผ ํ๋ ์ํ๋ฅผ ์๋ฏธํจ!
#ํด๋น ๋ชจ๋์ ์ฌ์ฉํ๋ฉด ๋ชจ๋ธ์ด ์ ์ฅ๋ ๋๋ ๋ถ๋ฌ์ฌ ๋ ํด๋น ๋ฒํผ๋ ํจ๊ป ์ ์ฅ ๋ฐ ๋ก๋
self.register_buffer('pe',pe)
def forward(self,x):
# input x + positional encoding
# ์
๋ ฅ ๋ฐ์ดํฐ์ ๊ฐ ์์น์ ํด๋นํ๋ ์์น ์๋ฒ ๋ฉ์ ๊ฐ์ ธ์ค๋ ๋ถ๋ถ
# ์ญ์ ํ ํ ํ์ x
x = x + (self.pe[:,:x.shape[1],:]).requires_grad_(False)
return self.dropout(x)
3. Layer Normalization ๊ตฌํํ๊ธฐ
#๋ ์ด์ด ์ ๊ทํ
class LayerNormalization(nn.Module):
def __init__(self, eps: float = 10**-6) -> None :
super().__init__()
self.eps = eps
self.alpha = nn.Parameter(torch.ones(1)) #Multiplied
self.bias = nn.Parameter(torch.zeros(1)) #Added
def forward(self,x):
mean = x.mean(dim =-1, keepdim=True)
std = x.std(dim= -1, keepdim =True)
return self.alpha * (x-mean) / (std+ self.eps) + self.bias
4. Feed Forward ๊ตฌํํ๊ธฐ
#FeedForward
class FeedForwardBlock(nn.Module):
def __init__(self, d_model : int, d_ff : int, dropout : float) -> None:
super().__init__()
# feed forward upwards projection size(d_ff=2048)
self.linear_1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear_2 = nn.Linear(d_ff, d_model) #์ฐจ์์ ๋ค์ 512์ฐจ์์ผ๋ก
def forward(self,x):
return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
5. Multi-Head Attention ๊ตฌํํ๊ธฐ
#Multi-Head Attention
class MultiHeadAttentionBlock(nn.Module):
#h : head ๊ฐ์
def __init__(self,d_model : int, h: int, dropout: float) -> None:
super().__init__()
self.d_model = d_model
self.h = h
#AssertionError ์คํ
#๋๋์ด ๋จ์ด์ง์ง ์์ผ๋ฉด ์ค์ง์ํด
assert d_model % h ==0, 'd_model is not divisible by h'
self.d_k = d_model // h
self.w_q = nn.Linear(d_model, d_model) #Wq
self.w_k = nn.Linear(d_model,d_model) #Wk
self.w_v = nn.Linear(d_model,d_model) #Wv
#concat ํ๋ ๋ถ๋ถ์์์ wo๊ฐ
self.w_o = nn.Linear(d_model,d_model) #Wo
self.dropout = nn.Dropout(dropout)
#class ๋ฐ์์ ์ ์ธ๋ def ํจ์์ ๊ฐ์(์ ์ ๋ฉ์๋)
#๊ตณ์ด ์ธ์คํด์ค๋ฅผ ์์ฑํ์ง ์๊ณ ๋ ํธ์ถํ ์ ์๋ค.
#ex) MultiHeadAttention.attention()
#ํน์ ์ธ์คํด์ค์ ์ํ์ ์์กดํ์ง ์๊ณ ํด๋์ค ์์ค์์ ์ํ๋์ด์ผ ํ ๋ ์ ์ฉ
@staticmethod
def attention(query, key, value, mask, dropout : nn.Dropout):
d_k = query.shape[-1] #(Batch, seq_len, d_model)
#(Batch, h, Seq_len, d_k) -> (Batch, h, Seq_len, Seq_len)
#์ฟผ๋ฆฌ์ ํค ๊ฐ์ ๋ด์ ๊ฐ ๊ตฌํ๊ธฐ -> ์ค์ผ์ผ๋ง
attention_scores = (query @ key.transpose(-2,-1)) / math.sqrt(d_k)
#mask ๋ถ๋ถ : ๋ง์ฝ mask๊ฐ ์ฃผ์ด์ก๋ค๋ฉด,
#0์ด ์๋ ๋ถ๋ถ์ ๋งค์ฐ ์์ ๊ฐ(-1e9)์ผ๋ก ์ฑ์ ๋ง์คํน
if mask is not None:
attention_scores.masked_fill_(mask==0,-1e9)
#์ํํธ๋งฅ์ค ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ์ดํ
์
์ค์ฝ์ด๋ฅผ ํ๋ฅ ๋ถํฌ๋ก ๋ณํ
#๊ฐ ์์น์ ๋ํ ์ดํ
์
๊ฐ์ค์น๊ฐ ๊ณ์ฐ
attention_scores = attention_scores.softmax(dim=-1) #(Batch,h,seq_len, seq_len)
#๋๋กญ์์์ด ์ ๊ณต๋์๋ค๋ฉด ์ดํ
์
๊ฐ์ค์น์ ๋๋กญ์์์ ์ ์ฉ
if dropout is not None :
attention_scores = dropout(attention_scores)
#์ต์ข
์ ์ผ๋ก ์ดํ
์
๊ฐ์คํฉ๋ ๊ฒฐ๊ณผ์ ์ดํ
์
๊ฐ์ค์น๋ฅผ ๋ฐํ
return (attention_scores @ value), attention_scores
def forward(self, q, k, v, mask):
#1. Q,K,V๋ฅผ d_k, d_k, d_v ์ฐจ์์ผ๋ก projection
query = self.w_q(q) #(Batch, seq_len, d_model) -> (Batch, seq_len, d_model)
key = self.w_k(k)
value = self.w_v(v)
#Q,K,V๋ฅผ head ์ ๋งํผ ๋ถ๋ฆฌํด์ฃผ๊ธฐ
#(Batch, seq_len, d_model) -> (Batch, Seq_len, h, d_k) -> (Batch, h, Seq_len, d_k)
query = query.view(query.shape[0],query.shape[1], self.h, self.d_k).transpose(1,2)
key = key.view(key.shape[0],key.shape[1], self.h, self.d_k).transpose(1,2)
value = value.view(value.shape[0],value.shape[1], self.h, self.d_k).transpose(1,2)
x, self.attention_scores = MultiHeadAttentionBlock.attention(query,key,value,mask,self.dropout)
#(Batch, h, Seq_len, d_k) -> (Batch, Seq_len, h, d_k) -> (Batch,Seq_len, d_k)
#https://ebbnflow.tistory.com/351
#contiguous(์ธ์ ํ) : Tensor์ ๊ฐ ๊ฐ๋ค์ด ๋ฉ๋ชจ๋ฆฌ์๋ ์์ฐจ์ ์ผ๋ก ์ ์ฅ๋์ด ์๋์ง ์ฌ๋ถ๋ฅผ ์๋ฏธ
x = x.transpose(1,2).contiguous().view(x.shape[0],-1, self.h*self.d_k) # -1์ ๋๋จธ์ง ์ฐจ์์ ์๋์ผ๋ก ์กฐ์ ํ๋ผ๋ ์๋ฏธ
#(Batch,Seq_len, d_model) -> (Batch, seq_len, d_model)
return self.w_o(x)
6. ResidualConnection ๊ตฌํํ๊ธฐ
#ResidualConnection
class ResidualConnection(nn.Module):
def __init__(self, dropout : float) -> None:
super().__init__()
self.dropout = nn.Dropout(dropout)
self.norm = LayerNormalization()
#sublayer ?
def forward(self,x,sublayer):
return x + self.dropout(sublayer(self.norm(x)))
7. EncoderBlock ๊ตฌํํ๊ธฐ
class EncoderBlock(nn.Module):
def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block : FeedForwardBlock, dropout : float) -> None:
super().__init__()
self.self_atttention_block = self_attention_block
self.feed_forward_block = feed_forward_block
#ResidualConnection(dropout) for _ in range(2) : self attention ๋ถ๋ถ, Feed Forward ๋ถ๋ถ์์ ๋ ๋ฒ์ skip connection์ด ์คํ
self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])
def forward(self,x,src_mask):
#attention ๋ถ๋ถ skip connection ์คํ
x = self.residual_connections[0](x, lambda x: self.self_atttention_block(x,x,x,src_mask))
#feed forward ๋ถ๋ถ skip connection ์คํ
x = self.residual_connections[1](x, self.feed_forward_block)
return x
8. Encoder ๊ตฌํํ๊ธฐ
class Encoder(nn.Module):
def __init__(self, layers : nn.ModuleList) -> None :
super().__init__()
self.layers = layers
self.norm = LayerNormalization()
def forward(self,x,mask):
for layer in self.layers:
x = layer(x,mask)
return self.norm(x)
9. DecoderBlock ๊ตฌํํ๊ธฐ
#DecoderBlock
class DecoderBlock(nn.Module):
def __init__(self,self_attention_block : MultiHeadAttentionBlock, cross_attention_block : MultiHeadAttentionBlock, feed_forward_block : FeedForwardBlock,dropout:float) -> None:
super().__init__()
self.self_attention_block = self_attention_block
self.cross_attention_block = cross_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.Module([ResidualConnection(dropout) for _ in range(3)])
#tgt_mask: ๋์ฝ๋์ ํ์ฌ ์์น ์ดํ์ ๋จ์ด๋ค์ ๊ฐ๋ ค์ฃผ๋ ๋ง์คํฌ
#src_mask: ์ธ์ฝ๋ ์ถ๋ ฅ์์ ํจ๋ฉ ํ ํฐ์ ํด๋นํ๋ ์์น๋ฅผ 0์ผ๋ก, ์ค์ ๋จ์ด์ ํด๋นํ๋ ์์น๋ฅผ 1๋ก ์ฑ์ด ์ด์ง ๋ง์คํฌ
#๋จ์ด ๊ธธ์ด๋ฅผ ๋ง์ถ๊ธฐ ์ํจ(์ฐ์ฐ๋ ๊ฐ์)
def forward(self,x,encoder_output, src_mask, tgt_mask):
x = self.residual_connections[0](x, lambda x: self.self_attention_block(x,x,x,tgt_mask))
#encoder output ๊ฐ์ ๋ฐ๋๋ค(cross_attention)
x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x,encoder_output,encoder_output,src_mask))
x = self.residual_connections[2](x, self.feed_forward_block)
return x
10. Decoder ๊ตฌํํ๊ธฐ
#Decoder
class Decoder(nn.Module):
def __init__(self,layers: nn.ModuleList) -> None:
super().__init__()
self.layers = layers
self.norm = LayerNormalization()
def forward(self,x,encoder_output,src_mask,tgt_mask):
for layer in self.layers:
x = layer(x,encoder_output,src_mask,tgt_mask)
return self.norm(x)
11. ProjectionLayer ๊ตฌํํ๊ธฐ
#ProjectionLayer
class ProjectionLayer(nn.Module):
def __init__(self,d_model : int, vocab_size : int) -> None:
super().__init__()
self.proj = nn.Linear(d_model, vocab_size)
def forward(self,x):
#(Batch, seq_len,d_model) -> (Batch,seq_len,vocab_size)
return torch.log_softmax(self.proj(x), dim=-1)
12. Tranfomer ๊ตฌํํ๊ธฐ
#Transformer
class Transformer(nn.Module):
def __init__(self,encoder :Encoder, decoder : Decoder, src_embed : InputEmbeddings, tgt_embed : InputEmbeddings, src_pos : PositionalEncoding, tgt_pos :PositionalEncoding, projection_layer : ProjectionLayer) -> None:
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.src_pos = src_pos
self.tgt_pos = tgt_pos
self.projection_layer = projection_layer
def encode(self, src,src_mask):
src = self.src_embed(src)
src = self.src_pos(src)
return self.encoder(src,src_mask)
def decode(self,encoder_output,src_mask,tgt,tgt_mask):
tgt = self.tgt_embed(tgt)
tgt = self.tgt_pos(tgt)
return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
def project(self,x):
return self.projection_layer(x)
def build_transformer(src_vocab_size : int, tgt_vocab_size : int, src_seq_len : int, tgt_seq_len : int, d_model : int=512, N:int = 6, h : int = 8, dropout : float=0.1, d_ff : int=2048 ) -> Transformer:
#Create Embedding layers
src_embed = InputEmbeddings(d_model, src_vocab_size)
tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)
# Create positional encoding layers
src_pos = PositionalEncoding(d_model,src_seq_len,dropout)
tgt_pos = PositionalEncoding(d_model,tgt_vocab_size,dropout)
#Create encoder blocks
encoder_blocks=[]
for _ in range(N):
encoder_self_attention_block = MultiHeadAttentionBlock(d_model,h,dropout)
feed_forward_block = FeedForwardBlock(d_model,d_ff, dropout)
encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
encoder_blocks.append(encoder_block)
#Create decoder blocks
decoder_blocks=[]
for _ in range(N):
decoder_self_attention_block = MultiHeadAttentionBlock(d_model,h,dropout)
decoder_cross_attention_block = MultiHeadAttentionBlock(d_model,h,dropout)
feed_forward_block = FeedForwardBlock(d_model,d_ff, dropout)
decoder_block = DecoderBlock(decoder_self_attention_block, decoder_self_attention_block, feed_forward_block, dropout)
decoder_blocks.append(decoder_block)
#Create encoder and decoder
encoder = Encoder(nn.ModuleList(encoder_blocks))
decoder = Decoder(nn.ModuleList(decoder_blocks))
#Create projection layer
projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
#Create the transformer
transformer = Transformer(encoder, decoder, src_embed,tgt_embed, src_pos, tgt_pos, projection_layer)
#initial parameters
for p in transformer.parameters():
if p.dim() >1 :
nn.init.xavier_uniform_(p)
return transformer
์ฐธ๊ณ
https://www.youtube.com/watch?v=ISNdQcPhsts
https://code-angie.tistory.com/9
https://code-angie.tistory.com/7#3-position-wise-fully-connected-feed-forward-network
๊ฐ์ฌํฉ๋๋ค.
728x90
๋ฐ์ํ
'Deep Learning > [์ฝ๋ ๊ตฌํ] DL Architecture ๊ตฌํ' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[Transformer] train.py, dataset.py, config.py, Mask ๊ตฌํํ๊ธฐ - 2 (Pytorch) (0) | 2024.02.21 |
---|---|
[UNet] copy and crop ์ฝ๋ ๊ตฌํ ๋ฐ ์ํคํ ์ฒ ๊ตฌํํ๊ธฐ (Pytorch) (1) | 2024.02.08 |