Why Multi-Head Self Attention Works: 10+1 Insights
AI Summer | Nikolas Adaloglou | 2021-03-25 Research-grounded Analysis | Multiple papers synthesized
핵심 Takeaways
개념적 통찰
- Self-attention은 정보 라우팅(Information Routing) 메커니즘
- Multi-head는 “독립적” 계산이 아니라 공통 부분공간에서 협력
- Attention 헤드들을 분류하고 가지치기 가능 (66% 제거 가능)
실무적 인사이트
| 통찰 | 함의 |
|---|---|
| Insight 3: Heads가 공통 투사 공간 학습 | Low-rank 구조 활용 가능 |
| Insight 8: Layer norm만으로 fine-tuning 가능 | 전이학습 효율성 극대 |
| Insight 7: Rank collapse 방지 필수 | Skip connection 필수 |
10+1 Insights 상세 분석
1️⃣ Self-Attention은 비대칭적 (Not Symmetric)
수학적 증명
QK^T / √d_k = (X·W^Q)(X·W^K)^T / √d_k
= X·W^Q·W^K^T·X^T / √d_k
W^Q ≠ W^K이므로 결과 행렬은 비대칭
의미
- Attention weight[i,j] ≠ Attention weight[j,i]
- Directed graph: i → j 방향성 있음
- 정보 라우팅의 한 방향성
같은 가중치 사용 시 (W^Q = W^K)
QK^T = X·W·W^T·X^T (대칭!)
- 무방향 그래프 (undirected)
- 일부 논문에서 이 방식 채택
2️⃣ Attention as Routing of Multiple Local Information
핵심 논문 (Schlag et al., 2021)
“Attention은 다양한 국소 정보 소스를 글로벌 트리 구조로 라우팅하는 메커니즘”
발견
- 헤드들이 다른 부분에 attend하는 것 아님
- 각 헤드는 거의 모든 정보 보존
- 차이점: 정보를 어떻게 결합하는가
의미: 아키텍처 설계 시 고려
목표: 정보 보존, 유연한 라우팅
→ Dense attention보다 효율적 가능
→ Sparse attention 여전히 가능
3️⃣ Encoder Weights Classification & Pruning (Voita et al., 2019)
3가지 헤드 분류
| 헤드 유형 | 주의 패턴 | 예시 |
|---|---|---|
| Positional | 인접 토큰만 | i-1, i, i+1 위치 |
| Syntactic | 문법적 관계 | 주어-술어, 전치사-목적어 |
| Rare word | 희귀 단어 | ”Transformer” 같은 특수 용어 |
가지치기 결과 (48개 헤드)
원본: 모든 48개 헤드 사용 (BLEU score = 100%)
가지치기: 17개 헤드만 유지
결과: BLEU score = 97% (거의 동일)
→ 인코더의 2/3 가지치기 가능!
실무 적용
- 학습 후 가지치기: 모델 압축
- 선택적 업데이트: 불필요한 헤드 미세조정 스킵
4️⃣ Multiple Heads on Encoder-Decoder Attention (Michel et al., 2019)
발견
- 인코더 self-attention: 쉽게 가지치기 가능
- 인코더-디코더 cross-attention: 매우 중요 (60% 이상 제거 불가)
이유
Encoder self-attn: 입력 인코딩 (redundancy 허용)
Cross-attention: 소스-타겟 정렬 (unique 역할)
Decoder self-attn: 생성 순서 강제 (masking 담당)
설계 시사점
- 용량 제약 시: 인코더 헤드 줄임
- 성능 보장 시: Cross-attention 헤드 유지
5️⃣ Softmax 이후 Self-Attention은 Low-Rank (Wang et al., 2020)
관찰 (Linformer 논문)
Singular value decomposition of normalized attention P:
- 처음 128개 singular value로 90% 정보 복구
- 나머지 640개는 무시 가능
의미: 근사 가능성
P ≈ U₁₂₈ Σ₁₂₈ V₁₂₈^T (rank-128 근사)
계산량: O(n²) → O(128n) (선형!)
실제 Linformer 구현
# Keys, Values를 low-dimensional space로 투사
K_proj = K @ W_proj # (n, 128)
V_proj = V @ W_proj # (n, 128)
# Attention 계산: O(128n) 시간6️⃣ Attention Weights as Fast Weight Memory Systems (Schmidhuber ‘91 → Modern)
1990년대 개념 (Fast Weights)
느린 네트워크가 빠른 네트워크의 가중치 동적 생성
→ 문맥에 따라 가중치 변함
Modern Attention의 해석
Softmax 제거 시:
y^(i) = (Σ_j v^(j) ⊗ k^(j)) q^(i)
W^(i) = outer product of values & keys
→ "문맥에 따라 동적 생성되는 가중치"
핵심 요구: Orthogonal Keys
“간섭을 피하려면 키들이 직교해야 한다. 그렇지 않으면 dot product가 여러 키에 attend하여 값들의 선형결합을 반환”
7️⃣ Rank Collapse (Dong et al., 2021)
발견: Attention은 Rank 1로 수렴
스택된 attention 레이어 수 n에 대해:
rank(output) ∝ e^(-c·n) (지수적 붕괴!)
깊이 8 → rank 1에 가까워짐
원인 및 대책
| 원인 | 효과 | 대책 |
|---|---|---|
| 반복된 attention | 토큰 균일화(uniformity) | ❌ Skip connection 필수 |
| 누적 정보 손실 | 차원 축약 | ✅ MLP (높은 차원) |
| 정규화 과정 | 정보 소실 | ❓ Layer norm 무관 |
설계 원칙
✅ Skip connection: 필수 (rank collapse 방지)
✅ MLP: 필수 (차원성 회복)
❌ Layer norm: Rank collapse 방지 무용지물
8️⃣ Layer Norm: 전이학습의 핵심 (Aghajanyan et al., 2021)
놀라운 발견
BERT (768D) → MNIST/CIFAR-10 fine-tuning:
① Freeze all (제로샷): 35% 정확도
② Freeze + LayerNorm fine-tune (0.1% params): 80% ✅
③ Full fine-tune (100% params): 82%
→ LayerNorm만으로 98% 성능 달성!
이유 분석
LayerNorm의 학습 파라미터: γ(스케일), β(시프트)
→ 미리 학습된 Q, K, V 투사들을 rescale/shift
→ 새 도메인의 분포에 맞춤
실무적 함의
# 전이학습 최적 전략
for param in model.parameters():
param.requires_grad = False
# LayerNorm만 활성화
for layer in model.modules():
if isinstance(layer, LayerNorm):
for param in layer.parameters():
param.requires_grad = True9️⃣ Pretraining의 위력: 계산 원시 요소(Computation Primitives)
발견
Pretrained transformer (frozen) + LayerNorm fine-tune
≈ Fully fine-tuned transformer
→ Q, K, V 투사가 이미 좋은 "계산 블록"
의미
Pretraining = 기본 계산 단위 학습
Fine-tuning = 이들을 새 도메인에 맞춤
결과: 매우 데이터 효율적 (작은 데이터셋 OK)
🔟 Quadratic Complexity: 해결 진행 중
문제
Attention O(n²) 복잡도
- n = 512: 262K operations (가능)
- n = 4096: 16M+ operations (병목)
- n = 65K: 4B operations (불가능)
주요 해결책
1. 수학적 근사 (Linformer, Performer)
Low-rank approximation 활용
O(n²) → O(n) 시간 복잡도
2. 희소화 (Big Bird, Longformer)
Local window + Global tokens + Random
세 가지 조합으로 O(n) 달성
3. 상대 위치 (Shaw et al., 2018)
절대 위치 대신 상대 거리
→ 학습 길이보다 긴 시퀀스 외삽 가능
1️⃣1️⃣ Self-Attention은 완전히 이해된 구조가 아님
미해결 질문들
- 왜 정확히 multi-head attention이 작동하는가?
- 각 헤드의 **특화(specialization)**는 학습된 결과인가, 아니면 초기화 때문인가?
- Skip connection + MLP + Attention의 상호작용?
최근 주목할 연구 방향
- Attention의 정보 이론적 분석
- Lottery ticket hypothesis 적용 (pruning)
- Mechanistic interpretability (어떤 회로가 specific behavior 생성?)
ABCD 학습 목표
Understand (이해)
- A: 기계학습 배경 있는 IT 전문가
- B: Multi-head attention의 10가지 주요 성질 설명 가능
- C: 학술 논문이나 발표 자료 제시
- D: 5가지 이상 정확한 설명 (대칭성, 라우팅, 가지치기, low-rank, rank collapse)
Apply (적용)
- A: 모델 최적화 엔지니어
- B: 주어진 제약(메모리, 레이턴시)에서 최적의 attention 설정 결정
- C: 768D 임베딩, 512 길이 시퀀스, 2GB VRAM 제약
- D: HeadCount, head dimension 결정 + 추론 시 가지치기 전략
Analyze (분석)
- A: Transformer 모델 연구자, 평가 담당자
- B: 학습된 attention 패턴으로부터 모델의 강점/약점 진단
- C: Attention weight 행렬 및 singular value 분포 데이터
- D: 헤드별 특화 패턴, 불필요한 헤드 발견, 가지치기 추천
교육 설계 배치
Lecture Module 2: “LLM 모델 이해 및 활용” (16시간)
강의 흐름
-
이론 (4h):
- Insight 1-3: 기본 개념 (비대칭성, 라우팅, 분류)
- Insight 7-8: 아키텍처 안정성
-
실습 (3h):
- Hugging Face로 attention weight 시각화
- Head-wise contribution 분석
- Pruning 시뮬레이션
-
심화 (3h):
- Insight 5-6: 선형 attention 이해
- Insight 4: 실제 pruning 적용
선수과목 요구
← Lecture 1: “생성형 AI 기초” (Transformer 개요)
후속 모듈
→ Lecture 3: “Vector DB” (임베딩 최적화) → Lecture 5: “RAG Pipeline” (attention의 가지치기 영향)
관련 문서
- attention-in-transformers-visualized: 3Blue1Brown의 시각적 설명
- the-illustrated-transformer: 전체 아키텍처 개요
- transformer-attention-mechanism: 기본 메커니즘
- low-rank-approximation: Low-rank attention 심화
- attention-pruning: 헤드 가지치기 전략
- research-papers-attention: 원본 논문 모음
교육 설계 노트: 이 자료는 “왜 이렇게 설계했는가”에 대한 실증적 답변 제공. 추상적 설명 대신 구체적 수치와 논문 결과로 신뢰도 확보.