from .configuration_chunkable_mamba2 import ChunkableMamba2Config from transformers.cache_utils import Cache, is_torchdynamo_compiling from transformers.models.mamba2.modeling_mamba2 import ( Mamba2Block, Mamba2Mixer, Mamba2Model, Mamba2RMSNorm, apply_mask_to_padding_states, ) import torch from torch import nn mamba_split_conv1d_scan_combined = None class ChunkableMamba2Mixer(Mamba2Mixer): def __init__(self, config: ChunkableMamba2Config, layer_idx: int): super().__init__(config, layer_idx) self.use_mem_eff_path = config.use_mem_eff_path global mamba_split_conv1d_scan_combined if self.use_mem_eff_path and mamba_split_conv1d_scan_combined is None: from .chunkable_ssd_combined import chunkable_mamba_split_conv1d_scan_combined mamba_split_conv1d_scan_combined = chunkable_mamba_split_conv1d_scan_combined def cuda_kernels_forward( self, hidden_states: torch.Tensor, cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): if ( cache_params is not None and cache_params.has_previous_state(self.layer_idx) ) and not self.use_mem_eff_path: return super().cuda_kernels_forward( hidden_states=hidden_states, cache_params=cache_params, attention_mask=attention_mask, ) # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} seq_idx = ( (attention_mask[:, -hidden_states.size(1) :] - 1).to(torch.int32) if attention_mask is not None else None ) # 2-4. Fused kernel for conv1d, SSM, and the final projection out = mamba_split_conv1d_scan_combined( projected_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.dt_bias, A, D=self.D, chunk_size=self.chunk_size, seq_idx=seq_idx, activation=self.activation, rmsnorm_weight=self.norm.weight, rmsnorm_eps=self.norm.variance_epsilon, outproj_weight=self.out_proj.weight, outproj_bias=self.out_proj.bias, headdim=self.head_dim, ngroups=self.n_groups, norm_before_gate=False, initial_conv_states=cache_params.layers[self.layer_idx].conv_states if cache_params is not None else None, initial_ssm_states=cache_params.layers[self.layer_idx].recurrent_states if cache_params is not None else None, return_final_states=cache_params is not None, **dt_limit_kwargs, ) if cache_params is not None: out, conv_states, ssm_state = out cache_params.layers[self.layer_idx].has_previous_state = False cache_params.update_conv_state(conv_states, layer_idx=self.layer_idx) cache_params.update_recurrent_state(ssm_state, layer_idx=self.layer_idx) return out def forward( self, hidden_states: torch.Tensor, cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): if "cuda" in self.in_proj.weight.device.type and not is_torchdynamo_compiling(): return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask) return self.torch_forward(hidden_states, cache_params, attention_mask) class ChunkableMamba2Block(Mamba2Block): def __init__(self, config, layer_idx): super(Mamba2Block, self).__init__() self.config = config self.layer_idx = layer_idx self.residual_in_fp32 = config.residual_in_fp32 self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.mixer = ChunkableMamba2Mixer(config, layer_idx=layer_idx) class ChunkableMamba2Model(Mamba2Model): config_class = ChunkableMamba2Config def __init__(self, config): super(Mamba2Model, self).__init__(config) self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList( [ ChunkableMamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers) ] ) self.gradient_checkpointing = False self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) # Initialize weights and apply final processing self._register_load_state_dict_pre_hook(self.load_hook) self.post_init()