SHA-RNN
Single Headed Attention RNN: Stop Thinking With Your Head
A language model architecture proving that a single attention head combined with LSTMs can match Transformer performance at a fraction of the cost. Train on a single GPU in 24 hours instead of TPU clusters for weeks.
The Provocation
Why transformers might not be needed
The Core Thesis
The paper title is deliberately provocative: "Stop Thinking With Your Head" is a critique of the ML community's obsession with multi-headed attention in Transformers. Stephen Merity demonstrates that a single attention head integrated into an LSTM backbone achieves nearly identical performance to Transformer-XL.
Computational Accessibility
Transformers require massive GPU infrastructure and long training times. SHA-RNN achieves competitive results on consumer hardware in ~24 hours.
Overengineering Critique
The field became obsessed with Transformer architecture without exploring whether simpler approaches could achieve similar results. Multi-head attention may be overkill.
Premature RNN Death
The research community declared RNNs 'dead' when properly designed hybrid approaches could be competitive with much less complexity.
Architecture
Hybrid LSTM + Single-Head Attention
High-Level Structure
Input Embedding (1024 dim)
↓
Block 0: LSTM → Feed-Forward (Boom)
↓
Block 1: LSTM → Feed-Forward (Boom)
↓
Block 2: LSTM → Single-Head Attention → Feed-Forward (Boom) ← Attention only here!
↓
Block 3: LSTM → Feed-Forward (Boom)
↓
Output Linear (to vocabulary)Notice: Attention is only added to one layer (the second-to-last). This minimizes overhead while capturing long-range dependencies.
Why This Works
- •LSTM provides inherent sequential ordering (no positional embeddings needed)
- •Single attention head captures long-range dependencies up to 5000 tokens
- •Recurrent nature gives strong local inductive bias
- •Feed-forward (Boom) modules add capacity similar to Transformer FFN
Key Differences from Transformers
- •1 attention head vs 8-16 in Transformers
- •LSTM backbone vs pure attention
- •Attention in 1 layer vs all layers
- •No positional encoding required
- •Memory caching for long context (5000 tokens)
Key Innovations
What makes SHA-RNN special
Learnable Scaling from Zero
Instead of initializing attention weights normally, SHA-RNN starts scaling parameters near zero and uses sigmoid gating to gradually learn attention. This allows the model to 'discover' attention rather than rely on it from the start.
# Query, Key, Value scaling parameters start at zero
self.qs = nn.Parameter(torch.zeros(...))
self.ks = nn.Parameter(torch.zeros(...))
# Sigmoid gates scale the projections
qs = torch.sigmoid(self.qs) # Starts near 0.5, learns to scale
ks = torch.sigmoid(self.ks)
q, k, v = qs * query, ks * key, vs * valueOverparam Module for Value Scaling
The value scaling uses an overparameterized transformation similar to LSTM gating, projecting to 2x dimensions and using sigmoid-tanh gating.
class Overparam(nn.Module):
def __init__(self, nhid):
self.linear = nn.Linear(nhid, 2 * nhid)
def forward(self, x):
x = self.linear(x)
c, f = x.chunk(2, dim=-1) # Split into content and forget
return torch.sigmoid(f) * torch.tanh(c) # LSTM-like gatingMemory Caching for Long Context
SHA-RNN caches hidden states from previous sequences, allowing attention to look back up to 5000 tokens. This enables efficient long-context modeling without reprocessing entire history.
def forward(self, x, hidden, mems=None):
# Concatenate memory with current hidden states
if mems is not None:
bigh = torch.cat([mem, mh], dim=0) # Past + current
# Attention over extended context
attn_out = self.attention(query, key=bigh, value=bigh)
# Keep only recent tokens in memory
new_mem = bigh[-self.num_max_positions:] # Sliding window
return output, hidden, new_memBoom: Efficient Feed-Forward
The 'Boom' module is a shortcut-based feed-forward that's faster than standard Transformer FFN by using reshape and sum reduction instead of a second linear layer.
class Boom(nn.Module):
"""Feed-forward with shortcut for efficiency."""
def __init__(self, nhid, nout=None, shortcut=True):
nout = nout or nhid
self.linear = nn.Linear(nhid, nout * 4) # Expand 4x
self.shortcut = shortcut
def forward(self, x):
x = self.linear(x)
x = gelu(x)
if self.shortcut:
# Reshape and sum instead of second linear
x = x.view(*x.shape[:-1], 4, -1).sum(dim=-2)
return xSingle-Head Attention
The core mechanism
Attention Implementation
class Attention(nn.Module):
def __init__(self, nhid):
self.num_heads = 1 # Single head!
# Projections
self.q_proj = nn.Linear(nhid, nhid)
self.k_proj = nn.Linear(nhid, nhid)
self.v_proj = nn.Linear(nhid, nhid)
self.out_proj = nn.Linear(nhid, nhid)
# Learnable scaling (initialized near zero)
self.qs = nn.Parameter(torch.zeros(1, 1, nhid))
self.ks = nn.Parameter(torch.zeros(1, 1, nhid))
self.vs = Overparam(nhid) # Overparameterized scaling
def forward(self, query, key, value, mask=None):
# Apply learned scaling
qs = torch.sigmoid(self.qs)
ks = torch.sigmoid(self.ks)
q = qs * self.q_proj(query)
k = ks * self.k_proj(key)
v = self.vs(self.v_proj(value))
# Scaled dot-product attention
d_k = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn = F.softmax(scores, dim=-1)
output = torch.matmul(attn, v)
return self.out_proj(output)Why Single Head Works
Multi-head attention in Transformers theoretically allows attending to different representation subspaces. However, SHA-RNN shows that with an LSTM backbone providing strong sequential modeling, a single attention head is sufficient to capture the long-range dependencies that LSTMs miss. The LSTM already provides "multiple heads" of local context processing.
LSTM + Attention Synergy
Why the combination works
What LSTMs Provide
- •Inherent positional information - No need for positional embeddings, the recurrent structure encodes sequence order
- •Strong local modeling - Excellent at capturing local patterns and n-gram-like features
- •Gated memory - Selective retention of important information through forget/input gates
- •Efficient sequential processing - O(n) complexity for sequence length
What Attention Adds
- •Direct long-range access - Can attend to any position in the 5000-token context window
- •Content-based retrieval - Find relevant context based on semantic similarity, not just position
- •Parallel context access - All positions in the context are equally accessible in one step
- •Gradient highways - Direct paths for gradients to flow to distant positions
The Block Structure
class Block(nn.Module):
def __init__(self, embed_dim, hidden_dim, use_attn=False):
self.lstm = nn.LSTM(embed_dim, embed_dim)
self.attn = Attention(embed_dim) if use_attn else None
self.boom = Boom(embed_dim, hidden_dim) # Feed-forward
# Layer norms
self.ln_start = nn.LayerNorm(embed_dim)
self.ln_mid = nn.LayerNorm(embed_dim)
self.ln_mem = nn.LayerNorm(embed_dim)
self.ln_ff = nn.LayerNorm(embed_dim)
def forward(self, x, hidden, mem=None):
# LSTM processing
x = self.ln_start(x)
lstm_out, new_hidden = self.lstm(x, hidden)
# Optional attention (only in layer 2)
if self.attn is not None:
x = self.ln_mid(lstm_out)
if mem is not None:
context = torch.cat([mem, x], dim=0)
else:
context = x
attn_out = self.attn(x, self.ln_mem(context), context)
lstm_out = lstm_out + attn_out # Residual
# Feed-forward
x = self.ln_ff(lstm_out)
x = lstm_out + self.boom(x) # Residual
return x, new_hiddenTraining
How to reproduce
Training Configuration
python -u main.py \
--epochs 32 \
--dropouth 0.1 --dropouti 0.1 --dropout 0.1 \
--data data/enwik8/ \
--save ENWIK8.pt \
--log-interval 10 \
--seed 5512 \
--optimizer lamb \ # LAMB optimizer (critical!)
--bptt 1024 \ # Sequence length
--warmup 800 \ # LR warmup steps
--lr 2e-3 \
--emsize 1024 \ # Embedding dimension
--nhid 4096 \ # Hidden dimension
--nlayers 4 \ # 4 LSTM layers
--batch_size 16Key Training Choices
- LAMB optimizer - Layer-wise Adaptive Moments for batch training, crucial for stability
- 800 step warmup - Gradual LR increase prevents early instability
- Mixed precision (AMP) - Faster training with minimal quality loss
- Gradient clipping at 0.3 - Prevents exploding gradients in LSTMs
- Two-stage training - First 32 epochs at 2e-3, then fine-tune at 1e-3
Hardware Requirements
- GPU - Single Titan V (12GB) or similar
- Training time - ~24 hours for full training
- Epoch time - 30-60 minutes per epoch
- Memory - Fits in 12GB with batch_size 16
- Can reduce - Lower batch size or BPTT if memory constrained
Reproducibility Advantage
Unlike Transformers which often require careful hyperparameter tuning, complex learning rate schedules, and fragile warmup procedures, SHA-RNN training is remarkably stable. If memory becomes an issue, you can simply reduce parameters and the model gracefully degrades to a standard LSTM.
Benchmarks
enwik8 byte-level language modeling
| Model | Test BPC | Parameters | LSTM? | Training |
|---|---|---|---|---|
| Krause mLSTM | 1.24 | 46M | ✓ | - |
| AWD-LSTM | 1.23 | 44M | ✓ | - |
| SHA-LSTM (this work) | 1.07 | 63M | ✓ | ~24h, 1 GPU |
| Transformer-XL (12L) | 1.06 | 41M | ✗ | Much longer, multi-GPU |
| Transformer-XL (18L) | 1.03 | 88M | ✗ | Much longer, multi-GPU |
| Adaptive Span Transformer | 1.02 | 38M | ✗ | ~24h, multi-GPU |
Key Takeaway
SHA-LSTM achieves within 1% of Transformer-XL performance using only a single consumer GPU and ~24 hours of training. The only models that significantly beat it require either more layers (18L Transformer-XL) or specialized attention mechanisms (Adaptive Span). For researchers and practitioners without access to large compute clusters, SHA-RNN represents a practical path to competitive language modeling.
Production Benefits
Why this matters for deployment
Standard Components
Only uses LSTM, single-head attention, and feed-forward layers. No exotic ops, easy to export to ONNX and deploy with existing optimized frameworks.
Smaller Memory
Single attention head and LSTM backbone require less memory than multi-head Transformers, especially for long sequences.
Graceful Degradation
If resources are constrained, can fall back to pure LSTM by removing attention. The model still works, just with reduced long-range capability.
The Bigger Picture
Lessons for ML research
Architectural Fashion vs. Practical Efficiency
The ML community has a tendency to declare previous architectures "dead" when new ones emerge. RNNs were declared obsolete after "Attention Is All You Need" (2017), but SHA-RNN shows this was premature. The question isn't "RNNs vs Transformers" but rather "what's the right tool for the job?"
Compute Accessibility
The trend toward ever-larger models trained on massive GPU clusters excludes most researchers and institutions from participating in cutting-edge work. SHA-RNN demonstrates that careful architectural design can achieve competitive results without requiring TPU pods, democratizing access to language modeling research.
The "Found Bug That Works" Phenomenon
The codebase contains this delightful comment:
# BUG: This does _nothing_ as mix isn't set to r ...
# But ... I got good results with this ... so ...
# Let's leave it as is for right now ...This illustrates a broader point: empirical validation often trumps theoretical purity. Sometimes the "wrong" implementation works well, and understanding why can lead to new insights.
Key Files
Repository structure
| File | Purpose |
|---|---|
| model.py | Core architecture: SHARNN, Block, Attention, Boom classes |
| main.py | Training script with LAMB optimizer and mixed precision |
| generate.py | Inference and text generation utilities |
| splitcross.py | Adaptive softmax for large vocabularies |
| lookahead.py | Lookahead optimizer wrapper |