|
from __future__ import annotations |
|
from copy import deepcopy |
|
from typing import Any, Dict |
|
|
|
import hydra |
|
from langchain.tools import BaseTool |
|
|
|
from flows.base_flows import AtomicFlow |
|
|
|
|
|
class LCToolFlow(AtomicFlow): |
|
REQUIRED_KEYS_CONFIG = ["backend"] |
|
|
|
SUPPORTS_CACHING: bool = False |
|
|
|
backend: BaseTool |
|
|
|
def __init__(self, backend: BaseTool, **kwargs) -> None: |
|
super().__init__(**kwargs) |
|
self.backend = backend |
|
|
|
@classmethod |
|
def _set_up_backend(cls, config: Dict[str, Any]) -> BaseTool: |
|
if config["_target_"].startswith("."): |
|
|
|
|
|
|
|
cls_parent_module = ".".join(cls.__module__.split(".")[:-1]) |
|
config["_target_"] = cls_parent_module + config["_target_"] |
|
tool = hydra.utils.instantiate(config, _convert_="partial") |
|
|
|
return tool |
|
|
|
@classmethod |
|
def instantiate_from_config(cls, config: Dict[str, Any]) -> LCToolFlow: |
|
flow_config = deepcopy(config) |
|
|
|
kwargs = {"flow_config": flow_config} |
|
|
|
|
|
kwargs["backend"] = cls._set_up_backend(config["backend"]) |
|
|
|
|
|
return cls(**kwargs) |
|
|
|
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
|
observation = self.backend.run(tool_input=input_data) |
|
|
|
return {"observation": observation} |
|
|
|
|