add base_llm.py
This commit is contained in:
72
app/base_llm.py
Normal file
72
app/base_llm.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import gradio as gr
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
import torch
|
||||
import threading
|
||||
|
||||
# Пути к модели
|
||||
model_name = "/models/Qwen3-8B"
|
||||
|
||||
# Загрузка токенизатора и модели (один раз при старте)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Отключим кэширование в истории, чтобы каждый запрос был независимым
|
||||
def generate_response(message, history):
|
||||
# Форматируем диалог: используем только текущую историю
|
||||
prompt = ""
|
||||
for human, assistant in history:
|
||||
prompt += f"<|im_start|>user\n{human}<|im_end|>\n<|im_start|>assistant\n{assistant}<|im_end|>\n"
|
||||
prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
# Токенизация
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
# Создаём уникальный streamer для каждого запроса
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer,
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True
|
||||
)
|
||||
|
||||
# Параметры генерации
|
||||
generation_kwargs = {
|
||||
"input_ids": inputs["input_ids"],
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0.6,
|
||||
"top_p": 0.9,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"streamer": streamer,
|
||||
}
|
||||
|
||||
# Запускаем генерацию в отдельном потоке
|
||||
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
# Постепенно возвращаем результат
|
||||
buffer = ""
|
||||
for new_text in streamer:
|
||||
buffer += new_text
|
||||
yield buffer.strip()
|
||||
|
||||
# Создаем интерфейс
|
||||
demo = gr.ChatInterface(
|
||||
fn=generate_response,
|
||||
title="Qwen3 Chat",
|
||||
description="Общайтесь с моделью Qwen3 в режиме реального времени с потоковой генерацией",
|
||||
theme="soft",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ВАЖНО: используем .queue() для поддержки асинхронной обработки
|
||||
demo.queue(max_size=20, default_concurrency_limit=10).launch(
|
||||
server_port=8080,
|
||||
server_name="0.0.0.0",
|
||||
share=False,
|
||||
# Можно добавить: max_batch_size=1, concurrency_count=4
|
||||
)
|
||||
Reference in New Issue
Block a user