feat chat robot 接口迁移

This commit is contained in:
zhouchengrong
2024-05-29 11:12:59 +08:00
parent a9dcd444c8
commit 13fec64125
23 changed files with 1139 additions and 1 deletions

View File

@@ -0,0 +1,3 @@
from .user_buffer_window import UserConversationBufferWindowMemory
__all__ = ['UserConversationBufferWindowMemory']

View File

@@ -0,0 +1,93 @@
import logging
from typing import Any, Dict, List, Tuple
import json
import redis
from redis import Redis
from langchain.memory import RedisChatMessageHistory
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage
from langchain.schema.messages import _message_to_dict, messages_from_dict
from langchain.memory.utils import get_prompt_input_key
from app.core.config import *
class UserConversationBufferWindowMemory(BaseChatMemory):
"""Buffer for storing conversation memory."""
redis_client: Redis
human_prefix: str = "Human"
ai_prefix: str = "AI"
memory_key: str = "history" #: :meta private:
k: int = 5
@classmethod
def from_redis(
cls,
host: str = REDIS_HOST,
port: int = REDIS_PORT,
db: int = 3,
**kwargs
):
redis_client = Redis(host=host, port=port, db=db)
try:
response = redis_client.ping()
if response:
print("Connect to redis server successfully.")
logging.info("Connect to redis server successfully.")
else:
print("Fail to connect to redis server")
logging.info("Fail to connect to redis server")
except redis.RedisError as e:
logging.info(f"Error occurs when connecting to redis server: {str(e)}")
return cls(redis_client=redis_client, **kwargs)
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any], key: str = "") -> Dict[str, str]:
"""Return history buffer."""
_items: Any = self.redis_client.lrange(key, 0, self.k * 2) if self.k > 0 else []
items = [json.loads(m.decode("utf-8")) for m in _items[::-1]]
buffer = messages_from_dict(items)
if not self.return_messages:
buffer = get_buffer_string(
buffer,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
return {self.memory_key: buffer}
def _get_input_output(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> Tuple[str, str]:
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
return inputs[prompt_input_key], outputs[output_key]
def add_message(self, key: str, message: BaseMessage) -> None:
self.redis_client.lpush(key, json.dumps(_message_to_dict(message)))
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str], key: str = "") -> None:
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
self.add_message(key, HumanMessage(content=input_str))
self.add_message(key, AIMessage(content=output_str))
# def clear(self, key) -> None:
# """Clear memory contents."""
# self.redis_client.delete(key)