Hybrid Mamba¶
foreblocks.hybrid_mamba provides custom SSM (State Space Model) building blocks that combine selective-scan dynamics with sliding-window attention. Two block variants are available:
HybridMambaBlock— pure SSM (Mamba v1-style selective scan) with optional pre-normHybridMamba2Block— parallel SSM + sliding-window attention branches fused with a learned gate, RoPE, GQA, and output normalisation (Mamba-2 / SSD-based)
Installation requirements¶
The module is pure PyTorch by default. Optional acceleration layers:
| Backend | What it enables | How to activate |
|---|---|---|
| Triton | Fused causal conv and grouped-SSD scan | pip install triton |
| CUDA extension | Selective scan CUDA kernel | precompile_selective_scan_extension() |
Check availability at runtime:
from foreblocks.hybrid_mamba import TRITON_AVAILABLE, extension_available
print(TRITON_AVAILABLE) # True if triton is installed
print(extension_available()) # True if CUDA extension is built
RotaryEmbedding¶
Rotary Position Embedding (Su et al. 2021 — RoFormer). Applied to Q and K inside SlidingWindowAttention. Available as a standalone module if you need to add RoPE to your own attention implementation.
from foreblocks.hybrid_mamba import RotaryEmbedding
rope = RotaryEmbedding(head_dim=64, base=10_000, max_seq_len=2048)
# tensors are (B, H, T, head_dim)
q_rot, k_rot = rope(q, k)
The cache is extended automatically when T > max_seq_len, so setting a generous upper bound is fine.
Constructor parameters¶
| Parameter | Default | Description |
|---|---|---|
head_dim |
required | Dimension of each attention head. Must be even. |
base |
10_000 |
Frequency base (original paper default). |
max_seq_len |
8192 |
Pre-built cache length; extended on demand. |
HybridMambaBlock¶
A single Mamba-style block. Expands d_model → d_inner via a causal depthwise conv, runs selective scan, then projects back. An optional pre-norm (use_pre_norm=True) stabilises training in deep stacks.
import torch
from foreblocks.hybrid_mamba import HybridMambaBlock
block = HybridMambaBlock(
d_model=256,
d_inner=512, # defaults to 2 * d_model
d_state=16,
d_conv=4,
dt_rank=None, # auto: max(4, ceil(d_model / 16))
use_cuda_scan=True,
use_pre_norm=True, # LayerNorm before in_proj — recommended
)
x = torch.randn(8, 64, 256) # (batch, seq_len, d_model)
y = block(x) # same shape as x
Constructor parameters¶
| Parameter | Default | Description |
|---|---|---|
d_model |
required | Input / output feature dimension |
d_inner |
2 * d_model |
Inner (expanded) dimension |
d_state |
16 |
SSM state dimension per feature |
d_conv |
4 |
Causal conv kernel size |
dt_rank |
auto | Low-rank projection for Δt; None → max(4, ceil(d_model/16)) |
dt_min / dt_max |
1e-4 / 1.0 |
Clamp range for the time-step after softplus |
use_cuda_scan |
True |
Use CUDA kernel if extension is loaded; falls back to PyTorch otherwise |
use_pre_norm |
True |
Apply LayerNorm on the input before the expansion projection |
HybridMamba2Block¶
Combines a multi-head SSD branch (StructuredStateSpaceDualityBranch) with a SlidingWindowAttention branch. The two outputs are mixed with a sigmoid gate, then passed through an output norm before the final projection:
ssm_out = SSD( LayerNorm(x) )
attn_out = SlidingWindowAttn( LayerNorm(x) )
gate = sigmoid( Linear( LayerNorm(x) ) )
mixed = gate * ssm_out + (1 − gate) * attn_out
output = out_proj( LayerNorm(mixed) )
The output LayerNorm (out_norm) stabilises gradients through the mixing gate.
from foreblocks.hybrid_mamba import HybridMamba2Block
block = HybridMamba2Block(
d_model=256,
d_inner=512,
d_state=16,
d_conv=4,
dt_rank=None,
num_heads=8,
n_kv_heads=2, # GQA: 2 KV heads shared across 8 query heads
window_size=128,
attn_dropout=0.0,
use_gated_delta=False,
rope_base=10_000,
max_seq_len=2048,
)
x = torch.randn(4, 128, 256)
y = block(x) # (4, 128, 256)
Constructor parameters¶
| Parameter | Default | Description |
|---|---|---|
d_model |
required | Model dimension |
d_inner |
2 * d_model |
SSM inner dimension |
d_state |
16 |
SSM state dimension per head |
d_conv |
4 |
Causal conv kernel size in the SSM branch |
dt_rank |
auto | Low-rank Δt projection size |
num_heads |
8 |
Query heads for attention; head count for SSD |
n_kv_heads |
None |
KV heads for GQA. None → standard MHA. Must divide num_heads. |
window_size |
128 |
Sliding-window size for local causal attention |
attn_dropout |
0.0 |
Attention dropout during training |
use_gated_delta |
False |
Add per-head sigmoid gate on Δt in the SSD branch |
rope_base |
10_000 |
RoPE frequency base |
max_seq_len |
8192 |
Pre-built RoPE cache length |
Grouped Query Attention (GQA)¶
Set n_kv_heads to a divisor of num_heads to enable GQA. With num_heads=8, n_kv_heads=2 the model uses 4× fewer KV parameters and KV cache entries compared to MHA, matching the Llama 3 / Mistral configuration:
block = HybridMamba2Block(
d_model=512,
num_heads=16,
n_kv_heads=4, # 4 query heads share each KV head
window_size=256,
)
KV heads are broadcast to query heads via repeat_interleave before SDPA — no extra memory allocation beyond the repeat.
Stacking blocks into a model¶
TinyHybridMamba2LM shows the recommended stacking pattern: every attn_every_n layers uses a HybridMamba2Block; all others use the cheaper HybridMambaBlock.
from foreblocks.hybrid_mamba import TinyHybridMamba2LM
model = TinyHybridMamba2LM(
vocab_size=50257,
d_model=512,
n_layers=12,
d_state=16,
d_conv=4,
num_heads=8,
n_kv_heads=2, # GQA across all hybrid blocks
window_size=256,
attn_every_n=4, # HybridMamba2Block at layers 0, 4, 8; rest are plain Mamba
tie_embeddings=True,
use_pre_norm=True,
rope_base=10_000,
max_seq_len=4096,
)
For use as a time-series backbone, replace the embedding + LM-head with your own projection layers and feed patch embeddings or raw-feature vectors in place of input_ids.
Diagnostics¶
from foreblocks.hybrid_mamba import run_default_diagnostics, benchmark_block
run_default_diagnostics() # quick correctness checks for ops on current device
stats = benchmark_block(d_model=256, seq_len=512, batch=8)
print(stats) # wall-clock and memory stats
Ops reference¶
| Symbol | Where | Description |
|---|---|---|
causal_depthwise_conv1d |
ops/ |
Causal grouped conv; Triton kernel when available |
selective_scan |
ops/ |
S4/Mamba v1 scan; CUDA kernel when loaded |
grouped_ssd_scan |
ops/ |
Multi-head SSD scan (Mamba-2); Triton kernel when available |
dt_prep |
ops/ |
Δt bias add + softplus + clamp |
fused_out |
ops/ |
Fused RMSNorm + gate + residual add |