如何创建自定义聊天模型类
本指南假定读者熟悉以下概念:
在本指南中,我们将学习如何使用 LangChain 抽象来创建一个自定义的 聊天模型。
使用标准 BaseChatModel 接口包装你的大型语言模型,即可在现有的 LangChain 程序中以最少的代码修改轻松使用你的 LLM!
作为额外福利,你的大语言模型将自动成为 LangChain 可运行的 实例,并且无需额外配置即可享受一些优化功能(例如,通过线程池实现批量处理)、异步支持、astream_events API 等。
输入和输出
首先,我们需要讨论一下 消息,它们是聊天模型的输入和输出。
消息
聊天模型以消息作为输入,并返回一条消息作为输出。
LangChain 具有几种 内置消息类型:
| 消息类型 | 描述 |
|---|---|
SystemMessage | Used for priming AI behavior, usually passed in as the first of a sequence of input messages. |
HumanMessage | Represents a message from a person interacting with the chat model. |
AIMessage | Represents a message from the chat model. This can be either text or a request to invoke a tool. |
FunctionMessage / ToolMessage | Message for passing the results of tool invocation back to the model. |
AIMessageChunk / HumanMessageChunk / ... | Chunk variant of each type of message. |
ToolMessage 和 FunctionMessage 紧密遵循 OpenAI 的 function 和 tool 角色。
这是一个快速发展的领域,随着越来越多的模型增加函数调用功能,预计该模式将不断更新。
from langchain_core.messages import (
AIMessage,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
流式变体
所有聊天消息都有一个带有 Chunk 的流式版本。
from langchain_core.messages import (
AIMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
ToolMessageChunk,
)
这些块在从聊天模型流式输出时使用,且它们都定义了一个累加属性!
AIMessageChunk(content="Hello") + AIMessageChunk(content=" World!")
AIMessageChunk(content='Hello World!')
基础聊天模型
让我们实现一个聊天模型,该模型将回显提示中最后一条消息的前 n 个字符!
为此,我们将从 BaseChatModel 继承,并且需要实现以下内容:
| Method/Property | 描述 | Required/Optional |
|---|---|---|
_generate | Use to generate a chat result from a prompt | Required |
_llm_type (property) | Used to uniquely identify the type of the model. Used for logging. | Required |
_identifying_params (property) | Represent model parameterization for tracing purposes. | Optional |
_stream | Use to implement streaming. | Optional |
_agenerate | Use to implement a native async method. | Optional |
_astream | Use to implement async version of _stream. | Optional |
_astream 的实现使用 run_in_executor 在单独线程中启动同步 _stream,如果已实现 _stream,否则会回退到使用 _agenerate。
如果你想复用 _stream 的实现,可以使用这个技巧,但如果你能够实现原生异步的代码,那将是一个更好的解决方案,因为这样的代码运行时开销更小。
实现
from typing import Any, Dict, Iterator, List, Optional
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field
class ChatParrotLink(BaseChatModel):
"""A custom chat model that echoes the first `parrot_buffer_length` characters
of the input.
When contributing an implementation to LangChain, carefully document
the model including the initialization parameters, include
an example of how to initialize the model and include any relevant
links to the underlying models documentation or API.
Example:
.. code-block:: python
model = ChatParrotLink(parrot_buffer_length=2, model="bird-brain-001")
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""
model_name: str = Field(alias="model")
"""The name of the model"""
parrot_buffer_length: int
"""The number of characters from the last message of the prompt to be echoed."""
temperature: Optional[float] = None
max_tokens: Optional[int] = None
timeout: Optional[int] = None
stop: Optional[List[str]] = None
max_retries: int = 2
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Override the _generate method to implement the chat model logic.
This can be a call to an API, a call to a local model, or any other
implementation that generates a response to the input prompt.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
# Replace this with actual logic to generate a response from a list
# of messages.
last_message = messages[-1]
tokens = last_message.content[: self.parrot_buffer_length]
ct_input_tokens = sum(len(message.content) for message in messages)
ct_output_tokens = len(tokens)
message = AIMessage(
content=tokens,
additional_kwargs={}, # Used to add additional payload to the message
response_metadata={ # Use for response metadata
"time_in_seconds": 3,
"model_name": self.model_name,
},
usage_metadata={
"input_tokens": ct_input_tokens,
"output_tokens": ct_output_tokens,
"total_tokens": ct_input_tokens + ct_output_tokens,
},
)
##
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the output of the model.
This method should be implemented if the model can generate output
in a streaming fashion. If the model does not support streaming,
do not implement it. In that case streaming requests will be automatically
handled by the _generate method.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
last_message = messages[-1]
tokens = str(last_message.content[: self.parrot_buffer_length])
ct_input_tokens = sum(len(message.content) for message in messages)
for token in tokens:
usage_metadata = UsageMetadata(
{
"input_tokens": ct_input_tokens,
"output_tokens": 1,
"total_tokens": ct_input_tokens + 1,
}
)
ct_input_tokens = 0
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token, usage_metadata=usage_metadata)
)
if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
# Let's add some other information (e.g., response metadata)
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content="",
response_metadata={"time_in_sec": 3, "model_name": self.model_name},
)
)
if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return "echoing-chat-model-advanced"
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.
This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": self.model_name,
}
让我们来测试一下 🧪
聊天模型将实现 LangChain 的标准 Runnable 接口,该接口受到 LangChain 多个抽象功能的支持!
model = ChatParrotLink(parrot_buffer_length=3, model="my_custom_model")
model.invoke(
[
HumanMessage(content="hello!"),
AIMessage(content="Hi there human!"),
HumanMessage(content="Meow!"),
]
)
AIMessage(content='Meo', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-cf11aeb6-8ab6-43d7-8c68-c1ef89b6d78e-0', usage_metadata={'input_tokens': 26, 'output_tokens': 3, 'total_tokens': 29})
model.invoke("hello")
AIMessage(content='hel', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-618e5ed4-d611-4083-8cf1-c270726be8d9-0', usage_metadata={'input_tokens': 5, 'output_tokens': 3, 'total_tokens': 8})
model.batch(["hello", "goodbye"])
[AIMessage(content='hel', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-eea4ed7d-d750-48dc-90c0-7acca1ff388f-0', usage_metadata={'input_tokens': 5, 'output_tokens': 3, 'total_tokens': 8}),
AIMessage(content='goo', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-07cfc5c1-3c62-485f-b1e0-3d46e1547287-0', usage_metadata={'input_tokens': 7, 'output_tokens': 3, 'total_tokens': 10})]
for chunk in model.stream("cat"):
print(chunk.content, end="|")
c|a|t||
请查看模型中 _astream 的实现!如果没有实现,将不会有任何输出流!
async for chunk in model.astream("cat"):
print(chunk.content, end="|")
c|a|t||
让我们尝试使用 astream 事件 API,这也有助于双重确认所有回调都已正确实现!
async for event in model.astream_events("cat", version="v1"):
print(event)
{'event': 'on_chat_model_start', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'name': 'ChatParrotLink', 'tags': [], 'metadata': {}, 'data': {'input': 'cat'}, 'parent_ids': []}
{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='c', additional_kwargs={}, response_metadata={}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 3, 'output_tokens': 1, 'total_tokens': 4})}, 'parent_ids': []}
{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='a', additional_kwargs={}, response_metadata={}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 0, 'output_tokens': 1, 'total_tokens': 1})}, 'parent_ids': []}
{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='t', additional_kwargs={}, response_metadata={}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 0, 'output_tokens': 1, 'total_tokens': 1})}, 'parent_ids': []}
{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='', additional_kwargs={}, response_metadata={'time_in_sec': 3}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a')}, 'parent_ids': []}
{'event': 'on_chat_model_end', 'name': 'ChatParrotLink', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'data': {'output': AIMessageChunk(content='cat', additional_kwargs={}, response_metadata={'time_in_sec': 3}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 3, 'output_tokens': 3, 'total_tokens': 6})}, 'parent_ids': []}
贡献
我们感谢所有关于聊天模型集成的贡献。
以下是帮助确保您的贡献被添加到 LangChain 的检查清单:
Documentation:
- 该模型包含所有初始化参数的文档字符串,因为这些内容将在 API 参考 中显示。
- 该模型的类文档字符串中包含一个链接,指向模型的API(如果该模型由服务提供支持的话)。
Tests:
- 为重写的方法添加单元测试或集成测试。验证
invoke,ainvoke,batch,stream如果已覆盖相应的代码,则可以正常工作。
流式传输(如果你在实现它的话):
- 实现_stream方法以启用流式传输
停止标记行为:
- 停止标记应被尊重
- 停止标记应包含在响应结果中
秘密API密钥:
- 如果您的模型连接到API,它很可能在初始化时接受API密钥。使用Pydantic的
SecretStr用于密钥类型,以防止在打印模型时被意外输出。
识别参数:
- 包含一个
model_name在识别参数方面
Optimizations:
考虑提供原生异步支持,以减少模型的开销!
- 提供原生异步支持
_agenerate(被用于ainvoke) - 提供原生异步支持
_astream(被用于astream)
下一步
你现在已经学会了如何创建自己的自定义聊天模型。
接下来,查看本节中的其他如何使用聊天模型的指南,例如 如何让模型返回结构化输出 或 如何跟踪聊天模型的令牌使用情况。