hf_moondream.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  import torch.nn as nn
 
3
  from transformers import PreTrainedModel, PretrainedConfig
4
  from typing import Union
5
 
@@ -43,14 +44,6 @@ class HfMoondream(PreTrainedModel):
43
  MoondreamConfig.from_dict(config.config), setup_caches=False
44
  )
45
  self._is_kv_cache_setup = False
46
- self.post_init()
47
-
48
- @classmethod
49
- def from_pretrained(cls, *args, **kwargs):
50
- output = super().from_pretrained(*args, **kwargs)
51
- model = output[0] if isinstance(output, tuple) else output
52
- model.model._refresh_runtime_buffers()
53
- return output
54
 
55
  def _setup_caches(self):
56
  if not self._is_kv_cache_setup:
 
1
  import torch
2
  import torch.nn as nn
3
+
4
  from transformers import PreTrainedModel, PretrainedConfig
5
  from typing import Union
6
 
 
44
  MoondreamConfig.from_dict(config.config), setup_caches=False
45
  )
46
  self._is_kv_cache_setup = False
 
 
 
 
 
 
 
 
47
 
48
  def _setup_caches(self):
49
  if not self._is_kv_cache_setup:
layers.py CHANGED
@@ -5,14 +5,6 @@ import torch.nn.functional as F
5
  from dataclasses import dataclass
6
  from typing import Literal, Optional
7
 
8
- from .lora import (
9
- DenseLoRALayer,
10
- MoELoRALayer,
11
- apply_dense_lora,
12
- apply_moe_lora_fc1_flat,
13
- apply_moe_lora_fc2_flat,
14
- )
15
-
16
  try:
17
  from torchao import quantize_
18
  from torchao.quantization import int4_weight_only
@@ -134,12 +126,11 @@ class MLPWeights:
134
  act: Literal["gelu_approx"] = "gelu_approx"
135
 
136
 
137
- def mlp(
138
- x: torch.Tensor, w: MLPWeights, lora: Optional[DenseLoRALayer] = None
139
- ) -> torch.Tensor:
140
  x0 = w.fc1(x)
141
  if lora is not None:
142
- x = x0 + apply_dense_lora(x, lora.up_a, lora.up_b)
 
143
  else:
144
  x = x0
145
 
