Skip to main content

MultiQueryRetriever

基于距离的向量数据库检索将查询嵌入到高维空间中,并根据“距离”找到相似的嵌入文档。但是,如果查询措辞发生细微变化或嵌入不很好地捕捉到数据的语义,检索可能会产生不同的结果。有时需要手动进行提示工程/调整来解决这些问题,但这可能很繁琐。

MultiQueryRetriever通过使用LLM从不同角度生成多个查询来自动化提示调整过程,以适应给定用户输入查询。对于每个查询,它检索一组相关文档,并对所有查询进行唯一并集操作,以获得更大的一组可能相关的文档。通过在同一个问题上生成多个视角,MultiQueryRetriever可能能够克服基于距离的检索的一些限制,并获得更丰富的结果集。

# 构建示例向量数据库
from langchain.vectorstores import Chroma
from langchain.document_loaders import WebBaseLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter

# 加载博客文章
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
data = loader.load()

# 分割
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
splits = text_splitter.split_documents(data)

# 向量数据库
embedding = OpenAIEmbeddings()
vectordb = Chroma.from_documents(documents=splits, embedding=embedding)

API 参考:

简单用法

指定用于查询生成的LLM,剩下的工作由检索器完成。

from langchain.chat_models import ChatOpenAI
from langchain.retrievers.multi_query import MultiQueryRetriever

question = "任务分解的方法有哪些?"
llm = ChatOpenAI(temperature=0)
retriever_from_llm = MultiQueryRetriever.from_llm(
retriever=vectordb.as_retriever(), llm=llm
)

API 参考:

# 设置查询的日志记录
import logging

logging.basicConfig()
logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)

unique_docs = retriever_from_llm.get_relevant_documents(query=question)
len(unique_docs)

提供自定义提示

您还可以提供提示以及输出解析器,将结果拆分为查询列表。

from typing import List
from langchain import LLMChain
from pydantic import BaseModel, Field
from langchain.prompts import PromptTemplate
from langchain.output_parsers import PydanticOutputParser


# 输出解析器将LLM结果拆分为查询列表
class LineList(BaseModel):
# "lines" 是解析输出的键(属性名)
lines: List[str] = Field(description="文本行列表")


class LineListOutputParser(PydanticOutputParser):
def __init__(self) -> None:
super().__init__(pydantic_object=LineList)

def parse(self, text: str) -> LineList:
lines = text.strip().split("\n")
return LineList(lines=lines)


output_parser = LineListOutputParser()

QUERY_PROMPT = PromptTemplate(
input_variables=["question"],
template="""您是一个AI语言模型助手。您的任务是从向量数据库中生成给定用户问题的五个不同版本的查询以检索相关文档。通过在用户问题上生成多个视角,您的目标是帮助用户克服基于距离的相似性搜索的一些限制。请提供这些替代问题,用换行符分隔。
原始问题:{question}""",
)
llm = ChatOpenAI(temperature=0)

# 链
llm_chain = LLMChain(llm=llm, prompt=QUERY_PROMPT, output_parser=output_parser)

# 其他输入
question = "任务分解的方法有哪些?"

API 参考:

# 运行
retriever = MultiQueryRetriever(
retriever=vectordb.as_retriever(), llm_chain=llm_chain, parser_key="lines"
) # "lines" 是解析输出的键(属性名)

# 结果
unique_docs = retriever.get_relevant_documents(
query="课程中对回归有什么说法?"
)
len(unique_docs)

API 参考: