728x90
๋ฐ์ํ
- seq2seq ⇒ 2๊ฐ์ RNN์ ์ฐ๊ฒฐํด ํ๋์ ์๊ณ์ด ๋ฐ์ดํฐ๋ฅผ ๋ค๋ฅธ ์๊ณ์ด ๋ฐ์ดํฐ๋ก ๋ณํ.
- ์ดํ ์ ์ ์ญํ ⇒ seq2seq๋ฅผ ๋ ๊ฐ๋ ฅํ๊ฒ ํด์ค.
8.1 ์ดํ ์ ์ ๊ตฌ์กฐ
8.1.1 seq2seq์ ๋ฌธ์ ์
8.1.2 Encoder ๊ฐ์
8.1.3 Decoder ๊ฐ์ 1(์ ํ ์์ ๊ณ์ธต)
- ์ธ๊ฐ์ด ๋ฒ์ญํ ๋๋ ‘๋=I’, ๊ณ ์์ด=cat์ด๋ผ๋ ๋์ ๊ด๊ณ๋ก ์ง์์ ํ์ฉ⇒ ์ผ๋ผ์ด๋จผํธ๋ผ๊ณ ์นญํจ
- seq2seq2์๊ฒ๋ ์ ์ฉ์ํด

- ๊ธฐ์กด decoder์ ๋ง์ง๋ง์ ์๋, ์ต์ข ์๋ ์ํ๋ง์ ๊ฐ์ ธ๊ฐ๋ ํ์.

- ์ด๋ค ๊ณ์ฐ ์ธต์ด ์๋ก ์๊น

- ์ ํ์์
๊ณ์ธต
- 2๊ฐ์ง๋ฅผ ์ ๋ ฅ๋ฐ์ 1) (encoder์์ ๋์จ hs), 2) ๊ฐ์ค์น ๊ณ์ฐ์์ ๋์จ a
- ์๋ค๋ค์ ๊ณ์ฐํด์ c๋ก ์ ๋ฌ
- ๊ฐ์คํฉ์ ๊ตฌํด์ ๋งฅ๋ฝ๋ฒกํฐ๋ก ๋ง๋ ๋ค ex) ‘๋’ ๋ผ๋ ๋จ์ด์ ๊ฐ์ค์น๊ฐ 0.8๋ก ๊ฐ์ฅ ๋์ผ๋ฏ๋ก c๋ผ๋ ๋งฅ๋ฝ ๋ฒกํฐ์๋ ‘๋’ ๋ผ๋ ๋จ์ด์ ๋ํ ์ ๋ณด๊ฐ ๋ง์ด ํฌํจ ๋์ด ์์ ๊ฒ์.
8.1.4 Decoder ๊ฐ์ 2 (๊ฐ์ค์น(a) ๊ณ์ฐ ๊ณ์ธต)

- ๊ฐ์ค์น(a) ๊ณ์ฐ ๊ณ์ธต
- encoder์ ์๋ ์ํ ๋ชจ์์ธ hs + LSTM ์ถ๋ ฅ๊ฐ h 2๊ฐ ๊ฐ์ ์ ๋ ฅ์ผ๋ก ๋ฐ์.
- ๋๊ฐ์ ์ ๋ ฅ ๋ฐ์ดํฐ๋ก a๋ผ๋ ๊ฐ์ค์น๋ฅผ ๋ง๋ค์ด์ผ ํจ.
- ์ด ๋ ์ฐ๋ฆฌ๊ฐ ํด์ผํ ์ผ์ hlstm์ด hs์์์ ๊ฐ ์๋ ์ํ์ ์ผ๋ง๋ ๊ด๋ จ์ด ์๋์ง๋ฅผ ํ๋์ ์์น๋ก ํํํ๋ ๊ฐ์ค์น๋ก ๋ง๋ค์ด์ผ ํจ.⇒ ๋ฒกํฐ์ ๋ด์ ์ ์ด์ฉํด ๊ตฌํจ.
- ⇒ ์ฝ์ฌ์ธ ์ ์ฌ๋: ๋ ๋ฒกํฐ์ ํฌ๊ธฐ ์ ๊ฒฝ x, ๋ ๋ฒกํฐ๊ฐ ๊ฐ๋ฆฌํค๋ ๋ฐฉํฅ ๊ฐ์ ๊ฐ๋๋ฅผ ๊ตฌํจ ⇒ ์ ์ฌ๋ ๊ณ์ฐ.
- ๊ฐ์ค์น ๊ณ์ฐ ๊ณ์ธต์ ๊ณ์ฐ ๊ทธ๋ํ
8.2 ์ดํ ์ ์ ๊ฐ์ถ seq2seq ๊ตฌํ
8.3 ์ดํ ์ ํ๊ฐ
8.4 ์ดํ ์ ์ ๊ดํ ๋จ์ ์ด์ผ๊ธฐ
8.4.1 ์๋ฐฉํฅ RNN


