🏡/repos/Smerity/

sha-rnn

🧠

SHA-RNN

Single Headed Attention RNN: Stop Thinking With Your Head

A language model architecture proving that a single attention head combined with LSTMs can match Transformer performance at a fraction of the cost. Train on a single GPU in 24 hours instead of TPU clusters for weeks.

1.07 BPC63M params1 GPU24 hours
🎯

The Provocation

Why transformers might not be needed

The Core Thesis

The paper title is deliberately provocative: "Stop Thinking With Your Head" is a critique of the ML community's obsession with multi-headed attention in Transformers. Stephen Merity demonstrates that a single attention head integrated into an LSTM backbone achieves nearly identical performance to Transformer-XL.

Transformer-XL (12 layers)
1.06 BPC, 41M params, requires TPU infrastructure
SHA-LSTM (4 layers)
1.07 BPC, 63M params, single Titan V GPU

Computational Accessibility

Transformers require massive GPU infrastructure and long training times. SHA-RNN achieves competitive results on consumer hardware in ~24 hours.

Overengineering Critique

The field became obsessed with Transformer architecture without exploring whether simpler approaches could achieve similar results. Multi-head attention may be overkill.

Premature RNN Death

The research community declared RNNs 'dead' when properly designed hybrid approaches could be competitive with much less complexity.

🏗️

Architecture

Hybrid LSTM + Single-Head Attention

High-Level Structure

Input Embedding (1024 dim)
    ↓
Block 0: LSTM → Feed-Forward (Boom)
    ↓
Block 1: LSTM → Feed-Forward (Boom)
    ↓
Block 2: LSTM → Single-Head Attention → Feed-Forward (Boom)  ← Attention only here!
    ↓
Block 3: LSTM → Feed-Forward (Boom)
    ↓
Output Linear (to vocabulary)

Notice: Attention is only added to one layer (the second-to-last). This minimizes overhead while capturing long-range dependencies.

Why This Works

  • LSTM provides inherent sequential ordering (no positional embeddings needed)
  • Single attention head captures long-range dependencies up to 5000 tokens
  • Recurrent nature gives strong local inductive bias
  • Feed-forward (Boom) modules add capacity similar to Transformer FFN

Key Differences from Transformers

  • 1 attention head vs 8-16 in Transformers
  • LSTM backbone vs pure attention
  • Attention in 1 layer vs all layers
  • No positional encoding required
  • Memory caching for long context (5000 tokens)
💡

Key Innovations

What makes SHA-RNN special

1

Learnable Scaling from Zero

Instead of initializing attention weights normally, SHA-RNN starts scaling parameters near zero and uses sigmoid gating to gradually learn attention. This allows the model to 'discover' attention rather than rely on it from the start.

# Query, Key, Value scaling parameters start at zero
self.qs = nn.Parameter(torch.zeros(...))
self.ks = nn.Parameter(torch.zeros(...))

# Sigmoid gates scale the projections
qs = torch.sigmoid(self.qs)  # Starts near 0.5, learns to scale
ks = torch.sigmoid(self.ks)

q, k, v = qs * query, ks * key, vs * value
2

Overparam Module for Value Scaling

The value scaling uses an overparameterized transformation similar to LSTM gating, projecting to 2x dimensions and using sigmoid-tanh gating.

class Overparam(nn.Module):
    def __init__(self, nhid):
        self.linear = nn.Linear(nhid, 2 * nhid)

    def forward(self, x):
        x = self.linear(x)
        c, f = x.chunk(2, dim=-1)  # Split into content and forget
        return torch.sigmoid(f) * torch.tanh(c)  # LSTM-like gating
3

Memory Caching for Long Context

SHA-RNN caches hidden states from previous sequences, allowing attention to look back up to 5000 tokens. This enables efficient long-context modeling without reprocessing entire history.

def forward(self, x, hidden, mems=None):
    # Concatenate memory with current hidden states
    if mems is not None:
        bigh = torch.cat([mem, mh], dim=0)  # Past + current

    # Attention over extended context
    attn_out = self.attention(query, key=bigh, value=bigh)

    # Keep only recent tokens in memory
    new_mem = bigh[-self.num_max_positions:]  # Sliding window

    return output, hidden, new_mem
