Why Multi-Head Self Attention Works: 10+1 Insights

AI Summer | Nikolas Adaloglou | 2021-03-25 Research-grounded Analysis | Multiple papers synthesized

핵심 Takeaways

개념적 통찰

  • Self-attention은 정보 라우팅(Information Routing) 메커니즘
  • Multi-head는 “독립적” 계산이 아니라 공통 부분공간에서 협력
  • Attention 헤드들을 분류하고 가지치기 가능 (66% 제거 가능)

실무적 인사이트

통찰함의
Insight 3: Heads가 공통 투사 공간 학습Low-rank 구조 활용 가능
Insight 8: Layer norm만으로 fine-tuning 가능전이학습 효율성 극대
Insight 7: Rank collapse 방지 필수Skip connection 필수

10+1 Insights 상세 분석

1️⃣ Self-Attention은 비대칭적 (Not Symmetric)

수학적 증명

QK^T / √d_k = (X·W^Q)(X·W^K)^T / √d_k 
           = X·W^Q·W^K^T·X^T / √d_k

W^Q ≠ W^K이므로 결과 행렬은 비대칭

의미

  • Attention weight[i,j] ≠ Attention weight[j,i]
  • Directed graph: i → j 방향성 있음
  • 정보 라우팅의 한 방향성

같은 가중치 사용 시 (W^Q = W^K)

QK^T = X·W·W^T·X^T (대칭!)
  • 무방향 그래프 (undirected)
  • 일부 논문에서 이 방식 채택

2️⃣ Attention as Routing of Multiple Local Information

핵심 논문 (Schlag et al., 2021)

“Attention은 다양한 국소 정보 소스를 글로벌 트리 구조로 라우팅하는 메커니즘”

발견

  • 헤드들이 다른 부분에 attend하는 것 아님
  • 각 헤드는 거의 모든 정보 보존
  • 차이점: 정보를 어떻게 결합하는가

의미: 아키텍처 설계 시 고려

목표: 정보 보존, 유연한 라우팅
→ Dense attention보다 효율적 가능
→ Sparse attention 여전히 가능

3️⃣ Encoder Weights Classification & Pruning (Voita et al., 2019)

3가지 헤드 분류

헤드 유형주의 패턴예시
Positional인접 토큰만i-1, i, i+1 위치
Syntactic문법적 관계주어-술어, 전치사-목적어
Rare word희귀 단어”Transformer” 같은 특수 용어

가지치기 결과 (48개 헤드)

원본: 모든 48개 헤드 사용 (BLEU score = 100%)
가지치기: 17개 헤드만 유지
결과: BLEU score = 97% (거의 동일)

→ 인코더의 2/3 가지치기 가능!

실무 적용

  • 학습 후 가지치기: 모델 압축
  • 선택적 업데이트: 불필요한 헤드 미세조정 스킵

4️⃣ Multiple Heads on Encoder-Decoder Attention (Michel et al., 2019)

발견

  • 인코더 self-attention: 쉽게 가지치기 가능
  • 인코더-디코더 cross-attention: 매우 중요 (60% 이상 제거 불가)

이유

Encoder self-attn: 입력 인코딩 (redundancy 허용)
Cross-attention: 소스-타겟 정렬 (unique 역할)
Decoder self-attn: 생성 순서 강제 (masking 담당)

설계 시사점

  • 용량 제약 시: 인코더 헤드 줄임
  • 성능 보장 시: Cross-attention 헤드 유지

5️⃣ Softmax 이후 Self-Attention은 Low-Rank (Wang et al., 2020)

관찰 (Linformer 논문)

Singular value decomposition of normalized attention P:
- 처음 128개 singular value로 90% 정보 복구
- 나머지 640개는 무시 가능

의미: 근사 가능성

P ≈ U₁₂₈ Σ₁₂₈ V₁₂₈^T (rank-128 근사)

계산량: O(n²) → O(128n)  (선형!)

실제 Linformer 구현

# Keys, Values를 low-dimensional space로 투사
K_proj = K @ W_proj  # (n, 128)
V_proj = V @ W_proj  # (n, 128)
 
# Attention 계산: O(128n) 시간

6️⃣ Attention Weights as Fast Weight Memory Systems (Schmidhuber ‘91 → Modern)

1990년대 개념 (Fast Weights)

느린 네트워크가 빠른 네트워크의 가중치 동적 생성
→ 문맥에 따라 가중치 변함

Modern Attention의 해석

Softmax 제거 시:
y^(i) = (Σ_j v^(j) ⊗ k^(j)) q^(i)

W^(i) = outer product of values & keys
→ "문맥에 따라 동적 생성되는 가중치"

핵심 요구: Orthogonal Keys

“간섭을 피하려면 키들이 직교해야 한다. 그렇지 않으면 dot product가 여러 키에 attend하여 값들의 선형결합을 반환”


7️⃣ Rank Collapse (Dong et al., 2021)

발견: Attention은 Rank 1로 수렴

스택된 attention 레이어 수 n에 대해:
rank(output) ∝ e^(-c·n)  (지수적 붕괴!)

깊이 8 → rank 1에 가까워짐

