Instructions to use dynatrace-oss/chunkable-mamba2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use dynatrace-oss/chunkable-mamba2 with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("dynatrace-oss/chunkable-mamba2", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from .configuration_chunkable_mamba2 import ChunkableMamba2Config | |
| from transformers.cache_utils import Cache, is_torchdynamo_compiling | |
| from transformers.models.mamba2.modeling_mamba2 import ( | |
| Mamba2Block, | |
| Mamba2Mixer, | |
| Mamba2Model, | |
| Mamba2RMSNorm, | |
| apply_mask_to_padding_states, | |
| ) | |
| import torch | |
| from torch import nn | |
| mamba_split_conv1d_scan_combined = None | |
| class ChunkableMamba2Mixer(Mamba2Mixer): | |
| def __init__(self, config: ChunkableMamba2Config, layer_idx: int): | |
| super().__init__(config, layer_idx) | |
| self.use_mem_eff_path = config.use_mem_eff_path | |
| global mamba_split_conv1d_scan_combined | |
| if self.use_mem_eff_path and mamba_split_conv1d_scan_combined is None: | |
| from .chunkable_ssd_combined import chunkable_mamba_split_conv1d_scan_combined | |
| mamba_split_conv1d_scan_combined = chunkable_mamba_split_conv1d_scan_combined | |
| def cuda_kernels_forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cache_params: Cache | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| ): | |
| if ( | |
| cache_params is not None | |
| and cache_params.has_previous_state(self.layer_idx) | |
| ) and not self.use_mem_eff_path: | |
| return super().cuda_kernels_forward( | |
| hidden_states=hidden_states, | |
| cache_params=cache_params, | |
| attention_mask=attention_mask, | |
| ) | |
| # 1. Gated MLP's linear projection | |
| projected_states = self.in_proj(hidden_states) | |
| A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) | |
| dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} | |
| seq_idx = ( | |
| (attention_mask[:, -hidden_states.size(1) :] - 1).to(torch.int32) | |
| if attention_mask is not None | |
| else None | |
| ) | |
| # 2-4. Fused kernel for conv1d, SSM, and the final projection | |
| out = mamba_split_conv1d_scan_combined( | |
| projected_states, | |
| self.conv1d.weight.squeeze(1), | |
| self.conv1d.bias, | |
| self.dt_bias, | |
| A, | |
| D=self.D, | |
| chunk_size=self.chunk_size, | |
| seq_idx=seq_idx, | |
| activation=self.activation, | |
| rmsnorm_weight=self.norm.weight, | |
| rmsnorm_eps=self.norm.variance_epsilon, | |
| outproj_weight=self.out_proj.weight, | |
| outproj_bias=self.out_proj.bias, | |
| headdim=self.head_dim, | |
| ngroups=self.n_groups, | |
| norm_before_gate=False, | |
| initial_conv_states=cache_params.layers[self.layer_idx].conv_states | |
| if cache_params is not None | |
| else None, | |
| initial_ssm_states=cache_params.layers[self.layer_idx].recurrent_states | |
| if cache_params is not None | |
| else None, | |
| return_final_states=cache_params is not None, | |
| **dt_limit_kwargs, | |
| ) | |
| if cache_params is not None: | |
| out, conv_states, ssm_state = out | |
| cache_params.layers[self.layer_idx].has_previous_state = False | |
| cache_params.update_conv_state(conv_states, layer_idx=self.layer_idx) | |
| cache_params.update_recurrent_state(ssm_state, layer_idx=self.layer_idx) | |
| return out | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cache_params: Cache | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| ): | |
| if "cuda" in self.in_proj.weight.device.type and not is_torchdynamo_compiling(): | |
| return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask) | |
| return self.torch_forward(hidden_states, cache_params, attention_mask) | |
| class ChunkableMamba2Block(Mamba2Block): | |
| def __init__(self, config, layer_idx): | |
| super(Mamba2Block, self).__init__() | |
| self.config = config | |
| self.layer_idx = layer_idx | |
| self.residual_in_fp32 = config.residual_in_fp32 | |
| self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) | |
| self.mixer = ChunkableMamba2Mixer(config, layer_idx=layer_idx) | |
| class ChunkableMamba2Model(Mamba2Model): | |
| config_class = ChunkableMamba2Config | |
| def __init__(self, config): | |
| super(Mamba2Model, self).__init__(config) | |
| self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) | |
| self.layers = nn.ModuleList( | |
| [ | |
| ChunkableMamba2Block(config, layer_idx=idx) | |
| for idx in range(config.num_hidden_layers) | |
| ] | |
| ) | |
| self.gradient_checkpointing = False | |
| self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) | |
| # Initialize weights and apply final processing | |
| self._register_load_state_dict_pre_hook(self.load_hook) | |
| self.post_init() | |