728x90
๋ฐ์ํ
์ง๋ ์๊ฐ์ ์ด์ด, ์ค๋์ ๋๋จธ์ง train.py, config.py, dataset.py ํ์ผ์ ๊ตฌํํ๋ค.
https://www.youtube.com/watch?v=ISNdQcPhsts
์ด ๋ถ ์ฝ๋๋ฅผ ๋ฐํ์ผ๋ก ๊ตฌํํ์์ต๋๋ค.
1. Dataset.py ๊ตฌํ
1-1. Bilingual Dataset
์ฌ์ฉํ ๋ฐ์ดํฐ์ ์ Hugging Face์์ ์ ๊ณตํ๋ opus_books Dataset์ ํ์ฉํ์๋ค.
https://huggingface.co/datasets/opus_books/viewer/en-it
import torch
import torch.nn as nn
from torch.utils.data import Dataset
class BilingualDataset(Dataset):
#ํด๋์ค๋ฅผ ์์ฑํ ๋ ์คํ๋๋ ์์ฑ์
def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len) -> None:
super().__init__()
self.ds = ds
self.tokenizer_src = tokenizer_src
self.tokenizer_tgt = tokenizer_tgt
self.src_lang = src_lang
self.tgt_lang = tgt_lang
self.sos_token = torch.Tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64)
self.eos_token = torch.Tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64)
self.pad_token = torch.Tensor([tokenizer_src.token_to_id('[PAD]')], dtype=torch.int64)
#์์์ ๊ฐ์๋ฅผ ์
๋ ์ ๊ทผ๋๋ ๋ฉ์๋
def __len__(self):
return len(self.ds)
#์ธ๋ฑ์ค์ ์ ๊ทผํ ๋ ํธ์ถ๋๋ ๋ฉ์๋
def __getitem__(self, index : Any) -> Any:
src_target_pair = self.ds[index]
src_text = src_target_pair['translation'][self.src_lang]
tgt_text = src_target_pair['translation'][self.tgt_lang]
# ํ ํฌ๋์ด์ ๋ฅผ ์ฌ์ฉํ์ฌ ํ
์คํธ๋ฅผ ํ ํฐํํ๊ณ , ๊ทธ ๊ฒฐ๊ณผ์์ ํ ํฐ IDs๋ฅผ ์ถ์ถ
enc_input_tokens = self.tokenizer_src.encode(src_text).ids
dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
#seq_len(์ํ์ค ๊ธธ์ด)์ ๋ง์ถ๊ธฐ ์ํด ํจ๋ฉ์ ์ถ๊ฐํด์ค
#์๋ณธ ์ธ์ด์ ์
๋ ฅ ์ํ์ค์ ์ถ๊ฐํด์ผ ํ๋ ํจ๋ฉ ํ ํฐ์ ์
# SOS, EOS token (-2)
enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
#๋์ ์ธ์ด์ ์ถ๋ ฅ ์ํ์ค์ ์ถ๊ฐํด์ผ ํ๋ ํจ๋ฉ ํ ํฐ์ ์
# EOS token (-1)
dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
#self.seq_len : 10
#enc_input_tokens : ["์๋
", "ํ์ธ์", "๋ฐ๊ฐ์ต๋๋ค"]
#dec_input_tokens : ["Hello", "World"]
#enc_num_padding_tokens: 10 - len(["์๋
", "ํ์ธ์", "๋ฐ๊ฐ์ต๋๋ค"]) - 2 = 10 - 3 - 2 = 5
#dec_num_padding_tokens: 10 - len(["Hello", "World"]) - 1 = 10 - 2 - 1 = 7
#['sos',"์๋
", "ํ์ธ์", "๋ฐ๊ฐ์ต๋๋ค", 'eos',ํจ๋ฉ, ํจ๋ฉ, ํจ๋ฉ, ํจ๋ฉ, ํจ๋ฉ]
#['Hello', 'World', 'EOS', ํจ๋ฉ, ํจ๋ฉ, ํจ๋ฉ, ํจ๋ฉ, ํจ๋ฉ, ํจ๋ฉ, ํจ๋ฉ]
#0๋ณด๋ค ์์ ๊ฒฝ์ฐ, ์๋ฌ ๋ฐ์ ex. 10 - 9 - 2 = -1
#์ ์ฒด ๋ฌธ์ฅ ๊ธธ์ด๋ณด๋ค input ๋ฌธ์ฅ ๊ธธ์ด๊ฐ ๋ ๊ธด ๊ฒฝ์ฐ,
if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
raise ValueError("Sentence is too long")
1-2. ํ ํฐ Concat
# ํ ํฐ Concat
# Add <s> and </s> token
encoder_input = torch.cat(
[
self.sos_token,
torch.tensor(enc_input_tokens, dtype=torch.int64),
self.eos_token,
torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
],
dim=0,
)
# Add only <sos> token
decoder_input = torch.cat(
[
self.sos_token,
torch.tensor(dec_input_tokens, dtype=torch.int64),
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
],
dim=0,
)
# Add only <eos> token
label = torch.cat(
[
torch.tensor(dec_input_tokens, dtype=torch.int64),
self.eos_token,
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
],
dim=0,
)
# Double check the size of the tensors to make sure they are all seq_len long
# Asser : if ๊ฐ๋
# ์ต๋ ์ํ์ค์ input+ํจ๋ฉ ํ ๊ฒฐ๊ณผ๊ฐ ๋์ผํ ๊ธธ์ด์ธ์ง๋ฅผ ๋ฌผ์ด๋ด
assert encoder_input.size(0) == self.seq_len
assert decoder_input.size(0) == self.seq_len
assert label.size(0) == self.seq_len
return {
"encoder_input": encoder_input, # (seq_len)
"decoder_input": decoder_input, # (seq_len)
#์ฐจ์์ ๋ง์ถฐ์ฃผ๊ธฐ ์ํด ์คํ
"encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
# ํ์ฌ ์์น ์ดํ์ ํ ํฐ ๋ง์คํฌ ์ฒ๋ฆฌ (ํจ๋ฉ ํ ํฐ์ด ์๋ ์์น์ ๋ํด True, ํจ๋ฉ ํ ํฐ์ธ ์์น์ ๋ํด False ๋ฐํ)
# -> ํน์ ์์น์ ํ ํฐ์ ๋ํ attention์ ์ํํ ์ง ๋ง์ง๋ฅผ ๊ฒฐ์
"decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
"label": label, # (seq_len) ๋์ฝ๋์ ์ค์ ์ถ๋ ฅ์ ํด๋นํ๋ ๊ฐ
"src_text": src_text, #input text
"tgt_text": tgt_text, #target text
1-3. Mask ๊ตฌํ
def causal_mask(size):
#https://incredible.ai/nlp/2020/02/29/Transformer/#241-padding-mask-%ED%95%B5%EC%8B%AC-%EB%82%B4%EC%9A%A9
#triu : ์ ์ฌ๊ฐํ์ n x n ์ด ์์ ๋, ์๋์ชฝ ์ผ๊ฐ ๋ถ๋ถ์ 0์ผ๋ก, ์์ชฝ ์ผ๊ฐ ๋ถ๋ถ๋ง 1๋ก ๋ฆฌํด
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
return mask == 0
2. Config.py ๊ตฌํ
from pathlib import Path
def get_config():
return {
"batch_size": 8,
"num_epochs": 20,
"lr": 10**-4,
"seq_len": 350,
"d_model": 512,
"datasource": 'opus_books',
"lang_src": "en",
"lang_tgt": "it",
"model_folder": "weights",
"model_basename": "tmodel_",
"preload": "latest",
"tokenizer_file": "tokenizer_{0}.json",
"experiment_name": "runs/tmodel"
}
#weight ํด๋ ๋ฐ ์ด๋ฆ ๋ถ๋ฌ์ค๊ธฐ
def get_weights_file_path(config, epoch: str):
model_folder = f"{config['datasource']}_{config['model_folder']}"
model_filename = f"{config['model_basename']}{epoch}.pt"
return str(Path('.') / model_folder / model_filename)
# Find the latest weights file in the weights folder
def latest_weights_file_path(config):
model_folder = f"{config['datasource']}_{config['model_folder']}"
model_filename = f"{config['model_basename']}*"
weights_files = list(Path(model_folder).glob(model_filename))
if len(weights_files) == 0:
return None
#๊ฐ์ ธ์จ ํ์ผ๋ค์ ํ์ผ์ ์์ ์๊ฐ์ ๊ธฐ์ค์ผ๋ก ์ ๋ ฌ
#๋งจ ๋ง์ง๋ง ํ์ผ์ด ๊ฐ์ฅ ์ต๊ทผ์ ์์ฑ๋ ๋ชจ๋ธ ๊ฐ์ค์น ํ์ผ์ด ๋จ!
weights_files.sort()
return str(weights_files[-1])
3. Train.py (+Validation)
import torch
import torch.nn. as nn
from torch.utils.data import Dataset, DataLoader, random_split
from dataset import BilingualDataset, causal_mask
from model import build_transformer
from config import get_config, get_weights_file_path, latest_weights_file_path
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers_pre_tokenizers import Whitespace
from torch.utils.tensorboard import SummaryWriter
import warnings
from tqdm import tqdm
from pathlib import Path
#๋์ฝ๋์์ ๊ฐ ์์ ๋ง๋ค ๊ฐ์ฅ ํ๋ฅ ์ด ๋์ ๋จ์ด๋ฅผ ์ ํํด
#๋ค์ ์์ ์ ์
๋ ฅ์ผ๋ก ์ฌ์ฉํ๋ ๊ฐ๋จํ ๋์ฝ๋ฉ ์ ๋ต ์ค ํ๋
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
sos_idx = tokenizer_tgt.token_to_id('[SOS]')
eos_idx = tokenizer_tgt.token_to_id('[EOS]')
# Precompute the encoder output and reuse it for every step
encoder_output = model.encode(source, source_mask)
# Initialize the decoder input with the sos token
decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
while True:
#๋์ฝ๋ ์
๋ ฅ์ ๊ธธ์ด๊ฐ max_len์ ๋๋ฌํ๋ฉด ๋ฃจํ๋ฅผ ์ข
๋ฃ
if decoder_input.size(1) == max_len:
break
# build mask for target
#ํ์ฌ๊น์ง ์์ฑ๋ ๋์ฝ๋ ์
๋ ฅ์ ๋ํ ์ดํ
์
๋ง์คํฌ๋ฅผ ์์ฑ
#ํ์ฌ ์์น ์ดํ์ ํ ํฐ์ ๋ํ ์ดํ
์
์ ๋ฐฉ์ง
decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
# calculate output
# ๋์ฝ๋๋ฅผ ์ฌ์ฉํ์ฌ ํ์ฌ๊น์ง์ ๋์ฝ๋ ์
๋ ฅ์ ๋ํ ์ถ๋ ฅ์ ๊ณ์ฐ
out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
# get next token
# ๋์ฝ๋ ์ถ๋ ฅ์ ๋ง์ง๋ง ํ ํฐ์ ๋ํ ํ๋ฅ ๋ถํฌ๋ฅผ ๊ณ์ฐ
prob = model.project(out[:, -1])
#๊ฐ์ฅ ๋์ ํ๋ฅ ์ ๋ค์ ๋จ์ด๋ก ์ ์
#_, : ๋ณ์ ๊ฐ์ ๊ตณ์ด ์ฌ์ฉํ ํ์๊ฐ ์์ ๋ ์ฌ์ฉ
_, next_word = torch.max(prob, dim=1)
#์ ํ๋ ๋ค์ ๋จ์ด๋ฅผ ํ์ฌ๊น์ง์ ๋์ฝ๋ ์
๋ ฅ
decoder_input = torch.cat(
[decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
)
#๋ง์ฝ ๋ค์ ๋จ์ด๊ฐ eos ํ ํฐ์ด๋ฉด ์ค์ง
if next_word == eos_idx:
break
return decoder_input.squeeze(0)
def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
model.eval()
count = 0
# source_texts = []
# expected = []
# predicted = []
#TMI
try:
# get the console window width
with os.popen('stty size', 'r') as console:
_, console_width = console.read().split()
console_width = int(console_width)
except:
# If we can't get the console width, use 80 as default
console_width = 80
#Val ์งํ
with torch.no_grad():
for batch in validation_ds:
count += 1
encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)
# check that the batch size is 1
assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"
#Greedy Decode ์ ์ฉ
model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
#ํ์ฌ ๋ฐฐ์น์์ ์์ค ๋ฌธ์ฅ๊ณผ ํ๊น ๋ฌธ์ฅ์ ๊ฐ์ ธ์ด
source_text = batch["src_text"][0]
target_text = batch["tgt_text"][0]
#๋ชจ๋ธ์ ์ถ๋ ฅ์ ๋์ฝ๋ฉํ์ฌ ํ
์คํธ๋ก ๋ณํํจ
model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
# #๊ฐ๊ฐ ์์ค ๋ฌธ์ฅ, ํ๊น ๋ฌธ์ฅ, ๋ชจ๋ธ์ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ๋ฆฌ์คํธ์ ์ถ๊ฐ
# source_texts.append(source_text)
# expected.append(target_text)
# predicted.append(model_out_text)
# Print the source, target and model output
print_msg('-'*console_width)
#์ฐ์ธก ์ ๋ ฌ
print_msg(f"{f'SOURCE: ':>12}{source_text}")
print_msg(f"{f'TARGET: ':>12}{target_text}")
print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")
#์ง์ ๋ ์์ ์๋งํผ ์ถ๋ ฅํ๋ฉด ์ข
๋ฃ
if count == num_examples:
print_msg('-'*console_width)
break
# if writer:
# # Evaluate the character error rate
# # Compute the char error rate
# metric = torchmetrics.CharErrorRate()
# cer = metric(predicted, expected)
# writer.add_scalar('validation cer', cer, global_step)
# writer.flush()
# # Compute the word error rate
# metric = torchmetrics.WordErrorRate()
# wer = metric(predicted, expected)
# writer.add_scalar('validation wer', wer, global_step)
# writer.flush()
# # Compute the BLEU metric
# metric = torchmetrics.BLEUScore()
# bleu = metric(predicted, expected)
# writer.add_scalar('validation BLEU', bleu, global_step)
# writer.flush()
#๋ชจ๋ ๋ฌธ์ฅ ๊ฐ์ ธ์ค๊ธฐ
def get_all_sentences(ds,lang):
for item in ds:
#yield : ๋ฐ๋ณต๋ ๋๋ง๋ค ํ๋์ ๊ฐ์ ์์ฑ <-> ๊ธฐ์กด def๋ ๋จ์ผ ๊ฒฐ๊ณผ๋ฌผ๋ง ๋ฐํ
yield item['translation'][lang]
#ํ ํฌ๋์ด์ ๊ฐ์ ธ์ค๊ฑฐ๋ ์๋ก ์์ฑ
def get_or_build_tokenizer(config, ds, lang): #datasets, langauage
# config['tokenizer_file] = '../tokenizers/tokenizer_{0}.json
tokenizer_path = Path(config['tokenizer_file'].format(lang))
#HuggingFace ๋ผ์ด๋ธ๋ฌ๋ฆฌ
if not Path.exists(tokenizer_path):
tokenizer = Tokenizer(WordLevel(unk_token='[UNK]'))
tokenizer.pre_tokenizer = Whitespace() #๊ณต๋ฐฑ์ ๊ธฐ์ค์ผ๋ก ๋ถ๋ฆฌ
#WordLevelTrainer : ํ์ต ๋ฐ ์ถ๋ก (๋จ์ด ํ ํฐ์ ํ์ตํ๋๋ฐ ์ฌ์ฉ๋๋ ๋ฉ์๋)
trainer = WordLevelTrainer(special_tokens=['[UNK]', '[PAD]','[SOS]','[EOS]'])
tokenizer.train_from_iterator(get_all_sentences(ds,lang),trainer=trainer)
tokenizer.save(str(tokenizer_path))
else:
tokenizer = Tokenizer.from_file(str(tokenizer_path))
return tokenizer
#train, val data๋ก ๋๋๋ ํจ์
def get_ds(config):
ds_raw = load_dataset('opus_books',f'{config['lang_src']}-{config['lang_tgt']}',split='train')
#Build Tokenizers
tokenizer_src = get_or_build_tokenizer(config, ds_raw, config=['lang_src'])
tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config=['lang_src'])
#Keep 90% for training and 10% for validation
train_ds_size = int(0.9+len(ds_raw))
val_ds_size = len(ds_raw) - train_ds_size
train_ds_raw, val_ds_raw = random_split(ds_raw,[train_ds_size, val_ds_size])
#Dataset class ์ ์ฉ
train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
# Find the maximum length of each sentence in the source and target sentence
# max seq length ์ฐพ๊ธฐ
max_len_src = 0
max_len_tgt = 0
for item in ds_raw:
src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
max_len_src = max(max_len_src, len(src_ids))
max_len_tgt = max(max_len_tgt, len(tgt_ids))
print(f'Max length of source sentence: {max_len_src}')
print(f'Max length of target sentence: {max_len_tgt}')
#Data ๋ถ๋ฌ์ค๊ธฐ
train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)
return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt
#๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
def get_model(config, vocab_src_len, vocab_tgt_len):
model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])
return model
#Train ํจ์
def train_model(config):
# Define the device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
print("Using device:", device)
if (device == 'cuda'):
print(f"Device name: {torch.cuda.get_device_name(device.index)}")
print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
elif (device == 'mps'):
print(f"Device name: <mps>")
else:
print("NOTE: If you have a GPU, consider using it for training.")
print(" On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc")
print(" On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu")
device = torch.device(device)
# Make sure the weights folder exists
Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
#๋ชจ๋ธ ์ ์ฉ
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
# Tensorboard
writer = SummaryWriter(config['experiment_name'])
#Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
# If the user specified a model to preload before training, load it
initial_epoch = 0
global_step = 0
preload = config['preload']
model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
if model_filename:
print(f'Preloading model {model_filename}')
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])
initial_epoch = state['epoch'] + 1
optimizer.load_state_dict(state['optimizer_state_dict'])
global_step = state['global_step']
else:
print('No model to preload, starting from scratch')
#cross Entropy ์ฌ์ฉ
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)
for epoch in range(initial_epoch, config['num_epochs']):
torch.cuda.empty_cache()
model.train()
batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
for batch in batch_iterator:
#model.train()
encoder_input = batch['encoder_input'].to(device) # (B, seq_len)
decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
# Run the tensors through the encoder, decoder and the projection layer
encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)
# Compare the output with the label
label = batch['label'].to(device) # (B, seq_len)
# Compute the loss using a simple cross entropy
loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
# Log the loss (Tensorboard ์ฌ์ฉํ ๋)
writer.add_scalar('train loss', loss.item(), global_step)
writer.flush()
# Backpropagate the loss
loss.backward()
# Update the weights
optimizer.step()
optimizer.zero_grad(set_to_none=True)
#run_validation
global_step += 1
# Run validation at the end of every epoch
run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)
# Save the model at the end of every epoch
# ์ํญ๋ง๋ค ์ ์ฅ
model_filename = get_weights_file_path(config, f"{epoch:02d}")
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'global_step': global_step
}, model_filename)
if __name__ == '__main__':
warnings.filterwarnings("ignore")
config = get_config() #์ง์ ํ ํ์ดํผํ๋ผ๋ฏธํฐ๊ฐ ์ ์ฉ
train_model(config)
์คํ ๊ฒฐ๊ณผ๋ ์ถํ ์ฌ๋ฆด ์์ .
728x90
๋ฐ์ํ
'Deep Learning > [์ฝ๋ ๊ตฌํ] DL Architecture ๊ตฌํ' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[Transformer] ์ํคํ ์ฒ ๊ตฌํํ๊ธฐ - 1 (Pytorch) (1) | 2024.02.17 |
---|---|
[UNet] copy and crop ์ฝ๋ ๊ตฌํ ๋ฐ ์ํคํ ์ฒ ๊ตฌํํ๊ธฐ (Pytorch) (1) | 2024.02.08 |