CIMP Style Transfer Generator (ResNet-18 CIMP, crop 512, batch 16)

A metadata-conditioned style-transfer generator that translates HAADF-STEM images between acquisition settings. Given an input image $x$ and a pair of CIMP metadata embeddings $(e_{\text{id}}, e_{\text{tgt}})$, the network produces an image that preserves the content of $x$ but matches the style associated with $e_{\text{tgt}}$.

Conditioning runs on top of Stemson-AI/cmmp-resnet18-512. Training used four objectives: LSGAN adversarial, LPIPS cycle-consistency, LPIPS identity, and a CIMP-space embedding-alignment term.

What's new vs. the previous upload

  • batch_size increased from 8 to 16.
  • lambda_cycle and lambda_id increased from 0.1 to 0.125 to put slightly more weight on content preservation.

Architecture

  • Generator: StyleUNet with FiLM conditioning at every convolutional block. base_filters=32, embed_dim=256 (concatenated 128-d target + identity CIMP embeddings). ~9.9M params.
  • Discriminator: NoisePatchGAN, base_filters=64, meta_embed_dim=128. ~0.66M params.

Training Configuration

Parameter Value
CIMP encoder Stemson-AI/cmmp-resnet18-512 (frozen)
Crop size $512 \times 512$
Batch size 16
Epochs 250
Optimizer Adam, $\text{lr}_G = \text{lr}_D = 2 \cdot 10^{-4}$
LSGAN $\mathcal{L}_{\text{LSGAN}}$
Cycle LPIPS, $\lambda_1 = 0.125$
Identity LPIPS, $\lambda_2 = 0.125$
Emb alignment MSE in CIMP visual-embedding space, $\lambda_3 = 0.5$
Hardware 1 $\times$ H100

Files

  • generator.pth - inference-ready state dict for the StyleUNet generator.
  • full_checkpoint.pth - dict with both generator and discriminator state dicts, for resuming training.
  • config.json - architecture + training hyperparameters.
  • training_log.csv - per-epoch training / validation losses.

Usage

import torch
from huggingface_hub import hf_hub_download
from models import CMMP, StyleUNet
import torch.nn.functional as F

device = "cuda"

cimp = CMMP(
    meta_input_dim=7, embed_dim=128,
    image_encoder="resnet18", image_size=512,
    meta_hidden_dim=256, meta_num_layers=3,
).to(device)
cimp.load_state_dict(torch.load(hf_hub_download("Stemson-AI/cmmp-resnet18-512", "model.pth"),
                                map_location=device))
cimp.eval()

gen = StyleUNet(embed_dim=256, in_channels=1, out_channels=1, base_filters=32, use_FiLM=True).to(device)
gen.load_state_dict(torch.load(hf_hub_download("Stemson-AI/cimp-style-transfer-512", "generator.pth"),
                               map_location=device))
gen.eval()

# x: (1, 1, 512, 512) in [0, 1]; src_meta, tgt_meta: (1, 7) z-scored
with torch.no_grad():
    e_id  = F.normalize(cimp.meta(src_meta.to(device)), p=2, dim=-1)
    e_tgt = F.normalize(cimp.meta(tgt_meta.to(device)), p=2, dim=-1)
    y = gen(x.to(device), e_tgt, e_id)

Related Models

Citation

@misc{cimp2026,
  title={Contrastive Image-Metadata Pre-training for Materials Transmission Electron Microscopy},
  author={Channing, Georgia and Keller, Debora and Rossell, Marta D. and Torr, Philip and Erni, Rolf and Helveg, Stig and Eliasson, Henrik},
  year={2026},
}
Downloads last month
13
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support