Spaces:
Running on Zero
Running on Zero
Commit ·
aa5c7de
1
Parent(s): 81c8dbd
Add image similarity mode using forward_embedding FEATURE_DIM descriptors
Browse files- app.py +55 -0
- inference_tagger_standalone.py +136 -82
- tagger_ui/templates/index.html +238 -3
app.py
CHANGED
|
@@ -183,6 +183,20 @@ def _build_custom_pca(
|
|
| 183 |
# ---------------------------------------------------------------------------
|
| 184 |
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
@spaces.GPU
|
| 187 |
def _gpu_infer(pixel_values: torch.Tensor) -> torch.Tensor:
|
| 188 |
"""Move tensor to device, run model forward, return CPU logits."""
|
|
@@ -282,6 +296,47 @@ def get_pca(
|
|
| 282 |
)
|
| 283 |
|
| 284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
# ---- FastAPI routes --------------------------------------------------------
|
| 286 |
|
| 287 |
|
|
|
|
| 183 |
# ---------------------------------------------------------------------------
|
| 184 |
|
| 185 |
|
| 186 |
+
@spaces.GPU
|
| 187 |
+
def _gpu_extract_descriptor(pixel_values: torch.Tensor) -> np.ndarray:
|
| 188 |
+
"""Extract the FEATURE_DIM=6400 image descriptor via forward_embedding.
|
| 189 |
+
Returns a [6400] float32 numpy array on CPU.
|
| 190 |
+
"""
|
| 191 |
+
pv = pixel_values.to(model.device)
|
| 192 |
+
with (
|
| 193 |
+
torch.no_grad(),
|
| 194 |
+
torch.autocast(device_type=model.device.type, dtype=model.dtype),
|
| 195 |
+
):
|
| 196 |
+
features = model.model.forward_embedding(pv) # [1, 6400]
|
| 197 |
+
return features[0].cpu().numpy() # [6400]
|
| 198 |
+
|
| 199 |
+
|
| 200 |
@spaces.GPU
|
| 201 |
def _gpu_infer(pixel_values: torch.Tensor) -> torch.Tensor:
|
| 202 |
"""Move tensor to device, run model forward, return CPU logits."""
|
|
|
|
| 296 |
)
|
| 297 |
|
| 298 |
|
| 299 |
+
@app.api(name="get_similarity")
|
| 300 |
+
def get_similarity(image_a: str, image_b: str, max_size: int = 1024) -> str:
|
| 301 |
+
"""Extract FEATURE_DIM=6400 descriptors for two images and return their
|
| 302 |
+
cosine similarity.
|
| 303 |
+
|
| 304 |
+
Returns JSON:
|
| 305 |
+
{
|
| 306 |
+
"score": float, # cosine similarity in [-1, 1]
|
| 307 |
+
"desc_a": [6400 floats], # L2-normalised descriptor for image A
|
| 308 |
+
"desc_b": [6400 floats], # L2-normalised descriptor for image B
|
| 309 |
+
}
|
| 310 |
+
"""
|
| 311 |
+
src_a = _resolve_image_source(image_a)
|
| 312 |
+
src_b = _resolve_image_source(image_b)
|
| 313 |
+
|
| 314 |
+
img_a = _open_image(src_a)
|
| 315 |
+
img_b = _open_image(src_b)
|
| 316 |
+
|
| 317 |
+
pv_a = _preprocess(img_a, max_size)
|
| 318 |
+
pv_b = _preprocess(img_b, max_size)
|
| 319 |
+
|
| 320 |
+
# Run both through the backbone in separate GPU calls
|
| 321 |
+
# (spaces.GPU does not support batching across different-sized tensors)
|
| 322 |
+
feat_a = _gpu_extract_descriptor(pv_a) # [6400]
|
| 323 |
+
feat_b = _gpu_extract_descriptor(pv_b) # [6400]
|
| 324 |
+
|
| 325 |
+
# L2-normalise
|
| 326 |
+
feat_a = feat_a / (np.linalg.norm(feat_a) + 1e-8)
|
| 327 |
+
feat_b = feat_b / (np.linalg.norm(feat_b) + 1e-8)
|
| 328 |
+
|
| 329 |
+
score = float(np.dot(feat_a, feat_b))
|
| 330 |
+
|
| 331 |
+
return json.dumps(
|
| 332 |
+
{
|
| 333 |
+
"score": round(score, 6),
|
| 334 |
+
"desc_a": feat_a.tolist(),
|
| 335 |
+
"desc_b": feat_b.tolist(),
|
| 336 |
+
}
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
# ---- FastAPI routes --------------------------------------------------------
|
| 341 |
|
| 342 |
|
inference_tagger_standalone.py
CHANGED
|
@@ -83,6 +83,7 @@ FEATURE_DIM = (1 + N_REGISTERS) * D_MODEL # 6400
|
|
| 83 |
# RoPE helpers
|
| 84 |
# ---------------------------------------------------------------------------
|
| 85 |
|
|
|
|
| 86 |
@lru_cache(maxsize=32)
|
| 87 |
def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor:
|
| 88 |
device = torch.device(device_str)
|
|
@@ -94,11 +95,14 @@ def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor:
|
|
| 94 |
return coords # [h*w, 2]
|
| 95 |
|
| 96 |
|
| 97 |
-
def _build_rope(
|
| 98 |
-
|
|
|
|
| 99 |
coords = _patch_coords_cached(h_patches, w_patches, str(device))
|
| 100 |
-
inv_freq = 1.0 / (
|
| 101 |
-
|
|
|
|
|
|
|
| 102 |
angles = 2 * math.pi * coords[:, :, None] * inv_freq[None, None, :]
|
| 103 |
angles = angles.flatten(1, 2).tile(2)
|
| 104 |
cos = torch.cos(angles).to(dtype).unsqueeze(0).unsqueeze(0)
|
|
@@ -111,8 +115,7 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
|
| 111 |
return torch.cat((-x[..., h:], x[..., :h]), dim=-1)
|
| 112 |
|
| 113 |
|
| 114 |
-
def _apply_rope(q: torch.Tensor, k: torch.Tensor,
|
| 115 |
-
cos: torch.Tensor, sin: torch.Tensor):
|
| 116 |
n_pre = 1 + N_REGISTERS
|
| 117 |
q_pre, q_pat = q[..., :n_pre, :], q[..., n_pre:, :]
|
| 118 |
k_pre, k_pat = k[..., :n_pre, :], k[..., n_pre:, :]
|
|
@@ -125,6 +128,7 @@ def _apply_rope(q: torch.Tensor, k: torch.Tensor,
|
|
| 125 |
# Transformer blocks
|
| 126 |
# ---------------------------------------------------------------------------
|
| 127 |
|
|
|
|
| 128 |
class _Attention(nn.Module):
|
| 129 |
def __init__(self):
|
| 130 |
super().__init__()
|
|
@@ -139,7 +143,7 @@ class _Attention(nn.Module):
|
|
| 139 |
k = self.k_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
|
| 140 |
v = self.v_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
|
| 141 |
q, k = _apply_rope(q, k, cos, sin)
|
| 142 |
-
out = F.scaled_dot_product_attention(q, k, v, scale=HEAD_DIM
|
| 143 |
return self.o_proj(out.transpose(1, 2).reshape(B, S, D_MODEL))
|
| 144 |
|
| 145 |
|
|
@@ -179,13 +183,15 @@ class _Embeddings(nn.Module):
|
|
| 179 |
self.mask_token = nn.Parameter(torch.zeros(1, 1, D_MODEL))
|
| 180 |
self.register_tokens = nn.Parameter(torch.zeros(1, N_REGISTERS, D_MODEL))
|
| 181 |
self.patch_embeddings = nn.Conv2d(
|
| 182 |
-
3, D_MODEL, kernel_size=PATCH_SIZE, stride=PATCH_SIZE
|
|
|
|
| 183 |
|
| 184 |
def forward(self, pixel_values):
|
| 185 |
B = pixel_values.shape[0]
|
| 186 |
dtype = self.patch_embeddings.weight.dtype
|
| 187 |
-
patches =
|
| 188 |
-
pixel_values.to(dtype)).flatten(2).transpose(1, 2)
|
|
|
|
| 189 |
cls = self.cls_token.expand(B, -1, -1)
|
| 190 |
regs = self.register_tokens.expand(B, -1, -1)
|
| 191 |
return torch.cat([cls, regs, patches], dim=1)
|
|
@@ -224,7 +230,7 @@ class DINOv3ViTH(nn.Module):
|
|
| 224 |
x = block(x, cos, sin)
|
| 225 |
x = self.norm(x)
|
| 226 |
# token layout: [CLS, reg_0..reg_R-1, patch_0..patch_N]
|
| 227 |
-
patch_tokens = x[:, 1 + N_REGISTERS:, :] # [B, h_p*w_p, D_MODEL]
|
| 228 |
return patch_tokens, h_p, w_p
|
| 229 |
|
| 230 |
|
|
@@ -232,16 +238,18 @@ class DINOv3ViTH(nn.Module):
|
|
| 232 |
# Head — auto-detected from the checkpoint
|
| 233 |
# =============================================================================
|
| 234 |
|
|
|
|
| 235 |
class _LowRankHead(nn.Module):
|
| 236 |
"""Two-matrix low-rank projection head.
|
| 237 |
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
"""
|
| 242 |
|
| 243 |
-
def __init__(
|
| 244 |
-
|
|
|
|
| 245 |
super().__init__()
|
| 246 |
self.proj_down = nn.Linear(in_dim, rank, bias=down_bias)
|
| 247 |
self.proj_up = nn.Linear(rank, num_tags, bias=up_bias)
|
|
@@ -265,15 +273,15 @@ def _build_head_from_checkpoint(
|
|
| 265 |
Returns (module, remapped_state_dict) where the remapped state dict
|
| 266 |
matches the module's own key names so strict loading works.
|
| 267 |
"""
|
| 268 |
-
weights_2d = [
|
| 269 |
-
|
|
|
|
| 270 |
|
| 271 |
# --- Case 1: single dense linear ---------------------------------------
|
| 272 |
-
singles = [(k, v) for k, v in weights_2d
|
| 273 |
-
if tuple(v.shape) == (num_tags, in_dim)]
|
| 274 |
if len(weights_2d) <= 2 and len(singles) == 1:
|
| 275 |
wkey, wval = singles[0]
|
| 276 |
-
base = wkey[:-len(".weight")]
|
| 277 |
bias_key = base + ".bias"
|
| 278 |
has_bias = bias_key in head_sd
|
| 279 |
module = nn.Linear(in_dim, num_tags, bias=has_bias)
|
|
@@ -285,12 +293,13 @@ def _build_head_from_checkpoint(
|
|
| 285 |
extra = set(head_sd) - expected_src
|
| 286 |
if extra:
|
| 287 |
raise RuntimeError(
|
| 288 |
-
f"Head has single-linear shape but extra unknown keys: {sorted(extra)}"
|
|
|
|
| 289 |
return module, remapped
|
| 290 |
|
| 291 |
# --- Case 2: low-rank pair ---------------------------------------------
|
| 292 |
down = None # (key, tensor) with shape [rank, in_dim]
|
| 293 |
-
up = None
|
| 294 |
for k, v in weights_2d:
|
| 295 |
if v.shape[1] == in_dim and v.shape[0] != num_tags:
|
| 296 |
down = (k, v)
|
|
@@ -303,12 +312,13 @@ def _build_head_from_checkpoint(
|
|
| 303 |
if rank_down != rank_up:
|
| 304 |
raise RuntimeError(
|
| 305 |
f"Low-rank head: inner dims disagree "
|
| 306 |
-
f"(down out={rank_down}, up in={rank_up})"
|
|
|
|
| 307 |
|
| 308 |
down_key, down_w = down
|
| 309 |
up_key, up_w = up
|
| 310 |
-
down_base = down_key[:-len(".weight")]
|
| 311 |
-
up_base = up_key[:-len(".weight")]
|
| 312 |
down_bias_key = down_base + ".bias"
|
| 313 |
up_bias_key = up_base + ".bias"
|
| 314 |
has_down_bias = down_bias_key in head_sd
|
|
@@ -340,11 +350,14 @@ def _build_head_from_checkpoint(
|
|
| 340 |
if extra:
|
| 341 |
raise RuntimeError(
|
| 342 |
f"Low-rank head detected but checkpoint has extra unknown "
|
| 343 |
-
f"head keys: {sorted(extra)}"
|
|
|
|
| 344 |
|
| 345 |
-
print(
|
| 346 |
-
|
| 347 |
-
|
|
|
|
|
|
|
| 348 |
return module, remapped
|
| 349 |
|
| 350 |
raise RuntimeError(
|
|
@@ -357,6 +370,7 @@ def _build_head_from_checkpoint(
|
|
| 357 |
# Tagger wrapper module
|
| 358 |
# =============================================================================
|
| 359 |
|
|
|
|
| 360 |
class DINOv3Tagger(nn.Module):
|
| 361 |
"""Backbone + head. The head is attached after the checkpoint is
|
| 362 |
inspected (so we can build the right shape)."""
|
|
@@ -369,15 +383,26 @@ class DINOv3Tagger(nn.Module):
|
|
| 369 |
def forward(self, pixel_values):
|
| 370 |
hidden = self.backbone(pixel_values)
|
| 371 |
cls = hidden[:, 0, :]
|
| 372 |
-
regs = hidden[:, 1: 1 + N_REGISTERS, :].flatten(1)
|
| 373 |
features = torch.cat([cls, regs], dim=-1).float() # fp32 for head
|
| 374 |
return self.head(features)
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
|
| 377 |
# =============================================================================
|
| 378 |
# Checkpoint loading helpers
|
| 379 |
# =============================================================================
|
| 380 |
|
|
|
|
| 381 |
def _split_and_clean_state_dict(sd: dict) -> tuple[dict, dict]:
|
| 382 |
"""Split full state dict into (backbone_sd, head_sd), stripping the
|
| 383 |
``backbone.`` prefix and applying the remaps needed to match
|
|
@@ -395,10 +420,10 @@ def _split_and_clean_state_dict(sd: dict) -> tuple[dict, dict]:
|
|
| 395 |
head_sd: dict = {}
|
| 396 |
for k, v in sd.items():
|
| 397 |
if k.startswith("backbone."):
|
| 398 |
-
nk = k[len("backbone."):]
|
| 399 |
# Remap (1): strip intermediate "model." before "layer."
|
| 400 |
if nk.startswith("model.layer."):
|
| 401 |
-
nk = nk[len("model."):]
|
| 402 |
backbone_sd[nk] = v
|
| 403 |
else:
|
| 404 |
head_sd[k] = v
|
|
@@ -406,7 +431,7 @@ def _split_and_clean_state_dict(sd: dict) -> tuple[dict, dict]:
|
|
| 406 |
# Remap (2): layer.N.layer_scale{1,2}.lambda1 → layer.N.layer_scale{1,2}
|
| 407 |
for k in list(backbone_sd.keys()):
|
| 408 |
if ".layer_scale" in k and k.endswith(".lambda1"):
|
| 409 |
-
backbone_sd[k[:-len(".lambda1")]] = backbone_sd.pop(k)
|
| 410 |
|
| 411 |
# Remap (3): drop rope buffers (recomputed on the fly)
|
| 412 |
for k in list(backbone_sd.keys()):
|
|
@@ -454,18 +479,21 @@ def preprocess_image(source, max_size: int = 1024) -> torch.Tensor:
|
|
| 454 |
new_w = _snap(max(PATCH_SIZE, round(w * scale)), PATCH_SIZE)
|
| 455 |
new_h = _snap(max(PATCH_SIZE, round(h * scale)), PATCH_SIZE)
|
| 456 |
|
| 457 |
-
return v2.Compose(
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
|
|
|
|
|
|
| 463 |
|
| 464 |
|
| 465 |
# =============================================================================
|
| 466 |
# Tagger wrapper
|
| 467 |
# =============================================================================
|
| 468 |
|
|
|
|
| 469 |
class Tagger:
|
| 470 |
"""Inference wrapper for DINOv3Tagger (ViT-H/16+).
|
| 471 |
|
|
@@ -519,12 +547,15 @@ class Tagger:
|
|
| 519 |
|
| 520 |
if not head_sd:
|
| 521 |
raise RuntimeError(
|
| 522 |
-
"Checkpoint contains no non-backbone keys — cannot build head."
|
|
|
|
| 523 |
|
| 524 |
# --- Build model, inferring head shape from the checkpoint --------
|
| 525 |
self.model = DINOv3Tagger()
|
| 526 |
head_module, head_sd_remapped = _build_head_from_checkpoint(
|
| 527 |
-
head_sd,
|
|
|
|
|
|
|
| 528 |
)
|
| 529 |
self.model.head = head_module
|
| 530 |
|
|
@@ -533,10 +564,8 @@ class Tagger:
|
|
| 533 |
self.model.head.load_state_dict(head_sd_remapped, strict=True)
|
| 534 |
|
| 535 |
# --- Move to device. Backbone → bf16/fp16; head stays fp32. --------
|
| 536 |
-
self.model.backbone = self.model.backbone.to(
|
| 537 |
-
|
| 538 |
-
self.model.head = self.model.head.to(
|
| 539 |
-
device=self.device, dtype=torch.float32)
|
| 540 |
self.model.eval()
|
| 541 |
print(f"[Tagger] Ready on {self.device} (backbone={dtype}, head=fp32)")
|
| 542 |
|
|
@@ -571,12 +600,20 @@ class Tagger:
|
|
| 571 |
scale = min(1.0, max_size / max(w, h))
|
| 572 |
new_w = _snap(round(w * scale), PATCH_SIZE)
|
| 573 |
new_h = _snap(round(h * scale), PATCH_SIZE)
|
| 574 |
-
pv =
|
| 575 |
-
v2.
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
else:
|
| 581 |
pv = preprocess_image(image, max_size=max_size).to(self.device)
|
| 582 |
|
|
@@ -606,8 +643,9 @@ class Tagger:
|
|
| 606 |
return Image.fromarray(rgb_uint8, mode="RGB")
|
| 607 |
|
| 608 |
@torch.no_grad()
|
| 609 |
-
def predict(
|
| 610 |
-
|
|
|
|
| 611 |
"""Tag a single image (local path or URL)."""
|
| 612 |
if topk is None and threshold is None:
|
| 613 |
topk = 30
|
|
@@ -625,20 +663,23 @@ class Tagger:
|
|
| 625 |
order = values.argsort(descending=True)
|
| 626 |
indices, values = indices[order], values[order]
|
| 627 |
|
| 628 |
-
return [
|
| 629 |
-
|
|
|
|
|
|
|
| 630 |
|
| 631 |
@torch.no_grad()
|
| 632 |
-
def predict_batch(
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
|
| 637 |
|
| 638 |
# =============================================================================
|
| 639 |
# Output formatters
|
| 640 |
# =============================================================================
|
| 641 |
|
|
|
|
| 642 |
def _fmt_pretty(path: str, results) -> str:
|
| 643 |
lines = [f"\n{'─' * 60}", f" {path}", f"{'─' * 60}"]
|
| 644 |
for rank, (tag, score) in enumerate(results, 1):
|
|
@@ -652,38 +693,53 @@ def _fmt_tags(results) -> str:
|
|
| 652 |
|
| 653 |
|
| 654 |
def _fmt_json(path: str, results) -> dict:
|
| 655 |
-
return {
|
| 656 |
-
|
|
|
|
|
|
|
| 657 |
|
| 658 |
|
| 659 |
# =============================================================================
|
| 660 |
# CLI
|
| 661 |
# =============================================================================
|
| 662 |
|
|
|
|
| 663 |
def main():
|
| 664 |
parser = argparse.ArgumentParser(
|
| 665 |
description="DINOv3 ViT-H/16+ tagger inference (standalone)",
|
| 666 |
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 667 |
)
|
| 668 |
-
parser.add_argument(
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
parser.add_argument(
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
|
| 679 |
mode = parser.add_mutually_exclusive_group()
|
| 680 |
-
mode.add_argument(
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
|
|
|
|
|
|
| 684 |
|
| 685 |
-
parser.add_argument(
|
| 686 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
args = parser.parse_args()
|
| 688 |
|
| 689 |
tagger = Tagger(
|
|
@@ -693,9 +749,7 @@ def main():
|
|
| 693 |
max_size=args.max_size,
|
| 694 |
)
|
| 695 |
|
| 696 |
-
topk, threshold = (
|
| 697 |
-
(None, args.threshold) if args.threshold else (args.topk, None)
|
| 698 |
-
)
|
| 699 |
json_out = []
|
| 700 |
|
| 701 |
for src in args.images:
|
|
@@ -716,4 +770,4 @@ def main():
|
|
| 716 |
|
| 717 |
|
| 718 |
if __name__ == "__main__":
|
| 719 |
-
main()
|
|
|
|
| 83 |
# RoPE helpers
|
| 84 |
# ---------------------------------------------------------------------------
|
| 85 |
|
| 86 |
+
|
| 87 |
@lru_cache(maxsize=32)
|
| 88 |
def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor:
|
| 89 |
device = torch.device(device_str)
|
|
|
|
| 95 |
return coords # [h*w, 2]
|
| 96 |
|
| 97 |
|
| 98 |
+
def _build_rope(
|
| 99 |
+
h_patches: int, w_patches: int, dtype: torch.dtype, device: torch.device
|
| 100 |
+
):
|
| 101 |
coords = _patch_coords_cached(h_patches, w_patches, str(device))
|
| 102 |
+
inv_freq = 1.0 / (
|
| 103 |
+
ROPE_THETA
|
| 104 |
+
** torch.arange(0, 1, 4 / HEAD_DIM, dtype=torch.float32, device=device)
|
| 105 |
+
)
|
| 106 |
angles = 2 * math.pi * coords[:, :, None] * inv_freq[None, None, :]
|
| 107 |
angles = angles.flatten(1, 2).tile(2)
|
| 108 |
cos = torch.cos(angles).to(dtype).unsqueeze(0).unsqueeze(0)
|
|
|
|
| 115 |
return torch.cat((-x[..., h:], x[..., :h]), dim=-1)
|
| 116 |
|
| 117 |
|
| 118 |
+
def _apply_rope(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
|
|
|
| 119 |
n_pre = 1 + N_REGISTERS
|
| 120 |
q_pre, q_pat = q[..., :n_pre, :], q[..., n_pre:, :]
|
| 121 |
k_pre, k_pat = k[..., :n_pre, :], k[..., n_pre:, :]
|
|
|
|
| 128 |
# Transformer blocks
|
| 129 |
# ---------------------------------------------------------------------------
|
| 130 |
|
| 131 |
+
|
| 132 |
class _Attention(nn.Module):
|
| 133 |
def __init__(self):
|
| 134 |
super().__init__()
|
|
|
|
| 143 |
k = self.k_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
|
| 144 |
v = self.v_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
|
| 145 |
q, k = _apply_rope(q, k, cos, sin)
|
| 146 |
+
out = F.scaled_dot_product_attention(q, k, v, scale=HEAD_DIM**-0.5)
|
| 147 |
return self.o_proj(out.transpose(1, 2).reshape(B, S, D_MODEL))
|
| 148 |
|
| 149 |
|
|
|
|
| 183 |
self.mask_token = nn.Parameter(torch.zeros(1, 1, D_MODEL))
|
| 184 |
self.register_tokens = nn.Parameter(torch.zeros(1, N_REGISTERS, D_MODEL))
|
| 185 |
self.patch_embeddings = nn.Conv2d(
|
| 186 |
+
3, D_MODEL, kernel_size=PATCH_SIZE, stride=PATCH_SIZE
|
| 187 |
+
)
|
| 188 |
|
| 189 |
def forward(self, pixel_values):
|
| 190 |
B = pixel_values.shape[0]
|
| 191 |
dtype = self.patch_embeddings.weight.dtype
|
| 192 |
+
patches = (
|
| 193 |
+
self.patch_embeddings(pixel_values.to(dtype)).flatten(2).transpose(1, 2)
|
| 194 |
+
)
|
| 195 |
cls = self.cls_token.expand(B, -1, -1)
|
| 196 |
regs = self.register_tokens.expand(B, -1, -1)
|
| 197 |
return torch.cat([cls, regs, patches], dim=1)
|
|
|
|
| 230 |
x = block(x, cos, sin)
|
| 231 |
x = self.norm(x)
|
| 232 |
# token layout: [CLS, reg_0..reg_R-1, patch_0..patch_N]
|
| 233 |
+
patch_tokens = x[:, 1 + N_REGISTERS :, :] # [B, h_p*w_p, D_MODEL]
|
| 234 |
return patch_tokens, h_p, w_p
|
| 235 |
|
| 236 |
|
|
|
|
| 238 |
# Head — auto-detected from the checkpoint
|
| 239 |
# =============================================================================
|
| 240 |
|
| 241 |
+
|
| 242 |
class _LowRankHead(nn.Module):
|
| 243 |
"""Two-matrix low-rank projection head.
|
| 244 |
|
| 245 |
+
features (in_dim)
|
| 246 |
+
→ Linear(in_dim, rank, bias=?)
|
| 247 |
+
→ Linear(rank, num_tags, bias=?)
|
| 248 |
"""
|
| 249 |
|
| 250 |
+
def __init__(
|
| 251 |
+
self, in_dim: int, rank: int, num_tags: int, down_bias: bool, up_bias: bool
|
| 252 |
+
):
|
| 253 |
super().__init__()
|
| 254 |
self.proj_down = nn.Linear(in_dim, rank, bias=down_bias)
|
| 255 |
self.proj_up = nn.Linear(rank, num_tags, bias=up_bias)
|
|
|
|
| 273 |
Returns (module, remapped_state_dict) where the remapped state dict
|
| 274 |
matches the module's own key names so strict loading works.
|
| 275 |
"""
|
| 276 |
+
weights_2d = [
|
| 277 |
+
(k, v) for k, v in head_sd.items() if k.endswith(".weight") and v.ndim == 2
|
| 278 |
+
]
|
| 279 |
|
| 280 |
# --- Case 1: single dense linear ---------------------------------------
|
| 281 |
+
singles = [(k, v) for k, v in weights_2d if tuple(v.shape) == (num_tags, in_dim)]
|
|
|
|
| 282 |
if len(weights_2d) <= 2 and len(singles) == 1:
|
| 283 |
wkey, wval = singles[0]
|
| 284 |
+
base = wkey[: -len(".weight")]
|
| 285 |
bias_key = base + ".bias"
|
| 286 |
has_bias = bias_key in head_sd
|
| 287 |
module = nn.Linear(in_dim, num_tags, bias=has_bias)
|
|
|
|
| 293 |
extra = set(head_sd) - expected_src
|
| 294 |
if extra:
|
| 295 |
raise RuntimeError(
|
| 296 |
+
f"Head has single-linear shape but extra unknown keys: {sorted(extra)}"
|
| 297 |
+
)
|
| 298 |
return module, remapped
|
| 299 |
|
| 300 |
# --- Case 2: low-rank pair ---------------------------------------------
|
| 301 |
down = None # (key, tensor) with shape [rank, in_dim]
|
| 302 |
+
up = None # (key, tensor) with shape [num_tags, rank]
|
| 303 |
for k, v in weights_2d:
|
| 304 |
if v.shape[1] == in_dim and v.shape[0] != num_tags:
|
| 305 |
down = (k, v)
|
|
|
|
| 312 |
if rank_down != rank_up:
|
| 313 |
raise RuntimeError(
|
| 314 |
f"Low-rank head: inner dims disagree "
|
| 315 |
+
f"(down out={rank_down}, up in={rank_up})"
|
| 316 |
+
)
|
| 317 |
|
| 318 |
down_key, down_w = down
|
| 319 |
up_key, up_w = up
|
| 320 |
+
down_base = down_key[: -len(".weight")]
|
| 321 |
+
up_base = up_key[: -len(".weight")]
|
| 322 |
down_bias_key = down_base + ".bias"
|
| 323 |
up_bias_key = up_base + ".bias"
|
| 324 |
has_down_bias = down_bias_key in head_sd
|
|
|
|
| 350 |
if extra:
|
| 351 |
raise RuntimeError(
|
| 352 |
f"Low-rank head detected but checkpoint has extra unknown "
|
| 353 |
+
f"head keys: {sorted(extra)}"
|
| 354 |
+
)
|
| 355 |
|
| 356 |
+
print(
|
| 357 |
+
f"[Tagger] Detected low-rank head: "
|
| 358 |
+
f"in_dim={in_dim}, rank={rank_down}, num_tags={num_tags} "
|
| 359 |
+
f"(down_bias={has_down_bias}, up_bias={has_up_bias})"
|
| 360 |
+
)
|
| 361 |
return module, remapped
|
| 362 |
|
| 363 |
raise RuntimeError(
|
|
|
|
| 370 |
# Tagger wrapper module
|
| 371 |
# =============================================================================
|
| 372 |
|
| 373 |
+
|
| 374 |
class DINOv3Tagger(nn.Module):
|
| 375 |
"""Backbone + head. The head is attached after the checkpoint is
|
| 376 |
inspected (so we can build the right shape)."""
|
|
|
|
| 383 |
def forward(self, pixel_values):
|
| 384 |
hidden = self.backbone(pixel_values)
|
| 385 |
cls = hidden[:, 0, :]
|
| 386 |
+
regs = hidden[:, 1 : 1 + N_REGISTERS, :].flatten(1)
|
| 387 |
features = torch.cat([cls, regs], dim=-1).float() # fp32 for head
|
| 388 |
return self.head(features)
|
| 389 |
|
| 390 |
+
def forward_embedding(self, pixel_values):
|
| 391 |
+
"""Return the FEATURE_DIM=6400 image descriptor without applying the head.
|
| 392 |
+
Same as forward() but stops before self.head — use this for similarity queries.
|
| 393 |
+
"""
|
| 394 |
+
hidden = self.backbone(pixel_values)
|
| 395 |
+
cls = hidden[:, 0, :]
|
| 396 |
+
regs = hidden[:, 1 : 1 + N_REGISTERS, :].flatten(1)
|
| 397 |
+
features = torch.cat([cls, regs], dim=-1).float() # fp32 for head
|
| 398 |
+
return features
|
| 399 |
+
|
| 400 |
|
| 401 |
# =============================================================================
|
| 402 |
# Checkpoint loading helpers
|
| 403 |
# =============================================================================
|
| 404 |
|
| 405 |
+
|
| 406 |
def _split_and_clean_state_dict(sd: dict) -> tuple[dict, dict]:
|
| 407 |
"""Split full state dict into (backbone_sd, head_sd), stripping the
|
| 408 |
``backbone.`` prefix and applying the remaps needed to match
|
|
|
|
| 420 |
head_sd: dict = {}
|
| 421 |
for k, v in sd.items():
|
| 422 |
if k.startswith("backbone."):
|
| 423 |
+
nk = k[len("backbone.") :]
|
| 424 |
# Remap (1): strip intermediate "model." before "layer."
|
| 425 |
if nk.startswith("model.layer."):
|
| 426 |
+
nk = nk[len("model.") :]
|
| 427 |
backbone_sd[nk] = v
|
| 428 |
else:
|
| 429 |
head_sd[k] = v
|
|
|
|
| 431 |
# Remap (2): layer.N.layer_scale{1,2}.lambda1 → layer.N.layer_scale{1,2}
|
| 432 |
for k in list(backbone_sd.keys()):
|
| 433 |
if ".layer_scale" in k and k.endswith(".lambda1"):
|
| 434 |
+
backbone_sd[k[: -len(".lambda1")]] = backbone_sd.pop(k)
|
| 435 |
|
| 436 |
# Remap (3): drop rope buffers (recomputed on the fly)
|
| 437 |
for k in list(backbone_sd.keys()):
|
|
|
|
| 479 |
new_w = _snap(max(PATCH_SIZE, round(w * scale)), PATCH_SIZE)
|
| 480 |
new_h = _snap(max(PATCH_SIZE, round(h * scale)), PATCH_SIZE)
|
| 481 |
|
| 482 |
+
return v2.Compose(
|
| 483 |
+
[
|
| 484 |
+
v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
|
| 485 |
+
v2.ToImage(),
|
| 486 |
+
v2.ToDtype(torch.float32, scale=True),
|
| 487 |
+
v2.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
|
| 488 |
+
]
|
| 489 |
+
)(img).unsqueeze(0)
|
| 490 |
|
| 491 |
|
| 492 |
# =============================================================================
|
| 493 |
# Tagger wrapper
|
| 494 |
# =============================================================================
|
| 495 |
|
| 496 |
+
|
| 497 |
class Tagger:
|
| 498 |
"""Inference wrapper for DINOv3Tagger (ViT-H/16+).
|
| 499 |
|
|
|
|
| 547 |
|
| 548 |
if not head_sd:
|
| 549 |
raise RuntimeError(
|
| 550 |
+
"Checkpoint contains no non-backbone keys — cannot build head."
|
| 551 |
+
)
|
| 552 |
|
| 553 |
# --- Build model, inferring head shape from the checkpoint --------
|
| 554 |
self.model = DINOv3Tagger()
|
| 555 |
head_module, head_sd_remapped = _build_head_from_checkpoint(
|
| 556 |
+
head_sd,
|
| 557 |
+
in_dim=FEATURE_DIM,
|
| 558 |
+
num_tags=self.num_tags,
|
| 559 |
)
|
| 560 |
self.model.head = head_module
|
| 561 |
|
|
|
|
| 564 |
self.model.head.load_state_dict(head_sd_remapped, strict=True)
|
| 565 |
|
| 566 |
# --- Move to device. Backbone → bf16/fp16; head stays fp32. --------
|
| 567 |
+
self.model.backbone = self.model.backbone.to(device=self.device, dtype=dtype)
|
| 568 |
+
self.model.head = self.model.head.to(device=self.device, dtype=torch.float32)
|
|
|
|
|
|
|
| 569 |
self.model.eval()
|
| 570 |
print(f"[Tagger] Ready on {self.device} (backbone={dtype}, head=fp32)")
|
| 571 |
|
|
|
|
| 600 |
scale = min(1.0, max_size / max(w, h))
|
| 601 |
new_w = _snap(round(w * scale), PATCH_SIZE)
|
| 602 |
new_h = _snap(round(h * scale), PATCH_SIZE)
|
| 603 |
+
pv = (
|
| 604 |
+
v2.Compose(
|
| 605 |
+
[
|
| 606 |
+
v2.Resize(
|
| 607 |
+
(new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS
|
| 608 |
+
),
|
| 609 |
+
v2.ToImage(),
|
| 610 |
+
v2.ToDtype(torch.float32, scale=True),
|
| 611 |
+
v2.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
|
| 612 |
+
]
|
| 613 |
+
)(img)
|
| 614 |
+
.unsqueeze(0)
|
| 615 |
+
.to(self.device)
|
| 616 |
+
)
|
| 617 |
else:
|
| 618 |
pv = preprocess_image(image, max_size=max_size).to(self.device)
|
| 619 |
|
|
|
|
| 643 |
return Image.fromarray(rgb_uint8, mode="RGB")
|
| 644 |
|
| 645 |
@torch.no_grad()
|
| 646 |
+
def predict(
|
| 647 |
+
self, image, topk: int | None = 30, threshold: float | None = None
|
| 648 |
+
) -> list[tuple[str, float]]:
|
| 649 |
"""Tag a single image (local path or URL)."""
|
| 650 |
if topk is None and threshold is None:
|
| 651 |
topk = 30
|
|
|
|
| 663 |
order = values.argsort(descending=True)
|
| 664 |
indices, values = indices[order], values[order]
|
| 665 |
|
| 666 |
+
return [
|
| 667 |
+
(self.idx2tag[i], float(v))
|
| 668 |
+
for i, v in zip(indices.tolist(), values.tolist())
|
| 669 |
+
]
|
| 670 |
|
| 671 |
@torch.no_grad()
|
| 672 |
+
def predict_batch(
|
| 673 |
+
self, images, topk: int | None = 30, threshold: float | None = None
|
| 674 |
+
):
|
| 675 |
+
return [self.predict(img, topk=topk, threshold=threshold) for img in images]
|
| 676 |
|
| 677 |
|
| 678 |
# =============================================================================
|
| 679 |
# Output formatters
|
| 680 |
# =============================================================================
|
| 681 |
|
| 682 |
+
|
| 683 |
def _fmt_pretty(path: str, results) -> str:
|
| 684 |
lines = [f"\n{'─' * 60}", f" {path}", f"{'─' * 60}"]
|
| 685 |
for rank, (tag, score) in enumerate(results, 1):
|
|
|
|
| 693 |
|
| 694 |
|
| 695 |
def _fmt_json(path: str, results) -> dict:
|
| 696 |
+
return {
|
| 697 |
+
"file": path,
|
| 698 |
+
"tags": [{"tag": t, "score": round(s, 4)} for t, s in results],
|
| 699 |
+
}
|
| 700 |
|
| 701 |
|
| 702 |
# =============================================================================
|
| 703 |
# CLI
|
| 704 |
# =============================================================================
|
| 705 |
|
| 706 |
+
|
| 707 |
def main():
|
| 708 |
parser = argparse.ArgumentParser(
|
| 709 |
description="DINOv3 ViT-H/16+ tagger inference (standalone)",
|
| 710 |
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 711 |
)
|
| 712 |
+
parser.add_argument(
|
| 713 |
+
"--checkpoint", required=True, help="Path to .safetensors or .pt checkpoint"
|
| 714 |
+
)
|
| 715 |
+
parser.add_argument("--vocab", required=True, help="Path to tagger_vocab*.json")
|
| 716 |
+
parser.add_argument(
|
| 717 |
+
"--images", nargs="+", required=True, help="Image paths and/or http(s) URLs"
|
| 718 |
+
)
|
| 719 |
+
parser.add_argument(
|
| 720 |
+
"--device", default="cuda", help="Device: cuda, cuda:0, cpu (default: cuda)"
|
| 721 |
+
)
|
| 722 |
+
parser.add_argument(
|
| 723 |
+
"--max-size",
|
| 724 |
+
type=int,
|
| 725 |
+
default=1024,
|
| 726 |
+
help="Long-edge cap in pixels (default: 1024)",
|
| 727 |
+
)
|
| 728 |
|
| 729 |
mode = parser.add_mutually_exclusive_group()
|
| 730 |
+
mode.add_argument(
|
| 731 |
+
"--topk", type=int, default=30, help="Return top-k tags (default: 30)"
|
| 732 |
+
)
|
| 733 |
+
mode.add_argument(
|
| 734 |
+
"--threshold", type=float, help="Return all tags with score >= threshold"
|
| 735 |
+
)
|
| 736 |
|
| 737 |
+
parser.add_argument(
|
| 738 |
+
"--format",
|
| 739 |
+
choices=["pretty", "tags", "json"],
|
| 740 |
+
default="pretty",
|
| 741 |
+
help="Output format (default: pretty)",
|
| 742 |
+
)
|
| 743 |
args = parser.parse_args()
|
| 744 |
|
| 745 |
tagger = Tagger(
|
|
|
|
| 749 |
max_size=args.max_size,
|
| 750 |
)
|
| 751 |
|
| 752 |
+
topk, threshold = (None, args.threshold) if args.threshold else (args.topk, None)
|
|
|
|
|
|
|
| 753 |
json_out = []
|
| 754 |
|
| 755 |
for src in args.images:
|
|
|
|
| 770 |
|
| 771 |
|
| 772 |
if __name__ == "__main__":
|
| 773 |
+
main()
|
tagger_ui/templates/index.html
CHANGED
|
@@ -259,6 +259,81 @@
|
|
| 259 |
.tag-pill:hover { opacity: .8; }
|
| 260 |
.tag-pill .score { font-size: .66rem; opacity: .7; }
|
| 261 |
.tag-pill.hidden { display: none; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
</style>
|
| 263 |
</head>
|
| 264 |
<body>
|
|
@@ -266,7 +341,14 @@
|
|
| 266 |
<h1>DINOv3 <span>Tagger</span></h1>
|
| 267 |
<p class="subtitle">ViT-H/16+ · {{ num_tags | format_number }} tags · {{ vocab_path }}</p>
|
| 268 |
|
| 269 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
<!-- ====== LEFT PANEL ====== -->
|
| 272 |
<div class="panel-left">
|
|
@@ -367,7 +449,57 @@
|
|
| 367 |
</div>
|
| 368 |
</div><!-- /panel-right -->
|
| 369 |
|
| 370 |
-
</div><!-- /
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
<script>
|
| 373 |
// ---- category metadata from server ----
|
|
@@ -653,7 +785,7 @@
|
|
| 653 |
if (e.key === 'Enter') runFromUrl();
|
| 654 |
});
|
| 655 |
|
| 656 |
-
// drag & drop
|
| 657 |
const dz = document.getElementById('drop-zone');
|
| 658 |
dz.addEventListener('dragover', e => { e.preventDefault(); dz.classList.add('drag-over'); });
|
| 659 |
dz.addEventListener('dragleave', () => dz.classList.remove('drag-over'));
|
|
@@ -662,6 +794,55 @@
|
|
| 662 |
const file = e.dataTransfer.files[0];
|
| 663 |
if (file) stageFile(file);
|
| 664 |
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 665 |
</script>
|
| 666 |
|
| 667 |
<!-- @gradio/client module: patches runFromUrl / submitFile / runPca on window
|
|
@@ -773,6 +954,60 @@
|
|
| 773 |
// Re-run with current colour pickers — re-submits full request (backbone
|
| 774 |
// result is cached by the Gradio queue so subsequent calls are fast if
|
| 775 |
// the same image/max_size is used, but ZeroGPU requires a full round-trip).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
window.rerunCustomPca = async function() {
|
| 777 |
if (!_lastPcaRequest) return;
|
| 778 |
const spinner = document.getElementById('pca-spinner');
|
|
|
|
| 259 |
.tag-pill:hover { opacity: .8; }
|
| 260 |
.tag-pill .score { font-size: .66rem; opacity: .7; }
|
| 261 |
.tag-pill.hidden { display: none; }
|
| 262 |
+
|
| 263 |
+
/* ---- similarity drop zones ---- */
|
| 264 |
+
.drop-zone-sim {
|
| 265 |
+
border: 2px dashed var(--border); border-radius: var(--radius);
|
| 266 |
+
color: var(--muted); cursor: pointer; font-size: .85rem;
|
| 267 |
+
padding: 1.2rem; text-align: center;
|
| 268 |
+
transition: border-color .15s, background .15s;
|
| 269 |
+
}
|
| 270 |
+
.drop-zone-sim.drag-over { border-color: var(--accent); background: rgba(124,106,247,.06); }
|
| 271 |
+
|
| 272 |
+
/* ---- mode toggle ---- */
|
| 273 |
+
.mode-toggle {
|
| 274 |
+
display: flex; gap: .5rem; margin-bottom: 1.5rem;
|
| 275 |
+
}
|
| 276 |
+
.mode-btn {
|
| 277 |
+
background: var(--surface); border: 1px solid var(--border);
|
| 278 |
+
border-radius: var(--radius); color: var(--muted); cursor: pointer;
|
| 279 |
+
font-size: .9rem; font-weight: 600; padding: .5rem 1.4rem;
|
| 280 |
+
transition: border-color .15s, color .15s, background .15s;
|
| 281 |
+
}
|
| 282 |
+
.mode-btn:hover { border-color: var(--accent); color: var(--text); }
|
| 283 |
+
.mode-btn.active { border-color: var(--accent); color: #fff; background: var(--accent); }
|
| 284 |
+
|
| 285 |
+
/* ---- similarity panel ---- */
|
| 286 |
+
#similarity-panel { display: none; width: 100%; max-width: 1600px; }
|
| 287 |
+
|
| 288 |
+
.sim-inputs {
|
| 289 |
+
display: flex; gap: 1.25rem; flex-wrap: wrap; margin-bottom: 1rem;
|
| 290 |
+
}
|
| 291 |
+
.sim-inputs .card { flex: 1 1 0; min-width: 260px; }
|
| 292 |
+
|
| 293 |
+
.sim-run-row {
|
| 294 |
+
display: flex; justify-content: center; margin-bottom: 1.5rem;
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
/* score display */
|
| 298 |
+
.sim-score-card {
|
| 299 |
+
background: var(--surface); border: 1px solid var(--border);
|
| 300 |
+
border-radius: var(--radius); padding: 1.25rem 1.5rem;
|
| 301 |
+
margin-bottom: 1.25rem; display: none;
|
| 302 |
+
}
|
| 303 |
+
.sim-score-label {
|
| 304 |
+
font-size: .8rem; color: var(--muted); margin-bottom: .5rem;
|
| 305 |
+
text-transform: uppercase; letter-spacing: .05em;
|
| 306 |
+
}
|
| 307 |
+
.sim-score-value {
|
| 308 |
+
font-size: 2.4rem; font-weight: 700; letter-spacing: -.02em;
|
| 309 |
+
margin-bottom: .6rem;
|
| 310 |
+
}
|
| 311 |
+
.sim-score-bar-bg {
|
| 312 |
+
height: 8px; background: var(--border); border-radius: 4px; overflow: hidden;
|
| 313 |
+
}
|
| 314 |
+
.sim-score-bar-fill {
|
| 315 |
+
height: 100%; border-radius: 4px; background: var(--accent);
|
| 316 |
+
transition: width .4s ease;
|
| 317 |
+
}
|
| 318 |
+
.sim-score-note {
|
| 319 |
+
font-size: .75rem; color: var(--muted); margin-top: .4rem;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
/* image previews in similarity mode */
|
| 323 |
+
.sim-previews {
|
| 324 |
+
display: flex; gap: 1.25rem; flex-wrap: wrap;
|
| 325 |
+
}
|
| 326 |
+
.sim-preview-col {
|
| 327 |
+
flex: 1 1 0; min-width: 0;
|
| 328 |
+
}
|
| 329 |
+
.sim-preview-col img {
|
| 330 |
+
width: 100%; border-radius: var(--radius); border: 1px solid var(--border);
|
| 331 |
+
object-fit: contain; max-height: 480px; display: block;
|
| 332 |
+
}
|
| 333 |
+
.sim-preview-label {
|
| 334 |
+
font-size: .75rem; color: var(--muted); margin-bottom: .4rem;
|
| 335 |
+
text-transform: uppercase; letter-spacing: .04em;
|
| 336 |
+
}
|
| 337 |
</style>
|
| 338 |
</head>
|
| 339 |
<body>
|
|
|
|
| 341 |
<h1>DINOv3 <span>Tagger</span></h1>
|
| 342 |
<p class="subtitle">ViT-H/16+ · {{ num_tags | format_number }} tags · {{ vocab_path }}</p>
|
| 343 |
|
| 344 |
+
<!-- mode toggle -->
|
| 345 |
+
<div class="mode-toggle">
|
| 346 |
+
<button class="mode-btn active" id="mode-btn-tagger" onclick="setMode('tagger')">Tagger</button>
|
| 347 |
+
<button class="mode-btn" id="mode-btn-similarity" onclick="setMode('similarity')">Similarity</button>
|
| 348 |
+
</div>
|
| 349 |
+
|
| 350 |
+
<!-- ===== TAGGER MODE ===== -->
|
| 351 |
+
<div id="tagger-panels" class="layout">
|
| 352 |
|
| 353 |
<!-- ====== LEFT PANEL ====== -->
|
| 354 |
<div class="panel-left">
|
|
|
|
| 449 |
</div>
|
| 450 |
</div><!-- /panel-right -->
|
| 451 |
|
| 452 |
+
</div><!-- /tagger-panels -->
|
| 453 |
+
|
| 454 |
+
<!-- ===== SIMILARITY MODE ===== -->
|
| 455 |
+
<div id="similarity-panel">
|
| 456 |
+
|
| 457 |
+
<div class="sim-inputs">
|
| 458 |
+
<!-- Image A -->
|
| 459 |
+
<div class="card">
|
| 460 |
+
<div style="font-size:.8rem;font-weight:600;color:var(--muted);text-transform:uppercase;letter-spacing:.05em;margin-bottom:.75rem">Image A</div>
|
| 461 |
+
<div class="input-row">
|
| 462 |
+
<input type="text" id="sim-url-a" placeholder="Paste URL…" />
|
| 463 |
+
</div>
|
| 464 |
+
<div id="sim-drop-a" class="drop-zone-sim" onclick="document.getElementById('sim-file-a').click()">
|
| 465 |
+
<input type="file" id="sim-file-a" accept="image/*" style="display:none" onchange="simStageFile('a', this)" />
|
| 466 |
+
Drop image A here or <strong>click to browse</strong>
|
| 467 |
+
</div>
|
| 468 |
+
<img id="sim-preview-a" src="" alt="" style="display:none;width:100%;margin-top:.75rem;border-radius:var(--radius);border:1px solid var(--border);max-height:300px;object-fit:contain" />
|
| 469 |
+
</div>
|
| 470 |
+
|
| 471 |
+
<!-- Image B -->
|
| 472 |
+
<div class="card">
|
| 473 |
+
<div style="font-size:.8rem;font-weight:600;color:var(--muted);text-transform:uppercase;letter-spacing:.05em;margin-bottom:.75rem">Image B</div>
|
| 474 |
+
<div class="input-row">
|
| 475 |
+
<input type="text" id="sim-url-b" placeholder="Paste URL…" />
|
| 476 |
+
</div>
|
| 477 |
+
<div id="sim-drop-b" class="drop-zone-sim" onclick="document.getElementById('sim-file-b').click()">
|
| 478 |
+
<input type="file" id="sim-file-b" accept="image/*" style="display:none" onchange="simStageFile('b', this)" />
|
| 479 |
+
Drop image B here or <strong>click to browse</strong>
|
| 480 |
+
</div>
|
| 481 |
+
<img id="sim-preview-b" src="" alt="" style="display:none;width:100%;margin-top:.75rem;border-radius:var(--radius);border:1px solid var(--border);max-height:300px;object-fit:contain" />
|
| 482 |
+
</div>
|
| 483 |
+
</div>
|
| 484 |
+
|
| 485 |
+
<div class="sim-run-row">
|
| 486 |
+
<button class="btn" id="sim-run-btn" onclick="runSimilarity()">Compare</button>
|
| 487 |
+
</div>
|
| 488 |
+
|
| 489 |
+
<div class="spinner" id="sim-spinner" style="display:none"></div>
|
| 490 |
+
<div class="error-msg" id="sim-error" style="display:none"></div>
|
| 491 |
+
|
| 492 |
+
<!-- score -->
|
| 493 |
+
<div class="sim-score-card" id="sim-score-card">
|
| 494 |
+
<div class="sim-score-label">Cosine Similarity (FEATURE_DIM descriptor)</div>
|
| 495 |
+
<div class="sim-score-value" id="sim-score-value">—</div>
|
| 496 |
+
<div class="sim-score-bar-bg">
|
| 497 |
+
<div class="sim-score-bar-fill" id="sim-score-bar" style="width:0%"></div>
|
| 498 |
+
</div>
|
| 499 |
+
<div class="sim-score-note" id="sim-score-note"></div>
|
| 500 |
+
</div>
|
| 501 |
+
|
| 502 |
+
</div><!-- /similarity-panel -->
|
| 503 |
|
| 504 |
<script>
|
| 505 |
// ---- category metadata from server ----
|
|
|
|
| 785 |
if (e.key === 'Enter') runFromUrl();
|
| 786 |
});
|
| 787 |
|
| 788 |
+
// drag & drop — tagger
|
| 789 |
const dz = document.getElementById('drop-zone');
|
| 790 |
dz.addEventListener('dragover', e => { e.preventDefault(); dz.classList.add('drag-over'); });
|
| 791 |
dz.addEventListener('dragleave', () => dz.classList.remove('drag-over'));
|
|
|
|
| 794 |
const file = e.dataTransfer.files[0];
|
| 795 |
if (file) stageFile(file);
|
| 796 |
});
|
| 797 |
+
|
| 798 |
+
// ---- mode toggle ----
|
| 799 |
+
function setMode(mode) {
|
| 800 |
+
document.getElementById('tagger-panels').style.display = mode === 'tagger' ? 'flex' : 'none';
|
| 801 |
+
document.getElementById('similarity-panel').style.display = mode === 'similarity' ? 'block' : 'none';
|
| 802 |
+
document.getElementById('mode-btn-tagger').classList.toggle('active', mode === 'tagger');
|
| 803 |
+
document.getElementById('mode-btn-similarity').classList.toggle('active', mode === 'similarity');
|
| 804 |
+
}
|
| 805 |
+
|
| 806 |
+
// ---- similarity: staged files ----
|
| 807 |
+
const _simStaged = { a: null, b: null };
|
| 808 |
+
|
| 809 |
+
function simStageFile(side, input) {
|
| 810 |
+
const file = input.files[0];
|
| 811 |
+
if (!file) return;
|
| 812 |
+
_simStaged[side] = file;
|
| 813 |
+
const reader = new FileReader();
|
| 814 |
+
reader.onload = e => {
|
| 815 |
+
const img = document.getElementById(`sim-preview-${side}`);
|
| 816 |
+
img.src = e.target.result;
|
| 817 |
+
img.style.display = 'block';
|
| 818 |
+
};
|
| 819 |
+
reader.readAsDataURL(file);
|
| 820 |
+
}
|
| 821 |
+
|
| 822 |
+
// drag & drop — similarity A
|
| 823 |
+
function _wireDrop(dzId, side) {
|
| 824 |
+
const el = document.getElementById(dzId);
|
| 825 |
+
el.addEventListener('dragover', e => { e.preventDefault(); el.classList.add('drag-over'); });
|
| 826 |
+
el.addEventListener('dragleave', () => el.classList.remove('drag-over'));
|
| 827 |
+
el.addEventListener('drop', e => {
|
| 828 |
+
e.preventDefault(); el.classList.remove('drag-over');
|
| 829 |
+
const file = e.dataTransfer.files[0];
|
| 830 |
+
if (!file) return;
|
| 831 |
+
_simStaged[side] = file;
|
| 832 |
+
const reader = new FileReader();
|
| 833 |
+
reader.onload = ev => {
|
| 834 |
+
const img = document.getElementById(`sim-preview-${side}`);
|
| 835 |
+
img.src = ev.target.result;
|
| 836 |
+
img.style.display = 'block';
|
| 837 |
+
};
|
| 838 |
+
reader.readAsDataURL(file);
|
| 839 |
+
});
|
| 840 |
+
}
|
| 841 |
+
_wireDrop('sim-drop-a', 'a');
|
| 842 |
+
_wireDrop('sim-drop-b', 'b');
|
| 843 |
+
|
| 844 |
+
// placeholder — replaced by module script
|
| 845 |
+
function runSimilarity() {}
|
| 846 |
</script>
|
| 847 |
|
| 848 |
<!-- @gradio/client module: patches runFromUrl / submitFile / runPca on window
|
|
|
|
| 954 |
// Re-run with current colour pickers — re-submits full request (backbone
|
| 955 |
// result is cached by the Gradio queue so subsequent calls are fast if
|
| 956 |
// the same image/max_size is used, but ZeroGPU requires a full round-trip).
|
| 957 |
+
// ---- similarity ----
|
| 958 |
+
|
| 959 |
+
window.runSimilarity = async function() {
|
| 960 |
+
const urlA = document.getElementById('sim-url-a').value.trim();
|
| 961 |
+
const urlB = document.getElementById('sim-url-b').value.trim();
|
| 962 |
+
|
| 963 |
+
const imageArgA = urlA ? urlA : (_simStaged.a ? await handle_file(_simStaged.a) : null);
|
| 964 |
+
const imageArgB = urlB ? urlB : (_simStaged.b ? await handle_file(_simStaged.b) : null);
|
| 965 |
+
|
| 966 |
+
if (!imageArgA || !imageArgB) {
|
| 967 |
+
document.getElementById('sim-error').textContent = 'Provide both images before comparing.';
|
| 968 |
+
document.getElementById('sim-error').style.display = 'block';
|
| 969 |
+
return;
|
| 970 |
+
}
|
| 971 |
+
document.getElementById('sim-error').style.display = 'none';
|
| 972 |
+
document.getElementById('sim-run-btn').disabled = true;
|
| 973 |
+
document.getElementById('sim-spinner').style.display = 'block';
|
| 974 |
+
document.getElementById('sim-score-card').style.display = 'none';
|
| 975 |
+
|
| 976 |
+
try {
|
| 977 |
+
const res = await gradioApp.predict("/get_similarity", {
|
| 978 |
+
image_a: imageArgA,
|
| 979 |
+
image_b: imageArgB,
|
| 980 |
+
max_size: 1024,
|
| 981 |
+
});
|
| 982 |
+
const data = JSON.parse(res.data[0]);
|
| 983 |
+
const score = data.score; // [-1, 1]
|
| 984 |
+
const pct = Math.round(((score + 1) / 2) * 100); // map to [0,100]
|
| 985 |
+
|
| 986 |
+
document.getElementById('sim-score-value').textContent = score.toFixed(4);
|
| 987 |
+
document.getElementById('sim-score-bar').style.width = pct + '%';
|
| 988 |
+
document.getElementById('sim-score-note').textContent =
|
| 989 |
+
score > 0.9 ? 'Very high similarity — nearly identical semantic content.' :
|
| 990 |
+
score > 0.7 ? 'High similarity — strongly related images.' :
|
| 991 |
+
score > 0.5 ? 'Moderate similarity — some shared features.' :
|
| 992 |
+
score > 0.2 ? 'Low similarity — loosely related.' :
|
| 993 |
+
'Very low similarity — likely unrelated images.';
|
| 994 |
+
|
| 995 |
+
// colour the bar by score
|
| 996 |
+
const barEl = document.getElementById('sim-score-bar');
|
| 997 |
+
barEl.style.background =
|
| 998 |
+
score > 0.7 ? '#4ade80' :
|
| 999 |
+
score > 0.4 ? '#facc15' : '#f87171';
|
| 1000 |
+
|
| 1001 |
+
document.getElementById('sim-score-card').style.display = 'block';
|
| 1002 |
+
} catch (err) {
|
| 1003 |
+
document.getElementById('sim-error').textContent = String(err);
|
| 1004 |
+
document.getElementById('sim-error').style.display = 'block';
|
| 1005 |
+
} finally {
|
| 1006 |
+
document.getElementById('sim-run-btn').disabled = false;
|
| 1007 |
+
document.getElementById('sim-spinner').style.display = 'none';
|
| 1008 |
+
}
|
| 1009 |
+
};
|
| 1010 |
+
|
| 1011 |
window.rerunCustomPca = async function() {
|
| 1012 |
if (!_lastPcaRequest) return;
|
| 1013 |
const spinner = document.getElementById('pca-spinner');
|