๐Ÿ“
Awesome reviews
  • Welcome
  • Paper review
    • [2022 Spring] Paper review
      • RobustNet [Eng]
      • DPT [Kor]
      • DALL-E [Kor]
      • VRT: A Video Restoration Transformer [Kor]
      • Barbershop [Kor]
      • Barbershop [Eng]
      • REFICS [ENG]
      • Deep texture manifold [Kor]
      • SlowFast Networks [Kor]
      • SCAN [Eng]
      • DPT [Kor]
      • Chaining a U-Net With a Residual U-Net for Retinal Blood Vessels Segmentation [Kor]
      • Chaining a U-Net With a Residual U-Net for Retinal Blood Vessels Segmentation [Eng]
      • Patch Cratf : Video Denoising by Deep Modeling and Patch Matching [Eng]
      • LAFITE: Towards Language-Free Training for Text-to-Image Generation [Kor]
      • RegSeg [Eng]
      • D-NeRF [Eng]
      • SimCLR [Kor]
      • LabOR [Kor]
      • LabOR [Eng]
      • SegFormer [Kor]
      • Self-Calibrating Neural Radiance Fields [Kor]
      • Self-Calibrating Neural Radiance Fields [Eng]
      • GIRAFFE [Kor]
      • GIRAFFE [Eng]
      • DistConv [Kor]
      • SCAN [Eng]
      • slowfastnetworks [Kor]
      • Nesterov and Scale-Invariant Attack [Kor]
      • OutlierExposure [Eng]
      • TSNs [Kor]
      • TSNs [Eng]
      • Improving the Transferability of Adversarial Samples With Adversarial Transformations [Kor]
      • VOS: OOD detection by Virtual Outlier Synthesis [Kor]
      • MultitaskNeuralProcess [Kor]
      • RSLAD [Eng]
      • Deep Learning for 3D Point Cloud Understanding: A Survey [Eng]
      • BEIT [Kor]
      • Divergence-aware Federated Self-Supervised Learning [Eng]
      • NeRF-W [Kor]
      • Learning Multi-Scale Photo Exposure Correction [Eng]
      • ReActNet: Towards Precise Binary Neural Network with Generalized Activation Functions [Eng]
      • ViT [Eng]
      • CrossTransformer [Kor]
      • NeRF [Kor]
      • RegNeRF [Kor]
      • Image Inpainting with External-internal Learning and Monochromic Bottleneck [Eng]
      • CLIP-NeRF [Kor]
      • CLIP-NeRF [Eng]
      • DINO: Emerging Properties in Self-Supervised Vision Transformers [Eng]
      • DINO: Emerging Properties in Self-Supervised Vision Transformers [Kor]
      • DatasetGAN [Eng]
      • MOS [Kor]
      • MOS [Eng]
      • PlaNet [Eng]
      • MAE [Kor]
      • Fair Attribute Classification through Latent Space De-biasing [Kor]
      • Fair Attribute Classification through Latent Space De-biasing [Eng]
      • Learning to Adapt in Dynamic, Real-World Environments Through Meta-Reinforcement Learning [Kor]
      • PointNet [Kor]
      • PointNet [Eng]
      • MSD AT [Kor]
      • MM-TTA [Kor]
      • MM-TTA [Eng]
      • M-CAM [Eng]
      • MipNerF [Kor]
      • The Emergence of Objectness: Learning Zero-Shot Segmentation from Videos [Eng]
      • Calibration [Eng]
      • CenterPoint [Kor]
      • YOLOX [Kor]
    • [2021 Fall] Paper review
      • DenseNet [Kor]
      • Time series as image [Kor]
      • mem3d [Kor]
      • GraSP [Kor]
      • DRLN [Kor]
      • VinVL: Revisiting Visual Representations in Vision-Language Models [Eng]
      • VinVL: Revisiting Visual Representations in Vision-Language Models [Kor]
      • NeSyXIL [Kor]
      • NeSyXIL [Eng]
      • RCAN [Kor]
      • RCAN [Eng]
      • MI-AOD [Kor]
      • MI-AOD [Eng]
      • DAFAS [Eng]
      • HyperGAN [Eng]
      • HyperGAN [Kor]
      • Scene Text Telescope: Text-focused Scene Image Super-Resolution [Eng]
      • Scene Text Telescope: Text-focused Scene Image Super-Resolution [Kor]
      • UPFlow [Eng]
      • GFP-GAN [Kor]
      • Federated Contrastive Learning [Kor]
      • Federated Contrastive Learning [Eng]
      • BGNN [Kor]
      • LP-KPN [Kor]
      • Feature Disruptive Attack [Kor]
      • Representative Interpretations [Kor]
      • Representative Interpretations [Eng]
      • Neural Discrete Representation Learning [KOR]
      • Neural Discrete Representation Learning [ENG]
      • Video Frame Interpolation via Adaptive Convolution [Kor]
      • Separation of hand motion and pose [kor]
      • pixelNeRF [Kor]
      • pixelNeRF [Eng]
      • SRResNet and SRGAN [Eng]
      • MZSR [Kor]
      • SANforSISR [Kor]
      • IPT [Kor]
      • Swin Transformer [kor]
      • CNN Cascade for Face Detection [Kor]
      • CapsNet [Kor]
      • Towards Better Generalization: Joint Depth-Pose Learning without PoseNet [Kor]
      • CSRNet [Kor]
      • ScrabbleGAN [Kor]
      • CenterTrack [Kor]
      • CenterTrack [Eng]
      • STSN [Kor]
      • STSN [Eng]
      • VL-BERT:Visual-Linguistic BERT [Kor]
      • VL-BERT:Visual-Linguistic BERT [Eng]
      • Squeeze-and-Attention Networks for Semantic segmentation [Kor]
      • Shot in the dark [Kor]
      • Noise2Self [Kor]
      • Noise2Self [Eng]
      • Dynamic Head [Kor]
      • PSPNet [Kor]
      • PSPNet [Eng]
      • CUT [Kor]
      • CLIP [Eng]
      • Local Implicit Image Function [Kor]
      • Local Implicit Image Function [Eng]
      • MetaAugment [Eng]
      • Show, Attend and Tell [Kor]
      • Transformer [Kor]
      • DETR [Eng]
      • Multimodal Versatile Network [Eng]
      • Multimodal Versatile Network [Kor]
      • BlockDrop [Kor]
      • MDETR [Kor]
      • MDETR [Eng]
      • FSCE [Kor]
      • waveletSR [Kor]
      • DAN-net [Eng]
      • Boosting Monocular Depth Estimation [Eng]
      • Progressively Complementary Network for Fisheye Image Rectification Using Appearance Flow [Kor]
      • Syn2real-generalization [Kor]
      • Syn2real-generalization [Eng]
      • GPS-Net [Kor]
      • Frustratingly Simple Few Shot Object Detection [Eng]
      • DCGAN [Kor]
      • RealSR [Kor]
      • AMP [Kor]
      • AMP [Eng]
      • RCNN [Kor]
      • MobileNet [Eng]
  • Author's note
    • [2022 Spring] Author's note
      • Pop-Out Motion [Kor]
    • [2021 Fall] Author's note
      • Standardized Max Logits [Eng]
      • Standardized Max Logits [Kor]
  • Dive into implementation
    • [2022 Spring] Implementation
      • Supervised Contrastive Replay [Kor]
      • Pose Recognition with Cascade Transformers [Eng]
    • [2021 Fall] Implementation
      • Diversity Input Method [Kor]
        • Source code
      • Diversity Input Method [Eng]
        • Source code
  • Contributors
    • [2022 Fall] Contributors
    • [2021 Fall] Contributors
  • How to contribute?
    • (Template) Paper review [Language]
    • (Template) Author's note [Language]
    • (Template) Implementation [Language]
  • KAIST AI
