如何创建自定义的LLM类
本笔记本将介绍如何创建自定义的 LLM 包装器,以便在 LangChain 不支持的 LLM 或您希望使用其他包装器时进行使用。
使用标准 LLM 接口包装你的大语言模型,即可在现有的 LangChain 程序中以最少的代码修改来使用你的大语言模型。
作为额外福利,您的大型语言模型将自动成为 LangChain Runnable,并立即享受一些优化功能,如异步支持、astream_events API 等。
实现
自定义大型语言模型需要实现的只有两个必要事项:
| 方法 | 描述 |
|---|---|
_call | Takes in a string and some optional stop words, and returns a string. Used by invoke. |
_llm_type | A property that returns a string, used for logging purposes only. |
可选实现:
| 方法 | 描述 |
|---|---|
_identifying_params | Used to help with identifying the model and printing the LLM; should return a dictionary. This is a @property. |
_acall | Provides an async native implementation of _call, used by ainvoke. |
_stream | Method to stream the output token by token. |
_astream | Provides an async native implementation of _stream; in newer LangChain versions, defaults to _stream. |
让我们实现一个简单的自定义大型语言模型,该模型仅返回输入的前 n 个字符。
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
class CustomLLM(LLM):
"""A custom chat model that echoes the first `n` characters of the input.
When contributing an implementation to LangChain, carefully document
the model including the initialization parameters, include
an example of how to initialize the model and include any relevant
links to the underlying models documentation or API.
Example:
.. code-block:: python
model = CustomChatModel(n=2)
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""
n: int
"""The number of characters from the last message of the prompt to be echoed."""
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given input.
Override this method to implement the LLM logic.
Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of the stop substrings.
If stop tokens are not supported consider raising NotImplementedError.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
The model output as a string. Actual completions SHOULD NOT include the prompt.
"""
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
return prompt[: self.n]
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream the LLM on the given prompt.
This method should be overridden by subclasses that support streaming.
If not implemented, the default behavior of calls to stream will be to
fallback to the non-streaming version of the model and return
the output as a single chunk.
Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
An iterator of GenerationChunks.
"""
for char in prompt[: self.n]:
chunk = GenerationChunk(text=char)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters."""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": "CustomChatModel",
}
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model. Used for logging purposes only."""
return "custom"
让我们来测试一下 🧪
此大型语言模型将实现 LangChain 的标准 Runnable 接口,该接口受到许多 LangChain 抽象的支持!
llm = CustomLLM(n=5)
print(llm)
[1mCustomLLM[0m
Params: {'model_name': 'CustomChatModel'}
llm.invoke("This is a foobar thing")
'This '
await llm.ainvoke("world")
'world'
llm.batch(["woof woof woof", "meow meow meow"])
['woof ', 'meow ']
await llm.abatch(["woof woof woof", "meow meow meow"])
['woof ', 'meow ']
async for token in llm.astream("hello"):
print(token, end="|", flush=True)
h|e|l|l|o|
让我们确认它是否能与其他 LangChain API 无缝集成。
from langchain_core.prompts import ChatPromptTemplate
API 参考:ChatPromptTemplate
prompt = ChatPromptTemplate.from_messages(
[("system", "you are a bot"), ("human", "{input}")]
)
llm = CustomLLM(n=7)
chain = prompt | llm
idx = 0
async for event in chain.astream_events({"input": "hello there!"}, version="v1"):
print(event)
idx += 1
if idx > 7:
# Truncate
break
{'event': 'on_chain_start', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'name': 'RunnableSequence', 'tags': [], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}
{'event': 'on_prompt_start', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}
{'event': 'on_prompt_end', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}, 'output': ChatPromptValue(messages=[SystemMessage(content='you are a bot'), HumanMessage(content='hello there!')])}}
{'event': 'on_llm_start', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'input': {'prompts': ['System: you are a bot\nHuman: hello there!']}}}
{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'S'}}
{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'S'}}
{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'y'}}
{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'y'}}
贡献
我们感谢所有关于聊天模型集成的贡献。
以下是帮助确保您的贡献被添加到 LangChain 的检查清单:
Documentation:
- 该模型包含所有初始化参数的文档字符串,因为这些内容将在 API 参考 中显示。
- 该模型的类文档字符串中包含一个链接,指向模型的API(如果该模型由服务提供支持的话)。
Tests:
- 为重写的方法添加单元测试或集成测试。验证
invoke,ainvoke,batch,stream如果已覆盖相应的代码,则可以正常工作。
流式传输(如果你在实现它的话):
- 确保调用
on_llm_new_token回调 -
on_llm_new_token在生成数据块之前被调用
停止标记行为:
- 停止标记应被尊重
- 停止标记应包含在响应结果中
秘密API密钥:
- 如果您的模型连接到API,它很可能在初始化时接受API密钥。使用Pydantic的
SecretStr用于密钥类型,以防止在打印模型时被意外输出。