Skip to main content

自定义带有工具检索的代理

本章是基于这个笔记本构建的,并假设您熟悉代理的工作原理。

这个笔记本中引入的新颖想法是使用检索来选择用于回答代理查询的工具集。当您有很多工具可供选择时,这是非常有用的。您不能将所有工具的描述放在提示中(由于上下文长度问题),因此您可以在运行时动态选择要考虑使用的N个工具。

在这个笔记本中,我们将创建一个有些人为的示例。我们将有一个合法的工具(搜索)和99个虚假的工具,这些工具只是胡言乱语。然后,我们将在提示模板中添加一步,该步骤获取用户输入并检索与查询相关的工具。

设置环境

进行必要的导入等操作。

from langchain.agents import (  
Tool,
AgentExecutor,
LLMSingleActionAgent,
AgentOutputParser,
)
from langchain.prompts import StringPromptTemplate
from langchain import OpenAI, SerpAPIWrapper, LLMChain
from typing import List, Union
from langchain.schema import AgentAction, AgentFinish
import re

API 参考:

设置工具

我们将创建一个合法的工具(搜索)和99个虚假的工具。

# 定义代理可以用来回答用户查询的工具
search = SerpAPIWrapper()
search_tool = Tool(
name="搜索",
func=search.run,
description="当您需要回答有关当前事件的问题时很有用",
)


def fake_func(inp: str) -> str:
return "foo"


fake_tools = [
Tool(
name=f"foo-{i}",
func=fake_func,
description=f"一个愚蠢的函数,您可以使用它来获取有关数字{i}的更多信息",
)
for i in range(99)
]
ALL_TOOLS = [search_tool] + fake_tools

工具检索器

我们将使用一个向量存储库为每个工具描述创建嵌入。然后,对于传入的查询,我们可以为该查询创建嵌入,并进行相似性搜索以获取相关的工具。

from langchain.vectorstores import FAISS  
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document

docs = [
Document(page_content=t.description, metadata={"index": i})
for i, t in enumerate(ALL_TOOLS)
]

vector_store = FAISS.from_documents(docs, OpenAIEmbeddings())

retriever = vector_store.as_retriever()


def get_tools(query):
docs = retriever.get_relevant_documents(query)
return [ALL_TOOLS[d.metadata["index"]] for d in docs]

现在我们可以测试这个检索器,看看它是否工作正常。

get_tools("天气如何?")
[Tool(name='搜索', description='当您需要回答有关当前事件的问题时很有用', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x114b28a90>, func=<bound method SerpAPIWrapper.run of SerpAPIWrapper(search_engine=<class 'serpapi.google_search.GoogleSearch'>, params={'engine': 'google', 'google_domain': 'google.com', 'gl': 'us', 'hl': 'en'}, serpapi_api_key='', aiosession=None)>, coroutine=None),
Tool(name='foo-95', description='一个愚蠢的函数,您可以使用它来获取有关数字95的更多信息', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x114b28a90>, func=<function fake_func at 0x15e5bd1f0>, coroutine=None),
Tool(name='foo-12', description='一个愚蠢的函数,您可以使用它来获取有关数字12的更多信息', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x114b28a90>, func=<function fake_func at 0x15e5bd1f0>, coroutine=None),
Tool(name='foo-15', description='一个愚蠢的函数,您可以使用它来获取有关数字15的更多信息', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x114b28a90>, func=<function fake_func at 0x15e5bd1f0>, coroutine=None)]
get_tools("数字13是什么?")
[Tool(name='foo-13', description='一个愚蠢的函数,您可以使用它来获取有关数字13的更多信息', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x114b28a90>, func=<function fake_func at 0x15e5bd1f0>, coroutine=None),
Tool(name='foo-12', description='一个愚蠢的函数,您可以使用它来获取有关数字12的更多信息', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x114b28a90>, func=<function fake_func at 0x15e5bd1f0>, coroutine=None),
Tool(name='foo-14', description='一个愚蠢的函数,您可以使用它来获取有关数字14的更多信息', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x114b28a90>, func=<function fake_func at 0x15e5bd1f0>, coroutine=None),
Tool(name='foo-11', description='一个愚蠢的函数,您可以使用它来获取有关数字11的更多信息', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x114b28a90>, func=<function fake_func at 0x15e5bd1f0>, coroutine=None)]