4

Boom: Efficient Feed-Forward

The 'Boom' module is a shortcut-based feed-forward that's faster than standard Transformer FFN by using reshape and sum reduction instead of a second linear layer.

class Boom(nn.Module):
    """Feed-forward with shortcut for efficiency."""
    def __init__(self, nhid, nout=None, shortcut=True):
        nout = nout or nhid
        self.linear = nn.Linear(nhid, nout * 4)  # Expand 4x
        self.shortcut = shortcut

    def forward(self, x):
        x = self.linear(x)
        x = gelu(x)
        if self.shortcut:
            # Reshape and sum instead of second linear
            x = x.view(*x.shape[:-1], 4, -1).sum(dim=-2)
        return x
👁️

Single-Head Attention

The core mechanism

Attention Implementation

class Attention(nn.Module):
    def __init__(self, nhid):
        self.num_heads = 1  # Single head!

        # Projections
        self.q_proj = nn.Linear(nhid, nhid)
        self.k_proj = nn.Linear(nhid, nhid)
        self.v_proj = nn.Linear(nhid, nhid)
        self.out_proj = nn.Linear(nhid, nhid)

        # Learnable scaling (initialized near zero)
        self.qs = nn.Parameter(torch.zeros(1, 1, nhid))
        self.ks = nn.Parameter(torch.zeros(1, 1, nhid))
        self.vs = Overparam(nhid)  # Overparameterized scaling

    def forward(self, query, key, value, mask=None):
        # Apply learned scaling
        qs = torch.sigmoid(self.qs)
        ks = torch.sigmoid(self.ks)

        q = qs * self.q_proj(query)
        k = ks * self.k_proj(key)
        v = self.vs(self.v_proj(value))

        # Scaled dot-product attention
        d_k = q.size(-1)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, v)

        return self.out_proj(output)

Why Single Head Works

Multi-head attention in Transformers theoretically allows attending to different representation subspaces. However, SHA-RNN shows that with an LSTM backbone providing strong sequential modeling, a single attention head is sufficient to capture the long-range dependencies that LSTMs miss. The LSTM already provides "multiple heads" of local context processing.

🔄

LSTM + Attention Synergy

Why the combination works

What LSTMs Provide

  • Inherent positional information - No need for positional embeddings, the recurrent structure encodes sequence order
  • Strong local modeling - Excellent at capturing local patterns and n-gram-like features
  • Gated memory - Selective retention of important information through forget/input gates
  • Efficient sequential processing - O(n) complexity for sequence length

What Attention Adds

  • Direct long-range access - Can attend to any position in the 5000-token context window
  • Content-based retrieval - Find relevant context based on semantic similarity, not just position
  • Parallel context access - All positions in the context are equally accessible in one step
  • Gradient highways - Direct paths for gradients to flow to distant positions

The Block Structure

class Block(nn.Module):
    def __init__(self, embed_dim, hidden_dim, use_attn=False):
        self.lstm = nn.LSTM(embed_dim, embed_dim)
        self.attn = Attention(embed_dim) if use_attn else None
        self.boom = Boom(embed_dim, hidden_dim)  # Feed-forward

        # Layer norms
        self.ln_start = nn.LayerNorm(embed_dim)
        self.ln_mid = nn.LayerNorm(embed_dim)
        self.ln_mem = nn.LayerNorm(embed_dim)
        self.ln_ff = nn.LayerNorm(embed_dim)

    def forward(self, x, hidden, mem=None):
        # LSTM processing
        x = self.ln_start(x)
        lstm_out, new_hidden = self.lstm(x, hidden)

        # Optional attention (only in layer 2)
        if self.attn is not None:
            x = self.ln_mid(lstm_out)
            if mem is not None:
                context = torch.cat([mem, x], dim=0)
            else:
                context = x
            attn_out = self.attn(x, self.ln_mem(context), context)
            lstm_out = lstm_out + attn_out  # Residual

        # Feed-forward
        x = self.ln_ff(lstm_out)
        x = lstm_out + self.boom(x)  # Residual

        return x, new_hidden
🏋️

Training

How to reproduce

Training Configuration

