silveroxides commited on
Commit
aa5c7de
·
1 Parent(s): 81c8dbd

Add image similarity mode using forward_embedding FEATURE_DIM descriptors

Browse files
app.py CHANGED
@@ -183,6 +183,20 @@ def _build_custom_pca(
183
  # ---------------------------------------------------------------------------
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  @spaces.GPU
187
  def _gpu_infer(pixel_values: torch.Tensor) -> torch.Tensor:
188
  """Move tensor to device, run model forward, return CPU logits."""
@@ -282,6 +296,47 @@ def get_pca(
282
  )
283
 
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  # ---- FastAPI routes --------------------------------------------------------
286
 
287
 
 
183
  # ---------------------------------------------------------------------------
184
 
185
 
186
+ @spaces.GPU
187
+ def _gpu_extract_descriptor(pixel_values: torch.Tensor) -> np.ndarray:
188
+ """Extract the FEATURE_DIM=6400 image descriptor via forward_embedding.
189
+ Returns a [6400] float32 numpy array on CPU.
190
+ """
191
+ pv = pixel_values.to(model.device)
192
+ with (
193
+ torch.no_grad(),
194
+ torch.autocast(device_type=model.device.type, dtype=model.dtype),
195
+ ):
196
+ features = model.model.forward_embedding(pv) # [1, 6400]
197
+ return features[0].cpu().numpy() # [6400]
198
+
199
+
200
  @spaces.GPU
201
  def _gpu_infer(pixel_values: torch.Tensor) -> torch.Tensor:
202
  """Move tensor to device, run model forward, return CPU logits."""
 
296
  )
297
 
298
 
299
+ @app.api(name="get_similarity")
300
+ def get_similarity(image_a: str, image_b: str, max_size: int = 1024) -> str:
301
+ """Extract FEATURE_DIM=6400 descriptors for two images and return their
302
+ cosine similarity.
303
+
304
+ Returns JSON:
305
+ {
306
+ "score": float, # cosine similarity in [-1, 1]
307
+ "desc_a": [6400 floats], # L2-normalised descriptor for image A
308
+ "desc_b": [6400 floats], # L2-normalised descriptor for image B
309
+ }
310
+ """
311
+ src_a = _resolve_image_source(image_a)
312
+ src_b = _resolve_image_source(image_b)
313
+
314
+ img_a = _open_image(src_a)
315
+ img_b = _open_image(src_b)
316
+
317
+ pv_a = _preprocess(img_a, max_size)
318
+ pv_b = _preprocess(img_b, max_size)
319
+
320
+ # Run both through the backbone in separate GPU calls
321
+ # (spaces.GPU does not support batching across different-sized tensors)
322
+ feat_a = _gpu_extract_descriptor(pv_a) # [6400]
323
+ feat_b = _gpu_extract_descriptor(pv_b) # [6400]
324
+
325
+ # L2-normalise
326
+ feat_a = feat_a / (np.linalg.norm(feat_a) + 1e-8)
327
+ feat_b = feat_b / (np.linalg.norm(feat_b) + 1e-8)
328
+
329
+ score = float(np.dot(feat_a, feat_b))
330
+
331
+ return json.dumps(
332
+ {
333
+ "score": round(score, 6),
334
+ "desc_a": feat_a.tolist(),
335
+ "desc_b": feat_b.tolist(),
336
+ }
337
+ )
338
+
339
+
340
  # ---- FastAPI routes --------------------------------------------------------
341
 
342
 
inference_tagger_standalone.py CHANGED
@@ -83,6 +83,7 @@ FEATURE_DIM = (1 + N_REGISTERS) * D_MODEL # 6400
83
  # RoPE helpers
84
  # ---------------------------------------------------------------------------
85
 
 
86
  @lru_cache(maxsize=32)
87
  def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor:
88
  device = torch.device(device_str)
@@ -94,11 +95,14 @@ def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor:
94
  return coords # [h*w, 2]
95
 
96
 
97
- def _build_rope(h_patches: int, w_patches: int,
98
- dtype: torch.dtype, device: torch.device):
 
99
  coords = _patch_coords_cached(h_patches, w_patches, str(device))
100
- inv_freq = 1.0 / (ROPE_THETA ** torch.arange(
101
- 0, 1, 4 / HEAD_DIM, dtype=torch.float32, device=device))
 
 
102
  angles = 2 * math.pi * coords[:, :, None] * inv_freq[None, None, :]
103
  angles = angles.flatten(1, 2).tile(2)
104
  cos = torch.cos(angles).to(dtype).unsqueeze(0).unsqueeze(0)
@@ -111,8 +115,7 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
111
  return torch.cat((-x[..., h:], x[..., :h]), dim=-1)
112
 
113
 
114
- def _apply_rope(q: torch.Tensor, k: torch.Tensor,
115
- cos: torch.Tensor, sin: torch.Tensor):
116
  n_pre = 1 + N_REGISTERS
117
  q_pre, q_pat = q[..., :n_pre, :], q[..., n_pre:, :]
118
  k_pre, k_pat = k[..., :n_pre, :], k[..., n_pre:, :]
@@ -125,6 +128,7 @@ def _apply_rope(q: torch.Tensor, k: torch.Tensor,
125
  # Transformer blocks
126
  # ---------------------------------------------------------------------------
127
 
 
128
  class _Attention(nn.Module):
129
  def __init__(self):
130
  super().__init__()
@@ -139,7 +143,7 @@ class _Attention(nn.Module):
139
  k = self.k_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
140
  v = self.v_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
141
  q, k = _apply_rope(q, k, cos, sin)
142
- out = F.scaled_dot_product_attention(q, k, v, scale=HEAD_DIM ** -0.5)
143
  return self.o_proj(out.transpose(1, 2).reshape(B, S, D_MODEL))
144
 
145
 
@@ -179,13 +183,15 @@ class _Embeddings(nn.Module):
179
  self.mask_token = nn.Parameter(torch.zeros(1, 1, D_MODEL))
180
  self.register_tokens = nn.Parameter(torch.zeros(1, N_REGISTERS, D_MODEL))
181
  self.patch_embeddings = nn.Conv2d(
182
- 3, D_MODEL, kernel_size=PATCH_SIZE, stride=PATCH_SIZE)
 
183
 
184
  def forward(self, pixel_values):
185
  B = pixel_values.shape[0]
186
  dtype = self.patch_embeddings.weight.dtype
187
- patches = self.patch_embeddings(
188
- pixel_values.to(dtype)).flatten(2).transpose(1, 2)
 
189
  cls = self.cls_token.expand(B, -1, -1)
190
  regs = self.register_tokens.expand(B, -1, -1)
191
  return torch.cat([cls, regs, patches], dim=1)
@@ -224,7 +230,7 @@ class DINOv3ViTH(nn.Module):
224
  x = block(x, cos, sin)
225
  x = self.norm(x)
226
  # token layout: [CLS, reg_0..reg_R-1, patch_0..patch_N]
227
- patch_tokens = x[:, 1 + N_REGISTERS:, :] # [B, h_p*w_p, D_MODEL]
228
  return patch_tokens, h_p, w_p
229
 
230
 
@@ -232,16 +238,18 @@ class DINOv3ViTH(nn.Module):
232
  # Head — auto-detected from the checkpoint
233
  # =============================================================================
234
 
 
235
  class _LowRankHead(nn.Module):
236
  """Two-matrix low-rank projection head.
237
 
238
- features (in_dim)
239
- → Linear(in_dim, rank, bias=?)
240
- → Linear(rank, num_tags, bias=?)
241
  """
242
 
243
- def __init__(self, in_dim: int, rank: int, num_tags: int,
244
- down_bias: bool, up_bias: bool):
 
245
  super().__init__()
246
  self.proj_down = nn.Linear(in_dim, rank, bias=down_bias)
247
  self.proj_up = nn.Linear(rank, num_tags, bias=up_bias)
@@ -265,15 +273,15 @@ def _build_head_from_checkpoint(
265
  Returns (module, remapped_state_dict) where the remapped state dict
266
  matches the module's own key names so strict loading works.
267
  """
268
- weights_2d = [(k, v) for k, v in head_sd.items()
269
- if k.endswith(".weight") and v.ndim == 2]
 
270
 
271
  # --- Case 1: single dense linear ---------------------------------------
272
- singles = [(k, v) for k, v in weights_2d
273
- if tuple(v.shape) == (num_tags, in_dim)]
274
  if len(weights_2d) <= 2 and len(singles) == 1:
275
  wkey, wval = singles[0]
276
- base = wkey[:-len(".weight")]
277
  bias_key = base + ".bias"
278
  has_bias = bias_key in head_sd
279
  module = nn.Linear(in_dim, num_tags, bias=has_bias)
@@ -285,12 +293,13 @@ def _build_head_from_checkpoint(
285
  extra = set(head_sd) - expected_src
286
  if extra:
287
  raise RuntimeError(
288
- f"Head has single-linear shape but extra unknown keys: {sorted(extra)}")
 
289
  return module, remapped
290
 
291
  # --- Case 2: low-rank pair ---------------------------------------------
292
  down = None # (key, tensor) with shape [rank, in_dim]
293
- up = None # (key, tensor) with shape [num_tags, rank]
294
  for k, v in weights_2d:
295
  if v.shape[1] == in_dim and v.shape[0] != num_tags:
296
  down = (k, v)
@@ -303,12 +312,13 @@ def _build_head_from_checkpoint(
303
  if rank_down != rank_up:
304
  raise RuntimeError(
305
  f"Low-rank head: inner dims disagree "
306
- f"(down out={rank_down}, up in={rank_up})")
 
307
 
308
  down_key, down_w = down
309
  up_key, up_w = up
310
- down_base = down_key[:-len(".weight")]
311
- up_base = up_key[:-len(".weight")]
312
  down_bias_key = down_base + ".bias"
313
  up_bias_key = up_base + ".bias"
314
  has_down_bias = down_bias_key in head_sd
@@ -340,11 +350,14 @@ def _build_head_from_checkpoint(
340
  if extra:
341
  raise RuntimeError(
342
  f"Low-rank head detected but checkpoint has extra unknown "
343
- f"head keys: {sorted(extra)}")
 
344
 
345
- print(f"[Tagger] Detected low-rank head: "
346
- f"in_dim={in_dim}, rank={rank_down}, num_tags={num_tags} "
347
- f"(down_bias={has_down_bias}, up_bias={has_up_bias})")
 
 
348
  return module, remapped
349
 
350
  raise RuntimeError(
@@ -357,6 +370,7 @@ def _build_head_from_checkpoint(
357
  # Tagger wrapper module
358
  # =============================================================================
359
 
 
360
  class DINOv3Tagger(nn.Module):
361
  """Backbone + head. The head is attached after the checkpoint is
362
  inspected (so we can build the right shape)."""
@@ -369,15 +383,26 @@ class DINOv3Tagger(nn.Module):
369
  def forward(self, pixel_values):
370
  hidden = self.backbone(pixel_values)
371
  cls = hidden[:, 0, :]
372
- regs = hidden[:, 1: 1 + N_REGISTERS, :].flatten(1)
373
  features = torch.cat([cls, regs], dim=-1).float() # fp32 for head
374
  return self.head(features)
375
 
 
 
 
 
 
 
 
 
 
 
376
 
377
  # =============================================================================
378
  # Checkpoint loading helpers
379
  # =============================================================================
380
 
 
381
  def _split_and_clean_state_dict(sd: dict) -> tuple[dict, dict]:
382
  """Split full state dict into (backbone_sd, head_sd), stripping the
383
  ``backbone.`` prefix and applying the remaps needed to match
@@ -395,10 +420,10 @@ def _split_and_clean_state_dict(sd: dict) -> tuple[dict, dict]:
395
  head_sd: dict = {}
396
  for k, v in sd.items():
397
  if k.startswith("backbone."):
398
- nk = k[len("backbone."):]
399
  # Remap (1): strip intermediate "model." before "layer."
400
  if nk.startswith("model.layer."):
401
- nk = nk[len("model."):]
402
  backbone_sd[nk] = v
403
  else:
404
  head_sd[k] = v
@@ -406,7 +431,7 @@ def _split_and_clean_state_dict(sd: dict) -> tuple[dict, dict]:
406
  # Remap (2): layer.N.layer_scale{1,2}.lambda1 → layer.N.layer_scale{1,2}
407
  for k in list(backbone_sd.keys()):
408
  if ".layer_scale" in k and k.endswith(".lambda1"):
409
- backbone_sd[k[:-len(".lambda1")]] = backbone_sd.pop(k)
410
 
411
  # Remap (3): drop rope buffers (recomputed on the fly)
412
  for k in list(backbone_sd.keys()):
@@ -454,18 +479,21 @@ def preprocess_image(source, max_size: int = 1024) -> torch.Tensor:
454
  new_w = _snap(max(PATCH_SIZE, round(w * scale)), PATCH_SIZE)
455
  new_h = _snap(max(PATCH_SIZE, round(h * scale)), PATCH_SIZE)
456
 
457
- return v2.Compose([
458
- v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
459
- v2.ToImage(),
460
- v2.ToDtype(torch.float32, scale=True),
461
- v2.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
462
- ])(img).unsqueeze(0)
 
 
463
 
464
 
465
  # =============================================================================
466
  # Tagger wrapper
467
  # =============================================================================
468
 
 
469
  class Tagger:
470
  """Inference wrapper for DINOv3Tagger (ViT-H/16+).
471
 
@@ -519,12 +547,15 @@ class Tagger:
519
 
520
  if not head_sd:
521
  raise RuntimeError(
522
- "Checkpoint contains no non-backbone keys — cannot build head.")
 
523
 
524
  # --- Build model, inferring head shape from the checkpoint --------
525
  self.model = DINOv3Tagger()
526
  head_module, head_sd_remapped = _build_head_from_checkpoint(
527
- head_sd, in_dim=FEATURE_DIM, num_tags=self.num_tags,
 
 
528
  )
529
  self.model.head = head_module
530
 
@@ -533,10 +564,8 @@ class Tagger:
533
  self.model.head.load_state_dict(head_sd_remapped, strict=True)
534
 
535
  # --- Move to device. Backbone → bf16/fp16; head stays fp32. --------
536
- self.model.backbone = self.model.backbone.to(
537
- device=self.device, dtype=dtype)
538
- self.model.head = self.model.head.to(
539
- device=self.device, dtype=torch.float32)
540
  self.model.eval()
541
  print(f"[Tagger] Ready on {self.device} (backbone={dtype}, head=fp32)")
542
 
@@ -571,12 +600,20 @@ class Tagger:
571
  scale = min(1.0, max_size / max(w, h))
572
  new_w = _snap(round(w * scale), PATCH_SIZE)
573
  new_h = _snap(round(h * scale), PATCH_SIZE)
574
- pv = v2.Compose([
575
- v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
576
- v2.ToImage(),
577
- v2.ToDtype(torch.float32, scale=True),
578
- v2.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
579
- ])(img).unsqueeze(0).to(self.device)
 
 
 
 
 
 
 
 
580
  else:
581
  pv = preprocess_image(image, max_size=max_size).to(self.device)
582
 
@@ -606,8 +643,9 @@ class Tagger:
606
  return Image.fromarray(rgb_uint8, mode="RGB")
607
 
608
  @torch.no_grad()
609
- def predict(self, image, topk: int | None = 30,
610
- threshold: float | None = None) -> list[tuple[str, float]]:
 
611
  """Tag a single image (local path or URL)."""
612
  if topk is None and threshold is None:
613
  topk = 30
@@ -625,20 +663,23 @@ class Tagger:
625
  order = values.argsort(descending=True)
626
  indices, values = indices[order], values[order]
627
 
628
- return [(self.idx2tag[i], float(v))
629
- for i, v in zip(indices.tolist(), values.tolist())]
 
 
630
 
631
  @torch.no_grad()
632
- def predict_batch(self, images, topk: int | None = 30,
633
- threshold: float | None = None):
634
- return [self.predict(img, topk=topk, threshold=threshold)
635
- for img in images]
636
 
637
 
638
  # =============================================================================
639
  # Output formatters
640
  # =============================================================================
641
 
 
642
  def _fmt_pretty(path: str, results) -> str:
643
  lines = [f"\n{'─' * 60}", f" {path}", f"{'─' * 60}"]
644
  for rank, (tag, score) in enumerate(results, 1):
@@ -652,38 +693,53 @@ def _fmt_tags(results) -> str:
652
 
653
 
654
  def _fmt_json(path: str, results) -> dict:
655
- return {"file": path,
656
- "tags": [{"tag": t, "score": round(s, 4)} for t, s in results]}
 
 
657
 
658
 
659
  # =============================================================================
660
  # CLI
661
  # =============================================================================
662
 
 
663
  def main():
664
  parser = argparse.ArgumentParser(
665
  description="DINOv3 ViT-H/16+ tagger inference (standalone)",
666
  formatter_class=argparse.RawDescriptionHelpFormatter,
667
  )
668
- parser.add_argument("--checkpoint", required=True,
669
- help="Path to .safetensors or .pt checkpoint")
670
- parser.add_argument("--vocab", required=True,
671
- help="Path to tagger_vocab*.json")
672
- parser.add_argument("--images", nargs="+", required=True,
673
- help="Image paths and/or http(s) URLs")
674
- parser.add_argument("--device", default="cuda",
675
- help="Device: cuda, cuda:0, cpu (default: cuda)")
676
- parser.add_argument("--max-size", type=int, default=1024,
677
- help="Long-edge cap in pixels (default: 1024)")
 
 
 
 
 
 
678
 
679
  mode = parser.add_mutually_exclusive_group()
680
- mode.add_argument("--topk", type=int, default=30,
681
- help="Return top-k tags (default: 30)")
682
- mode.add_argument("--threshold", type=float,
683
- help="Return all tags with score >= threshold")
 
 
684
 
685
- parser.add_argument("--format", choices=["pretty", "tags", "json"],
686
- default="pretty", help="Output format (default: pretty)")
 
 
 
 
687
  args = parser.parse_args()
688
 
689
  tagger = Tagger(
@@ -693,9 +749,7 @@ def main():
693
  max_size=args.max_size,
694
  )
695
 
696
- topk, threshold = (
697
- (None, args.threshold) if args.threshold else (args.topk, None)
698
- )
699
  json_out = []
700
 
701
  for src in args.images:
@@ -716,4 +770,4 @@ def main():
716
 
717
 
718
  if __name__ == "__main__":
719
- main()
 
83
  # RoPE helpers
84
  # ---------------------------------------------------------------------------
85
 
86
+
87
  @lru_cache(maxsize=32)
88
  def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor:
89
  device = torch.device(device_str)
 
95
  return coords # [h*w, 2]
96
 
97
 
98
+ def _build_rope(
99
+ h_patches: int, w_patches: int, dtype: torch.dtype, device: torch.device
100
+ ):
101
  coords = _patch_coords_cached(h_patches, w_patches, str(device))
102
+ inv_freq = 1.0 / (
103
+ ROPE_THETA
104
+ ** torch.arange(0, 1, 4 / HEAD_DIM, dtype=torch.float32, device=device)
105
+ )
106
  angles = 2 * math.pi * coords[:, :, None] * inv_freq[None, None, :]
107
  angles = angles.flatten(1, 2).tile(2)
108
  cos = torch.cos(angles).to(dtype).unsqueeze(0).unsqueeze(0)
 
115
  return torch.cat((-x[..., h:], x[..., :h]), dim=-1)
116
 
117
 
118
+ def _apply_rope(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
 
119
  n_pre = 1 + N_REGISTERS
120
  q_pre, q_pat = q[..., :n_pre, :], q[..., n_pre:, :]
121
  k_pre, k_pat = k[..., :n_pre, :], k[..., n_pre:, :]
 
128
  # Transformer blocks
129
  # ---------------------------------------------------------------------------
130
 
131
+
132
  class _Attention(nn.Module):
133
  def __init__(self):
134
  super().__init__()
 
143
  k = self.k_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
144
  v = self.v_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
145
  q, k = _apply_rope(q, k, cos, sin)
146
+ out = F.scaled_dot_product_attention(q, k, v, scale=HEAD_DIM**-0.5)
147
  return self.o_proj(out.transpose(1, 2).reshape(B, S, D_MODEL))
148
 
149
 
 
183
  self.mask_token = nn.Parameter(torch.zeros(1, 1, D_MODEL))
184
  self.register_tokens = nn.Parameter(torch.zeros(1, N_REGISTERS, D_MODEL))
185
  self.patch_embeddings = nn.Conv2d(
186
+ 3, D_MODEL, kernel_size=PATCH_SIZE, stride=PATCH_SIZE
187
+ )
188
 
189
  def forward(self, pixel_values):
190
  B = pixel_values.shape[0]
191
  dtype = self.patch_embeddings.weight.dtype
192
+ patches = (
193
+ self.patch_embeddings(pixel_values.to(dtype)).flatten(2).transpose(1, 2)
194
+ )
195
  cls = self.cls_token.expand(B, -1, -1)
196
  regs = self.register_tokens.expand(B, -1, -1)
197
  return torch.cat([cls, regs, patches], dim=1)
 
230
  x = block(x, cos, sin)
231
  x = self.norm(x)
232
  # token layout: [CLS, reg_0..reg_R-1, patch_0..patch_N]
233
+ patch_tokens = x[:, 1 + N_REGISTERS :, :] # [B, h_p*w_p, D_MODEL]
234
  return patch_tokens, h_p, w_p
235
 
236
 
 
238
  # Head — auto-detected from the checkpoint
239
  # =============================================================================
240
 
241
+
242
  class _LowRankHead(nn.Module):
243
  """Two-matrix low-rank projection head.
244
 
245
+ features (in_dim)
246
+ → Linear(in_dim, rank, bias=?)
247
+ → Linear(rank, num_tags, bias=?)
248
  """
249
 
250
+ def __init__(
251
+ self, in_dim: int, rank: int, num_tags: int, down_bias: bool, up_bias: bool
252
+ ):
253
  super().__init__()
254
  self.proj_down = nn.Linear(in_dim, rank, bias=down_bias)
255
  self.proj_up = nn.Linear(rank, num_tags, bias=up_bias)
 
273
  Returns (module, remapped_state_dict) where the remapped state dict
274
  matches the module's own key names so strict loading works.
275
  """
276
+ weights_2d = [
277
+ (k, v) for k, v in head_sd.items() if k.endswith(".weight") and v.ndim == 2
278
+ ]
279
 
280
  # --- Case 1: single dense linear ---------------------------------------
281
+ singles = [(k, v) for k, v in weights_2d if tuple(v.shape) == (num_tags, in_dim)]
 
282
  if len(weights_2d) <= 2 and len(singles) == 1:
283
  wkey, wval = singles[0]
284
+ base = wkey[: -len(".weight")]
285
  bias_key = base + ".bias"
286
  has_bias = bias_key in head_sd
287
  module = nn.Linear(in_dim, num_tags, bias=has_bias)
 
293
  extra = set(head_sd) - expected_src
294
  if extra:
295
  raise RuntimeError(
296
+ f"Head has single-linear shape but extra unknown keys: {sorted(extra)}"
297
+ )
298
  return module, remapped
299
 
300
  # --- Case 2: low-rank pair ---------------------------------------------
301
  down = None # (key, tensor) with shape [rank, in_dim]
302
+ up = None # (key, tensor) with shape [num_tags, rank]
303
  for k, v in weights_2d:
304
  if v.shape[1] == in_dim and v.shape[0] != num_tags:
305
  down = (k, v)
 
312
  if rank_down != rank_up:
313
  raise RuntimeError(
314
  f"Low-rank head: inner dims disagree "
315
+ f"(down out={rank_down}, up in={rank_up})"
316
+ )
317
 
318
  down_key, down_w = down
319
  up_key, up_w = up
320
+ down_base = down_key[: -len(".weight")]
321
+ up_base = up_key[: -len(".weight")]
322
  down_bias_key = down_base + ".bias"
323
  up_bias_key = up_base + ".bias"
324
  has_down_bias = down_bias_key in head_sd
 
350
  if extra:
351
  raise RuntimeError(
352
  f"Low-rank head detected but checkpoint has extra unknown "
353
+ f"head keys: {sorted(extra)}"
354
+ )
355
 
356
+ print(
357
+ f"[Tagger] Detected low-rank head: "
358
+ f"in_dim={in_dim}, rank={rank_down}, num_tags={num_tags} "
359
+ f"(down_bias={has_down_bias}, up_bias={has_up_bias})"
360
+ )
361
  return module, remapped
362
 
363
  raise RuntimeError(
 
370
  # Tagger wrapper module
371
  # =============================================================================
372
 
373
+
374
  class DINOv3Tagger(nn.Module):
375
  """Backbone + head. The head is attached after the checkpoint is
376
  inspected (so we can build the right shape)."""
 
383
  def forward(self, pixel_values):
384
  hidden = self.backbone(pixel_values)
385
  cls = hidden[:, 0, :]
386
+ regs = hidden[:, 1 : 1 + N_REGISTERS, :].flatten(1)
387
  features = torch.cat([cls, regs], dim=-1).float() # fp32 for head
388
  return self.head(features)
389
 
390
+ def forward_embedding(self, pixel_values):
391
+ """Return the FEATURE_DIM=6400 image descriptor without applying the head.
392
+ Same as forward() but stops before self.head — use this for similarity queries.
393
+ """
394
+ hidden = self.backbone(pixel_values)
395
+ cls = hidden[:, 0, :]
396
+ regs = hidden[:, 1 : 1 + N_REGISTERS, :].flatten(1)
397
+ features = torch.cat([cls, regs], dim=-1).float() # fp32 for head
398
+ return features
399
+
400
 
401
  # =============================================================================
402
  # Checkpoint loading helpers
403
  # =============================================================================
404
 
405
+
406
  def _split_and_clean_state_dict(sd: dict) -> tuple[dict, dict]:
407
  """Split full state dict into (backbone_sd, head_sd), stripping the
408
  ``backbone.`` prefix and applying the remaps needed to match
 
420
  head_sd: dict = {}
421
  for k, v in sd.items():
422
  if k.startswith("backbone."):
423
+ nk = k[len("backbone.") :]
424
  # Remap (1): strip intermediate "model." before "layer."
425
  if nk.startswith("model.layer."):
426
+ nk = nk[len("model.") :]
427
  backbone_sd[nk] = v
428
  else:
429
  head_sd[k] = v
 
431
  # Remap (2): layer.N.layer_scale{1,2}.lambda1 → layer.N.layer_scale{1,2}
432
  for k in list(backbone_sd.keys()):
433
  if ".layer_scale" in k and k.endswith(".lambda1"):
434
+ backbone_sd[k[: -len(".lambda1")]] = backbone_sd.pop(k)
435
 
436
  # Remap (3): drop rope buffers (recomputed on the fly)
437
  for k in list(backbone_sd.keys()):
 
479
  new_w = _snap(max(PATCH_SIZE, round(w * scale)), PATCH_SIZE)
480
  new_h = _snap(max(PATCH_SIZE, round(h * scale)), PATCH_SIZE)
481
 
482
+ return v2.Compose(
483
+ [
484
+ v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
485
+ v2.ToImage(),
486
+ v2.ToDtype(torch.float32, scale=True),
487
+ v2.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
488
+ ]
489
+ )(img).unsqueeze(0)
490
 
491
 
492
  # =============================================================================
493
  # Tagger wrapper
494
  # =============================================================================
495
 
496
+
497
  class Tagger:
498
  """Inference wrapper for DINOv3Tagger (ViT-H/16+).
499
 
 
547
 
548
  if not head_sd:
549
  raise RuntimeError(
550
+ "Checkpoint contains no non-backbone keys — cannot build head."
551
+ )
552
 
553
  # --- Build model, inferring head shape from the checkpoint --------
554
  self.model = DINOv3Tagger()
555
  head_module, head_sd_remapped = _build_head_from_checkpoint(
556
+ head_sd,
557
+ in_dim=FEATURE_DIM,
558
+ num_tags=self.num_tags,
559
  )
560
  self.model.head = head_module
561
 
 
564
  self.model.head.load_state_dict(head_sd_remapped, strict=True)
565
 
566
  # --- Move to device. Backbone → bf16/fp16; head stays fp32. --------
567
+ self.model.backbone = self.model.backbone.to(device=self.device, dtype=dtype)
568
+ self.model.head = self.model.head.to(device=self.device, dtype=torch.float32)
 
 
569
  self.model.eval()
570
  print(f"[Tagger] Ready on {self.device} (backbone={dtype}, head=fp32)")
571
 
 
600
  scale = min(1.0, max_size / max(w, h))
601
  new_w = _snap(round(w * scale), PATCH_SIZE)
602
  new_h = _snap(round(h * scale), PATCH_SIZE)
603
+ pv = (
604
+ v2.Compose(
605
+ [
606
+ v2.Resize(
607
+ (new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS
608
+ ),
609
+ v2.ToImage(),
610
+ v2.ToDtype(torch.float32, scale=True),
611
+ v2.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
612
+ ]
613
+ )(img)
614
+ .unsqueeze(0)
615
+ .to(self.device)
616
+ )
617
  else:
618
  pv = preprocess_image(image, max_size=max_size).to(self.device)
619
 
 
643
  return Image.fromarray(rgb_uint8, mode="RGB")
644
 
645
  @torch.no_grad()
646
+ def predict(
647
+ self, image, topk: int | None = 30, threshold: float | None = None
648
+ ) -> list[tuple[str, float]]:
649
  """Tag a single image (local path or URL)."""
650
  if topk is None and threshold is None:
651
  topk = 30
 
663
  order = values.argsort(descending=True)
664
  indices, values = indices[order], values[order]
665
 
666
+ return [
667
+ (self.idx2tag[i], float(v))
668
+ for i, v in zip(indices.tolist(), values.tolist())
669
+ ]
670
 
671
  @torch.no_grad()
672
+ def predict_batch(
673
+ self, images, topk: int | None = 30, threshold: float | None = None
674
+ ):
675
+ return [self.predict(img, topk=topk, threshold=threshold) for img in images]
676
 
677
 
678
  # =============================================================================
679
  # Output formatters
680
  # =============================================================================
681
 
682
+
683
  def _fmt_pretty(path: str, results) -> str:
684
  lines = [f"\n{'─' * 60}", f" {path}", f"{'─' * 60}"]
685
  for rank, (tag, score) in enumerate(results, 1):
 
693
 
694
 
695
  def _fmt_json(path: str, results) -> dict:
696
+ return {
697
+ "file": path,
698
+ "tags": [{"tag": t, "score": round(s, 4)} for t, s in results],
699
+ }
700
 
701
 
702
  # =============================================================================
703
  # CLI
704
  # =============================================================================
705
 
706
+
707
  def main():
708
  parser = argparse.ArgumentParser(
709
  description="DINOv3 ViT-H/16+ tagger inference (standalone)",
710
  formatter_class=argparse.RawDescriptionHelpFormatter,
711
  )
712
+ parser.add_argument(
713
+ "--checkpoint", required=True, help="Path to .safetensors or .pt checkpoint"
714
+ )
715
+ parser.add_argument("--vocab", required=True, help="Path to tagger_vocab*.json")
716
+ parser.add_argument(
717
+ "--images", nargs="+", required=True, help="Image paths and/or http(s) URLs"
718
+ )
719
+ parser.add_argument(
720
+ "--device", default="cuda", help="Device: cuda, cuda:0, cpu (default: cuda)"
721
+ )
722
+ parser.add_argument(
723
+ "--max-size",
724
+ type=int,
725
+ default=1024,
726
+ help="Long-edge cap in pixels (default: 1024)",
727
+ )
728
 
729
  mode = parser.add_mutually_exclusive_group()
730
+ mode.add_argument(
731
+ "--topk", type=int, default=30, help="Return top-k tags (default: 30)"
732
+ )
733
+ mode.add_argument(
734
+ "--threshold", type=float, help="Return all tags with score >= threshold"
735
+ )
736
 
737
+ parser.add_argument(
738
+ "--format",
739
+ choices=["pretty", "tags", "json"],
740
+ default="pretty",
741
+ help="Output format (default: pretty)",
742
+ )
743
  args = parser.parse_args()
744
 
745
  tagger = Tagger(
 
749
  max_size=args.max_size,
750
  )
751
 
752
+ topk, threshold = (None, args.threshold) if args.threshold else (args.topk, None)
 
 
753
  json_out = []
754
 
755
  for src in args.images:
 
770
 
771
 
772
  if __name__ == "__main__":
773
+ main()
tagger_ui/templates/index.html CHANGED
@@ -259,6 +259,81 @@
259
  .tag-pill:hover { opacity: .8; }
260
  .tag-pill .score { font-size: .66rem; opacity: .7; }
261
  .tag-pill.hidden { display: none; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  </style>
263
  </head>
264
  <body>
@@ -266,7 +341,14 @@
266
  <h1>DINOv3 <span>Tagger</span></h1>
267
  <p class="subtitle">ViT-H/16+ · {{ num_tags | format_number }} tags · {{ vocab_path }}</p>
268
 
269
- <div class="layout">
 
 
 
 
 
 
 
270
 
271
  <!-- ====== LEFT PANEL ====== -->
272
  <div class="panel-left">
@@ -367,7 +449,57 @@
367
  </div>
368
  </div><!-- /panel-right -->
369
 
370
- </div><!-- /layout -->
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
  <script>
373
  // ---- category metadata from server ----
@@ -653,7 +785,7 @@
653
  if (e.key === 'Enter') runFromUrl();
654
  });
655
 
656
- // drag & drop
657
  const dz = document.getElementById('drop-zone');
658
  dz.addEventListener('dragover', e => { e.preventDefault(); dz.classList.add('drag-over'); });
659
  dz.addEventListener('dragleave', () => dz.classList.remove('drag-over'));
@@ -662,6 +794,55 @@
662
  const file = e.dataTransfer.files[0];
663
  if (file) stageFile(file);
664
  });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
  </script>
666
 
667
  <!-- @gradio/client module: patches runFromUrl / submitFile / runPca on window
@@ -773,6 +954,60 @@
773
  // Re-run with current colour pickers — re-submits full request (backbone
774
  // result is cached by the Gradio queue so subsequent calls are fast if
775
  // the same image/max_size is used, but ZeroGPU requires a full round-trip).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776
  window.rerunCustomPca = async function() {
777
  if (!_lastPcaRequest) return;
778
  const spinner = document.getElementById('pca-spinner');
 
259
  .tag-pill:hover { opacity: .8; }
260
  .tag-pill .score { font-size: .66rem; opacity: .7; }
261
  .tag-pill.hidden { display: none; }
262
+
263
+ /* ---- similarity drop zones ---- */
264
+ .drop-zone-sim {
265
+ border: 2px dashed var(--border); border-radius: var(--radius);
266
+ color: var(--muted); cursor: pointer; font-size: .85rem;
267
+ padding: 1.2rem; text-align: center;
268
+ transition: border-color .15s, background .15s;
269
+ }
270
+ .drop-zone-sim.drag-over { border-color: var(--accent); background: rgba(124,106,247,.06); }
271
+
272
+ /* ---- mode toggle ---- */
273
+ .mode-toggle {
274
+ display: flex; gap: .5rem; margin-bottom: 1.5rem;
275
+ }
276
+ .mode-btn {
277
+ background: var(--surface); border: 1px solid var(--border);
278
+ border-radius: var(--radius); color: var(--muted); cursor: pointer;
279
+ font-size: .9rem; font-weight: 600; padding: .5rem 1.4rem;
280
+ transition: border-color .15s, color .15s, background .15s;
281
+ }
282
+ .mode-btn:hover { border-color: var(--accent); color: var(--text); }
283
+ .mode-btn.active { border-color: var(--accent); color: #fff; background: var(--accent); }
284
+
285
+ /* ---- similarity panel ---- */
286
+ #similarity-panel { display: none; width: 100%; max-width: 1600px; }
287
+
288
+ .sim-inputs {
289
+ display: flex; gap: 1.25rem; flex-wrap: wrap; margin-bottom: 1rem;
290
+ }
291
+ .sim-inputs .card { flex: 1 1 0; min-width: 260px; }
292
+
293
+ .sim-run-row {
294
+ display: flex; justify-content: center; margin-bottom: 1.5rem;
295
+ }
296
+
297
+ /* score display */
298
+ .sim-score-card {
299
+ background: var(--surface); border: 1px solid var(--border);
300
+ border-radius: var(--radius); padding: 1.25rem 1.5rem;
301
+ margin-bottom: 1.25rem; display: none;
302
+ }
303
+ .sim-score-label {
304
+ font-size: .8rem; color: var(--muted); margin-bottom: .5rem;
305
+ text-transform: uppercase; letter-spacing: .05em;
306
+ }
307
+ .sim-score-value {
308
+ font-size: 2.4rem; font-weight: 700; letter-spacing: -.02em;
309
+ margin-bottom: .6rem;
310
+ }
311
+ .sim-score-bar-bg {
312
+ height: 8px; background: var(--border); border-radius: 4px; overflow: hidden;
313
+ }
314
+ .sim-score-bar-fill {
315
+ height: 100%; border-radius: 4px; background: var(--accent);
316
+ transition: width .4s ease;
317
+ }
318
+ .sim-score-note {
319
+ font-size: .75rem; color: var(--muted); margin-top: .4rem;
320
+ }
321
+
322
+ /* image previews in similarity mode */
323
+ .sim-previews {
324
+ display: flex; gap: 1.25rem; flex-wrap: wrap;
325
+ }
326
+ .sim-preview-col {
327
+ flex: 1 1 0; min-width: 0;
328
+ }
329
+ .sim-preview-col img {
330
+ width: 100%; border-radius: var(--radius); border: 1px solid var(--border);
331
+ object-fit: contain; max-height: 480px; display: block;
332
+ }
333
+ .sim-preview-label {
334
+ font-size: .75rem; color: var(--muted); margin-bottom: .4rem;
335
+ text-transform: uppercase; letter-spacing: .04em;
336
+ }
337
  </style>
338
  </head>
339
  <body>
 
341
  <h1>DINOv3 <span>Tagger</span></h1>
342
  <p class="subtitle">ViT-H/16+ · {{ num_tags | format_number }} tags · {{ vocab_path }}</p>
343
 
344
+ <!-- mode toggle -->
345
+ <div class="mode-toggle">
346
+ <button class="mode-btn active" id="mode-btn-tagger" onclick="setMode('tagger')">Tagger</button>
347
+ <button class="mode-btn" id="mode-btn-similarity" onclick="setMode('similarity')">Similarity</button>
348
+ </div>
349
+
350
+ <!-- ===== TAGGER MODE ===== -->
351
+ <div id="tagger-panels" class="layout">
352
 
353
  <!-- ====== LEFT PANEL ====== -->
354
  <div class="panel-left">
 
449
  </div>
450
  </div><!-- /panel-right -->
451
 
452
+ </div><!-- /tagger-panels -->
453
+
454
+ <!-- ===== SIMILARITY MODE ===== -->
455
+ <div id="similarity-panel">
456
+
457
+ <div class="sim-inputs">
458
+ <!-- Image A -->
459
+ <div class="card">
460
+ <div style="font-size:.8rem;font-weight:600;color:var(--muted);text-transform:uppercase;letter-spacing:.05em;margin-bottom:.75rem">Image A</div>
461
+ <div class="input-row">
462
+ <input type="text" id="sim-url-a" placeholder="Paste URL…" />
463
+ </div>
464
+ <div id="sim-drop-a" class="drop-zone-sim" onclick="document.getElementById('sim-file-a').click()">
465
+ <input type="file" id="sim-file-a" accept="image/*" style="display:none" onchange="simStageFile('a', this)" />
466
+ Drop image A here or <strong>click to browse</strong>
467
+ </div>
468
+ <img id="sim-preview-a" src="" alt="" style="display:none;width:100%;margin-top:.75rem;border-radius:var(--radius);border:1px solid var(--border);max-height:300px;object-fit:contain" />
469
+ </div>
470
+
471
+ <!-- Image B -->
472
+ <div class="card">
473
+ <div style="font-size:.8rem;font-weight:600;color:var(--muted);text-transform:uppercase;letter-spacing:.05em;margin-bottom:.75rem">Image B</div>
474
+ <div class="input-row">
475
+ <input type="text" id="sim-url-b" placeholder="Paste URL…" />
476
+ </div>
477
+ <div id="sim-drop-b" class="drop-zone-sim" onclick="document.getElementById('sim-file-b').click()">
478
+ <input type="file" id="sim-file-b" accept="image/*" style="display:none" onchange="simStageFile('b', this)" />
479
+ Drop image B here or <strong>click to browse</strong>
480
+ </div>
481
+ <img id="sim-preview-b" src="" alt="" style="display:none;width:100%;margin-top:.75rem;border-radius:var(--radius);border:1px solid var(--border);max-height:300px;object-fit:contain" />
482
+ </div>
483
+ </div>
484
+
485
+ <div class="sim-run-row">
486
+ <button class="btn" id="sim-run-btn" onclick="runSimilarity()">Compare</button>
487
+ </div>
488
+
489
+ <div class="spinner" id="sim-spinner" style="display:none"></div>
490
+ <div class="error-msg" id="sim-error" style="display:none"></div>
491
+
492
+ <!-- score -->
493
+ <div class="sim-score-card" id="sim-score-card">
494
+ <div class="sim-score-label">Cosine Similarity (FEATURE_DIM descriptor)</div>
495
+ <div class="sim-score-value" id="sim-score-value">—</div>
496
+ <div class="sim-score-bar-bg">
497
+ <div class="sim-score-bar-fill" id="sim-score-bar" style="width:0%"></div>
498
+ </div>
499
+ <div class="sim-score-note" id="sim-score-note"></div>
500
+ </div>
501
+
502
+ </div><!-- /similarity-panel -->
503
 
504
  <script>
505
  // ---- category metadata from server ----
 
785
  if (e.key === 'Enter') runFromUrl();
786
  });
787
 
788
+ // drag & drop — tagger
789
  const dz = document.getElementById('drop-zone');
790
  dz.addEventListener('dragover', e => { e.preventDefault(); dz.classList.add('drag-over'); });
791
  dz.addEventListener('dragleave', () => dz.classList.remove('drag-over'));
 
794
  const file = e.dataTransfer.files[0];
795
  if (file) stageFile(file);
796
  });
