์ง๋ ์๊ฐ์ ์ด์ด, ์ค๋์ ๋๋จธ์ง train.py, config.py, dataset.py ํ์ผ์ ๊ตฌํํ๋ค.
https://www.youtube.com/watch?v=ISNdQcPhsts
์ด ๋ถ ์ฝ๋๋ฅผ ๋ฐํ์ผ๋ก ๊ตฌํํ์์ต๋๋ค.
1. Dataset.py ๊ตฌํ
1-1. Bilingual Dataset
์ฌ์ฉํ ๋ฐ์ดํฐ์ ์ Hugging Face์์ ์ ๊ณตํ๋ opus_books Dataset์ ํ์ฉํ์๋ค.
https://huggingface.co/datasets/opus_books/viewer/en-it
opus_books · Datasets at Hugging Face
{ "en": "Nor could I pass unnoticed the suggestion of the bleak shores of Lapland, Siberia, Spitzbergen, Nova Zembla, Iceland, Greenland, with \"the vast sweep of the Arctic Zone, and those forlorn regions of dreary space,--that reservoir of frost and snow
huggingface.co
import torch
import torch.nn as nn
from torch.utils.data import Dataset
class BilingualDataset(Dataset):
#ํด๋์ค๋ฅผ ์์ฑํ ๋ ์คํ๋๋ ์์ฑ์
def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len) -> None:
super().__init__()
self.ds = ds
self.tokenizer_src = tokenizer_src
self.tokenizer_tgt = tokenizer_tgt
self.src_lang = src_lang
self.tgt_lang = tgt_lang
self.sos_token = torch.Tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64)
self.eos_token = torch.Tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64)
self.pad_token = torch.Tensor([tokenizer_src.token_to_id('[PAD]')], dtype=torch.int64)
#์์์ ๊ฐ์๋ฅผ ์
๋ ์ ๊ทผ๋๋ ๋ฉ์๋
def __len__(self):
return len(self.ds)
#์ธ๋ฑ์ค์ ์ ๊ทผํ ๋ ํธ์ถ๋๋ ๋ฉ์๋
def __getitem__(self, index : Any) -> Any:
src_target_pair = self.ds[index]
src_text = src_target_pair['translation'][self.src_lang]
tgt_text = src_target_pair['translation'][self.tgt_lang]
# ํ ํฌ๋์ด์ ๋ฅผ ์ฌ์ฉํ์ฌ ํ
์คํธ๋ฅผ ํ ํฐํํ๊ณ , ๊ทธ ๊ฒฐ๊ณผ์์ ํ ํฐ IDs๋ฅผ ์ถ์ถ
enc_input_tokens = self.tokenizer_src.encode(src_text).ids
dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
#seq_len(์ํ์ค ๊ธธ์ด)์ ๋ง์ถ๊ธฐ ์ํด ํจ๋ฉ์ ์ถ๊ฐํด์ค
#์๋ณธ ์ธ์ด์ ์
๋ ฅ ์ํ์ค์ ์ถ๊ฐํด์ผ ํ๋ ํจ๋ฉ ํ ํฐ์ ์
# SOS, EOS token (-2)
enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
#๋์ ์ธ์ด์ ์ถ๋ ฅ ์ํ์ค์ ์ถ๊ฐํด์ผ ํ๋ ํจ๋ฉ ํ ํฐ์ ์
# EOS token (-1)
dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
#self.seq_len : 10
#enc_input_tokens : ["์๋
", "ํ์ธ์", "๋ฐ๊ฐ์ต๋๋ค"]
#dec_input_tokens : ["Hello", "World"]
#enc_num_padding_tokens: 10 - len(["์๋
", "ํ์ธ์", "๋ฐ๊ฐ์ต๋๋ค"]) - 2 = 10 - 3 - 2 = 5
#dec_num_padding_tokens: 10 - len(["Hello", "World"]) - 1 = 10 - 2 - 1 = 7
#['sos',"์๋
", "ํ์ธ์", "๋ฐ๊ฐ์ต๋๋ค", 'eos',ํจ๋ฉ, ํจ๋ฉ, ํจ๋ฉ, ํจ๋ฉ, ํจ๋ฉ]
#['Hello', 'World', 'EOS', ํจ๋ฉ, ํจ๋ฉ, ํจ๋ฉ, ํจ๋ฉ, ํจ๋ฉ, ํจ๋ฉ, ํจ๋ฉ]
#0๋ณด๋ค ์์ ๊ฒฝ์ฐ, ์๋ฌ ๋ฐ์ ex. 10 - 9 - 2 = -1
#์ ์ฒด ๋ฌธ์ฅ ๊ธธ์ด๋ณด๋ค input ๋ฌธ์ฅ ๊ธธ์ด๊ฐ ๋ ๊ธด ๊ฒฝ์ฐ,
if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
raise ValueError("Sentence is too long")
1-2. ํ ํฐ Concat
# ํ ํฐ Concat
# Add <s> and </s> token
encoder_input = torch.cat(
[
self.sos_token,
torch.tensor(enc_input_tokens, dtype=torch.int64),
self.eos_token,
torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
],
dim=0,
)
# Add only <sos> token
decoder_input = torch.cat(
[
self.sos_token,
torch.tensor(dec_input_tokens, dtype=torch.int64),
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
],
dim=0,
)
# Add only <eos> token
label = torch.cat(
[
torch.tensor(dec_input_tokens, dtype=torch.int64),
self.eos_token,
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
],
dim=0,
)
# Double check the size of the tensors to make sure they are all seq_len long
# Asser : if ๊ฐ๋
# ์ต๋ ์ํ์ค์ input+ํจ๋ฉ ํ ๊ฒฐ๊ณผ๊ฐ ๋์ผํ ๊ธธ์ด์ธ์ง๋ฅผ ๋ฌผ์ด๋ด
assert encoder_input.size(0) == self.seq_len
assert decoder_input.size(0) == self.seq_len
assert label.size(0) == self.seq_len
return {
"encoder_input": encoder_input, # (seq_len)
"decoder_input": decoder_input, # (seq_len)
#์ฐจ์์ ๋ง์ถฐ์ฃผ๊ธฐ ์ํด ์คํ
"encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
# ํ์ฌ ์์น ์ดํ์ ํ ํฐ ๋ง์คํฌ ์ฒ๋ฆฌ (ํจ๋ฉ ํ ํฐ์ด ์๋ ์์น์ ๋ํด True, ํจ๋ฉ ํ ํฐ์ธ ์์น์ ๋ํด False ๋ฐํ)
# -> ํน์ ์์น์ ํ ํฐ์ ๋ํ attention์ ์ํํ ์ง ๋ง์ง๋ฅผ ๊ฒฐ์
"decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
"label": label, # (seq_len) ๋์ฝ๋์ ์ค์ ์ถ๋ ฅ์ ํด๋นํ๋ ๊ฐ
"src_text": src_text, #input text
"tgt_text": tgt_text, #target text
1-3. Mask ๊ตฌํ
def causal_mask(size):
#https://incredible.ai/nlp/2020/02/29/Transformer/#241-padding-mask-%ED%95%B5%EC%8B%AC-%EB%82%B4%EC%9A%A9
#triu : ์ ์ฌ๊ฐํ์ n x n ์ด ์์ ๋, ์๋์ชฝ ์ผ๊ฐ ๋ถ๋ถ์ 0์ผ๋ก, ์์ชฝ ์ผ๊ฐ ๋ถ๋ถ๋ง 1๋ก ๋ฆฌํด
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
return mask == 0
2. Config.py ๊ตฌํ
from pathlib import Path
def get_config():
return {
"batch_size": 8,
"num_epochs": 20,
"lr": 10**-4,
"seq_len": 350,
"d_model": 512,
"datasource": 'opus_books',
"lang_src": "en",
"lang_tgt": "it",
"model_folder": "weights",
"model_basename": "tmodel_",
"preload": "latest",
"tokenizer_file": "tokenizer_{0}.json",
"experiment_name": "runs/tmodel"
}
#weight ํด๋ ๋ฐ ์ด๋ฆ ๋ถ๋ฌ์ค๊ธฐ
def get_weights_file_path(config, epoch: str):
model_folder = f"{config['datasource']}_{config['model_folder']}"
model_filename = f"{config['model_basename']}{epoch}.pt"
return str(Path('.') / model_folder / model_filename)
# Find the latest weights file in the weights folder
def latest_weights_file_path(config):
model_folder = f"{config['datasource']}_{config['model_folder']}"
model_filename = f"{config['model_basename']}*"
weights_files = list(Path(model_folder).glob(model_filename))
if len(weights_files) == 0:
return None
#๊ฐ์ ธ์จ ํ์ผ๋ค์ ํ์ผ์ ์์ ์๊ฐ์ ๊ธฐ์ค์ผ๋ก ์ ๋ ฌ
#๋งจ ๋ง์ง๋ง ํ์ผ์ด ๊ฐ์ฅ ์ต๊ทผ์ ์์ฑ๋ ๋ชจ๋ธ ๊ฐ์ค์น ํ์ผ์ด ๋จ!
weights_files.sort()
return str(weights_files[-1])
3. Train.py (+Validation)
import torch
import torch.nn. as nn
from torch.utils.data import Dataset, DataLoader, random_split
from dataset import BilingualDataset, causal_mask
from model import build_transformer
from config import get_config, get_weights_file_path, latest_weights_file_path
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers_pre_tokenizers import Whitespace
from torch.utils.tensorboard import SummaryWriter
import warnings
from tqdm import tqdm
from pathlib import Path
#๋์ฝ๋์์ ๊ฐ ์์ ๋ง๋ค ๊ฐ์ฅ ํ๋ฅ ์ด ๋์ ๋จ์ด๋ฅผ ์ ํํด
#๋ค์ ์์ ์ ์
๋ ฅ์ผ๋ก ์ฌ์ฉํ๋ ๊ฐ๋จํ ๋์ฝ๋ฉ ์ ๋ต ์ค ํ๋
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
sos_idx = tokenizer_tgt.token_to_id('[SOS]')
eos_idx = tokenizer_tgt.token_to_id('[EOS]')
# Precompute the encoder output and reuse it for every step
encoder_output = model.encode(source, source_mask)
# Initialize the decoder input with the sos token
decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
while True:
#๋์ฝ๋ ์
๋ ฅ์ ๊ธธ์ด๊ฐ max_len์ ๋๋ฌํ๋ฉด ๋ฃจํ๋ฅผ ์ข
๋ฃ
if decoder_input.size(1) == max_len:
break
# build mask for target
#ํ์ฌ๊น์ง ์์ฑ๋ ๋์ฝ๋ ์
๋ ฅ์ ๋ํ ์ดํ
์
๋ง์คํฌ๋ฅผ ์์ฑ
#ํ์ฌ ์์น ์ดํ์ ํ ํฐ์ ๋ํ ์ดํ
์
์ ๋ฐฉ์ง
decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
# calculate output
# ๋์ฝ๋๋ฅผ ์ฌ์ฉํ์ฌ ํ์ฌ๊น์ง์ ๋์ฝ๋ ์
๋ ฅ์ ๋ํ ์ถ๋ ฅ์ ๊ณ์ฐ
out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
# get next token
# ๋์ฝ๋ ์ถ๋ ฅ์ ๋ง์ง๋ง ํ ํฐ์ ๋ํ ํ๋ฅ ๋ถํฌ๋ฅผ ๊ณ์ฐ
prob = model.project(out[:, -1])
#๊ฐ์ฅ ๋์ ํ๋ฅ ์ ๋ค์ ๋จ์ด๋ก ์ ์
#_, : ๋ณ์ ๊ฐ์ ๊ตณ์ด ์ฌ์ฉํ ํ์๊ฐ ์์ ๋ ์ฌ์ฉ
_, next_word = torch.max(prob, dim=1)
#์ ํ๋ ๋ค์ ๋จ์ด๋ฅผ ํ์ฌ๊น์ง์ ๋์ฝ๋ ์
๋ ฅ
decoder_input = torch.cat(
[decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
)
#๋ง์ฝ ๋ค์ ๋จ์ด๊ฐ eos ํ ํฐ์ด๋ฉด ์ค์ง
if next_word == eos_idx:
break
return decoder_input.squeeze(0)
def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
model.eval()
count = 0
# source_texts = []
# expected = []
# predicted = []
#TMI
try:
# get the console window width
with os.popen('stty size', 'r') as console:
_, console_width = console.read().split()
console_width = int(console_width)
except:
# If we can't get the console width, use 80 as default
console_width = 80
#Val ์งํ
with torch.no_grad():
for batch in validation_ds:
count += 1
encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)
# check that the batch size is 1
assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"
#Greedy Decode ์ ์ฉ
model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
#ํ์ฌ ๋ฐฐ์น์์ ์์ค ๋ฌธ์ฅ๊ณผ ํ๊น ๋ฌธ์ฅ์ ๊ฐ์ ธ์ด
source_text = batch["src_text"][0]
target_text = batch["tgt_text"][0]
#๋ชจ๋ธ์ ์ถ๋ ฅ์ ๋์ฝ๋ฉํ์ฌ ํ
์คํธ๋ก ๋ณํํจ
model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
# #๊ฐ๊ฐ ์์ค ๋ฌธ์ฅ, ํ๊น ๋ฌธ์ฅ, ๋ชจ๋ธ์ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ๋ฆฌ์คํธ์ ์ถ๊ฐ
# source_texts.append(source_text)
# expected.append(target_text)
# predicted.append(model_out_text)
# Print the source, target and model output
print_msg('-'*console_width)
#์ฐ์ธก ์ ๋ ฌ
print_msg(f"{f'SOURCE: ':>12}{source_text}")
print_msg(f"{f'TARGET: ':>12}{target_text}")
print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")
#์ง์ ๋ ์์ ์๋งํผ ์ถ๋ ฅํ๋ฉด ์ข
๋ฃ
if count == num_examples:
print_msg('-'*console_width)
break
# if writer:
# # Evaluate the character error rate
# # Compute the char error rate
# metric = torchmetrics.CharErrorRate()
# cer = metric(predicted, expected)
# writer.add_scalar('validation cer', cer, global_step)
# writer.flush()
# # Compute the word error rate
# metric = torchmetrics.WordErrorRate()
# wer = metric(predicted, expected)
# writer.add_scalar('validation wer', wer, global_step)
# writer.flush()
# # Compute the BLEU metric
# metric = torchmetrics.BLEUScore()
# bleu = metric(predicted, expected)
# writer.add_scalar('validation BLEU', bleu, global_step)
# writer.flush()
#๋ชจ๋ ๋ฌธ์ฅ ๊ฐ์ ธ์ค๊ธฐ
def get_all_sentences(ds,lang):
for item in ds:
#yield : ๋ฐ๋ณต๋ ๋๋ง๋ค ํ๋์ ๊ฐ์ ์์ฑ <-> ๊ธฐ์กด def๋ ๋จ์ผ ๊ฒฐ๊ณผ๋ฌผ๋ง ๋ฐํ
yield item['translation'][lang]
#ํ ํฌ๋์ด์ ๊ฐ์ ธ์ค๊ฑฐ๋ ์๋ก ์์ฑ
def get_or_build_tokenizer(config, ds, lang): #datasets, langauage
# config['tokenizer_file] = '../tokenizers/tokenizer_{0}.json
tokenizer_path = Path(config['tokenizer_file'].format(lang))
#HuggingFace ๋ผ์ด๋ธ๋ฌ๋ฆฌ
if not Path.exists(tokenizer_path):
tokenizer = Tokenizer(WordLevel(unk_token='[UNK]'))
tokenizer.pre_tokenizer = Whitespace() #๊ณต๋ฐฑ์ ๊ธฐ์ค์ผ๋ก ๋ถ๋ฆฌ
#WordLevelTrainer : ํ์ต ๋ฐ ์ถ๋ก (๋จ์ด ํ ํฐ์ ํ์ตํ๋๋ฐ ์ฌ์ฉ๋๋ ๋ฉ์๋)
trainer = WordLevelTrainer(special_tokens=['[UNK]', '[PAD]','[SOS]','[EOS]'])
tokenizer.train_from_iterator(get_all_sentences(ds,lang),trainer=trainer)
tokenizer.save(str(tokenizer_path))
else:
tokenizer = Tokenizer.from_file(str(tokenizer_path))
return tokenizer
#train, val data๋ก ๋๋๋ ํจ์
def get_ds(config):
ds_raw = load_dataset('opus_books',f'{config['lang_src']}-{config['lang_tgt']}',split='train')
#Build Tokenizers
tokenizer_src = get_or_build_tokenizer(config, ds_raw, config=['lang_src'])
tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config=['lang_src'])
#Keep 90% for training and 10% for validation
train_ds_size = int(0.9+len(ds_raw))
val_ds_size = len(ds_raw) - train_ds_size
train_ds_raw, val_ds_raw = random_split(ds_raw,[train_ds_size, val_ds_size])
#Dataset class ์ ์ฉ
train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
# Find the maximum length of each sentence in the source and target sentence
# max seq length ์ฐพ๊ธฐ
max_len_src = 0
max_len_tgt = 0
for item in ds_raw:
src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
max_len_src = max(max_len_src, len(src_ids))
max_len_tgt = max(max_len_tgt, len(tgt_ids))
print(f'Max length of source sentence: {max_len_src}')
print(f'Max length of target sentence: {max_len_tgt}')
#Data ๋ถ๋ฌ์ค๊ธฐ
train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)
return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt
#๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
def get_model(config, vocab_src_len, vocab_tgt_len):
model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])
return model
#Train ํจ์
def train_model(config):
# Define the device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
print("Using device:", device)
if (device == 'cuda'):
print(f"Device name: {torch.cuda.get_device_name(device.index)}")
print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
elif (device == 'mps'):
print(f"Device name: <mps>")
else:
print("NOTE: If you have a GPU, consider using it for training.")
print(" On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc")
print(" On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu")
device = torch.device(device)
# Make sure the weights folder exists
Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
#๋ชจ๋ธ ์ ์ฉ
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
# Tensorboard
writer = SummaryWriter(config['experiment_name'])
#Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
# If the user specified a model to preload before training, load it
initial_epoch = 0
global_step = 0
preload = config['preload']
model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
if model_filename:
print(f'Preloading model {model_filename}')
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])
initial_epoch = state['epoch'] + 1
optimizer.load_state_dict(state['optimizer_state_dict'])
global_step = state['global_step']
else:
print('No model to preload, starting from scratch')
#cross Entropy ์ฌ์ฉ
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)
for epoch in range(initial_epoch, config['num_epochs']):
torch.cuda.empty_cache()
model.train()
batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
for batch in batch_iterator:
#model.train()
encoder_input = batch['encoder_input'].to(device) # (B, seq_len)
decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
# Run the tensors through the encoder, decoder and the projection layer
encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)
# Compare the output with the label
label = batch['label'].to(device) # (B, seq_len)
# Compute the loss using a simple cross entropy
loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
# Log the loss (Tensorboard ์ฌ์ฉํ ๋)
writer.add_scalar('train loss', loss.item(), global_step)
writer.flush()
# Backpropagate the loss
loss.backward()
# Update the weights
optimizer.step()
optimizer.zero_grad(set_to_none=True)
#run_validation
global_step += 1
# Run validation at the end of every epoch
run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)
# Save the model at the end of every epoch
# ์ํญ๋ง๋ค ์ ์ฅ
model_filename = get_weights_file_path(config, f"{epoch:02d}")
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'global_step': global_step
}, model_filename)
if __name__ == '__main__':
warnings.filterwarnings("ignore")
config = get_config() #์ง์ ํ ํ์ดํผํ๋ผ๋ฏธํฐ๊ฐ ์ ์ฉ
train_model(config)
์คํ ๊ฒฐ๊ณผ๋ ์ถํ ์ฌ๋ฆด ์์ .

'Deep Learning > [์ฝ๋ ๊ตฌํ] DL Architecture ๊ตฌํ' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[Transformer] ์ํคํ ์ฒ ๊ตฌํํ๊ธฐ - 1 (Pytorch) (1) | 2024.02.17 |
---|---|
[UNet] copy and crop ์ฝ๋ ๊ตฌํ ๋ฐ ์ํคํ ์ฒ ๊ตฌํํ๊ธฐ (Pytorch) (1) | 2024.02.08 |