feat chat robot 接口迁移

This commit is contained in:
zhouchengrong
2024-05-29 11:12:59 +08:00
parent a9dcd444c8
commit 13fec64125
23 changed files with 1139 additions and 1 deletions

View File

@@ -0,0 +1,7 @@
from .agent_executor import CustomAgentExecutor
from .conversational_functions_agent import ConversationalFunctionsAgent
__all__ = [
"CustomAgentExecutor",
"ConversationalFunctionsAgent"
]

View File

@@ -0,0 +1,132 @@
import inspect
import json
import logging
from typing import Any, Dict, List, Optional, Union, Tuple
from langchain.agents import AgentExecutor
from langchain.callbacks.manager import Callbacks, CallbackManager
from langchain.load.dump import dumpd
from langchain.schema import RUN_KEY, RunInfo
from langchain_core.agents import AgentAction, AgentFinish
class CustomAgentExecutor(AgentExecutor):
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
session_key: str = "",
*,
tags: Optional[List[str]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
include_run_info: Whether to include run info in the response. Defaults
to False.
"""
inputs = self.prep_inputs(inputs, session_key)
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose, tags, self.tags
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
)
try:
outputs = (
self._call(inputs, run_manager=run_manager)
if new_arg_supported
else self._call(inputs)
)
except (KeyboardInterrupt, Exception) as e:
logging.exception(e)
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs, session_key
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
session_key: str = ""
) -> Dict[str, str]:
"""Validate and prep outputs."""
self._validate_outputs(outputs)
if self.memory is not None and outputs['need_record']:
self.memory.save_context(inputs, outputs, session_key)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
def prep_inputs(self, inputs: Union[Dict[str, Any], Any], session_key: str = "") -> Dict[str, str]:
"""Validate and prep inputs."""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs, session_key)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
def _get_tool_return(
self, next_step_output: Tuple[AgentAction, str]
) -> Optional[AgentFinish]:
"""Check if the tool is a returning tool."""
agent_action, observation = next_step_output
name_to_tool_map = {tool.name: tool for tool in self.tools}
return_value_key = "output"
if len(self.agent.return_values) > 0:
return_value_key = self.agent.return_values[0]
try:
observation_list = json.loads(observation)
if agent_action.tool == "sql_db_query" and isinstance(observation_list,
list) and observation_list.__len__() != 0:
return AgentFinish(
{return_value_key: observation},
"",
)
except:
pass
# Invalid tools won't be in the map, so we return False.
if agent_action.tool in name_to_tool_map:
if name_to_tool_map[agent_action.tool].return_direct:
return AgentFinish(
{return_value_key: observation},
"",
)
return None

View File

@@ -0,0 +1,198 @@
import json
import re
from json import JSONDecodeError
from typing import List, Tuple, Any, Union
from dataclasses import dataclass
from langchain.callbacks.manager import Callbacks
from langchain.agents import (
OpenAIFunctionsAgent,
)
from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
OutputParserException
)
from langchain.schema.messages import (
AIMessage,
FunctionMessage
)
from langchain.tools import BaseTool, StructuredTool
# from langchain.tools.convert_to_openai import FunctionDescription
from langchain.utils.openai_functions import FunctionDescription
@dataclass
class _FunctionsAgentAction(AgentAction):
"""Add message_log to AgentAction class for the _FunctionAgentAction
"""
message_log: List[BaseMessage]
def __init__(
self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
):
"""Override init to support instantiation by position for backward compat."""
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
def _convert_agent_action_to_messages(
agent_action: AgentAction, observation: str
) -> List[BaseMessage]:
"""Convert an agents action to a message.
This code is used to reconstruct the original AI message from the agents action.
Args:
agent_action: Agent action to convert.
Returns:
AIMessage that corresponds to the original tools invocation.
"""
if isinstance(agent_action, _FunctionsAgentAction):
return agent_action.message_log + [
_create_function_message(agent_action, observation)
]
else:
return [AIMessage(content=agent_action.log)]
def _create_function_message(
agent_action: AgentAction, observation: str
) -> FunctionMessage:
"""Convert agents action and observation into a function message.
Args:
agent_action: the tools invocation request from the agents
observation: the result of the tools invocation
Returns:
FunctionMessage that corresponds to the original tools invocation
"""
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return FunctionMessage(
name=agent_action.tool,
content=content,
)
def _format_intermediate_steps(
intermediate_steps: List[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Format intermediate steps.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
Returns:
list of messages to send to the LLM for the next prediction
"""
messages = []
for intermediate_step in intermediate_steps:
agent_action, observation = intermediate_step
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
return messages
def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
"""Format tools into the OpenAI function API."""
parameters = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": tool.param_description if hasattr(tool, 'param_description') else "",
},
},
"required": ["query"],
}
return {
"name": tool.name,
"description": tool.description,
"parameters": parameters,
}
def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message but got {type(message)}")
function_call = message.additional_kwargs.get("function_call", {})
if function_call:
function_call = message.additional_kwargs["function_call"]
function_name = function_call["name"]
try:
_tool_input = json.loads(function_call["arguments"])
except JSONDecodeError:
raise OutputParserException(
f"Could not parse tools input: {function_call} because"
f"the `arguments` is not valid JSON."
)
if "query" in _tool_input:
tool_input = _tool_input["query"]
else:
tool_input = _tool_input
return _FunctionsAgentAction(
tool=function_name,
tool_input=tool_input,
log=f"\nInvoking: `{function_name}` with `{tool_input}`\n",
message_log=[message]
)
# pattern = r'\((.*?)\)'
# matches = re.findall(pattern, message.content)
# result = []
#
# for match in matches:
# result.append(match)
#
# if result:
# output = result
# else:
# output = message.content
return AgentFinish(return_values={"output": message.content}, log=message.content)
class ConversationalFunctionsAgent(OpenAIFunctionsAgent):
@property
def functions(self) -> List[dict]:
return [dict(_format_tool_to_openai_function(t)) for t in self.tools]
def plan(self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any
) -> Union[AgentAction, AgentFinish]:
"""Decide how agents should move after receiving an input. The difference between
OpenAIFunctionsAgent lies in the '_parse_ai_message' function. We add an OutputParser
into it.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
**kwargs: Including user's input string
Returns:
Action specifying what tools to use.
"""
agent_scratchpad: List[BaseMessage] = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages: List[BaseMessage] = prompt.to_messages()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision