CLIP-CAE: Interpretable Composition Attribution Enhancement for Visio-linguistic Compositional Understanding

This repository contains the official model weights for the EMNLP 2024 paper: "Interpretable Composition Attribution Enhancement for Visio-linguistic Compositional Understanding".

πŸ“Œ Abstract

Contrastively trained vision-language models (e.g., CLIP) often struggle with compositional reasoning, failing to distinguish fine-grained relations or attributes. We identify the root cause as the composition attribution issue, where models assign significantly lower attribution scores to relation and attribute terms compared to object terms.

CAE (Composition Attribution Enhancement) is a generic framework designed to rectify text attribution by shifting the model's focus beyond isolated object terms toward essential compositional concepts. Our approach significantly enhances the model's ability to discern intricate visual and linguistic details across seven benchmarks.

πŸ“‚ Model Variants

We provide four model variants based on different attribution methods used during training. All models use CLIP-ViT-B/32 as the backbone.

Model Variant Filename Attribution Method
CLIP-CAE (Attention-Based) CLIP-CAE-AB.pt Internal attention weight.
CLIP-CAE (GradCAM-Based) CLIP-CAE-GCB.pt GradCAM score.
CLIP-CAE (Perturbation-Based) CLIP-CAE-PB.pt Input perturbation.
CLIP-CAE (Gradient-Based) CLIP-CAE-GB.pt Gradients.

πŸ’» Usage

You can load these checkpoints using the open_clip library.

import open_clip
import torch

# Path to the downloaded .pt file (e.g., 'CLIP-CAE-AB.pt')
pretrained_path = 'path/to/CLIP-CAE-AB.pt'
device = "cuda" if torch.cuda.is_available() else "cpu"

# Create model and load the CAE weights
model, _, image_preprocess = open_clip.create_model_and_transforms(
    'ViT-B-32', 
    pretrained=pretrained_path, 
    device=device
)
model = model.eval()

print("CLIP-CAE model loaded successfully!")
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support