| import os |
| import sys |
|
|
| |
| try: |
| import spaces |
| except ImportError: |
| pass |
|
|
| sys.stdout.flush() |
| import functools |
| print = functools.partial(print, flush=True) |
|
|
| import ftfy |
| import sentencepiece |
|
|
| from FlowFacade import FlowFacade |
| from BackgroundEngine import BackgroundEngine |
| from style_transfer import StyleTransferEngine |
| from ui_manager import UIManager |
|
|
|
|
| def preload_models(): |
| """ |
| Pre-download models to cache on HF Spaces startup. |
| Backup method if YAML preload_from_hub doesn't work. |
| Only runs in HF Spaces environment. |
| """ |
| if not os.environ.get('SPACE_ID'): |
| return |
|
|
| cache_dir = os.path.expanduser("~/.cache/huggingface/hub") |
| if os.path.exists(cache_dir): |
| cached_models = os.listdir(cache_dir) |
| if any("wan2.2" in m.lower() or "models--kijai" in m.lower() for m in cached_models): |
| print("✓ Models already cached (YAML preload worked)") |
| return |
|
|
| print("→ Pre-caching models to disk (first-time setup)...") |
| print(" This may take 2-3 minutes, please wait...") |
|
|
| try: |
| from diffusers import WanTransformer3DModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from huggingface_hub import hf_hub_download |
| import torch |
|
|
| print(" [1/4] Downloading video model transformer...") |
| WanTransformer3DModel.from_pretrained( |
| "cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers", |
| subfolder='transformer', |
| torch_dtype=torch.bfloat16, |
| ) |
|
|
| print(" [2/4] Downloading video model transformer_2...") |
| WanTransformer3DModel.from_pretrained( |
| "cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers", |
| subfolder='transformer_2', |
| torch_dtype=torch.bfloat16, |
| ) |
|
|
| print(" [3/4] Downloading Lightning LoRA...") |
| hf_hub_download( |
| "Kijai/WanVideo_comfy", |
| "Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors" |
| ) |
|
|
| print(" [4/4] Downloading text model (optional)...") |
| AutoModelForCausalLM.from_pretrained( |
| "Qwen/Qwen2.5-0.5B-Instruct", |
| torch_dtype=torch.bfloat16, |
| ) |
| AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") |
|
|
| print("✓ All models cached successfully!") |
| print(" Future users will load instantly from cache") |
|
|
| except Exception as e: |
| print(f"⚠ Pre-cache warning: {e}") |
| print(" Models will download on first generation instead") |
|
|
|
|
| def check_environment(): |
| required_packages = [ |
| "torch", "transformers", "diffusers", "gradio", "PIL", |
| "accelerate", "numpy", "ftfy", "sentencepiece" |
| ] |
|
|
| optional_packages = { |
| "torchao": "INT8/FP8 quantization", |
| "xformers": "Memory efficient attention", |
| "aoti": "AoT compilation" |
| } |
|
|
| missing_packages = [] |
| missing_optional = [] |
|
|
| for package in required_packages: |
| try: |
| __import__(package) |
| except ImportError: |
| missing_packages.append(package) |
|
|
| for package, description in optional_packages.items(): |
| try: |
| __import__(package) |
| except ImportError: |
| missing_optional.append(f"{package} ({description})") |
|
|
| if missing_packages: |
| print("\n❌ Missing required packages:", ", ".join(missing_packages)) |
| print("\nInstall commands:") |
| print("!pip install torch==2.9.0 torchvision==0.24.0 torchaudio==2.9.0 --index-url https://download.pytorch.org/whl/cu126") |
| print("!pip install diffusers>=0.32.0 transformers>=4.46.0 accelerate gradio pillow numpy spaces ftfy sentencepiece protobuf imageio-ffmpeg") |
| print("!pip install torchao xformers") |
| sys.exit(1) |
|
|
| |
| if missing_optional and os.environ.get('DEBUG'): |
| print("⚠ Optional packages missing:", ", ".join(missing_optional)) |
|
|
|
|
| def main(): |
| check_environment() |
| preload_models() |
|
|
| try: |
| facade = FlowFacade() |
| background_engine = BackgroundEngine() |
| style_engine = StyleTransferEngine() |
| ui_manager = UIManager(facade, background_engine, style_engine) |
| interface = ui_manager.create_interface() |
| is_colab = 'google.colab' in sys.modules |
|
|
| print("✓ Ready") |
| interface.launch( |
| share=is_colab, |
| server_name="0.0.0.0", |
| server_port=None, |
| show_error=True |
| ) |
|
|
| except KeyboardInterrupt: |
| print("\n⚠ Shutdown requested") |
| if 'facade' in locals(): |
| facade.cleanup() |
| sys.exit(0) |
|
|
| except Exception as e: |
| print(f"\n❌ Startup error: {str(e)}") |
| import traceback |
| traceback.print_exc() |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|