| |
| import gradio as gr |
| import torch |
| import numpy as np |
| from PIL import Image |
| import fitz |
| import pandas as pd |
| from huggingface_hub import login |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| |
| |
| WhisperProcessor, |
| WhisperForConditionalGeneration, |
| ViltProcessor, |
| ViltForQuestionAnswering |
| ) |
| from parler_tts import ParlerTTSForConditionalGeneration |
| from transformers import AutoFeatureExtractor |
|
|
| import os |
| import scipy.io.wavfile as wavfile |
| import io |
|
|
| |
|
|
| |
| token = os.getenv("HF_API_TOKEN") |
| if token: |
| login(token=token) |
| else: |
| print("تحذير: لم يتم تعيين متغير البيئة HF_API_TOKEN. بعض النماذج قد تتطلب تسجيل الدخول.") |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"استخدام الجهاز: {device}") |
|
|
| |
| text_model_name = "distilgpt2" |
| print(f"تحميل نموذج النص: {text_model_name}") |
| text_tokenizer = AutoTokenizer.from_pretrained(text_model_name) |
| text_model = AutoModelForCausalLM.from_pretrained( |
| text_model_name, |
| torch_dtype=torch.float32, |
| device_map="auto" |
| ) |
| if text_tokenizer.pad_token is None: |
| text_tokenizer.pad_token = text_tokenizer.eos_token |
| print("تم تحميل نموذج النص.") |
|
|
| |
| image_model_name = "dandelin/vilt-b32-finetuned-vqa" |
| print(f"تحميل نموذج الصور (VQA): {image_model_name}") |
| image_processor = ViltProcessor.from_pretrained(image_model_name) |
| image_model = ViltForQuestionAnswering.from_pretrained( |
| image_model_name, |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
| device_map="auto" |
| ) |
| print("تم تحميل نموذج الصور (VQA).") |
|
|
| |
| stt_model_name = "openai/whisper-tiny" |
| print(f"تحميل نموذج تحويل الكلام إلى نص: {stt_model_name}") |
| stt_processor = WhisperProcessor.from_pretrained(stt_model_name) |
| stt_model = WhisperForConditionalGeneration.from_pretrained(stt_model_name).to(device) |
| stt_model.config.forced_decoder_ids = None |
| print("تم تحميل نموذج تحويل الكلام إلى نص.") |
|
|
| |
| tts_model_repo_id = "parler-tts/parler-tts-tiny-v1" |
| print(f"تحميل نموذج تحويل النص إلى كلام: {tts_model_repo_id}") |
| tts_model = ParlerTTSForConditionalGeneration.from_pretrained(tts_model_repo_id).to(device) |
| tts_feature_extractor = AutoFeatureExtractor.from_pretrained(tts_model_repo_id) |
| print("تم تحميل مكونات تحويل النص إلى كلام.") |
|
|
| |
|
|
| |
| def generate_text_response(prompt_text): |
| try: |
| full_prompt = f"السؤال: {prompt_text}\nالإجابة الودية والواضحة:" |
| inputs = text_tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(text_model.device) |
| outputs = text_model.generate( |
| **inputs, |
| max_new_tokens=150, |
| temperature=0.7, |
| top_k=50, |
| do_sample=True, |
| pad_token_id=text_tokenizer.eos_token_id, |
| no_repeat_ngram_size=2 |
| ) |
| response_text = text_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
| return response_text.strip() |
| except Exception as e: |
| print(f"خطأ في توليد النص: {e}") |
| return f"خطأ في معالجة النص: {str(e)}" |
|
|
| |
| def analyze_image(pil_image, question_text=None): |
| try: |
| if pil_image is None: |
| return "الرجاء رفع صورة أولاً." |
| if isinstance(pil_image, np.ndarray): |
| pil_image = Image.fromarray(pil_image).convert("RGB") |
| else: |
| pil_image = pil_image.convert("RGB") |
|
|
| if not question_text or question_text.strip() == "": |
| |
| |
| |
| question_text = "What is in this image?" |
| |
| |
| encoding = image_processor(pil_image, question_text, return_tensors="pt").to(image_model.device) |
| |
| |
| with torch.no_grad(): |
| outputs = image_model(**encoding) |
| |
| logits = outputs.logits |
| idx = logits.argmax(-1).item() |
| response_text = image_model.config.id2label[idx] |
| |
| return response_text |
| except Exception as e: |
| print(f"خطأ في تحليل الصورة: {e}") |
| return f"خطأ في تحليل الصورة: {str(e)}" |
|
|
| |
| def process_audio(audio_input): |
| try: |
| if audio_input is None: |
| return "الرجاء تسجيل الصوت أولاً.", "", (16000, np.array([], dtype=np.int16)) |
| |
| sample_rate, audio_data = audio_input |
| |
| if audio_data.dtype != np.float32: |
| audio_data = audio_data.astype(np.float32) |
| if np.max(np.abs(audio_data)) > 0: |
| audio_data = audio_data / np.max(np.abs(audio_data)) |
| else: |
| return "تم استقبال صوت صامت.", "", (16000, np.array([], dtype=np.int16)) |
|
|
| input_features = stt_processor(audio_data, sampling_rate=sample_rate, return_tensors="pt").input_features.to(device) |
| predicted_ids = stt_model.generate(input_features, language="ar") |
| transcription = stt_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip() |
|
|
| if not transcription: |
| return "لم يتمكن النموذج من استخراج نص من الصوت.", "", (16000, np.array([], dtype=np.int16)) |
|
|
| text_response = generate_text_response(transcription) |
| |
| prompt = text_response |
| with torch.no_grad(): |
| generation_output = tts_model.generate(input_ids=None, |
| prompt=prompt, |
| do_sample=True, |
| temperature=1.0).cpu().numpy().squeeze() |
| audio_output_np = generation_output |
| tts_sample_rate = tts_model.config.sampling_rate |
|
|
| return transcription, text_response, (tts_sample_rate, audio_output_np) |
| except Exception as e: |
| print(f"خطأ في معالجة الصوت: {e}") |
| empty_audio_data = np.array([], dtype=np.float32) |
| return f"خطأ في معالجة الصوت: {str(e)}", "", (16000, empty_audio_data) |
|
|
| |
| def process_file(file_obj): |
| try: |
| if file_obj is None: |
| return "الرجاء رفع ملف أولاً." |
| file_path = file_obj.name |
| text_content = "" |
| if file_path.endswith(".pdf"): |
| with fitz.open(file_path) as doc: |
| text_content = "\n".join(page.get_text() for page in doc) |
| elif file_path.endswith((".xlsx", ".xls")): |
| df = pd.read_excel(file_path) |
| text_content = df.to_string() |
| elif file_path.endswith(".csv"): |
| df = pd.read_csv(file_path) |
| text_content = df.to_string() |
| else: |
| return "❌ نوع الملف غير مدعوم حالياً (يدعم PDF, Excel, CSV)." |
|
|
| if not text_content.strip(): |
| return "الملف فارغ أو لا يمكن قراءة محتواه النصي." |
|
|
| max_context_len = 1000 |
| if len(text_content) > max_context_len: |
| text_content = text_content[:max_context_len] + "... [المحتوى تم اختصاره]" |
| |
| response = generate_text_response(f"لخص المحتوى التالي من الملف: \n\n{text_content}") |
| return response |
| except Exception as e: |
| print(f"خطأ في معالجة الملف: {e}") |
| return f"خطأ في قراءة الملف: {str(e)}" |
|
|
| |
| with gr.Blocks(css=".gradio-container {background-color: #f0f4f8; font-family: Arial; color: #333; padding: 20px;}", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# 🤖 Kemo Chat V3.2 - مساعد ذكي متعدد الوسائط (نماذج خفيفة الوزن - ViLT VQA)") |
| gr.Markdown("🎯 تفاعل معي عبر النصوص، الصور، الصوت أو الملفات! (باستخدام نماذج أقل استهلاكًا للذاكرة).") |
| gr.Markdown("📁 يدعم الملفات: PDF، Excel، CSV\n🖼️ يدعم الإجابة على الأسئلة حول الصور (VQA)\n🎙️ تحويل الصوت إلى نص والرد صوتياً") |
|
|
| with gr.Tab("💬 المحادثة النصية"): |
| text_input = gr.Textbox(label="اكتب سؤالك أو رسالتك هنا", lines=3) |
| text_output = gr.Textbox(label="الرد", lines=5, interactive=False) |
| text_submit = gr.Button("إرسال", variant="primary") |
| text_submit.click(fn=generate_text_response, inputs=text_input, outputs=text_output) |
|
|
| with gr.Tab("🖼️ تحليل الصور (سؤال وجواب)"): |
| gr.Markdown("ارفع صورة واطرح سؤالاً عنها.") |
| with gr.Row(): |
| image_input = gr.Image(type="pil", label="ارفع صورة") |
| with gr.Column(): |
| image_question = gr.Textbox(label="اطرح سؤالاً عن الصورة (مطلوب لـ ViLT)", lines=2, placeholder="مثال: What color is the car?") |
| image_output = gr.Textbox(label="الإجابة", lines=5, interactive=False) |
| image_submit = gr.Button("تحليل الصورة", variant="primary") |
| image_submit.click(fn=analyze_image, inputs=[image_input, image_question], outputs=image_output) |
|
|
| with gr.Tab("🎤 التفاعل الصوتي"): |
| gr.Markdown("سجّل رسالة صوتية، سيتم تحويلها إلى نص، ثم الرد عليها نصيًا وصوتيًا.") |
| audio_input = gr.Audio(sources=["microphone"], type="numpy", label="سجّل رسالتك الصوتية") |
| with gr.Row(): |
| audio_transcription = gr.Textbox(label="النص المستخرج من صوتك", interactive=False, lines=2) |
| audio_text_response = gr.Textbox(label="الرد النصي على رسالتك", interactive=False, lines=3) |
| audio_output_player = gr.Audio(label="الرد الصوتي من المساعد", type="numpy", interactive=False) |
| audio_submit = gr.Button("معالجة الصوت", variant="primary") |
| audio_submit.click(fn=process_audio, |
| inputs=audio_input, |
| outputs=[audio_transcription, audio_text_response, audio_output_player]) |
|
|
| with gr.Tab("📄 تحليل الملفات"): |
| gr.Markdown("ارفع ملف (PDF, Excel, CSV) وسأقوم بتلخيص محتواه أو الإجابة على أسئلتك حوله.") |
| file_input = gr.File(label="ارفع ملفك (PDF, Excel, CSV)", file_types=[".pdf", ".xls", ".xlsx", ".csv"]) |
| file_output = gr.Textbox(label="الرد على محتوى الملف", lines=5, interactive=False) |
| file_submit = gr.Button("تحليل الملف", variant="primary") |
| file_submit.click(fn=process_file, inputs=file_input, outputs=file_output) |
|
|
| if __name__ == "__main__": |
| print("Launching Gradio Demo (Lightweight Models with ViLT VQA)...") |
| demo.launch(share=True) |
|
|
|
|