797
+
798
+ // ---- mode toggle ----
799
+ function setMode(mode) {
800
+ document.getElementById('tagger-panels').style.display = mode === 'tagger' ? 'flex' : 'none';
801
+ document.getElementById('similarity-panel').style.display = mode === 'similarity' ? 'block' : 'none';
802
+ document.getElementById('mode-btn-tagger').classList.toggle('active', mode === 'tagger');
803
+ document.getElementById('mode-btn-similarity').classList.toggle('active', mode === 'similarity');
804
+ }
805
+
806
+ // ---- similarity: staged files ----
807
+ const _simStaged = { a: null, b: null };
808
+
809
+ function simStageFile(side, input) {
810
+ const file = input.files[0];
811
+ if (!file) return;
812
+ _simStaged[side] = file;
813
+ const reader = new FileReader();
814
+ reader.onload = e => {
815
+ const img = document.getElementById(`sim-preview-${side}`);
816
+ img.src = e.target.result;
817
+ img.style.display = 'block';
818
+ };
819
+ reader.readAsDataURL(file);
820
+ }
821
+
822
+ // drag & drop — similarity A
823
+ function _wireDrop(dzId, side) {
824
+ const el = document.getElementById(dzId);
825
+ el.addEventListener('dragover', e => { e.preventDefault(); el.classList.add('drag-over'); });
826
+ el.addEventListener('dragleave', () => el.classList.remove('drag-over'));
827
+ el.addEventListener('drop', e => {
828
+ e.preventDefault(); el.classList.remove('drag-over');
829
+ const file = e.dataTransfer.files[0];
830
+ if (!file) return;
831
+ _simStaged[side] = file;
832
+ const reader = new FileReader();
833
+ reader.onload = ev => {
834
+ const img = document.getElementById(`sim-preview-${side}`);
835
+ img.src = ev.target.result;
836
+ img.style.display = 'block';
837
+ };
838
+ reader.readAsDataURL(file);
839
+ });
840
+ }
841
+ _wireDrop('sim-drop-a', 'a');
842
+ _wireDrop('sim-drop-b', 'b');
843
+
844
+ // placeholder — replaced by module script
845
+ function runSimilarity() {}
846
  </script>