@@ -147,7 +138,8 @@ def mlp(
147
 
148
  x0 = w.fc2(x)
149
  if lora is not None:
150
- x = x0 + apply_dense_lora(x, lora.down_a, lora.down_b)
 
151
  else:
152
  x = x0
153
 
@@ -155,10 +147,7 @@ def mlp(
155
 
156
 
157
  def moe_mlp(
158
- x: torch.Tensor,
159
- mlp_module: nn.Module,
160
- experts_per_token: int,
161
- lora: Optional[MoELoRALayer] = None,
162
  ) -> torch.Tensor:
163
  B, T, C = x.shape
164
  x = x.reshape(-1, C)
@@ -178,23 +167,21 @@ def moe_mlp(
178
  flat_weights = topk_weights.view(-1) # [T*A]
179
 
180
  # Select expert weights
181
- w1_selected = w1_weight[flat_idxs]
182
- w2_selected = w2_weight[flat_idxs]
183
 
184
  # Expand input for all token-expert pairs
185
  x_expanded = x.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, C) # [T*A, D]
186
 
187
  # First linear layer with GeGLU: [T*A, H, D] @ [T*A, D, 1] -> [T*A, H]
188
- x1_full = torch.bmm(w1_selected, x_expanded.unsqueeze(-1)).squeeze(-1) # [T*A, H]
189
- if lora is not None:
190
- x1_full = x1_full + apply_moe_lora_fc1_flat(x_expanded, lora, flat_idxs)
191
  x1, g = x1_full.chunk(2, dim=-1)
192
  x1 = F.gelu(x1) * (g + 1)
193
 
194
  # Second linear layer: [T*A, D, H] @ [T*A, H, 1] -> [T*A, D]
195
  expert_outs = torch.bmm(w2_selected, x1.unsqueeze(-1)).squeeze(-1) # [T*A, D]
196
- if lora is not None:
197
- expert_outs = expert_outs + apply_moe_lora_fc2_flat(x1, lora, flat_idxs)
198
 
199
  # Apply weights and reshape
200
  weighted_outs = expert_outs * flat_weights.unsqueeze(-1) # [T*A, D]
@@ -216,22 +203,10 @@ def moe_mlp(
216
  x_tok = x.index_select(0, token_pos)
217
  gate_tok = topk_weights[token_pos, which_k]
218
 
219
- w1 = mlp_module.fc1.weight[expert_id]
220
- h_full = F.linear(x_tok, w1)
221
- if lora is not None:
222
- lora_up_a = lora.up_a[expert_id]
223
- lora_up_b = lora.up_b[expert_id]
224
- lora_mid = F.linear(x_tok, lora_up_a)
225
- h_full = h_full + F.linear(lora_mid, lora_up_b)
226
  h, g = h_full.chunk(2, dim=-1)
227
  h = F.gelu(h) * (g + 1)
228
- w2 = mlp_module.fc2.weight[expert_id]
229
- y = F.linear(h, w2)
230
- if lora is not None:
231
- lora_down_a = lora.down_a[expert_id]
232
- lora_down_b = lora.down_b[expert_id]
233
- lora_mid = F.linear(h, lora_down_a)
234
- y = y + F.linear(lora_mid, lora_down_b)
235
 
236
  y.mul_(gate_tok.unsqueeze(-1))
237
  out.index_add_(0, token_pos, y)
 
5
  from dataclasses import dataclass
6
  from typing import Literal, Optional
7
 
 
 
 
 
 
 
 
 
8
  try:
9
  from torchao import quantize_
10
  from torchao.quantization import int4_weight_only
 
126
  act: Literal["gelu_approx"] = "gelu_approx"
127
 
128
 
129
+ def mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> torch.Tensor:
 
 
130
  x0 = w.fc1(x)
131
  if lora is not None:
132
+ x1 = F.linear(F.linear(x, lora["fc1"]["A"]), lora["fc1"]["B"])
133
+ x = x0 + x1
134
  else:
135
  x = x0
136
 
 
138
 
139
  x0 = w.fc2(x)
140
  if lora is not None:
141
+ x1 = F.linear(F.linear(x, lora["fc2"]["A"]), lora["fc2"]["B"])
142
+ x = x0 + x1
143
  else:
144
  x = x0
145
 
 
147
 
148
 
149
  def moe_mlp(
150
+ x: torch.Tensor, mlp_module: nn.Module, experts_per_token: int
 
 
 
151
  ) -> torch.Tensor:
152
  B, T, C = x.shape
153
  x = x.reshape(-1, C)
 
167
  flat_weights = topk_weights.view(-1) # [T*A]
168
 
169
  # Select expert weights
170
+ w1_selected = w1_weight[flat_idxs] # [T*A, H, D]
171
+ w2_selected = w2_weight[flat_idxs] # [T*A, D, H]
172
 
173
  # Expand input for all token-expert pairs
174
  x_expanded = x.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, C) # [T*A, D]
175
 
176
  # First linear layer with GeGLU: [T*A, H, D] @ [T*A, D, 1] -> [T*A, H]
177
+ x1_full = torch.bmm(w1_selected, x_expanded.unsqueeze(-1)).squeeze(
178
+ -1
179
+ ) # [T*A, H]
180
  x1, g = x1_full.chunk(2, dim=-1)
181
  x1 = F.gelu(x1) * (g + 1)
182
 
183
  # Second linear layer: [T*A, D, H] @ [T*A, H, 1] -> [T*A, D]
184
  expert_outs = torch.bmm(w2_selected, x1.unsqueeze(-1)).squeeze(-1) # [T*A, D]
 
 
185
 
186
  # Apply weights and reshape
187
  weighted_outs = expert_outs * flat_weights.unsqueeze(-1) # [T*A, D]
 
203
  x_tok = x.index_select(0, token_pos)
204
  gate_tok = topk_weights[token_pos, which_k]
205
 
206
+ h_full = F.linear(x_tok, mlp_module.fc1.weight[expert_id])
 
 
 
 
 
 
207
  h, g = h_full.chunk(2, dim=-1)
208
  h = F.gelu(h) * (g + 1)
209
+ y = F.linear(h, mlp_module.fc2.weight[expert_id])
 
 
 
 
 
 
210
 
211
  y.mul_(gate_tok.unsqueeze(-1))
212
  out.index_add_(0, token_pos, y)
lora.py CHANGED
@@ -1,437 +1,82 @@
1
- import json
2
  import os
3
- import re
4
  import shutil
5
- from dataclasses import dataclass
6
- from pathlib import Path
7
- from typing import Any, Dict, Optional, Tuple
8
- from urllib.request import Request, urlopen
9
-
10
  import torch
11
 
12
- from .config import TextConfig
13
-
14
-
15
- class AdapterLoadError(RuntimeError):
16
- pass
17
 
18
 
19
- def _cache_root() -> Path:
20
  hf_hub_cache = os.environ.get("HF_HUB_CACHE")
21
- if hf_hub_cache:
22
- return Path(hf_hub_cache)
23
 
24
  hf_home = os.environ.get("HF_HOME")
25
- if hf_home:
26
- return Path(hf_home) / "hub"
27
 
28
- return Path("~/.cache/huggingface/hub").expanduser()
29
 
30
 
31
- def adapter_cache_dir() -> Path:
32
- return _cache_root() / "md_finetunes"
 
33
 
 
 
 
 
 
34
 
35
- def normalize_adapter_id(value: Optional[str]) -> Optional[str]:
36
- if not value:
37
- return None
38
- tail = value.split("/")[-1].strip()
39
- if "@" not in tail:
40
- return None
41
- return tail
42
 
43
-
44
- def parse_adapter_id(adapter_id: str) -> Tuple[str, str]:
45
- if not adapter_id or "@" not in adapter_id:
46
- raise AdapterLoadError(
47
- f"Invalid adapter id '{adapter_id}'. Expected 'finetune_id@step'."
48
- )
49
- finetune_id, step = adapter_id.split("@", 1)
50
- if not finetune_id or not step:
51
- raise AdapterLoadError(
52
- f"Invalid adapter id '{adapter_id}'. Expected 'finetune_id@step'."
53
- )
54
- return finetune_id, step
55
-
56
-
57
- def _fetch_presigned_url(finetune_id: str, step: str) -> str:
58
- endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai").rstrip("/")
59
  api_key = os.getenv("MOONDREAM_API_KEY")
60
- if not api_key:
61
- raise AdapterLoadError("MOONDREAM_API_KEY is required to load finetune adapters.")
62
-
63
- headers = {"User-Agent": "moondream-torch", "X-Moondream-Auth": api_key}
64
- url = f"{endpoint}/v1/tuning/finetunes/{finetune_id}/checkpoints/{step}/download"
65
- req = Request(url, headers=headers)
66
- try:
67
- with urlopen(req) as r:
68
- payload = json.loads(r.read().decode("utf-8"))
69
- except Exception as e:
70
- raise AdapterLoadError(f"Failed to fetch adapter URL: {e}") from e
71
-
72
- presigned = payload.get("url")
73
- if not presigned:
74
- raise AdapterLoadError("Adapter URL response missing 'url' field.")
75
- return presigned
76
-
77
-
78
- def cached_adapter_path(adapter_id: str) -> Path:
79
- finetune_id, step = parse_adapter_id(adapter_id)
80
-
81
- cache_dir = adapter_cache_dir() / finetune_id / step
82
- cache_dir.mkdir(parents=True, exist_ok=True)
83
 
84
- for name in ("adapter.pt", "adapter.safetensors"):
85
- path = cache_dir / name
86
- if path.exists() and path.stat().st_size > 0:
87
- return path
88
-
89
- presigned_url = _fetch_presigned_url(finetune_id, step)
90
- dest = cache_dir / "adapter.pt"
91
-
92
- try:
93
- with urlopen(presigned_url) as r, open(dest, "wb") as f:
94
- shutil.copyfileobj(r, f)
95
- except Exception as e:
96
- raise AdapterLoadError(f"Failed to download adapter: {e}") from e
97
  return dest
98
 
99
 
100
- def _load_state_dict(path: Path, device: torch.device) -> Dict[str, Any]:
101
- if path.suffix == ".safetensors":
102
- try:
103
- from safetensors.torch import safe_open
104
- except Exception as e:
105
- raise AdapterLoadError(
106
- "safetensors is required to load .safetensors adapters."
107
- ) from e
108
- data = {}
109
- with safe_open(str(path), framework="pt") as f:
110
- for key in f.keys():
111
- data[key] = f.get_tensor(key).to(device=device)
112
- return data
113
-
114
- try:
115
- return torch.load(path, map_location=device, weights_only=True)
116
- except TypeError:
117
- return torch.load(path, map_location=device)
118
-
119
-
120
- @dataclass
121
- class DenseLoRALayer:
122
- up_a: torch.Tensor
123
- up_b: torch.Tensor
124
- down_a: torch.Tensor
125
- down_b: torch.Tensor
126
-
127
-
128
- @dataclass
129
- class MoELoRALayer:
130
- up_a: torch.Tensor
131
- up_b: torch.Tensor
132
- down_a: torch.Tensor
133
- down_b: torch.Tensor
134
-
135
 
136
- class TextLoRA:
137
- def __init__(
138
- self,
139
- text_config: TextConfig,
140
- *,
141
- rank: int,
142
- max_rank: int,
143
- dtype: torch.dtype,
144
- device: torch.device,
145
- adapter_id: Optional[str] = None,
146
- ) -> None:
147
- if rank <= 0:
148
- raise AdapterLoadError("LoRA rank must be positive.")
149
- if max_rank < rank:
150
- raise AdapterLoadError("max_rank must be >= rank.")
151
-
152
- self.text_config = text_config
153
- self.rank = rank
154
- self.max_rank = max_rank
155
- self.adapter_id = adapter_id
156
-
157
- moe_cfg = text_config.moe
158
- self.start_layer = moe_cfg.start_layer if moe_cfg else text_config.n_layers
159
-
160
- if moe_cfg is not None:
161
- self.rank_per_expert = rank // moe_cfg.experts_per_token
162
- if self.rank_per_expert < 1:
163
- raise AdapterLoadError(
164
- f"rank ({rank}) must be >= experts_per_token ({moe_cfg.experts_per_token})"
165
- )
166
- self.max_rank_per_expert = max_rank // moe_cfg.experts_per_token
167
- if self.max_rank_per_expert < 1:
168
- raise AdapterLoadError(
169
- f"max_rank ({max_rank}) must be >= experts_per_token ({moe_cfg.experts_per_token})"
170
- )
171
- else:
172
- self.rank_per_expert = 0
173
- self.max_rank_per_expert = 0
174
-
175
- d_model = text_config.dim
176
- d_ffn = text_config.ff_dim
177
-
178
- self.dense: list[DenseLoRALayer] = []
179
- for _ in range(self.start_layer):
180
- self.dense.append(
181
- DenseLoRALayer(
182
- up_a=torch.zeros((max_rank, d_model), device=device, dtype=dtype),
183
- up_b=torch.zeros((d_ffn, max_rank), device=device, dtype=dtype),
184
- down_a=torch.zeros((max_rank, d_ffn), device=device, dtype=dtype),
185
- down_b=torch.zeros((d_model, max_rank), device=device, dtype=dtype),
186
- )
187
- )
188
-
189
- self.moe: list[MoELoRALayer] = []
190
- if moe_cfg is not None:
191
- num_experts = moe_cfg.num_experts
192
- d_expert = moe_cfg.expert_inner_dim
193
- for _ in range(text_config.n_layers - self.start_layer):
194
- self.moe.append(
195
- MoELoRALayer(
196
- up_a=torch.zeros(
197
- (num_experts, self.max_rank_per_expert, d_model),
198
- device=device,
199
- dtype=dtype,
200
- ),
201
- up_b=torch.zeros(
202
- (num_experts, d_expert * 2, self.max_rank_per_expert),
203
- device=device,
204
- dtype=dtype,
205
- ),
206
- down_a=torch.zeros(
207
- (num_experts, self.max_rank_per_expert, d_expert),
208
- device=device,
209
- dtype=dtype,
210
- ),
211
- down_b=torch.zeros(
212
- (num_experts, d_model, self.max_rank_per_expert),
213
- device=device,
214
- dtype=dtype,
215
- ),
216
- )
217
- )
218
-
219
- def dense_layer(self, layer_idx: int) -> Optional[DenseLoRALayer]:
220
- if layer_idx < len(self.dense):
221
- return self.dense[layer_idx]
222
- return None
223
 
224
- def moe_layer(self, layer_idx: int) -> Optional[MoELoRALayer]:
225
- moe_idx = layer_idx - self.start_layer
226
- if 0 <= moe_idx < len(self.moe):
227
- return self.moe[moe_idx]
228
  return None
229
 
230
- @staticmethod
231
- def _pad_axis(tensor: torch.Tensor, target: int, axis: int) -> torch.Tensor:
232
- if tensor.shape[axis] == target:
233
- return tensor
234
- if tensor.shape[axis] > target:
235
- raise AdapterLoadError(
236
- f"LoRA tensor rank {tensor.shape[axis]} exceeds max {target}"
237
- )
238
- pad_shape = list(tensor.shape)
239
- pad_shape[axis] = target - tensor.shape[axis]
240
- pad = torch.zeros(pad_shape, device=tensor.device, dtype=tensor.dtype)
241
- return torch.cat([tensor, pad], dim=axis)
242
-
243
- @staticmethod
244
- def detect_rank(state_dict: Dict[str, Any], text_config: TextConfig) -> int:
245
- for key, tensor in state_dict.items():
246
- if "dense" in key and "up_a" in key:
247
- return int(tensor.shape[0])
248
- for key, tensor in state_dict.items():
249
- if "moe" in key and "up_a" in key:
250
- rank_per_expert = int(tensor.shape[1])
251
- moe_cfg = text_config.moe
252
- if moe_cfg:
253
- return rank_per_expert * moe_cfg.experts_per_token
254
- return rank_per_expert
255
- raise AdapterLoadError("Could not detect LoRA rank from state dict.")
256
-
257
- @classmethod
258
- def from_state_dict(
259
- cls,
260
- state_dict: Dict[str, Any],
261
- *,
262
- text_config: TextConfig,
263
- max_rank: int,
264
- dtype: torch.dtype,
265
- device: torch.device,
266
- adapter_id: Optional[str] = None,
267
- ) -> "TextLoRA":
268
- rank = cls.detect_rank(state_dict, text_config)
269
- if rank > max_rank:
270
- raise AdapterLoadError(
271
- f"Adapter rank ({rank}) exceeds max_rank ({max_rank})."
272
- )
273
-
274
- lora = cls(
275
- text_config,
276
- rank=rank,
277
- max_rank=max_rank,
278
- dtype=dtype,
279
- device=device,
280
- adapter_id=adapter_id,
281
- )
282
-
283
- dense_seen = set()
284
- moe_seen = set()
285
-
286
- pattern = re.compile(r"(dense|moe)\.(\d+)\.(up_a|up_b|down_a|down_b)$")
287
- for key, tensor in state_dict.items():
288
- match = pattern.search(key)
289
- if not match:
290
- continue
291
- kind, idx_str, name = match.group(1), match.group(2), match.group(3)
292
- idx = int(idx_str)
293
- arr = tensor.to(device=device, dtype=dtype)
294
-
295
- if kind == "dense":
296
- if idx >= len(lora.dense):
297
- raise AdapterLoadError(f"Dense LoRA layer index {idx} out of range.")
298
- layer = lora.dense[idx]
299
- if name in ("up_a", "down_a"):
300
- arr = cls._pad_axis(arr, lora.max_rank, axis=0)
301
- else:
302
- arr = cls._pad_axis(arr, lora.max_rank, axis=1)
303
- setattr(layer, name, arr)
304
- dense_seen.add((idx, name))
305
- else:
306
- if idx >= len(lora.moe):
307
- raise AdapterLoadError(f"MoE LoRA layer index {idx} out of range.")
308
- layer = lora.moe[idx]
309
- if name in ("up_a", "down_a"):
310
- arr = cls._pad_axis(arr, lora.max_rank_per_expert, axis=1)
311
- else:
312
- arr = cls._pad_axis(arr, lora.max_rank_per_expert, axis=2)
313
- setattr(layer, name, arr)
314
- moe_seen.add((idx, name))
315
-
316
- for layer_idx in range(len(lora.dense)):
317
- for name in ("up_a", "up_b", "down_a", "down_b"):
318
- if (layer_idx, name) not in dense_seen:
319
- raise AdapterLoadError(
320
- f"Adapter missing dense LoRA for layer {layer_idx} ({name})."
321
- )
322
- for layer_idx in range(len(lora.moe)):
323
- for name in ("up_a", "up_b", "down_a", "down_b"):
324
- if (layer_idx, name) not in moe_seen:
325
- raise AdapterLoadError(
326
- f"Adapter missing MoE LoRA for layer {layer_idx} ({name})."
327
- )
328
-
329
- return lora
330
-
331
-
332
- def select_layer_lora(
333
- lora: Optional[TextLoRA], layer_idx: int, *, is_moe: bool
334
- ) -> Optional[object]:
335
- if lora is None:
336
- return None
337
- return lora.moe_layer(layer_idx) if is_moe else lora.dense_layer(layer_idx)
338
-
339
-
340
- def apply_dense_lora(
341
- x: torch.Tensor, lora_a: torch.Tensor, lora_b: torch.Tensor
342
- ) -> torch.Tensor:
343
- b, t, c = x.shape
344
- x_flat = x.reshape(-1, c)
345
- lora_mid = torch.matmul(x_flat, lora_a.t())
346
- lora_out = torch.matmul(lora_mid, lora_b.t())
347
- return lora_out.reshape(b, t, -1)
348
-
349
-
350
- def apply_moe_lora_fc1_flat(
351
- x_expanded: torch.Tensor, lora: MoELoRALayer, flat_idxs: torch.Tensor
352
- ) -> torch.Tensor:
353
- lora_up_a = lora.up_a[flat_idxs]
354
- lora_up_b = lora.up_b[flat_idxs]
355
- lora_mid = torch.bmm(lora_up_a, x_expanded.unsqueeze(-1)).squeeze(-1)
356
- lora_up = torch.bmm(lora_up_b, lora_mid.unsqueeze(-1)).squeeze(-1)
357
- return lora_up
358
-
359
-
360
- def apply_moe_lora_fc2_flat(
361
- h: torch.Tensor, lora: MoELoRALayer, flat_idxs: torch.Tensor
362
- ) -> torch.Tensor:
363
- lora_down_a = lora.down_a[flat_idxs]
364
- lora_down_b = lora.down_b[flat_idxs]
365
- lora_mid = torch.bmm(lora_down_a, h.unsqueeze(-1)).squeeze(-1)
366
- lora_down = torch.bmm(lora_down_b, lora_mid.unsqueeze(-1)).squeeze(-1)
367
- return lora_down
368
-
369
-
370
- _ADAPTER_CACHE: Dict[Tuple[str, str, str, Tuple], TextLoRA] = {}
371
- _CACHE_ORDER: list[Tuple[str, str, str, Tuple]] = []
372
- _CACHE_SIZE = 8
373
-
374
-
375
- def _config_key(text_config: TextConfig) -> Tuple:
376
- moe = text_config.moe
377
- moe_key = None
378
- if moe is not None:
379
- moe_key = (
380
- moe.num_experts,
381
- moe.start_layer,
382
- moe.experts_per_token,
383
- moe.expert_inner_dim,
384
- )
385
- return (
386
- text_config.dim,
387
- text_config.ff_dim,
388
- text_config.n_layers,
389
- moe_key,
390
- )
391
-
392
-
393
- def load_adapter(
394
- adapter_id: Optional[str],
395
- *,
396
- text_config: TextConfig,
397
- device: torch.device,
398
- dtype: torch.dtype,
399
- max_rank: int = 16,
400
- ) -> Optional[TextLoRA]:
401
- if adapter_id is None:
402
- return None
403
-
404
- adapter_id = normalize_adapter_id(adapter_id)
405
- if adapter_id is None:
406
- return None
407
-
408
- key = (adapter_id, str(device), str(dtype), _config_key(text_config))
409
- cached = _ADAPTER_CACHE.get(key)
410
- if cached is not None:
411
- return cached
412
-
413
- path = cached_adapter_path(adapter_id)
414
- checkpoint = _load_state_dict(path, device)
415
- if not isinstance(checkpoint, dict):
416
- raise AdapterLoadError("Invalid adapter checkpoint format.")
417
-
418
- state_dict = checkpoint.get("lora_state_dict", checkpoint)
419
- if not isinstance(state_dict, dict):
420
- raise AdapterLoadError("Adapter checkpoint missing lora_state_dict.")
421
-
422
- lora = TextLoRA.from_state_dict(
423
- state_dict,
424
- text_config=text_config,
425
- max_rank=max_rank,
426
- dtype=dtype,
427
- device=device,
428
- adapter_id=adapter_id,
429
  )
430
 
431
- _ADAPTER_CACHE[key] = lora
432
- _CACHE_ORDER.append(key)
433
- if len(_CACHE_ORDER) > _CACHE_SIZE:
434
- old = _CACHE_ORDER.pop(0)
435
- _ADAPTER_CACHE.pop(old, None)
436
-
437
- return lora
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
  import os
 
3
  import shutil
 
 
 
 
 
4
  import torch
5
 
6
+ from pathlib import Path
7
+ from urllib.request import Request, urlopen
8
+ from typing import Optional
 
 
9
 
10
 
11
+ def variant_cache_dir():
12
  hf_hub_cache = os.environ.get("HF_HUB_CACHE")
13
+ if hf_hub_cache is not None:
14
+ return Path(hf_hub_cache) / "md_variants"
15
 
16
  hf_home = os.environ.get("HF_HOME")
17
+ if hf_home is not None:
18
+ return Path(hf_home) / "hub" / "md_variants"
19
 
20
+ return Path("~/.cache/huggingface/hub").expanduser() / "md_variants"
21
 
22
 
23
+ def cached_variant_path(variant_id: str):
24
+ variant, *rest = variant_id.split("/", 1)
25
+ step = rest[0] if rest else "final"
26
 
27
+ cache_dir = variant_cache_dir() / variant
28
+ os.makedirs(cache_dir, exist_ok=True)
29
+ dest = cache_dir / f"{step}.pt"
30
+ if dest.exists():
31
+ return dest
32
 
33
+ md_endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai")
 
 
 
 
 
 
34
 
35
+ headers = {"User-Agent": "moondream-torch"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  api_key = os.getenv("MOONDREAM_API_KEY")
37
+ if api_key is not None:
38
+ headers["X-Moondream-Auth"] = api_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ req = Request(f"{md_endpoint}/v1/variants/{variant_id}/download", headers=headers)
41
+ with urlopen(req) as r, open(dest, "wb") as f:
42
+ shutil.copyfileobj(r, f)
 
 
 
 
 
 
 
 
 
 
43
  return dest
44
 
45
 
46
+ def nest(flat):
47
+ tree = {}
48
+ for k, v in flat.items():
49
+ parts = k.split(".")
50
+ d = tree
51
+ for p in parts[:-1]:
52
+ d = d.setdefault(p, {})
53
+ d[parts[-1]] = v
54
+ return tree
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ @functools.lru_cache(maxsize=5)
58
+ def variant_state_dict(variant_id: Optional[str] = None, device: str = "cpu"):
59
+ if variant_id is None:
 
60
  return None
61
 
62
+ state_dict = torch.load(
63
+ cached_variant_path(variant_id), map_location=device, weights_only=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  )
65
 
66
+ # TODO: Move these into the training code that saves checkpoints...
67
+ rename_rules = [
68
+ ("text_model.transformer.h", "text.blocks"),
69
+ (".mixer", ".attn"),
70
+ (".out_proj", ".proj"),
71
+ (".Wqkv", ".qkv"),
72
+ (".parametrizations.weight.0", ""),
73
+ ]
74
+ new_state_dict = {}
75
+ for key, tensor in state_dict.items():
76
+ new_key = key
77
+ for old, new in rename_rules:
78
+ if old in new_key:
79
+ new_key = new_key.replace(old, new)
80
+ new_state_dict[new_key] = tensor
81
+
82
+ return nest(new_state_dict)
model.safetensors.index.json CHANGED
The diff for this file is too large to render. See raw diff
 
model_fp8.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:699bd3876f9e105440d60a5fe30c26bc33fbdf008f5bd611a3557663b24bd371
3
- size 10505451019
 
 
 
 
modelv2-00001-of-00004.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:79006ed488cca15b173cd5c0c7c1a467c20aaf5508e13934c36378d071d48c13
3
- size 4907406296
 
 
 
 
modelv2-00002-of-00004.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:40202c61286ec7386d9bbce31d87af3064e42931b10323ed4b3e44158c0521e3
3
- size 4736548872
 
 
 
 
modelv2-00003-of-00004.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ff46835f23bac47c7409032391e02a095821e274f3faaeea3f826a960db9bf80
3
- size 4502742464
 
 
 
 
modelv2-00004-of-00004.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0a4d39e1bcb0ab835b9a00c7f458dedca4faf8741fc0b23fd2caf2af4547bca6
3
- size 4390628760
 
 
 
 
moondream.py CHANGED
@@ -21,13 +21,12 @@ from .region import (
21
  SpatialRefs,
22
  )
23
  from .layers import QuantizedLinear
24
- from .lora import load_adapter, normalize_adapter_id
25
- from .rope import precompute_freqs_cis
26
  from .utils import remove_outlier_points
27
 
28
  ImageEncodingSettings = TypedDict(
29
  "ImageEncodingSettings",
30
- {"adapter": str, "model": str},
31
  total=False,
32
  )
33
 
@@ -37,15 +36,14 @@ TextSamplingSettings = TypedDict(
37
  "max_tokens": int,
38
  "temperature": float,
39
  "top_p": float,
40
- "adapter": str,
41
- "model": str,
42
  },
43
  total=False,
44
  )
45
 
46
  ObjectSamplingSettings = TypedDict(
47
  "ObjectSamplingSettings",
48
- {"max_objects": int, "adapter": str, "model": str},
49
  total=False,
50
  )
51
 
@@ -122,7 +120,6 @@ class MoondreamModel(nn.Module):
122
  "size_decoder": linear_cls(
123
  config.region.dim, config.region.size_out_dim, dtype=dtype
124
  ),
125
- "ln": nn.LayerNorm(config.region.dim, dtype=dtype),
126
  }
127
  )
128
  self.region.coord_features = nn.Parameter(
@@ -172,26 +169,6 @@ class MoondreamModel(nn.Module):
172
  )
173
  return self._point_gen_indices
174
 
175
- def _refresh_runtime_buffers(self):
176
- attn_mask = torch.tril(
177
- torch.ones(
178
- 1,
179
- 1,
180
- self.config.text.max_context,
181
- self.config.text.max_context,
182
- dtype=torch.bool,
183
- device=self.device,
184
- )
185
- )
186
- patch_w = self.config.vision.crop_size // self.config.vision.enc_patch_size
187
- prefix_attn_len = 1 + patch_w**2
188
- attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
189
- self.attn_mask = attn_mask
190
- self.text.freqs_cis = precompute_freqs_cis(
191
- self.config.text.dim // (2 * self.config.text.n_heads),
192
- self.config.text.max_context,
193
- ).to(device=self.device)
194
-
195
  def _setup_caches(self):
196
  c = self.config.text
197
  for b in self.text.blocks:
@@ -204,29 +181,6 @@ class MoondreamModel(nn.Module):
204
  dtype=self.vision.pos_emb.dtype,
205
  )
206
 
207
- def _adapter_id_from_settings(self, settings: Optional[dict]) -> Optional[str]:
208
- if settings is None:
209
- return None
210
- adapter = settings.get("adapter")
211
- if adapter is not None:
212
- return normalize_adapter_id(adapter)
213
-
214
- model_value = settings.get("model")
215
- if isinstance(model_value, str):
216
- return normalize_adapter_id(model_value)
217
- return None
218
-
219
- def _resolve_lora(self, settings: Optional[dict]) -> Optional[object]:
220
- adapter_id = self._adapter_id_from_settings(settings)
221
- if adapter_id is None:
222
- return None
223
- return load_adapter(
224
- adapter_id,
225
- text_config=self.config.text,
226
- device=self.device,
227
- dtype=self.vision.pos_emb.dtype,
228
- )
229
-
230
  @property