python -u main.py \
    --epochs 32 \
    --dropouth 0.1 --dropouti 0.1 --dropout 0.1 \
    --data data/enwik8/ \
    --save ENWIK8.pt \
    --log-interval 10 \
    --seed 5512 \
    --optimizer lamb \           # LAMB optimizer (critical!)
    --bptt 1024 \                # Sequence length
    --warmup 800 \               # LR warmup steps
    --lr 2e-3 \
    --emsize 1024 \              # Embedding dimension
    --nhid 4096 \                # Hidden dimension
    --nlayers 4 \                # 4 LSTM layers
    --batch_size 16

Key Training Choices

  • LAMB optimizer - Layer-wise Adaptive Moments for batch training, crucial for stability
  • 800 step warmup - Gradual LR increase prevents early instability
  • Mixed precision (AMP) - Faster training with minimal quality loss
  • Gradient clipping at 0.3 - Prevents exploding gradients in LSTMs
  • Two-stage training - First 32 epochs at 2e-3, then fine-tune at 1e-3

Hardware Requirements

  • GPU - Single Titan V (12GB) or similar
  • Training time - ~24 hours for full training
  • Epoch time - 30-60 minutes per epoch
  • Memory - Fits in 12GB with batch_size 16
  • Can reduce - Lower batch size or BPTT if memory constrained

Reproducibility Advantage

Unlike Transformers which often require careful hyperparameter tuning, complex learning rate schedules, and fragile warmup procedures, SHA-RNN training is remarkably stable. If memory becomes an issue, you can simply reduce parameters and the model gracefully degrades to a standard LSTM.

📊

Benchmarks

enwik8 byte-level language modeling

ModelTest BPCParametersLSTM?Training
Krause mLSTM1.2446M-
AWD-LSTM1.2344M-
SHA-LSTM (this work)1.0763M~24h, 1 GPU
Transformer-XL (12L)1.0641MMuch longer, multi-GPU
Transformer-XL (18L)1.0388MMuch longer, multi-GPU
Adaptive Span Transformer1.0238M~24h, multi-GPU

Key Takeaway

SHA-LSTM achieves within 1% of Transformer-XL performance using only a single consumer GPU and ~24 hours of training. The only models that significantly beat it require either more layers (18L Transformer-XL) or specialized attention mechanisms (Adaptive Span). For researchers and practitioners without access to large compute clusters, SHA-RNN represents a practical path to competitive language modeling.

🚀

Production Benefits

Why this matters for deployment

Standard Components

Only uses LSTM, single-head attention, and feed-forward layers. No exotic ops, easy to export to ONNX and deploy with existing optimized frameworks.

Smaller Memory

Single attention head and LSTM backbone require less memory than multi-head Transformers, especially for long sequences.

Graceful Degradation

If resources are constrained, can fall back to pure LSTM by removing attention. The model still works, just with reduced long-range capability.

🤔

The Bigger Picture

Lessons for ML research

Architectural Fashion vs. Practical Efficiency

The ML community has a tendency to declare previous architectures "dead" when new ones emerge. RNNs were declared obsolete after "Attention Is All You Need" (2017), but SHA-RNN shows this was premature. The question isn't "RNNs vs Transformers" but rather "what's the right tool for the job?"

Compute Accessibility

The trend toward ever-larger models trained on massive GPU clusters excludes most researchers and institutions from participating in cutting-edge work. SHA-RNN demonstrates that careful architectural design can achieve competitive results without requiring TPU pods, democratizing access to language modeling research.

The "Found Bug That Works" Phenomenon

The codebase contains this delightful comment:

# BUG: This does _nothing_ as mix isn't set to r ...
# But ... I got good results with this ... so ...
# Let's leave it as is for right now ...

This illustrates a broader point: empirical validation often trumps theoretical purity. Sometimes the "wrong" implementation works well, and understanding why can lead to new insights.

📁

Key Files

Repository structure

FilePurpose
model.pyCore architecture: SHARNN, Block, Attention, Boom classes
main.pyTraining script with LAMB optimizer and mixed precision
generate.pyInference and text generation utilities
splitcross.pyAdaptive softmax for large vocabularies
lookahead.pyLookahead optimizer wrapper
🔗

Resources