Image Diffusion Preview with Consistency Solver (Google DeepMind)

paper code huggingface model

Quick Start

Pythonimport torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from scheduler_ppo import PPOScheduler  # Provided in this repo
from huggingface_hub import hf_hub_download

# Download the trained factor_net checkpoint
factor_net_path = hf_hub_download(
    repo_id="wangfuyun/consolver",
    filename="model.ckpt"
)

model_id = "runwayml/stable-diffusion-v1-5"
prompt = "an astronaut riding a horse on the moon, highly detailed, 8k"
num_inference_steps = 8
guidance_scale = 3.0
seed = 43
height = width = 512

def load_pipeline(scheduler_type="ddim"):
    if scheduler_type == "ppo":
        scheduler = PPOScheduler(
            beta_end=0.012,
            beta_schedule="scaled_linear",
            beta_start=0.00085,
            num_train_timesteps=1000,
            steps_offset=1,
            timestep_spacing="trailing",
            order_dim=4,
            scaler_dim=0,
            use_conv=False,
            factor_net_kwargs=dict(embedding_dim=64, hidden_dim=256, num_actions=11),
        )
    else:
        scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler", timestep_spacing="trailing")

    pipe = StableDiffusionPipeline.from_pretrained(
        model_id,
        scheduler=scheduler,
        safety_checker=None,
        # torch_dtype=torch.float16,  # Uncomment for GPU memory savings
    ).to("cuda")

    if scheduler_type == "ppo" and factor_net_path:
        weight = torch.load(factor_net_path, map_location="cpu")
        pipe.scheduler.factor_net.load_state_dict(weight)
        pipe.scheduler.factor_net.to("cuda")

    return pipe

generator = torch.Generator("cuda").manual_seed(seed)

# DDIM baseline (8 steps)
pipe_ddim = load_pipeline("ddim")
image_ddim = pipe_ddim(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
                       generator=generator, height=height, width=width).images[0]
image_ddim.save("ddim_result.jpg")

# ConSolver (8 steps)
pipe_consolver = load_pipeline("ppo")
image_consolver = pipe_consolver(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
                                 generator=generator, height=height, width=width).images[0]
image_consolver.save("consolver_result.jpg")
DDIM Consistency Solver
DDIM ConsistencySolver
Downloads last month
16
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for wangfuyun/consolver

Finetuned
(605)
this model