231
  def device(self):
232
  return self.vision.pos_emb.device
@@ -349,7 +303,11 @@ class MoondreamModel(nn.Module):
349
  elif not isinstance(image, Image.Image):
350
  raise ValueError("image must be a PIL Image or EncodedImage")
351
 
352
- lora = self._resolve_lora(settings)
 
 
 
 
353
 
354
  # Run through text model in addition to the vision encoder, to minimize
355
  # re-computation if multiple queries are performed on this image.
@@ -450,7 +408,11 @@ class MoondreamModel(nn.Module):
450
  if settings
451
  else DEFAULT_TEMPERATURE
452
  )
453
- lora = self._resolve_lora(settings)
 
 
 
 
454
 
455
  top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
456
  eos_id = self.config.tokenizer.answer_id
@@ -562,7 +524,11 @@ class MoondreamModel(nn.Module):
562
  )
563
  top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
564
  eos_id = eos_id if eos_id is not None else self.config.tokenizer.eos_id
565
- lora = self._resolve_lora(settings)
 
 
 
 
566
 
567
  _, _, next_token, pos = self._prefill_prompt(
568
  prompt_tokens,
@@ -705,7 +671,6 @@ class MoondreamModel(nn.Module):
705
  reasoning_dict = {
706
  "reasoning": {"text": reasoning_text, "grounding": reasoning_grounding}
707
  }
708
- spatial_refs = None
709
  else:
710
  prompt_tokens[0] += self.config.tokenizer.templates["query"]["suffix"]
711
  reasoning_dict = {}
@@ -869,7 +834,11 @@ class MoondreamModel(nn.Module):
869
  device=self.device,
870
  )
