Skip to main content
Open In Colab在 GitHub 上打开

如何创建自定义 LLM 类

本笔记本介绍了如何创建自定义 LLM 包装器,以防您想使用自己的 LLM 或与 LangChain 中支持的包装器不同的包装器。

用标准包装 LLMLLMinterface 允许你在现有的 LangChain 程序中使用你的 LLM,只需修改最少的代码。

作为奖励,您的 LLM 将自动成为 LangChainRunnable并将受益于一些开箱即用的优化、异步支持、astream_eventsAPI 等

谨慎

您当前正在记录文本完成模型的使用。许多最新和最受欢迎的模型都是聊天完成模型

除非您专门使用更高级的提示技术,否则您可能正在寻找此页面

实现

自定义 LLM 只需要实现两个必需的内容:

方法描述
_callTakes in a string and some optional stop words, and returns a string. Used by invoke.
_llm_typeA property that returns a string, used for logging purposes only.

可选实现:

方法描述
_identifying_paramsUsed to help with identifying the model and printing the LLM; should return a dictionary. This is a @property.
_acallProvides an async native implementation of _call, used by ainvoke.
_streamMethod to stream the output token by token.
_astreamProvides an async native implementation of _stream; in newer LangChain versions, defaults to _stream.

让我们实现一个简单的自定义 LLM,它只返回输入的前 n 个字符。

from typing import Any, Dict, Iterator, List, Mapping, Optional

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk


class CustomLLM(LLM):
"""A custom chat model that echoes the first `n` characters of the input.

When contributing an implementation to LangChain, carefully document
the model including the initialization parameters, include
an example of how to initialize the model and include any relevant
links to the underlying models documentation or API.

Example:

.. code-block:: python

model = CustomChatModel(n=2)
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""

n: int
"""The number of characters from the last message of the prompt to be echoed."""

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given input.

Override this method to implement the LLM logic.

Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of the stop substrings.
If stop tokens are not supported consider raising NotImplementedError.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.

Returns:
The model output as a string. Actual completions SHOULD NOT include the prompt.
"""
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
return prompt[: self.n]

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream the LLM on the given prompt.

This method should be overridden by subclasses that support streaming.

If not implemented, the default behavior of calls to stream will be to
fallback to the non-streaming version of the model and return
the output as a single chunk.

Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.

Returns:
An iterator of GenerationChunks.
"""
for char in prompt[: self.n]:
chunk = GenerationChunk(text=char)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)

yield chunk

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters."""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": "CustomChatModel",
}

@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model. Used for logging purposes only."""
return "custom"

让我们测试一下 🧪

该 LLM 将实施标准Runnable许多 LangChain 抽象都支持的接口!

llm = CustomLLM(n=5)
print(llm)
CustomLLM
Params: {'model_name': 'CustomChatModel'}
llm.invoke("This is a foobar thing")
'This '
await llm.ainvoke("world")
'world'
llm.batch(["woof woof woof", "meow meow meow"])
['woof ', 'meow ']
await llm.abatch(["woof woof woof", "meow meow meow"])
['woof ', 'meow ']
async for token in llm.astream("hello"):
print(token, end="|", flush=True)
h|e|l|l|o|

让我们确认一下 in 与其他LangChain蜜蜂属。

from langchain_core.prompts import ChatPromptTemplate
API 参考:ChatPromptTemplate
prompt = ChatPromptTemplate.from_messages(
[("system", "you are a bot"), ("human", "{input}")]
)
llm = CustomLLM(n=7)
chain = prompt | llm
idx = 0
async for event in chain.astream_events({"input": "hello there!"}, version="v1"):
print(event)
idx += 1
if idx > 7:
# Truncate
break
{'event': 'on_chain_start', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'name': 'RunnableSequence', 'tags': [], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}
{'event': 'on_prompt_start', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}
{'event': 'on_prompt_end', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}, 'output': ChatPromptValue(messages=[SystemMessage(content='you are a bot'), HumanMessage(content='hello there!')])}}
{'event': 'on_llm_start', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'input': {'prompts': ['System: you are a bot\nHuman: hello there!']}}}
{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'S'}}
{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'S'}}
{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'y'}}
{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'y'}}

贡献

我们感谢所有聊天模型集成贡献。

这是一个清单,可帮助确保你的贡献被添加到 LangChain:

文档:

  • 该模型包含所有初始化参数的文档字符串,因为这些参数将显示在 APIReference 中。
  • 如果模型由服务提供支持,则模型的类 doc-string 包含指向模型 API 的链接。

测试:

  • 将单元测试或集成测试添加到覆盖的方法中。验证invoke,ainvoke,batch,streamwork (如果你覆盖了相应的代码)。

流式处理(如果您正在实现):

  • 确保调用on_llm_new_token回调
  • on_llm_new_token在生成块之前调用

停止令牌行为:

  • 应遵守 Stop token
  • Stop token 应作为响应的一部分 INCLUDED

秘密 API 密钥:

  • 如果您的模型连接到 API,它可能会接受 API 密钥作为其初始化的一部分。使用 Pydantic 的SecretStrtype 来获取密钥,这样在人们打印模型时就不会意外打印出来。