class TimeBiLSTM: def __init__(self, Wx1, Wh1, b1, Wx2, Wh2, b2, stateful=False): self.forward_lstm = TimeLSTM(Wx1, Wh1, b1, stateful) self.backward_lstm = TimeLSTM(Wx2, Wh2, b2, stateful) self.params = self.forward_lstm.params + self.backward_lstm.params self.grads = self.forward_lstm.grads + self.backward_lstm.grads def forward(self, xs): o1 = self.forward_lstm.forward(xs) o2 = self.backward_lstm.forward(xs[:, ::-1]) o2 = o2[:, ::-1] ## ์ ๋ค์ ๋ค์ง์? out = np.concatenate((o1, o2), axis=2) return out def backward(self, dhs): H = dhs.shape[2] // 2 do1 = dhs[:, :, :H] do2 = dhs[:, :, H:] dxs1 = self.forward_lstm.backward(do1) do2 = do2[:, ::-1] dxs2 = self.backward_lstm.backward(do2) dxs2 = dxs2[:, ::-1] dxs = dxs1 + dxs2 return dxs
8.4.2 Attention ๊ณ์ธต ์ฌ์ฉ ๋ฐฉ๋ฒ
8.4.3 seq2seq ์ฌ์ธตํ & skip connection

- ๊ณ์ธต์ ๊น๊ฒ ํ ๊ฒฝ์ฐ, ๋ชจ๋ธ์ ์ผ๋ฐํ ์ฑ๋ฅ( ๊ณผ์ ํฉ์ด ๋ฐ์ํ์ง ์๋๋ก ) ๋จ์ด๋จ๋ฆฌ์ง ์๋๋ก ํ๋ ๊ฒ์ด ์ค์ํจ.⇒ ๋ฐฉ๋ฒ์ ์๋ก skip connection ๊ธฐ๋ฒ์ด ์กด์ฌ.
- skip connection์ ์ถ๋ ฅ๊ฐ์ ๊น์ด ๋ฐฉํฅ์ ๋ค์ LSTM ๊ณ์ธต ์ถ๋ ฅ๊ฐ์ ๋ํด์ฃผ๋ ๋ฐฉ์์ผ๋ก ์ํ
- ๋ง์ ์ฐ์ฐ์ ํ๊ฒ ๋๋ฉด ์ญ์ ํ๊ฐ ์งํ๋์ด๋ ๊ธฐ์ธ๊ธฐ ์์ค์ด ์ผ์ด๋์ง ์์.
- ๊น์ด ๋ฐฉํฅ์์์ ๊ธฐ์ธ๊ธฐ ์์ค๊ณผ ํญ๋ฐ ⇒ skip connection
- ์๊ฐ ๋ฐฉํฅ์์์ ๊ธฐ์ธ๊ธฐ ์์ค ⇒ ๊ฒ์ดํธ ์ถ๊ฐํ LSTM, GRU
- ์๊ฐ ๋ฐฉํฅ์์์ ๊ธฐ์ธ๊ธฐ ํญ๋ฐ ⇒ Gradient Clipping( L2 ๊ท์ )
8.5 ์ดํ ์ ์์ฉ
8.5.1 ๊ตฌ๊ธ ์ ๊ฒฝ๋ง ๊ธฐ๊ณ ๋ฒ์ญ(GNMT)
- ๊ท์น ๊ธฐ๋ฐ ๋ฒ์ญ ⇒ ์ฉ๋ก ๊ธฐ๋ฐ ๋ฒ์ญ⇒ ํต๊ณ ๊ธฐ๋ฐ ๋ฒ์ญ

