Lord-Raven
Removing domain check.
04b11ba
import spaces
import torch
import gradio
import json
import onnxruntime
import time
from datetime import datetime
from transformers import pipeline
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
# CORS Config - This isn't actually working. I probably don't want this any more.
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win","https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win","https://crunchatize-2-2b4f5b1479a6.c5v4v4jx6pq5.win","https://tamabotchi-2dba63df3bf1.c5v4v4jx6pq5.win","https://expressions-plus-plus-a8db5f0bd422.c5v4v4jx6pq5.win", "https://crunchatize-zero-55cf8960d31a.c5v4v4jx6pq5.win", "https://misty-mississippi-87473a288fb9.c5v4v4jx6pq5.win"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
print(f"Is CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
print(f"CUDA version: {torch.version.cuda}")
model_name_cpu = "MoritzLaurer/ModernBERT-large-zeroshot-v2.0"
tokenizer_name_cpu = "MoritzLaurer/ModernBERT-large-zeroshot-v2.0"
model_name_gpu = "MoritzLaurer/roberta-large-zeroshot-v2.0-c"
tokenizer_name_gpu = "MoritzLaurer/roberta-large-zeroshot-v2.0-c"
classifier_cpu = pipeline(task="zero-shot-classification", model=model_name_cpu, tokenizer=tokenizer_name_cpu, device="cpu", torch_dtype=torch.bfloat16)
classifier_gpu = pipeline(task="zero-shot-classification", model=model_name_gpu, tokenizer=tokenizer_name_gpu, device="cuda", torch_dtype=torch.bfloat16) if torch.cuda.is_available() else classifier_cpu
def classify(data_string, request: gradio.Request):
data = json.loads(data_string)
# Try to prevent batch suggestion warning in log.
classifier_cpu.call_count = 0
classifier_gpu.call_count = 0
start_time = time.time()
result = {}
try:
if 'cpu' not in data:
result = zero_shot_classification_gpu(data)
print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - GPU Classification took {time.time() - start_time}")
except Exception as e:
print(f"GPU classification failed: {e}\nFall back to CPU.")
if not result:
result = zero_shot_classification_cpu(data)
print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - CPU Classification took {time.time() - start_time}")
return json.dumps(result)
def zero_shot_classification_cpu(data):
return classifier_cpu(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
@spaces.GPU(duration=3)
def zero_shot_classification_gpu(data):
return classifier_gpu(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
def create_sequences(data):
return [data['sequence'] + '\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels']]
gradio_interface = gradio.Interface(
fn = classify,
inputs = gradio.Textbox(label="JSON Input"),
outputs = gradio.Textbox(label="JSON Output"),
title = "Statosphere Backend",
description = "This Space is a classification service for a set of chub.ai stages and not really intended for use through this UI."
)
app.mount("/gradio", gradio_interface)
gradio_interface.launch()