Files
llm/app/test.py
2025-08-03 06:52:55 +00:00

161 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import Any, Dict, List, Literal, Optional
import torch
from config import settings
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
class ChatWithAI:
def __init__(self, provider: "qwen3"):
self.provider = provider
self.embeddings = HuggingFaceEmbeddings(
model_name=settings.LM_MODEL_NAME,
model_kwargs={"device": "cuda"},
encode_kwargs={"normalize_embeddings": True},
)
if provider == "qwen3":
model_name = getattr(settings, "LOCAL_LLM_NAME", "/models/Qwen3-4B")
logger.info(f"Загрузка локальной модели: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
device_map="cuda",
)
#Создаём text-generation pipeline
self.llm = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="cuda", # 0 = GPU, -1 = CPU
temperature=0.7,
max_new_tokens=512,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
else:
raise ValueError(f"Неподдерживаемый провайдер: {provider}")
self.chroma_db = Chroma(
persist_directory=settings.DOCS_CHROMA_PATH,
embedding_function=self.embeddings,
collection_name=settings.DOCS_COLLECTION_NAME,
)
def get_relevant_context(self, query: str, k: int = 3) -> List[Dict[str, Any]]:
"""Получение релевантного контекста из базы данных."""
try:
results = self.chroma_db.similarity_search(query, k=k)
return [
{
"text": doc.page_content,
"metadata": doc.metadata,
}
for doc in results
]
except Exception as e:
logger.error(f"Ошибка при получении контекста: {e}")
return []
def format_context(self, context: List[Dict[str, Any]]) -> str:
"""Форматирование контекста для промпта."""
formatted_context = []
for item in context:
metadata_str = "\n".join(f"{k}: {v}" for k, v in item["metadata"].items())
formatted_context.append(
f"Текст: {item['text']}\nМетаданные:\n{metadata_str}\n"
)
return "\n---\n".join(formatted_context)
def generate_response(self, query: str) -> Optional[str]:
"""Генерация ответа на основе запроса и контекста."""
try:
context = self.get_relevant_context(query)
if not context:
return "Извините, не удалось найти релевантный контекст для ответа."
formatted_context = self.format_context(context)
messages = [
{
"role": "system",
"content": """Ты — внутренний менеджер компании Amvera Cloud. Отвечаешь по делу без лишних вступлений.
Правила:
1. Сразу переходи к сути, без фраз типа "На основе контекста"
2. Используй только факты. Если точных данных нет — отвечай общими фразами об Marzban, но не придумывай конкретику
3. Используй обычный текст без форматирования
4. Включай ссылки только если они есть в контексте
5. Говори от первого лица множественного числа: "Мы предоставляем", "У нас есть"
6. При упоминании файлов делай это естественно, например: "Я прикреплю инструкцию, где подробно описаны шаги"
7. На приветствия отвечай доброжелательно, на негатив — с легким юмором
8. Можешь при ответах использовать общую информацию из открытых источников по Marzban, но опирайся на контекст
9. Если пользователь спрашивает о ценах, планах или технических характеристиках — давай конкретные ответы из контекста
10. При технических вопросах предлагай практические решения
Персонализируй ответы, упоминая имя клиента если оно есть в контексте. Будь краток, информативен и полезен.""",
},
{
"role": "user",
"content": f"Вопрос: {query}\nКонтекст: {formatted_context}",
},
]
# Генерация через transformers
if self.provider == "qwen3":
# Используем токенизатор модели для форматирования чата
tokenizer = self.llm.tokenizer
model = self.llm.model
# Применяем chat template (поддерживается в современных моделях: Zephyr, Llama3, Qwen и т.д.)
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
)
# Генерация
outputs = self.llm(
prompt,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
)
response_text = outputs[0]["generated_text"]
# Убираем входной промпт, оставляем только ответ
if prompt in response_text:
response_text = response_text[len(prompt):].strip()
return response_text
else:
# Остальные провайдеры (deepseek, openai) используют langchain
response = self.llm.invoke(messages)
if hasattr(response, "content"):
return str(response.content)
return str(response).strip()
except Exception as e:
logger.error(f"Ошибка при генерации ответа: {e}")
return "Произошла ошибка при генерации ответа."
if __name__ == "__main__":
chat = ChatWithAI(provider="qwen3")
print("\n=== Чат с ИИ ===\n")
while True:
query = input("Вы: ")
if query.lower() == "выход":
print("\nДо свидания!")
break
print("\nИИ печатает...", end="\r")
response = chat.generate_response(query)
print(" " * 20, end="\r") # Очищаем "ИИ печатает..."
print(f"ИИ: {response}\n")