openai 替换为 通义千问

This commit is contained in:
2024-07-08 18:50:01 +08:00
parent d772adcd7a
commit 8ad3e8ac0f
8 changed files with 412 additions and 89 deletions

View File

@@ -1,6 +1,6 @@
# flake8: noqa
"""Tools for interacting with a SQL database."""
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Type
from pydantic import BaseModel, Extra, Field, root_validator
@@ -12,9 +12,11 @@ from langchain.callbacks.manager import (
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
# from langchain.sql_database import SQLDatabase
from langchain.utilities import SQLDatabase
from langchain_community.utilities import SQLDatabase
from langchain.tools.base import BaseTool
from langchain.tools.sql_database.prompt import QUERY_CHECKER
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
from langchain_community.tools.sql_database.tool import QuerySQLCheckerTool, _QuerySQLCheckerToolInput
class BaseSQLDatabaseTool(BaseModel):
@@ -135,32 +137,70 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
raise NotImplementedError("ListTablesSqlDbTool does not support async")
# class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
# """Use an LLM to check if a query is correct.
# Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
#
# template: str = QUERY_CHECKER
# llm: BaseLanguageModel
# llm_chain: LLMChain = Field(init=False)
# name = "sql_db_query_checker"
# description = (
# "Use this tools to double check if your query is correct before executing it."
# "Always use this tools before executing a query with sql_db_query!"
# )
#
# @root_validator(pre=True)
# def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
# if "llm_chain" not in values:
# values["llm_chain"] = LLMChain(
# llm=values.get("llm"),
# prompt=PromptTemplate(
# template=QUERY_CHECKER,
# input_variables=["query", "dialect"]
# ),
# )
#
# if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
# # if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
# raise ValueError(
# "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
# )
#
# return values
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
"""Use an LLM to check if a query is correct.
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
template: str = QUERY_CHECKER
llm: BaseLanguageModel
llm_chain: LLMChain = Field(init=False)
name = "sql_db_query_checker"
description = (
"Use this tools to double check if your query is correct before executing it."
"Always use this tools before executing a query with sql_db_query!"
)
llm_chain: Any = Field(init=False)
name: str = "sql_db_query_checker"
description: str = """
Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with sql_db_query!
"""
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
@root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "llm_chain" not in values:
values["llm_chain"] = LLMChain(
llm=values.get("llm"),
prompt=PromptTemplate(
template=QUERY_CHECKER,
input_variables=["query", "dialect"]
),
)
# from langchain.chains.llm import LLMChain
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
# if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
llm = values.get("llm") # type: ignore[arg-type]
prompt = PromptTemplate(
template=QUERY_CHECKER, input_variables=["dialect", "query"]
)
values["llm_chain"] = prompt | llm
# values["llm_chain"] = LLMChain(
# llm=values.get("llm"), # type: ignore[arg-type]
# prompt=PromptTemplate(
# template=QUERY_CHECKER, input_variables=["dialect", "query"]
# ),
# )
# if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
if values["llm_chain"].first.input_variables != ["dialect", "query"]:
raise ValueError(
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
)