Skip to main content

工具输入模式

默认情况下,工具通过检查函数签名来推断参数模式。为了更严格的要求,可以指定自定义的输入模式,以及自定义的验证逻辑。

from typing import Any, Dict

from langchain.agents import AgentType, initialize_agent
from langchain.llms import OpenAI
from langchain.tools.requests.tool import RequestsGetTool, TextRequestsWrapper
from pydantic import BaseModel, Field, root_validator

API 参考:

llm = OpenAI(temperature=0)

执行脚本:

pip install tldextract > /dev/null

import tldextract

_APPROVED_DOMAINS = {
"langchain",
"wikipedia",
}

class ToolInputSchema(BaseModel):
url: str = Field(...)

@root_validator
def validate_query(cls, values: Dict[str, Any]) -> Dict:
url = values["url"]
domain = tldextract.extract(url).domain
if domain not in _APPROVED_DOMAINS:
raise ValueError(
f"Domain {domain} is not on the approved list:"
f" {sorted(_APPROVED_DOMAINS)}"
)
return values

tool = RequestsGetTool(
args_schema=ToolInputSchema, requests_wrapper=TextRequestsWrapper()
)

agent = initialize_agent(
[tool], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False
)

# 这将成功,因为在验证过程中没有触发任何参数
answer = agent.run("What's the main title on langchain.com?")
print(answer)
The main title of langchain.com is "LANG CHAIN 🦜️🔗 Official Home Page"
# 这将失败,因为域名不在批准列表中
agent.run("What's the main title on google.com?")

输出结果:

---------------------------------------------------------------------------
ValidationError Traceback (most recent call last)

Cell In[7], line 1
----> 1 agent.run("What's the main title on google.com?")

File ~/code/lc/lckg/langchain/chains/base.py:213, in Chain.run(self, *args, **kwargs)
211 if len(args) != 1:
212 raise ValueError("`run` supports only one positional argument.")
--> 213 return self(args[0])[self.output_keys[0]]
215 if kwargs and not args:
216 return self(kwargs)[self.output_keys[0]]

File ~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:526, in pydantic.main.BaseModel.parse_obj()

File ~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:341, in pydantic.main.BaseModel.__init__()

ValidationError: 1 validation error for ToolInputSchema
__root__
Domain google is not on the approved list: ['langchain', 'wikipedia'] (type=value_error)