add base_llm.py

This commit is contained in:
2025-08-03 18:09:27 +07:00
parent cd2c2d235d
commit 5c585544d3

72
app/base_llm.py Normal file
View File

@@ -0,0 +1,72 @@
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
import threading
# Пути к модели
model_name = "/models/Qwen3-8B"
# Загрузка токенизатора и модели (один раз при старте)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="cuda",
trust_remote_code=True
)
# Отключим кэширование в истории, чтобы каждый запрос был независимым
def generate_response(message, history):
# Форматируем диалог: используем только текущую историю
prompt = ""
for human, assistant in history:
prompt += f"<|im_start|>user\n{human}<|im_end|>\n<|im_start|>assistant\n{assistant}<|im_end|>\n"
prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
# Токенизация
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Создаём уникальный streamer для каждого запроса
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
# Параметры генерации
generation_kwargs = {
"input_ids": inputs["input_ids"],
"max_new_tokens": 1024,
"temperature": 0.6,
"top_p": 0.9,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"streamer": streamer,
}
# Запускаем генерацию в отдельном потоке
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Постепенно возвращаем результат
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer.strip()
# Создаем интерфейс
demo = gr.ChatInterface(
fn=generate_response,
title="Qwen3 Chat",
description="Общайтесь с моделью Qwen3 в режиме реального времени с потоковой генерацией",
theme="soft",
)
if __name__ == "__main__":
# ВАЖНО: используем .queue() для поддержки асинхронной обработки
demo.queue(max_size=20, default_concurrency_limit=10).launch(
server_port=8080,
server_name="0.0.0.0",
share=False,
# Можно добавить: max_batch_size=1, concurrency_count=4
)