Powered by GitBook
On this page
  • 1. Introduction
  • 2. Method
  • Experience Replay
  • Reference & Additional materials

Was this helpful?

  1. Dive into implementation
  2. [2022 Spring] Implementation

Supervised Contrastive Replay [Kor]

Mai, Zheda / Supervised contrastive replay- Revisiting the nearest class mean classifier in online class-incremental continual learning / CVPR 2021

Supervised Contrastive Replay: Revisiting the Nearest Class Mean Classifier in Online Class-Incremental Continual Learning[Kor]

1. Introduction

Continaul Learning (CL)

CL์ด๋ž€, ์—ฐ์†์ ์œผ๋กœ ์ฃผ์–ด์ง€๋Š” Data Stream์„ Input์œผ๋กœ ๋ฐ›์•„, ์—ฐ์†์ ์œผ๋กœ ํ•™์Šตํ•˜๋Š” ๋ชจ๋ธ์„ ๋งŒ๋“ค์–ด๋‚ด๋Š” ๊ฒƒ์„ ๋ชฉํ‘œ๋กœ ํ•˜๋Š” ๋ฌธ์ œ ์„ธํŒ…์ž…๋‹ˆ๋‹ค. ํ˜„์žฌ ๋”ฅ ๋Ÿฌ๋‹ ๊ธฐ๋ฐ˜์˜ ๋ชจ๋ธ๋“ค์€, ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ์…‹์„ ํ•™์Šตํ•  ๊ฒฝ์šฐ ์ด์ „ ๋ฐ์ดํ„ฐ์…‹์—์„œ์˜ ์„ฑ๋Šฅ์€ ๋งค์šฐ ๋–จ์–ด์ง‘๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ํ˜„์ƒ์„ Catastrophic Forgetting(CF)๋ผ๊ณ  ๋ถ€๋ฆ…๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด ์„ค๋ช…ํ•˜์ž๋ฉด, Cifar10์„ ํ•™์Šตํ•œ ๋ชจ๋ธ์ด MNIST๋ฅผ ํ•™์Šตํ•  ๊ฒฝ์šฐ, MNIST์—์„œ์˜ ์„ฑ๋Šฅ์€ ๋†’์ง€๋งŒ, Cifar10์˜ ์„ฑ๋Šฅ์€ ๋‚ฎ์•„์ง‘๋‹ˆ๋‹ค.(๋‹จ์ˆœํžˆ MNIST๋ฅผ ํŠธ๋ ˆ์ด๋‹ ํ•œ ๊ฒฝ์šฐ, ๊ฑฐ์˜ 0%์— ๊ฐ€๊นŒ์šด ์„ฑ๋Šฅ์„ ๋ณด์ž…๋‹ˆ๋‹ค.) ์ด์ €์— Cifar10์—์„œ์˜ ์„ฑ๋Šฅ์ด ์–ด๋•Ÿ๋˜ ๊ฐ„์—, ๊ทน์ ์ธ ์„ฑ๋Šฅ ํ•˜๋ฝ์ด ๋‚˜ํƒ€๋‚˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ์ด๋•Œ Cifar10๊ณผ MNIST ๊ฐ™์ด ์—ฐ์†์ ์œผ๋กœ ๋“ค์–ด์˜ค๋Š” Dataset๋“ค์„ Task๋ผ๊ณ  ๋ถ€๋ฆ…๋‹ˆ๋‹ค.

CF๋Š” ๋”ฅ ๋Ÿฌ๋‹์ด ์—ฌ๊ธฐ์ €๊ธฐ์— ์“ฐ์ด๊ณ  ์žˆ๋Š” ๊ณผ์ •์—์„œ ๊ผญ ํ•ด๊ฒฐํ•ด์•ผ ํ•  ๋ฌธ์ œ์ž…๋‹ˆ๋‹ค. ํ•œ๋ฒˆ ๋ชจ๋ธ์„ ํ›ˆ๋ จ์‹œํ‚ค๊ณ  ๋‚œ ํ›„, ๊ทธ ๋ชจ๋ธ์„ ์‹ค์ œ ์„œ๋น„์Šค์— ์„œ๋น™ํ•  ๊ฒฝ์šฐ ๋ฐ์ดํ„ฐ๋Š” ๋” ์Œ“์ด๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ด ๋ฐ์ดํ„ฐ๋ฅผ ์ถ”๊ฐ€๋กœ ํ•™์Šต์‹œํ‚ค๊ฒŒ ๋˜๋ฉด, ๋ชจ๋ธ์€ ์˜คํžˆ๋ ค ์„ฑ๋Šฅ์ด ๋–จ์–ด์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด์ „์— ๋ชจ๋ธ์„ ํŠธ๋ ˆ์ด๋‹ ํ•  ๋•Œ ์‚ฌ์šฉํ–ˆ๋˜ ๋ฐ์ดํ„ฐ๋ฅผ ์ „๋ถ€ ๋‹ค ๋‹ค์‹œ ์‚ฌ์šฉํ•˜๊ณ , ์ถ”๊ฐ€๋กœ ์ถ”๊ฐ€ ๋ฐ์ดํ„ฐ๋ฅผ ๋„ฃ์–ด์ฃผ์–ด์„œ ํŠธ๋ ˆ์ด๋‹์„ ์‹œ์ผœ์•ผ ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด๋Š” ๊ทน์ ์ธ ๊ณ„์‚ฐ ๋น„ํšจ์œจ์„ฑ์„ ๋ถ€๋ฆ…๋‹ˆ๋‹ค. ์ž๋™์œผ๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ์ฐพ์•„์„œ ์ ์  ๋˜‘๋˜‘ํ•ด์ง€๋Š”, ์˜ํ™”์™€ ๊ฐ™์€ AI๋Š” ์ง€๊ธˆ ๋‚˜ํƒ€๋‚˜์ง€ ์•Š๋Š” ์ด์œ ์ž…๋‹ˆ๋‹ค.

์ด๋Ÿฌํ•œ CF๋ฅผ ํ•ด๊ฒฐํ•˜๊ณ ์ž ํ•˜๋Š” ๋ฌธ์ œ ์„ธํŒ…์ด CL์ž…๋‹ˆ๋‹ค. ์ด ๋…ผ๋ฌธ์˜ ์ €์ž Zheda Mai๋Š” CL ๋ถ„์•ผ์—์„œ ์ตœ๊ทผ ์ข‹์€ ๋…ผ๋ฌธ์„ ๋งŽ์ด ๋‚ด๋ฉฐ SOTA์— ๊ฐ€๊นŒ์šด ๋ฐฉ๋ฒ•๋ก ๋“ค์„ ๋งค๋ฒˆ ์ œ์‹œํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. Mai์˜ ๋…ผ๋ฌธ ์ค‘์—์„œ๋„ ์ด ๋…ผ๋ฌธ์€, ๋น„๋ก ํŠธ๋ฆญ์„ ์‚ฌ์šฉํ•˜๊ธฐ๋Š” ํ–ˆ์ง€๋งŒ CL๋กœ์„œ๋Š” ์ƒ์ƒ๋„ ํ•˜์ง€ ๋ชปํ–ˆ๋˜ ๋†’์€ ์„ฑ๋Šฅ์„ ๋ณด์—ฌ์ฃผ๋Š” ๋…ผ๋ฌธ์ด๊ธฐ ๋•Œ๋ฌธ์— ์ƒ๋‹นํžˆ ๋งค๋ ฅ์ ์ž…๋‹ˆ๋‹ค.

Experience Replay(ER)

CL ๋ฌธ์ œ ์„ธํŒ…์—์„œ ํ˜„์žฌ ์ง€๋ฐฐ์ ์ด๋ผ๊ณ  ํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•๋ก ์€ Experience Replay์ž…๋‹ˆ๋‹ค. ๋‹จ์ˆœํ•œ ๋ฐฉ๋ฒ•์—๋„ ๋ถˆ๊ตฌํ•˜๊ณ  ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์ด๊ณ , ๊ฐœ์„ ํ•  ์—ฌ์ง€๊ฐ€ ๋ชจ๋“ˆ์ ์œผ๋กœ ๋งŽ์ด ๋‚จ์•„์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋งŽ์ด ์—ฐ๊ตฌ๋˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ER์˜ ๋ฐฉ๋ฒ•๋ก ์€ ๊ฐ„๋‹จํ•ฉ๋‹ˆ๋‹ค. ์ด์ „ ํƒœ์Šคํฌ์—์„œ ๋ช‡๊ฐ€์ง€ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฝ‘์•„ External Memory์— ์ €์žฅํ•ด๋‘ก๋‹ˆ๋‹ค. ์ƒˆ๋กœ์šด ํƒœ์Šคํฌ๊ฐ€ ๋“ค์–ด์˜ค๋ฉด External Memory์— ์žˆ๋Š” ๋ฐ์ดํ„ฐ์™€ ํ•จ๊ป˜ ํ›ˆ๋ จ์‹œํ‚ต๋‹ˆ๋‹ค.

๋‹น์—ฐํžˆ External Memory๊ฐ€ ๋งŽ์œผ๋ฉด ๋งŽ์„ ์ˆ˜๋ก ์ด์ „ ํƒœ์Šคํฌ์˜ ์„ฑ๋Šฅ ์ €ํ•˜๋ฅผ ์ž˜ ๋ง‰์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ER์˜ ์ตœ์ข… ๋ชฉํ‘œ๋Š” ์ตœ์†Œํ•œ์˜ External Memory๋ฅผ ์ด์šฉํ•ด์„œ ์ตœ๋Œ€ํ•œ CF๋ฅผ ์ค„์ด๋Š” ๊ฒƒ ์ž…๋‹ˆ๋‹ค.

ER์˜ ํ˜„์žฌ ์ตœ์‹  ์„ธํŒ…์„ ๊ฐ„๋žตํ•˜๊ฒŒ ์ •๋ฆฌํ•˜์ž๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์ ์ด ์ค‘์š”ํ•˜๋‹ค๊ณ  ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  • ํ˜„์žฌ ํƒœ์Šคํฌ์˜ batch 1๊ฐœ + External Memory์—์„œ์˜ batch 1๊ฐœ๋ฅผ ํ•จ๊ป˜ ํŠธ๋ ˆ์ด๋‹ ํ•œ๋‹ค.

  • External Memory์˜ ๊ฒฝ์šฐ ํฌ๊ธฐ๊ฐ€ ๋ณดํ†ต ์ž‘๊ธฐ ๋•Œ๋ฌธ์— ๋‘˜์„ ๊ทธ๋Œ€๋กœ ํ•จ๊ป˜ ํŠธ๋ ˆ์ด๋‹ ํ•ด๋ฒ„๋ฆฌ๋ฉด ๋‘˜์˜ Class Imbalance๊ฐ€ ์ผ์–ด๋‚˜์„œ ์„ฑ๋Šฅ์ด ๋–จ์–ด์ง€๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ๋‘˜์˜ ๋น„์œจ์„ ๋งž์ถฐ์„œ ํŠธ๋ ˆ์ด๋‹ ํ•ด ์ฃผ๋Š” ๊ฒƒ์ด ER์˜ ์„ฑ๋Šฅ์„ ๋†’์ด๋Š” ํŒ์ž…๋‹ˆ๋‹ค.

2. Method

SoftMax Classifier์˜ CL์—์„œ์˜ ๋ฌธ์ œ์ 

