Skip to main content

SageMaker Tracking (SageMaker跟踪)

这个笔记本展示了如何使用LangChain回调函数将提示和其他LLM超参数记录和跟踪到SageMaker实验中。在这里,我们使用不同的场景来展示这个功能:

  • 场景1单个LLM - 使用单个LLM模型根据给定的提示生成输出的情况。
  • 场景2顺序链 - 使用两个LLM模型的顺序链的情况。
  • 场景3带工具的代理(思维链) - 除了LLM之外,还使用了多个工具(搜索和数学)的情况。

Amazon SageMaker是一个完全托管的服务,用于快速、轻松地构建、训练和部署机器学习(ML)模型。

Amazon SageMaker Experiments是Amazon SageMaker的一个功能,它允许您组织、跟踪、比较和评估ML实验和模型版本。

在这个笔记本中,我们将创建一个单独的实验来记录每个场景的提示。

安装和设置

pip install sagemaker
pip install openai
pip install google-search-results

首先,设置所需的API密钥

import os

## 在下面添加您的API密钥
os.environ["OPENAI_API_KEY"] = "<ADD-KEY-HERE>"
os.environ["SERPAPI_API_KEY"] = "<ADD-KEY-HERE>"
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain, SimpleSequentialChain
from langchain.agents import initialize_agent, load_tools
from langchain.agents import Tool
from langchain.callbacks import SageMakerCallbackHandler

from sagemaker.analytics import ExperimentAnalytics
from sagemaker.session import Session
from sagemaker.experiments.run import Run

LLM提示跟踪

#LLM超参数
HPARAMS = {
"temperature": 0.1,
"model_name": "text-davinci-003",
}

#用于保存提示日志的存储桶(如果使用默认存储桶,请使用`None`,否则更改它)
BUCKET_NAME = None

#实验名称
EXPERIMENT_NAME = "langchain-sagemaker-tracker"

#使用给定的存储桶创建SageMaker会话
session = Session(default_bucket=BUCKET_NAME)

场景1 - LLM

RUN_NAME = "run-scenario-1"
PROMPT_TEMPLATE = "tell me a joke about {topic}"
INPUT_VARIABLES = {"topic": "fish"}
with Run(experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session) as run:

# 创建SageMaker回调函数
sagemaker_callback = SageMakerCallbackHandler(run)

# 使用回调函数定义LLM模型
llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)

# 创建提示模板
prompt = PromptTemplate.from_template(template=PROMPT_TEMPLATE)

# 创建LLM链
chain = LLMChain(llm=llm, prompt=prompt, callbacks=[sagemaker_callback])

# 运行链
chain.run(**INPUT_VARIABLES)

# 重置回调函数
sagemaker_callback.flush_tracker()

场景2 - 顺序链

RUN_NAME = "run-scenario-2"

PROMPT_TEMPLATE_1 = """You are a playwright. Given the title of play, it is your job to write a synopsis for that title.
Title: {title}
Playwright: This is a synopsis for the above play:"""
PROMPT_TEMPLATE_2 = """You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.
Play Synopsis: {synopsis}
Review from a New York Times play critic of the above play:"""

INPUT_VARIABLES = {
"input": "documentary about good video games that push the boundary of game design"
}
with Run(experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session) as run:

# 创建SageMaker回调函数
sagemaker_callback = SageMakerCallbackHandler(run)

# 为链创建提示模板
prompt_template1 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_1)
prompt_template2 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_2)

# 使用回调函数定义LLM模型
llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)

# 创建链1
chain1 = LLMChain(llm=llm, prompt=prompt_template1, callbacks=[sagemaker_callback])

# 创建链2
chain2 = LLMChain(llm=llm, prompt=prompt_template2, callbacks=[sagemaker_callback])

# 创建顺序链
overall_chain = SimpleSequentialChain(chains=[chain1, chain2], callbacks=[sagemaker_callback])

# 运行整个顺序链
overall_chain.run(**INPUT_VARIABLES)

# 重置回调函数
sagemaker_callback.flush_tracker()

场景3 - 带工具的代理

RUN_NAME = "run-scenario-3"
PROMPT_TEMPLATE = "Who is the oldest person alive? And what is their current age raised to the power of 1.51?"
with Run(experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session) as run:

# 创建SageMaker回调函数
sagemaker_callback = SageMakerCallbackHandler(run)

# 使用回调函数定义LLM模型
llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)

# 定义工具
tools = load_tools(["serpapi", "llm-math"], llm=llm, callbacks=[sagemaker_callback])

# 使用所有工具初始化代理
agent = initialize_agent(tools, llm, agent="zero-shot-react-description", callbacks=[sagemaker_callback])

# 运行代理
agent.run(input=PROMPT_TEMPLATE)

# 重置回调函数
sagemaker_callback.flush_tracker()

加载日志数据

一旦提示被记录,我们可以轻松地加载并将其转换为Pandas DataFrame,如下所示。

#加载
logs = ExperimentAnalytics(experiment_name=EXPERIMENT_NAME)

#转换为pandas dataframe
df = logs.dataframe(force_refresh=True)

print(df.shape)
df.head()

如上所示,实验中有三个运行(行),对应于每个场景。每个运行将提示和相关的LLM设置/超参数作为json记录,并保存在s3存储桶中。随意加载和探索每个json路径中的日志数据。