847
 
848
  <!-- @gradio/client module: patches runFromUrl / submitFile / runPca on window
 
954
  // Re-run with current colour pickers — re-submits full request (backbone
955
  // result is cached by the Gradio queue so subsequent calls are fast if
956
  // the same image/max_size is used, but ZeroGPU requires a full round-trip).
957
+ // ---- similarity ----
958
+
959
+ window.runSimilarity = async function() {
960
+ const urlA = document.getElementById('sim-url-a').value.trim();
961
+ const urlB = document.getElementById('sim-url-b').value.trim();
962
+
963
+ const imageArgA = urlA ? urlA : (_simStaged.a ? await handle_file(_simStaged.a) : null);
964
+ const imageArgB = urlB ? urlB : (_simStaged.b ? await handle_file(_simStaged.b) : null);
965
+
966
+ if (!imageArgA || !imageArgB) {
967
+ document.getElementById('sim-error').textContent = 'Provide both images before comparing.';
968
+ document.getElementById('sim-error').style.display = 'block';
969
+ return;
970
+ }
971
+ document.getElementById('sim-error').style.display = 'none';
972
+ document.getElementById('sim-run-btn').disabled = true;
973
+ document.getElementById('sim-spinner').style.display = 'block';
974
+ document.getElementById('sim-score-card').style.display = 'none';
975
+
976
+ try {
977
+ const res = await gradioApp.predict("/get_similarity", {
978
+ image_a: imageArgA,
979
+ image_b: imageArgB,
980
+ max_size: 1024,
981
+ });
982
+ const data = JSON.parse(res.data[0]);
983
+ const score = data.score; // [-1, 1]
984
+ const pct = Math.round(((score + 1) / 2) * 100); // map to [0,100]
985
+
986
+ document.getElementById('sim-score-value').textContent = score.toFixed(4);
987
+ document.getElementById('sim-score-bar').style.width = pct + '%';
988
+ document.getElementById('sim-score-note').textContent =
989
+ score > 0.9 ? 'Very high similarity — nearly identical semantic content.' :
990
+ score > 0.7 ? 'High similarity — strongly related images.' :
991
+ score > 0.5 ? 'Moderate similarity — some shared features.' :
992
+ score > 0.2 ? 'Low similarity — loosely related.' :
993
+ 'Very low similarity — likely unrelated images.';
994
+
995
+ // colour the bar by score
996
+ const barEl = document.getElementById('sim-score-bar');
997
+ barEl.style.background =
998
+ score > 0.7 ? '#4ade80' :
999
+ score > 0.4 ? '#facc15' : '#f87171';
1000
+
1001
+ document.getElementById('sim-score-card').style.display = 'block';
1002
+ } catch (err) {
1003
+ document.getElementById('sim-error').textContent = String(err);
1004
+ document.getElementById('sim-error').style.display = 'block';
1005
+ } finally {
1006
+ document.getElementById('sim-run-btn').disabled = false;
1007
+ document.getElementById('sim-spinner').style.display = 'none';
1008
+ }
1009
+ };
1010
+
1011
  window.rerunCustomPca = async function() {
1012
  if (!_lastPcaRequest) return;
1013
  const spinner = document.getElementById('pca-spinner');