77 lines
2.9 KiB
Python
77 lines
2.9 KiB
Python
import gradio as gr
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||
import torch
|
||
import threading
|
||
|
||
# Пути к модели
|
||
model_name = "/models/Qwen3-8B"
|
||
|
||
# Загрузка токенизатора и модели (один раз при старте)
|
||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
model_name,
|
||
torch_dtype=torch.float16,
|
||
device_map="cuda",
|
||
trust_remote_code=True
|
||
)
|
||
|
||
# Отключим кэширование в истории, чтобы каждый запрос был независимым
|
||
def generate_response(message, history):
|
||
# Форматируем диалог: используем только текущую историю
|
||
prompt = ""
|
||
for human, assistant in history:
|
||
prompt += f"<|im_start|>user\n{human}<|im_end|>\n<|im_start|>assistant\n{assistant}<|im_end|>\n"
|
||
prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
|
||
|
||
# Токенизация
|
||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||
|
||
# Создаём уникальный streamer для каждого запроса
|
||
streamer = TextIteratorStreamer(
|
||
tokenizer,
|
||
skip_prompt=True,
|
||
skip_special_tokens=True
|
||
)
|
||
|
||
# Параметры генерации
|
||
generation_kwargs = {
|
||
"input_ids": inputs["input_ids"],
|
||
"max_new_tokens": 1024,
|
||
"temperature": 0.6,
|
||
"top_p": 0.9,
|
||
"do_sample": True,
|
||
"pad_token_id": tokenizer.eos_token_id,
|
||
"streamer": streamer,
|
||
}
|
||
|
||
# Запускаем генерацию в отдельном потоке
|
||
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
|
||
thread.start()
|
||
|
||
# Постепенно возвращаем результат
|
||
buffer = ""
|
||
for new_text in streamer:
|
||
buffer += new_text
|
||
yield buffer.strip()
|
||
|
||
# Создаем интерфейс
|
||
demo = gr.ChatInterface(
|
||
fn=generate_response,
|
||
title="Qwen3-4B-Base Chat",
|
||
description="Общайтесь с моделью Qwen3-4B-Base в режиме реального времени с потоковой генерацией",
|
||
examples=[
|
||
"Объясни, как работает квантование AWQ?",
|
||
"Напиши стихотворение про ИИ",
|
||
"Какие преимущества у Qwen3 перед предыдущими версиями?"
|
||
],
|
||
theme="soft",
|
||
)
|
||
|
||
if __name__ == "__main__":
|
||
# ВАЖНО: используем .queue() для поддержки асинхронной обработки
|
||
demo.queue(max_size=20, default_concurrency_limit=10).launch(
|
||
server_port=8080,
|
||
server_name="0.0.0.0",
|
||
share=False,
|
||
# Можно добавить: max_batch_size=1, concurrency_count=4
|
||
) |