871
 
872
- lora = self._resolve_lora(settings)
 
 
 
 
873
 
874
  _, hidden, next_token, pos = self._prefill_prompt(
875
  prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
@@ -913,7 +882,11 @@ class MoondreamModel(nn.Module):
913
  device=self.device,
914
  )
915
 
916
- lora = self._resolve_lora(settings)
 
 
 
 
917
 
918
  _, hidden, next_token, pos = self._prefill_prompt(
919
  prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
 
21
  SpatialRefs,
22
  )
23
  from .layers import QuantizedLinear
24
+ from .lora import variant_state_dict
 
25
  from .utils import remove_outlier_points
26
 
27
  ImageEncodingSettings = TypedDict(
28
  "ImageEncodingSettings",
29
+ {"variant": str},
30
  total=False,
31
  )
32
 
 
36
  "max_tokens": int,
37
  "temperature": float,
38
  "top_p": float,
39
+ "variant": str,
 
40
  },
41
  total=False,
42
  )
43
 
44
  ObjectSamplingSettings = TypedDict(
45
  "ObjectSamplingSettings",
46
+ {"max_objects": int, "variant": str},
47
  total=False,
48
  )
49
 
 
120
  "size_decoder": linear_cls(
121
  config.region.dim, config.region.size_out_dim, dtype=dtype
122
  ),
 
