Add asymmetric hybrid block-causal mask for efficient vision path
Browse files- modeling.py +138 -15
modeling.py
CHANGED
|
@@ -81,6 +81,49 @@ def hybrid_block_causal_mask_multiturn(b, h, q_idx, kv_idx, response_block_idx=N
|
|
| 81 |
return block_diagonal | offset_block_causal | x0_causal
|
| 82 |
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
|
| 85 |
# Compute block indices
|
| 86 |
block_q = q_idx // block_size
|
|
@@ -710,16 +753,42 @@ class Fast_dVLMAttention(nn.Module):
|
|
| 710 |
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
| 711 |
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
| 712 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
cos, sin = position_embeddings
|
| 714 |
if self.training:
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
#
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 723 |
query_states = torch.cat((q_1, q_2), dim=-2)
|
| 724 |
key_states = torch.cat((k_1, k_2), dim=-2)
|
| 725 |
else:
|
|
@@ -1504,6 +1573,11 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 1504 |
self.minimum_noise_level = getattr(config, 'minimum_noise_level', 0.0)
|
| 1505 |
self.im_end_token_id = 151645 # <|im_end|> token id
|
| 1506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1507 |
# Vision-to-text aligner (if vision output dim != text hidden dim)
|
| 1508 |
vision_out_dim = config.vision_config.out_hidden_size
|
| 1509 |
text_hidden = config.text_config.hidden_size
|
|
@@ -1614,6 +1688,27 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 1614 |
)
|
| 1615 |
return mask
|
| 1616 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1617 |
@can_return_tuple
|
| 1618 |
@auto_docstring
|
| 1619 |
def forward(
|
|
@@ -1726,6 +1821,8 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 1726 |
labels = torch.cat([labels, complementary_labels], dim=0)
|
| 1727 |
|
| 1728 |
attention_mask = self.gen_hybrid_block_causal_mask(seq_len, response_block_idx, turn_idx, input_ids.shape[0], self.config.num_attention_heads)
|
|
|
|
|
|
|
| 1729 |
|
| 1730 |
else:
|
| 1731 |
# Multimodal block diffusion path.
|
|
@@ -1808,11 +1905,26 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 1808 |
labels_noisy = labels.clone()
|
| 1809 |
labels_noisy[~mask_indices] = -100
|
| 1810 |
|
| 1811 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1812 |
input_ids_pair1 = torch.cat([noisy_input_ids, original_input_ids], dim=1)
|
| 1813 |
embeds_pair1 = torch.cat([noisy_embeds, original_embeds], dim=1)
|
| 1814 |
labels_pair1 = labels_noisy
|
| 1815 |
-
position_ids_pair1 =
|
| 1816 |
|
| 1817 |
# Complementary pair: mask the positions that were left clean above.
|
| 1818 |
complementary_mask_indices = response_mask & ~mask_indices
|
|
@@ -1825,13 +1937,16 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 1825 |
complementary_noisy_embeds_raw = self.model.language_model.embed_tokens(complementary_noisy_input_ids)
|
| 1826 |
complementary_noisy_embeds = torch.where(vision_mask_3d, original_embeds, complementary_noisy_embeds_raw)
|
| 1827 |
|
|
|
|
|
|
|
| 1828 |
complementary_labels = original_labels.clone()
|
| 1829 |
complementary_labels[~complementary_mask_indices] = -100
|
|
|
|
| 1830 |
|
| 1831 |
input_ids_pair2 = torch.cat([complementary_noisy_input_ids, original_input_ids], dim=1)
|
| 1832 |
embeds_pair2 = torch.cat([complementary_noisy_embeds, original_embeds], dim=1)
|
| 1833 |
labels_pair2 = complementary_labels
|
| 1834 |
-
position_ids_pair2 =
|
| 1835 |
|
| 1836 |
# Stack the complementary pair along the batch dimension.
|
| 1837 |
input_ids = torch.cat([input_ids_pair1, input_ids_pair2], dim=0)
|
|
@@ -1839,11 +1954,18 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 1839 |
labels = torch.cat([labels_pair1, labels_pair2], dim=0)
|
| 1840 |
position_ids = torch.cat([position_ids_pair1, position_ids_pair2], dim=1)
|
| 1841 |
|
| 1842 |
-
attention_mask = self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1843 |
|
| 1844 |
# Phase D: forward through the inner model. Vision features (if any)
|
| 1845 |
# have already been scattered into inputs_embeds, so pixel_values are
|
| 1846 |
-
# cleared to skip re-processing inside `Fast_dVLMModel`.
|
|
|
|
|
|
|
|
|
|
| 1847 |
outputs = self.model(
|
| 1848 |
input_ids=input_ids,
|
| 1849 |
pixel_values=None,
|
|
@@ -1889,7 +2011,8 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 1889 |
loss = None
|
| 1890 |
|
| 1891 |
if self.training:
|
| 1892 |
-
|
|
|
|
| 1893 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1894 |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1895 |
logits = self.lm_head(mdm_hidden_states[:, slice_indices, :])
|
|
@@ -1901,7 +2024,7 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
|
|
| 1901 |
loss = self.loss_function(
|
| 1902 |
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **new_kwargs
|
| 1903 |
) * 0.5
|
| 1904 |
-
causal_hidden_states = hidden_states[:hidden_states.shape[0]//2,
|
| 1905 |
causal_logits = self.lm_head(causal_hidden_states[:, slice_indices, :])
|
| 1906 |
loss += self.loss_function(
|
| 1907 |
logits=causal_logits, labels=original_labels, vocab_size=self.config.text_config.vocab_size, **new_kwargs
|
|
|
|
| 81 |
return block_diagonal | offset_block_causal | x0_causal
|
| 82 |
|
| 83 |
|
| 84 |
+
def hybrid_block_causal_mask_multiturn_asymmetric(
|
| 85 |
+
b, h, q_idx, kv_idx,
|
| 86 |
+
turn_idx_noisy=None,
|
| 87 |
+
turn_idx_clean=None,
|
| 88 |
+
n_noisy=None,
|
| 89 |
+
):
|
| 90 |
+
"""
|
| 91 |
+
Asymmetric variant of `hybrid_block_causal_mask_multiturn` used by the
|
| 92 |
+
efficient vision path.
|
| 93 |
+
|
| 94 |
+
Layout: ``[noisy(L_text) | clean(L)]`` where the noisy half drops vision
|
| 95 |
+
tokens, so ``L_text < L``. Separate ``turn_idx`` tensors are required for
|
| 96 |
+
the two halves (the noisy half indexes into the compressed positions, the
|
| 97 |
+
clean half into the original positions). Mask rules are identical to the
|
| 98 |
+
symmetric version:
|
| 99 |
+
* block_diagonal: x_t ↔ x_t within the same turn.
|
| 100 |
+
* offset_block_causal: x_t may attend to x_0 of strictly earlier turns.
|
| 101 |
+
* x0_causal: standard causal masking inside the x_0 region.
|
| 102 |
+
"""
|
| 103 |
+
x0_flag_q = (q_idx >= n_noisy)
|
| 104 |
+
x0_flag_kv = (kv_idx >= n_noisy)
|
| 105 |
+
|
| 106 |
+
pos_q = torch.where(x0_flag_q, q_idx - n_noisy, q_idx)
|
| 107 |
+
pos_kv = torch.where(x0_flag_kv, kv_idx - n_noisy, kv_idx)
|
| 108 |
+
|
| 109 |
+
turn_q = torch.where(
|
| 110 |
+
x0_flag_q,
|
| 111 |
+
turn_idx_clean[torch.clamp(pos_q, max=turn_idx_clean.shape[0] - 1)],
|
| 112 |
+
turn_idx_noisy[torch.clamp(pos_q, max=turn_idx_noisy.shape[0] - 1)],
|
| 113 |
+
)
|
| 114 |
+
turn_kv = torch.where(
|
| 115 |
+
x0_flag_kv,
|
| 116 |
+
turn_idx_clean[torch.clamp(pos_kv, max=turn_idx_clean.shape[0] - 1)],
|
| 117 |
+
turn_idx_noisy[torch.clamp(pos_kv, max=turn_idx_noisy.shape[0] - 1)],
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
block_diagonal = ~x0_flag_q & ~x0_flag_kv & (turn_q == turn_kv)
|
| 121 |
+
offset_block_causal = (turn_q > turn_kv) & x0_flag_kv & (~x0_flag_q)
|
| 122 |
+
x0_causal = x0_flag_q & x0_flag_kv & (pos_q >= pos_kv)
|
| 123 |
+
|
| 124 |
+
return block_diagonal | offset_block_causal | x0_causal
|
| 125 |
+
|
| 126 |
+
|
| 127 |
def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
|
| 128 |
# Compute block indices
|
| 129 |
block_q = q_idx // block_size
|
|
|
|
| 753 |
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
| 754 |
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
| 755 |
|
| 756 |
+
# `noisy_seq_len` is an MDM-specific kwarg; pop it before the value can
|
| 757 |
+
# leak into attention backends (e.g. flash_attention_2) that don't
|
| 758 |
+
# understand it.
|
| 759 |
+
noisy_seq_len = kwargs.pop("noisy_seq_len", None)
|
| 760 |
+
|
| 761 |
cos, sin = position_embeddings
|
| 762 |
if self.training:
|
| 763 |
+
total_seq_len = query_states.shape[2]
|
| 764 |
+
# The noisy half can be shorter than the clean half (multimodal
|
| 765 |
+
# batches drop vision tokens from the noisy side). When the caller
|
| 766 |
+
# tells us its length explicitly we honor it; otherwise we fall back
|
| 767 |
+
# to the symmetric split used for text-only batches.
|
| 768 |
+
if noisy_seq_len is not None:
|
| 769 |
+
noisy_len = int(noisy_seq_len)
|
| 770 |
+
else:
|
| 771 |
+
noisy_len = total_seq_len // 2
|
| 772 |
+
|
| 773 |
+
q_1 = query_states[:, :, :noisy_len]
|
| 774 |
+
q_2 = query_states[:, :, noisy_len:]
|
| 775 |
+
k_1 = key_states[:, :, :noisy_len]
|
| 776 |
+
k_2 = key_states[:, :, noisy_len:]
|
| 777 |
+
|
| 778 |
+
if cos.shape[2] >= total_seq_len:
|
| 779 |
+
cos_1 = cos[:, :, :noisy_len, :]
|
| 780 |
+
sin_1 = sin[:, :, :noisy_len, :]
|
| 781 |
+
cos_2 = cos[:, :, noisy_len:, :]
|
| 782 |
+
sin_2 = sin[:, :, noisy_len:, :]
|
| 783 |
+
else:
|
| 784 |
+
# `position_ids` only covers the clean half length. Both halves
|
| 785 |
+
# share the same RoPE — valid only for the symmetric layout
|
| 786 |
+
# where noisy_len == clean_len.
|
| 787 |
+
cos_1, sin_1 = cos, sin
|
| 788 |
+
cos_2, sin_2 = cos, sin
|
| 789 |
+
|
| 790 |
+
q_1, k_1 = apply_multimodal_rotary_pos_emb(q_1, k_1, cos_1, sin_1, self.rope_scaling["mrope_section"])
|
| 791 |
+
q_2, k_2 = apply_multimodal_rotary_pos_emb(q_2, k_2, cos_2, sin_2, self.rope_scaling["mrope_section"])
|
| 792 |
query_states = torch.cat((q_1, q_2), dim=-2)
|
| 793 |
key_states = torch.cat((k_1, k_2), dim=-2)
|
| 794 |
else:
|
|
|
|
| 1573 |
self.minimum_noise_level = getattr(config, 'minimum_noise_level', 0.0)
|
| 1574 |
self.im_end_token_id = 151645 # <|im_end|> token id
|
| 1575 |
|
| 1576 |
+
# Length of the noisy half passed through to attention. For text-only
|
| 1577 |
+
# batches it equals the (symmetric) sequence length; for multimodal
|
| 1578 |
+
# batches the noisy half drops vision tokens, so it is shorter.
|
| 1579 |
+
self._noisy_seq_len: Optional[int] = None
|
| 1580 |
+
|
| 1581 |
# Vision-to-text aligner (if vision output dim != text hidden dim)
|
| 1582 |
vision_out_dim = config.vision_config.out_hidden_size
|
| 1583 |
text_hidden = config.text_config.hidden_size
|
|
|
|
| 1688 |
)
|
| 1689 |
return mask
|
| 1690 |
|
| 1691 |
+
def gen_hybrid_block_causal_mask_asymmetric(
|
| 1692 |
+
self, L_text, L_clean, turn_idx_noisy, turn_idx_clean, B, H
|
| 1693 |
+
):
|
| 1694 |
+
"""Generate the asymmetric hybrid mask used by the efficient vision path.
|
| 1695 |
+
|
| 1696 |
+
Layout: ``[noisy(L_text) | clean(L_clean)]`` where vision tokens have
|
| 1697 |
+
been removed from the noisy half.
|
| 1698 |
+
"""
|
| 1699 |
+
n_noisy_t = torch.tensor(L_text, device=self.device, dtype=torch.int32)
|
| 1700 |
+
total = L_text + L_clean
|
| 1701 |
+
mask = create_block_mask(
|
| 1702 |
+
partial(
|
| 1703 |
+
hybrid_block_causal_mask_multiturn_asymmetric,
|
| 1704 |
+
turn_idx_noisy=turn_idx_noisy,
|
| 1705 |
+
turn_idx_clean=turn_idx_clean,
|
| 1706 |
+
n_noisy=n_noisy_t,
|
| 1707 |
+
),
|
| 1708 |
+
B=B, H=H, Q_LEN=total, KV_LEN=total,
|
| 1709 |
+
)
|
| 1710 |
+
return mask
|
| 1711 |
+
|
| 1712 |
@can_return_tuple
|
| 1713 |
@auto_docstring
|
| 1714 |
def forward(
|
|
|
|
| 1821 |
labels = torch.cat([labels, complementary_labels], dim=0)
|
| 1822 |
|
| 1823 |
attention_mask = self.gen_hybrid_block_causal_mask(seq_len, response_block_idx, turn_idx, input_ids.shape[0], self.config.num_attention_heads)
|
| 1824 |
+
# Text-only path: noisy and clean halves have identical length.
|
| 1825 |
+
self._noisy_seq_len = seq_len
|
| 1826 |
|
| 1827 |
else:
|
| 1828 |
# Multimodal block diffusion path.
|
|
|
|
| 1905 |
labels_noisy = labels.clone()
|
| 1906 |
labels_noisy[~mask_indices] = -100
|
| 1907 |
|
| 1908 |
+
# Efficient vision: drop vision tokens from the noisy half so the
|
| 1909 |
+
# model only attends to the (much shorter) text portion on that
|
| 1910 |
+
# side, while the clean half keeps the full sequence so visual
|
| 1911 |
+
# context is still available via cross-attention.
|
| 1912 |
+
text_positions = (~vision_token_mask[0]).nonzero(as_tuple=True)[0]
|
| 1913 |
+
L_text = text_positions.shape[0]
|
| 1914 |
+
|
| 1915 |
+
noisy_embeds = noisy_embeds[:, text_positions, :]
|
| 1916 |
+
noisy_input_ids = noisy_input_ids[:, text_positions]
|
| 1917 |
+
labels_noisy = labels_noisy[:, text_positions]
|
| 1918 |
+
|
| 1919 |
+
noisy_position_ids = original_position_ids[:, :, text_positions]
|
| 1920 |
+
combined_position_ids = torch.cat([noisy_position_ids, original_position_ids], dim=2)
|
| 1921 |
+
turn_idx_noisy = turn_idx[text_positions]
|
| 1922 |
+
|
| 1923 |
+
# Concatenate [noisy(L_text) | clean(L)] along the sequence dim.
|
| 1924 |
input_ids_pair1 = torch.cat([noisy_input_ids, original_input_ids], dim=1)
|
| 1925 |
embeds_pair1 = torch.cat([noisy_embeds, original_embeds], dim=1)
|
| 1926 |
labels_pair1 = labels_noisy
|
| 1927 |
+
position_ids_pair1 = combined_position_ids
|
| 1928 |
|
| 1929 |
# Complementary pair: mask the positions that were left clean above.
|
| 1930 |
complementary_mask_indices = response_mask & ~mask_indices
|
|
|
|
| 1937 |
complementary_noisy_embeds_raw = self.model.language_model.embed_tokens(complementary_noisy_input_ids)
|
| 1938 |
complementary_noisy_embeds = torch.where(vision_mask_3d, original_embeds, complementary_noisy_embeds_raw)
|
| 1939 |
|
| 1940 |
+
complementary_noisy_embeds = complementary_noisy_embeds[:, text_positions, :]
|
| 1941 |
+
complementary_noisy_input_ids = complementary_noisy_input_ids[:, text_positions]
|
| 1942 |
complementary_labels = original_labels.clone()
|
| 1943 |
complementary_labels[~complementary_mask_indices] = -100
|
| 1944 |
+
complementary_labels = complementary_labels[:, text_positions]
|
| 1945 |
|
| 1946 |
input_ids_pair2 = torch.cat([complementary_noisy_input_ids, original_input_ids], dim=1)
|
| 1947 |
embeds_pair2 = torch.cat([complementary_noisy_embeds, original_embeds], dim=1)
|
| 1948 |
labels_pair2 = complementary_labels
|
| 1949 |
+
position_ids_pair2 = combined_position_ids
|
| 1950 |
|
| 1951 |
# Stack the complementary pair along the batch dimension.
|
| 1952 |
input_ids = torch.cat([input_ids_pair1, input_ids_pair2], dim=0)
|
|
|
|
| 1954 |
labels = torch.cat([labels_pair1, labels_pair2], dim=0)
|
| 1955 |
position_ids = torch.cat([position_ids_pair1, position_ids_pair2], dim=1)
|
| 1956 |
|
| 1957 |
+
attention_mask = self.gen_hybrid_block_causal_mask_asymmetric(
|
| 1958 |
+
L_text, L, turn_idx_noisy, turn_idx,
|
| 1959 |
+
input_ids.shape[0], self.config.num_attention_heads,
|
| 1960 |
+
)
|
| 1961 |
+
self._noisy_seq_len = L_text
|
| 1962 |
|
| 1963 |
# Phase D: forward through the inner model. Vision features (if any)
|
| 1964 |
# have already been scattered into inputs_embeds, so pixel_values are
|
| 1965 |
+
# cleared to skip re-processing inside `Fast_dVLMModel`. The noisy
|
| 1966 |
+
# half length is forwarded as a kwarg so attention can split the
|
| 1967 |
+
# asymmetric `[noisy | clean]` layout correctly.
|
| 1968 |
+
kwargs['noisy_seq_len'] = self._noisy_seq_len
|
| 1969 |
outputs = self.model(
|
| 1970 |
input_ids=input_ids,
|
| 1971 |
pixel_values=None,
|
|
|
|
| 2011 |
loss = None
|
| 2012 |
|
| 2013 |
if self.training:
|
| 2014 |
+
noisy_len = self._noisy_seq_len
|
| 2015 |
+
mdm_hidden_states = hidden_states[:, :noisy_len, :]
|
| 2016 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 2017 |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 2018 |
logits = self.lm_head(mdm_hidden_states[:, slice_indices, :])
|
|
|
|
| 2024 |
loss = self.loss_function(
|
| 2025 |
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **new_kwargs
|
| 2026 |
) * 0.5
|
| 2027 |
+
causal_hidden_states = hidden_states[:hidden_states.shape[0]//2, noisy_len:, :]
|
| 2028 |
causal_logits = self.lm_head(causal_hidden_states[:, slice_indices, :])
|
| 2029 |
loss += self.loss_function(
|
| 2030 |
logits=causal_logits, labels=original_labels, vocab_size=self.config.text_config.vocab_size, **new_kwargs
|