161 lines
7.6 KiB
Python
161 lines
7.6 KiB
Python
from typing import Any, Dict, List, Literal, Optional
|
||
|
||
import torch
|
||
from config import settings
|
||
from langchain_chroma import Chroma
|
||
from langchain_huggingface import HuggingFaceEmbeddings
|
||
from loguru import logger
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||
|
||
|
||
class ChatWithAI:
|
||
def __init__(self, provider: "qwen3"):
|
||
self.provider = provider
|
||
self.embeddings = HuggingFaceEmbeddings(
|
||
model_name=settings.LM_MODEL_NAME,
|
||
model_kwargs={"device": "cuda"},
|
||
encode_kwargs={"normalize_embeddings": True},
|
||
)
|
||
|
||
if provider == "qwen3":
|
||
model_name = getattr(settings, "LOCAL_LLM_NAME", "/models/Qwen3-4B")
|
||
|
||
logger.info(f"Загрузка локальной модели: {model_name}")
|
||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
model_name,
|
||
torch_dtype=torch.float32,
|
||
device_map="cuda",
|
||
)
|
||
|
||
#Создаём text-generation pipeline
|
||
self.llm = pipeline(
|
||
"text-generation",
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
device_map="cuda", # 0 = GPU, -1 = CPU
|
||
temperature=0.7,
|
||
max_new_tokens=512,
|
||
do_sample=True,
|
||
pad_token_id=tokenizer.eos_token_id,
|
||
)
|
||
else:
|
||
raise ValueError(f"Неподдерживаемый провайдер: {provider}")
|
||
|
||
self.chroma_db = Chroma(
|
||
persist_directory=settings.DOCS_CHROMA_PATH,
|
||
embedding_function=self.embeddings,
|
||
collection_name=settings.DOCS_COLLECTION_NAME,
|
||
)
|
||
|
||
def get_relevant_context(self, query: str, k: int = 3) -> List[Dict[str, Any]]:
|
||
"""Получение релевантного контекста из базы данных."""
|
||
try:
|
||
results = self.chroma_db.similarity_search(query, k=k)
|
||
return [
|
||
{
|
||
"text": doc.page_content,
|
||
"metadata": doc.metadata,
|
||
}
|
||
for doc in results
|
||
]
|
||
except Exception as e:
|
||
logger.error(f"Ошибка при получении контекста: {e}")
|
||
return []
|
||
|
||
def format_context(self, context: List[Dict[str, Any]]) -> str:
|
||
"""Форматирование контекста для промпта."""
|
||
formatted_context = []
|
||
for item in context:
|
||
metadata_str = "\n".join(f"{k}: {v}" for k, v in item["metadata"].items())
|
||
formatted_context.append(
|
||
f"Текст: {item['text']}\nМетаданные:\n{metadata_str}\n"
|
||
)
|
||
return "\n---\n".join(formatted_context)
|
||
|
||
def generate_response(self, query: str) -> Optional[str]:
|
||
"""Генерация ответа на основе запроса и контекста."""
|
||
try:
|
||
context = self.get_relevant_context(query)
|
||
if not context:
|
||
return "Извините, не удалось найти релевантный контекст для ответа."
|
||
|
||
formatted_context = self.format_context(context)
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": """Ты — внутренний менеджер компании Amvera Cloud. Отвечаешь по делу без лишних вступлений.
|
||
|
||
Правила:
|
||
1. Сразу переходи к сути, без фраз типа "На основе контекста"
|
||
2. Используй только факты. Если точных данных нет — отвечай общими фразами об Marzban, но не придумывай конкретику
|
||
3. Используй обычный текст без форматирования
|
||
4. Включай ссылки только если они есть в контексте
|
||
5. Говори от первого лица множественного числа: "Мы предоставляем", "У нас есть"
|
||
6. При упоминании файлов делай это естественно, например: "Я прикреплю инструкцию, где подробно описаны шаги"
|
||
7. На приветствия отвечай доброжелательно, на негатив — с легким юмором
|
||
8. Можешь при ответах использовать общую информацию из открытых источников по Marzban, но опирайся на контекст
|
||
9. Если пользователь спрашивает о ценах, планах или технических характеристиках — давай конкретные ответы из контекста
|
||
10. При технических вопросах предлагай практические решения
|
||
|
||
Персонализируй ответы, упоминая имя клиента если оно есть в контексте. Будь краток, информативен и полезен.""",
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": f"Вопрос: {query}\nКонтекст: {formatted_context}",
|
||
},
|
||
]
|
||
# Генерация через transformers
|
||
if self.provider == "qwen3":
|
||
# Используем токенизатор модели для форматирования чата
|
||
tokenizer = self.llm.tokenizer
|
||
model = self.llm.model
|
||
|
||
# Применяем chat template (поддерживается в современных моделях: Zephyr, Llama3, Qwen и т.д.)
|
||
prompt = tokenizer.apply_chat_template(
|
||
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
|
||
)
|
||
|
||
# Генерация
|
||
outputs = self.llm(
|
||
prompt,
|
||
max_new_tokens=512,
|
||
temperature=0.7,
|
||
do_sample=True,
|
||
top_p=0.9,
|
||
pad_token_id=tokenizer.eos_token_id,
|
||
)
|
||
response_text = outputs[0]["generated_text"]
|
||
|
||
# Убираем входной промпт, оставляем только ответ
|
||
if prompt in response_text:
|
||
response_text = response_text[len(prompt):].strip()
|
||
|
||
return response_text
|
||
|
||
else:
|
||
# Остальные провайдеры (deepseek, openai) используют langchain
|
||
response = self.llm.invoke(messages)
|
||
if hasattr(response, "content"):
|
||
return str(response.content)
|
||
return str(response).strip()
|
||
except Exception as e:
|
||
logger.error(f"Ошибка при генерации ответа: {e}")
|
||
return "Произошла ошибка при генерации ответа."
|
||
|
||
|
||
if __name__ == "__main__":
|
||
chat = ChatWithAI(provider="qwen3")
|
||
print("\n=== Чат с ИИ ===\n")
|
||
|
||
while True:
|
||
query = input("Вы: ")
|
||
if query.lower() == "выход":
|
||
print("\nДо свидания!")
|
||
break
|
||
|
||
print("\nИИ печатает...", end="\r")
|
||
response = chat.generate_response(query)
|
||
print(" " * 20, end="\r") # Очищаем "ИИ печатает..."
|
||
print(f"ИИ: {response}\n") |