๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
Deep Learning/[์ฝ”๋“œ ๊ตฌํ˜„] DL Architecture ๊ตฌํ˜„

[Transformer] train.py, dataset.py, config.py, Mask ๊ตฌํ˜„ํ•˜๊ธฐ - 2 (Pytorch)

by ์ œ๋ฃฝ 2024. 2. 21.
728x90
๋ฐ˜์‘ํ˜•

์ง€๋‚œ ์‹œ๊ฐ„์— ์ด์–ด, ์˜ค๋Š˜์€ ๋‚˜๋จธ์ง€ train.py, config.py, dataset.py ํŒŒ์ผ์„ ๊ตฌํ˜„ํ–ˆ๋‹ค.

 

https://www.youtube.com/watch?v=ISNdQcPhsts

 

์ด ๋ถ„ ์ฝ”๋“œ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๊ตฌํ˜„ํ•˜์˜€์Šต๋‹ˆ๋‹ค.


Transformer Overall Architecture

 

1. Dataset.py ๊ตฌํ˜„

1-1. Bilingual Dataset

ํŒจ๋”ฉ ์ถ”๊ฐ€ ๋ถ€๋ถ„ ์ฐธ๊ณ  (input, output)

 

์‚ฌ์šฉํ•œ ๋ฐ์ดํ„ฐ์…‹์€ Hugging Face์—์„œ ์ œ๊ณตํ•˜๋Š” opus_books Dataset์„ ํ™œ์šฉํ•˜์˜€๋‹ค.

https://huggingface.co/datasets/opus_books/viewer/en-it

 

opus_books · Datasets at Hugging Face

{ "en": "Nor could I pass unnoticed the suggestion of the bleak shores of Lapland, Siberia, Spitzbergen, Nova Zembla, Iceland, Greenland, with \"the vast sweep of the Arctic Zone, and those forlorn regions of dreary space,--that reservoir of frost and snow

huggingface.co

import torch
import torch.nn as nn
from torch.utils.data import Dataset

class BilingualDataset(Dataset):

    #ํด๋ž˜์Šค๋ฅผ ์ƒ์„ฑํ•  ๋•Œ ์‹คํ–‰๋˜๋Š” ์ƒ์„ฑ์ž
    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len) -> None:
        super().__init__()

        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

        self.sos_token = torch.Tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64)
        self.eos_token = torch.Tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64)
        self.pad_token = torch.Tensor([tokenizer_src.token_to_id('[PAD]')], dtype=torch.int64)

    #์›์†Œ์˜ ๊ฐœ์ˆ˜๋ฅผ ์…€ ๋•Œ ์ ‘๊ทผ๋˜๋Š” ๋ฉ”์„œ๋“œ
    def __len__(self):
        return len(self.ds)
    
    #์ธ๋ฑ์Šค์— ์ ‘๊ทผํ•  ๋•Œ ํ˜ธ์ถœ๋˜๋Š” ๋ฉ”์„œ๋“œ
    def __getitem__(self, index : Any) -> Any:
        src_target_pair = self.ds[index]
        src_text = src_target_pair['translation'][self.src_lang]
        tgt_text = src_target_pair['translation'][self.tgt_lang] 

        # ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ๋ฅผ ํ† ํฐํ™”ํ•˜๊ณ , ๊ทธ ๊ฒฐ๊ณผ์—์„œ ํ† ํฐ IDs๋ฅผ ์ถ”์ถœ
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids        

        #seq_len(์‹œํ€€์Šค ๊ธธ์ด)์— ๋งž์ถ”๊ธฐ ์œ„ํ•ด ํŒจ๋”ฉ์„ ์ถ”๊ฐ€ํ•ด์คŒ
        #์›๋ณธ ์–ธ์–ด์˜ ์ž…๋ ฅ ์‹œํ€€์Šค์— ์ถ”๊ฐ€ํ•ด์•ผ ํ•˜๋Š” ํŒจ๋”ฉ ํ† ํฐ์˜ ์ˆ˜
        # SOS, EOS token (-2)
        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
        #๋Œ€์ƒ ์–ธ์–ด์˜ ์ถœ๋ ฅ ์‹œํ€€์Šค์— ์ถ”๊ฐ€ํ•ด์•ผ ํ•˜๋Š” ํŒจ๋”ฉ ํ† ํฐ์˜ ์ˆ˜
        # EOS token (-1)
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
        
        #self.seq_len : 10
        #enc_input_tokens : ["์•ˆ๋…•", "ํ•˜์„ธ์š”", "๋ฐ˜๊ฐ‘์Šต๋‹ˆ๋‹ค"]
        #dec_input_tokens : ["Hello", "World"]
        #enc_num_padding_tokens: 10 - len(["์•ˆ๋…•", "ํ•˜์„ธ์š”", "๋ฐ˜๊ฐ‘์Šต๋‹ˆ๋‹ค"]) - 2 = 10 - 3 - 2 = 5
        #dec_num_padding_tokens: 10 - len(["Hello", "World"]) - 1 = 10 - 2 - 1 = 7
        #['sos',"์•ˆ๋…•", "ํ•˜์„ธ์š”", "๋ฐ˜๊ฐ‘์Šต๋‹ˆ๋‹ค", 'eos',ํŒจ๋”ฉ, ํŒจ๋”ฉ, ํŒจ๋”ฉ, ํŒจ๋”ฉ, ํŒจ๋”ฉ]
        #['Hello', 'World', 'EOS', ํŒจ๋”ฉ, ํŒจ๋”ฉ, ํŒจ๋”ฉ, ํŒจ๋”ฉ, ํŒจ๋”ฉ, ํŒจ๋”ฉ, ํŒจ๋”ฉ]
        
        #0๋ณด๋‹ค ์ž‘์„ ๊ฒฝ์šฐ, ์—๋Ÿฌ ๋ฐœ์ƒ ex. 10 - 9 - 2 = -1 
        #์ „์ฒด ๋ฌธ์žฅ ๊ธธ์ด๋ณด๋‹ค input ๋ฌธ์žฅ ๊ธธ์ด๊ฐ€ ๋” ๊ธด ๊ฒฝ์šฐ,
        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError("Sentence is too long")

 

 

1-2. ํ† ํฐ Concat

        # ํ† ํฐ Concat

        # Add <s> and </s> token
        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )        
        

        # Add only <sos> token
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only <eos> token
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Double check the size of the tensors to make sure they are all seq_len long
        # Asser  : if ๊ฐœ๋…
        # ์ตœ๋Œ€ ์‹œํ€€์Šค์™€ input+ํŒจ๋”ฉ ํ•œ ๊ฒฐ๊ณผ๊ฐ€ ๋™์ผํ•œ ๊ธธ์ด์ธ์ง€๋ฅผ ๋ฌผ์–ด๋ด„
        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return {
            "encoder_input": encoder_input,  # (seq_len)
            "decoder_input": decoder_input,  # (seq_len)

            #์ฐจ์›์„ ๋งž์ถฐ์ฃผ๊ธฐ ์œ„ํ•ด ์‹คํ–‰
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
            # ํ˜„์žฌ ์œ„์น˜ ์ดํ›„์˜ ํ† ํฐ ๋งˆ์Šคํฌ ์ฒ˜๋ฆฌ (ํŒจ๋”ฉ ํ† ํฐ์ด ์•„๋‹Œ ์œ„์น˜์— ๋Œ€ํ•ด True, ํŒจ๋”ฉ ํ† ํฐ์ธ ์œ„์น˜์— ๋Œ€ํ•ด False ๋ฐ˜ํ™˜) 
            # -> ํŠน์ • ์œ„์น˜์— ํ† ํฐ์— ๋Œ€ํ•œ attention์„ ์ˆ˜ํ–‰ํ• ์ง€ ๋ง์ง€๋ฅผ ๊ฒฐ์ •
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
            "label": label,  # (seq_len) ๋””์ฝ”๋”์˜ ์‹ค์ œ ์ถœ๋ ฅ์— ํ•ด๋‹นํ•˜๋Š” ๊ฐ’
            "src_text": src_text, #input text
            "tgt_text": tgt_text, #target text

 

 

1-3. Mask ๊ตฌํ˜„

 

def causal_mask(size):
    #https://incredible.ai/nlp/2020/02/29/Transformer/#241-padding-mask-%ED%95%B5%EC%8B%AC-%EB%82%B4%EC%9A%A9
    #triu : ์ •์‚ฌ๊ฐํ˜•์˜ n x n ์ด ์žˆ์„ ๋•Œ, ์•„๋ž˜์ชฝ ์‚ผ๊ฐ ๋ถ€๋ถ„์€ 0์œผ๋กœ, ์œ„์ชฝ ์‚ผ๊ฐ ๋ถ€๋ถ„๋งŒ 1๋กœ ๋ฆฌํ„ด
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

 

 

2. Config.py ๊ตฌํ˜„

from pathlib import Path

def get_config():
    return {
        "batch_size": 8,
        "num_epochs": 20,
        "lr": 10**-4,
        "seq_len": 350,
        "d_model": 512,
        "datasource": 'opus_books',
        "lang_src": "en",
        "lang_tgt": "it",
        "model_folder": "weights",
        "model_basename": "tmodel_",
        "preload": "latest",
        "tokenizer_file": "tokenizer_{0}.json",
        "experiment_name": "runs/tmodel"
    }

#weight ํด๋” ๋ฐ ์ด๋ฆ„ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
def get_weights_file_path(config, epoch: str):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}{epoch}.pt"
    return str(Path('.') / model_folder / model_filename)

# Find the latest weights file in the weights folder
def latest_weights_file_path(config):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}*"
    weights_files = list(Path(model_folder).glob(model_filename))
    if len(weights_files) == 0:
        return None
    #๊ฐ€์ ธ์˜จ ํŒŒ์ผ๋“ค์„ ํŒŒ์ผ์˜ ์ˆ˜์ • ์‹œ๊ฐ„์„ ๊ธฐ์ค€์œผ๋กœ ์ •๋ ฌ
    #๋งจ ๋งˆ์ง€๋ง‰ ํŒŒ์ผ์ด ๊ฐ€์žฅ ์ตœ๊ทผ์— ์ƒ์„ฑ๋œ ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ํŒŒ์ผ์ด ๋จ!
    weights_files.sort()
    return str(weights_files[-1])

 

 

3. Train.py (+Validation)

import torch
import torch.nn. as nn
from torch.utils.data import Dataset, DataLoader, random_split

from dataset import BilingualDataset, causal_mask
from model import build_transformer
from config import get_config, get_weights_file_path, latest_weights_file_path

from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers_pre_tokenizers import Whitespace

from torch.utils.tensorboard import SummaryWriter

import warnings
from tqdm import tqdm
from pathlib import Path

#๋””์ฝ”๋”์—์„œ ๊ฐ ์‹œ์ ๋งˆ๋‹ค ๊ฐ€์žฅ ํ™•๋ฅ ์ด ๋†’์€ ๋‹จ์–ด๋ฅผ ์„ ํƒํ•ด 
#๋‹ค์Œ ์‹œ์ ์˜ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•˜๋Š” ๊ฐ„๋‹จํ•œ ๋””์ฝ”๋”ฉ ์ „๋žต ์ค‘ ํ•˜๋‚˜
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        #๋””์ฝ”๋” ์ž…๋ ฅ์˜ ๊ธธ์ด๊ฐ€ max_len์— ๋„๋‹ฌํ•˜๋ฉด ๋ฃจํ”„๋ฅผ ์ข…๋ฃŒ
        if decoder_input.size(1) == max_len:
            break

        # build mask for target
        #ํ˜„์žฌ๊นŒ์ง€ ์ƒ์„ฑ๋œ ๋””์ฝ”๋” ์ž…๋ ฅ์— ๋Œ€ํ•œ ์–ดํ…์…˜ ๋งˆ์Šคํฌ๋ฅผ ์ƒ์„ฑ
        #ํ˜„์žฌ ์œ„์น˜ ์ดํ›„์˜ ํ† ํฐ์— ๋Œ€ํ•œ ์–ดํ…์…˜์„ ๋ฐฉ์ง€
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # calculate output
        # ๋””์ฝ”๋”๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ˜„์žฌ๊นŒ์ง€์˜ ๋””์ฝ”๋” ์ž…๋ ฅ์— ๋Œ€ํ•œ ์ถœ๋ ฅ์„ ๊ณ„์‚ฐ
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        # get next token
        # ๋””์ฝ”๋” ์ถœ๋ ฅ์˜ ๋งˆ์ง€๋ง‰ ํ† ํฐ์— ๋Œ€ํ•œ ํ™•๋ฅ  ๋ถ„ํฌ๋ฅผ ๊ณ„์‚ฐ
        prob = model.project(out[:, -1])
        #๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์„ ๋‹ค์Œ ๋‹จ์–ด๋กœ ์„ ์ •
        #_, : ๋ณ€์ˆ˜ ๊ฐ’์„ ๊ตณ์ด ์‚ฌ์šฉํ•  ํ•„์š”๊ฐ€ ์—†์„ ๋•Œ ์‚ฌ์šฉ
        _, next_word = torch.max(prob, dim=1)
        #์„ ํƒ๋œ ๋‹ค์Œ ๋‹จ์–ด๋ฅผ ํ˜„์žฌ๊นŒ์ง€์˜ ๋””์ฝ”๋” ์ž…๋ ฅ
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        #๋งŒ์•ฝ ๋‹ค์Œ ๋‹จ์–ด๊ฐ€ eos ํ† ํฐ์ด๋ฉด ์ค‘์ง€
        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)


