Supervised Contrastive Replay: Revisiting the Nearest Class Mean Classifier in Online Class-Incremental Continual Learning[Kor]
1. Introduction
Continaul Learning (CL)
CL์ด๋, ์ฐ์์ ์ผ๋ก ์ฃผ์ด์ง๋ Data Stream์ Input์ผ๋ก ๋ฐ์, ์ฐ์์ ์ผ๋ก ํ์ตํ๋ ๋ชจ๋ธ์ ๋ง๋ค์ด๋ด๋ ๊ฒ์ ๋ชฉํ๋ก ํ๋ ๋ฌธ์ ์ธํ
์
๋๋ค. ํ์ฌ ๋ฅ ๋ฌ๋ ๊ธฐ๋ฐ์ ๋ชจ๋ธ๋ค์, ์๋ก์ด ๋ฐ์ดํฐ์
์ ํ์ตํ ๊ฒฝ์ฐ ์ด์ ๋ฐ์ดํฐ์
์์์ ์ฑ๋ฅ์ ๋งค์ฐ ๋จ์ด์ง๋๋ค. ์ด๋ฌํ ํ์์ Catastrophic Forgetting(CF)๋ผ๊ณ ๋ถ๋ฆ
๋๋ค. ์๋ฅผ ๋ค์ด ์ค๋ช
ํ์๋ฉด, Cifar10์ ํ์ตํ ๋ชจ๋ธ์ด MNIST๋ฅผ ํ์ตํ ๊ฒฝ์ฐ, MNIST์์์ ์ฑ๋ฅ์ ๋์ง๋ง, Cifar10์ ์ฑ๋ฅ์ ๋ฎ์์ง๋๋ค.(๋จ์ํ MNIST๋ฅผ ํธ๋ ์ด๋ ํ ๊ฒฝ์ฐ, ๊ฑฐ์ 0%์ ๊ฐ๊น์ด ์ฑ๋ฅ์ ๋ณด์
๋๋ค.) ์ด์ ์ Cifar10์์์ ์ฑ๋ฅ์ด ์ด๋๋ ๊ฐ์, ๊ทน์ ์ธ ์ฑ๋ฅ ํ๋ฝ์ด ๋ํ๋๊ฒ ๋ฉ๋๋ค. ์ด๋ Cifar10๊ณผ MNIST ๊ฐ์ด ์ฐ์์ ์ผ๋ก ๋ค์ด์ค๋ Dataset๋ค์ Task๋ผ๊ณ ๋ถ๋ฆ
๋๋ค.
CF๋ ๋ฅ ๋ฌ๋์ด ์ฌ๊ธฐ์ ๊ธฐ์ ์ฐ์ด๊ณ ์๋ ๊ณผ์ ์์ ๊ผญ ํด๊ฒฐํด์ผ ํ ๋ฌธ์ ์
๋๋ค. ํ๋ฒ ๋ชจ๋ธ์ ํ๋ จ์ํค๊ณ ๋ ํ, ๊ทธ ๋ชจ๋ธ์ ์ค์ ์๋น์ค์ ์๋นํ ๊ฒฝ์ฐ ๋ฐ์ดํฐ๋ ๋ ์์ด๊ฒ ๋ฉ๋๋ค. ํ์ง๋ง ์ด ๋ฐ์ดํฐ๋ฅผ ์ถ๊ฐ๋ก ํ์ต์ํค๊ฒ ๋๋ฉด, ๋ชจ๋ธ์ ์คํ๋ ค ์ฑ๋ฅ์ด ๋จ์ด์ง ์ ์์ต๋๋ค. ์ด์ ์ ๋ชจ๋ธ์ ํธ๋ ์ด๋ ํ ๋ ์ฌ์ฉํ๋ ๋ฐ์ดํฐ๋ฅผ ์ ๋ถ ๋ค ๋ค์ ์ฌ์ฉํ๊ณ , ์ถ๊ฐ๋ก ์ถ๊ฐ ๋ฐ์ดํฐ๋ฅผ ๋ฃ์ด์ฃผ์ด์ ํธ๋ ์ด๋์ ์์ผ์ผ ํ๋ ๊ฒ์
๋๋ค. ์ด๋ ๊ทน์ ์ธ ๊ณ์ฐ ๋นํจ์จ์ฑ์ ๋ถ๋ฆ
๋๋ค. ์๋์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ์ฐพ์์ ์ ์ ๋๋ํด์ง๋, ์ํ์ ๊ฐ์ AI๋ ์ง๊ธ ๋ํ๋์ง ์๋ ์ด์ ์
๋๋ค.
์ด๋ฌํ CF๋ฅผ ํด๊ฒฐํ๊ณ ์ ํ๋ ๋ฌธ์ ์ธํ
์ด CL์
๋๋ค. ์ด ๋
ผ๋ฌธ์ ์ ์ Zheda Mai๋ CL ๋ถ์ผ์์ ์ต๊ทผ ์ข์ ๋
ผ๋ฌธ์ ๋ง์ด ๋ด๋ฉฐ SOTA์ ๊ฐ๊น์ด ๋ฐฉ๋ฒ๋ก ๋ค์ ๋งค๋ฒ ์ ์ํ๊ณ ์์ต๋๋ค. Mai์ ๋
ผ๋ฌธ ์ค์์๋ ์ด ๋
ผ๋ฌธ์, ๋น๋ก ํธ๋ฆญ์ ์ฌ์ฉํ๊ธฐ๋ ํ์ง๋ง CL๋ก์๋ ์์๋ ํ์ง ๋ชปํ๋ ๋์ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ๋ ๋
ผ๋ฌธ์ด๊ธฐ ๋๋ฌธ์ ์๋นํ ๋งค๋ ฅ์ ์
๋๋ค.
Experience Replay(ER)
CL ๋ฌธ์ ์ธํ
์์ ํ์ฌ ์ง๋ฐฐ์ ์ด๋ผ๊ณ ํ ์ ์๋ ๋ฐฉ๋ฒ๋ก ์ Experience Replay์
๋๋ค. ๋จ์ํ ๋ฐฉ๋ฒ์๋ ๋ถ๊ตฌํ๊ณ ์ข์ ์ฑ๋ฅ์ ๋ณด์ด๊ณ , ๊ฐ์ ํ ์ฌ์ง๊ฐ ๋ชจ๋์ ์ผ๋ก ๋ง์ด ๋จ์์๊ธฐ ๋๋ฌธ์ ๋ง์ด ์ฐ๊ตฌ๋๊ณ ์์ต๋๋ค. ER์ ๋ฐฉ๋ฒ๋ก ์ ๊ฐ๋จํฉ๋๋ค. ์ด์ ํ์คํฌ์์ ๋ช๊ฐ์ง ๋ฐ์ดํฐ๋ฅผ ๋ฝ์ External Memory์ ์ ์ฅํด๋ก๋๋ค. ์๋ก์ด ํ์คํฌ๊ฐ ๋ค์ด์ค๋ฉด External Memory์ ์๋ ๋ฐ์ดํฐ์ ํจ๊ป ํ๋ จ์ํต๋๋ค.
๋น์ฐํ External Memory๊ฐ ๋ง์ผ๋ฉด ๋ง์ ์๋ก ์ด์ ํ์คํฌ์ ์ฑ๋ฅ ์ ํ๋ฅผ ์ ๋ง์ ์ ์์ต๋๋ค. ER์ ์ต์ข
๋ชฉํ๋ ์ต์ํ์ External Memory๋ฅผ ์ด์ฉํด์ ์ต๋ํ CF๋ฅผ ์ค์ด๋ ๊ฒ ์
๋๋ค.
ER์ ํ์ฌ ์ต์ ์ธํ
์ ๊ฐ๋ตํ๊ฒ ์ ๋ฆฌํ์๋ฉด ๋ค์๊ณผ ๊ฐ์ ์ ์ด ์ค์ํ๋ค๊ณ ํ ์ ์์ต๋๋ค.
ํ์ฌ ํ์คํฌ์ batch 1๊ฐ + External Memory์์์ batch 1๊ฐ๋ฅผ ํจ๊ป ํธ๋ ์ด๋ ํ๋ค.
External Memory์ ๊ฒฝ์ฐ ํฌ๊ธฐ๊ฐ ๋ณดํต ์๊ธฐ ๋๋ฌธ์ ๋์ ๊ทธ๋๋ก ํจ๊ป ํธ๋ ์ด๋ ํด๋ฒ๋ฆฌ๋ฉด ๋์ Class Imbalance๊ฐ ์ผ์ด๋์ ์ฑ๋ฅ์ด ๋จ์ด์ง๊ฒ ๋ฉ๋๋ค. ๋์ ๋น์จ์ ๋ง์ถฐ์ ํธ๋ ์ด๋ ํด ์ฃผ๋ ๊ฒ์ด ER์ ์ฑ๋ฅ์ ๋์ด๋ ํ์
๋๋ค.
2. Method
SoftMax Classifier์ CL์์์ ๋ฌธ์ ์
์ด ๋
ผ๋ฌธ์ ํต์ฌ Contribution์ด์ ์ ์๊ฐ ์ฃผ์ฅํ๋ ๊ฒ์ Softmax Classifier์ ๋ฌธ์ ์ ์
๋๋ค. Softmax Classifier๋ ๋ง์ ๋ถ๋ถ์์ ์ต๊ณ ์ ์ฑ๋ฅ์ ๋ด๊ณ ์์ง๋ง, CL์์ ๋งํผ์ ์ข์ง ์๋ค๋ ๊ฒ์ด ์ ์์ ์๊ฐ์
๋๋ค. ๊ทธ ์ด์ ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
์๋ก์ด ํด๋์ค๊ฐ ๋ค์ด์ค๋ ๊ฒ์ ์ ์ฐํ์ง ์๋ค
Softmax์ ํน์ฑ์ ์ฒ์๋ถํฐ ํด๋์ค์ ๊ฐฏ์๋ฅผ ์ ํด์ค์ผ ํฉ๋๋ค. ์ด ๋๋ฌธ์ ํ์คํฌ๊ฐ ์ผ๋ง๋ ๋ค์ด์ฌ์ง ๋ชจ๋ฅด๋ CL ์ธํ
์ ํน์ฑ์ ๋ง์ง ์์ต๋๋ค. (ํ์ง๋ง ํ์ฌ CL ์ฐ๊ตฌ๋ ๋๋ถ๋ถ ํ์คํฌ๊ฐ ์ผ๋ง๋ ๋ค์ด์ฌ์ง ์๊ณ ์์ต๋๋ค. ์ด๊ฒ์ ํ์ ์คํ์ ๋ณด์๋ฉด ๋ ์ ์ดํด๋ฉ๋๋ค.)
representation๊ณผ classification์ด ์ฐ๊ฒฐ๋์ด ์์ง ์๋ค
Encoder๊ฐ ๋ฐ๋ ๊ฒฝ์ฐ Softmax layer๋ ์๋ก ํ๋ จ๋์ด์ผ ํฉ๋๋ค.
Task-recency bias
์ด์ ์ ์ฌ๋ฌ ์ฐ๊ตฌ์์, Softmax classifier๊ฐ ์ต๊ทผ ํ์คํฌ์ ์น์ค๋๋ ๊ฒฝํฅ์ด ์๋ค๋ ๊ฒ์ด ๊ด์ฐฐ๋์์ต๋๋ค. ์ด๋ ๋ฐ์ดํฐ์ ๋ถํฌ๊ฐ ํ์ฌ ํ์คํฌ์ ์น์ค๋์ด์๋ CL์ ํน์ฑ์ ์ฑ๋ฅ์ ์น๋ช
์ ์ผ ์ ์์ต๋๋ค.
Nearest Class Mean(NCM) Classifier
์ ์๋ ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด์, Few-shot learning์์ ์ฃผ๋ก ์ฌ์ฉ๋๋ NCM Classifier๋ฅผ ์ฌ์ฉํ์๊ณ ์ฃผ์ฅํฉ๋๋ค. NCM Classifier์ ๊ฒฝ์ฐ Prototype Classifier๋ผ๊ณ ๋ ๋ถ๋ฆฝ๋๋ค. ์ด Classifier๋ ํธ๋ ์ด๋์ด ๋๋ ํ, ํธ๋ ์ด๋์ ์ฌ์ฉ๋์๋ ๋ชจ๋ ํด๋์ค ๋ฐ์ดํฐ์ ํ๊ท ์ ๋ด์ด ์ ์ฅํฉ๋๋ค. ์ด๋ ๊ฒ ์ ์ฅ๋ ํ๊ท ๊ฐ์ Prototype์ฒ๋ผ ์๋ํฉ๋๋ค. Test์, ๊ฐ์ฅ ๊ฐ๊น์ด Prototype์ ๊ฐ์ง๋ ํด๋์ค๋ก ํด๋์ค๋ฅผ ์ถ์ธกํ๊ฒ ๋ฉ๋๋ค.
NCM Classifier๋ SoftMax์ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ฉด์, few-shot learning์ฒ๋ผ data ๋ถ์กฑ ํ์์ ์๋ฌ๋ฆฌ๋ CL๊ณผ ๊ต์ฅํ ๊ถํฉ์ด ์ ๋ง์ต๋๋ค. ์ค์ ๋ก NCM Classfier๋ฅผ ์ ์ฉํ๋ ๊ฒ๋ง์ผ๋ก๋ ๋๋ถ๋ถ์ CL ๋ฐฉ๋ฒ๋ก ์ ์ฑ๋ฅ์ด ํฌ๊ฒ ์์นํฉ๋๋ค.
u c = 1 n c โ i f ( x i ) โ
1 { y i = c } u_c = \frac{1}{n_c}\sum_i f(x_i) \cdot 1\{y_i = c \} u c โ = n c โ 1 โ โ i โ f ( x i โ ) โ
1 { y i โ = c }
y โ = a r g m i n c = 1 , . . . , t โฃ โฃ f ( x ) โ u c โฃ โฃ y^* = argmin_{c=1,...,t} ||f(x) - u_c || y โ = a r g mi n c = 1 , ... , t โ โฃโฃ f ( x ) โ u c โ โฃโฃ
NCM classifier๋ฅผ ์ํด ์ฌ์ฉ๋๋ ์์์ ์์ ๊ฐ๋ค. ์ฌ๊ธฐ์ c๋ ํด๋์ค๋ฅผ ๋ปํ๊ณ , 1{y=c} ๋ y๊ฐ c์ผ ๋๋ฌธ 1์ด๋ผ๋ ๊ฒ์ ์๋ฏธํ๋ค. ํด๋์ค ๋ณ ๋ฉ๋ชจ๋ฆฌ์ ๋ค์ด์๋ ๋ฐ์ดํฐ์ ํ๊ท ์ ๊ตฌํ๊ณ , ๊ทธ ํ๊ท ์ ๊ฐ์ฅ ๊ฐ๊น์ด ํด๋์ค๋ก Inference๋ฅผ ์งํํ๋ค.
Supervisied Contrastive Replay
NCM Classifier์ ํฌํ
์
์ ๋ ๋์ผ ์ ์๋ ๋ฐฉ๋ฒ์ด SCR์
๋๋ค. NCM Classifier๋ Representation ๊ฐ ๊ฑฐ๋ฆฌ๋ฅผ ์ค์ฌ์ผ๋ก inference๋ฅผ ์งํํฉ๋๋ค. ์ด๋ฐ ์ํฉ์์ ๋ค๋ฅธ ํด๋์ค๋ ๋ ๋ฉ๋ฆฌ, ๊ฐ์ ํด๋์ค๋ ๋ ๊ฐ๊น์ด ๋ถ์ฌ๋๋ Contrastive Learning์ NCM์ ํฐ ๋์์ด ๋ ์ ์์ต๋๋ค. ์ ์๋ ํธ๋ ์ด๋ ๋ฐ์ดํฐ์ ๋จ์ํ Augmented View๋ฅผ ์ถ๊ฐํ๊ณ , ์ด ๋ฐ์ดํฐ๋ค์ ์ด์ฉํ์ฌ Contrastive Learning์ ์งํํฉ๋๋ค. ๋ฉ๋ชจ๋ฆฌ ๋ฐ์ดํฐ์ ํ์ฌ ๋ฐ์ดํฐ๋ฅผ ํจ๊ป ์ฌ์ฉํฉ๋๋ค.
L S C L ( Z I ) = โ i โ I 1 โฃ P ( i ) โฃ โ p โ P ( i ) l o g e x p ( z i โ
z p / ฯ ) โ j โ A ( i ) e x p ( z i โ
z j / ฯ ) L_{SCL}(Z_I) = \sum_{i\in I} \frac{1}{|P(i)|} \sum{p\in P(i)} log \frac{exp(z_i\cdot z_p / \tau)}{\sum{j \in A(i)}exp(z_i \cdot z_j / \tau) } L SC L โ ( Z I โ ) = โ i โ I โ โฃ P ( i ) โฃ 1 โ โ p โ P ( i ) l o g โ j โ A ( i ) e x p ( z i โ โ
z j โ / ฯ ) e x p ( z i โ โ
z p โ / ฯ ) โ
Loss ์์ ์ ์๊ณผ ๊ฐ์ต๋๋ค. $B = {x_k,y_k}{k=1,...,b}$์ Mini Batch๋ผ๊ณ ํ ๋, $\tilde{B}$ $= { \tilde{x_k} = Aug(x_k), y_k } {k=1,...,b}$ ์
๋๋ค. ๊ทธ๋ฆฌ๊ณ $B_I = B \cap \tilde{B}$ ์
๋๋ค. $I$๋ $B_I$์ ์ง์๋ค์ ์งํฉ์ด๊ณ , $A(i)=I \setminus {i}$ ์
๋๋ค. $P(i) = {p \in A(i) : y_p = y_i}$ ์
๋๋ค. ๋ณต์กํด ๋ณด์ด์ง๋ง ์ฐฌ์ฐฌํ ๋ฏ์ด๋ณด๋ฉด ์ด๋ ต์ง ์์ต๋๋ค. ๊ฒฐ๊ตญ $P(i)$๋ ์ํ i๋ฅผ ์ ์ธํ ๊ฒ ์ค์์ label์ด ๊ฐ์ ๊ฒ, ๊ทธ๋ฌ๋๊น Positive sample์ ์๋ฏธํฉ๋๋ค. $Z_I = {z_i}_{i \in I} = Model(x_i)$ ์ด๊ณ , $\tau$๋ ์กฐ์ ์ ์ํ temperature parameter ์
๋๋ค.
Implementation์์๋ Continual Learning์ ๋ฒค์น๋งํฌ๋ผ๊ณ ํ ์๋ ์๋ Split Cifar-10์์ ์คํ์ ์งํํฉ๋๋ค. ์ผ๋ฐ์ ์ธ BaseLine์ผ๋ก ๋ง์ด ์ฌ์ฉ๋๋ Experience Replay์ ๋ํ ๊ตฌํ๊ณผ, ์ด ๋
ผ๋ฌธ์์ ์ ์ํ NCN Classifier๋ฅผ ์ฌ์ฉํ Experience Replay์ ๋ํ ๊ตฌํ์ ์ค๋นํ์ต๋๋ค.
Environment
Colab ํ๊ฒฝ์์ ์คํํ๊ธฐ๋ฅผ ์ถ์ฒ๋๋ฆฝ๋๋ค.
Copy import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as D
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
Setting of Continual Learning
์ด ์ฑํฐ์์๋ Continual Learning evaluation์ ์ํ ๊ธฐ๋ณธ์ ์ธ ์ธํ
์ ์ค๋นํฉ๋๋ค. ๋ฐ์ดํฐ์
์ Cifar-10์ 5๊ฐ์ ํ์คํฌ๋ก ๋๋ Split Cifar-10์ ์ฌ์ฉํ์ต๋๋ค. ๋
ผ๋ฌธ์์๋ Reduced_ResNet18์ ๋ฒ ์ด์ค ๋ชจ๋ธ๋ก ์ฌ์ฉํฉ๋๋ค. ํ์ง๋ง ์ด Implementation์์๋ ๊ตฌํ์ ๊ฐ๋จํจ์ ์ํด ์์ CNN๋ชจ๋ธ์ ์ฌ์ฉํฉ๋๋ค. ์ด ์ฝ๋์์๋ Split Cifar-10์ ๋ง๋ค๊ณ , Reduced_ResNet18์ ์ ์ํฉ๋๋ค.
Copy # Made Split-Cifar10
def setting_data():
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]) #settign transform
transform_test = transforms.Compose([
transforms.ToTensor(),
])
#์๋ณธ Cifar-10 dataset์ ๋ค์ด๋ก๋ ๋ฐ์ ์ค๋๋ค
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=1,
shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=1,
shuffle=False)
#์๋์ ์ฝ๋๋ Cifar10์ ์์์ ์์(5๋ก ๋๋ ๋๋จธ์ง)๋ก 5๊ฐ์ ํ์คํฌ๋ก ๋ถ๋ฆฌํฉ๋๋ค.
set_x = [[] for i in range(5)]
set_y = [[] for i in range(5)]
set_x_ = [[] for i in range(5)]
set_y_ = [[] for i in range(5)]
if shuffle==False:
for batch_images, batch_labels in train_loader:
if batch_labels >= 5:
y = batch_labels-5
else :
y = batch_labels
set_x_[y].append(batch_images)
set_y_[y].append(batch_labels)
for i in range(5):
set_x[i] = torch.stack(set_x_[i])
set_y[i] = torch.stack(set_y_[i])
set_x_t = [[] for i in range(5)]
set_y_t = [[] for i in range(5)]
set_x_t_ = [[] for i in range(5)]
set_y_t_ = [[] for i in range(5)]
for batch_images, batch_labels in test_loader:
if batch_labels >= 5:
y = batch_labels-5
else :
y = batch_labels
set_x_t_[y].append(batch_images)
set_y_t_[y].append(batch_labels)
for i in range(5):
set_x_t[i] = torch.stack(set_x_t_[i])
set_y_t[i] = torch.stack(set_y_t_[i])
return set_x,set_y,set_x_t,set_y_t
์๋ ์ฝ๋๋ ์ฌ์ฉ๋ Base CNN ๋ชจ๋ธ์ธ Reduced ResNet18์ ์ ์ํฉ๋๋ค. ๋ง์ง๋ง FC ๋ ์ด์ด๋ฅผ ์ฌ์ฉํ์ง ์๋ features๋ผ๋ ํจ์๊ฐ ์กด์ฌํ๋ ์ ์ด ํน์ดํ ๋งํ ์ ์
๋๋ค. ์ด features๋ ํ์ NCM classifier๋ฅผ ๊ตฌํํ ๋ ์ฌ์ฉ๋ฉ๋๋ค,
Copy class PreActBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(PreActBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
)
def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out += shortcut
return out
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
def all_forward(self, x):
out1 = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out1))
out += self.shortcut(x)
out = F.relu(out)
return out1,out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Reduced_ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10,dropout=0.3):
super(Reduced_ResNet, self).__init__()
self.in_planes = 20
self.conv1 = nn.Conv2d(3, 20, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(20)
self.layer1 = self._make_layer(block, 20, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 40, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 80, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 160, num_blocks[3], stride=2)
self.d1 = nn.Dropout(p=dropout)
self.linear = nn.Linear(160*block.expansion, num_classes)
self.linear3 = nn.Linear(400, num_classes)
self.linear2 = nn.Linear(640,400)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def features(self, x):
'''Features before FC layers'''
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
return out
def reduced_ResNet18(num_classes=10):
return Reduced_ResNet(BasicBlock,[2,2,2,2],num_classes=num_classes)
Experience Replay
์๋ ์ฝ๋๋ Continual Learning์์ ๊ฐ์ฅ ๋ง์ด ์ฐ์ด๋ ๋ฒ ์ด์ค๋ผ์ธ ์ค ํ๋์ธ Experience Replay๋ฅผ ๊ตฌํํฉ๋๋ค. Memory size, training epoch, learning rate ๋ฑ ๋ค์ํ ์ต์
๋ค์ ๋ฐ๋๋ฉฐ ์ฑ๋ฅ์ด ์ด๋ป๊ฒ ๋ณํ๋์ง ์์๋ณด๋ฉด ์ฌ๋ฏธ์์ ๊ฒ์
๋๋ค.
๋จผ์ ์๋ ์ฝ๋์์๋ External Memory๋ฅผ ๊ตฌํํฉ๋๋ค. ๋ฉ๋ชจ๋ฆฌ๋ ์ด๋ค ์์ผ๋ก ๊ตฌํํด๋ ์๊ด์ ์์ง๋ง, ๋๋ค์ผ๋ก ๋ฉ๋ชจ๋ฆฌ์ ๋ค์ด๊ฐ/๋ฉ๋ชจ๋ฆฌ์์ ๋ฝํ ๋ฐ์ดํฐ๋ฅผ ์ฝ๊ฒ ๊ตฌํํ๊ธฐ ์ํด ํด๋์ค๋ฅผ ํ๋ ๋ง๋ค์์ต๋๋ค.
Copy class Memory():
def __init__(self,mem_size,size=32): #mem_size๋ ๋ฉ๋ชจ๋ฆฌ์ ํฌ๊ธฐ๋ฅผ ๊ฒฐ์ ํฉ๋๋ค.
self.mem_size = mem_size
self.image = []
self.label = []
self.num_tasks = 0
self.image_size=size
def add(self,image,label): #๋ฉ๋ชจ๋ฆฌ์ ๋ค์ด๊ฐ image์ label์ input์ผ๋ก ๋ฐ์ต๋๋ค. ์ ์ธํ ๋ ์ ํด์ค ๋ฉ๋ชจ๋ฆฌ ์ฌ์ด์ฆ์ ๋ง์ถ์ด ์๋์ผ๋ก ์ฌ์ด์ฆ๊ฐ ์กฐ์ ๋ฉ๋๋ค.
self.num_tasks +=1
image_new= []
label_new = []
task_size = int(self.mem_size/self.num_tasks)
if self.num_tasks != 1 :
for task_number in range(len(self.label)):
numbers = np.array([i for i in range(len(self.label[task_number]))])
choosed = np.random.choice(numbers,task_size)
image_new.append(self.image[task_number][choosed])
label_new.append(self.label[task_number][choosed])
numbers = np.array([i for i in range(len(label))])
choosed = np.random.choice(numbers,task_size)
image_new.append(image[choosed])
label_new.append(label[choosed])
self.image = image_new
self.label = label_new
def pull(self,size):
#๋ฉ๋ชจ๋ฆฌ์์ size๋งํผ์ image-label ์์ ๊บผ๋
๋๋ค. ์ญ์ ๋๋ค์ผ๋ก ์กฐ์ ํด์ค๋๋ค.
image = torch.stack(self.image).view(-1,3,self.image_size,self.image_size)
label = torch.stack(self.label).view(-1)
numbers = np.array([i for i in range(len(label))])
choosed = np.random.choice(numbers,size)
return image[choosed],label[choosed]
๋ฉ๋ชจ๋ฆฌ์ ๋ค์ด๊ฐ ์ํ๊ณผ, ๊บผ๋ด์ง๋ ์ํ์ ์ ํ๋ ๊ฒ์ ER method์์ ์ค์ํ ๋ถ๋ถ์
๋๋ค. ๊ธฐ๋ณธ์ ์ธ ER method๋ ๋ชจ๋ ๊ฒ์ ๋๋ค์ผ๋ก ์กฐ์ ํ์ง๋ง, MIR, GSS, ASER ๋ฑ์ ์ถ๊ฐ์ ์ธ ๋ฉ์๋๋ ์ด ๋ถ๋ถ์ผ๋ก ์ฃผ์ํ๊ฒ ์กฐ์ ํฉ๋๋ค.
๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ง๋ค์์ผ๋ ๋ค์์ผ๋ก ์งํํ ๊ฒ์ ํธ๋ ์ด๋, ํ
์คํธ, ๊ทธ๋ฆฌ๊ณ Continual Leaerning setting์
๋๋ค. ์งํํ๊ธฐ ํธํ๊ฒ ํธ๋ ์ด๋๊ณผ ํ
์คํธ๋ฅผ ๋ฐ๋ก ํจ์ํ ํ๊ณ , Continual Learning process๋ ER ํจ์์์ ๋ฐ๋ก ์ ์ํด์ค๋๋ค.
Copy from typing_extensions import TypeAlias
def training(model,training_data,memory,opt,epoch,mem=False,mem_iter=1,mem_batch=10):
model.train()
dl = D.DataLoader(training_data,batch_size=10,shuffle=True)
criterion = nn.CrossEntropyLoss()
for ep in range(epoch):
for i, batch_data in enumerate(dl):
batch_x,batch_y = batch_data
batch_x = batch_x.view(-1,3,32,32)
batch_y = batch_y.view(-1)
if mem==True:
for j in range(mem_iter) :
logits = model.forward(batch_x)
loss = criterion(logits,batch_y)
opt.zero_grad()
loss.backward()
mem_x, mem_y = memory.pull(mem_batch)
mem_x = mem_x.view(-1,3,32,32)
mem_y = mem_y.view(-1)
mem_logits = model.forward(mem_x)
mem_loss = criterion(mem_logits,mem_y)
mem_loss.backward()
else :
logits = model.forward(batch_x)
loss = criterion(logits,batch_y)
opt.zero_grad()
loss.backward()
opt.step()
def test(model,tls):
accs = []
model.eval()
for tl in tls:
correct = 0
total = 0
for x,y in tl:
x = x
y = y
total += y.size(0)
output = model(x)
_,predicted = output.max(1)
correct += predicted.eq(y).sum().item()
accs.append(100*correct/total)
return accs
def make_test_loaders(set_x_t,set_y_t):
tls = []
for i in range(len(set_x_t)):
ds = D.TensorDataset(set_x_t[i].view(-1,3,32,32),set_y_t[i].view(-1))
dl = D.DataLoader(ds,batch_size=100,shuffle=True)
tls.append(dl)
return tls
def ER(mem_size):
set_x,set_y,set_x_t,set_y_t = setting_data()
test_loaders = make_test_loaders(set_x_t,set_y_t)
model = reduced_ResNet18()
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
mem_x = []
mem_y = []
accs = []
Mem = Memory(mem_size)
for i in range(0,len(set_x)):
training_data = D.TensorDataset(set_x[i].view(-1,3,32,32),set_y[i].view(-1))
if i !=0:
training(model,training_data,Mem,optimizer,1,mem=True)
else:
training(model,training_data,[],optimizer,1,mem=False)
Mem.add(set_x[i].view(-1,3,32,32),set_y[i].view(-1))
acc = test(model,test_loaders)
accs.append(acc)
print(acc)
print('final accracy : ', np.array(acc).mean())
colab cpu๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ ์ฝ 20๋ถ ์ ๋๊ฐ ์์๋ฉ๋๋ค. Memory size 1000, epoch 1์ ์ํฉ์์ ์ต์ข
์ฑ๋ฅ์ ํ๊ท ์ ์ฝ 34-36์ ๋๋ก ๋์จ๋ค๋ฉด ํ๋ฅญํฉ๋๋ค. ์ ์์ ๋
ผ๋ฌธ์ ๋์จ ํ๊ท ๊ฐ์ ๋๋ต 37 ์ ๋์
๋๋ค. learning rate์ 0.05-0.08 ์ ๋๋ก ๋ฎ์ถ๋ค๋ฉด ์ ์์ ์ฑ๋ฅ์ ๊ทผ์ ํ ๊ฐ์ ์ป์ ์ ์์ต๋๋ค.
Use NCM Classifier
์ฌ๊ธฐ์ Contrastive Learning๊น์ง ๊ตฌํํ๋ ๊ฒ์ CPU๋ง ์ฌ์ฉํ๋ ํน์ฑ์ ์ด๋ ต๊ธฐ ๋๋ฌธ์, NCM Classifier๋ฅผ ๊ตฌํํ๊ณ , ์ฑ๋ฅ ์์น์ ํ์ธํ ์ ์๋๋ก Implementation ํ ๊ฒ์
๋๋ค.
Copy def ncm_test(model,mem_x,mem_y,tls):
labels = np.unique(np.array(mem_y))
classes= labels
exemplar_means = {}
cls_sample = {label : [] for label in labels}
ds = D.TensorDataset(mem_x.view(-1,3,32,32),mem_y.view(-1))
dl = D.DataLoader(ds,batch_size=1,shuffle=False)
accs = []
for image,label in dl:
cls_sample[label.item()].append(image)
for cl, exemplar in cls_sample.items():
features = []
for ex in exemplar:
feature = model.features(ex.view(-1,3,32,32)).detach().clone()
feature.data= feature.data/feature.data.norm()
features.append(feature)
if len(features)==0:
mu_y = torch.normal(0,1,size=tuple(model.features(x.view(-1,3,32,32)).detach().size()))
else :
features = torch.stack(features)
mu_y = features.mean(0)
mu_y.data = mu_y.data/mu_y.data.norm()
exemplar_means[cl] = mu_y
with torch.no_grad():
model = model
for task, test_loader in enumerate(tls):
acc = []
correct = 0
size =0
for batch_x,batch_y in test_loader:
batch_x = batch_x
batch_y = batch_y
feature = model.features(batch_x)
for j in range(feature.size(0)):
feature.data[j] = feature.data[j] / feature.data[j].norm()
feature = feature.view(-1,160,1)
means = torch.stack([exemplar_means[cls] for cls in classes]).view(-1,160)
means = torch.stack([means] * batch_x.size(0))
means = means.transpose(1,2)
feature = feature.expand_as(means)
dists = (feature-means).pow(2).sum(1).squeeze()
_,pred_label = dists.min(1)
correct_cnt = (np.array(classes)[pred_label.tolist()]==batch_y.cpu().numpy()).sum().item()/batch_y.size(0)
correct += correct_cnt * batch_y.size(0)
size += batch_y.size(0)
accs.append(correct/size)
return accs
def NCM_ER(mem_size):
set_x,set_y,set_x_t,set_y_t = setting_data()
test_loaders = make_test_loaders(set_x_t,set_y_t)
model = reduced_ResNet18()
optimizer = torch.optim.SGD(model.parameters(),lr=0.05)
mem_x = []
mem_y = []
accs = []
Mem = Memory(mem_size)
for i in range(0,len(set_x)):
training_data = D.TensorDataset(set_x[i].view(-1,3,32,32),set_y[i].view(-1))
if i !=0:
training(model,training_data,Mem,optimizer,1,mem=True)
else:
training(model,training_data,[],optimizer,1,mem=False)
Mem.add(set_x[i].view(-1,3,32,32),set_y[i].view(-1))
acc = ncm_test(model,Mem,test_loaders)
print(acc)
print('final accracy : ', np.array(acc).mean())
NCM_ER์ ์ด์ฉํ ๊ฒฝ์ฐ, Colab CPU์์ ์ฝ 21๋ถ์ด ์์๋ฉ๋๋ค. ์ฑ๋ฅ์ memory size 1000 ๊ธฐ์ค์ผ๋ก ์ฝ 38-41 ์ ๋๋ก, ์ ์์ reference ๊ฐ๋ณด๋ค ๋ฎ๊ฒ ๋์ค๋๋ผ๋ ๊ด์ฐฎ์ต๋๋ค. hyperparemeter tuning์ ์ ์ํํ๋ค๋ฉด ์ ์์ ์ฑ๋ฅ์ ๊ทผ์ ํ๊ฒ ์ฑ๋ฅ์ ์ฌ๋ฆด ์ ์์ต๋๋ค.
Take Home Message
continual learning์ ์์ง ๊ฐ ๊ธธ์ด ๋จธ๋, contrastive learning์ด๋ transformer์ฒ๋ผ main vision task์์๋ ์ด๋ฏธ ๊ทธ ๋ฅ๋ ฅ์ด ๊ฒ์ฆ๋์์ง๋ง continual leanring์์๋ ์ ์ฐ์ธ ๊ฒ๋ค์ด ๋ง์ต๋๋ค. ์ ์ดํด๋ณธ๋ค๋ฉด ์์ง continual learning์ ๋ฐ์ ๊ฐ๋ฅ์ฑ์ด ์ถฉ๋ถํฉ๋๋ค.
Author
๊ถ๋ฏผ์ฐฌ (MINCHAN KWON)
https://kmc0207.github.io/CV/
Reviewer
Korean name (English name): Affiliation / Contact information
Korean name (English name): Affiliation / Contact information
Reference & Additional materials
Official (unofficial) GitHub repository