Image Diffusion Preview with Consistency Solver (Google DeepMind)
Quick Start
Pythonimport torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from scheduler_ppo import PPOScheduler # Provided in this repo
from huggingface_hub import hf_hub_download
# Download the trained factor_net checkpoint
factor_net_path = hf_hub_download(
repo_id="wangfuyun/consolver",
filename="model.ckpt"
)
model_id = "runwayml/stable-diffusion-v1-5"
prompt = "an astronaut riding a horse on the moon, highly detailed, 8k"
num_inference_steps = 8
guidance_scale = 3.0
seed = 43
height = width = 512
def load_pipeline(scheduler_type="ddim"):
if scheduler_type == "ppo":
scheduler = PPOScheduler(
beta_end=0.012,
beta_schedule="scaled_linear",
beta_start=0.00085,
num_train_timesteps=1000,
steps_offset=1,
timestep_spacing="trailing",
order_dim=4,
scaler_dim=0,
use_conv=False,
factor_net_kwargs=dict(embedding_dim=64, hidden_dim=256, num_actions=11),
)
else:
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler", timestep_spacing="trailing")
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
scheduler=scheduler,
safety_checker=None,
# torch_dtype=torch.float16, # Uncomment for GPU memory savings
).to("cuda")
if scheduler_type == "ppo" and factor_net_path:
weight = torch.load(factor_net_path, map_location="cpu")
pipe.scheduler.factor_net.load_state_dict(weight)
pipe.scheduler.factor_net.to("cuda")
return pipe
generator = torch.Generator("cuda").manual_seed(seed)
# DDIM baseline (8 steps)
pipe_ddim = load_pipeline("ddim")
image_ddim = pipe_ddim(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
generator=generator, height=height, width=width).images[0]
image_ddim.save("ddim_result.jpg")
# ConSolver (8 steps)
pipe_consolver = load_pipeline("ppo")
image_consolver = pipe_consolver(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
generator=generator, height=height, width=width).images[0]
image_consolver.save("consolver_result.jpg")
|
|
|
| DDIM | ConsistencySolver |
- Downloads last month
- 16
Model tree for wangfuyun/consolver
Base model
runwayml/stable-diffusion-v1-5