Spaces:
Runtime error
Runtime error
"""Select examples based on length.""" | |
import re | |
from typing import Callable, Dict, List | |
from pydantic import BaseModel, validator | |
from langchain.prompts.example_selector.base import BaseExampleSelector | |
from langchain.prompts.prompt import PromptTemplate | |
def _get_length_based(text: str) -> int: | |
return len(re.split("\n| ", text)) | |
class LengthBasedExampleSelector(BaseExampleSelector, BaseModel): | |
"""Select examples based on length.""" | |
examples: List[dict] | |
"""A list of the examples that the prompt template expects.""" | |
example_prompt: PromptTemplate | |
"""Prompt template used to format the examples.""" | |
get_text_length: Callable[[str], int] = _get_length_based | |
"""Function to measure prompt length. Defaults to word count.""" | |
max_length: int = 2048 | |
"""Max length for the prompt, beyond which examples are cut.""" | |
example_text_lengths: List[int] = [] #: :meta private: | |
def add_example(self, example: Dict[str, str]) -> None: | |
"""Add new example to list.""" | |
self.examples.append(example) | |
string_example = self.example_prompt.format(**example) | |
self.example_text_lengths.append(self.get_text_length(string_example)) | |
def calculate_example_text_lengths(cls, v: List[int], values: Dict) -> List[int]: | |
"""Calculate text lengths if they don't exist.""" | |
# Check if text lengths were passed in | |
if v: | |
return v | |
# If they were not, calculate them | |
example_prompt = values["example_prompt"] | |
get_text_length = values["get_text_length"] | |
string_examples = [example_prompt.format(**eg) for eg in values["examples"]] | |
return [get_text_length(eg) for eg in string_examples] | |
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: | |
"""Select which examples to use based on the input lengths.""" | |
inputs = " ".join(input_variables.values()) | |
remaining_length = self.max_length - self.get_text_length(inputs) | |
i = 0 | |
examples = [] | |
while remaining_length > 0 and i < len(self.examples): | |
new_length = remaining_length - self.example_text_lengths[i] | |
if new_length < 0: | |
break | |
else: | |
examples.append(self.examples[i]) | |
remaining_length = new_length | |
i += 1 | |
return examples | |