123
  }
124
  )
125
  self.region.coord_features = nn.Parameter(
 
169
  )
170
  return self._point_gen_indices
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  def _setup_caches(self):
173
  c = self.config.text
174
  for b in self.text.blocks:
 
181
  dtype=self.vision.pos_emb.dtype,
182
  )
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  @property
185
  def device(self):
186
  return self.vision.pos_emb.device
 
303
  elif not isinstance(image, Image.Image):
304
  raise ValueError("image must be a PIL Image or EncodedImage")
305
 
306
+ lora = (
307
+ variant_state_dict(settings["variant"], device=self.device)
308
+ if settings is not None and "variant" in settings
309
+ else None
310
+ )
311
 
312
  # Run through text model in addition to the vision encoder, to minimize
313
  # re-computation if multiple queries are performed on this image.
 
408
  if settings
409
  else DEFAULT_TEMPERATURE
410
  )
411
+ lora = (
412
+ variant_state_dict(settings["variant"], device=self.device)
413
+ if settings is not None and "variant" in settings
414
+ else None
415
+ )
416
 
417
  top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
418
  eos_id = self.config.tokenizer.answer_id
 
524
  )
525
  top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
526
  eos_id = eos_id if eos_id is not None else self.config.tokenizer.eos_id
527
+ lora = (
528
+ variant_state_dict(settings["variant"], device=self.device)
529
+ if settings is not None and "variant" in settings
530
+ else None
531
+ )
532
 
