jfeng1115's picture
init commit
58d33f0
"""Common schema objects."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, NamedTuple, Optional
from pydantic import BaseModel, Extra, Field, root_validator
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Get buffer string of messages."""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
elif isinstance(m, ChatMessage):
role = m.role
else:
raise ValueError(f"Got unsupported message type: {m}")
string_messages.append(f"{role}: {m.content}")
return "\n".join(string_messages)
class AgentAction(NamedTuple):
"""Agent's action to take."""
tool: str
tool_input: str
log: str
class AgentFinish(NamedTuple):
"""Agent's return value."""
return_values: dict
log: str
class AgentClarify(NamedTuple):
"""Agent's clarification request."""
question: str
log: str
class Generation(BaseModel):
"""Output of a single generation."""
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
"""Raw generation info response from the provider"""
"""May include things like reason for finishing (e.g. in OpenAI)"""
# TODO: add log probs
class BaseMessage(BaseModel):
"""Message object."""
content: str
additional_kwargs: dict = Field(default_factory=dict)
@property
@abstractmethod
def type(self) -> str:
"""Type of the message, used for serialization."""
class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "human"
class AIMessage(BaseMessage):
"""Type of message that is spoken by the AI."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "ai"
class SystemMessage(BaseMessage):
"""Type of message that is a system message."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "system"
class ChatMessage(BaseMessage):
"""Type of message with arbitrary speaker."""
role: str
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "chat"
def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
def messages_to_dict(messages: List[BaseMessage]) -> List[dict]:
return [_message_to_dict(m) for m in messages]
def _message_from_dict(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
elif _type == "ai":
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "chat":
return ChatMessage(**message["data"])
else:
raise ValueError(f"Got unexpected type: {_type}")
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
return [_message_from_dict(m) for m in messages]
class ChatGeneration(Generation):
"""Output of a single generation."""
text = ""
message: BaseMessage
@root_validator
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["text"] = values["message"].content
return values
class ChatResult(BaseModel):
"""Class that contains all relevant information for a Chat Result."""
generations: List[ChatGeneration]
"""List of the things generated."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
class LLMResult(BaseModel):
"""Class that contains all relevant information for an LLM Result."""
generations: List[List[Generation]]
"""List of the things generated. This is List[List[]] because
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
class PromptValue(BaseModel, ABC):
@abstractmethod
def to_string(self) -> str:
"""Return prompt as string."""
@abstractmethod
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
class BaseLanguageModel(BaseModel, ABC):
@abstractmethod
def generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
@abstractmethod
async def agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
# TODO: this method may not be exact.
# TODO: this method may differ based on model (eg codex).
try:
from transformers import GPT2TokenizerFast
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"This is needed in order to calculate get_num_tokens. "
"Please it install it with `pip install transformers`."
)
# create a GPT-3 tokenizer instance
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# tokenize the text using the GPT-3 tokenizer
tokenized_text = tokenizer.tokenize(text)
# calculate the number of tokens in the tokenized text
return len(tokenized_text)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the message."""
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
class BaseMemory(BaseModel, ABC):
"""Base interface for memory in chains."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
@abstractmethod
def memory_variables(self) -> List[str]:
"""Input keys this memory class will load dynamically."""
@abstractmethod
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain.
If None, return all memories
"""
@abstractmethod
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this model run to memory."""
@abstractmethod
def clear(self) -> None:
"""Clear memory contents."""
class Document(BaseModel):
"""Interface for interacting with a document."""
page_content: str
lookup_str: str = ""
lookup_index = 0
metadata: dict = Field(default_factory=dict)
@property
def paragraphs(self) -> List[str]:
"""Paragraphs of the page."""
return self.page_content.split("\n\n")
@property
def summary(self) -> str:
"""Summary of the page (the first paragraph)."""
return self.paragraphs[0]
def lookup(self, string: str) -> str:
"""Lookup a term in the page, imitating cmd-F functionality."""
if string.lower() != self.lookup_str:
self.lookup_str = string.lower()
self.lookup_index = 0
else:
self.lookup_index += 1
lookups = [p for p in self.paragraphs if self.lookup_str in p.lower()]
if len(lookups) == 0:
return "No Results"
elif self.lookup_index >= len(lookups):
return "No More Results"
else:
result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})"
return f"{result_prefix} {lookups[self.lookup_index]}"
class BaseRetriever(ABC):
@abstractmethod
def get_relevant_texts(self, query: str) -> List[Document]:
"""Get texts relevant for a query.
Args:
query: string to find relevant tests for
Returns:
List of relevant documents
"""
# For backwards compatibility
Memory = BaseMemory
class BaseOutputParser(BaseModel, ABC):
"""Class to parse the output of an LLM call."""
@abstractmethod
def parse(self, text: str) -> Any:
"""Parse the output of an LLM call."""
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
return self.parse(completion)
def get_format_instructions(self) -> str:
raise NotImplementedError
@property
def _type(self) -> str:
"""Return the type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict
class OutputParserException(Exception):
"""Exception that output parsers should raise to signify a parsing error.
This exists to differentiate parsing errors from other code or execution errors
that also may arise inside the output parser. OutputParserExceptions will be
available to catch and handle in ways to fix the parsing error, while other
errors will be raised.
"""
pass