Spaces:
Runtime error
Runtime error
"""Unit tests for ReAct.""" | |
from typing import Any, List, Mapping, Optional, Union | |
from pydantic import BaseModel | |
from langchain.agents.react.base import ReActChain, ReActDocstoreAgent | |
from langchain.agents.tools import Tool | |
from langchain.docstore.base import Docstore | |
from langchain.docstore.document import Document | |
from langchain.llms.base import LLM | |
from langchain.prompts.prompt import PromptTemplate | |
from langchain.schema import AgentAction | |
_PAGE_CONTENT = """This is a page about LangChain. | |
It is a really cool framework. | |
What isn't there to love about langchain? | |
Made in 2022.""" | |
_FAKE_PROMPT = PromptTemplate(input_variables=["input"], template="{input}") | |
class FakeListLLM(LLM, BaseModel): | |
"""Fake LLM for testing that outputs elements of a list.""" | |
responses: List[str] | |
i: int = -1 | |
def _llm_type(self) -> str: | |
"""Return type of llm.""" | |
return "fake_list" | |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
"""Increment counter, and then return response in that index.""" | |
self.i += 1 | |
return self.responses[self.i] | |
def _identifying_params(self) -> Mapping[str, Any]: | |
return {} | |
class FakeDocstore(Docstore): | |
"""Fake docstore for testing purposes.""" | |
def search(self, search: str) -> Union[str, Document]: | |
"""Return the fake document.""" | |
document = Document(page_content=_PAGE_CONTENT) | |
return document | |
def test_predict_until_observation_normal() -> None: | |
"""Test predict_until_observation when observation is made normally.""" | |
outputs = ["foo\nAction 1: Search[foo]"] | |
fake_llm = FakeListLLM(responses=outputs) | |
tools = [ | |
Tool(name="Search", func=lambda x: x, description="foo"), | |
Tool(name="Lookup", func=lambda x: x, description="bar"), | |
] | |
agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools) | |
output = agent.plan([], input="") | |
expected_output = AgentAction("Search", "foo", outputs[0]) | |
assert output == expected_output | |
def test_predict_until_observation_repeat() -> None: | |
"""Test when no action is generated initially.""" | |
outputs = ["foo", " Search[foo]"] | |
fake_llm = FakeListLLM(responses=outputs) | |
tools = [ | |
Tool(name="Search", func=lambda x: x, description="foo"), | |
Tool(name="Lookup", func=lambda x: x, description="bar"), | |
] | |
agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools) | |
output = agent.plan([], input="") | |
expected_output = AgentAction("Search", "foo", "foo\nAction 1: Search[foo]") | |
assert output == expected_output | |
def test_react_chain() -> None: | |
"""Test react chain.""" | |
responses = [ | |
"I should probably search\nAction 1: Search[langchain]", | |
"I should probably lookup\nAction 2: Lookup[made]", | |
"Ah okay now I know the answer\nAction 3: Finish[2022]", | |
] | |
fake_llm = FakeListLLM(responses=responses) | |
react_chain = ReActChain(llm=fake_llm, docstore=FakeDocstore()) | |
output = react_chain.run("when was langchain made") | |
assert output == "2022" | |
def test_react_chain_bad_action() -> None: | |
"""Test react chain when bad action given.""" | |
bad_action_name = "BadAction" | |
responses = [ | |
f"I'm turning evil\nAction 1: {bad_action_name}[langchain]", | |
"Oh well\nAction 2: Finish[curses foiled again]", | |
] | |
fake_llm = FakeListLLM(responses=responses) | |
react_chain = ReActChain(llm=fake_llm, docstore=FakeDocstore()) | |
output = react_chain.run("when was langchain made") | |
assert output == "curses foiled again" | |