533
  _, _, next_token, pos = self._prefill_prompt(
534
  prompt_tokens,
 
671
  reasoning_dict = {
672
  "reasoning": {"text": reasoning_text, "grounding": reasoning_grounding}
673
  }
 
674
  else:
675
  prompt_tokens[0] += self.config.tokenizer.templates["query"]["suffix"]
676
  reasoning_dict = {}
 
834
  device=self.device,
835
  )
836
 
837
+ lora = (
838
+ variant_state_dict(settings["variant"], device=self.device)
839
+ if settings is not None and "variant" in settings
840
+ else None
841
+ )
842
 
843
  _, hidden, next_token, pos = self._prefill_prompt(
844
  prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
 
882
  device=self.device,
883
  )
884
 
885
+ lora = (
886
+ variant_state_dict(settings["variant"], device=self.device)
887
+ if settings is not None and "variant" in settings
888
+ else None
889
+ )
890
 
891
  _, hidden, next_token, pos = self._prefill_prompt(
892
  prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
region.py CHANGED
@@ -52,7 +52,6 @@ def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
52
  Returns:
53
  A single logit representing the predicted coordinate value (x or y)
54
  """
55
- hidden_state = w.ln(hidden_state)
56
  return w.coord_decoder(hidden_state)
57
 
58
 
@@ -89,7 +88,6 @@ def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
89
  A tensor containing logits for 1024 bins for width and height.
90
  Shape is (2, 1024) where the first dimension corresponds to width and height.
91
  """
92
- hidden_state = w.ln(hidden_state)
93
  return w.size_decoder(hidden_state).view(2, -1)
94
 
95
 
 
52
  Returns:
53
  A single logit representing the predicted coordinate value (x or y)
54
  """
 
55
  return w.coord_decoder(hidden_state)
56
 
57
 
 
88
  A tensor containing logits for 1024 bins for width and height.
89
  Shape is (2, 1024) where the first dimension corresponds to width and height.
90
  """
 
91
  return w.size_decoder(hidden_state).view(2, -1)
92
 
93
 
text.py CHANGED
@@ -8,7 +8,6 @@ from typing import Optional
8
  from .layers import layer_norm, mlp, QuantizedLinear, moe_mlp
9
  from .rope import apply_rotary_emb, precompute_freqs_cis
10
  from .config import TextConfig
11
- from .lora import select_layer_lora
12
 
13
 
14
  def text_encoder(input_ids: torch.Tensor, w: nn.Module):
@@ -24,12 +23,15 @@ def attn(
24
  n_heads: int,
25
  n_kv_heads: int,
26
  position_ids: torch.Tensor,
 
27
  flex_block_mask_slice=None,
28
  ):
29
  bsz, q_len, d_model = x.shape
30
  head_dim = d_model // n_heads
31
 
32
  qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
 
 
33
  q_dim = n_heads * head_dim
34
  kv_dim = n_kv_heads * head_dim
35
  q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1)
