WuChengyue commited on
Commit
d7977da
·
verified ·
1 Parent(s): cb43b83

Add asymmetric hybrid block-causal mask for efficient vision path

Browse files
Files changed (1) hide show
  1. modeling.py +138 -15
modeling.py CHANGED
@@ -81,6 +81,49 @@ def hybrid_block_causal_mask_multiturn(b, h, q_idx, kv_idx, response_block_idx=N
81
  return block_diagonal | offset_block_causal | x0_causal
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
85
  # Compute block indices
86
  block_q = q_idx // block_size
@@ -710,16 +753,42 @@ class Fast_dVLMAttention(nn.Module):
710
  key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
711
  value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
712
 
 
 
 
 
 
713
  cos, sin = position_embeddings
714
  if self.training:
715
- #split q into two parts
716
- q_1 = query_states[:,:,:query_states.shape[2]//2]
717
- q_2 = query_states[:,:,query_states.shape[2]//2:]
718
- #split k into two parts
719
- k_1 = key_states[:,:,:key_states.shape[2]//2]
720
- k_2 = key_states[:,:,key_states.shape[2]//2:]
721
- q_1, k_1 = apply_multimodal_rotary_pos_emb(q_1, k_1, cos, sin, self.rope_scaling["mrope_section"])
722
- q_2, k_2 = apply_multimodal_rotary_pos_emb(q_2, k_2, cos, sin, self.rope_scaling["mrope_section"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
723
  query_states = torch.cat((q_1, q_2), dim=-2)
724
  key_states = torch.cat((k_1, k_2), dim=-2)
725
  else:
@@ -1504,6 +1573,11 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
1504
  self.minimum_noise_level = getattr(config, 'minimum_noise_level', 0.0)
1505
  self.im_end_token_id = 151645 # <|im_end|> token id
1506
 
 
 
 
 
 
1507
  # Vision-to-text aligner (if vision output dim != text hidden dim)
1508
  vision_out_dim = config.vision_config.out_hidden_size
1509
  text_hidden = config.text_config.hidden_size
@@ -1614,6 +1688,27 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
1614
  )
1615
  return mask
1616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1617
  @can_return_tuple
1618
  @auto_docstring
1619
  def forward(
@@ -1726,6 +1821,8 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
1726
  labels = torch.cat([labels, complementary_labels], dim=0)
1727
 
1728
  attention_mask = self.gen_hybrid_block_causal_mask(seq_len, response_block_idx, turn_idx, input_ids.shape[0], self.config.num_attention_heads)
 
 
1729
 
1730
  else:
1731
  # Multimodal block diffusion path.
@@ -1808,11 +1905,26 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
1808
  labels_noisy = labels.clone()
1809
  labels_noisy[~mask_indices] = -100
1810
 
1811
- # Concatenate [noisy | clean] along the sequence dimension.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1812
  input_ids_pair1 = torch.cat([noisy_input_ids, original_input_ids], dim=1)
1813
  embeds_pair1 = torch.cat([noisy_embeds, original_embeds], dim=1)
1814
  labels_pair1 = labels_noisy
1815
- position_ids_pair1 = original_position_ids
1816
 
1817
  # Complementary pair: mask the positions that were left clean above.
1818
  complementary_mask_indices = response_mask & ~mask_indices
@@ -1825,13 +1937,16 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
1825
  complementary_noisy_embeds_raw = self.model.language_model.embed_tokens(complementary_noisy_input_ids)
1826
  complementary_noisy_embeds = torch.where(vision_mask_3d, original_embeds, complementary_noisy_embeds_raw)
1827
 
 
 
1828
  complementary_labels = original_labels.clone()
1829
  complementary_labels[~complementary_mask_indices] = -100
 
1830
 
1831
  input_ids_pair2 = torch.cat([complementary_noisy_input_ids, original_input_ids], dim=1)
1832
  embeds_pair2 = torch.cat([complementary_noisy_embeds, original_embeds], dim=1)
1833
  labels_pair2 = complementary_labels
1834
- position_ids_pair2 = original_position_ids
1835
 
1836
  # Stack the complementary pair along the batch dimension.
1837
  input_ids = torch.cat([input_ids_pair1, input_ids_pair2], dim=0)
@@ -1839,11 +1954,18 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
1839
  labels = torch.cat([labels_pair1, labels_pair2], dim=0)
1840
  position_ids = torch.cat([position_ids_pair1, position_ids_pair2], dim=1)
1841
 
1842
- attention_mask = self.gen_hybrid_block_causal_mask(L, response_block_idx, turn_idx, input_ids.shape[0], self.config.num_attention_heads)
 
 
 
 
1843
 
1844
  # Phase D: forward through the inner model. Vision features (if any)
1845
  # have already been scattered into inputs_embeds, so pixel_values are
1846
- # cleared to skip re-processing inside `Fast_dVLMModel`.
 
 
 
1847
  outputs = self.model(
1848
  input_ids=input_ids,
1849
  pixel_values=None,
@@ -1889,7 +2011,8 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
1889
  loss = None
1890
 
1891
  if self.training:
1892
- mdm_hidden_states = hidden_states[:, :hidden_states.shape[1]//2, :]
 
1893
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1894
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1895
  logits = self.lm_head(mdm_hidden_states[:, slice_indices, :])
@@ -1901,7 +2024,7 @@ class Fast_dVLMForConditionalGeneration(Fast_dVLMPreTrainedModel, GenerationMixi
1901
  loss = self.loss_function(
1902
  logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **new_kwargs
1903
  ) * 0.5
1904
- causal_hidden_states = hidden_states[:hidden_states.shape[0]//2, hidden_states.shape[1]//2:, :]
1905
  causal_logits = self.lm_head(causal_hidden_states[:, slice_indices, :])
1906
  loss += self.loss_function(
1907
  logits=causal_logits, labels=original_labels, vocab_size=self.config.text_config.vocab_size, **new_kwargs
 
81
  return block_diagonal | offset_block_causal | x0_causal
82
 
83
 
84
+ def hybrid_block_causal_mask_multiturn_asymmetric(
85
+ b, h, q_idx, kv_idx,
86
+ turn_idx_noisy=None,
87
+ turn_idx_clean=None,
88
+ n_noisy=None,
89
+ ):
90
+ """
91
+ Asymmetric variant of `hybrid_block_causal_mask_multiturn` used by the
92
+ efficient vision path.
93
+
94
+ Layout: ``[noisy(L_text) | clean(L)]`` where the noisy half drops vision
95
+ tokens, so ``L_text < L``. Separate ``turn_idx`` tensors are required for
96
+ the two halves (the noisy half indexes into the compressed positions, the
97
+ clean half into the original positions). Mask rules are identical to the
98
+ symmetric version:
99
+ * block_diagonal: x_t ↔ x_t within the same turn.
100
+ * offset_block_causal: x_t may attend to x_0 of strictly earlier turns.
101
+ * x0_causal: standard causal masking inside the x_0 region.
102
+ """
103
+ x0_flag_q = (q_idx >= n_noisy)
104
+ x0_flag_kv = (kv_idx >= n_noisy)
105
+
106
+ pos_q = torch.where(x0_flag_q, q_idx - n_noisy, q_idx)
107
+ pos_kv = torch.where(x0_flag_kv, kv_idx - n_noisy, kv_idx)
108
+
109
+ turn_q = torch.where(
110
+ x0_flag_q,
111
+ turn_idx_clean[torch.clamp(pos_q, max=turn_idx_clean.shape[0] - 1)],
112
+ turn_idx_noisy[torch.clamp(pos_q, max=turn_idx_noisy.shape[0] - 1)],
113
+ )
114
+ turn_kv = torch.where(
115
+ x0_flag_kv,
116
+ turn_idx_clean[torch.clamp(pos_kv, max=turn_idx_clean.shape[0] - 1)],
117
+ turn_idx_noisy[torch.clamp(pos_kv, max=turn_idx_noisy.shape[0] - 1)],
118
+ )
119
+
120
+ block_diagonal = ~x0_flag_q & ~x0_flag_kv & (turn_q == turn_kv)
121
+ offset_block_causal = (turn_q > turn_kv) & x0_flag_kv & (~x0_flag_q)
122
+ x0_causal = x0_flag_q & x0_flag_kv & (pos_q >= pos_kv)
123
+
124
+ return block_diagonal | offset_block_causal | x0_causal
125
+
126
+
127
  def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
128
  # Compute block indices
129
  block_q = q_idx // block_size
 
753
  key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
754
  value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
755
 
756
+ # `noisy_seq_len` is an MDM-specific kwarg; pop it before the value can
757
+ # leak into attention backends (e.g. flash_attention_2) that don't
758
+ # understand it.
759
+ noisy_seq_len = kwargs.pop("noisy_seq_len", None)
760
+
761
  cos, sin = position_embeddings
762
  if self.training:
763
+ total_seq_len = query_states.shape[2]
764
+ # The noisy half can be shorter than the clean half (multimodal
765
+ # batches drop vision tokens from the noisy side). When the caller
766
+ # tells us its length explicitly we honor it; otherwise we fall back
767
+ # to the symmetric split used for text-only batches.
768
+ if noisy_seq_len is not None:
769
+ noisy_len = int(noisy_seq_len)
770
+ else:
771
+ noisy_len = total_seq_len // 2
772
+
773
+ q_1 = query_states[:, :, :noisy_len]
774
+ q_2 = query_states[:, :, noisy_len:]
775
+ k_1 = key_states[:, :, :noisy_len]
776
+ k_2 = key_states[:, :, noisy_len:]
777
+
778
+ if cos.shape[2] >= total_seq_len:
779
+ cos_1 = cos[:, :, :noisy_len, :]
780
+ sin_1 = sin[:, :, :noisy_len, :]
781
+ cos_2 = cos[:, :, noisy_len:, :]
782
+ sin_2 = sin[:, :, noisy_len:, :]
783
+ else:
784
+ # `position_ids` only covers the clean half length. Both halves
785
+ # share the same RoPE — valid only for the symmetric layout
786
+ # where noisy_len == clean_len.
787
+ cos_1, sin_1 = cos, sin
788
+ cos_2, sin_2 = cos, sin
789
+
790
+ q_1, k_1 = apply_multimodal_rotary_pos_emb(q_1, k_1, cos_1, sin_1, self.rope_scaling["mrope_section"])
791
+ q_2, k_2 = apply_multimodal_rotary_pos_emb(q_2, k_2, cos_2, sin_2, self.rope_scaling["mrope_section"])
792
  query_states = torch.cat((q_1, q_2), dim=-2)
793
  key_states = torch.cat((k_1, k_2), dim=-2)
794
  else:
 
1573
  self.minimum_noise_level = getattr(config, 'minimum_noise_level', 0.0)
1574
  self.im_end_token_id = 151645 # <|im_end|> token id
1575
 
1576
+ # Length of the noisy half passed through to attention. For text-only
1577
+ # batches it equals the (symmetric) sequence length; for multimodal
1578
+ # batches the noisy half drops vision tokens, so it is shorter.
1579
+ self._noisy_seq_len: Optional[int] = None
1580
+
1581
  # Vision-to-text aligner (if vision output dim != text hidden dim)
1582
  vision_out_dim = config.vision_config.out_hidden_size
1583
  text_hidden = config.text_config.hidden_size
 
1688
  )
1689
  return mask
1690
 
1691
+ def gen_hybrid_block_causal_mask_asymmetric(
1692
+ self, L_text, L_clean, turn_idx_noisy, turn_idx_clean, B, H
1693
+ ):
1694
+ """Generate the asymmetric hybrid mask used by the efficient vision path.
1695
+
1696
+ Layout: ``[noisy(L_text) | clean(L_clean)]`` where vision tokens have
1697
+ been removed from the noisy half.
1698
+ """
1699
+ n_noisy_t = torch.tensor(L_text, device=self.device, dtype=torch.int32)
1700
+ total = L_text + L_clean
1701
+ mask = create_block_mask(
1702
+ partial(
1703
+ hybrid_block_causal_mask_multiturn_asymmetric,
1704
+ turn_idx_noisy=turn_idx_noisy,
1705
+ turn_idx_clean=turn_idx_clean,
1706
+ n_noisy=n_noisy_t,
1707
+ ),
1708
+ B=B, H=H, Q_LEN=total, KV_LEN=total,
1709
+ )
1710
+ return mask
1711
+
1712
  @can_return_tuple
1713
  @auto_docstring
1714
  def forward(
 
1821
  labels = torch.cat([labels, complementary_labels], dim=0)
1822
 
1823
  attention_mask = self.gen_hybrid_block_causal_mask(seq_len, response_block_idx, turn_idx, input_ids.shape[0], self.config.num_attention_heads)
1824
+ # Text-only path: noisy and clean halves have identical length.
1825
+ self._noisy_seq_len = seq_len
1826
 
1827
  else:
1828
  # Multimodal block diffusion path.
 
1905
  labels_noisy = labels.clone()
1906
  labels_noisy[~mask_indices] = -100
1907
 
1908
+ # Efficient vision: drop vision tokens from the noisy half so the
1909
+ # model only attends to the (much shorter) text portion on that
1910
+ # side, while the clean half keeps the full sequence so visual
1911
+ # context is still available via cross-attention.
1912
+ text_positions = (~vision_token_mask[0]).nonzero(as_tuple=True)[0]
1913
+ L_text = text_positions.shape[0]
1914
+
1915
+ noisy_embeds = noisy_embeds[:, text_positions, :]
1916
+ noisy_input_ids = noisy_input_ids[:, text_positions]
1917
+ labels_noisy = labels_noisy[:, text_positions]
1918
+
1919
+ noisy_position_ids = original_position_ids[:, :, text_positions]
1920
+ combined_position_ids = torch.cat([noisy_position_ids, original_position_ids], dim=2)
1921
+ turn_idx_noisy = turn_idx[text_positions]
1922
+
1923
+ # Concatenate [noisy(L_text) | clean(L)] along the sequence dim.
1924
  input_ids_pair1 = torch.cat([noisy_input_ids, original_input_ids], dim=1)
1925
  embeds_pair1 = torch.cat([noisy_embeds, original_embeds], dim=1)
1926
  labels_pair1 = labels_noisy
1927
+ position_ids_pair1 = combined_position_ids
1928
 
1929
  # Complementary pair: mask the positions that were left clean above.
1930
  complementary_mask_indices = response_mask & ~mask_indices
 
1937
  complementary_noisy_embeds_raw = self.model.language_model.embed_tokens(complementary_noisy_input_ids)
1938
  complementary_noisy_embeds = torch.where(vision_mask_3d, original_embeds, complementary_noisy_embeds_raw)
1939
 
1940
+ complementary_noisy_embeds = complementary_noisy_embeds[:, text_positions, :]
1941
+ complementary_noisy_input_ids = complementary_noisy_input_ids[:, text_positions]
1942
  complementary_labels = original_labels.clone()
1943
  complementary_labels[~complementary_mask_indices] = -100
1944
+ complementary_labels = complementary_labels[:, text_positions]
1945
 
1946
  input_ids_pair2 = torch.cat([complementary_noisy_input_ids, original_input_ids], dim=1)
1947
  embeds_pair2 = torch.cat([complementary_noisy_embeds, original_embeds], dim=1)
1948
  labels_pair2 = complementary_labels
1949
+ position_ids_pair2 = combined_position_ids
1950
 
1951
  # Stack the complementary pair along the batch dimension.
1952
  input_ids = torch.cat([input_ids_pair1, input_ids_pair2], dim=0)
 
1954
  labels = torch.cat([labels_pair1, labels_pair2], dim=0)
1955
  position_ids = torch.cat([position_ids_pair1, position_ids_pair2], dim=1)
1956
 
1957
+ attention_mask = self.gen_hybrid_block_causal_mask_asymmetric(
1958
+ L_text, L, turn_idx_noisy, turn_idx,
1959
+ input_ids.shape[0], self.config.num_attention_heads,
1960
+ )
1961
+ self._noisy_seq_len = L_text
1962
 
1963
  # Phase D: forward through the inner model. Vision features (if any)
1964
  # have already been scattered into inputs_embeds, so pixel_values are
1965
+ # cleared to skip re-processing inside `Fast_dVLMModel`. The noisy
1966
+ # half length is forwarded as a kwarg so attention can split the
1967
+ # asymmetric `[noisy | clean]` layout correctly.
1968
+ kwargs['noisy_seq_len'] = self._noisy_seq_len
1969
  outputs = self.model(
1970
  input_ids=input_ids,
1971
  pixel_values=None,
 
2011
  loss = None
2012
 
2013
  if self.training:
2014
+ noisy_len = self._noisy_seq_len
2015
+ mdm_hidden_states = hidden_states[:, :noisy_len, :]
2016
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
2017
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
2018
  logits = self.lm_head(mdm_hidden_states[:, slice_indices, :])
 
2024
  loss = self.loss_function(
2025
  logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **new_kwargs
2026
  ) * 0.5
2027
+ causal_hidden_states = hidden_states[:hidden_states.shape[0]//2, noisy_len:, :]
2028
  causal_logits = self.lm_head(causal_hidden_states[:, slice_indices, :])
2029
  loss += self.loss_function(
2030
  logits=causal_logits, labels=original_labels, vocab_size=self.config.text_config.vocab_size, **new_kwargs