[๋ ผ๋ฌธ๋ฆฌ๋ทฐ]
ICLR 2023 notable top 5%
ABSTRACT
- ํ๋ จํ ํ์ ์์ด ๊ธฐ์กด ViT ๋ชจ๋ธ์ ์ฒ๋ฆฌ๋์ ๋๋ฆด์ ์๋ ๊ฐ๋จํ ๋ฐฉ๋ฒ
- ์ผ๋ฐ์ ์ด๊ณ ๊ฐ๋ฒผ์ด ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํ์ฌ ์ ์ฌํ ํ ํฐ๋ค์ ์ ์ง์ ์ผ๋ก ํฉ์น๋ค.
- ToMe(Token Merging)์ training๋์ ์ฝ๊ฒ ์ ์ฉ์ด ๊ฐ๋ฅํ๋ค.
Introduction
- Transformer์ ํ ํฐ์ Run-time์ Pruningํ์ฌ ๋ ๋น ๋ฅธ ๋ชจ๋ธ์ ๊ฐ๋ฅํ๊ฒ ํ๋ ๋ถ์ผ๊ฐ ๋ฑ์ฅ
- Token Pruning์ ๋๋ถ๋ถ์ Training ์๋๋ฅผ ๋์ด๊ธฐ ์ํด ์ ์ฉํ ์ ์๋ค.
- ๋ฐ๋ผ์ Token์ Pruningํ๋ ๊ฒ ๋ณด๋ค ๋ ๋์ ๋ฐฉ์์ธ combine์ ์ ์ฉํ์ฌ Token Merging ๋ฐฉ๋ฒ์ ์ ์ํ๋ค.
- ๋ณธ ๋ ผ๋ฌธ์์์ custome matching algorithm์ pruning๋งํผ ๋น ๋ฅด๋ฉด์ ๋ ์ ํํ๋ค. ๊ฒ๋ค๊ฐ ๋ณธ ๋ ผ๋ฌธ์ ๋ฐฉ๋ฒ์ training ์ค์ด๊ฑฐ๋ training ์ค์ด ์๋๋๋ผ๋ ๊ฐ๋ฅํ๋ค. (ํ๋ จ ์ฌ๋ถ์ ์๊ด์์)
RELATED WORK
Efficient Transformers
- ๋ณธ ๋ ผ๋ฌธ์์๋ ๊ธฐ์กด์ ViT model์ token๋ค์ mergingํจ์ผ๋ก์จ ์๋๋ฅผ ๋์ด๋ ๊ฒ์ ์ด์ ์ ๋์๋ค.
Token Reduction
- Transformer๋ token์ ์์ ์๊ด์์ด ์๋ํ ์ ์๊ธฐ ๋๋ฌธ์, ์ต๊ทผ ๋ง์ ์ฐ๊ตฌ์์ ์ด๋ฌํ token๋ค์ pruningํ๋ ์๋๊ฐ ๋ง์ด ์์๋ค.
- ํ์ง๋ง, ์ด๋ฌํ ๋ฐฉ๋ฒ๋ค์ training์ ํ์๋ก ํ์ง๋ง, ์ฐ๋ฆฌ์ ๋ฐฉ๋ฒ์ training์ด ์์ด ์ฌ์ฉ์ด ๊ฐ๋ฅํ๋ค.
- ๋ํ, ๋๋ถ๋ถ์ pruning work๋ค์ ๋์ (dynamic)ํ๋ค. ์ฆ, ํ ํฐ์ ์๋ image ๋๋ ๋ฌธ์ฅ๋ง๋ค ๋ค๋ฅด๋ค.
- ์ด๋ฌํ ๋ฐฉ๋ฒ์ accuracy์๋ ๋์์ด ๋์ง๋ง, ํ ํฐ ์๊ฐ ๋ค๋ฅธ ์ํ์ ๋์ด์ ๋ฐฐ์นํ ์ ์๊ธฐ ๋๋ฌธ์ ์ค์ฉ์ฑ์ด ์๋ค.
- ์ด๋ฌํ ๋ฌธ์ ๋ค์ ํด๊ฒฐํ๊ธฐ ์ํด ๋๋ถ๋ถ์ pruning ๋ ผ๋ฌธ์ ํ ํฐ์ ์ ๊ฑฐํ๋ ๋์ ํ๋ จ ์ค์ mask๋ฅผ ์ ์ฉํ์ฌ pruning์๋๋ฅผ ์ ํ์ํจ๋ค. ๋ฐ๋ฉด์, ๋ณธ ๋ ผ๋ฌธ์ ๋ฐฉ๋ฒ์ inference์ training ๋ชจ๋์ ์ ์ฉ๋์ด ๋ ๊ฒฝ์ฐ ๋ชจ๋ ์ค์ ์๋ ํฅ์์ ๋ฌ์ฑ์ด ๊ฐ๋ฅํ๋ค.
Combining Tokens
- Token์ ์ ๊ฑฐํ๋ ์์ ์ ๋ง์ง๋ง, combineํ๋ ์์ ์ ๊ฑฐ์ ์๋ค.
- Token Polling(2021)์ด ๊ฐ์ฅ ์ฐ๋ฆฌ์ ๋ฐฉ๋ฒ๊ณผ ์ ์ฌํ๋ค. ๊ทธ๋ฌ๋, ๋๋ฆฐ k-means ์ ๊ทผ๋ฒ์ ์ฌ์ฉํ์๋ค.
- ์ง๊ธ๊น์ง Training์์ด ํ ํฐ์ ๊ฒฐํฉํ ๋ speed-accuracy trade-off๋ฅผ ์ ๊ณตํ๋๋ฐ ์ฑ๊ณตํ ์ ๊ทผ ๋ฐฉ์์ด ์์๋ค.
TOKEN MERGING
- ๋ ผ๋ฌธ์ ๋ชฉํ๋ ๊ธฐ์กด์ ViT์ token merging module์ ์ฝ์ ํ๋ ๊ฒ์ด๋ค.
- ์ค๋ณต๋๋ tokens๋ค์ ํฉ์นจ์ผ๋ก์จ ๋ฐ๋์ trainํ ํ์๋ ์์ง๋ง ์ฐ์ฐ๋์ ์ฆ๊ฐ์ํฌ ์ ์๊ธฐ๋ฅผ ํฌ๋งํ๋ค.
Strategy
- $ r $ ์ token๋ค ์์ด๋ค.
- Transformer์ ๊ฐ ๋ธ๋ก์์ token์ ๋ณํฉํ์ฌ layer๋น $ r $๊ฐ๋ฅผ ์ค์ธ๋ค.
- ๋คํธ์ํฌ์ $ L $ ๋ธ๋ก์ ํตํด, ์ ์ง์ ์ผ๋ก $ rL $ token์ ํฉ์น๋ค. ํ ํฐ ์๊ฐ ์ ์์๋ก ์ ํ๋๋ ๋ฎ์ง๋ง ์ฒ๋ฆฌ๋์ ๋์์ง๊ธฐ ๋๋ฌธ์ $ r $์ ๋ณ๊ฒฝํ๋ฉด speed - accuracy์ ๋ฐธ๋ฐ์ค๋ฅผ ๋ง์ถ ์ ์๋ค.
- ์ค์ํ ๊ฒ์ ์ด๋ฏธ์ง์ ๋ด์ฉ์ ์๊ด์ ์ด $ rL $ ํ ํฐ์ ์ค์ธ๋ค๋ ๊ฒ์ด๋ค.
- ๋ณธ ๋ ผ๋ฌธ์ token merging step์ attention๊ณผ MLP ์ฌ์ด์ ๋ฐฐ์นํจ์ผ๋ก์จ token์์ ์ ๋ณด๋ฅผ ํ๋ ค์ค ์ ์์ผ๋ฉฐ, attention์ ๊ฒฐ๊ณผ๋ฅผ ํตํด ์ด๋ค ๋์์ ๋ณํ ํ ์ง๋ฅผ ๊ฒฐ์ ํ ์ ์๋ค. -> ๋ ๊ฐ์ง ๋ชจ๋ ์ ํ๋๋ฅผ ๋์ผ ์ ์๋ค.
Token Similarity
- Similar token์ ๋ณํฉํ๊ธฐ ์ ์, ์ฐ์ 'similar'ํ๋ค๋ ๊ธฐ์ค์ ๋ด๋ ค์ผํ๋ค.
- ๋ Token์ ํน์ง ์ฌ์ด์ distance๊ฐ ์์ ๊ฒฝ์ฐ์ ์ ์ฌํ๋ค๊ณ ํ ์ ์์ง๋ง, ์ด๊ฒ์ ๋ฐ๋์ ์ต์ ์ ์๋๋ค.
- Transformer๋ overparameterized ๋์ด์๋ค. ์ด๋ features์ ์ค์ํ์ง ์์ ๋ ธ์ด์ฆ๋ฅผ ํฌํจํ ๊ฐ๋ฅ์ฑ์ด ์์์ ์๋ฏธํ๋ค.
- ๋ณธ ๋ ผ๋ฌธ์์๋ ์ด๋ค ๊ฒ์ด ์ ์ฌํ์ง๋ฅผ ๊ฒฐ์ ํ๊ธฐ ์ํด ๊ฐ Token์ Key ์ฌ์ด์ dot product similarity metric์ ์ฌ์ฉํ๋ค.
Biparite Soft Matching
- Token similarity๊ฐ ๊ฒฐ์ ๋๊ณ ์ ์ฒด ์๋ฅผ $ r $๋งํผ ์ค์ด๊ธฐ ์ํด ์ผ์น์ํฌ Token์ ๊ฒฐ์ ํ๋ ๋น ๋ฅธ ๋ฐฉ๋ฒ์ด ํ์ํ๋ค.
- ๋
ผ๋ฌธ์์๋ ๋ค์๊ณผ ๊ฐ์ ๋ชฉํ๋ก ์๊ณ ๋ฆฌ์ฆ์ ๊ตฌํํ์๋ค.
- 1) ๋ณ๋ ฌํํ ์ ์๋ ๋ฐ๋ณต์ ์ธ ๊ฒ๋ค์ ํผํ๋ค
- 2) ๋ณํฉ์ด ๋ง๋๋ ๋ณํ๊ฐ ์ ์ง์ ์ด๊ธฐ๋ฅผ
- ๋
ผ๋ฌธ์์๋ ๋ค์๊ณผ ๊ฐ์ ์์๋ก ์๊ณ ๋ฆฌ์ฆ์ ๊ตฌํํ์๋ค.
- 1) Token๋ค์ ๊ฐ์ ์ฌ์ด์ฆ์ A์ B ๊ทธ๋ฃน์ผ๋ก ๋๋๋ค.
- 2) A์ ๊ฐ ํ ํฐ์์ B์ ๊ฐ์ฅ ์ ์ฌํ ํ ํฐ์ผ๋ก ์ ์ ๊ทธ๋ฆฐ๋ค.
- 3) $ r $๊ฐ์ ๊ฐ์ฅ ์ ์ฌํ๊ฒ ์ ์งํ๋ค.
- 4) ์ฐ๊ฒฐ๋ Token๋ค์ ๋ณํฉํ๋ค.
- 5) ๋ ๊ทธ๋ฃน์ ๋ค์ ํฉ์น๋ค.
- ์ ๋ฐฉ์์ฒ๋ผ ์ํํ๊ฒ ๋๋ฉด A์ ๊ฐ ํ ํฐ์๋ ํ๋์ ์ฃ์ง๊ฐ ์๊ธฐ ๋๋ฌธ์ 4๋จ๊ณ์์ ์ฐ๊ฒฐ๋ ๊ตฌ์ฑ ์์๋ฅผ ์ฐพ๋ ๊ฒ์ ๊ฐ๋จํ๋ค.
- ์์ ๊ฐ์ ๋ฐฉ์์ "bipartite soft matching"์ด๋ผ ํ๊ณ , ์ด๋ Token์ ๋ฌด์์๋ก ๋จ์ด๋จ๋ฆฌ๋ ๊ฒ๊ณผ ๊ฑฐ์ ๊ฐ์ ์๋์ด๋ค.
Tracking Token Size
- Token์ด ๋ณํฉ๋๋ฉด ๋ ์ด์ ํ๋์ ์
๋ ฅ Patch๋ฅผ ๋ํ๋ด์ง ์๊ฒ ๋๋ ๊ฒ์ด๋ค.
- ์ด๋ ๊ฒ ๋๋ฉด softmax attention์ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ๊ฟ์ ์๋ค. ์๋ฅผ๋ค์ด ๊ฐ์ key๋ฅผ ๊ฐ์ง ๋ ๊ฐ์ ํ ํฐ์ ๋ณํฉํ๋ค๋ฉด, ๊ทธ key๋ softmax term์ ๋ ์ํฅ์ ๋ฐ์ ๊ฒ์ด๋ค.
- $ s $๋ ๊ฐ ํ ํฐ์ ํฌ๊ธฐ(ํ ํฐ์ด ๋ํ๋ด๋ patch ์)๋ฅผ ํฌํจํ๋ ํ ๋ฒกํฐ์ด๋ค.
IMAGE EXPERIMENTS
- ImageNet-1k์ 4๊ฐ์ง ๋ค๋ฅธ ๋ฐฉ๋ฒ์ผ๋ก ํ์ตํ์๋ค.
- AugReg, MAE, SWAG, DeiT
- ๋ ผ๋ฌธ์ ๋ฐฉ๋ฒ์ผ๋ก off-the-shelf(๋ฏธ๋ฆฌ ํ์ต๋ ๋ชจ๋ธ)์ ๋๋ฆฌ๊ฑฐ๋, MAE์ DeiT์ ๊ฒฝ์ฐ, ๋ ผ๋ฌธ์ ๋ฐฉ๋ฒ์ ์ ์ฉํ์ฌ Training์ ์ํด.
Token Similarity
- attention ํ attention key๋ฅผ ์ฌ์ฉํ์ฌ ๋ณํฉ์ ์ํํ๋๊ฒ์ด ํจ์ฌ ์ ํํ๋ค.
- Table 1b์ ๊ฐ์ด Token ๊ฑฐ๋ฆฌ๋ฅผ ์ธก์ ํ๋ ๊ฒ์ด cosine similarity๊ฐ ๊ฐ์ฅ ์ข๋ค. ์ดํ ํจ์จ์ฑ์ ์ํด Attention Head๋ฅผ ์ฐ๊ฒฐํ๋ ๋์ K ํ๊ท ์ ์ํํ๋ค.
Algorithmic Choices
- ๋ณํฉํ Token์ ๊ฒฐ์ ํ ํ์๋ Token ํฌ๊ธฐ $ s $๋ก weighted averaging์ ์ํํ์ฌ ๊ฒฐํฉํ๋ค.
- ๋ ๊ฐ์ ์ธํธ A, B๋ฅผ ๋ถํ ํ ๋ ๊ต๋๋ก Token์ ํ ๋นํ๋ ๊ฒ์ด ๊ฐ์ฅ ์ ์๋ํ๋ค.
Proportional Attention.
- off-the-shelf MAE ๋ชจ๋ธ์ ์ ์ธํ๊ณ ๋ชจ๋ ๋ชจ๋ธ์ ๋ํด Proportional attention์ ์ฌ์ฉํ๋ค.
Comparing Matching Algorithms
- Table2์ ๋ณธ ๋ ผ๋ฌธ์ ๋ฐฉ๋ฒ์ธ bipartite matching๊ณผ ๋ค๋ฅธ token reduction algorithm์ ๋น๊ตํ์๋ค.
- ๋ ผ๋ฌธ์ bipartite matching์ greedyํ ์ ๊ทผ๋ฒ์ ์ ํ์ฑ๊ณผ pruning ์๋๋ฅผ ๊ฐ์ง๊ณ ์๋ค.
Selecting a Merging Schedule.
- ๊ธฐ๋ณธ์ ์ผ๋ก layer๋น Token์ ๋ณํฉํ๋ค.
- ์ด ๋ฐฉ๋ฒ์ ์ต์ ํจ์ ํ ์คํธํ๊ธฐ ์ํด ์คํ์ ํ ๊ฒฐ๊ณผ ์ผ์ ํ๊ฒ ๋ณํฉ๋ ํ ํฐ์ด ์ฆ๊ฐํจ์ ๋ฐ๋ผ ์ต์ ์ ๊ฐ๊น๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ์๋ค.
์ด์ธ ๋ค๋ฅธ ์คํ์ ์ธ ์ธํ ์ ์คํต
COMPARISON TO OTHER WORKS