176 lines
8.2 KiB
Python
176 lines
8.2 KiB
Python
from typing import Any, Dict, List, Optional
|
||
import gradio as gr
|
||
import torch
|
||
import threading
|
||
from loguru import logger
|
||
from langchain_chroma import Chroma
|
||
from langchain_huggingface import HuggingFaceEmbeddings
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||
|
||
# Предполагается, что у тебя есть config.py с settings
|
||
from config import settings
|
||
|
||
|
||
class ChatWithAI:
|
||
def __init__(self, provider: str = "qwen3"):
|
||
self.provider = provider
|
||
self.embeddings = HuggingFaceEmbeddings(
|
||
model_name=settings.LM_MODEL_NAME,
|
||
model_kwargs={"device": "cuda"},
|
||
encode_kwargs={"normalize_embeddings": True},
|
||
)
|
||
|
||
if provider == "qwen3":
|
||
model_name = getattr(settings, "LOCAL_LLM_NAME", "/models/Qwen3-4B")
|
||
logger.info(f"Загрузка локальной модели: {model_name}")
|
||
|
||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||
self.model = AutoModelForCausalLM.from_pretrained(
|
||
model_name,
|
||
torch_dtype=torch.float16,
|
||
device_map="cuda",
|
||
)
|
||
|
||
# Streamer для потоковой генерации
|
||
self.streamer = TextIteratorStreamer(
|
||
self.tokenizer,
|
||
skip_prompt=True,
|
||
skip_special_tokens=True
|
||
)
|
||
else:
|
||
raise ValueError(f"Неподдерживаемый провайдер: {provider}")
|
||
|
||
self.chroma_db = Chroma(
|
||
persist_directory=settings.DOCS_CHROMA_PATH,
|
||
embedding_function=self.embeddings,
|
||
collection_name=settings.DOCS_COLLECTION_NAME,
|
||
)
|
||
|
||
def get_relevant_context(self, query: str, k: int = 3) -> List[Dict[str, Any]]:
|
||
"""Получение релевантного контекста из базы данных."""
|
||
try:
|
||
results = self.chroma_db.similarity_search(query, k=k)
|
||
return [
|
||
{
|
||
"text": doc.page_content,
|
||
"metadata": doc.metadata,
|
||
}
|
||
for doc in results
|
||
]
|
||
except Exception as e:
|
||
logger.error(f"Ошибка при получении контекста: {e}")
|
||
return []
|
||
|
||
def format_context(self, context: List[Dict[str, Any]]) -> str:
|
||
"""Форматирование контекста для промпта."""
|
||
formatted_context = []
|
||
for item in context:
|
||
metadata_str = "\n".join(f"{k}: {v}" for k, v in item["metadata"].items())
|
||
formatted_context.append(
|
||
f"Текст: {item['text']}\nМетаданные:\n{metadata_str}\n"
|
||
)
|
||
return "\n---\n".join(formatted_context)
|
||
|
||
def generate_response_stream(self, query: str):
|
||
"""Генерация ответа с потоковой передачей токенов."""
|
||
try:
|
||
logger.info(f"Пользовательский запрос: {query}")
|
||
context = self.get_relevant_context(query)
|
||
if not context:
|
||
yield "Извините, не удалось найти релевантный контекст для ответа."
|
||
return
|
||
|
||
formatted_context = self.format_context(context)
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": """Ты — внутренний менеджер помощи пользоватклям по вопросам настройки серверов. Отвечаешь по делу без лишних вступлений.
|
||
|
||
Правила:
|
||
1. Сразу переходи к сути, без фраз типа "На основе контекста"
|
||
2. Используй только факты. Если точных данных нет — отвечай общими фразами об настройки серверов и подбору железа, но не придумывай конкретику
|
||
3. Используй обычный текст без форматирования
|
||
4. Включай ссылки только если они есть в контексте
|
||
5. Говори от первого лица множественного числа: "Мы предоставляем", "У нас есть"
|
||
6. При упоминании файлов делай это естественно, например: "Я прикреплю инструкцию, где подробно описаны шаги"
|
||
7. На приветствия отвечай доброжелательно, на негатив — с легким юмором
|
||
8. Можешь при ответах использовать общую информацию из открытых источников по настройке сервером и подбору железа, но опирайся на контекст
|
||
9. Если пользователь спрашивает о ценах, планах или технических характеристиках — давай конкретные ответы из контекста
|
||
10. При технических вопросах предлагай практические решения
|
||
|
||
Персонализируй ответы, упоминая имя клиента если оно есть в контексте. Будь краток, информативен и полезен.""",
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": f"Вопрос: {query}\nКонтекст: {formatted_context}",
|
||
},
|
||
]
|
||
|
||
# Применяем шаблон чата
|
||
prompt = self.tokenizer.apply_chat_template(
|
||
messages,
|
||
tokenize=False,
|
||
add_generation_prompt=True,
|
||
enable_thinking=False
|
||
)
|
||
|
||
# Подготавливаем вход
|
||
inputs = self.tokenizer(prompt, return_tensors="pt").to("cuda")
|
||
|
||
# Очищаем streamer и запускаем генерацию в отдельном потоке
|
||
self.streamer = TextIteratorStreamer(
|
||
self.tokenizer,
|
||
skip_prompt=True,
|
||
skip_special_tokens=True
|
||
)
|
||
|
||
generate_kwargs = {
|
||
"input_ids": inputs["input_ids"],
|
||
"max_new_tokens": 512,
|
||
"temperature": 0.5,
|
||
"do_sample": True,
|
||
"top_p": 0.9,
|
||
"pad_token_id": self.tokenizer.eos_token_id,
|
||
"streamer": self.streamer,
|
||
}
|
||
|
||
thread = threading.Thread(target=self.model.generate, kwargs=generate_kwargs)
|
||
thread.start()
|
||
|
||
# Потоковая передача токенов
|
||
buffer = ""
|
||
for token in self.streamer:
|
||
buffer += token
|
||
yield buffer # Отправляем частичный ответ
|
||
|
||
except Exception as e:
|
||
logger.error(f"Ошибка при генерации ответа: {e}")
|
||
yield "Произошла ошибка при генерации ответа."
|
||
|
||
|
||
# === Gradio интерфейс ===
|
||
def main():
|
||
chat = ChatWithAI(provider="qwen3")
|
||
|
||
def respond(message, history):
|
||
# Генерируем ответ по частям
|
||
for token in chat.generate_response_stream(message):
|
||
yield token
|
||
|
||
demo = gr.ChatInterface(
|
||
fn=respond,
|
||
title="Помощник настройки сервера и подбора железа",
|
||
description="Задайте вопрос — получите ответ от внутреннего менеджера.",
|
||
examples=[
|
||
"Как определить цели и требования для домашнего сервера?",
|
||
"Как выбрать ОС для домашнего сервера?",
|
||
"Проверка совместимости и выбор процессора",
|
||
"Установка Unraid и первоначальная настройка"
|
||
],
|
||
)
|
||
demo.launch(server_name="0.0.0.0", server_port=8080, share=False)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |