BEIT [Kor]
Bao et al. / BEIT - BERT Pre-Training of Image Transformers / ICLR 2022 Oral
1. Problem definition
์ด ๋ ผ๋ฌธ์ self-supervised pre-training์ ํตํด ์ด๋ฏธ์ง์ representation learning์ ์ํํ๋ ์ฐ๊ตฌ๋ฅผ ์งํํ์ต๋๋ค. ๊ธฐ์กด์ ๋น์ ์์ญ์์ ํ ์ด๋ฏธ์ง์ ์๋ก ๋ค๋ฅธ perturbation์ ์ ์ฉํ ๋ค ์ด๋ฅผ ๋ฐํ์ผ๋ก representation learning์ ์งํํ๋ SimCLR๋ BYOL ๋ฑ๊ณผ๋ ๋ค๋ฅด๊ฒ, NLP ์์ญ์์ ํฐ ์ฑ๊ณผ๋ฅผ ๊ฑฐ๋ BERT์ Masked Language Modeling(MLM)์ ์ด๋ฏธ์ง์ ์ ์ฉ์ํจ ๊ฒ์ด ์ด ๋ ผ๋ฌธ์ ์ฃผ๋ contribution์ด๋ผ ํ ์ ์๊ฒ ์ต๋๋ค.
pre-training์ผ๋ก ํ์ต๋ representation์ ์ฑ๋ฅ์ ๊ฒ์ฆํ๊ธฐ ์ํ fine-tuning task (ํน์ downstream task)๋ก๋ ์ด๋ฏธ์ง ๋ถ๋ฅ(image classification)์ semantic segmentation์ ์งํํ์ต๋๋ค.
2. Motivation
Related Works
Self-supervised Representation Learning
๋น์ ์์ญ์์ ์ฃผ๋ก ์ด๋ฃจ์ด์ง representation learning ์ค ๋ํ์ ์ธ ์ฐ๊ตฌ๋ฅผ ๊ผฝ์ผ๋ผ๋ฉด SimCLR (Chen et al.)๋ฅผ ๋นผ๋์ ์ ์์ต๋๋ค. ์ด ์ฐ๊ตฌ๋ contrastive learning์ ํตํด ์ด๋ฏธ์ง์ representation learning์ ์งํํ ์ฐ๊ตฌ์ธ๋ฐ์, contrastive learning์ ๊ธฐ๋ณธ ๊ฐ๋ ๊ณผ ํจ๊ป ๊ฐ๋จํ ์ค๋ช ๋๋ฆฌ๊ฒ ์ต๋๋ค.
Contrastive learning์ ๊ฐ ์ด๋ฏธ์ง๊ฐ ๋ชจ๋ธ์ ํต๊ณผํด์ ๋์จ latent vector (ํน์ representation vector)๊ฐ ์กด์ฌํ๋ latent space ์์์, positive pair๋ค์ latent vector๋ค๋ผ๋ฆฌ๋ ๊ฐ๊น๊ฒ negative pair๋ค์ latent vector๋ค๋ผ๋ฆฌ๋ ๋ฉ๊ฒ ํ์ต์ํค๋ ๋ฐฉ์์ ๋งํฉ๋๋ค. ์ด๋ฅผ ์ํํ๊ธฐ ์ํด ์ฃผ๋ก InfoNCE๋ผ๊ณ ๋ถ๋ฆฌ๋, ์๋์ loss ํจ์๋ฅผ ํตํด ๋ชจ๋ธ์ ์ต์ ํ์ํค๋๋ฐ์, ์ง๊ด์ ์ผ๋ก ์ค๋ช ๋๋ฆฌ์๋ฉด positive pair ํน์ negative pair์์ ๋์จ latent vector pair๋ค๋ผ๋ฆฌ์ similarity๋ฅผ ๊ณ์ฐํ ๋ค ์ด๋ฅผ ๋ถ๋ฅ ๋ฌธ์ ์์์ logit ๊ฐ์ผ๋ก ์ทจ๊ธํ์ฌ cross entropy๋ก ํ์ต์ํจ๋ค๊ณ ๋ณด์๋ฉด ๋ฉ๋๋ค. ์ด ๋, cross entropy ํ ์ Ground Truth label์ positive pair๊ฐ ๋๊ธฐ ๋๋ฌธ์ ํ์ต ๊ณผ์ ์์ positive pair์ similarity๋ ๋์ด๊ณ negative pair์ similarity๋ ๋ฎ์์ง๊ฒ ๋ฉ๋๋ค.
์ฌ๊ธฐ์ (i, j) pair๋ positive pair์ ์ด๋ฏธ์ง ์ธ๋ฑ์ค๋ฅผ ์๋ฏธํ๊ณ , ๋ ๋ฒ์งธ ์ด๋ฏธ์ง๊ฐ ๋ชจ๋ธ์ ํต๊ณผํ์ฌ ๋์จ latent vector์ ๋๋ค. ๋ํ ๋ ๋ฒกํฐ ์ฌ์ด์ similarity๋ฅผ ๋ํ๋ด๋ ํจ์๋ ๋ด์ ํน์ cosine similarity๋ฅผ ์ฌ์ฉํ๊ณค ํฉ๋๋ค. ์ด๋ฌํ contrastive learning์ ๊ณ ์ง์ ์ธ ๋ฌธ์ ์ค ํ๋๋ ๋ชจ๋ธ์ด ํ์ตํ๋ ๊ณผ์ ์์ collapse๊ฐ ์ผ์ด๋๋ค๋ ์ ์ ๋๋ค. ๋ฌด์จ ๋ง์ด๋๋ฉด, ์ฐ๋ฆฌ๊ฐ ๊ธฐ๋ํ๊ธฐ๋ก๋ ๋ชจ๋ธ์ด positive pair์ธ ์ด๋ฏธ์ง๋ค๋ผ๋ฆฌ๋ ๋ฉ๊ฒ ํ๊ณ negative pair์ธ ์ด๋ฏธ์ง๋ค๋ผ๋ฆฌ๋ ๋ฉ๊ฒ ํด์ latent space ์์ ๋ค์ํ ์ด๋ฏธ์ง๋ค์ representatation vector๋ฅผ ์ฌ๊ธฐ์ ๊ธฐ ํฉ๋ฟ๋ ค์ค ์ค ์์๋๋ฐ, ์ค์ ๋ก ํด๋ณด๋๊น ๊ทธ๋ ๊ฒ ๋์ง ์๊ณ ์ด๋ฏธ์ง๋ค์ representation vector๋ค์ด latent space์ ์์ฃผ ์์ ๋ถ๋ถ ์์์๋ง ๋๊ณ ์๋๋ผ๋ ๊ฒ์ ๋๋ค. ๋ชจ๋ธ์ด ์ด๋ ๊ฒ ํ์ต๋๋ ํต์ฌ์ ์ธ ์์ธ์ ์์์ ์ ์๋ loss ํจ์์ ๊ด๋ จ์ด ์๋๋ฐ์, ์์ธํ ๋ณด์๋ฉด positive pair ๊ฐ์ similarity๊ฐ ๋งค์ฐ ๋๊ธฐ๋ง ํ๋ฉด loss ๊ฐ์ด ๋จ์ด์ง ๊ฒ์ด๋ผ๋ ์๊ฐ์ ํ์ค ์ ์์ ๊ฒ๋๋ค. ๋ฌผ๋ก negative pair ๊ฐ์ similarity์ ๋นํด์ positive pair ๊ฐ์ similarity๊ฐ ๋์์ผ ํ๊ฒ ์ง๋ง (์ด๋ ๊ฒ ํ์ต๋๊ธธ ๊ธฐ๋ํ ๊ฒ์ด๊ธฐ๋ ํ๊ณ ์), ๋ชจ๋ธ์ด ํด๋น loss ํจ์๋ฅผ ์ต์ ํํ๋ ๊ณผ์ ์์ ๊ทธ๋ฅ ๋ชจ๋ ์ด๋ฏธ์ง๋ค์ ๋น์ทํ representation vector๋ก ๋ง๋ค์ด๋ฒ๋ฆฌ๋ ๊ฒ์ด negative pair๋ค ๊ฐ์ similarity๋ ํจ๊ป ๊ณ ๋ คํ๋ ๊ฒ๋ณด๋ค loss๋ฅผ ๋จ์ด๋จ๋ฆฌ๊ธฐ ์์ํ๊ธฐ ๋๋ฌธ์ ์ด๋ฐ ์ผ์ด ๋ฐ์ํ์ ๊ฐ๋ฅ์ฑ์ด ์์ต๋๋ค. ๊ฒฐ๊ตญ ๋ชจ๋ธ์ด pre-training ๋จ๊ณ์์ ์ด๋ ๊ฒ ๋ชจ๋ ์ด๋ฏธ์ง๋ค์ ๋น์ทํ representation vector๋ก ๋ง๋ค์ด๋ฒ๋ฆฌ๋ฉด ๋น์ฐํ fine-tuning ๋จ๊ณ์์ ์ด๋ฏธ์ง ๋ถ๋ฅ๋ semantic segmentation ๋ฑ์ downstream task๋ฅผ ์ํํ๋ ๊ฒ์ ์คํ๋ ค ๋ ์ด๋ ต๊ฒ ๋ง๋ญ๋๋ค. ์ด๋ฌํ ๊ณ ์ง์ ์ธ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํจ๊ณผ ๋์์ contrastive learning์ ํตํด ์๋ฏธ์๋ ์ฑ๋ฅ ํฅ์์ ๋ณด์ฌ์ค ๋ํ์ ์ธ ๋น์ ํ์ดํผ๊ฐ ๋ฐ๋ก SimCLR (Chen et al.)์ ๋๋ค. ์ด ๋ ผ๋ฌธ์ ์ฃผ์ contribution๊ณผ ๊ทธ ๋ฐฉ๋ฒ์ ๋ค์๊ณผ ๊ฐ์ด ์ ๋ฆฌํด๋ณผ ์ ์๊ฒ ์ต๋๋ค.
๋ชจ๋ธ์ ํ์ต์ํฌ ๋ ๋ฐฐ์น ์ฌ์ด์ฆ๋ฅผ ํฌ๊ฒ ๋๋ฆฌ๋ ๊ฒ๊ณผ ๋์์ ๋ฐฐ์น ๋ด์์ ์์ ์ positive sample์ ์ ์ธํ ๋ชจ๋ ๋ค๋ฅธ ์ด๋ฏธ์ง๋ฅผ negative sample๋ก ์ทจ๊ธํจ์ผ๋ก์จ ์์ฒญ๋๊ฒ ๋ง์ negative pair๋ฅผ ํตํด ์์์ ๋ง์๋๋ฆฐ model collapse๋ฅผ ๋ฐฉ์งํ์ต๋๋ค.
Contrastive learning์ ํต์ฌ์ด๋ผ๊ณ ํ ์ ์๋ positive pair๋ฅผ ํ ์ด๋ฏธ์ง์ ์๋ก ๋ค๋ฅธ transformation (ํน์ perturbation)์ ์ ์ฉํ์ฌ ๋์จ ๋ ์ด๋ฏธ์ง๋ก ์ ์ํ๋ ๊ฒ์ ์ ์ํ์ผ๋ฉฐ, ์ด๋ฅผ ํตํด downstream task์ ๋ํด ์๋ฏธ ์๋ ์ฑ๋ฅ ํฅ์์ ๋ณด์ฌ์ฃผ์์ต๋๋ค.
์์ ๊ฐ์ด positive pair๋ฅผ ๊ตฌ์ฑํ ๋ ํจ๊ณผ์ ์ธ transformation์ ์กฐํฉ์ ์ ์ํ์ต๋๋ค.
ํํธ, representation learning์ ์์ฒ๋ผ contrastive learning์ ํตํด์๋ง ํ ์ ์๋ ๊ฒ์ ์๋๋๋ค. ์คํ๋ ค contrastive task๋ pre-training ๋จ๊ณ์์ ํ ์ ์๋ pretext task์ ์ผ๋ถ์ผ ๋ฟ์ธ๋ฐ์, ์ด ๋ ผ๋ฌธ์์ ์ด๋ฏธ์ง์ representation learning์ ์ํด ์ฌ์ฉํ๋ ๋ฐฉ์ ๋ํ contrastive learning์ด ์๋ Masked Image Modeling(MIM)์ผ๋ก, BERT (Devlin et al.)์์ NLP๋ฅผ ์ํด ์ฌ์ฉ๋ MLM์ ์ด๋ฏธ์ง ๋๋ฉ์ธ์ ๋ง๊ฒ ๋ณํ์ํจ ํํ์ ๋๋ค. ์ด ๋ ผ๋ฌธ์ ํต์ฌ์ด๋ผ๊ณ ํด๋นํ ์ ์๋ ์์ด๋์ด๊ฐ BERT์์ motivated๋ ๋งํผ, BERT์ ๋ํด์๋ ๊ฐ๋จํ ์ค๋ช ๋๋ฆฌ๊ฒ ์ต๋๋ค. BERT๋ ์์ฐ์ด ๋๋ฉ์ธ์์ pre-training ๋ฐฉ๋ฒ๋ก ์ ์ ์ํ๊ณ , ์ด๋ฅผ Transformer Encoder์ ์ ์ฉํ์ฌ NLP ์์ญ์์ ๋๋ถ๋ถ์ downstream task์ ๋ํด ์ผ๊ด์ ์ผ๋ก ์ฑ๋ฅ ํฅ์์ ๋ณด์ฌ์ค ๋ ผ๋ฌธ์ ๋๋ค. BERT์์ ์ ์ํ pre-training ๋ฐฉ๋ฒ์๋ ์ด ๋ ๊ฐ์ง, MLM๊ณผ Next Sentence Prediction(NSP)์ด ์์ต๋๋ค.
Masked Language Modeling MLM์ input token๋ค๋ก ์ด๋ฃจ์ด์ง sequence์์ ์ผ๋ถ randomํ token์ ์ ํํ์ฌ [MASK]๋ผ๋ token์ผ๋ก ๋์ฒดํ ์ํ๋ก ๋ชจ๋ธ์ ํต๊ณผ์ํจ ๋ค, ํด๋น token์ output์ linear layer๋ฅผ ๋ฌ์ mask๋๊ธฐ ์ ์ ์๋ token์ด ๋ฌด์์ด์๋์ง๋ฅผ ๋ง์ถ๋ task์ ๋๋ค.
Next Sentence Prediction BERT๋ ์์์ MLM์ ํ๋ ๊ฒ๊ณผ ๋์์ NSP๋ฅผ ํตํด์๋ ๋ชจ๋ธ์ ํ์ต์ํค๋๋ฐ์, NSP๋ input sequence๊ฐ ๋ ๊ฐ์ sentence๋ค๋ก ์ด๋ฃจ์ด์ก์ ๋ ๊ทธ ๋ sentence๊ฐ ์ฐ๊ฒฐ๋์ด ์๋ ๋ฌธ์ฅ์ธ์ง, ์๋์ง๋ฅผ ๋ง์ถ๋ task์ ๋๋ค.
Transformers for Visual Tasks
NLP ์์ญ์์ Transformer ์ํคํ ์ณ๋ ๊ธฐ์กด์ RNN์ ๋นํด ๋๋ผ์ธ๋งํผ์ ์ฑ๋ฅ ํฅ์์ ๋ณด์ฌ์ฃผ์์ต๋๋ค. ๊ทธ ํ๋ฆ์ ๋ฐ๋ผ ๋น์ ์์ญ์์๋ Transformer๋ฅผ ํ์ฉํ๋ ์ฐ๊ตฌ๋ค์ด ์ด๋ฃจ์ด์ก๋๋ฐ, ๋ํ์ ์ผ๋ก fully-Transformer-based ์ํคํ ์ณ๋ฅผ ํตํด ์ด๋ฏธ์ง ๋ถ๋ฅ ํ์คํฌ๋ฅผ ์งํํ vision transformer (Dosovitskiy et al.) (ViT)์, semantic segmentation๊ณผ ๊ฐ์ scene understanding ํ์คํฌ๋ฅผ ์งํํ swin transformer (Liu et al.)๋ฅผ ๋ค ์ ์์ต๋๋ค. ์ด ๋ชจ๋ธ๋ค์ ํ์ฌ ๋ค์ํ SOTA ๋ชจ๋ธ๋ค์ backbone ์ํคํ ์ณ๋ก ์ฐ์ด๊ณ ์๋๋ฐ์, Transformer์ ํต์ฌ ์์ด๋์ด์ธ self-attention mechanism ํ์ ์ถฉ๋ถํ computational resource์ training time์ด ํ์ํ๋ค๋ ๋จ์ ์ด ์์ต๋๋ค. ๋ํ Transformer ์ํคํ ์ณ๋ ์ผ๋ฐ์ ์ผ๋ก CNN-based ๋ชจ๋ธ๋ณด๋ค ๋ ๋ง์ ํ์ต ๋ฐ์ดํฐ๋ฅผ ์๊ตฌํ๊ธฐ ๋๋ฌธ์ ์ฑ๋ฅ์ ๋ณด์ฅ๋์ง๋ง ํ์ต์ํค๊ธฐ ์ํ ์กฐ๊ฑด์ด ๊น๋ค๋กญ๋ค๋ ์ ์ด ๋ํ์ ์ธ ๋ฌธ์ ๋ผ๊ณ ํ ์ ์์ต๋๋ค.
Idea
์ด ๋ ผ๋ฌธ์ ํต์ฌ ์์ด๋์ด๋ ๋ค์๊ณผ ๊ฐ์ด ์์ฝํ ์ ์์ต๋๋ค.
Self-supervised Learning์ ํตํด Vision Transformer๊ฐ CNN-based model์ ๋นํด ๋ ๋ง์ ํ์ต ๋ฐ์ดํฐ๋ฅผ ์๊ตฌํ๋ค๋ ๋ฌธ์ ๋ฅผ ํด๊ฒฐ
์ด ๋, self-supervised learning์ ๋ฐฉ๋ฒ๋ก ์ผ๋ก์ ๊ธฐ์กด์ NLP ์์ญ์์ ์์ฒญ๋ ์ฑ๊ณผ๋ฅผ ๋ณด์ฌ์ค BERT์ MLM์ ์ด๋ฏธ์ง ๋๋ฉ์ธ์ ์ ์ฉํ ์ ์๋๋ก ๋ณํํ์ฌ ์ ์
3. Method
Image Patch
๋จผ์ , input์ผ๋ก ์ฌ์ฉ๋๋ 224 x 224 ์ด๋ฏธ์ง๋ฅผ 16 x 16์ ์์ patch๋ค๋ก ์ชผ๊ฐญ๋๋ค. ๋ฐ๋ผ์ ์ด (224 / 16) x (224 / 16) = 14 x 14๊ฐ์ patch๋ก ์ชผ๊ฐ์ง๋ฉฐ, ์ข์๋ถํฐ ์ฐํ๊น์ง ์์๋๋ก Vision Transformer์ input sequence๋ฅผ ๊ตฌ์ฑํฉ๋๋ค.
Visual Token
BERT์ฒ๋ผ input sequence์ ์ผ๋ถ๋ฅผ maskingํ ๋ค mask๋๊ธฐ ์ ์ token์ ์์ธกํ๋ MLM์ ํ๊ธฐ ์ํด์๋, input patch์ ๋ํ discretization์ด ์ด๋ฃจ์ด์ ธ์ผ ํฉ๋๋ค. ๋ฌผ๋ก , mask๋ token์ด ํต๊ณผํด์ ๋์จ hidden token์ mask๋๊ธฐ ์ ์ ์๋ณธ ์ด๋ฏธ์ง๋ก ๋ณต์์ํค๋ ์ผ์ข ์ regression task๋ฅผ ์งํํ ์๋ ์๊ฒ ์ง๋ง, ์ ์์ ๋ง์ ๋ฐ๋ฅด๋ฉด ์ด๋ ์ ์ ํ ๋ฐฉ๋ฒ์ด ์๋๋ผ๊ณ ํฉ๋๋ค.
However, such pixel-level recovery task tends to waste modeling capability on pre-training short-range dependencies and high-frequency details.
๋ฐ๋ผ์, ์ฐ๋ฆฌ๋ MLM์ฒ๋ผ discrimination (classification) task, ์ฆ hidden token์ ํตํด mask๋๊ธฐ ์ ์ ์๋ณธ์ ์์ธกํ๊ฒ๋ ํ์ฌ ๋ชจ๋ธ์ ํ์ต์ํฌ ๊ฒ๋๋ค. ์ด๋ฅผ ์ํด์๋ ์์ ์ธ๊ธํ๋๋ก input patch์ ๋ํ discretization์ด ํ์ํ๋ค๊ณ ํ๋๋ฐ์, ์ด๊ฒ ๋ฌด์์ ์๋ฏธํ๋ ๊ฑธ๊น์?
์ ๊ทธ๋ฆผ์ ๋ณด์์ฃ . ๋ง์ฝ ์ฐ๋ฆฌ๊ฐ MLM์ ํ๋ค๋ฉด ์์ ๊ฐ์ด mask๋ token์ด ๋ชจ๋ธ์ ํต๊ณผํ๊ณ ๋์จ hidden token์์ ๋ฏธ๋ฆฌ ์ ์๋ vocabulary์ ์๋ ๋จ์ด๋ค ์ค ํ๋๋ก classification์ ์งํํ ๊ฒ๋๋ค. Classification ๊ฒฐ๊ณผ ๊ฐ์ฅ ๋์ ํ๋ฅ ๊ฐ์ ๊ฐ์ง ๋จ์ด๊ฐ ๋ชจ๋ธ์ด ์์ธกํ๋ ๊ฐ์ฅ ๊ทธ๋ด๋ฏํ ๋จ์ด๊ฒ ์ฃ . ๊ทธ๋ฐ๋ฐ ์ฐ๋ฆฌ๋ ์์ฐ์ด๋ฅผ input์ผ๋ก ํต๊ณผ์ํฌ ๊ฒ์ด ์๋๋ผ image patch๋ค์ ํต๊ณผ์ํฌ ๊ฑด๋ฐ, ์ฌ๊ธฐ์ ๋ฌธ์ ๊ฐ ์๊น๋๋ค. ์ฐ๋ฆฌ๊ฐ ์ด๋ฏธ์ง์ ๋ํ vocabulary๊ฐ ๋ฐ๋ก ์์ง๊ฐ ์์๋ฐ, ๋์ฒด ์ด๋ป๊ฒ masked hidden token์ ๋ํ classification์ ์ํํ ๊ฒ์ด๋๋ ๊ฒ๋๋ค. ๋ฐ๋ก ์ด ๋ฌธ์ ๋ก ์ธํด์ ์ฐ๋ฆฌ๋ image patch๋ค์ ๋ง์น ์์ฐ์ด๊ฐ ๊ทธ๋ ๊ฒ ํ๋ฏ์ด tokenize (discretize)ํ๊ณ , token๋ค์ set์ธ vocabulary๋ฅผ ์ ์ํด์ผ ํฉ๋๋ค. ์ฆ discretize๋ผ๋ ๊ฒ์, continuousํ RGB๊ฐ๋ค๋ก ์ด๋ฃจ์ด์ง image patch๋ฅผ discreteํ ๋จ์๋ก ์ชผ๊ฐฌ์ผ๋ก์จ ์๋ณธ ์ด๋ฏธ์ง ์์ธก์ classification task๋ก ์ํํ ์ ์๊ฒ๋ ํ๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.
์ด๋ฅผ ์ํด์ ์ฐ๋ฆฌ๋ ๋ฏธ๋ฆฌ ์ ์๋ vocabulary์ tokenizer๊ฐ ํ์ํฉ๋๋ค. ์ด ๋ ผ๋ฌธ์์๋ DALL-E (Ramesh et al.)์์ ๊ณต๊ฐํ tokenizer๋ฅผ ์ฌ์ฉํ๋๋ฐ, ์ด๋ discrete variational auto-encoder์ธ VQ-VAE (Vector quantized-Variational AutoEncoder) (Oord et al.)๋ก ์ด๋ฃจ์ด์ก์ต๋๋ค. ์ฆ, VQ-VAE์ codebook์ด vocabulary๊ฐ ๋๋ฉฐ vector quantization์ ํตํด ๊ฐ image patch๊ฐ codebook ์์ ํน์ vector๋ก tokenize ๋ฉ๋๋ค. quantization ๊ณผ์ ์ ๋ํ ์ค๋ช ์ ๋ณธ ๋ฆฌ๋ทฐ์ scope๋ฅผ ๋ฒ์ด๋๊ธฐ ๋๋ฌธ์ ์์ธํ ๋ด์ฉ์ ์์ ํ์ง ์๊ฒ ์ง๋ง, ์ค์ํ ๊ฒ์ ์ ๋ชจ๋์ ํตํด์ ๊ฐ image patch๊ฐ discreteํ visual token์ด ๋๊ณ , ์ด๋ฅผ ํตํด MLM๊ณผ ๊ฐ์ task๋ฅผ ์ํํ ์ ์๋ค๋ ์ ์ ํ์ ํ์ จ๊ธฐ๋ฅผ ๋ฐ๋๋๋ค.
Masked Image Modeling
์ฌ๊ธฐ๊น์ง ์์ผ๋ฉด ์ด์ ๋จ์๊ฑด (1) input์ผ๋ก ๋ค์ด๊ฐ image patch์ ์ผ๋ถ๋ฅผ maskingํ ๋ค (2) Transformer Encoder์ ํต๊ณผ์ํค๊ณ , (3) mask๋ ์์น์์ ๋์จ hidden output token์ ๊ฐ์ง๊ณ mask๋๊ธฐ ์ ์ ์๋ณธ image patch์ code๋ก classificationํ๋, ์ด๋ฅธ๋ฐ Masked Image Modeling (MIM)์ ์งํํ๋ฉด ๋ฉ๋๋ค. ๊ฐ ๋จ๊ณ๋ฅผ ์ฐจ๊ทผ์ฐจ๊ทผ ๋ฐ๋ผ๊ฐ๋ด ์๋ค.
(1) Blockwise Masking
์ด ๋ ผ๋ฌธ์์๋ (14 x 14)๊ฐ์ image patch๋ค ์ค ์ฝ 40% ์ ๋๋ฅผ maskingํ๊ธฐ๋ก ํ์ต๋๋ค. ์ฐ๋ฆฌ๊ฐ ์๊ฐํ๊ธฐ์ ๊ฐ์ฅ naiveํ ๋ฐฉ์์ (14 x 14)๊ฐ์ patch๋ค์ ๊ฐ๊ฐ 40% ํ๋ฅ ๋ก masking๋ ์ง ์๋ ์ง ์ ํํ๊ฒ๋ ํ๋ฉด ๋ ํ ์ง๋ง, ์ด ๋ ผ๋ฌธ์์๋ ๋ค๋ฅธ ๋ฐฉ์์ masking ๋ฐฉ๋ฒ์ ์ ์ํ์ต๋๋ค. ๋ฐ๋ก blockwise masking์ธ๋ฐ์, ์ฝ๊ฒ ๋ง์๋๋ฆฌ์๋ฉด ๊ฐ image patch๋ฅผ ๋ ๋ฆฝ์ ์ผ๋ก maskingํ๋ ๊ฒ ์๋๋ผ ์ฐ์๋ image patch๋ฅผ ๊ณจ๋ผ์ block ๋จ์๋ก maskingํ์๋ ๊ฒ๋๋ค. ์ด๋ฅผ ์ํด์ span masking์ฒ๋ผ transformer์ input์ผ๋ก ๋ฃ๊ธฐ ์ํด ํ ์ค๋ก ์ธ์ด sequence์์ ์ฐ์๋ token์ maskingํ ์๋ ์๊ฒ ์ง๋ง, ๊ทธ๊ฒ๋ณด๋ค๋ image์ ํน์ฑ์ ์ด๋ ค์ ํ ์ค๋ก ์ธ์ฐ๊ธฐ ์ ์ image patch๋ค์ block ๋จ์๋ก ๋ฌถ๋ ๊ฒ์ ์ ์ํ์ต๋๋ค. ์ Figure 3์ ์ผ์ชฝ์ blockwise masking์ด ๋๋ ๋ถ๋ถ์ ๋ณด์๋ฉด 2 x 2์ patch๊ฐ masking๋๋ ๊ฒ์ ํ์ธํ์ค ์ ์์ต๋๋ค. ๊ตฌ์ฒด์ ์ผ๋ก๋ ๋ค์๊ณผ ๊ฐ์ ์๊ณ ๋ฆฌ์ฆ์ผ๋ก masking๋ image patch๋ค์ด ์ ํ๋ฉ๋๋ค. ์ด๋ ๊ฒ ์ ํ๋ image patch๋ค์ learnableํ special token์ผ๋ก ๋์ฒด๋ฉ๋๋ค.
(2) Forwarding to Transformer Encoder
์ด๋ ๊ฒ ํด์ masking์ด ์๋ฃ๋๊ณ ํ ์ค๋ก ์ธ์์ง input sequence์์ ๋งจ ์์ "start of sequence"๋ฅผ ๋ปํ๋ ๋ ๋ค๋ฅธ special token์ ์ถ๊ฐํด์ฃผ๋ฉด, input์ผ๋ก ๋ค์ด๊ฐ ์ ๋ค์ ๋ค ์ ํด์ก์ต๋๋ค. ๋ค๋ง, mask token๊ณผ sos token์ด ์๋ image patch๋ค์ ์์ง rawํ RGB ํฝ์ ๊ฐ๋ค๋ก ์ด๋ฃจ์ด์ก๊ธฐ ๋๋ฌธ์ transformer์ ํ์ฐ๊ธฐ ์ํด์ image patch๋ค์ ๊ฐ๊ฐ vector๋ก ํํํด์ค์ผ ํฉ๋๋ค. ์ด๋ ๊ฐ๋จํ๊ฒ image patch๋ฅผ linear layer๋ฅผ ํตํด ์ ํด์ง ์ฐจ์ (์๋ง 768์ด๊ฒ ์ฃ ?)์ผ๋ก projectionํจ์ผ๋ก์จ ์ด๋ฃจ์ด์ง๋๋ค. ์, ๊ทธ๋ผ ์ ๋ง ์ต์ข ์ ์ผ๋ก transformer encoder์ ๋ค์ด๊ฐ input์ด ๊ฒฐ์ ๋์ต๋๋ค. ์ฌ๊ธฐ์ patch๋ค ๊ฐ์ ์์ ์ ๋ณด ์ญํ ์ ํด์ค learnableํ position embedding์ด ๋ํด์ง๋ฉด, ๊ทธ๋๋ก transformer encoder๋ฅผ ํต๊ณผํด์ ๊ฐ input token๋ง๋ค hidden output token์ด ๋์ค๊ฒ ๋ฉ๋๋ค.
(3) Masked Image Modeling
์ด์ ์ฐ๋ฆฌ๋ Masked Image Modeling์ ํ ๊ฒ์ด๊ธฐ ๋๋ฌธ์ hidden output token๋ค ์ค์์ mask๊ฐ ๋์๋ ์์น์ output token๋ง์ ์ฌ์ฉํ ๊ฒ๋๋ค. ํด๋น hidden vector๋ค์, VQ-VAE๋ฅผ ํ์ต์ํด์ผ๋ก์จ ๋ฏธ๋ฆฌ ๊ตฌํด๋จ๋ codebook์ size๋ก projection (์ ์๋ค์ ์ด Linear layer๋ฅผ Masked Image Modeling Head๋ผ๊ณ ๋ถ๋ฆ ๋๋ค.) ํด์ ์๋ณธ image patch์ code๋ก classification์ ์งํํ๋ฉด ๋์ ๋๋ค.
4. Experiment & Result
Experimental Setup
Pre-training Setup
Dataset ์ด ๋ ผ๋ฌธ์์ pre-training์ ์ํด ์ฌ์ฉํ ๋ฐ์ดํฐ์ ์ ImageNet-1K์ training set์ ๋๋ค. ์ด 1.2๋ฐฑ๋ง๊ฐ์ 224 x 224์ resolution ์ด๋ฏธ์ง๋ค์ ์ฌ์ฉํ๋ค๊ณ ํฉ๋๋ค. ๋ํ data augmentation์ผ๋ก random resized cropping, horizontal flipping, color jittering์ ์ ์ฉํ์ผ๋ฉฐ mask ratio๋ ์์์ ์ธ๊ธํ๋๋ก ์ด 40%, ๊ฐ์๋ก ์น๋ฉด 14 x 14 = 196๊ฐ์ image patch ์ค ์ต๋ 75๊ฐ์ image patch๋ฅผ maskingํ๋ค๊ณ ํฉ๋๋ค.
Model Architecture ๋ชจ๋ธ ์ํคํ ์ณ๋ BERT-Base (ํน์ ViT-Base) ๋ชจ๋ธ๊ณผ ์์ ํ ๋์ผํฉ๋๋ค. ๊ตฌ์ฒด์ ์ผ๋ก๋, Transformer Encoder layer์ ๊ฐ์๋ 12๊ฐ, Encoder์ hidden dimension size๋ 768, feed-forward dimension size๋ 3072, multi-head attention heads์ ๊ฐ์๋ 12๊ฐ์ ๋๋ค. ๋ํ, ๊ฐ image patch์ size๋ 16 x 16์ผ๋ก ์ฌ์ฉํ์ต๋๋ค.
Training Hyperparameters 2K batch size๋ก ์ด 500K steps (=800 epochs)๋งํผ ํ์ต์์ผฐ์ต๋๋ค. 16๊ฐ์ Nvidia Tesla V100 32GB GPU๋ก ์ด 5์ผ๋์ ํ์ตํ๋ค๊ณ ํฉ๋๋ค.
Baseline ์ฑ๋ฅ ๋น๊ต๋ฅผ ์ํด contrastive learning ๋ฐฉ์์ SSL ๋ชจ๋ธ์ธ MoCo v3์ self-distillation ๋ฐฉ์์ DINO ๋ชจ๋ธ์ baseline์ผ๋ก ์ผ์์ต๋๋ค.
Fine-tuning Setup
Fine-tuning Task Pre-training์ ์ฑ๋ฅ์ ๊ฒ์ฆํ๊ธฐ ์ํ downstream task๋ก๋ image classification๊ณผ semantic segmentation์ ์งํํ์ต๋๋ค. ๋, intermediate fine-tuning์ด๋ผ๋ ๊ฒ์ ์งํํ๋ค๊ณ ํ๋๋ฐ, ์ด๊ฑด self-supervised pre-training์ด ๋๋ ๋ค ํด๋น pre-training ๋ฐ์ดํฐ์ ์ downstream task๋ก ๋ค์ ํ ๋ฒ fine-tuning์ ์งํํ ๋ค์, ์ต์ข ์ ์ผ๋ก target dataset์ fine-tuning์ ์ํค๋ ์์ ์ธ ๊ฒ ๊ฐ์ต๋๋ค.
Dataset image classificaton์ ์งํํ๊ธฐ ์ํ dataset์ผ๋ก๋ CIFAR-100๊ณผ pre-trainingํ ๋ ์ฌ์ฉํ๋ ImageNet-1K๋ฅผ ์ฌ์ฉํ๋ค๊ณ ํฉ๋๋ค. ๋ํ, Semantic segmentation์ ์งํํ๊ธฐ ์ํ dataset์ผ๋ก๋ ADE20K์ ImageNet-1K๋ฅผ ์ฌ์ฉํ์ต๋๋ค.
Evaluation Metric Image classification์ ์ฑ๋ฅ ๊ฒ์ฆ์ ์ํด์๋ Top-1 accuracy๋ฅผ ์ฌ์ฉํ๊ณ , semantic segmentation์ ์ฑ๋ฅ ๊ฒ์ฆ์ ์ํด์๋ mIoU metric์ ์ฌ์ฉํ์ต๋๋ค.
Result
Image Classification
Image classification์ ๊ฒฐ๊ณผ๋ ์์ ๊ฐ์ต๋๋ค. Baseline์ด ์ข ๋ง์์ ๋ณด๊ธฐ์ ์ข ๋์กํ ์ ์๋๋ฐ์, ๋ณด์ค๋งํ ๋ถ๋ถ๋ง ์ง์ด์ ๋ง์๋๋ฆฌ๊ฒ ์ต๋๋ค. ์ฐ์ MoCo v3์ ๋น๊ตํ์ ๋ BEIT์ ์ฑ๋ฅ์ด ๋ ๋ฐ์ด๋ ๊ฒ์ ํ์ธํ์ค ์ ์์ต๋๋ค. Intermediate fine-tuning์ ๊ฑฐ์ณค์ ๋๋ DINO๋ณด๋ค BEIT์ ์ฑ๋ฅ์ด ์์ฃผ ์ฝ๊ฐ ํฅ์๋๋ ๊ฒ๋ ํ์ธํ์ค ์ ์์ผ์ค ๊ฒ๋๋ค. ํํธ, ์ฑ๋ฅ๊ณผ๋ ๋ณ๊ฐ๋ก BEIT์ ํ์ต ์๋ ด ์๋๊ฐ random initialization๋ DeiT์ ํ์ต ์๋ ด ์๋๋ณด๋ค ๋น ๋ฅด๋ค๊ณ ํฉ๋๋ค. ํด๋น ์๋ฃ๋ ์๋์์ ํ์ธํ์ค ์ ์์ต๋๋ค.
Semantic Segmentation
Semantic Segmentation์ ๊ฒฐ๊ณผ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
Semantic segmentation task์์๋ ๋ง์ฐฌ๊ฐ์ง๋ก BEIT๊ฐ DINO๋ณด๋ค ์ฑ๋ฅ์ด ์ข์์ผ๋ฉฐ, intermediate fine-tuning์ ํ ๊ฒฝ์ฐ ์ฑ๋ฅ์ด ๋ ์ข์์ง๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค. ๋ํ, self-attention map์ ๋ณด์๋ฉด pre-training ๋จ๊ณ์์ semantic ๊ฒฝ๊ณ์ ๊ด๋ จ๋ ์๋ฌด๋ฐ annotation ์์ด ํ์ต๋์์๋ ๋ถ๊ตฌํ๊ณ pre-training๋ self-attention map์ด ์ด๋ฏธ์ง ๋ด์ object์ semanticํ ๊ฒฝ๊ณ๋ฅผ ์ ๊ฒ์ถํ๋ ๊ฒ์ ํ์ธํ์ค ์ ์์ต๋๋ค.
5. Conclusion
๊ฒฐ๋ก ์ ์ผ๋ก ์ด ๋ ผ๋ฌธ์ ๋ค์๊ณผ ๊ฐ์ด ์ ๋ฆฌํ ์ ์์ ๊ฒ ๊ฐ์ต๋๋ค.
Vision Transformer๋ฅผ self-supervision์ผ๋ก pre-trainingํจ์ผ๋ก์จ image classification, semantic segmentation ๋ฑ์ downstream task์ ๋ํด ์ฑ๋ฅ์ ํฅ์์ํค๋ ๋ฐฉ๋ฒ๋ก ์ ์ ์.
๊ธฐ์กด์ BERT์ฒ๋ผ MLM ๋ฐฉ์๋๋ก pre-trainingํ๋ ๋ฐฉ๋ฒ์ ์ด๋ฏธ์ง ๋๋ฉ์ธ์ ๋ง๊ฒ ๋ณํ์์ผ ์ ์ฉํจ.
Take home message (์ค๋์ ๊ตํ)
์ฌ์ค ์ด ๋ ผ๋ฌธ์ ์์ ํ ์๋ก์ด ๋ฐฉ๋ฒ๋ก ์ ์ ์ํ๋ค๊ธฐ ๋ณด๋ค๋ ๊ธฐ์กด์ ๋ฐฉ๋ฒ๋ก ๋ค์ ViT์ ์ ์ฉํด๋ณธ ๊ฒ์ ๋ถ๊ณผํฉ๋๋ค.
BERT์ MLM ํ์ต๋ฐฉ์์ ๊ฑฐ์ ๊ทธ๋๋ก ๋ฐ๋ผ๊ฐ๊ณ , ์ด๋ฅผ ์ํ image tokenizer๋ DALL-E์์ ๊ณต๊ฐํ tokenizer๋ฅผ ๊ทธ๋๋ก ์ฌ์ฉํ์ ๋ฟ๋ง ์๋๋ผ Backbone ์ํคํ ์ณ๋ ViT๋ฅผ ๊ทธ๋๋ก ์ฌ์ฉํ์ฃ .
๊ทธ๋ผ์๋ ๋ถ๊ตฌํ๊ณ ICLR 2022, ๊ทธ๊ฒ๋ Oral๋ก ๋ถ์ ๊ฒ์ ๋ณด๋ฉด, ๊ธฐ์กด์ ๋ฐฉ๋ฒ๋ก ๋ค์ ์ ์ฃผ๋ฌผ๋ฌ์ ์๋ก์ด ๋๋ฉ์ธ์ด๋ ์ปจ์ ์ ์ ์ฉํ๋ ๊ฒ๋ ๊ด์ฐฎ์ ์ฐ๊ตฌ์ฃผ์ ๊ฐ ๋ ๊ฒ์ด๋ผ๋ ์๊ฐ์ด ๋ญ๋๋ค.
Author / Reviewer information
Author
์ค์ ์ฐ (Jungwoo Oh)
KAIST AI
Reviewer
Korean name (English name): Affiliation / Contact information
Korean name (English name): Affiliation / Contact information
...
Reference & Additional materials
Last updated