원인 및 대책

원인효과대책
반복된 attention토큰 균일화(uniformity)❌ Skip connection 필수
누적 정보 손실차원 축약✅ MLP (높은 차원)
정규화 과정정보 소실❓ Layer norm 무관

설계 원칙

✅ Skip connection: 필수 (rank collapse 방지)
✅ MLP: 필수 (차원성 회복)
❌ Layer norm: Rank collapse 방지 무용지물

8️⃣ Layer Norm: 전이학습의 핵심 (Aghajanyan et al., 2021)

놀라운 발견

BERT (768D) → MNIST/CIFAR-10 fine-tuning:

① Freeze all (제로샷): 35% 정확도
② Freeze + LayerNorm fine-tune (0.1% params): 80% ✅
③ Full fine-tune (100% params): 82%

→ LayerNorm만으로 98% 성능 달성!

이유 분석

LayerNorm의 학습 파라미터: γ(스케일), β(시프트)
→ 미리 학습된 Q, K, V 투사들을 rescale/shift
→ 새 도메인의 분포에 맞춤

실무적 함의

# 전이학습 최적 전략
for param in model.parameters():
    param.requires_grad = False
 
# LayerNorm만 활성화
for layer in model.modules():
    if isinstance(layer, LayerNorm):
        for param in layer.parameters():
            param.requires_grad = True

9️⃣ Pretraining의 위력: 계산 원시 요소(Computation Primitives)

발견

Pretrained transformer (frozen) + LayerNorm fine-tune
≈ Fully fine-tuned transformer

→ Q, K, V 투사가 이미 좋은 "계산 블록"

의미

Pretraining = 기본 계산 단위 학습
Fine-tuning = 이들을 새 도메인에 맞춤

결과: 매우 데이터 효율적 (작은 데이터셋 OK)

🔟 Quadratic Complexity: 해결 진행 중

문제

Attention O(n²) 복잡도
- n = 512: 262K operations (가능)
- n = 4096: 16M+ operations (병목)
- n = 65K: 4B operations (불가능)

주요 해결책

1. 수학적 근사 (Linformer, Performer)

Low-rank approximation 활용
O(n²) → O(n) 시간 복잡도

2. 희소화 (Big Bird, Longformer)

Local window + Global tokens + Random
세 가지 조합으로 O(n) 달성

3. 상대 위치 (Shaw et al., 2018)

절대 위치 대신 상대 거리
→ 학습 길이보다 긴 시퀀스 외삽 가능

1️⃣1️⃣ Self-Attention은 완전히 이해된 구조가 아님

미해결 질문들

  1. 정확히 multi-head attention이 작동하는가?
  2. 각 헤드의 **특화(specialization)**는 학습된 결과인가, 아니면 초기화 때문인가?
  3. Skip connection + MLP + Attention의 상호작용?

최근 주목할 연구 방향

  • Attention의 정보 이론적 분석
  • Lottery ticket hypothesis 적용 (pruning)
  • Mechanistic interpretability (어떤 회로가 specific behavior 생성?)

ABCD 학습 목표

Understand (이해)

  • A: 기계학습 배경 있는 IT 전문가
  • B: Multi-head attention의 10가지 주요 성질 설명 가능
  • C: 학술 논문이나 발표 자료 제시
  • D: 5가지 이상 정확한 설명 (대칭성, 라우팅, 가지치기, low-rank, rank collapse)

Apply (적용)

  • A: 모델 최적화 엔지니어
  • B: 주어진 제약(메모리, 레이턴시)에서 최적의 attention 설정 결정
  • C: 768D 임베딩, 512 길이 시퀀스, 2GB VRAM 제약
  • D: HeadCount, head dimension 결정 + 추론 시 가지치기 전략

Analyze (분석)

  • A: Transformer 모델 연구자, 평가 담당자
  • B: 학습된 attention 패턴으로부터 모델의 강점/약점 진단
  • C: Attention weight 행렬 및 singular value 분포 데이터
  • D: 헤드별 특화 패턴, 불필요한 헤드 발견, 가지치기 추천

교육 설계 배치

Lecture Module 2: “LLM 모델 이해 및 활용” (16시간)

강의 흐름

  1. 이론 (4h):

    • Insight 1-3: 기본 개념 (비대칭성, 라우팅, 분류)
    • Insight 7-8: 아키텍처 안정성
  2. 실습 (3h):

    • Hugging Face로 attention weight 시각화
    • Head-wise contribution 분석
    • Pruning 시뮬레이션
  3. 심화 (3h):

    • Insight 5-6: 선형 attention 이해
    • Insight 4: 실제 pruning 적용

선수과목 요구

← Lecture 1: “생성형 AI 기초” (Transformer 개요)

후속 모듈

→ Lecture 3: “Vector DB” (임베딩 최적화) → Lecture 5: “RAG Pipeline” (attention의 가지치기 영향)


관련 문서


교육 설계 노트: 이 자료는 “왜 이렇게 설계했는가”에 대한 실증적 답변 제공. 추상적 설명 대신 구체적 수치와 논문 결과로 신뢰도 확보.