Spaces:
Running
on
Zero
Running
on
Zero
Ray
commited on
Commit
·
4178132
1
Parent(s):
77d6865
Update app and add models via LFS
Browse files- {omini → Genfocus}/__init__.py +0 -0
- omini/pipeline/flux_omini.py → Genfocus/pipeline/flux.py +12 -7
- app.py +261 -91
- bokehNet.safetensors +3 -0
- default.safetensors → deblurNet.safetensors +0 -0
- example/female.jpg +3 -0
- requirements.txt +3 -1
{omini → Genfocus}/__init__.py
RENAMED
|
File without changes
|
omini/pipeline/flux_omini.py → Genfocus/pipeline/flux.py
RENAMED
|
@@ -35,12 +35,12 @@ def clip_hidden_states(hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
|
| 35 |
hidden_states = hidden_states.clip(-65504, 65504)
|
| 36 |
return hidden_states
|
| 37 |
|
| 38 |
-
|
| 39 |
-
def encode_images(pipeline: FluxPipeline, images: torch.Tensor):
|
| 40 |
"""
|
| 41 |
Encodes the images into tokens and ids for FLUX pipeline.
|
| 42 |
"""
|
| 43 |
-
|
|
|
|
| 44 |
images = images.to(pipeline.device).to(pipeline.dtype)
|
| 45 |
images = pipeline.vae.encode(images).latent_dist.sample()
|
| 46 |
images = (
|
|
@@ -105,6 +105,7 @@ class Condition(object):
|
|
| 105 |
position_scale=1.0,
|
| 106 |
latent_mask=None,
|
| 107 |
is_complement=False,
|
|
|
|
| 108 |
) -> None:
|
| 109 |
self.condition = condition
|
| 110 |
self.adapter = adapter_setting
|
|
@@ -114,12 +115,17 @@ class Condition(object):
|
|
| 114 |
latent_mask.T.reshape(-1) if latent_mask is not None else None
|
| 115 |
)
|
| 116 |
self.is_complement = is_complement
|
|
|
|
| 117 |
|
| 118 |
def encode(
|
| 119 |
self, pipe: FluxPipeline, empty: bool = False
|
| 120 |
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
if self.position_delta is not None:
|
| 125 |
ids[:, 1] += self.position_delta[0]
|
|
@@ -136,7 +142,6 @@ class Condition(object):
|
|
| 136 |
|
| 137 |
return tokens, ids
|
| 138 |
|
| 139 |
-
|
| 140 |
@contextmanager
|
| 141 |
def specify_lora(lora_modules: List[BaseTunerLayer], specified_lora):
|
| 142 |
# Filter valid lora modules
|
|
@@ -259,7 +264,6 @@ def attn_forward(
|
|
| 259 |
with specify_lora((attn.to_out[0],), adapters[i + h2_n]):
|
| 260 |
h = attn.to_out[0](h)
|
| 261 |
h_out.append(h)
|
| 262 |
-
|
| 263 |
return (h_out, h2_out) if h2_n else h_out
|
| 264 |
|
| 265 |
|
|
@@ -450,6 +454,7 @@ def transformer_forward(
|
|
| 450 |
|
| 451 |
return (output,)
|
| 452 |
|
|
|
|
| 453 |
@torch.no_grad()
|
| 454 |
def generate(
|
| 455 |
pipeline: FluxPipeline,
|
|
|
|
| 35 |
hidden_states = hidden_states.clip(-65504, 65504)
|
| 36 |
return hidden_states
|
| 37 |
|
| 38 |
+
def encode_images(pipeline: FluxPipeline, images: torch.Tensor,No_preprocess=False):
|
|
|
|
| 39 |
"""
|
| 40 |
Encodes the images into tokens and ids for FLUX pipeline.
|
| 41 |
"""
|
| 42 |
+
if not No_preprocess:
|
| 43 |
+
images = pipeline.image_processor.preprocess(images)
|
| 44 |
images = images.to(pipeline.device).to(pipeline.dtype)
|
| 45 |
images = pipeline.vae.encode(images).latent_dist.sample()
|
| 46 |
images = (
|
|
|
|
| 105 |
position_scale=1.0,
|
| 106 |
latent_mask=None,
|
| 107 |
is_complement=False,
|
| 108 |
+
No_preprocess=False,
|
| 109 |
) -> None:
|
| 110 |
self.condition = condition
|
| 111 |
self.adapter = adapter_setting
|
|
|
|
| 115 |
latent_mask.T.reshape(-1) if latent_mask is not None else None
|
| 116 |
)
|
| 117 |
self.is_complement = is_complement
|
| 118 |
+
self.No_preprocess=No_preprocess
|
| 119 |
|
| 120 |
def encode(
|
| 121 |
self, pipe: FluxPipeline, empty: bool = False
|
| 122 |
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 123 |
+
if isinstance(self.condition, Image.Image):
|
| 124 |
+
condition_empty = Image.new("RGB", self.condition.size, (0, 0, 0))
|
| 125 |
+
elif torch.is_tensor(self.condition):
|
| 126 |
+
H, W = self.condition.shape[-2], self.condition.shape[-1]
|
| 127 |
+
condition_empty = Image.fromarray(np.zeros((H, W, 3), dtype=np.uint8), "RGB")
|
| 128 |
+
tokens, ids = encode_images(pipe, condition_empty if empty else self.condition,self.No_preprocess)
|
| 129 |
|
| 130 |
if self.position_delta is not None:
|
| 131 |
ids[:, 1] += self.position_delta[0]
|
|
|
|
| 142 |
|
| 143 |
return tokens, ids
|
| 144 |
|
|
|
|
| 145 |
@contextmanager
|
| 146 |
def specify_lora(lora_modules: List[BaseTunerLayer], specified_lora):
|
| 147 |
# Filter valid lora modules
|
|
|
|
| 264 |
with specify_lora((attn.to_out[0],), adapters[i + h2_n]):
|
| 265 |
h = attn.to_out[0](h)
|
| 266 |
h_out.append(h)
|
|
|
|
| 267 |
return (h_out, h2_out) if h2_n else h_out
|
| 268 |
|
| 269 |
|
|
|
|
| 454 |
|
| 455 |
return (output,)
|
| 456 |
|
| 457 |
+
|
| 458 |
@torch.no_grad()
|
| 459 |
def generate(
|
| 460 |
pipeline: FluxPipeline,
|
app.py
CHANGED
|
@@ -1,139 +1,309 @@
|
|
| 1 |
-
import
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
| 5 |
-
|
| 6 |
-
from
|
| 7 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
# ===
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
#
|
|
|
|
|
|
|
| 13 |
MODEL_ID = "black-forest-labs/FLUX.1-dev"
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
dtype = torch.bfloat16
|
| 19 |
-
pipe_flux = FluxPipeline.from_pretrained(MODEL_ID, torch_dtype=dtype)
|
| 20 |
|
| 21 |
-
|
| 22 |
-
try:
|
| 23 |
-
print("🔄 Loading LoRA weights...")
|
| 24 |
-
pipe_flux.load_lora_weights(".", weight_name=LORA_WEIGHT_NAME, adapter_name="deblurring")
|
| 25 |
-
pipe_flux.set_adapters(["deblurring"])
|
| 26 |
-
except Exception as e:
|
| 27 |
-
print(f"⚠️ LoRA Error: {e}")
|
| 28 |
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
def center_crop_512(img: Image.Image) -> Image.Image:
|
| 33 |
w, h = img.size
|
| 34 |
target = 512
|
| 35 |
-
|
| 36 |
-
# Only resize if the image is smaller than the target size
|
| 37 |
if min(w, h) < target:
|
| 38 |
scale = target / min(w, h)
|
| 39 |
new_w, new_h = int(w * scale), int(h * scale)
|
| 40 |
img = img.resize((new_w, new_h), Image.LANCZOS)
|
| 41 |
w, h = new_w, new_h
|
| 42 |
-
|
| 43 |
-
# Calculate center coordinates
|
| 44 |
left = (w - target) // 2
|
| 45 |
top = (h - target) // 2
|
| 46 |
right = left + target
|
| 47 |
bottom = top + target
|
| 48 |
-
|
| 49 |
return img.crop((left, top, right, bottom))
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
-
condition_1_img = center_crop_512(input_image)
|
| 65 |
-
# Create a black image for condition 0
|
| 66 |
condition_0_img = Image.new("RGB", (512, 512), (0, 0, 0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
height=512,
|
| 80 |
-
width=512,
|
| 81 |
-
prompt=prompt,
|
| 82 |
-
conditions=conditions
|
| 83 |
-
).images[0]
|
| 84 |
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
-
#
|
|
|
|
|
|
|
| 88 |
css = """
|
| 89 |
-
#col-container { margin: 0 auto; max-width:
|
| 90 |
"""
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
|
|
|
| 100 |
|
| 101 |
with gr.Blocks(css=css) as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
with gr.Column(elem_id="col-container"):
|
| 103 |
-
#
|
| 104 |
-
gr.Markdown("# 📷 Genfocus DeblurNet Demo")
|
| 105 |
-
|
| 106 |
-
# Description
|
| 107 |
-
gr.Markdown("""
|
| 108 |
-
### Description
|
| 109 |
-
This demo showcases the functionality of our **first-stage defocus deblurring**.
|
| 110 |
-
|
| 111 |
-
⚠️ **Note**: For demonstration purposes, input images will be automatically **Center Cropped to 512x512**.
|
| 112 |
-
""")
|
| 113 |
|
| 114 |
with gr.Row():
|
| 115 |
-
with gr.Column():
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
)
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
)
|
| 137 |
|
| 138 |
if __name__ == "__main__":
|
|
|
|
| 139 |
demo.launch()
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
import gradio as gr
|
| 4 |
import torch
|
| 5 |
import numpy as np
|
| 6 |
+
import tempfile
|
| 7 |
+
from PIL import Image, ImageDraw
|
| 8 |
+
from skimage import color, img_as_float32, img_as_ubyte
|
| 9 |
+
|
| 10 |
+
# Hugging Face Spaces 特有的 GPU 裝飾器
|
| 11 |
+
import spaces
|
| 12 |
+
from huggingface_hub import login
|
| 13 |
|
| 14 |
+
# === Import Logic ===
|
| 15 |
+
# 確保 Genfocus 與 depth_pro 資料夾已上傳至 Space 根目錄
|
| 16 |
+
try:
|
| 17 |
+
from Genfocus.pipeline.flux import Condition, generate, seed_everything
|
| 18 |
+
print("✅ Loaded Condition/generate from Genfocus.pipeline.flux")
|
| 19 |
+
except ImportError:
|
| 20 |
+
raise RuntimeError("❌ Cannot find 'Genfocus'. Please upload the folder to the Space.")
|
| 21 |
+
|
| 22 |
+
import depth_pro
|
| 23 |
|
| 24 |
+
# ==========================================
|
| 25 |
+
# 2. Global Settings
|
| 26 |
+
# ==========================================
|
| 27 |
MODEL_ID = "black-forest-labs/FLUX.1-dev"
|
| 28 |
+
# 假設 .safetensors 檔案位於根目錄
|
| 29 |
+
DEBLUR_LORA_PATH = "."
|
| 30 |
+
DEBLUR_WEIGHT_NAME = "deblurNet.safetensors"
|
| 31 |
+
BOKEH_LORA_DIR = "."
|
| 32 |
+
BOKEH_WEIGHT_NAME = "bokehNet.safetensors"
|
| 33 |
|
| 34 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 35 |
+
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
print(f"🚀 Device detected: {device}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
# 全域變數初始化 (延遲加載以節省啟動時間)
|
| 40 |
+
pipe_flux = None
|
| 41 |
+
depth_model = None
|
| 42 |
+
depth_transform = None
|
| 43 |
+
current_adapter = None
|
| 44 |
|
| 45 |
+
def load_models():
|
| 46 |
+
"""在第一次執行時加載模型"""
|
| 47 |
+
global pipe_flux, depth_model, depth_transform
|
| 48 |
+
|
| 49 |
+
if pipe_flux is None:
|
| 50 |
+
print("🔄 Loading FLUX pipeline...")
|
| 51 |
+
# 注意:FLUX.1-dev 需要 HF Token 權限
|
| 52 |
+
pipe_flux = FluxPipeline.from_pretrained(MODEL_ID, torch_dtype=dtype)
|
| 53 |
+
if device == "cuda":
|
| 54 |
+
pipe_flux.to("cuda")
|
| 55 |
+
|
| 56 |
+
if depth_model is None:
|
| 57 |
+
print("🔄 Loading Depth Pro model...")
|
| 58 |
+
try:
|
| 59 |
+
checkpoint_path = hf_hub_download(
|
| 60 |
+
repo_id=WEIGHTS_REPO_ID,
|
| 61 |
+
filename=DEPTH_FILENAME,
|
| 62 |
+
repo_type="model"
|
| 63 |
+
)
|
| 64 |
+
print(f"📂 Depth checkpoint cached at: {checkpoint_path}")
|
| 65 |
+
|
| 66 |
+
depth_model, depth_transform = depth_pro.create_model_and_transforms(
|
| 67 |
+
device=device,
|
| 68 |
+
checkpoint_path=checkpoint_path
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
if device == "cuda":
|
| 72 |
+
depth_model.eval().to("cuda")
|
| 73 |
+
else:
|
| 74 |
+
depth_model.eval()
|
| 75 |
+
print("✅ Depth Pro loaded.")
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"❌ Failed to load Depth Pro: {e}")
|
| 78 |
+
|
| 79 |
+
# ==========================================
|
| 80 |
+
# 3. Helper Functions
|
| 81 |
+
# ==========================================
|
| 82 |
def center_crop_512(img: Image.Image) -> Image.Image:
|
| 83 |
w, h = img.size
|
| 84 |
target = 512
|
|
|
|
|
|
|
| 85 |
if min(w, h) < target:
|
| 86 |
scale = target / min(w, h)
|
| 87 |
new_w, new_h = int(w * scale), int(h * scale)
|
| 88 |
img = img.resize((new_w, new_h), Image.LANCZOS)
|
| 89 |
w, h = new_w, new_h
|
|
|
|
|
|
|
| 90 |
left = (w - target) // 2
|
| 91 |
top = (h - target) // 2
|
| 92 |
right = left + target
|
| 93 |
bottom = top + target
|
|
|
|
| 94 |
return img.crop((left, top, right, bottom))
|
| 95 |
|
| 96 |
+
def switch_lora(target_mode):
|
| 97 |
+
global pipe_flux, current_adapter
|
| 98 |
+
if current_adapter == target_mode:
|
| 99 |
+
return
|
| 100 |
+
print(f"🔄 Switching LoRA to [{target_mode}]...")
|
| 101 |
+
pipe_flux.unload_lora_weights()
|
| 102 |
+
if target_mode == "deblur":
|
| 103 |
+
try:
|
| 104 |
+
pipe_flux.load_lora_weights(DEBLUR_LORA_PATH, weight_name=DEBLUR_WEIGHT_NAME, adapter_name="deblurring")
|
| 105 |
+
pipe_flux.set_adapters(["deblurring"])
|
| 106 |
+
current_adapter = "deblur"
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f"❌ Failed to load Deblur LoRA: {e}")
|
| 109 |
+
elif target_mode == "bokeh":
|
| 110 |
+
try:
|
| 111 |
+
pipe_flux.load_lora_weights(BOKEH_LORA_DIR, weight_name=BOKEH_WEIGHT_NAME, adapter_name="bokeh")
|
| 112 |
+
pipe_flux.set_adapters(["bokeh"])
|
| 113 |
+
current_adapter = "bokeh"
|
| 114 |
+
except Exception as e:
|
| 115 |
+
print(f"❌ Failed to load Bokeh LoRA: {e}")
|
| 116 |
+
|
| 117 |
+
# ==========================================
|
| 118 |
+
# 4. Processing Logic
|
| 119 |
+
# ==========================================
|
| 120 |
+
|
| 121 |
+
def preprocess_input_image(raw_img, do_resize):
|
| 122 |
+
if raw_img is None: return None, None, None
|
| 123 |
+
print(f"🔄 Preprocessing Input... Resize={do_resize}")
|
| 124 |
+
img_to_process = raw_img
|
| 125 |
+
if do_resize:
|
| 126 |
+
w, h = img_to_process.size
|
| 127 |
+
scale = 512 / min(w, h)
|
| 128 |
+
new_w, new_h = int(w * scale), int(h * scale)
|
| 129 |
+
img_to_process = img_to_process.resize((new_w, new_h), Image.LANCZOS)
|
| 130 |
|
| 131 |
+
final_input = center_crop_512(img_to_process)
|
| 132 |
+
return final_input, final_input, None
|
| 133 |
|
| 134 |
+
def draw_red_dot_on_preview(clean_img, evt: gr.SelectData):
|
| 135 |
+
if clean_img is None: return None, None
|
| 136 |
+
img_copy = clean_img.copy()
|
| 137 |
+
draw = ImageDraw.Draw(img_copy)
|
| 138 |
+
x, y = evt.index
|
| 139 |
+
r = 8
|
| 140 |
+
draw.ellipse((x-r, y-r, x+r, y+r), outline="red", width=2)
|
| 141 |
+
draw.line((x-r, y, x+r, y), fill="red", width=2)
|
| 142 |
+
draw.line((x, y-r, x, y+r), fill="red", width=2)
|
| 143 |
+
return img_copy, evt.index
|
| 144 |
|
| 145 |
+
# !!! 關鍵修改:加上 @spaces.GPU 裝飾器 !!!
|
| 146 |
+
# 這告訴 HF Spaces 當這個函式執行時,將其分配到 GPU 上
|
| 147 |
+
@spaces.GPU(duration=120)
|
| 148 |
+
def run_genfocus_pipeline(clean_input_512, click_coords, K_value, cached_latents):
|
| 149 |
+
global pipe_flux, depth_model
|
| 150 |
+
|
| 151 |
+
# 確保模型已加載
|
| 152 |
+
load_models()
|
| 153 |
+
|
| 154 |
+
if clean_input_512 is None:
|
| 155 |
+
raise gr.Error("Please complete Step 1 (Upload Image) first.")
|
| 156 |
+
|
| 157 |
+
print("🚀 Starting Genfocus Pipeline...")
|
| 158 |
+
|
| 159 |
+
# 1. Run Deblur (Stage 1)
|
| 160 |
+
switch_lora("deblur")
|
| 161 |
|
|
|
|
|
|
|
| 162 |
condition_0_img = Image.new("RGB", (512, 512), (0, 0, 0))
|
| 163 |
+
cond0 = Condition(condition_0_img, "deblurring", [0, 32], 1.0)
|
| 164 |
+
cond1 = Condition(clean_input_512, "deblurring", [0, 0], 1.0)
|
| 165 |
+
|
| 166 |
+
seed_everything(42)
|
| 167 |
+
deblurred_img = generate(
|
| 168 |
+
pipe_flux, height=512, width=512,
|
| 169 |
+
prompt="a sharp photo with everything in focus",
|
| 170 |
+
conditions=[cond0, cond1]
|
| 171 |
+
).images[0]
|
| 172 |
+
|
| 173 |
+
if K_value == 0:
|
| 174 |
+
return deblurred_img, cached_latents
|
| 175 |
+
|
| 176 |
+
# 2. Run Bokeh (Stage 2)
|
| 177 |
+
if click_coords is None:
|
| 178 |
+
click_coords = [256, 256]
|
| 179 |
+
|
| 180 |
+
# Depth Estimation
|
| 181 |
+
try:
|
| 182 |
+
img_t = depth_transform(deblurred_img)
|
| 183 |
+
if device == "cuda": img_t = img_t.to("cuda")
|
| 184 |
+
with torch.no_grad():
|
| 185 |
+
pred = depth_model.infer(img_t, f_px=None)
|
| 186 |
+
depth_map = pred["depth"].cpu().numpy().squeeze()
|
| 187 |
+
safe_depth = np.where(depth_map > 0.0, depth_map, np.finfo(np.float32).max)
|
| 188 |
+
disp_orig = 1.0 / safe_depth
|
| 189 |
+
disp = cv2.resize(disp_orig, (512, 512), interpolation=cv2.INTER_LINEAR)
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print(f"❌ Depth Error: {e}")
|
| 192 |
+
return deblurred_img, cached_latents
|
| 193 |
|
| 194 |
+
# Defocus Map
|
| 195 |
+
tx, ty = click_coords
|
| 196 |
+
tx = min(max(int(tx), 0), 511)
|
| 197 |
+
ty = min(max(int(ty), 0), 511)
|
| 198 |
|
| 199 |
+
disp_focus = float(disp[ty, tx])
|
| 200 |
+
dmf = disp - np.float32(disp_focus)
|
| 201 |
+
defocus_abs = np.abs(K_value * dmf)
|
| 202 |
+
MAX_COC = 100.0
|
| 203 |
+
defocus_t = torch.from_numpy(defocus_abs).unsqueeze(0).float()
|
| 204 |
+
cond_map = (defocus_t / MAX_COC).clamp(0, 1).repeat(3,1,1).unsqueeze(0)
|
| 205 |
|
| 206 |
+
# Latents
|
| 207 |
+
if cached_latents is None:
|
| 208 |
+
seed_everything(42)
|
| 209 |
+
gen = torch.Generator(device=pipe_flux.device).manual_seed(1234)
|
| 210 |
+
latents, _ = pipe_flux.prepare_latents(
|
| 211 |
+
batch_size=1, num_channels_latents=16, height=512, width=512,
|
| 212 |
+
dtype=pipe_flux.dtype, device=pipe_flux.device, generator=gen, latents=None
|
| 213 |
+
)
|
| 214 |
+
current_latents = latents
|
| 215 |
+
else:
|
| 216 |
+
current_latents = cached_latents
|
| 217 |
+
|
| 218 |
+
# Generate
|
| 219 |
+
switch_lora("bokeh")
|
| 220 |
+
cond_img = Condition(deblurred_img, "bokeh")
|
| 221 |
+
cond_dmf = Condition(cond_map, "bokeh", [0,0], 1.0, No_preprocess=True)
|
| 222 |
|
| 223 |
+
seed_everything(42)
|
| 224 |
+
gen = torch.Generator(device=pipe_flux.device).manual_seed(1234)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
+
with torch.no_grad():
|
| 227 |
+
res = generate(
|
| 228 |
+
pipe_flux, height=512, width=512,
|
| 229 |
+
prompt="an excellent photo with a large aperture",
|
| 230 |
+
conditions=[cond_img, cond_dmf],
|
| 231 |
+
guidance_scale=1.0, kv_cache=False, generator=gen,
|
| 232 |
+
latents=current_latents,
|
| 233 |
+
)
|
| 234 |
+
generated_bokeh = res.images[0]
|
| 235 |
+
return generated_bokeh, current_latents
|
| 236 |
|
| 237 |
+
# ==========================================
|
| 238 |
+
# 5. UI Setup
|
| 239 |
+
# ==========================================
|
| 240 |
css = """
|
| 241 |
+
#col-container { margin: 0 auto; max-width: 1400px; }
|
| 242 |
"""
|
| 243 |
|
| 244 |
+
base_path = os.getcwd()
|
| 245 |
+
# 簡化 Example 路徑檢查
|
| 246 |
+
example_dir = os.path.join(base_path, "example")
|
| 247 |
+
valid_examples = []
|
| 248 |
+
if os.path.exists(example_dir):
|
| 249 |
+
files = os.listdir(example_dir)
|
| 250 |
+
for f in files:
|
| 251 |
+
if f.lower().endswith(('.jpg', '.jpeg', '.png')):
|
| 252 |
+
valid_examples.append([os.path.join(example_dir, f)])
|
| 253 |
|
| 254 |
with gr.Blocks(css=css) as demo:
|
| 255 |
+
clean_processed_state = gr.State(value=None)
|
| 256 |
+
click_coords_state = gr.State(value=None)
|
| 257 |
+
latents_state = gr.State(value=None)
|
| 258 |
+
|
| 259 |
with gr.Column(elem_id="col-container"):
|
| 260 |
+
gr.Markdown("# 📷 Genfocus Pipeline: Interactive Refocusing (HF Demo)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
with gr.Row():
|
| 263 |
+
with gr.Column(scale=1):
|
| 264 |
+
gr.Markdown("### Step 1: Upload & Preprocess")
|
| 265 |
+
input_raw = gr.Image(label="Raw Input Image", type="pil")
|
| 266 |
+
resize_chk = gr.Checkbox(label="Resize min edge to 512", value=False)
|
| 267 |
+
|
| 268 |
+
if valid_examples:
|
| 269 |
+
gr.Examples(examples=valid_examples, inputs=input_raw, label="Examples")
|
| 270 |
+
|
| 271 |
+
with gr.Column(scale=1):
|
| 272 |
+
gr.Markdown("### Step 2: Set Focus & K")
|
| 273 |
+
focus_preview_img = gr.Image(label="Focus Point Selection", type="pil", interactive=False)
|
| 274 |
+
with gr.Row():
|
| 275 |
+
click_status = gr.Textbox(label="Coords", value="Center", interactive=False, scale=1)
|
| 276 |
+
k_slider = gr.Slider(0, 50, value=0, step=1, label="Blur Strength (K)", scale=2)
|
| 277 |
+
run_btn = gr.Button("✨ Run Genfocus", variant="primary", scale=1)
|
| 278 |
+
|
| 279 |
+
with gr.Row():
|
| 280 |
+
output_img = gr.Image(label="Result", type="pil", interactive=False)
|
| 281 |
+
|
| 282 |
+
# Events
|
| 283 |
+
update_trigger = [input_raw.change, resize_chk.change, input_raw.upload]
|
| 284 |
+
for trigger in update_trigger:
|
| 285 |
+
trigger(
|
| 286 |
+
fn=preprocess_input_image,
|
| 287 |
+
inputs=[input_raw, resize_chk],
|
| 288 |
+
outputs=[focus_preview_img, clean_processed_state, latents_state]
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
focus_preview_img.select(
|
| 292 |
+
fn=draw_red_dot_on_preview,
|
| 293 |
+
inputs=[clean_processed_state],
|
| 294 |
+
outputs=[focus_preview_img, click_coords_state]
|
| 295 |
+
).then(
|
| 296 |
+
fn=lambda x: f"x={x[0]}, y={x[1]}",
|
| 297 |
+
inputs=[click_coords_state],
|
| 298 |
+
outputs=[click_status]
|
| 299 |
)
|
| 300 |
|
| 301 |
+
run_btn.click(
|
| 302 |
+
fn=run_genfocus_pipeline,
|
| 303 |
+
inputs=[clean_processed_state, click_coords_state, k_slider, latents_state],
|
| 304 |
+
outputs=[output_img, latents_state]
|
| 305 |
+
)
|
|
|
|
| 306 |
|
| 307 |
if __name__ == "__main__":
|
| 308 |
+
# HF Spaces 不需要指定 server_name 或 allowed_paths
|
| 309 |
demo.launch()
|
bokehNet.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2e87c32b368b1af66e2aa5bc8f58c53ebba75b1afae85e8279cb5f9c5c608d13
|
| 3 |
+
size 463703368
|
default.safetensors → deblurNet.safetensors
RENAMED
|
File without changes
|
example/female.jpg
ADDED
|
Git LFS Details
|
requirements.txt
CHANGED
|
@@ -6,4 +6,6 @@ protobuf
|
|
| 6 |
sentencepiece
|
| 7 |
gradio
|
| 8 |
jupyter
|
| 9 |
-
torchao
|
|
|
|
|
|
|
|
|
| 6 |
sentencepiece
|
| 7 |
gradio
|
| 8 |
jupyter
|
| 9 |
+
torchao
|
| 10 |
+
git+https://github.com/apple/ml-depth-pro.git
|
| 11 |
+
scikit-image
|