Spaces:
Runtime error
Runtime error
"""An agent designed to hold a conversation in addition to using tools.""" | |
from __future__ import annotations | |
import re | |
from typing import Any, List, Optional, Sequence, Tuple | |
from langchain.agents.agent import Agent | |
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX | |
from langchain.callbacks.base import BaseCallbackManager | |
from langchain.chains import LLMChain | |
from langchain.llms import BaseLLM | |
from langchain.prompts import PromptTemplate | |
from langchain.tools.base import BaseTool | |
class ConversationalAgent(Agent): | |
"""An agent designed to hold a conversation in addition to using tools.""" | |
ai_prefix: str = "AI" | |
def _agent_type(self) -> str: | |
"""Return Identifier of agent type.""" | |
return "conversational-react-description" | |
def observation_prefix(self) -> str: | |
"""Prefix to append the observation with.""" | |
return "Observation: " | |
def llm_prefix(self) -> str: | |
"""Prefix to append the llm call with.""" | |
return "Thought:" | |
def create_prompt( | |
cls, | |
tools: Sequence[BaseTool], | |
prefix: str = PREFIX, | |
suffix: str = SUFFIX, | |
format_instructions: str = FORMAT_INSTRUCTIONS, | |
ai_prefix: str = "AI", | |
human_prefix: str = "Human", | |
input_variables: Optional[List[str]] = None, | |
) -> PromptTemplate: | |
"""Create prompt in the style of the zero shot agent. | |
Args: | |
tools: List of tools the agent will have access to, used to format the | |
prompt. | |
prefix: String to put before the list of tools. | |
suffix: String to put after the list of tools. | |
ai_prefix: String to use before AI output. | |
human_prefix: String to use before human output. | |
input_variables: List of input variables the final prompt will expect. | |
Returns: | |
A PromptTemplate with the template assembled from the pieces here. | |
""" | |
tool_strings = "\n".join( | |
[f"> {tool.name}: {tool.description}" for tool in tools] | |
) | |
tool_names = ", ".join([tool.name for tool in tools]) | |
format_instructions = format_instructions.format( | |
tool_names=tool_names, ai_prefix=ai_prefix, human_prefix=human_prefix | |
) | |
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) | |
if input_variables is None: | |
input_variables = ["input", "chat_history", "agent_scratchpad"] | |
return PromptTemplate(template=template, input_variables=input_variables) | |
def finish_tool_name(self) -> str: | |
"""Name of the tool to use to finish the chain.""" | |
return self.ai_prefix | |
def _extract_tool_and_input(self, llm_output: str) -> Optional[Tuple[str, str]]: | |
if f"{self.ai_prefix}:" in llm_output: | |
return self.ai_prefix, llm_output.split(f"{self.ai_prefix}:")[-1].strip() | |
regex = r"Action: (.*?)[\n]*Action Input: (.*)" | |
match = re.search(regex, llm_output) | |
if not match: | |
raise ValueError(f"Could not parse LLM output: `{llm_output}`") | |
action = match.group(1) | |
action_input = match.group(2) | |
return action.strip(), action_input.strip(" ").strip('"') | |
def from_llm_and_tools( | |
cls, | |
llm: BaseLLM, | |
tools: Sequence[BaseTool], | |
callback_manager: Optional[BaseCallbackManager] = None, | |
prefix: str = PREFIX, | |
suffix: str = SUFFIX, | |
format_instructions: str = FORMAT_INSTRUCTIONS, | |
ai_prefix: str = "AI", | |
human_prefix: str = "Human", | |
input_variables: Optional[List[str]] = None, | |
**kwargs: Any, | |
) -> Agent: | |
"""Construct an agent from an LLM and tools.""" | |
cls._validate_tools(tools) | |
prompt = cls.create_prompt( | |
tools, | |
ai_prefix=ai_prefix, | |
human_prefix=human_prefix, | |
prefix=prefix, | |
suffix=suffix, | |
format_instructions=format_instructions, | |
input_variables=input_variables, | |
) | |
llm_chain = LLMChain( | |
llm=llm, | |
prompt=prompt, | |
callback_manager=callback_manager, | |
) | |
tool_names = [tool.name for tool in tools] | |
return cls( | |
llm_chain=llm_chain, allowed_tools=tool_names, ai_prefix=ai_prefix, **kwargs | |
) | |