• ์ฐ๋ฆฌ๊ฐ ์์ ๋ฐฐ์ ๋ ์ดํ ์ ์ ๊ฐ์ถ seq2seq์ ๋ง์ฐฌ๊ฐ์ง๋ก Encoder, Decoder, Attention์ผ๋ก ๊ตฌ์ฑ๋์ด์๋ค. ๋ค๋ง, ์ฌ๊ธฐ์ ๋ฒ์ญ ์ ํ๋๋ฅผ ๋์ด๊ธฐ ์ํด LSTM ๊ณ์ธต์ ๋ค์ธตํ, ์๋ฐฉํฅ LSTM, skip ์ฐ๊ฒฐ ๋ฑ์ ์ถ๊ฐํ๋ค. ๊ทธ๋ฆฌ๊ณ ํ์ต ์๊ฐ์ ๋จ์ถํ๊ธฐ ์ํด GPU๋ก ๋ถ์ฐํ์ต์ ์ํํ๊ณ ์๋ค. ์ด์ธ์๋ ๋ฎ์ ๋น๋์ ๋จ์ด์ฒ๋ฆฌ๋ ์ถ๋ก ๊ณ ์ํ๋ฅผ ์ํ ์์ํ ๋ฑ์ ์ฐ๊ตฌ๋ ์ด๋ฃจ์ด์ง๊ณ ์๋ค. ์ด๋ก์จ ์ ์ ์ฌ๋์ ์ ํ๋์ ๊ฐ๊น์์ง๊ณ ์๋ค.
8.5.2 ํธ๋์คํฌ๋จธ
- RNN ๋์ ํฉ์ฑ๊ณฑ ๊ณ์ธต์ ํ์ฉํด seq2seq ๊ตฌํ1. RNN: ๋ณ๋ ฌ์ฒ๋ฆฌ ๋ถ๊ฐ.
- RNN: ๋ฌผ๋ฆฌ์ ๊ฑฐ๋ฆฌ๊ฐ ๋ฉ์ด์ง๋ฉด ๋์๊ด๊ณ ์ ํ์ต ๋ชปํจ.
- ⇒ ์ด์ ์๊ฐ์ ๋ํ ๊ฒฐ๊ณผ๋ฅผ ์ด์ฉํด ์์๋๋ก ๊ณ์ฐํ๊ธฐ ๋๋ฌธ์ ๋ณ๋ ฌ์ฒ๋ฆฌ๊ฐ ๋ถ๊ฐ๋ฅ.
- ์ ํ ์ดํ ์


- ์ ํ ์ดํ ์ ์ ๊ฒฝ์ฐ, ์ ๋ ฅ ์ํ์ค ๋ด์์์ ๋์๊ด๊ณ๋ฅผ ํ์ตํ๊ณ , ๋์์ ์ถ๋ ฅ ์ํ์ค ๋ด์์์ ๋์๊ด๊ณ๋ ํ์ตํจ.

- ๊ณ์ฐ๋ ์ค๊ณ , GPU ํ์ฉํ ๋ณ๋ ฌ ๊ณ์ฐ์ ํํ๋ ์ป์.
728x90
๋ฐ์ํ