Mamba: Linear-Time Sequence Modeling with Selective State Spaces — Under the Hood¶
Source: Mamba: Linear-Time Sequence Modeling with Selective State Spaces — Albert Gu & Tri Dao (arXiv:2312.00752v2, CMU + Princeton)
Overview¶
Mamba is an architecture that replaces the Transformer's quadratic-scaling self-attention with a selective state space model (S6) that runs in linear time. The key insight is that classical SSMs are fast but dumb (they can't select what to remember), while Transformers are smart but slow (O(L²) memory in the sequence length). Mamba breaks this tradeoff through a hardware-aware algorithm that materializes expanded states only in fast GPU SRAM rather than slow HBM.
1. The Fundamental Tradeoff: Context Compression¶
All sequence models trade off between two forces:
flowchart LR
subgraph Transformer
attn["Self-Attention\nStores full KV cache\nO(L) memory per layer\nO(L²) compute\nCan selectively attend anywhere"]
end
subgraph RNN
rnn["Recurrence\nFixed-size hidden state h\nO(1) memory per step\nO(L) compute\nCannot selectively forget"]
end
subgraph Mamba
ssm["Selective SSM\nInput-dependent state transitions\nO(1) inference memory\nO(L) compute\nCAN selectively remember/forget"]
end
Transformer -->|"too slow"| gap["efficiency gap"]
RNN -->|"too rigid"| gap
gap -->|"solved by"| Mamba
The efficiency vs. effectiveness axis can be understood as a state compression problem: - Attention: does not compress (stores all context) → effective but O(L²) - LTI SSM: compresses to fixed-size state → fast but cannot select relevant content - Mamba (S6): input-dependent compression → fast AND content-aware
2. State Space Model Mathematics¶
2.1 The Continuous System¶
A structured SSM maps a 1D input sequence x(t) → y(t) through a hidden state h(t):
Where A ∈ ℝᴺˣᴺ is the state transition matrix, B ∈ ℝᴺˣ¹ is the input projection, C ∈ ℝ¹ˣᴺ is the output projection.
2.2 Discretization — ZOH Rule¶
The continuous parameters (Δ, A, B) are converted to discrete parameters (Ā, B̄) using Zero-Order Hold:
Δ (timescale/step size) is a learnable parameter that controls how much the model "samples" from the input at each discrete time step. Large Δ → more weight on current input; small Δ → rely more on hidden state.
flowchart LR
cont_params["Continuous params\n(Δ, A, B)"]
zoh["Zero-Order Hold:\nĀ = exp(Δ·A)\nB̄ = ZOH formula"]
disc_params["Discrete params\n(Ā, B̄, C)"]
recurrence["Recurrence mode:\nhₜ = Ā·hₜ₋₁ + B̄·xₜ\nyₜ = C·hₜ"]
convolution["Convolution mode:\ny = x * K\nK = (CB, CAB, ..., CᴬᵏB, ...)"]
cont_params --> zoh --> disc_params
disc_params --> recurrence
disc_params --> convolution
recurrence -->|inference| out["O(1) per step"]
convolution -->|training| train["O(L log L) parallel"]
2.3 Dual Computation Paths¶
stateDiagram-v2
[*] --> Mode
Mode --> ConvMode: Training (full sequence visible)
Mode --> RecMode: Inference (token-by-token)
ConvMode --> Parallel: Compute global kernel K\nFFT-based O(L log L)\nAll tokens processed simultaneously
RecMode --> Sequential: Maintain hidden state h\nO(1) per new token\nOnly h + xₜ needed
Critical property: S4 (classical SSM) can switch modes because A, B, C, Δ are time-invariant constants — the convolution kernel K is the same for all positions.
3. The Selection Mechanism — S6¶
3.1 The Problem with LTI (Linear Time Invariance)¶
flowchart TD
input["Input: [A, B, C, D, -, -, A]"]
task["Task: Copy only token A, ignore others"]
lti["LTI SSM: A,B,C,Δ = constants\nCannot distinguish relevant vs. irrelevant tokens\nSame recurrence dynamics for every token"]
select["S6 (Selective): B,C,Δ = f(xₜ)\nDifferent dynamics per token\nCan gate out irrelevant inputs"]
input --> task
task --> lti
task --> select
lti -->|"fails Selective Copying"| fail["All tokens treated equally"]
select -->|"passes Selective Copying"| pass["Filters irrelevant, retains relevant"]
3.2 Making Parameters Input-Dependent¶
The key change from S4 → S6: B, C, Δ become functions of the current input xₜ:
| Parameter | S4 (LTI) | S6 (Selective) | Tensor shape |
|---|---|---|---|
| A | (D, N) constant | (D, N) constant | Fixed |
| B | (D, N) constant | B = Linearₙ(x) | (B, L, N) |
| C | (D, N) constant | C = Linearₙ(x) | (B, L, N) |
| Δ | (D,) constant | Δ = Broadcast(Linear₁(x)) | (B, L, D) |
sequenceDiagram
participant Input as xₜ ∈ ℝᴰ
participant Linear as Linear Projections
participant SSM as Selective SSM
participant State as Hidden State hₜ
Input->>Linear: project input
Linear-->>SSM: Bₜ = Linearₙ(xₜ)
Linear-->>SSM: Cₜ = Linearₙ(xₜ)
Linear-->>SSM: Δₜ = softplus(param + Linear₁(xₜ))
SSM->>SSM: Discretize: Āₜ = exp(Δₜ·A), B̄ₜ = ZOH(Δₜ,A,Bₜ)
SSM->>State: hₜ = Āₜ·hₜ₋₁ + B̄ₜ·xₜ
State-->>SSM: yₜ = Cₜ·hₜ
Consequence: The model is now time-varying — the recurrence parameters change at every step. This breaks the LTI equivalence to convolution, requiring a new computation strategy.
4. Hardware-Aware Selective Scan — The GPU Algorithm¶
4.1 The Memory Hierarchy Problem¶
block-beta
columns 1
block:gpu["GPU Memory Hierarchy"]:1
sram["SRAM (on-chip): ~20MB, ~19 TB/s bandwidth\nFast but tiny"]
hbm["HBM (off-chip): ~80GB, ~2 TB/s bandwidth\nSlow but large"]
end
note1["Naive approach: materialize full state h ∈ ℝ^(B,L,D,N)\nB=batch, L=seq len, D=channels, N=state dim\nFor L=1000, D=1024, N=16: ~50GB → doesn't fit in SRAM"]
Root cause of the bottleneck: The expanded state h has shape (B, L, D, N) — a factor of N (~16–64) larger than input/output (B, L, D). Writing this to HBM on every step is bandwidth-bound.
4.2 The Solution: Kernel Fusion + SRAM-Only State¶
flowchart TD
subgraph Standard["Naive Approach (slow)"]
load1["Load (Δ,A,B,C) from HBM"]
discretize1["Discretize → (Ā,B̄) — write to HBM"]
scan1["Load (Ā,B̄) from HBM for scan"]
state1["Compute h — materialize in HBM (LARGE)"]
output1["Compute y — write to HBM"]
end
subgraph Fused["Mamba Fused Kernel (fast)"]
load2["Load (Δ,A,B,C) from HBM to SRAM"]
fused["All in SRAM:\n• discretize Δ → Ā,B̄\n• run recurrence scan\n• compute y = C·h"]
write2["Write only y ∈ (B,L,D) back to HBM\nNever materialize expanded h in HBM"]
end
Standard -->|"memory bandwidth bottleneck"| slow["3-10× slower"]
Fused -->|"all heavy computation in SRAM"| fast["3× faster on A100"]
4.3 Parallel Scan Algorithm¶
Even though the recurrence hₜ = Āₜ·hₜ₋₁ + B̄ₜ·xₜ appears sequential, it can be parallelized using prefix scan (Blelloch 1990):
flowchart TD
seq["Sequence: h₀, h₁, h₂, h₃, h₄, h₅, h₆, h₇"]
subgraph "Parallel Scan (O(log L) depth)"
step1["Step 1: pairs (0,1)(2,3)(4,5)(6,7)\nCompute Ā₁h₀+B̄₁x₁, etc."]
step2["Step 2: pairs (0..1,2..3)(4..5,6..7)\nMerge prefix computations"]
step3["Step 3: global prefix\nAll hₜ known simultaneously"]
end
seq --> step1 --> step2 --> step3
step3 --> out["All outputs in O(L) work, O(log L) depth"]
Key property: An associative binary operator exists for the (Ā, B̄x) recurrence, enabling prefix scan. Each merge operation: (Ā₂, B̄₂x₂) ∘ (Ā₁, B̄₁x₁) = (Ā₂·Ā₁, Ā₂·B̄₁x₁ + B̄₂x₂).
4.4 Recomputation for Backpropagation¶
sequenceDiagram
participant Forward
participant SRAM
participant HBM
Forward->>SRAM: load (Δ,A,B,C,x)
SRAM->>SRAM: compute all intermediate h states
SRAM->>HBM: save only y (small), discard h (large)
Note over Forward: Backward pass needs intermediate h
Forward->>SRAM: reload (Δ,A,B,C,x) from HBM
SRAM->>SRAM: recompute h on the fly during backward
Note over SRAM: Extra compute but saves N× memory
This is the same technique as FlashAttention's recomputation — trade compute for memory. Mamba achieves the same memory footprint as FlashAttention despite operating on a larger expanded state.
5. Mamba Block Architecture¶
5.1 Single Block Design¶
flowchart TD
input["Input x ∈ ℝ^(B,L,D)"]
norm["LayerNorm"]
linear1["Linear: D → E·D (expand, E=2)"]
linear2["Linear: D → E·D (gating branch)"]
conv["Causal Conv1D (local convolution)"]
silu["SiLU activation"]
s6["Selective SSM (S6)"]
gate["Element-wise × (gating)"]
proj["Linear: E·D → D (project back)"]
residual["+ Residual connection"]
input --> norm
norm --> linear1
norm --> linear2
linear1 --> conv --> silu --> s6
s6 --> gate
linear2 --> gate
gate --> proj --> residual
input --> residual
Parameter count for one block with expansion E=2: - Input/output projections: 2ED² + ED² = 3ED² - SSM parameters: Δ, A, B, C projections — much smaller, ~D·N
5.2 Full Mamba Architecture¶
flowchart TD
embed["Token Embedding: vocab → D"]
b1["Mamba Block 1"]
b2["Mamba Block 2"]
dots["..."]
bn["Mamba Block n"]
norm_final["Final LayerNorm"]
head["LM Head: D → vocab"]
embed --> b1 --> b2 --> dots --> bn --> norm_final --> head
b1 -->|"residual stream"| b2
b2 -->|"residual stream"| dots
vs. Transformer: No MHA blocks, no separate MLP blocks. Mamba block integrates both the mixing (SSM) and transformation (gating) in a single homogeneous unit.
6. Selection Mechanism Interpretations¶
6.1 Connection to Gating in RNNs¶
The Δ parameter with softplus activation has a deep connection to LSTM forget gates:
flowchart LR
delta_small["Small Δₜ\n(Δ→0)"]
delta_large["Large Δₜ\n(Δ→∞)"]
delta_small -->|"Ā = exp(Δ·A) ≈ I"| forget_not["Keep hidden state\n(remember context)"]
delta_large -->|"Ā → 0 (for A<0)"| forget_yes["Reset hidden state\n(ignore context, focus on current input)"]
delta_large -->|"B̄ = ZOH ≈ A⁻¹Bₜ"| attend["Current token fully attended"]
When Δ is large: the model attends to the current token. When Δ is small: the model passes the hidden state unchanged. This is the mechanism that enables selective copying and induction heads.
6.2 Selective Copy Task — Data Flow¶
sequenceDiagram
participant Tokens
participant Δ
participant State as Hidden State h
participant Output
Tokens->>Δ: token A (relevant) → Δ large
Δ->>State: Ā≈0: reset state, B̄ large: write A into h
Tokens->>Δ: token - (irrelevant) → Δ small
Δ->>State: Ā≈I: keep A in h unchanged
Tokens->>Δ: token A (relevant again)
Δ->>State: write A again
State->>Output: recall A from h at output position
7. Memory Layout — Tensor Dimensions¶
block-beta
columns 3
B["B (batch)"] L["L (sequence)"] D["D (channels/d_model)"]
block:input["Input x"]:3
xshape["(B, L, D)"]
end
block:expanded["Expanded state h (SRAM only)"]:3
hshape["(B, L, D, N)\nN = SSM state dim (~16-64)\nNEVER written to HBM"]
end
block:output["Output y"]:3
yshape["(B, L, D)"]
end
block:params["Selective params"]:3
bshape["B: (B,L,N)"] cshape["C: (B,L,N)"] dshape["Δ: (B,L,D)"]
end
Memory savings vs. naive: Instead of writing (B,L,D,N) to HBM (~N× larger), only (B,L,D) is written. For N=16, that's 16× less memory bandwidth per layer.
8. Performance Characteristics¶
8.1 Complexity Comparison¶
| Model | Training FLOPs | Inference Memory | Inference per step |
|---|---|---|---|
| Transformer | O(L²·D) | O(L·D) KV cache | O(L·D) growing |
| SSM (S4) | O(L·D·N·log L) | O(D·N) constant | O(D·N) constant |
| Mamba (S6) | O(L·D·N) | O(D·N) constant | O(D·N) constant |
graph LR
subgraph "Inference Memory per token"
transformer_mem["Transformer: KV cache grows with L\nAt L=100K: ~10GB just for cache"]
mamba_mem["Mamba: fixed hidden state h\nAt any L: O(D·N) = constant"]
end
8.2 Throughput — 5× vs Transformers¶
flowchart LR
trans_inf["Transformer inference:\n1. Load K,V cache (grows with context)\n2. Compute attention over all past tokens\n3. Bandwidth-limited by cache size"]
mamba_inf["Mamba inference:\n1. Load only current state h (fixed size)\n2. One matrix-vector multiply per token\n3. Cache-friendly, constant bandwidth"]
trans_inf -->|"L=1000 context"| slow2["~10ms per token"]
mamba_inf -->|"L=1000 context"| fast2["~2ms per token (5× faster)"]
9. Why A Must Remain Constant¶
Making A input-dependent would destroy the associativity of the scan operator, eliminating the parallel scan parallelization:
flowchart TD
assoc{"Is (Āₜ, B̄ₜxₜ) ∘ (Āₜ₋₁, B̄ₜ₋₁xₜ₋₁) associative?"}
yes["YES — because A is diagonal/fixed\nMerge operator is well-defined\nParallel scan O(log L) depth achievable"]
no["NO — if A is input-dependent\nMerge requires full left-right order\nCannot parallelize beyond O(L)"]
assoc -->|"A constant"| yes
assoc -->|"A input-dependent"| no
The authors keep A fixed (as a diagonal matrix of negative reals), while making B, C, Δ input-dependent. This is the minimal change that enables selectivity while preserving hardware efficiency.
10. Data Flow Summary — Forward Pass¶
sequenceDiagram
participant Input as x ∈ (B,L,D)
participant LinearProj as Linear Projections
participant Discretize as Discretizer (ZOH)
participant Scan as Parallel Scan (SRAM)
participant Output as y ∈ (B,L,D)
Input->>LinearProj: project x → B:(B,L,N), C:(B,L,N), Δ:(B,L,D)
LinearProj->>Discretize: compute Ā=exp(Δ·A), B̄=ZOH(Δ,A,B)
Note over Discretize: All in SRAM — no HBM writes
Discretize->>Scan: Ā:(B,L,D,N), B̄:(B,L,D,N), C:(B,L,N)
Scan->>Scan: Parallel prefix scan\nhₜ = Āₜhₜ₋₁ + B̄ₜxₜ
Scan->>Output: yₜ = Cₜ · hₜ\nwrite (B,L,D) to HBM
Key Invariants¶
| Property | Classical S4 | Mamba S6 |
|---|---|---|
| Parameters | Time-invariant | Input-dependent (B,C,Δ) |
| Computation mode | Conv (train) + RNN (infer) | Scan (train) + RNN (infer) |
| State materialization | Can use convolution kernel K | SRAM-only fused kernel |
| Selectivity | No content-awareness | Full content-aware gating |
| Inference memory | O(D·N) | O(D·N) |
| Training complexity | O(L log L) | O(L) |
| Backprop strategy | Standard | Recomputation (like FlashAttention) |
Mamba demonstrates that the Transformer's quadratic complexity is not a fundamental requirement for powerful sequence modeling — it is an artifact of not compressing context selectively. By making compression content-aware via input-dependent parameters and computing efficiently via SRAM-resident parallel scan, Mamba achieves O(L) training, O(1) inference, and Transformer-matching quality.