提示模板

提示模板非常标准,因为我们实际上并没有改变实际提示模板的逻辑,而是改变了检索的方式。

# 设置基本模板
template = """Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:

{tools}

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

Begin! Remember to speak as a pirate when giving your final answer. Use lots of "Arg"s

Question: {input}
{agent_scratchpad}"""

自定义提示模板现在具有 tools_getter 的概念,我们在输入上调用它以选择要使用的工具。

from typing import Callable


# 设置一个提示模板
class CustomPromptTemplate(StringPromptTemplate):
# 要使用的模板
template: str
############## NEW ######################
# 可用的工具列表
tools_getter: Callable

def format(self, **kwargs) -> str:
# 获取中间步骤(AgentAction,Observation 元组)
# 以特定方式格式化它们
intermediate_steps = kwargs.pop("intermediate_steps")
thoughts = ""
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\nObservation: {observation}\nThought: "
# 将 agent_scratchpad 变量设置为该值
kwargs["agent_scratchpad"] = thoughts
############## NEW ######################
tools = self.tools_getter(kwargs["input"])
# 从提供的工具列表创建一个 tools 变量
kwargs["tools"] = "\n".join(
[f"{tool.name}: {tool.description}" for tool in tools]
)
# 创建一个工具名称列表
kwargs["tool_names"] = ", ".join([tool.name for tool in tools])
return self.template.format(**kwargs)


prompt = CustomPromptTemplate(
template=template,
tools_getter=get_tools,
# 这里省略了 `agent_scratchpad`、`tools` 和 `tool_names` 变量,因为这些是动态生成的
# 这里包括了 `intermediate_steps` 变量,因为这是需要的
input_variables=["input", "intermediate_steps"],
)

输出解析器

输出解析器与之前的笔记本相同,因为我们没有改变输出格式的任何内容。

class CustomOutputParser(AgentOutputParser):
def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
# 检查代理是否应该结束
if "Final Answer:" in llm_output:
return AgentFinish(
# 返回值通常是一个带有单个 `output` 键的字典
# 目前不建议尝试其他任何内容 :)
return_values={"output": llm_output.split("Final Answer:")[-1].strip()},
log=llm_output,
)
# 解析出 action 和 action input
regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
match = re.search(regex, llm_output, re.DOTALL)
if not match:
raise ValueError(f"Could not parse LLM output: `{llm_output}`")
action = match.group(1).strip()
action_input = match.group(2)
# 返回 action 和 action input
return AgentAction(
tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output
)


output_parser = CustomOutputParser()

设置 LLM、停止序列和代理

与之前的笔记本相同。

llm = OpenAI(temperature=0)

# LLM 链包括 LLM 和提示
llm_chain = LLMChain(llm=llm, prompt=prompt)

tools = get_tools("天气如何?")
tool_names = [tool.name for tool in tools]
agent = LLMSingleActionAgent(
llm_chain=llm_chain,
output_parser=output_parser,
stop=["\nObservation:"],
allowed_tools=tool_names,
)

使用代理

现在我们可以使用它了!

agent_executor = AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, verbose=True
)

agent_executor.run("旧金山的天气如何?")
> 进入新的 AgentExecutor 链...
Thought: 我需要找出旧金山的天气
Action: 搜索
Action Input: 旧金山的天气

Observation: 早上大部分多云,下午部分多云。最高温度约为60华氏度。东北偏东风转向西风,风速为10到15英里/小时。湿度71%。紫外线指数10的6。我现在知道最终答案
Final Answer: '嗯,早上大部分多云,下午部分多云。最高温度约为60华氏度。东北偏东风转向西风,风速为10到15英里/小时。湿度71%。紫外线指数10的6。
> 完成链。
"'嗯,早上大部分多云,下午部分多云。最高温度约为60华氏度。东北偏东风转向西风,风速为10到15英里/小时。湿度71%。紫外线指数10的6。"