Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
from typing import List, Optional | |
from pydantic import ValidationError | |
from langchain.chains.llm import LLMChain | |
from langchain.chat_models.base import BaseChatModel | |
from .output_parser import ( | |
AutoGPTOutputParser, | |
BaseAutoGPTOutputParser, | |
) | |
from .prompt import AutoGPTPrompt | |
from .prompt_generator import ( | |
FINISH_NAME, | |
) | |
from langchain.schema import ( | |
AIMessage, | |
BaseMessage, | |
Document, | |
HumanMessage, | |
SystemMessage, | |
) | |
from langchain.tools.base import BaseTool | |
from langchain.tools.human.tool import HumanInputRun | |
from langchain.vectorstores.base import VectorStoreRetriever | |
import json | |
class AutoGPT: | |
"""Agent class for interacting with Auto-GPT.""" | |
def __init__( | |
self, | |
ai_name: str, | |
memory: VectorStoreRetriever, | |
chain: LLMChain, | |
output_parser: BaseAutoGPTOutputParser, | |
tools: List[BaseTool], | |
feedback_tool: Optional[HumanInputRun] = None, | |
): | |
self.ai_name = ai_name | |
self.memory = memory | |
self.full_message_history: List[BaseMessage] = [] | |
self.next_action_count = 0 | |
self.chain = chain | |
self.output_parser = output_parser | |
self.tools = tools | |
self.feedback_tool = feedback_tool | |
def from_llm_and_tools( | |
cls, | |
ai_name: str, | |
ai_role: str, | |
memory: VectorStoreRetriever, | |
tools: List[BaseTool], | |
llm: BaseChatModel, | |
human_in_the_loop: bool = False, | |
output_parser: Optional[BaseAutoGPTOutputParser] = None, | |
) -> AutoGPT: | |
prompt = AutoGPTPrompt( | |
ai_name=ai_name, | |
ai_role=ai_role, | |
tools=tools, | |
input_variables=["memory", "messages", "goals", "user_input"], | |
token_counter=llm.get_num_tokens, | |
) | |
human_feedback_tool = HumanInputRun() if human_in_the_loop else None | |
chain = LLMChain(llm=llm, prompt=prompt) | |
return cls( | |
ai_name, | |
memory, | |
chain, | |
output_parser or AutoGPTOutputParser(), | |
tools, | |
feedback_tool=human_feedback_tool, | |
) | |
def __call__(self, goals: List[str]) -> str: | |
user_input = ( | |
"Determine which next command to use, " | |
"and respond using the format specified above:" | |
) | |
# Interaction Loop | |
loop_count = 0 | |
history_rec = [] | |
while True: | |
# Discontinue if continuous limit is reached | |
loop_count += 1 | |
# Send message to AI, get response | |
assistant_reply = self.chain.run( | |
goals=goals, | |
messages=self.full_message_history, | |
memory=self.memory, | |
user_input=user_input, | |
) | |
pos = assistant_reply.find("{") | |
if pos > 0: | |
assistant_reply = assistant_reply[pos:] | |
# Print Assistant thoughts | |
print(assistant_reply) | |
self.full_message_history.append(HumanMessage(content=user_input)) | |
self.full_message_history.append(AIMessage(content=assistant_reply)) | |
# Get command name and arguments | |
action = self.output_parser.parse(assistant_reply) | |
tools = {t.name: t for t in self.tools} | |
if action.name == FINISH_NAME: | |
return action.args["response"] | |
if action.name in tools: | |
tool = tools[action.name] | |
try: | |
# for tools in swarms.tools, the input should be string, while for default langchain toosl, the input is in json format, here we modify the following code | |
tmp_json = action.args.copy() | |
tmp_json["history context"] = str(history_rec[-5:])[-500:] | |
tmp_json["user message"] = goals[0] | |
json_args = str(tmp_json).replace("'", '"') | |
observation = tool.run(json_args) | |
except ValidationError as e: | |
observation = f"Error in args: {str(e)}" | |
result = f"Command {tool.name} returned: {observation}" | |
if ( | |
result.find("using the given APIs") == -1 | |
and result.lower().find("no answer") == -1 | |
): | |
history_rec.append(f"Tool {action.name} returned: {observation}") | |
elif action.name == "ERROR": | |
result = f"Error: {action.args}. " | |
else: | |
result = ( | |
f"Unknown command '{action.name}'. " | |
f"Please refer to the 'COMMANDS' list for available " | |
f"commands and only respond in the specified JSON format." | |
) | |
self.full_message_history.append(SystemMessage(content=result)) | |