Skip to content

Hybrid Mamba

foreblocks.hybrid_mamba provides custom SSM (State Space Model) building blocks that combine selective-scan dynamics with sliding-window attention. Two block variants are available:

  • HybridMambaBlock — pure SSM (Mamba v1-style selective scan) with optional pre-norm
  • HybridMamba2Block — parallel SSM + sliding-window attention branches fused with a learned gate, RoPE, GQA, and output normalisation (Mamba-2 / SSD-based)

Installation requirements

The module is pure PyTorch by default. Optional acceleration layers:

Backend What it enables How to activate
Triton Fused causal conv and grouped-SSD scan pip install triton
CUDA extension Selective scan CUDA kernel precompile_selective_scan_extension()

Check availability at runtime:

from foreblocks.hybrid_mamba import TRITON_AVAILABLE, extension_available

print(TRITON_AVAILABLE)       # True if triton is installed
print(extension_available())  # True if CUDA extension is built

RotaryEmbedding

Rotary Position Embedding (Su et al. 2021 — RoFormer). Applied to Q and K inside SlidingWindowAttention. Available as a standalone module if you need to add RoPE to your own attention implementation.

from foreblocks.hybrid_mamba import RotaryEmbedding

rope = RotaryEmbedding(head_dim=64, base=10_000, max_seq_len=2048)

# tensors are (B, H, T, head_dim)
q_rot, k_rot = rope(q, k)

The cache is extended automatically when T > max_seq_len, so setting a generous upper bound is fine.

Constructor parameters

Parameter Default Description
head_dim required Dimension of each attention head. Must be even.
base 10_000 Frequency base (original paper default).
max_seq_len 8192 Pre-built cache length; extended on demand.

HybridMambaBlock

A single Mamba-style block. Expands d_modeld_inner via a causal depthwise conv, runs selective scan, then projects back. An optional pre-norm (use_pre_norm=True) stabilises training in deep stacks.

import torch
from foreblocks.hybrid_mamba import HybridMambaBlock

block = HybridMambaBlock(
    d_model=256,
    d_inner=512,       # defaults to 2 * d_model
    d_state=16,
    d_conv=4,
    dt_rank=None,      # auto: max(4, ceil(d_model / 16))
    use_cuda_scan=True,
    use_pre_norm=True, # LayerNorm before in_proj — recommended
)

x = torch.randn(8, 64, 256)  # (batch, seq_len, d_model)
y = block(x)                 # same shape as x

Constructor parameters

Parameter Default Description
d_model required Input / output feature dimension
d_inner 2 * d_model Inner (expanded) dimension
d_state 16 SSM state dimension per feature
d_conv 4 Causal conv kernel size
dt_rank auto Low-rank projection for Δt; Nonemax(4, ceil(d_model/16))
dt_min / dt_max 1e-4 / 1.0 Clamp range for the time-step after softplus
use_cuda_scan True Use CUDA kernel if extension is loaded; falls back to PyTorch otherwise
use_pre_norm True Apply LayerNorm on the input before the expansion projection

HybridMamba2Block

Combines a multi-head SSD branch (StructuredStateSpaceDualityBranch) with a SlidingWindowAttention branch. The two outputs are mixed with a sigmoid gate, then passed through an output norm before the final projection:

ssm_out  = SSD( LayerNorm(x) )
attn_out = SlidingWindowAttn( LayerNorm(x) )
gate     = sigmoid( Linear( LayerNorm(x) ) )
mixed    = gate * ssm_out + (1 − gate) * attn_out
output   = out_proj( LayerNorm(mixed) )

The output LayerNorm (out_norm) stabilises gradients through the mixing gate.

from foreblocks.hybrid_mamba import HybridMamba2Block

block = HybridMamba2Block(
    d_model=256,
    d_inner=512,
    d_state=16,
    d_conv=4,
    dt_rank=None,
    num_heads=8,
    n_kv_heads=2,       # GQA: 2 KV heads shared across 8 query heads
    window_size=128,
    attn_dropout=0.0,
    use_gated_delta=False,
    rope_base=10_000,
    max_seq_len=2048,
)

x = torch.randn(4, 128, 256)
y = block(x)  # (4, 128, 256)

Constructor parameters

Parameter Default Description
d_model required Model dimension
d_inner 2 * d_model SSM inner dimension
d_state 16 SSM state dimension per head
d_conv 4 Causal conv kernel size in the SSM branch
dt_rank auto Low-rank Δt projection size
num_heads 8 Query heads for attention; head count for SSD
n_kv_heads None KV heads for GQA. None → standard MHA. Must divide num_heads.
window_size 128 Sliding-window size for local causal attention
attn_dropout 0.0 Attention dropout during training
use_gated_delta False Add per-head sigmoid gate on Δt in the SSD branch
rope_base 10_000 RoPE frequency base
max_seq_len 8192 Pre-built RoPE cache length

Grouped Query Attention (GQA)

Set n_kv_heads to a divisor of num_heads to enable GQA. With num_heads=8, n_kv_heads=2 the model uses 4× fewer KV parameters and KV cache entries compared to MHA, matching the Llama 3 / Mistral configuration:

block = HybridMamba2Block(
    d_model=512,
    num_heads=16,
    n_kv_heads=4,   # 4 query heads share each KV head
    window_size=256,
)

KV heads are broadcast to query heads via repeat_interleave before SDPA — no extra memory allocation beyond the repeat.

Stacking blocks into a model

TinyHybridMamba2LM shows the recommended stacking pattern: every attn_every_n layers uses a HybridMamba2Block; all others use the cheaper HybridMambaBlock.

from foreblocks.hybrid_mamba import TinyHybridMamba2LM

model = TinyHybridMamba2LM(
    vocab_size=50257,
    d_model=512,
    n_layers=12,
    d_state=16,
    d_conv=4,
    num_heads=8,
    n_kv_heads=2,       # GQA across all hybrid blocks
    window_size=256,
    attn_every_n=4,     # HybridMamba2Block at layers 0, 4, 8; rest are plain Mamba
    tie_embeddings=True,
    use_pre_norm=True,
    rope_base=10_000,
    max_seq_len=4096,
)

For use as a time-series backbone, replace the embedding + LM-head with your own projection layers and feed patch embeddings or raw-feature vectors in place of input_ids.

Diagnostics

from foreblocks.hybrid_mamba import run_default_diagnostics, benchmark_block

run_default_diagnostics()   # quick correctness checks for ops on current device

stats = benchmark_block(d_model=256, seq_len=512, batch=8)
print(stats)  # wall-clock and memory stats

Ops reference

Symbol Where Description
causal_depthwise_conv1d ops/ Causal grouped conv; Triton kernel when available
selective_scan ops/ S4/Mamba v1 scan; CUDA kernel when loaded
grouped_ssd_scan ops/ Multi-head SSD scan (Mamba-2); Triton kernel when available
dt_prep ops/ Δt bias add + softplus + clamp
fused_out ops/ Fused RMSNorm + gate + residual add