feat chat robot 接口迁移
This commit is contained in:
7
app/service/chat_robot/script/agents/__init__.py
Normal file
7
app/service/chat_robot/script/agents/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .agent_executor import CustomAgentExecutor
|
||||
from .conversational_functions_agent import ConversationalFunctionsAgent
|
||||
|
||||
__all__ = [
|
||||
"CustomAgentExecutor",
|
||||
"ConversationalFunctionsAgent"
|
||||
]
|
||||
132
app/service/chat_robot/script/agents/agent_executor.py
Normal file
132
app/service/chat_robot/script/agents/agent_executor.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.callbacks.manager import Callbacks, CallbackManager
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.schema import RUN_KEY, RunInfo
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
|
||||
|
||||
class CustomAgentExecutor(AgentExecutor):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
session_key: str = "",
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param.
|
||||
return_only_outputs: boolean for whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs, session_key)
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
)
|
||||
try:
|
||||
outputs = (
|
||||
self._call(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._call(inputs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
logging.exception(e)
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs, session_key
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
def prep_outputs(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
outputs: Dict[str, str],
|
||||
return_only_outputs: bool = False,
|
||||
session_key: str = ""
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep outputs."""
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None and outputs['need_record']:
|
||||
self.memory.save_context(inputs, outputs, session_key)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def prep_inputs(self, inputs: Union[Dict[str, Any], Any], session_key: str = "") -> Dict[str, str]:
|
||||
"""Validate and prep inputs."""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||
if len(_input_keys) != 1:
|
||||
raise ValueError(
|
||||
f"A single string input was passed in, but this chain expects "
|
||||
f"multiple inputs ({_input_keys}). When a chain expects "
|
||||
f"multiple inputs, please call it by passing in a dictionary, "
|
||||
"eg `chain({'foo': 1, 'bar': 2})`"
|
||||
)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs, session_key)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
def _get_tool_return(
|
||||
self, next_step_output: Tuple[AgentAction, str]
|
||||
) -> Optional[AgentFinish]:
|
||||
"""Check if the tool is a returning tool."""
|
||||
agent_action, observation = next_step_output
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
return_value_key = "output"
|
||||
|
||||
if len(self.agent.return_values) > 0:
|
||||
return_value_key = self.agent.return_values[0]
|
||||
|
||||
try:
|
||||
observation_list = json.loads(observation)
|
||||
if agent_action.tool == "sql_db_query" and isinstance(observation_list,
|
||||
list) and observation_list.__len__() != 0:
|
||||
return AgentFinish(
|
||||
{return_value_key: observation},
|
||||
"",
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Invalid tools won't be in the map, so we return False.
|
||||
if agent_action.tool in name_to_tool_map:
|
||||
if name_to_tool_map[agent_action.tool].return_direct:
|
||||
return AgentFinish(
|
||||
{return_value_key: observation},
|
||||
"",
|
||||
)
|
||||
return None
|
||||
@@ -0,0 +1,198 @@
|
||||
import json
|
||||
import re
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Tuple, Any, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.agents import (
|
||||
OpenAIFunctionsAgent,
|
||||
)
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseMessage,
|
||||
OutputParserException
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
FunctionMessage
|
||||
)
|
||||
from langchain.tools import BaseTool, StructuredTool
|
||||
# from langchain.tools.convert_to_openai import FunctionDescription
|
||||
from langchain.utils.openai_functions import FunctionDescription
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FunctionsAgentAction(AgentAction):
|
||||
"""Add message_log to AgentAction class for the _FunctionAgentAction
|
||||
"""
|
||||
message_log: List[BaseMessage]
|
||||
|
||||
def __init__(
|
||||
self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
|
||||
):
|
||||
"""Override init to support instantiation by position for backward compat."""
|
||||
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
|
||||
|
||||
|
||||
def _convert_agent_action_to_messages(
|
||||
agent_action: AgentAction, observation: str
|
||||
) -> List[BaseMessage]:
|
||||
"""Convert an agents action to a message.
|
||||
|
||||
This code is used to reconstruct the original AI message from the agents action.
|
||||
|
||||
Args:
|
||||
agent_action: Agent action to convert.
|
||||
|
||||
Returns:
|
||||
AIMessage that corresponds to the original tools invocation.
|
||||
"""
|
||||
if isinstance(agent_action, _FunctionsAgentAction):
|
||||
return agent_action.message_log + [
|
||||
_create_function_message(agent_action, observation)
|
||||
]
|
||||
else:
|
||||
return [AIMessage(content=agent_action.log)]
|
||||
|
||||
|
||||
def _create_function_message(
|
||||
agent_action: AgentAction, observation: str
|
||||
) -> FunctionMessage:
|
||||
"""Convert agents action and observation into a function message.
|
||||
Args:
|
||||
agent_action: the tools invocation request from the agents
|
||||
observation: the result of the tools invocation
|
||||
Returns:
|
||||
FunctionMessage that corresponds to the original tools invocation
|
||||
"""
|
||||
if not isinstance(observation, str):
|
||||
try:
|
||||
content = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception:
|
||||
content = str(observation)
|
||||
else:
|
||||
content = observation
|
||||
return FunctionMessage(
|
||||
name=agent_action.tool,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _format_intermediate_steps(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
) -> List[BaseMessage]:
|
||||
"""Format intermediate steps.
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
Returns:
|
||||
list of messages to send to the LLM for the next prediction
|
||||
"""
|
||||
messages = []
|
||||
|
||||
for intermediate_step in intermediate_steps:
|
||||
agent_action, observation = intermediate_step
|
||||
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
"""Format tools into the OpenAI function API."""
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": tool.param_description if hasattr(tool, 'param_description') else "",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
|
||||
def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
|
||||
if not isinstance(message, AIMessage):
|
||||
raise TypeError(f"Expected an AI message but got {type(message)}")
|
||||
|
||||
function_call = message.additional_kwargs.get("function_call", {})
|
||||
|
||||
if function_call:
|
||||
function_call = message.additional_kwargs["function_call"]
|
||||
function_name = function_call["name"]
|
||||
try:
|
||||
_tool_input = json.loads(function_call["arguments"])
|
||||
except JSONDecodeError:
|
||||
raise OutputParserException(
|
||||
f"Could not parse tools input: {function_call} because"
|
||||
f"the `arguments` is not valid JSON."
|
||||
)
|
||||
|
||||
if "query" in _tool_input:
|
||||
tool_input = _tool_input["query"]
|
||||
else:
|
||||
tool_input = _tool_input
|
||||
|
||||
return _FunctionsAgentAction(
|
||||
tool=function_name,
|
||||
tool_input=tool_input,
|
||||
log=f"\nInvoking: `{function_name}` with `{tool_input}`\n",
|
||||
message_log=[message]
|
||||
)
|
||||
|
||||
# pattern = r'\((.*?)\)'
|
||||
# matches = re.findall(pattern, message.content)
|
||||
# result = []
|
||||
#
|
||||
# for match in matches:
|
||||
# result.append(match)
|
||||
#
|
||||
# if result:
|
||||
# output = result
|
||||
# else:
|
||||
# output = message.content
|
||||
|
||||
return AgentFinish(return_values={"output": message.content}, log=message.content)
|
||||
|
||||
|
||||
class ConversationalFunctionsAgent(OpenAIFunctionsAgent):
|
||||
@property
|
||||
def functions(self) -> List[dict]:
|
||||
return [dict(_format_tool_to_openai_function(t)) for t in self.tools]
|
||||
|
||||
def plan(self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Decide how agents should move after receiving an input. The difference between
|
||||
OpenAIFunctionsAgent lies in the '_parse_ai_message' function. We add an OutputParser
|
||||
into it.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
**kwargs: Including user's input string
|
||||
|
||||
Returns:
|
||||
Action specifying what tools to use.
|
||||
"""
|
||||
agent_scratchpad: List[BaseMessage] = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages: List[BaseMessage] = prompt.to_messages()
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
Reference in New Issue
Block a user