如何创建自定义聊天模型类
本指南假定您熟悉以下概念:
在本指南中,我们将学习如何使用 LangChain 抽象创建自定义聊天模型。
用标准包装 LLMBaseChatModelinterface 允许你在现有的 LangChain 程序中使用你的 LLM,只需修改最少的代码!
作为奖励,您的 LLM 将自动成为 LangChain Runnable,并将受益于一些开箱即用的优化(例如,通过线程池进行批处理)、异步支持、astream_eventsAPI 等
输入和输出
首先,我们需要讨论消息,它们是聊天模型的输入和输出。
消息
聊天模型将消息作为输入,并返回消息作为输出。
LangChain 有一些内置的消息类型:
| 消息类型 | 描述 |
|---|---|
SystemMessage | Used for priming AI behavior, usually passed in as the first of a sequence of input messages. |
HumanMessage | Represents a message from a person interacting with the chat model. |
AIMessage | Represents a message from the chat model. This can be either text or a request to invoke a tool. |
FunctionMessage / ToolMessage | Message for passing the results of tool invocation back to the model. |
AIMessageChunk / HumanMessageChunk / ... | Chunk variant of each type of message. |
ToolMessage和FunctionMessage紧随 OpenAI 的function和tool角色。
这是一个快速发展的领域,随着越来越多的模型添加函数调用功能。预计此架构将添加内容。
from langchain_core.messages import (
AIMessage,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
流式处理变体
所有聊天消息都有一个流式处理变体,其中包含Chunk在名称中。
from langchain_core.messages import (
AIMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
ToolMessageChunk,
)
这些块在流式传输聊天模型的输出时使用,并且它们都定义了一个 additive 属性!
AIMessageChunk(content="Hello") + AIMessageChunk(content=" World!")
AIMessageChunk(content='Hello World!')
基本聊天模型
让我们实现一个 chat 模型,该模型回显第一个n提示中最后一条消息的字符!
为此,我们将从BaseChatModel我们需要实现以下内容:
| 方法/属性 | 描述 | 必需/可选 |
|---|---|---|
_generate | Use to generate a chat result from a prompt | Required |
_llm_type (property) | Used to uniquely identify the type of the model. Used for logging. | Required |
_identifying_params (property) | Represent model parameterization for tracing purposes. | Optional |
_stream | Use to implement streaming. | Optional |
_agenerate | Use to implement a native async method. | Optional |
_astream | Use to implement async version of _stream. | Optional |
这_astreamimplementation usesrun_in_executor启动同步_stream在单独的线程中 if_stream已实现,否则它会回退到_agenerate.
如果您想重用_stream实现,但如果你能够实现原生异步的代码,那是一个更好的解决方案,因为该代码将以更少的开销运行。
实现
from typing import Any, Dict, Iterator, List, Optional
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field
class ChatParrotLink(BaseChatModel):
"""A custom chat model that echoes the first `parrot_buffer_length` characters
of the input.
When contributing an implementation to LangChain, carefully document
the model including the initialization parameters, include
an example of how to initialize the model and include any relevant
links to the underlying models documentation or API.
Example:
.. code-block:: python
model = ChatParrotLink(parrot_buffer_length=2, model="bird-brain-001")
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""
model_name: str = Field(alias="model")
"""The name of the model"""
parrot_buffer_length: int
"""The number of characters from the last message of the prompt to be echoed."""
temperature: Optional[float] = None
max_tokens: Optional[int] = None
timeout: Optional[int] = None
stop: Optional[List[str]] = None
max_retries: int = 2
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Override the _generate method to implement the chat model logic.
This can be a call to an API, a call to a local model, or any other
implementation that generates a response to the input prompt.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
# Replace this with actual logic to generate a response from a list
# of messages.
last_message = messages[-1]
tokens = last_message.content[: self.parrot_buffer_length]
ct_input_tokens = sum(len(message.content) for message in messages)
ct_output_tokens = len(tokens)
message = AIMessage(
content=tokens,
additional_kwargs={}, # Used to add additional payload to the message
response_metadata={ # Use for response metadata
"time_in_seconds": 3,
"model_name": self.model_name,
},
usage_metadata={
"input_tokens": ct_input_tokens,
"output_tokens": ct_output_tokens,
"total_tokens": ct_input_tokens + ct_output_tokens,
},
)
##
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the output of the model.
This method should be implemented if the model can generate output
in a streaming fashion. If the model does not support streaming,
do not implement it. In that case streaming requests will be automatically
handled by the _generate method.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
last_message = messages[-1]
tokens = str(last_message.content[: self.parrot_buffer_length])
ct_input_tokens = sum(len(message.content) for message in messages)
for token in tokens:
usage_metadata = UsageMetadata(
{
"input_tokens": ct_input_tokens,
"output_tokens": 1,
"total_tokens": ct_input_tokens + 1,
}
)
ct_input_tokens = 0
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token, usage_metadata=usage_metadata)
)
if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
# Let's add some other information (e.g., response metadata)
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content="",
response_metadata={"time_in_sec": 3, "model_name": self.model_name},
)
)
if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return "echoing-chat-model-advanced"
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.
This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": self.model_name,
}
让我们测试一下 🧪
聊天模型将实现标准Runnable许多 LangChain 抽象都支持的接口!
model = ChatParrotLink(parrot_buffer_length=3, model="my_custom_model")
model.invoke(
[
HumanMessage(content="hello!"),
AIMessage(content="Hi there human!"),
HumanMessage(content="Meow!"),
]
)
AIMessage(content='Meo', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-cf11aeb6-8ab6-43d7-8c68-c1ef89b6d78e-0', usage_metadata={'input_tokens': 26, 'output_tokens': 3, 'total_tokens': 29})
model.invoke("hello")
AIMessage(content='hel', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-618e5ed4-d611-4083-8cf1-c270726be8d9-0', usage_metadata={'input_tokens': 5, 'output_tokens': 3, 'total_tokens': 8})
model.batch(["hello", "goodbye"])
[AIMessage(content='hel', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-eea4ed7d-d750-48dc-90c0-7acca1ff388f-0', usage_metadata={'input_tokens': 5, 'output_tokens': 3, 'total_tokens': 8}),
AIMessage(content='goo', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-07cfc5c1-3c62-485f-b1e0-3d46e1547287-0', usage_metadata={'input_tokens': 7, 'output_tokens': 3, 'total_tokens': 10})]
for chunk in model.stream("cat"):
print(chunk.content, end="|")
c|a|t||
请参阅_astream在模型中!如果您不实施它,则不会有输出流。!
async for chunk in model.astream("cat"):
print(chunk.content, end="|")
c|a|t||
让我们尝试使用 astream events API,它也将有助于仔细检查所有回调是否已实现!
async for event in model.astream_events("cat", version="v1"):
print(event)
{'event': 'on_chat_model_start', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'name': 'ChatParrotLink', 'tags': [], 'metadata': {}, 'data': {'input': 'cat'}, 'parent_ids': []}
{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='c', additional_kwargs={}, response_metadata={}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 3, 'output_tokens': 1, 'total_tokens': 4})}, 'parent_ids': []}
{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='a', additional_kwargs={}, response_metadata={}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 0, 'output_tokens': 1, 'total_tokens': 1})}, 'parent_ids': []}
{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='t', additional_kwargs={}, response_metadata={}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 0, 'output_tokens': 1, 'total_tokens': 1})}, 'parent_ids': []}
{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='', additional_kwargs={}, response_metadata={'time_in_sec': 3}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a')}, 'parent_ids': []}
{'event': 'on_chat_model_end', 'name': 'ChatParrotLink', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'data': {'output': AIMessageChunk(content='cat', additional_kwargs={}, response_metadata={'time_in_sec': 3}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 3, 'output_tokens': 3, 'total_tokens': 6})}, 'parent_ids': []}
贡献
我们感谢所有聊天模型集成贡献。
这是一个清单,可帮助确保你的贡献被添加到 LangChain:
文档:
- 该模型包含所有初始化参数的文档字符串,因为这些参数将显示在 API 参考 中。
- 如果模型由服务提供支持,则模型的类 doc-string 包含指向模型 API 的链接。
测试:
- 将单元测试或集成测试添加到覆盖的方法中。验证
invoke,ainvoke,batch,streamwork (如果你覆盖了相应的代码)。
流式处理(如果您正在实现):
- 实现 _stream 方法以使流式处理正常工作
停止令牌行为:
- 应遵守 Stop token
- Stop token 应作为响应的一部分 INCLUDED
秘密 API 密钥:
- 如果您的模型连接到 API,它可能会接受 API 密钥作为其初始化的一部分。使用 Pydantic 的
SecretStrtype 来获取密钥,这样在人们打印模型时就不会意外打印出来。
标识参数:
- 包括
model_name在识别参数时
优化:
考虑提供原生异步支持以减少模型的开销!
- 提供了
_agenerate(使用者ainvoke) - 提供了
_astream(使用者astream)
后续步骤
您现在已经学习了如何创建自己的自定义聊天模型。
接下来,查看本节中的其他操作指南聊天模型,例如如何让模型返回结构化输出或如何跟踪聊天模型令牌使用情况。