Framer / app.py
multimodalart's picture
multimodalart HF Staff
[Admin maintenance] Support new ZeroGPU hardware (#6)
32096a3
import datetime
import os
import sys
import uuid
import warnings
import spaces
# Stub huggingface_hub symbols removed in 0.26+ that diffusers 0.24 still imports.
# We keep diffusers 0.24 (the vendored models_diffusers/ uses APIs removed in 0.30+),
# but gradio 5.49.1 requires huggingface-hub >= 0.33.5 which dropped HfFolder/cached_download/hf_cache_home.
# NOTE: huggingface_hub uses lazy __getattr__, so `hasattr` returns False even when the attribute
# is "exposed" lazily. We probe with attribute access in a try/except instead.
import huggingface_hub as _hf_hub
import huggingface_hub.constants as _hf_hub_constants
try:
_hf_hub_constants.hf_cache_home
except AttributeError:
_hf_hub_constants.hf_cache_home = _hf_hub_constants.HF_HOME
try:
_hf_hub.cached_download
except AttributeError:
def _cached_download_removed(*a, **k):
raise NotImplementedError(
"huggingface_hub.cached_download was removed in 0.26. "
"This stub satisfies diffusers 0.24's import; "
"the function is only used for GitHub community-pipeline downloads which Framer does not invoke."
)
_hf_hub.cached_download = _cached_download_removed
try:
_hf_hub.HfFolder
except AttributeError:
class _HfFolderShim:
@staticmethod
def get_token():
try:
return _hf_hub.get_token()
except Exception:
return None
@staticmethod
def save_token(token):
pass
_hf_hub.HfFolder = _HfFolderShim
import cv2
import gradio as gr
import numpy as np
import torch
import torchvision
from huggingface_hub import snapshot_download
from PIL import Image
from scipy.interpolate import PchipInterpolator
sys.path.insert(0, os.getcwd())
from gradio_demo.utils_drag import *
from models_diffusers.controlnet_svd import ControlNetSVDModel
from models_diffusers.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from pipelines.pipeline_stable_video_diffusion_interp_control import StableVideoDiffusionInterpControlPipeline
print("gr file", gr.__file__)
os.makedirs("checkpoints", exist_ok=True)
snapshot_download(
"wwen1997/framer_512x320",
local_dir="checkpoints/framer_512x320",
token=os.environ["TOKEN"],
)
snapshot_download(
"stabilityai/stable-video-diffusion-img2vid-xt",
local_dir="checkpoints/stable-video-diffusion-img2vid-xt",
token=os.environ["TOKEN"],
)
model_id = "checkpoints/framer_512x320"
device = "cuda"
dtype = torch.float16
OUTPUT_DIR = "gradio_demo/outputs"
HEIGHT = 320
WIDTH = 512
MODEL_LENGTH = 14
USE_SIFT = False
unet = UNetSpatioTemporalConditionModel.from_pretrained(
os.path.join(model_id, "unet"),
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
custom_resume=True,
)
unet = unet.to(device, dtype)
controlnet = ControlNetSVDModel.from_pretrained(
os.path.join(model_id, "controlnet"),
)
controlnet = controlnet.to(device, dtype)
pipe = StableVideoDiffusionInterpControlPipeline.from_pretrained(
"checkpoints/stable-video-diffusion-img2vid-xt",
unet=unet,
controlnet=controlnet,
low_cpu_mem_usage=False,
torch_dtype=torch.float16,
variant="fp16",
local_files_only=True,
)
pipe.to(device)
def interpolate_trajectory(points, n_points):
x = [point[0] for point in points]
y = [point[1] for point in points]
t = np.linspace(0, 1, len(points))
# fx = interp1d(t, x, kind='cubic')
# fy = interp1d(t, y, kind='cubic')
fx = PchipInterpolator(t, x)
fy = PchipInterpolator(t, y)
new_t = np.linspace(0, 1, n_points)
new_x = fx(new_t)
new_y = fy(new_t)
new_points = list(zip(new_x, new_y))
return new_points
def gen_gaussian_heatmap(imgSize=200):
circle_img = np.zeros((imgSize, imgSize), np.float32)
circle_mask = cv2.circle(circle_img, (imgSize // 2, imgSize // 2), imgSize // 2, 1, -1)
isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32)
for i in range(imgSize):
for j in range(imgSize):
isotropicGrayscaleImage[i, j] = (
1
/ 2
/ np.pi
/ (40**2)
* np.exp(-1 / 2 * ((i - imgSize / 2) ** 2 / (40**2) + (j - imgSize / 2) ** 2 / (40**2)))
)
isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask
isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32)
isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage) * 255).astype(np.uint8)
return isotropicGrayscaleImage
def get_vis_image(
target_size=(512, 512),
points=None,
side=20,
num_frames=14,
# original_size=(512 , 512), args="", first_frame=None, is_mask = False, model_id=None,
):
# images = []
vis_images = []
heatmap = gen_gaussian_heatmap()
trajectory_list = []
radius_list = []
for index, point in enumerate(points):
trajectories = [[int(i[0]), int(i[1])] for i in point]
trajectory_list.append(trajectories)
radius = 20
radius_list.append(radius)
if len(trajectory_list) == 0:
vis_images = [Image.fromarray(np.zeros(target_size, np.uint8)) for _ in range(num_frames)]
return vis_images
for idxx, point in enumerate(trajectory_list[0]):
new_img = np.zeros(target_size, np.uint8)
vis_img = new_img.copy()
# ids_embedding = torch.zeros((target_size[0], target_size[1], 320))
if idxx >= num_frames:
break
# for cc, (mask, trajectory, radius) in enumerate(zip(mask_list, trajectory_list, radius_list)):
for cc, (trajectory, radius) in enumerate(zip(trajectory_list, radius_list)):
center_coordinate = trajectory[idxx]
trajectory_ = trajectory[:idxx]
side = min(radius, 50)
y1 = max(center_coordinate[1] - side, 0)
y2 = min(center_coordinate[1] + side, target_size[0] - 1)
x1 = max(center_coordinate[0] - side, 0)
x2 = min(center_coordinate[0] + side, target_size[1] - 1)
if x2 - x1 > 3 and y2 - y1 > 3:
need_map = cv2.resize(heatmap, (x2 - x1, y2 - y1))
new_img[y1:y2, x1:x2] = need_map.copy()
if cc >= 0:
vis_img[y1:y2, x1:x2] = need_map.copy()
if len(trajectory_) == 1:
vis_img[trajectory_[0][1], trajectory_[0][0]] = 255
else:
for itt in range(len(trajectory_) - 1):
cv2.line(
vis_img,
(trajectory_[itt][0], trajectory_[itt][1]),
(trajectory_[itt + 1][0], trajectory_[itt + 1][1]),
(255, 255, 255),
3,
)
img = new_img
# Ensure all images are in RGB format
if len(img.shape) == 2: # Grayscale image
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_GRAY2RGB)
elif len(img.shape) == 3 and img.shape[2] == 3: # Color image in BGR format
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
# Convert the numpy array to a PIL image
# pil_img = Image.fromarray(img)
# images.append(pil_img)
vis_images.append(Image.fromarray(vis_img))
return vis_images
def frames_to_video(frames_folder, output_video_path, fps=7):
frame_files = os.listdir(frames_folder)
# sort the frame files by their names
frame_files = sorted(frame_files, key=lambda x: int(x.split(".")[0]))
video = []
for frame_file in frame_files:
frame_path = os.path.join(frames_folder, frame_file)
frame = torchvision.io.read_image(frame_path)
video.append(frame)
video = torch.stack(video)
video = rearrange(video, "T C H W -> T H W C")
torchvision.io.write_video(output_video_path, video, fps=fps)
def save_gifs_side_by_side(
batch_output,
validation_control_images,
output_folder,
target_size=(512, 512),
duration=200,
point_tracks=None,
):
flattened_batch_output = batch_output
def create_gif(image_list, gif_path, duration=100):
pil_images = [validate_and_convert_image(img, target_size=target_size) for img in image_list]
pil_images = [img for img in pil_images if img is not None]
if pil_images:
pil_images[0].save(gif_path, save_all=True, append_images=pil_images[1:], loop=0, duration=duration)
# also save all the pil_images
tmp_folder = gif_path.replace(".gif", "")
print(tmp_folder)
ensure_dirname(tmp_folder)
tmp_frame_list = []
for idx, pil_image in enumerate(pil_images):
tmp_frame_path = os.path.join(tmp_folder, f"{idx}.png")
pil_image.save(tmp_frame_path)
tmp_frame_list.append(tmp_frame_path)
# also save as mp4
output_video_path = gif_path.replace(".gif", ".mp4")
frames_to_video(tmp_folder, output_video_path, fps=7)
# Creating GIFs for each image list
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
gif_paths = []
for idx, image_list in enumerate([validation_control_images, flattened_batch_output]):
gif_path = os.path.join(output_folder.replace("vis_gif.gif", ""), f"temp_{idx}_{timestamp}.gif")
create_gif(image_list, gif_path)
gif_paths.append(gif_path)
# also save the point_tracks
assert point_tracks is not None
point_tracks_path = gif_path.replace(".gif", ".npy")
np.save(point_tracks_path, point_tracks.cpu().numpy())
# Function to combine GIFs side by side
def combine_gifs_side_by_side(gif_paths, output_path):
print(gif_paths)
gifs = [Image.open(gif) for gif in gif_paths]
# Assuming all gifs have the same frame count and duration
frames = []
for frame_idx in range(gifs[-1].n_frames):
combined_frame = None
for gif in gifs:
if frame_idx >= gif.n_frames:
gif.seek(gif.n_frames - 1)
else:
gif.seek(frame_idx)
if combined_frame is None:
combined_frame = gif.copy()
else:
combined_frame = get_concat_h(combined_frame, gif.copy(), gap=10)
frames.append(combined_frame)
if output_path.endswith(".mp4"):
video = [torchvision.transforms.functional.pil_to_tensor(frame) for frame in frames]
video = torch.stack(video)
video = rearrange(video, "T C H W -> T H W C")
torchvision.io.write_video(output_path, video, fps=7)
print(f"Saved video to {output_path}")
else:
frames[0].save(output_path, save_all=True, append_images=frames[1:], loop=0, duration=duration)
# Helper function to concatenate images horizontally
def get_concat_h(im1, im2, gap=10):
# # img first, heatmap second
# im1, im2 = im2, im1
dst = Image.new("RGB", (im1.width + im2.width + gap, max(im1.height, im2.height)), (255, 255, 255))
dst.paste(im1, (0, 0))
dst.paste(im2, (im1.width + gap, 0))
return dst
# Helper function to concatenate images vertically
def get_concat_v(im1, im2):
dst = Image.new("RGB", (max(im1.width, im2.width), im1.height + im2.height))
dst.paste(im1, (0, 0))
dst.paste(im2, (0, im1.height))
return dst
# Combine the GIFs into a single file
combined_gif_path = output_folder
combine_gifs_side_by_side(gif_paths, combined_gif_path)
combined_gif_path_v = gif_path.replace(".gif", "_v.mp4")
ensure_dirname(combined_gif_path_v.replace(".mp4", ""))
combine_gifs_side_by_side(gif_paths, combined_gif_path_v)
# # Clean up temporary GIFs
# for gif_path in gif_paths:
# os.remove(gif_path)
return combined_gif_path
# Define functions
def validate_and_convert_image(image, target_size=(512, 512)):
if image is None:
print("Encountered a None image")
return None
if isinstance(image, torch.Tensor):
# Convert PyTorch tensor to PIL Image
if image.ndim == 3 and image.shape[0] in [1, 3]: # Check for CxHxW format
if image.shape[0] == 1: # Convert single-channel grayscale to RGB
image = image.repeat(3, 1, 1)
image = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
image = Image.fromarray(image)
else:
print(f"Invalid image tensor shape: {image.shape}")
return None
elif isinstance(image, Image.Image):
# Resize PIL Image
image = image.resize(target_size)
else:
print("Image is not a PIL Image or a PyTorch tensor")
return None
return image
def reset_states():
return None, None, None, None, None, []
def preprocess_image(image):
image_pil = image2pil(image.name)
raw_w, raw_h = image_pil.size
# resize_ratio = max(512 / raw_w, 320 / raw_h)
# image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
# image_pil = transforms.CenterCrop((320, 512))(image_pil.convert('RGB'))
image_pil = image_pil.resize((512, 320), Image.BILINEAR)
first_frame_path = os.path.join(OUTPUT_DIR, f"first_frame_{str(uuid.uuid4())[:4]}.png")
image_pil.save(first_frame_path)
return first_frame_path, first_frame_path, []
def preprocess_image_end(image_end):
image_end_pil = image2pil(image_end.name)
raw_w, raw_h = image_end_pil.size
# resize_ratio = max(512 / raw_w, 320 / raw_h)
# image_end_pil = image_end_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
# image_end_pil = transforms.CenterCrop((320, 512))(image_end_pil.convert('RGB'))
image_end_pil = image_end_pil.resize((512, 320), Image.BILINEAR)
last_frame_path = os.path.join(OUTPUT_DIR, f"last_frame_{str(uuid.uuid4())[:4]}.png")
image_end_pil.save(last_frame_path)
return last_frame_path, last_frame_path, []
def add_drag(tracking_points):
if not tracking_points or tracking_points[-1]:
tracking_points.append([])
return tracking_points
def delete_last_drag(tracking_points, first_frame_path, last_frame_path):
if tracking_points:
tracking_points.pop()
transparent_background = Image.open(first_frame_path).convert("RGBA")
transparent_background_end = Image.open(last_frame_path).convert("RGBA")
w, h = transparent_background.size
transparent_layer = np.zeros((h, w, 4))
for track in tracking_points:
if len(track) > 1:
for i in range(len(track) - 1):
start_point = track[i]
end_point = track[i + 1]
vx = end_point[0] - start_point[0]
vy = end_point[1] - start_point[1]
arrow_length = np.sqrt(vx**2 + vy**2)
if i == len(track) - 2:
cv2.arrowedLine(
transparent_layer,
tuple(start_point),
tuple(end_point),
(255, 0, 0, 255),
2,
tipLength=8 / arrow_length,
)
else:
cv2.line(
transparent_layer,
tuple(start_point),
tuple(end_point),
(255, 0, 0, 255),
2,
)
else:
cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer)
return tracking_points, trajectory_map, trajectory_map_end
def delete_last_step(tracking_points, first_frame_path, last_frame_path):
if tracking_points and tracking_points[-1]:
tracking_points[-1].pop()
transparent_background = Image.open(first_frame_path).convert("RGBA")
transparent_background_end = Image.open(last_frame_path).convert("RGBA")
w, h = transparent_background.size
transparent_layer = np.zeros((h, w, 4))
for track in tracking_points:
if not track:
continue
if len(track) > 1:
for i in range(len(track) - 1):
start_point = track[i]
end_point = track[i + 1]
vx = end_point[0] - start_point[0]
vy = end_point[1] - start_point[1]
arrow_length = np.sqrt(vx**2 + vy**2)
if i == len(track) - 2:
cv2.arrowedLine(
transparent_layer,
tuple(start_point),
tuple(end_point),
(255, 0, 0, 255),
2,
tipLength=8 / arrow_length,
)
else:
cv2.line(
transparent_layer,
tuple(start_point),
tuple(end_point),
(255, 0, 0, 255),
2,
)
else:
cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer)
return tracking_points, trajectory_map, trajectory_map_end
def add_tracking_points(
tracking_points, first_frame_path, last_frame_path, evt: gr.SelectData
): # SelectData is a subclass of EventData
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
if not tracking_points:
tracking_points = [[]]
tracking_points[-1].append(evt.index)
transparent_background = Image.open(first_frame_path).convert("RGBA")
transparent_background_end = Image.open(last_frame_path).convert("RGBA")
w, h = transparent_background.size
transparent_layer = 0
for idx, track in enumerate(tracking_points):
# mask = cv2.imread(
# os.path.join(OUTPUT_DIR, f"mask_{idx+1}.jpg")
# )
mask = np.zeros((320, 512, 3))
color = color_list[idx + 1]
transparent_layer = mask[:, :, 0].reshape(h, w, 1) * color.reshape(1, 1, -1) + transparent_layer
if len(track) > 1:
for i in range(len(track) - 1):
start_point = track[i]
end_point = track[i + 1]
vx = end_point[0] - start_point[0]
vy = end_point[1] - start_point[1]
arrow_length = np.sqrt(vx**2 + vy**2)
if i == len(track) - 2:
cv2.arrowedLine(
transparent_layer,
tuple(start_point),
tuple(end_point),
(255, 0, 0, 255),
2,
tipLength=8 / arrow_length,
)
else:
cv2.line(
transparent_layer,
tuple(start_point),
tuple(end_point),
(255, 0, 0, 255),
2,
)
else:
cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
alpha_coef = 0.99
im2_data = transparent_layer.getdata()
new_im2_data = [(r, g, b, int(a * alpha_coef)) for r, g, b, a in im2_data]
transparent_layer.putdata(new_im2_data)
trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer)
return tracking_points, trajectory_map, trajectory_map_end
@spaces.GPU
def run(
first_frame_path,
last_frame_path,
tracking_points,
controlnet_cond_scale,
motion_bucket_id,
progress=gr.Progress(track_tqdm=True),
):
original_width, original_height = 512, 320 # TODO
# load_image
image = Image.open(first_frame_path).convert("RGB")
width, height = image.size
image = image.resize((WIDTH, HEIGHT))
image_end = Image.open(last_frame_path).convert("RGB")
image_end = image_end.resize((WIDTH, HEIGHT))
input_all_points = tracking_points
sift_track_update = False
anchor_points_flag = None
if (len(input_all_points) == 0) and USE_SIFT:
sift_track_update = True
controlnet_cond_scale = 0.5
from models_diffusers.sift_match import interpolate_trajectory as sift_interpolate_trajectory
from models_diffusers.sift_match import sift_match
output_file_sift = os.path.join(OUTPUT_DIR, "sift.png")
# (f, topk, 2), f=2 (before interpolation)
pred_tracks = sift_match(
image,
image_end,
thr=0.5,
topk=5,
method="random",
output_path=output_file_sift,
)
if pred_tracks is not None:
# interpolate the tracks, following draganything gradio demo
pred_tracks = sift_interpolate_trajectory(pred_tracks, num_frames=MODEL_LENGTH)
anchor_points_flag = torch.zeros((MODEL_LENGTH, pred_tracks.shape[1])).to(pred_tracks.device)
anchor_points_flag[0] = 1
anchor_points_flag[-1] = 1
pred_tracks = pred_tracks.permute(1, 0, 2) # (num_points, num_frames, 2)
else:
resized_all_points = [
tuple([tuple([int(e1[0] * WIDTH / original_width), int(e1[1] * HEIGHT / original_height)]) for e1 in e])
for e in input_all_points
]
# a list of num_tracks tuples, each tuple contains a track with several points, represented as (x, y)
# in image w & h scale
for idx, splited_track in enumerate(resized_all_points):
if len(splited_track) == 0:
warnings.warn("running without point trajectory control")
continue
if len(splited_track) == 1: # stationary point
displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
splited_track = tuple([splited_track[0], displacement_point])
# interpolate the track
splited_track = interpolate_trajectory(splited_track, MODEL_LENGTH)
splited_track = splited_track[:MODEL_LENGTH]
resized_all_points[idx] = splited_track
pred_tracks = torch.tensor(resized_all_points) # (num_points, num_frames, 2)
vis_images = get_vis_image(
target_size=(HEIGHT, WIDTH),
points=pred_tracks,
num_frames=MODEL_LENGTH,
)
if len(pred_tracks.shape) != 3:
print("pred_tracks.shape", pred_tracks.shape)
with_control = False
controlnet_cond_scale = 0.0
else:
with_control = True
pred_tracks = pred_tracks.permute(1, 0, 2).to(device, dtype) # (num_frames, num_points, 2)
point_embedding = None
video_frames = pipe(
image,
image_end,
# trajectory control
with_control=with_control,
point_tracks=pred_tracks,
point_embedding=point_embedding,
with_id_feature=False,
controlnet_cond_scale=controlnet_cond_scale,
# others
num_frames=14,
width=width,
height=height,
# decode_chunk_size=8,
# generator=generator,
motion_bucket_id=motion_bucket_id,
fps=7,
num_inference_steps=30,
# track
sift_track_update=sift_track_update,
anchor_points_flag=anchor_points_flag,
).frames[0]
vis_images = [cv2.applyColorMap(np.array(img).astype(np.uint8), cv2.COLORMAP_JET) for img in vis_images]
vis_images = [cv2.cvtColor(np.array(img).astype(np.uint8), cv2.COLOR_BGR2RGB) for img in vis_images]
vis_images = [Image.fromarray(img) for img in vis_images]
# video_frames = [img for sublist in video_frames for img in sublist]
val_save_dir = os.path.join(OUTPUT_DIR, "vis_gif.gif")
save_gifs_side_by_side(
video_frames,
vis_images[:MODEL_LENGTH],
val_save_dir,
target_size=(WIDTH, HEIGHT),
duration=110,
point_tracks=pred_tracks,
)
return val_save_dir
if __name__ == "__main__":
ensure_dirname(OUTPUT_DIR)
color_list = []
for i in range(20):
color = np.concatenate([np.random.random(4) * 255], axis=0)
color_list.append(color)
with gr.Blocks() as demo:
gr.Markdown("""<h1 align="center">Framer: Interactive Frame Interpolation</h1><br>""")
gr.Markdown(
"""Gradio Demo for <a href='https://arxiv.org/abs/2410.18978'><b>Framer: Interactive Frame Interpolation</b></a>.<br>
Github Repo can be found at https://github.com/aim-uofa/Framer<br>
The template is inspired by DragAnything."""
)
gr.Image(label="Framer: Interactive Frame Interpolation", value="assets/demos.gif", height=432, width=768)
gr.Markdown(
"""## Usage: <br>
1. Upload images<br>
&ensp; 1.1 Upload the start image via the "Upload Start Image" button.<br>
&ensp; 1.2. Upload the end image via the "Upload End Image" button.<br>
2. (Optional) Draw some drags.<br>
&ensp; 2.1. Click "Add Drag Trajectory" to add the motion trajectory.<br>
&ensp; 2.2. You can click several points on either start or end image to forms a path.<br>
&ensp; 2.3. Click "Delete last drag" to delete the whole lastest path.<br>
&ensp; 2.4. Click "Delete last step" to delete the lastest clicked control point.<br>
3. Interpolate the images (according the path) with a click on "Run" button. <br>"""
)
first_frame_path = gr.State()
last_frame_path = gr.State()
tracking_points = gr.State([])
with gr.Row():
with gr.Column(scale=1):
image_upload_button = gr.UploadButton(label="Upload Start Image", file_types=["image"])
image_end_upload_button = gr.UploadButton(label="Upload End Image", file_types=["image"])
# select_area_button = gr.Button(value="Select Area with SAM")
add_drag_button = gr.Button(value="Add New Drag Trajectory")
reset_button = gr.Button(value="Reset")
run_button = gr.Button(value="Run")
delete_last_drag_button = gr.Button(value="Delete last drag")
delete_last_step_button = gr.Button(value="Delete last step")
with gr.Column(scale=7):
with gr.Row():
with gr.Column(scale=6):
input_image = gr.Image(
label="start frame",
interactive=True,
height=320,
width=512,
sources=[],
)
with gr.Column(scale=6):
input_image_end = gr.Image(
label="end frame",
interactive=True,
height=320,
width=512,
sources=[],
)
with gr.Row():
with gr.Column(scale=1):
controlnet_cond_scale = gr.Slider(
label="Control Scale",
minimum=0.0,
maximum=10,
step=0.1,
value=1.0,
)
motion_bucket_id = gr.Slider(
label="Motion Bucket",
minimum=1,
maximum=180,
step=1,
value=100,
)
with gr.Column(scale=5):
output_video = gr.Image(
label="Output Video",
height=320,
width=1152,
)
with gr.Row():
gr.Markdown(
"""
## Citation
```bibtex
@article{wang2024framer,
title={Framer: Interactive Frame Interpolation},
author={Wang, Wen and Wang, Qiuyu and Zheng, Kecheng and Ouyang, Hao and Chen, Zhekai and Gong, Biao and Chen, Hao and Shen, Yujun and Shen, Chunhua},
journal={arXiv preprint https://arxiv.org/abs/2410.18978},
year={2024}
}
```
"""
)
image_upload_button.upload(
fn=preprocess_image,
inputs=image_upload_button,
outputs=[input_image, first_frame_path, tracking_points],
)
image_end_upload_button.upload(
fn=preprocess_image_end,
inputs=image_end_upload_button,
outputs=[input_image_end, last_frame_path, tracking_points],
)
add_drag_button.click(
fn=add_drag,
inputs=tracking_points,
outputs=tracking_points,
)
delete_last_drag_button.click(
fn=delete_last_drag,
inputs=[tracking_points, first_frame_path, last_frame_path],
outputs=[tracking_points, input_image, input_image_end],
)
delete_last_step_button.click(
fn=delete_last_step,
inputs=[tracking_points, first_frame_path, last_frame_path],
outputs=[tracking_points, input_image, input_image_end],
)
reset_button.click(
fn=reset_states,
outputs=[input_image, input_image_end, first_frame_path, last_frame_path, output_video, tracking_points],
)
gr.on(
triggers=[input_image.select, input_image_end.select],
fn=add_tracking_points,
inputs=[tracking_points, first_frame_path, last_frame_path],
outputs=[tracking_points, input_image, input_image_end],
)
run_button.click(
fn=run,
inputs=[first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id],
outputs=output_video,
)
demo.launch()