์ด ๋…ผ๋ฌธ์˜ ํ•ต์‹ฌ Contribution์ด์ž ์ €์ž๊ฐ€ ์ฃผ์žฅํ•˜๋Š” ๊ฒƒ์€ Softmax Classifier์˜ ๋ฌธ์ œ์ ์ž…๋‹ˆ๋‹ค. Softmax Classifier๋Š” ๋งŽ์€ ๋ถ€๋ถ„์—์„œ ์ตœ๊ณ ์˜ ์„ฑ๋Šฅ์„ ๋‚ด๊ณ  ์žˆ์ง€๋งŒ, CL์—์„œ ๋งŒํผ์€ ์ข‹์ง€ ์•Š๋‹ค๋Š” ๊ฒƒ์ด ์ €์ž์˜ ์ƒ๊ฐ์ž…๋‹ˆ๋‹ค. ๊ทธ ์ด์œ ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

  • ์ƒˆ๋กœ์šด ํด๋ž˜์Šค๊ฐ€ ๋“ค์–ด์˜ค๋Š” ๊ฒƒ์— ์œ ์—ฐํ•˜์ง€ ์•Š๋‹ค

    • Softmax์˜ ํŠน์„ฑ์ƒ ์ฒ˜์Œ๋ถ€ํ„ฐ ํด๋ž˜์Šค์˜ ๊ฐฏ์ˆ˜๋ฅผ ์ •ํ•ด์ค˜์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด ๋•Œ๋ฌธ์— ํƒœ์Šคํฌ๊ฐ€ ์–ผ๋งˆ๋‚˜ ๋“ค์–ด์˜ฌ์ง€ ๋ชจ๋ฅด๋Š” CL ์„ธํŒ…์˜ ํŠน์„ฑ์— ๋งž์ง€ ์•Š์Šต๋‹ˆ๋‹ค. (ํ•˜์ง€๋งŒ ํ˜„์žฌ CL ์—ฐ๊ตฌ๋Š” ๋Œ€๋ถ€๋ถ„ ํƒœ์Šคํฌ๊ฐ€ ์–ผ๋งˆ๋‚˜ ๋“ค์–ด์˜ฌ์ง€ ์•Œ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ํ›„์˜ ์‹คํ—˜์„ ๋ณด์‹œ๋ฉด ๋” ์ž˜ ์ดํ•ด๋ฉ๋‹ˆ๋‹ค.)

  • representation๊ณผ classification์ด ์—ฐ๊ฒฐ๋˜์–ด ์žˆ์ง€ ์•Š๋‹ค

    • Encoder๊ฐ€ ๋ฐ”๋€” ๊ฒฝ์šฐ Softmax layer๋Š” ์ƒˆ๋กœ ํ›ˆ๋ จ๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

  • Task-recency bias

    • ์ด์ „์˜ ์—ฌ๋Ÿฌ ์—ฐ๊ตฌ์—์„œ, Softmax classifier๊ฐ€ ์ตœ๊ทผ ํƒœ์Šคํฌ์— ์น˜์ค‘๋˜๋Š” ๊ฒฝํ–ฅ์ด ์žˆ๋‹ค๋Š” ๊ฒƒ์ด ๊ด€์ฐฐ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ๋ฐ์ดํ„ฐ์˜ ๋ถ„ํฌ๊ฐ€ ํ˜„์žฌ ํƒœ์Šคํฌ์— ์น˜์ค‘๋˜์–ด์žˆ๋Š” CL์˜ ํŠน์„ฑ์ƒ ์„ฑ๋Šฅ์— ์น˜๋ช…์ ์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Nearest Class Mean(NCM) Classifier

์ €์ž๋Š” ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด์„œ, Few-shot learning์—์„œ ์ฃผ๋กœ ์‚ฌ์šฉ๋˜๋Š” NCM Classifier๋ฅผ ์‚ฌ์šฉํ•˜์ž๊ณ  ์ฃผ์žฅํ•ฉ๋‹ˆ๋‹ค. NCM Classifier์˜ ๊ฒฝ์šฐ Prototype Classifier๋ผ๊ณ ๋„ ๋ถˆ๋ฆฝ๋‹ˆ๋‹ค. ์ด Classifier๋Š” ํŠธ๋ ˆ์ด๋‹์ด ๋๋‚œ ํ›„, ํŠธ๋ ˆ์ด๋‹์— ์‚ฌ์šฉ๋˜์—ˆ๋˜ ๋ชจ๋“  ํด๋ž˜์Šค ๋ฐ์ดํ„ฐ์˜ ํ‰๊ท ์„ ๋‚ด์–ด ์ €์žฅํ•ฉ๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ์ €์žฅ๋œ ํ‰๊ท ๊ฐ’์€ Prototype์ฒ˜๋Ÿผ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค. Test์‹œ, ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด Prototype์„ ๊ฐ€์ง€๋Š” ํด๋ž˜์Šค๋กœ ํด๋ž˜์Šค๋ฅผ ์ถ”์ธกํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

NCM Classifier๋Š” SoftMax์˜ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋ฉด์„œ, few-shot learning์ฒ˜๋Ÿผ data ๋ถ€์กฑ ํ˜„์ƒ์— ์‹œ๋‹ฌ๋ฆฌ๋Š” CL๊ณผ ๊ต‰์žฅํžˆ ๊ถํ•ฉ์ด ์ž˜ ๋งž์Šต๋‹ˆ๋‹ค. ์‹ค์ œ๋กœ NCM Classfier๋ฅผ ์ ์šฉํ•˜๋Š” ๊ฒƒ๋งŒ์œผ๋กœ๋„ ๋Œ€๋ถ€๋ถ„์˜ CL ๋ฐฉ๋ฒ•๋ก ์˜ ์„ฑ๋Šฅ์ด ํฌ๊ฒŒ ์ƒ์Šนํ•ฉ๋‹ˆ๋‹ค.

uc=1ncโˆ‘if(xi)โ‹…1{yi=c}u_c = \frac{1}{n_c}\sum_i f(x_i) \cdot 1\{y_i = c \}ucโ€‹=ncโ€‹1โ€‹โˆ‘iโ€‹f(xiโ€‹)โ‹…1{yiโ€‹=c}

yโˆ—=argminc=1,...,tโˆฃโˆฃf(x)โˆ’ucโˆฃโˆฃy^* = argmin_{c=1,...,t} ||f(x) - u_c ||yโˆ—=argminc=1,...,tโ€‹โˆฃโˆฃf(x)โˆ’ucโ€‹โˆฃโˆฃ

NCM classifier๋ฅผ ์œ„ํ•ด ์‚ฌ์šฉ๋˜๋Š” ์ˆ˜์‹์€ ์œ„์™€ ๊ฐ™๋‹ค. ์—ฌ๊ธฐ์„œ c๋Š” ํด๋ž˜์Šค๋ฅผ ๋œปํ•˜๊ณ , 1{y=c} ๋Š” y๊ฐ€ c์ผ ๋•Œ๋ฌธ 1์ด๋ผ๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•œ๋‹ค. ํด๋ž˜์Šค ๋ณ„ ๋ฉ”๋ชจ๋ฆฌ์— ๋“ค์–ด์žˆ๋Š” ๋ฐ์ดํ„ฐ์˜ ํ‰๊ท ์„ ๊ตฌํ•˜๊ณ , ๊ทธ ํ‰๊ท ์— ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ํด๋ž˜์Šค๋กœ Inference๋ฅผ ์ง„ํ–‰ํ•œ๋‹ค.

Supervisied Contrastive Replay

NCM Classifier์˜ ํฌํ…์…œ์„ ๋” ๋†’์ผ ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์ด SCR์ž…๋‹ˆ๋‹ค. NCM Classifier๋Š” Representation ๊ฐ„ ๊ฑฐ๋ฆฌ๋ฅผ ์ค‘์‹ฌ์œผ๋กœ inference๋ฅผ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฐ ์ƒํ™ฉ์—์„œ ๋‹ค๋ฅธ ํด๋ž˜์Šค๋Š” ๋” ๋ฉ€๋ฆฌ, ๊ฐ™์€ ํด๋ž˜์Šค๋Š” ๋” ๊ฐ€๊นŒ์ด ๋ถ™์—ฌ๋‘๋Š” Contrastive Learning์€ NCM์— ํฐ ๋„์›€์ด ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ €์ž๋Š” ํŠธ๋ ˆ์ด๋‹ ๋ฐ์ดํ„ฐ์— ๋‹จ์ˆœํ•œ Augmented View๋ฅผ ์ถ”๊ฐ€ํ•˜๊ณ , ์ด ๋ฐ์ดํ„ฐ๋“ค์„ ์ด์šฉํ•˜์—ฌ Contrastive Learning์„ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค. ๋ฉ”๋ชจ๋ฆฌ ๋ฐ์ดํ„ฐ์™€ ํ˜„์žฌ ๋ฐ์ดํ„ฐ๋ฅผ ํ•จ๊ป˜ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

LSCL(ZI)=โˆ‘iโˆˆI1โˆฃP(i)โˆฃโˆ‘pโˆˆP(i)logexp(ziโ‹…zp/ฯ„)โˆ‘jโˆˆA(i)exp(ziโ‹…zj/ฯ„)L_{SCL}(Z_I) = \sum_{i\in I} \frac{1}{|P(i)|} \sum{p\in P(i)} log \frac{exp(z_i\cdot z_p / \tau)}{\sum{j \in A(i)}exp(z_i \cdot z_j / \tau) }LSCLโ€‹(ZIโ€‹)=โˆ‘iโˆˆIโ€‹โˆฃP(i)โˆฃ1โ€‹โˆ‘pโˆˆP(i)logโˆ‘jโˆˆA(i)exp(ziโ€‹โ‹…zjโ€‹/ฯ„)exp(ziโ€‹โ‹…zpโ€‹/ฯ„)โ€‹

Loss ์‹์€ ์œ„ ์‹๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค. $B = {x_k,y_k}{k=1,...,b}$์˜ Mini Batch๋ผ๊ณ  ํ•  ๋•Œ, $\tilde{B}$ $= { \tilde{x_k} = Aug(x_k), y_k }{k=1,...,b}$ ์ž…๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  $B_I = B \cap \tilde{B}$ ์ž…๋‹ˆ๋‹ค. $I$๋Š” $B_I$์˜ ์ง€์ˆ˜๋“ค์˜ ์ง‘ํ•ฉ์ด๊ณ , $A(i)=I \setminus {i}$ ์ž…๋‹ˆ๋‹ค. $P(i) = {p \in A(i) : y_p = y_i}$ ์ž…๋‹ˆ๋‹ค. ๋ณต์žกํ•ด ๋ณด์ด์ง€๋งŒ ์ฐฌ์ฐฌํžˆ ๋œฏ์–ด๋ณด๋ฉด ์–ด๋ ต์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๊ฒฐ๊ตญ $P(i)$๋Š” ์ƒ˜ํ”Œ i๋ฅผ ์ œ์™ธํ•œ ๊ฒƒ ์ค‘์—์„œ label์ด ๊ฐ™์€ ๊ฒƒ, ๊ทธ๋Ÿฌ๋‹ˆ๊นŒ Positive sample์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค. $Z_I = {z_i}_{i \in I} = Model(x_i)$ ์ด๊ณ , $\tau$๋Š” ์กฐ์ •์„ ์œ„ํ•œ temperature parameter ์ž…๋‹ˆ๋‹ค.

Implementation์—์„œ๋Š” Continual Learning์˜ ๋ฒค์น˜๋งˆํฌ๋ผ๊ณ  ํ•  ์ˆ˜๋„ ์žˆ๋Š” Split Cifar-10์—์„œ ์‹คํ—˜์„ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์ธ BaseLine์œผ๋กœ ๋งŽ์ด ์‚ฌ์šฉ๋˜๋Š” Experience Replay์— ๋Œ€ํ•œ ๊ตฌํ˜„๊ณผ, ์ด ๋…ผ๋ฌธ์—์„œ ์ œ์•ˆํ•œ NCN Classifier๋ฅผ ์‚ฌ์šฉํ•œ Experience Replay์— ๋Œ€ํ•œ ๊ตฌํ˜„์„ ์ค€๋น„ํ–ˆ์Šต๋‹ˆ๋‹ค.

Environment

Colab ํ™˜๊ฒฝ์—์„œ ์‹คํ—˜ํ•˜๊ธฐ๋ฅผ ์ถ”์ฒœ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as D
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

Setting of Continual Learning

์ด ์ฑ•ํ„ฐ์—์„œ๋Š” Continual Learning evaluation์„ ์œ„ํ•œ ๊ธฐ๋ณธ์ ์ธ ์„ธํŒ…์„ ์ค€๋น„ํ•ฉ๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ์…‹์€ Cifar-10์„ 5๊ฐœ์˜ ํƒœ์Šคํฌ๋กœ ๋‚˜๋ˆˆ Split Cifar-10์„ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค. ๋…ผ๋ฌธ์—์„œ๋Š” Reduced_ResNet18์„ ๋ฒ ์ด์Šค ๋ชจ๋ธ๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ด Implementation์—์„œ๋Š” ๊ตฌํ˜„์˜ ๊ฐ„๋‹จํ•จ์„ ์œ„ํ•ด ์ž‘์€ CNN๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด ์ฝ”๋“œ์—์„œ๋Š” Split Cifar-10์„ ๋งŒ๋“ค๊ณ , Reduced_ResNet18์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.

# Made Split-Cifar10

