CIMP Style Transfer Generator (ResNet-18 CIMP, crop 512, batch 16)
A metadata-conditioned style-transfer generator that translates HAADF-STEM images between acquisition settings. Given an input image $x$ and a pair of CIMP metadata embeddings $(e_{\text{id}}, e_{\text{tgt}})$, the network produces an image that preserves the content of $x$ but matches the style associated with $e_{\text{tgt}}$.
Conditioning runs on top of Stemson-AI/cmmp-resnet18-512. Training used four objectives: LSGAN adversarial, LPIPS cycle-consistency, LPIPS identity, and a CIMP-space embedding-alignment term.
What's new vs. the previous upload
batch_sizeincreased from 8 to 16.lambda_cycleandlambda_idincreased from 0.1 to 0.125 to put slightly more weight on content preservation.
Architecture
- Generator:
StyleUNetwith FiLM conditioning at every convolutional block.base_filters=32,embed_dim=256(concatenated 128-d target + identity CIMP embeddings). ~9.9M params. - Discriminator:
NoisePatchGAN,base_filters=64,meta_embed_dim=128. ~0.66M params.
Training Configuration
| Parameter | Value |
|---|---|
| CIMP encoder | Stemson-AI/cmmp-resnet18-512 (frozen) |
| Crop size | $512 \times 512$ |
| Batch size | 16 |
| Epochs | 250 |
| Optimizer | Adam, $\text{lr}_G = \text{lr}_D = 2 \cdot 10^{-4}$ |
| LSGAN | $\mathcal{L}_{\text{LSGAN}}$ |
| Cycle | LPIPS, $\lambda_1 = 0.125$ |
| Identity | LPIPS, $\lambda_2 = 0.125$ |
| Emb alignment | MSE in CIMP visual-embedding space, $\lambda_3 = 0.5$ |
| Hardware | 1 $\times$ H100 |
Files
generator.pth- inference-ready state dict for theStyleUNetgenerator.full_checkpoint.pth- dict with bothgeneratoranddiscriminatorstate dicts, for resuming training.config.json- architecture + training hyperparameters.training_log.csv- per-epoch training / validation losses.
Usage
import torch
from huggingface_hub import hf_hub_download
from models import CMMP, StyleUNet
import torch.nn.functional as F
device = "cuda"
cimp = CMMP(
meta_input_dim=7, embed_dim=128,
image_encoder="resnet18", image_size=512,
meta_hidden_dim=256, meta_num_layers=3,
).to(device)
cimp.load_state_dict(torch.load(hf_hub_download("Stemson-AI/cmmp-resnet18-512", "model.pth"),
map_location=device))
cimp.eval()
gen = StyleUNet(embed_dim=256, in_channels=1, out_channels=1, base_filters=32, use_FiLM=True).to(device)
gen.load_state_dict(torch.load(hf_hub_download("Stemson-AI/cimp-style-transfer-512", "generator.pth"),
map_location=device))
gen.eval()
# x: (1, 1, 512, 512) in [0, 1]; src_meta, tgt_meta: (1, 7) z-scored
with torch.no_grad():
e_id = F.normalize(cimp.meta(src_meta.to(device)), p=2, dim=-1)
e_tgt = F.normalize(cimp.meta(tgt_meta.to(device)), p=2, dim=-1)
y = gen(x.to(device), e_tgt, e_id)
Related Models
- Stemson-AI/cmmp-resnet18-512 - the CIMP encoder.
- Stemson-AI/cmmp-resnet18-256 - earlier CIMP variant.
Citation
@misc{cimp2026,
title={Contrastive Image-Metadata Pre-training for Materials Transmission Electron Microscopy},
author={Channing, Georgia and Keller, Debora and Rossell, Marta D. and Torr, Philip and Erni, Rolf and Helveg, Stig and Eliasson, Henrik},
year={2026},
}
- Downloads last month
- 13