@@ -67,7 +69,14 @@ def attn(
67
 
68
  out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
69
 
70
- return w.proj(out)
 
 
 
 
 
 
 
71
 
72
 
73
  def text_decoder(
@@ -76,13 +85,17 @@ def text_decoder(
76
  attn_mask: torch.Tensor,
77
  position_ids: torch.Tensor,
78
  config: TextConfig,
79
- lora: Optional[object] = None,
80
  flex_block_mask_slice=None,
81
  ):
82
  for i, block in enumerate(w.blocks):
83
- layer_lora = select_layer_lora(
84
- lora, i, is_moe=config.moe is not None and i >= config.moe.start_layer
85
- )
 
 
 
 
86
 
87
  l_in = layer_norm(x, block.ln)
88
  l_attn = attn(
@@ -94,15 +107,14 @@ def text_decoder(
94
  n_heads=config.n_heads,
95
  n_kv_heads=config.n_kv_heads,
96
  position_ids=position_ids,
 
97
  flex_block_mask_slice=flex_block_mask_slice,
98
  )
99
 
100
  if config.moe is not None and i >= config.moe.start_layer:
101
- l_mlp = moe_mlp(
102
- l_in, block.mlp, config.moe.experts_per_token, lora=layer_lora
103
- )
104
  else:
105
- l_mlp = mlp(l_in, block.mlp, lora=layer_lora)
106
 
107
  x = x + l_attn + l_mlp
108
 
@@ -133,7 +145,7 @@ def build_dense_mlp(d_model, d_ffn, dtype, linear_cls):
133
 
134
  def build_moe_mlp(d_model, d_ffn, n_experts, dtype):
135
  # For GeGLU, fc1 needs to output 2 * d_ffn (for gating)
136
- mlp = nn.ModuleDict(
137
  {
138
  "router": nn.Linear(d_model, n_experts, dtype=dtype),
139
  "fc1": nn.ParameterDict(
@@ -152,7 +164,6 @@ def build_moe_mlp(d_model, d_ffn, n_experts, dtype):
152
  ),
153
  }
154
  )
155
- return mlp
156
 
157
 
158
  def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
 
8
  from .layers import layer_norm, mlp, QuantizedLinear, moe_mlp
9
  from .rope import apply_rotary_emb, precompute_freqs_cis
10
  from .config import TextConfig
 
11
 
12
 
13
  def text_encoder(input_ids: torch.Tensor, w: nn.Module):
 
23
  n_heads: int,
24
  n_kv_heads: int,
25
  position_ids: torch.Tensor,
26
+ lora: Optional[dict] = None,
27
  flex_block_mask_slice=None,
28
  ):
29
  bsz, q_len, d_model = x.shape
30
  head_dim = d_model // n_heads
31
 
32
  qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
33
+ if lora is not None:
34
+ qkv_out += F.linear(F.linear(x, lora["qkv"]["A"]), lora["qkv"]["B"])
35
  q_dim = n_heads * head_dim
36
  kv_dim = n_kv_heads * head_dim
37
  q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1)
 
69
 
70
  out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
71
 
72
+ out0 = w.proj(out)
73
+ if lora is not None:
74
+ out1 = F.linear(F.linear(x, lora["proj"]["A"]), lora["proj"]["B"])
75
+ out = out0 + out1
76
+ else:
77
+ out = out0
78
+
79
+ return out
80
 
81
 
82
  def text_decoder(
 
85
  attn_mask: torch.Tensor,
86
  position_ids: torch.Tensor,
87
  config: TextConfig,
88
+ lora: Optional[dict] = None,
89
  flex_block_mask_slice=None,
90
  ):
91
  for i, block in enumerate(w.blocks):
92
+ if lora is not None:
93
+ layer_lora = lora["text"]["blocks"][str(i)]
94
+ mlp_lora = layer_lora["mlp"]
95
+ attn_lora = layer_lora["attn"]
96
+ else:
97
+ mlp_lora = None
98
+ attn_lora = None
99
 
100
  l_in = layer_norm(x, block.ln)
101
  l_attn = attn(
 
107
  n_heads=config.n_heads,
108
  n_kv_heads=config.n_kv_heads,
109
  position_ids=position_ids,
110
+ lora=attn_lora,
111
  flex_block_mask_slice=flex_block_mask_slice,
112
  )
113
 
114
  if config.moe is not None and i >= config.moe.start_layer:
115
+ l_mlp = moe_mlp(l_in, block.mlp, config.moe.experts_per_token)
 
 
116
  else:
117
+ l_mlp = mlp(l_in, block.mlp, lora=mlp_lora)
118
 
119
  x = x + l_attn + l_mlp
120
 
 
145
 
146
  def build_moe_mlp(d_model, d_ffn, n_experts, dtype):
147
  # For GeGLU, fc1 needs to output 2 * d_ffn (for gating)
148
+ return nn.ModuleDict(
149
  {
150
  "router": nn.Linear(d_model, n_experts, dtype=dtype),
151
  "fc1": nn.ParameterDict(
 
164
  ),
165
  }
166
  )
 
167
 
168
 
169
  def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module: