Spaces:
Runtime error
Runtime error
"""Base callback handler that can be used to handle callbacks from langchain.""" | |
import asyncio | |
import functools | |
from abc import ABC, abstractmethod | |
from typing import Any, Dict, List, Union | |
from langchain.schema import AgentAction, AgentFinish, LLMResult | |
class BaseCallbackHandler(ABC): | |
"""Base callback handler that can be used to handle callbacks from langchain.""" | |
def always_verbose(self) -> bool: | |
"""Whether to call verbose callbacks even if verbose is False.""" | |
return False | |
def ignore_llm(self) -> bool: | |
"""Whether to ignore LLM callbacks.""" | |
return False | |
def ignore_chain(self) -> bool: | |
"""Whether to ignore chain callbacks.""" | |
return False | |
def ignore_agent(self) -> bool: | |
"""Whether to ignore agent callbacks.""" | |
return False | |
def on_llm_start( | |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
) -> Any: | |
"""Run when LLM starts running.""" | |
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any: | |
"""Run on new LLM token. Only available when streaming is enabled.""" | |
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any: | |
"""Run when LLM ends running.""" | |
def on_llm_error( | |
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
) -> Any: | |
"""Run when LLM errors.""" | |
def on_chain_start( | |
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |
) -> Any: | |
"""Run when chain starts running.""" | |
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: | |
"""Run when chain ends running.""" | |
def on_chain_error( | |
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
) -> Any: | |
"""Run when chain errors.""" | |
def on_tool_start( | |
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any | |
) -> Any: | |
"""Run when tool starts running.""" | |
def on_tool_end(self, output: str, **kwargs: Any) -> Any: | |
"""Run when tool ends running.""" | |
def on_tool_error( | |
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
) -> Any: | |
"""Run when tool errors.""" | |
def on_text(self, text: str, **kwargs: Any) -> Any: | |
"""Run on arbitrary text.""" | |
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: | |
"""Run on agent action.""" | |
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: | |
"""Run on agent end.""" | |
class BaseCallbackManager(BaseCallbackHandler, ABC): | |
"""Base callback manager that can be used to handle callbacks from LangChain.""" | |
def is_async(self) -> bool: | |
"""Whether the callback manager is async.""" | |
return False | |
def add_handler(self, callback: BaseCallbackHandler) -> None: | |
"""Add a handler to the callback manager.""" | |
def remove_handler(self, handler: BaseCallbackHandler) -> None: | |
"""Remove a handler from the callback manager.""" | |
def set_handler(self, handler: BaseCallbackHandler) -> None: | |
"""Set handler as the only handler on the callback manager.""" | |
self.set_handlers([handler]) | |
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None: | |
"""Set handlers as the only handlers on the callback manager.""" | |
def on_agent_clarify(self, output, color, verbose): | |
pass | |
class CallbackManager(BaseCallbackManager): | |
"""Callback manager that can be used to handle callbacks from langchain.""" | |
def __init__(self, handlers: List[BaseCallbackHandler]) -> None: | |
"""Initialize callback manager.""" | |
self.handlers: List[BaseCallbackHandler] = handlers | |
def on_llm_start( | |
self, | |
serialized: Dict[str, Any], | |
prompts: List[str], | |
verbose: bool = False, | |
**kwargs: Any | |
) -> None: | |
"""Run when LLM starts running.""" | |
for handler in self.handlers: | |
if not handler.ignore_llm: | |
if verbose or handler.always_verbose: | |
handler.on_llm_start(serialized, prompts, **kwargs) | |
def on_llm_new_token( | |
self, token: str, verbose: bool = False, **kwargs: Any | |
) -> None: | |
"""Run when LLM generates a new token.""" | |
for handler in self.handlers: | |
if not handler.ignore_llm: | |
if verbose or handler.always_verbose: | |
handler.on_llm_new_token(token, **kwargs) | |
def on_llm_end( | |
self, response: LLMResult, verbose: bool = False, **kwargs: Any | |
) -> None: | |
"""Run when LLM ends running.""" | |
for handler in self.handlers: | |
if not handler.ignore_llm: | |
if verbose or handler.always_verbose: | |
handler.on_llm_end(response) | |
def on_llm_error( | |
self, | |
error: Union[Exception, KeyboardInterrupt], | |
verbose: bool = False, | |
**kwargs: Any | |
) -> None: | |
"""Run when LLM errors.""" | |
for handler in self.handlers: | |
if not handler.ignore_llm: | |
if verbose or handler.always_verbose: | |
handler.on_llm_error(error) | |
def on_chain_start( | |
self, | |
serialized: Dict[str, Any], | |
inputs: Dict[str, Any], | |
verbose: bool = False, | |
**kwargs: Any | |
) -> None: | |
"""Run when chain starts running.""" | |
for handler in self.handlers: | |
if not handler.ignore_chain: | |
if verbose or handler.always_verbose: | |
handler.on_chain_start(serialized, inputs, **kwargs) | |
def on_chain_end( | |
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any | |
) -> None: | |
"""Run when chain ends running.""" | |
for handler in self.handlers: | |
if not handler.ignore_chain: | |
if verbose or handler.always_verbose: | |
handler.on_chain_end(outputs) | |
def on_chain_error( | |
self, | |
error: Union[Exception, KeyboardInterrupt], | |
verbose: bool = False, | |
**kwargs: Any | |
) -> None: | |
"""Run when chain errors.""" | |
for handler in self.handlers: | |
if not handler.ignore_chain: | |
if verbose or handler.always_verbose: | |
handler.on_chain_error(error) | |
def on_tool_start( | |
self, | |
serialized: Dict[str, Any], | |
input_str: str, | |
verbose: bool = False, | |
**kwargs: Any | |
) -> None: | |
"""Run when tool starts running.""" | |
for handler in self.handlers: | |
if not handler.ignore_agent: | |
if verbose or handler.always_verbose: | |
handler.on_tool_start(serialized, input_str, **kwargs) | |
def on_agent_action( | |
self, action: AgentAction, verbose: bool = False, **kwargs: Any | |
) -> None: | |
"""Run when tool starts running.""" | |
for handler in self.handlers: | |
if not handler.ignore_agent: | |
if verbose or handler.always_verbose: | |
handler.on_agent_action(action, **kwargs) | |
def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None: | |
"""Run when tool ends running.""" | |
for handler in self.handlers: | |
if not handler.ignore_agent: | |
if verbose or handler.always_verbose: | |
handler.on_tool_end(output, **kwargs) | |
def on_tool_error( | |
self, | |
error: Union[Exception, KeyboardInterrupt], | |
verbose: bool = False, | |
**kwargs: Any | |
) -> None: | |
"""Run when tool errors.""" | |
for handler in self.handlers: | |
if not handler.ignore_agent: | |
if verbose or handler.always_verbose: | |
handler.on_tool_error(error) | |
def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None: | |
"""Run on additional input from chains and agents.""" | |
for handler in self.handlers: | |
if verbose or handler.always_verbose: | |
handler.on_text(text, **kwargs) | |
def on_agent_finish( | |
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any | |
) -> None: | |
"""Run on agent end.""" | |
for handler in self.handlers: | |
if not handler.ignore_agent: | |
if verbose or handler.always_verbose: | |
handler.on_agent_finish(finish, **kwargs) | |
def add_handler(self, handler: BaseCallbackHandler) -> None: | |
"""Add a handler to the callback manager.""" | |
self.handlers.append(handler) | |
def remove_handler(self, handler: BaseCallbackHandler) -> None: | |
"""Remove a handler from the callback manager.""" | |
self.handlers.remove(handler) | |
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None: | |
"""Set handlers as the only handlers on the callback manager.""" | |
self.handlers = handlers | |
class AsyncCallbackHandler(BaseCallbackHandler): | |
"""Async callback handler that can be used to handle callbacks from langchain.""" | |
async def on_llm_start( | |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
) -> None: | |
"""Run when LLM starts running.""" | |
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
"""Run on new LLM token. Only available when streaming is enabled.""" | |
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |
"""Run when LLM ends running.""" | |
async def on_llm_error( | |
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
) -> None: | |
"""Run when LLM errors.""" | |
async def on_chain_start( | |
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |
) -> None: | |
"""Run when chain starts running.""" | |
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |
"""Run when chain ends running.""" | |
async def on_chain_error( | |
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
) -> None: | |
"""Run when chain errors.""" | |
async def on_tool_start( | |
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any | |
) -> None: | |
"""Run when tool starts running.""" | |
async def on_tool_end(self, output: str, **kwargs: Any) -> None: | |
"""Run when tool ends running.""" | |
async def on_tool_error( | |
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
) -> None: | |
"""Run when tool errors.""" | |
async def on_text(self, text: str, **kwargs: Any) -> None: | |
"""Run on arbitrary text.""" | |
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None: | |
"""Run on agent action.""" | |
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: | |
"""Run on agent end.""" | |
class AsyncCallbackManager(BaseCallbackManager): | |
"""Async callback manager that can be used to handle callbacks from LangChain.""" | |
def is_async(self) -> bool: | |
"""Return whether the handler is async.""" | |
return True | |
def __init__(self, handlers: List[BaseCallbackHandler]) -> None: | |
"""Initialize callback manager.""" | |
self.handlers: List[BaseCallbackHandler] = handlers | |
async def on_llm_start( | |
self, | |
serialized: Dict[str, Any], | |
prompts: List[str], | |
verbose: bool = False, | |
**kwargs: Any | |
) -> None: | |
"""Run when LLM starts running.""" | |
for handler in self.handlers: | |
if not handler.ignore_llm: | |
if verbose or handler.always_verbose: | |
if asyncio.iscoroutinefunction(handler.on_llm_start): | |
await handler.on_llm_start(serialized, prompts, **kwargs) | |
else: | |
await asyncio.get_event_loop().run_in_executor( | |
None, | |
functools.partial( | |
handler.on_llm_start, serialized, prompts, **kwargs | |
), | |
) | |
async def on_llm_new_token( | |
self, token: str, verbose: bool = False, **kwargs: Any | |
) -> None: | |
"""Run on new LLM token. Only available when streaming is enabled.""" | |
for handler in self.handlers: | |
if not handler.ignore_llm: | |
if verbose or handler.always_verbose: | |
if asyncio.iscoroutinefunction(handler.on_llm_new_token): | |
await handler.on_llm_new_token(token, **kwargs) | |
else: | |
await asyncio.get_event_loop().run_in_executor( | |
None, | |
functools.partial( | |
handler.on_llm_new_token, token, **kwargs | |
), | |
) | |
async def on_llm_end( | |
self, response: LLMResult, verbose: bool = False, **kwargs: Any | |
) -> None: | |
"""Run when LLM ends running.""" | |
for handler in self.handlers: | |
if not handler.ignore_llm: | |
if verbose or handler.always_verbose: | |
if asyncio.iscoroutinefunction(handler.on_llm_end): | |
await handler.on_llm_end(response, **kwargs) | |
else: | |
await asyncio.get_event_loop().run_in_executor( | |
None, | |
functools.partial(handler.on_llm_end, response, **kwargs), | |
) | |
async def on_llm_error( | |
self, | |
error: Union[Exception, KeyboardInterrupt], | |
verbose: bool = False, | |
**kwargs: Any | |
) -> None: | |
"""Run when LLM errors.""" | |
for handler in self.handlers: | |
if not handler.ignore_llm: | |
if verbose or handler.always_verbose: | |
if asyncio.iscoroutinefunction(handler.on_llm_error): | |
await handler.on_llm_error(error, **kwargs) | |
else: | |
await asyncio.get_event_loop().run_in_executor( | |
None, | |
functools.partial(handler.on_llm_error, error, **kwargs), | |
) | |
async def on_chain_start( | |
self, | |
serialized: Dict[str, Any], | |
inputs: Dict[str, Any], | |
verbose: bool = False, | |
**kwargs: Any | |
) -> None: | |
"""Run when chain starts running.""" | |
for handler in self.handlers: | |
if not handler.ignore_chain: | |
if verbose or handler.always_verbose: | |
if asyncio.iscoroutinefunction(handler.on_chain_start): | |
await handler.on_chain_start(serialized, inputs, **kwargs) | |
else: | |
await asyncio.get_event_loop().run_in_executor( | |
None, | |
functools.partial( | |
handler.on_chain_start, serialized, inputs, **kwargs | |
), | |
) | |
async def on_chain_end( | |
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any | |
) -> None: | |
"""Run when chain ends running.""" | |
for handler in self.handlers: | |
if not handler.ignore_chain: | |
if verbose or handler.always_verbose: | |
if asyncio.iscoroutinefunction(handler.on_chain_end): | |
await handler.on_chain_end(outputs, **kwargs) | |
else: | |
await asyncio.get_event_loop().run_in_executor( | |
None, | |
functools.partial(handler.on_chain_end, outputs, **kwargs), | |
) | |
async def on_chain_error( | |
self, | |
error: Union[Exception, KeyboardInterrupt], | |
verbose: bool = False, | |
**kwargs: Any | |
) -> None: | |
"""Run when chain errors.""" | |
for handler in self.handlers: | |
if not handler.ignore_chain: | |
if verbose or handler.always_verbose: | |
if asyncio.iscoroutinefunction(handler.on_chain_error): | |
await handler.on_chain_error(error, **kwargs) | |
else: | |
await asyncio.get_event_loop().run_in_executor( | |
None, | |
functools.partial(handler.on_chain_error, error, **kwargs), | |
) | |
async def on_tool_start( | |
self, | |
serialized: Dict[str, Any], | |
input_str: str, | |
verbose: bool = False, | |
**kwargs: Any | |
) -> None: | |
"""Run when tool starts running.""" | |
for handler in self.handlers: | |
if not handler.ignore_agent: | |
if verbose or handler.always_verbose: | |
if asyncio.iscoroutinefunction(handler.on_tool_start): | |
await handler.on_tool_start(serialized, input_str, **kwargs) | |
else: | |
await asyncio.get_event_loop().run_in_executor( | |
None, | |
functools.partial( | |
handler.on_tool_start, serialized, input_str, **kwargs | |
), | |
) | |
async def on_tool_end( | |
self, output: str, verbose: bool = False, **kwargs: Any | |
) -> None: | |
"""Run when tool ends running.""" | |
for handler in self.handlers: | |
if not handler.ignore_agent: | |
if verbose or handler.always_verbose: | |
if asyncio.iscoroutinefunction(handler.on_tool_end): | |
await handler.on_tool_end(output, **kwargs) | |
else: | |
await asyncio.get_event_loop().run_in_executor( | |
None, | |
functools.partial(handler.on_tool_end, output, **kwargs), | |
) | |
async def on_tool_error( | |
self, | |
error: Union[Exception, KeyboardInterrupt], | |
verbose: bool = False, | |
**kwargs: Any | |
) -> None: | |
"""Run when tool errors.""" | |
for handler in self.handlers: | |
if not handler.ignore_agent: | |
if verbose or handler.always_verbose: | |
if asyncio.iscoroutinefunction(handler.on_tool_error): | |
await handler.on_tool_error(error, **kwargs) | |
else: | |
await asyncio.get_event_loop().run_in_executor( | |
None, | |
functools.partial(handler.on_tool_error, error, **kwargs), | |
) | |
async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None: | |
"""Run when text is printed.""" | |
for handler in self.handlers: | |
if verbose or handler.always_verbose: | |
if asyncio.iscoroutinefunction(handler.on_text): | |
await handler.on_text(text, **kwargs) | |
else: | |
await asyncio.get_event_loop().run_in_executor( | |
None, functools.partial(handler.on_text, text, **kwargs) | |
) | |
async def on_agent_action( | |
self, action: AgentAction, verbose: bool = False, **kwargs: Any | |
) -> None: | |
"""Run on agent action.""" | |
for handler in self.handlers: | |
if not handler.ignore_agent: | |
if verbose or handler.always_verbose: | |
if asyncio.iscoroutinefunction(handler.on_agent_action): | |
await handler.on_agent_action(action, **kwargs) | |
else: | |
await asyncio.get_event_loop().run_in_executor( | |
None, | |
functools.partial( | |
handler.on_agent_action, action, **kwargs | |
), | |
) | |
async def on_agent_finish( | |
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any | |
) -> None: | |
"""Run when agent finishes.""" | |
for handler in self.handlers: | |
if not handler.ignore_agent: | |
if verbose or handler.always_verbose: | |
if asyncio.iscoroutinefunction(handler.on_agent_finish): | |
await handler.on_agent_finish(finish, **kwargs) | |
else: | |
await asyncio.get_event_loop().run_in_executor( | |
None, | |
functools.partial( | |
handler.on_agent_finish, finish, **kwargs | |
), | |
) | |
def add_handler(self, handler: BaseCallbackHandler) -> None: | |
"""Add a handler to the callback manager.""" | |
self.handlers.append(handler) | |
def remove_handler(self, handler: BaseCallbackHandler) -> None: | |
"""Remove a handler from the callback manager.""" | |
self.handlers.remove(handler) | |
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None: | |
"""Set handlers as the only handlers on the callback manager.""" | |
self.handlers = handlers | |