def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
    model.eval()
    count = 0

    # source_texts = []
    # expected = []
    # predicted = []

    #TMI 
    try:
        # get the console window width
        with os.popen('stty size', 'r') as console:
            _, console_width = console.read().split()
            console_width = int(console_width)
    except:
        # If we can't get the console width, use 80 as default
        console_width = 80

    #Val ์ง„ํ–‰
    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
            encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)

            # check that the batch size is 1
            assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"

            #Greedy Decode ์ ์šฉ
            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
            
            #ํ˜„์žฌ ๋ฐฐ์น˜์—์„œ ์†Œ์Šค ๋ฌธ์žฅ๊ณผ ํƒ€๊นƒ ๋ฌธ์žฅ์„ ๊ฐ€์ ธ์˜ด
            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            #๋ชจ๋ธ์˜ ์ถœ๋ ฅ์„ ๋””์ฝ”๋”ฉํ•˜์—ฌ ํ…์ŠคํŠธ๋กœ ๋ณ€ํ™˜ํ•จ
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

            # #๊ฐ๊ฐ ์†Œ์Šค ๋ฌธ์žฅ, ํƒ€๊นƒ ๋ฌธ์žฅ, ๋ชจ๋ธ์˜ ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ ๋ฆฌ์ŠคํŠธ์— ์ถ”๊ฐ€
            # source_texts.append(source_text)
            # expected.append(target_text)
            # predicted.append(model_out_text)
            
            # Print the source, target and model output
            print_msg('-'*console_width)
            #์šฐ์ธก ์ •๋ ฌ
            print_msg(f"{f'SOURCE: ':>12}{source_text}")
            print_msg(f"{f'TARGET: ':>12}{target_text}")
            print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")

            #์ง€์ •๋œ ์˜ˆ์ œ ์ˆ˜๋งŒํผ ์ถœ๋ ฅํ•˜๋ฉด ์ข…๋ฃŒ 
            if count == num_examples:
                print_msg('-'*console_width)
                break
    
    # if writer:
    #     # Evaluate the character error rate
    #     # Compute the char error rate 
    #     metric = torchmetrics.CharErrorRate()
    #     cer = metric(predicted, expected)
    #     writer.add_scalar('validation cer', cer, global_step)
    #     writer.flush()

    #     # Compute the word error rate
    #     metric = torchmetrics.WordErrorRate()
    #     wer = metric(predicted, expected)
    #     writer.add_scalar('validation wer', wer, global_step)
    #     writer.flush()

    #     # Compute the BLEU metric
    #     metric = torchmetrics.BLEUScore()
    #     bleu = metric(predicted, expected)
    #     writer.add_scalar('validation BLEU', bleu, global_step)
    #     writer.flush()


#๋ชจ๋“  ๋ฌธ์žฅ ๊ฐ€์ ธ์˜ค๊ธฐ
def get_all_sentences(ds,lang):
    for item in ds:
        #yield : ๋ฐ˜๋ณต๋  ๋•Œ๋งˆ๋‹ค ํ•˜๋‚˜์˜ ๊ฐ’์„ ์ƒ์„ฑ <-> ๊ธฐ์กด def๋Š” ๋‹จ์ผ ๊ฒฐ๊ณผ๋ฌผ๋งŒ ๋ฐ˜ํ™˜
        yield item['translation'][lang]

#ํ† ํฌ๋‚˜์ด์ € ๊ฐ€์ ธ์˜ค๊ฑฐ๋‚˜ ์ƒˆ๋กœ ์ƒ์„ฑ 
def get_or_build_tokenizer(config, ds, lang): #datasets, langauage
    # config['tokenizer_file] = '../tokenizers/tokenizer_{0}.json
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    #HuggingFace ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token='[UNK]'))
        tokenizer.pre_tokenizer = Whitespace() #๊ณต๋ฐฑ์„ ๊ธฐ์ค€์œผ๋กœ ๋ถ„๋ฆฌ 
        #WordLevelTrainer : ํ•™์Šต ๋ฐ ์ถ”๋ก (๋‹จ์–ด ํ† ํฐ์„ ํ•™์Šตํ•˜๋Š”๋ฐ ์‚ฌ์šฉ๋˜๋Š” ๋ฉ”์„œ๋“œ)
        trainer = WordLevelTrainer(special_tokens=['[UNK]', '[PAD]','[SOS]','[EOS]']) 
        tokenizer.train_from_iterator(get_all_sentences(ds,lang),trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

#train, val data๋กœ ๋‚˜๋ˆ„๋Š” ํ•จ์ˆ˜
def get_ds(config):
    ds_raw = load_dataset('opus_books',f'{config['lang_src']}-{config['lang_tgt']}',split='train')

    #Build Tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config=['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config=['lang_src'])

    #Keep 90% for training and 10% for validation
    train_ds_size = int(0.9+len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw,[train_ds_size, val_ds_size])

    #Dataset class ์ ์šฉ
    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

    # Find the maximum length of each sentence in the source and target sentence
    # max seq length ์ฐพ๊ธฐ
    max_len_src = 0
    max_len_tgt = 0

    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')
    
    #Data ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

#๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])
    return model

#Train ํ•จ์ˆ˜
def train_model(config):
    # Define the device
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
    print("Using device:", device)
    if (device == 'cuda'):
        print(f"Device name: {torch.cuda.get_device_name(device.index)}")
        print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
    elif (device == 'mps'):
        print(f"Device name: <mps>")
    else:
        print("NOTE: If you have a GPU, consider using it for training.")
        print("      On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc")
        print("      On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu")
    device = torch.device(device)

    # Make sure the weights folder exists
    Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)

    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    
    #๋ชจ๋ธ ์ ์šฉ
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    # Tensorboard
    writer = SummaryWriter(config['experiment_name'])
    #Adam optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

    # If the user specified a model to preload before training, load it
    initial_epoch = 0
    global_step = 0
    preload = config['preload']
    model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
    if model_filename:
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    else:
        print('No model to preload, starting from scratch')

    #cross Entropy ์‚ฌ์šฉ
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

    for epoch in range(initial_epoch, config['num_epochs']):
        torch.cuda.empty_cache()
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
        for batch in batch_iterator:
            #model.train()
            encoder_input = batch['encoder_input'].to(device) # (B, seq_len)
            decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
            encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
            decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)

            # Run the tensors through the encoder, decoder and the projection layer
            encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
            proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)

            # Compare the output with the label
            label = batch['label'].to(device) # (B, seq_len)

            # Compute the loss using a simple cross entropy
            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

            # Log the loss (Tensorboard ์‚ฌ์šฉํ•  ๋•Œ)
            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            # Backpropagate the loss
            loss.backward()

            # Update the weights
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            #run_validation

            global_step += 1

        # Run validation at the end of every epoch
        run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)

        # Save the model at the end of every epoch
        # ์—ํญ๋งˆ๋‹ค ์ €์žฅ
        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)


if __name__ == '__main__':
    warnings.filterwarnings("ignore")
    config = get_config() #์ง€์ •ํ•œ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ’ ์ ์šฉ
    train_model(config)

 

 

์‹คํ—˜ ๊ฒฐ๊ณผ๋Š” ์ถ”ํ›„ ์˜ฌ๋ฆด ์˜ˆ์ •.

728x90
๋ฐ˜์‘ํ˜•