openai 替换为 通义千问
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# flake8: noqa
|
||||
"""Tools for interacting with a SQL database."""
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
@@ -12,9 +12,11 @@ from langchain.callbacks.manager import (
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
# from langchain.sql_database import SQLDatabase
|
||||
from langchain.utilities import SQLDatabase
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
from langchain_community.tools.sql_database.tool import QuerySQLCheckerTool, _QuerySQLCheckerToolInput
|
||||
|
||||
|
||||
class BaseSQLDatabaseTool(BaseModel):
|
||||
@@ -135,32 +137,70 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
raise NotImplementedError("ListTablesSqlDbTool does not support async")
|
||||
|
||||
|
||||
# class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||||
# """Use an LLM to check if a query is correct.
|
||||
# Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
||||
#
|
||||
# template: str = QUERY_CHECKER
|
||||
# llm: BaseLanguageModel
|
||||
# llm_chain: LLMChain = Field(init=False)
|
||||
# name = "sql_db_query_checker"
|
||||
# description = (
|
||||
# "Use this tools to double check if your query is correct before executing it."
|
||||
# "Always use this tools before executing a query with sql_db_query!"
|
||||
# )
|
||||
#
|
||||
# @root_validator(pre=True)
|
||||
# def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# if "llm_chain" not in values:
|
||||
# values["llm_chain"] = LLMChain(
|
||||
# llm=values.get("llm"),
|
||||
# prompt=PromptTemplate(
|
||||
# template=QUERY_CHECKER,
|
||||
# input_variables=["query", "dialect"]
|
||||
# ),
|
||||
# )
|
||||
#
|
||||
# if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
|
||||
# # if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
|
||||
# raise ValueError(
|
||||
# "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
|
||||
# )
|
||||
#
|
||||
# return values
|
||||
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Use an LLM to check if a query is correct.
|
||||
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
||||
|
||||
template: str = QUERY_CHECKER
|
||||
llm: BaseLanguageModel
|
||||
llm_chain: LLMChain = Field(init=False)
|
||||
name = "sql_db_query_checker"
|
||||
description = (
|
||||
"Use this tools to double check if your query is correct before executing it."
|
||||
"Always use this tools before executing a query with sql_db_query!"
|
||||
)
|
||||
llm_chain: Any = Field(init=False)
|
||||
name: str = "sql_db_query_checker"
|
||||
description: str = """
|
||||
Use this tool to double check if your query is correct before executing it.
|
||||
Always use this tool before executing a query with sql_db_query!
|
||||
"""
|
||||
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
|
||||
|
||||
@root_validator(pre=True)
|
||||
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "llm_chain" not in values:
|
||||
values["llm_chain"] = LLMChain(
|
||||
llm=values.get("llm"),
|
||||
prompt=PromptTemplate(
|
||||
template=QUERY_CHECKER,
|
||||
input_variables=["query", "dialect"]
|
||||
),
|
||||
)
|
||||
# from langchain.chains.llm import LLMChain
|
||||
|
||||
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
|
||||
# if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
|
||||
llm = values.get("llm") # type: ignore[arg-type]
|
||||
prompt = PromptTemplate(
|
||||
template=QUERY_CHECKER, input_variables=["dialect", "query"]
|
||||
)
|
||||
values["llm_chain"] = prompt | llm
|
||||
# values["llm_chain"] = LLMChain(
|
||||
# llm=values.get("llm"), # type: ignore[arg-type]
|
||||
# prompt=PromptTemplate(
|
||||
# template=QUERY_CHECKER, input_variables=["dialect", "query"]
|
||||
# ),
|
||||
# )
|
||||
|
||||
# if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
|
||||
if values["llm_chain"].first.input_variables != ["dialect", "query"]:
|
||||
raise ValueError(
|
||||
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user