[논문리뷰]
ICLR 2023
ABSTRACT
- 본 논문에서 저자들은 Transformer의 attention operations에 대한 복잡도뿐만 아니라 linear layers를 줄일 수 있는 Token Pruning에 포커스를 맞췄다.
- 이전에 work들은 추후 layer의 attention의 영향에 대해 고려 없이 feed-forward 단계에서 token을 제거하였다. 이러한 이슈를 해결하고자 최종 예측에 큰 영향을 미치는 토큰을 보존하기 위해 output에서 input까지 각 attention의 중요성을 back-tracking 하는 방법을 제안한다.
- NLP와 CV에서의 효율성을 실험적으로 입증하였다.
Introduction
- Transformer의 Pruning 접근 방법은 주로 불필요한 모델 가중치나 attention의 제거하는 것을 목표로 한다.
- 어떤 Token을 제거할지에 대해 선행 연구들은 각 Token들에 대해 average attention scores를 계산하거나, 각 layer에서 추가적인 neural network로 각 token의 importance score를 구하였다.
- 이전의 Token Pruning 방법은 최종적인 예측 작업에서 각 Token의 중요성이 기준이 아닌, 입력 시퀀스에서 Token을 Pruning 한다는 문제점이 존재한다. -> 각 attention layer에서 입력 token에 대한 importance score를 계산하기 때문 -> 초기 layer에서 token을 prune 함.
- 왼쪽 Feed-Forward Token Pruning의 경우, "Love" Toekn이 Pruning 된다. 이러한 경우 "Love" Token이 다음 layer에서 높은 attention probability를 나타낼 수 있다는 사실을 고려하지 않고 Pruning 한다.
- 따라서 본 논문에서는 Attention back-tracking 방법은 제안한다.
- 마지막 layer의 token에서 input token으로 backward pass를 수행하여 attention score를 누적하여 각 input token의 importance score에 따라 token을 Pruning 한다.
- 이렇게 수행함으로써 input과 output에 최소한의 영향을 미칠 중요한 token을 잘 선택하고 유지할 수 있게 된다.
- 한 가지 어려운 이슈는 이러한 backtracking을 수행하기 위해서는 feed-forwarding 하기 전에 각 token의 attention score를 알아야 한다는 것이다.
- 이러한 문제를 해결하기 위해, KD로 trained 된 lightweight attention approximation network인 ApproxNet를 도입하였다.
- 또한 token을 선택하기 위해 Concrete masking이라는 방법을 제안한다. 구체적으로, task loss를 최소화하기 위한 objective function을 가지고 backtracking에 의해 계산되는 각 layer에서의 importance score를 공동으로 학습한다.
- 본 논문에서는 text와 image classification task에 대해서 평가하였다.
- 결과적으로 본 논문이 제안한 방법은
- 새로운 token pruning 방법 : 최종적인 결과에 대해 각 token의 중요성을 고려하는 attention back-tracking을 기반으로 transformer에 대한 새로운 token pruning 방법
- Concrete dropout으로 각 layer의 pruning 속도를 자동으로 학습하는 Concrete masking을 제안한다.
- 결과적으로 훨씬 더 나은 accuracy-computation tradeoff를 달성하였다.
RELATED WORK
Attention Pruning for Transformers
- self-attention의 computational cost가 크기 때문에, 최근 몇몇 연구에서는 시간 복잡도를 $ O(N^2) $에서 $ O(N) $으로 줄이는데 초점을 맞추고 있다.
- 하지만, 이 전의 work들은 static 한 sparsity pattern을 사용하기 때문에, input sequence의 content에 대해 고려하지 않는다. 이를 해결하기 위해 최근 연구에서는 입력 데이터를 고려하여 transformer의 attention을 동적으로 분산시키는 것을 목표로 한다.
Token Pruning for Transformers
- 최근 연구에서는 주어진 task에 중요한 token들을 선택함으로써 token의 수를 줄이는 것이 초점을 맞추고 있다.
- 하지만 기존의 방법들은 input token을 기준으로 feed-forward pass 중에 token을 제거한다. 그러므로 초기 layer에 중요한 token들을 제거하는 것이다.
- 이러한 방법들과 달리 본 논문은 각 token이 output vector에 어떠한 영향을 미칠지 고려하여 pruning을 진행한다.
Sparse Token Transformer with Attention Back Tracking
- Attention Back Tracking : attention approximation network와 각 입력에 대한 token별 임계값을 동적으로 학습하기 위한 Concrete masking으로 각 token에 대한 중요도 점수를 얻는 방법
- 목표는 각 layer에서 token을 pruning 하여 주어진 pretrained transformer의 inference cost를 낮추는 것
3.1 Transformer
- 아래 수식과 같이 l-th layer의 attention score와 확률을 표시한다. (기존의 Transformer 아키텍처의 구조와 동일)
3.2 Attention Back-Tracking (ABT)
- 기본 아이디어는 attention을 역추적하여 예측을 위한 최종 sequence에 미치는 영향을 측정하여 그 점수를 기반으로 token을 pruning 하는 것이다.
- 마지막 layer에서 attention probabilities가 더 높은 token으로 시작하며 classification task과 관련하여 전체 $ L $ Layer로 구성된 transformer incoder의 첫 번째 layer로 누적 attention probabilites로 back-tracking 한다.
- 하지만 위 기술을 적용하기 위해서 attention score를 사전에 계산하기 위해 모든 transformer layer를 통해 input sequence를 feed-forward 해야 한다. -> 이러한 방법은 각 input에 대해 모델의 straight-forward를 두 번 수행해야 하기 때문에 기술을 직접 구현하는 것은 상당히 비효율적일 수 있다.
Attention Approximation Network
- cost를 완화하기 위해, attention approximation network(ApproxNet)을 사용한다. -> 적은 수의 paremeter로 기존의 네트워크의 attention probability를 approximation 하는 distilled transformer encoder이다.
- TinyBERT의 distillation loss를 기반으로 knowledge distillation을 수행하여 ApproxNet을 수행한다.
- 구체적으로. embedding 기반 distillation loss $ L_{embd} $, hidden state based distillation loss $ L_{hidn} $, attention-based distillation loss $ L_{attn} $ 및 prediction-layer distillation loss $ L_{pred} $를 활용한다.
- 특히, attention-based distillation에 대해서는 mean squared error 대신에 MiniLM의 KL divergence loss를 사용한다.
- $ A^l_{prob} $ 는 오리지널 network(teacher)의 $ l $번째 layer attention probability이다.
- $ A(hat)^l_{prob} $ 는 ApproxNet(student)의 $ l $번째 layer attention probability이다.
- $ H $는 multi-head self-attention layer의 head 개수이다.
- $ T $는 input sequence의 길이
- $ L $은 transformer network의 layer 길이
Token Pruning with Attention Back-Tracking on Approximated Attention
- ApproxNet을 통해 얻은 approximated attentions을 통해 main network에서 각 layer의 token들을 pruning 할 수 있다.
- Fig 2. 의 Algorithm1처럼 token pruning에 대해 mask를 업데이트하는 알고리즘을 이용한다.
- UpdateTokenMaskOfLayer 알고리즘에서 input 값인 token mask $ i^{l+1} $과 token의 importance score $ s^{l+1} $이 입력된다
- $ l + 1 $번째 layer에서 선택된 token의 수인 $ o_{l+1} $이 존재한다.
- $ p $는 상위 layer에서 전파되어야 하는 importance score를 결정한다. 만약 $ p = 0 $이면 $ l $번째 layer의 token을 pruning 할 때 상위 layer의 importance score의 영향이 없어진다.
- $ k $ 는 top-k 방법을 사용할 때 몇 개의 token을 선택해야 하는지 결정한다. $ k $를 조정하여 각 layer에서 보유할 token 수를 변경할 수 있다.
3.3 Concrete Masking
- Transformer 모델의 각 layer의 top-k token은 정적인 값 k를 사용하여 선택할 수 있지만, input을 기반으로 pruning을 조절하기 위해서 pruning 비율을 동적으로 변경하는 메커니즘도 필요하다.
- 따라서 본 논문에서는 Concrete Masking이라 부르는 dynamic thresholding 방법을 제안한다.
- 위 식처럼 MHSA 행렬을 위 식과 같이 균일한 분포로 변환한다.
- 위 식을 통해 ApproxNet에서 예측된 attention score의 모든 행을 변환한다.
- 훈련 중 모델은 각 layer에 적용된 soft mask function을 통해 적절한 값을 학습한다.
- 충분한 epoch으로 훈련을 거친 후, soft -> hard mask로 변경한다. hard masking의 경우 임의로 설정된 threshold을 사용하여 soft Concrete token mask $ m^l $을 설정한다.
Experiments
- NLP와 Vision task 모두에서 평가하였다.
- Dataset
- NLP : use nine datasets from GLUE, BERT(base)
- Vision : ImageNet-1k, DeiT-small (base architecture), DynamicViT(primary baseline)
'Paper Review' 카테고리의 다른 글
Swin Transformer : Hierarchical Vision Transformer using Shifted Windows (3) | 2023.10.14 |
---|---|
TOKEN MERGING: YOUR VIT BUT FASTER (0) | 2023.03.28 |
Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration (0) | 2023.01.04 |
Convolutional Neural Network Pruning: A Survey (0) | 2022.12.30 |
Understanding the difficulty of training deep feedforward neural networks (0) | 2022.12.28 |