Skip to content

Hybrid Mamba

foreblocks.hybrid_mamba provides custom SSM (State Space Model) building blocks that combine selective-scan dynamics with sliding-window attention. Two block variants are available:

  • HybridMambaBlock — pure SSM (Mamba v1-style selective scan) with optional pre-norm
  • HybridMamba2Block — parallel SSM + sliding-window attention branches fused with a learned gate, RoPE, GQA, and output normalisation (Mamba-2 / SSD-based)

Installation requirements

The module is pure PyTorch by default. Optional acceleration layers:

BackendWhat it enablesHow to activate
TritonFused causal conv and grouped-SSD scanpip install triton
CUDA extensionSelective scan CUDA kernelprecompile_selective_scan_extension()

Check availability at runtime:

python
from foreblocks.hybrid_mamba import TRITON_AVAILABLE, extension_available

print(TRITON_AVAILABLE)       # True if triton is installed
print(extension_available())  # True if CUDA extension is built

RotaryEmbedding

Rotary Position Embedding (Su et al. 2021 — RoFormer). Applied to Q and K inside SlidingWindowAttention. Available as a standalone module if you need to add RoPE to your own attention implementation.

python
from foreblocks.hybrid_mamba import RotaryEmbedding

rope = RotaryEmbedding(head_dim=64, base=10_000, max_seq_len=2048)

# tensors are (B, H, T, head_dim)
q_rot, k_rot = rope(q, k)

The cache is extended automatically when T > max_seq_len, so setting a generous upper bound is fine.

Constructor parameters

ParameterDefaultDescription
head_dimrequiredDimension of each attention head. Must be even.
base10_000Frequency base (original paper default).
max_seq_len8192Pre-built cache length; extended on demand.

HybridMambaBlock

A single Mamba-style block. Expands d_modeld_inner via a causal depthwise conv, runs selective scan, then projects back. An optional pre-norm (use_pre_norm=True) stabilises training in deep stacks.

python
import torch
from foreblocks.hybrid_mamba import HybridMambaBlock

block = HybridMambaBlock(
    d_model=256,
    d_inner=512,       # defaults to 2 * d_model
    d_state=16,
    d_conv=4,
    dt_rank=None,      # auto: max(4, ceil(d_model / 16))
    use_cuda_scan=True,
    use_pre_norm=True, # LayerNorm before in_proj — recommended
)

x = torch.randn(8, 64, 256)  # (batch, seq_len, d_model)
y = block(x)                 # same shape as x

Constructor parameters

ParameterDefaultDescription
d_modelrequiredInput / output feature dimension
d_inner2 * d_modelInner (expanded) dimension
d_state16SSM state dimension per feature
d_conv4Causal conv kernel size
dt_rankautoLow-rank projection for Δt; Nonemax(4, ceil(d_model/16))
dt_min / dt_max1e-4 / 1.0Clamp range for the time-step after softplus
use_cuda_scanTrueUse CUDA kernel if extension is loaded; falls back to PyTorch otherwise
use_pre_normTrueApply LayerNorm on the input before the expansion projection

HybridMamba2Block

Combines a multi-head SSD branch (StructuredStateSpaceDualityBranch) with a SlidingWindowAttention branch. The two outputs are mixed with a sigmoid gate, then passed through an output norm before the final projection:

ssm_out  = SSD( LayerNorm(x) )
attn_out = SlidingWindowAttn( LayerNorm(x) )
gate     = sigmoid( Linear( LayerNorm(x) ) )
mixed    = gate * ssm_out + (1 − gate) * attn_out
output   = out_proj( LayerNorm(mixed) )

The output LayerNorm (out_norm) stabilises gradients through the mixing gate.

python
from foreblocks.hybrid_mamba import HybridMamba2Block

block = HybridMamba2Block(
    d_model=256,
    d_inner=512,
    d_state=16,
    d_conv=4,
    dt_rank=None,
    num_heads=8,
    n_kv_heads=2,       # GQA: 2 KV heads shared across 8 query heads
    window_size=128,
    attn_dropout=0.0,
    use_gated_delta=False,
    rope_base=10_000,
    max_seq_len=2048,
)

x = torch.randn(4, 128, 256)
y = block(x)  # (4, 128, 256)

Constructor parameters

ParameterDefaultDescription
d_modelrequiredModel dimension
d_inner2 * d_modelSSM inner dimension
d_state16SSM state dimension per head
d_conv4Causal conv kernel size in the SSM branch
dt_rankautoLow-rank Δt projection size
num_heads8Query heads for attention; head count for SSD
n_kv_headsNoneKV heads for GQA. None → standard MHA. Must divide num_heads.
window_size128Sliding-window size for local causal attention
attn_dropout0.0Attention dropout during training
use_gated_deltaFalseAdd per-head sigmoid gate on Δt in the SSD branch
rope_base10_000RoPE frequency base
max_seq_len8192Pre-built RoPE cache length

Grouped Query Attention (GQA)

Set n_kv_heads to a divisor of num_heads to enable GQA. With num_heads=8, n_kv_heads=2 the model uses 4× fewer KV parameters and KV cache entries compared to MHA, matching the Llama 3 / Mistral configuration:

python
block = HybridMamba2Block(
    d_model=512,
    num_heads=16,
    n_kv_heads=4,   # 4 query heads share each KV head
    window_size=256,
)

KV heads are broadcast to query heads via repeat_interleave before SDPA — no extra memory allocation beyond the repeat.

Stacking blocks into a model

TinyHybridMamba2LM shows the recommended stacking pattern: every attn_every_n layers uses a HybridMamba2Block; all others use the cheaper HybridMambaBlock.

python
from foreblocks.hybrid_mamba import TinyHybridMamba2LM

model = TinyHybridMamba2LM(
    vocab_size=50257,
    d_model=512,
    n_layers=12,
    d_state=16,
    d_conv=4,
    num_heads=8,
    n_kv_heads=2,       # GQA across all hybrid blocks
    window_size=256,
    attn_every_n=4,     # HybridMamba2Block at layers 0, 4, 8; rest are plain Mamba
    tie_embeddings=True,
    use_pre_norm=True,
    rope_base=10_000,
    max_seq_len=4096,
)

For use as a time-series backbone, replace the embedding + LM-head with your own projection layers and feed patch embeddings or raw-feature vectors in place of input_ids.

Diagnostics

python
from foreblocks.hybrid_mamba import run_default_diagnostics, benchmark_block

run_default_diagnostics()   # quick correctness checks for ops on current device

stats = benchmark_block(d_model=256, seq_len=512, batch=8)
print(stats)  # wall-clock and memory stats

Ops reference

SymbolWhereDescription
causal_depthwise_conv1dops/Causal grouped conv; Triton kernel when available
selective_scanops/S4/Mamba v1 scan; CUDA kernel when loaded
grouped_ssd_scanops/Multi-head SSD scan (Mamba-2); Triton kernel when available
dt_prepops/Δt bias add + softplus + clamp
fused_outops/Fused RMSNorm + gate + residual add

MIT License