Transformers
mamba2
vertical-chunking
chunkable-mamba2 / modeling_chunkable_mamba2.py
grantner's picture
chore: clean up masking
b69ee0c verified
from .configuration_chunkable_mamba2 import ChunkableMamba2Config
from transformers.cache_utils import Cache, is_torchdynamo_compiling
from transformers.models.mamba2.modeling_mamba2 import (
Mamba2Block,
Mamba2Mixer,
Mamba2Model,
Mamba2RMSNorm,
apply_mask_to_padding_states,
)
import torch
from torch import nn
mamba_split_conv1d_scan_combined = None
class ChunkableMamba2Mixer(Mamba2Mixer):
def __init__(self, config: ChunkableMamba2Config, layer_idx: int):
super().__init__(config, layer_idx)
self.use_mem_eff_path = config.use_mem_eff_path
global mamba_split_conv1d_scan_combined
if self.use_mem_eff_path and mamba_split_conv1d_scan_combined is None:
from .chunkable_ssd_combined import chunkable_mamba_split_conv1d_scan_combined
mamba_split_conv1d_scan_combined = chunkable_mamba_split_conv1d_scan_combined
def cuda_kernels_forward(
self,
hidden_states: torch.Tensor,
cache_params: Cache | None = None,
attention_mask: torch.Tensor | None = None,
):
if (
cache_params is not None
and cache_params.has_previous_state(self.layer_idx)
) and not self.use_mem_eff_path:
return super().cuda_kernels_forward(
hidden_states=hidden_states,
cache_params=cache_params,
attention_mask=attention_mask,
)
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)
A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
seq_idx = (
(attention_mask[:, -hidden_states.size(1) :] - 1).to(torch.int32)
if attention_mask is not None
else None
)
# 2-4. Fused kernel for conv1d, SSM, and the final projection
out = mamba_split_conv1d_scan_combined(
projected_states,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.dt_bias,
A,
D=self.D,
chunk_size=self.chunk_size,
seq_idx=seq_idx,
activation=self.activation,
rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.variance_epsilon,
outproj_weight=self.out_proj.weight,
outproj_bias=self.out_proj.bias,
headdim=self.head_dim,
ngroups=self.n_groups,
norm_before_gate=False,
initial_conv_states=cache_params.layers[self.layer_idx].conv_states
if cache_params is not None
else None,
initial_ssm_states=cache_params.layers[self.layer_idx].recurrent_states
if cache_params is not None
else None,
return_final_states=cache_params is not None,
**dt_limit_kwargs,
)
if cache_params is not None:
out, conv_states, ssm_state = out
cache_params.layers[self.layer_idx].has_previous_state = False
cache_params.update_conv_state(conv_states, layer_idx=self.layer_idx)
cache_params.update_recurrent_state(ssm_state, layer_idx=self.layer_idx)
return out
def forward(
self,
hidden_states: torch.Tensor,
cache_params: Cache | None = None,
attention_mask: torch.Tensor | None = None,
):
if "cuda" in self.in_proj.weight.device.type and not is_torchdynamo_compiling():
return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
return self.torch_forward(hidden_states, cache_params, attention_mask)
class ChunkableMamba2Block(Mamba2Block):
def __init__(self, config, layer_idx):
super(Mamba2Block, self).__init__()
self.config = config
self.layer_idx = layer_idx
self.residual_in_fp32 = config.residual_in_fp32
self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mixer = ChunkableMamba2Mixer(config, layer_idx=layer_idx)
class ChunkableMamba2Model(Mamba2Model):
config_class = ChunkableMamba2Config
def __init__(self, config):
super(Mamba2Model, self).__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[
ChunkableMamba2Block(config, layer_idx=idx)
for idx in range(config.num_hidden_layers)
]
)
self.gradient_checkpointing = False
self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
# Initialize weights and apply final processing
self._register_load_state_dict_pre_hook(self.load_hook)
self.post_init()