def setting_data():
	transform_train = transforms.Compose([
	        transforms.RandomCrop(32, padding=4),
	        transforms.RandomHorizontalFlip(),
	        transforms.ToTensor(),
	    ]) #settign transform
	
	    transform_test = transforms.Compose([
	        transforms.ToTensor(),
	    ])
	#์›๋ณธ Cifar-10 dataset์„ ๋‹ค์šด๋กœ๋“œ ๋ฐ›์•„ ์ค๋‹ˆ๋‹ค
	    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 
	    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
	    train_loader  = torch.utils.data.DataLoader(dataset=train_dataset,
	                                               batch_size=1,
	                                               shuffle=False)
	
	    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
	                                             batch_size=1,
	                                             shuffle=False)
#์•„๋ž˜์˜ ์ฝ”๋“œ๋Š” Cifar10์„ ์ž„์˜์˜ ์ˆœ์„œ(5๋กœ ๋‚˜๋ˆˆ ๋‚˜๋จธ์ง€)๋กœ 5๊ฐœ์˜ ํƒœ์Šคํฌ๋กœ ๋ถ„๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
	    set_x = [[] for i in range(5)]
	    set_y = [[] for i in range(5)]
	    set_x_ = [[] for i in range(5)]
	    set_y_ = [[] for i in range(5)]
	    if shuffle==False:
	        for batch_images, batch_labels in train_loader:
	          if batch_labels >= 5:
	            y = batch_labels-5
	          else :
	            y = batch_labels
	          set_x_[y].append(batch_images)
	          set_y_[y].append(batch_labels)
	        for i in range(5):
	          set_x[i] = torch.stack(set_x_[i])
	          set_y[i] = torch.stack(set_y_[i])
	        set_x_t = [[] for i in range(5)]
	        set_y_t = [[] for i in range(5)]
	        set_x_t_ = [[] for i in range(5)]
	        set_y_t_ = [[] for i in range(5)]
	        for batch_images, batch_labels in test_loader:
	          if batch_labels >= 5:
	            y = batch_labels-5
	          else :
	            y = batch_labels
	          set_x_t_[y].append(batch_images)
	          set_y_t_[y].append(batch_labels)
	        for i in range(5):
	          set_x_t[i] = torch.stack(set_x_t_[i])
	          set_y_t[i] = torch.stack(set_y_t_[i])
	return set_x,set_y,set_x_t,set_y_t

์•„๋ž˜ ์ฝ”๋“œ๋Š” ์‚ฌ์šฉ๋  Base CNN ๋ชจ๋ธ์ธ Reduced ResNet18์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. ๋งˆ์ง€๋ง‰ FC ๋ ˆ์ด์–ด๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š” features๋ผ๋Š” ํ•จ์ˆ˜๊ฐ€ ์กด์žฌํ•˜๋Š” ์ ์ด ํŠน์ดํ• ๋งŒํ•œ ์ ์ž…๋‹ˆ๋‹ค. ์ด features๋Š” ํ›„์— NCM classifier๋ฅผ ๊ตฌํ˜„ํ• ๋•Œ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค,

class PreActBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out += shortcut
        return out

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
    def all_forward(self, x):
        out1 = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out1))
        out += self.shortcut(x)
        out = F.relu(out)
        return out1,out
    

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class Reduced_ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10,dropout=0.3):
        super(Reduced_ResNet, self).__init__()
        self.in_planes = 20

        self.conv1 = nn.Conv2d(3, 20, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(20)
        self.layer1 = self._make_layer(block, 20, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 40, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 80, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 160, num_blocks[3], stride=2)
        self.d1 = nn.Dropout(p=dropout)
        self.linear = nn.Linear(160*block.expansion, num_classes)
        self.linear3 = nn.Linear(400, num_classes)
        self.linear2 = nn.Linear(640,400)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out
    def features(self, x):
        '''Features before FC layers'''
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        return out

def reduced_ResNet18(num_classes=10):
    return Reduced_ResNet(BasicBlock,[2,2,2,2],num_classes=num_classes)

Experience Replay

์•„๋ž˜ ์ฝ”๋“œ๋Š” Continual Learning์—์„œ ๊ฐ€์žฅ ๋งŽ์ด ์“ฐ์ด๋Š” ๋ฒ ์ด์Šค๋ผ์ธ ์ค‘ ํ•˜๋‚˜์ธ Experience Replay๋ฅผ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค. Memory size, training epoch, learning rate ๋“ฑ ๋‹ค์–‘ํ•œ ์˜ต์…˜๋“ค์„ ๋ฐ”๋€Œ๋ฉฐ ์„ฑ๋Šฅ์ด ์–ด๋–ป๊ฒŒ ๋ณ€ํ•˜๋Š”์ง€ ์•Œ์•„๋ณด๋ฉด ์žฌ๋ฏธ์žˆ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋จผ์ € ์•„๋ž˜ ์ฝ”๋“œ์—์„œ๋Š” External Memory๋ฅผ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค. ๋ฉ”๋ชจ๋ฆฌ๋Š” ์–ด๋–ค ์‹์œผ๋กœ ๊ตฌํ˜„ํ•ด๋„ ์ƒ๊ด€์€ ์—†์ง€๋งŒ, ๋žœ๋ค์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ์— ๋“ค์–ด๊ฐˆ/๋ฉ”๋ชจ๋ฆฌ์—์„œ ๋ฝ‘ํž ๋ฐ์ดํ„ฐ๋ฅผ ์‰ฝ๊ฒŒ ๊ตฌํ˜„ํ•˜๊ธฐ ์œ„ํ•ด ํด๋ž˜์Šค๋ฅผ ํ•˜๋‚˜ ๋งŒ๋“ค์—ˆ์Šต๋‹ˆ๋‹ค.

class Memory():
    def __init__(self,mem_size,size=32): #mem_size๋Š” ๋ฉ”๋ชจ๋ฆฌ์˜ ํฌ๊ธฐ๋ฅผ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
        self.mem_size = mem_size
        self.image = []
        self.label = []
        self.num_tasks = 0
        self.image_size=size
    
    def add(self,image,label): #๋ฉ”๋ชจ๋ฆฌ์— ๋“ค์–ด๊ฐˆ image์™€ label์„ input์œผ๋กœ ๋ฐ›์Šต๋‹ˆ๋‹ค. ์„ ์–ธํ• ๋•Œ ์ •ํ•ด์ค€ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์ด์ฆˆ์— ๋งž์ถ”์–ด ์ž๋™์œผ๋กœ ์‚ฌ์ด์ฆˆ๊ฐ€ ์กฐ์ •๋ฉ๋‹ˆ๋‹ค.
        self.num_tasks +=1
        image_new= []
        label_new = []
        task_size = int(self.mem_size/self.num_tasks)
        if self.num_tasks != 1 :
            for task_number in range(len(self.label)):
                numbers = np.array([i for i in range(len(self.label[task_number]))])
                choosed = np.random.choice(numbers,task_size)
                image_new.append(self.image[task_number][choosed])
                label_new.append(self.label[task_number][choosed])
        numbers = np.array([i for i in range(len(label))])
        choosed = np.random.choice(numbers,task_size)
        image_new.append(image[choosed])
        label_new.append(label[choosed])
        self.image = image_new
        self.label = label_new
        
    def pull(self,size):
#๋ฉ”๋ชจ๋ฆฌ์—์„œ size๋งŒํผ์˜ image-label ์Œ์„ ๊บผ๋ƒ…๋‹ˆ๋‹ค. ์—ญ์‹œ ๋žœ๋ค์œผ๋กœ ์กฐ์ •ํ•ด์ค๋‹ˆ๋‹ค.
        image = torch.stack(self.image).view(-1,3,self.image_size,self.image_size)
        label = torch.stack(self.label).view(-1)
        numbers = np.array([i for i in range(len(label))])
        choosed = np.random.choice(numbers,size)
        return image[choosed],label[choosed]

๋ฉ”๋ชจ๋ฆฌ์— ๋“ค์–ด๊ฐˆ ์ƒ˜ํ”Œ๊ณผ, ๊บผ๋‚ด์ง€๋Š” ์ƒ˜ํ”Œ์„ ์ •ํ•˜๋Š” ๊ฒƒ์€ ER method์—์„œ ์ค‘์š”ํ•œ ๋ถ€๋ถ„์ž…๋‹ˆ๋‹ค. ๊ธฐ๋ณธ์ ์ธ ER method๋Š” ๋ชจ๋“  ๊ฒƒ์„ ๋žœ๋ค์œผ๋กœ ์กฐ์ •ํ•˜์ง€๋งŒ, MIR, GSS, ASER ๋“ฑ์˜ ์ถ”๊ฐ€์ ์ธ ๋ฉ”์†Œ๋“œ๋Š” ์ด ๋ถ€๋ถ„์œผ๋กœ ์ฃผ์š”ํ•˜๊ฒŒ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.

๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋งŒ๋“ค์—ˆ์œผ๋‹ˆ ๋‹ค์Œ์œผ๋กœ ์ง„ํ–‰ํ•  ๊ฒƒ์€ ํŠธ๋ ˆ์ด๋‹, ํ…Œ์ŠคํŠธ, ๊ทธ๋ฆฌ๊ณ  Continual Leaerning setting์ž…๋‹ˆ๋‹ค. ์ง„ํ–‰ํ•˜๊ธฐ ํŽธํ•˜๊ฒŒ ํŠธ๋ ˆ์ด๋‹๊ณผ ํ…Œ์ŠคํŠธ๋ฅผ ๋”ฐ๋กœ ํ•จ์ˆ˜ํ™” ํ•˜๊ณ , Continual Learning process๋Š” ER ํ•จ์ˆ˜์—์„œ ๋”ฐ๋กœ ์ •์˜ํ•ด์ค๋‹ˆ๋‹ค.

from typing_extensions import TypeAlias
def training(model,training_data,memory,opt,epoch,mem=False,mem_iter=1,mem_batch=10):
    model.train()
    dl = D.DataLoader(training_data,batch_size=10,shuffle=True)
    criterion = nn.CrossEntropyLoss()
    for ep in range(epoch):
        for i, batch_data in enumerate(dl):
            batch_x,batch_y = batch_data
            batch_x = batch_x.view(-1,3,32,32)
            batch_y = batch_y.view(-1)
            if mem==True:
                  for j in range(mem_iter) :
                        logits = model.forward(batch_x)
                        loss = criterion(logits,batch_y)
                        opt.zero_grad()
                        loss.backward()
                        mem_x, mem_y = memory.pull(mem_batch)
                        mem_x = mem_x.view(-1,3,32,32)
                        mem_y = mem_y.view(-1)
                        mem_logits = model.forward(mem_x)
                        mem_loss = criterion(mem_logits,mem_y)
                        mem_loss.backward()
            else :
                    logits = model.forward(batch_x)
                    loss = criterion(logits,batch_y)
                    opt.zero_grad()
                    loss.backward()
            opt.step()

def test(model,tls):
    accs = []
    model.eval()
    for tl in tls:
        correct = 0
        total = 0
        for x,y in tl:
            x = x
            y = y
            total += y.size(0)
            output = model(x)
            _,predicted = output.max(1)
            correct += predicted.eq(y).sum().item()
        accs.append(100*correct/total)
    return accs

def make_test_loaders(set_x_t,set_y_t):
  tls = []
  for i in range(len(set_x_t)):
    ds = D.TensorDataset(set_x_t[i].view(-1,3,32,32),set_y_t[i].view(-1))
    dl = D.DataLoader(ds,batch_size=100,shuffle=True)
    tls.append(dl)
  return tls

def ER(mem_size):
      set_x,set_y,set_x_t,set_y_t = setting_data()
      test_loaders = make_test_loaders(set_x_t,set_y_t)
      model = reduced_ResNet18()
      optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
      mem_x = []
      mem_y = []
      accs = []
      Mem = Memory(mem_size)
      for i in range(0,len(set_x)):
          training_data = D.TensorDataset(set_x[i].view(-1,3,32,32),set_y[i].view(-1))
          if i !=0:
              training(model,training_data,Mem,optimizer,1,mem=True)
          else:
              training(model,training_data,[],optimizer,1,mem=False)
          Mem.add(set_x[i].view(-1,3,32,32),set_y[i].view(-1))
          acc = test(model,test_loaders)
          accs.append(acc)
          print(acc)
          
      print('final accracy : ', np.array(acc).mean())

colab cpu๋ฅผ ์‚ฌ์šฉํ•  ๊ฒฝ์šฐ ์•ฝ 20๋ถ„ ์ •๋„๊ฐ€ ์†Œ์š”๋ฉ๋‹ˆ๋‹ค. Memory size 1000, epoch 1์˜ ์ƒํ™ฉ์—์„œ ์ตœ์ข… ์„ฑ๋Šฅ์˜ ํ‰๊ท ์€ ์•ฝ 34-36์ •๋„๋กœ ๋‚˜์˜จ๋‹ค๋ฉด ํ›Œ๋ฅญํ•ฉ๋‹ˆ๋‹ค. ์ €์ž์˜ ๋…ผ๋ฌธ์— ๋‚˜์˜จ ํ‰๊ท ๊ฐ’์€ ๋Œ€๋žต 37 ์ •๋„์ž…๋‹ˆ๋‹ค. learning rate์„ 0.05-0.08 ์ •๋„๋กœ ๋‚ฎ์ถ˜๋‹ค๋ฉด ์ €์ž์˜ ์„ฑ๋Šฅ์— ๊ทผ์ ‘ํ•œ ๊ฐ’์„ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Use NCM Classifier

์—ฌ๊ธฐ์„œ Contrastive Learning๊นŒ์ง€ ๊ตฌํ˜„ํ•˜๋Š” ๊ฒƒ์€ CPU๋งŒ ์‚ฌ์šฉํ•˜๋Š” ํŠน์„ฑ์ƒ ์–ด๋ ต๊ธฐ ๋•Œ๋ฌธ์—, NCM Classifier๋ฅผ ๊ตฌํ˜„ํ•˜๊ณ , ์„ฑ๋Šฅ ์ƒ์Šน์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋„๋ก Implementation ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

def ncm_test(model,mem_x,mem_y,tls):
    labels = np.unique(np.array(mem_y))
    classes= labels
    exemplar_means = {}
    cls_sample = {label : [] for label in labels}
    ds = D.TensorDataset(mem_x.view(-1,3,32,32),mem_y.view(-1))
    dl = D.DataLoader(ds,batch_size=1,shuffle=False)
    accs = []
    for image,label in dl:
        cls_sample[label.item()].append(image)
    for cl, exemplar in cls_sample.items():
        features = []
        for ex in exemplar:
            feature = model.features(ex.view(-1,3,32,32)).detach().clone()
            feature.data= feature.data/feature.data.norm()
            features.append(feature)
        if len(features)==0:
            mu_y = torch.normal(0,1,size=tuple(model.features(x.view(-1,3,32,32)).detach().size()))
        else :
            features = torch.stack(features)
            mu_y = features.mean(0)
        mu_y.data = mu_y.data/mu_y.data.norm()
        exemplar_means[cl] = mu_y
    with torch.no_grad():
        model = model
        for task, test_loader in enumerate(tls):
            acc = []
            correct = 0
            size =0
            for  batch_x,batch_y in test_loader:
                batch_x = batch_x
                batch_y = batch_y
                feature = model.features(batch_x)
                for j in range(feature.size(0)):
                    feature.data[j] = feature.data[j] / feature.data[j].norm()
                feature = feature.view(-1,160,1)
                means = torch.stack([exemplar_means[cls] for cls in classes]).view(-1,160)
                means = torch.stack([means] * batch_x.size(0))
                means =  means.transpose(1,2)
                feature = feature.expand_as(means)
                dists = (feature-means).pow(2).sum(1).squeeze()
                _,pred_label = dists.min(1)
                correct_cnt = (np.array(classes)[pred_label.tolist()]==batch_y.cpu().numpy()).sum().item()/batch_y.size(0)
                correct += correct_cnt * batch_y.size(0)
                size += batch_y.size(0)
            accs.append(correct/size)
        return accs

def NCM_ER(mem_size):
      set_x,set_y,set_x_t,set_y_t = setting_data()
      test_loaders = make_test_loaders(set_x_t,set_y_t)
      model = reduced_ResNet18()
      optimizer = torch.optim.SGD(model.parameters(),lr=0.05)
      mem_x = []
      mem_y = []
      accs = []
      Mem = Memory(mem_size)
      for i in range(0,len(set_x)):
          training_data = D.TensorDataset(set_x[i].view(-1,3,32,32),set_y[i].view(-1))
          if i !=0:
              training(model,training_data,Mem,optimizer,1,mem=True)
          else:
              training(model,training_data,[],optimizer,1,mem=False)
          Mem.add(set_x[i].view(-1,3,32,32),set_y[i].view(-1))
          acc = ncm_test(model,Mem,test_loaders)
          print(acc)
          
      print('final accracy : ', np.array(acc).mean())

NCM_ER์„ ์ด์šฉํ•  ๊ฒฝ์šฐ, Colab CPU์—์„œ ์•ฝ 21๋ถ„์ด ์†Œ์š”๋ฉ๋‹ˆ๋‹ค. ์„ฑ๋Šฅ์€ memory size 1000 ๊ธฐ์ค€์œผ๋กœ ์•ฝ 38-41 ์ •๋„๋กœ, ์ €์ž์˜ reference ๊ฐ’๋ณด๋‹ค ๋‚ฎ๊ฒŒ ๋‚˜์˜ค๋”๋ผ๋„ ๊ดœ์ฐฎ์Šต๋‹ˆ๋‹ค. hyperparemeter tuning์„ ์ž˜ ์ˆ˜ํ–‰ํ•œ๋‹ค๋ฉด ์ €์ž์˜ ์„ฑ๋Šฅ์— ๊ทผ์ ‘ํ•˜๊ฒŒ ์„ฑ๋Šฅ์„ ์˜ฌ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Take Home Message

continual learning์€ ์•„์ง ๊ฐˆ ๊ธธ์ด ๋จธ๋‚˜, contrastive learning์ด๋‚˜ transformer์ฒ˜๋Ÿผ main vision task์—์„œ๋Š” ์ด๋ฏธ ๊ทธ ๋Šฅ๋ ฅ์ด ๊ฒ€์ฆ๋˜์—ˆ์ง€๋งŒ continual leanring์—์„œ๋Š” ์•ˆ ์“ฐ์ธ ๊ฒƒ๋“ค์ด ๋งŽ์Šต๋‹ˆ๋‹ค. ์ž˜ ์‚ดํŽด๋ณธ๋‹ค๋ฉด ์•„์ง continual learning์€ ๋ฐœ์ „ ๊ฐ€๋Šฅ์„ฑ์ด ์ถฉ๋ถ„ํ•ฉ๋‹ˆ๋‹ค.

Author

๊ถŒ๋ฏผ์ฐฌ (MINCHAN KWON)

  • KAIST AI

  • https://kmc0207.github.io/CV/

  • kmc0207@kaist.ac.kr

Reviewer

  1. Korean name (English name): Affiliation / Contact information

  2. Korean name (English name): Affiliation / Contact information

  3. ...

Reference & Additional materials

  1. Citation of this paper

  2. Official (unofficial) GitHub repository

  3. Citation of related work

  4. Other useful materials

  5. ...

Previous[2022 Spring] ImplementationNextPose Recognition with Cascade Transformers [Eng]

Last updated 2 years ago

Was this helpful?