hemm_space / sotopia_pi_generate.py
talha1503's picture
add application file
f1f9b0c
import re
import os
from typing import TypeVar
from functools import cache
import logging
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
from peft import PeftModel
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_community.chat_models import ChatLiteLLM
from langchain.chains import LLMChain
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
)
from langchain.schema import BaseOutputParser, OutputParserException
from message_classes import ActionType, AgentAction
from utils import format_docstring
from langchain_callback_handler import LoggingCallbackHandler
HF_TOKEN_KEY_FILE="./hf_token.key"
if os.path.exists(HF_TOKEN_KEY_FILE):
with open(HF_TOKEN_KEY_FILE, "r") as f:
os.environ["HF_TOKEN"] = f.read().strip()
OutputType = TypeVar("OutputType", bound=object)
log = logging.getLogger("generate")
logging_handler = LoggingCallbackHandler("langchain")
def generate_action(
model_name: str,
history: str,
turn_number: int,
action_types: list[ActionType],
agent: str,
temperature: float = 0.7,
) -> AgentAction:
"""
Using langchain to generate an example episode
"""
# try:
# Normal case, model as agent
template = """
Imagine you are {agent}, your task is to act/speak as {agent} would, keeping in mind {agent}'s social goal.
You can find {agent}'s goal (or background) in the 'Here is the context of the interaction' field.
Note that {agent}'s goal is only visible to you.
You should try your best to achieve {agent}'s goal in a way that align with their character traits.
Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).\n
{history}.
You are at Turn #{turn_number}. Your available action types are
{action_list}.
Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave.
Please only generate a JSON string including the action type and the argument.
Your action should follow the given format:
{format_instructions}
"""
return generate(
model_name=model_name,
template=template,
input_values=dict(
agent=agent,
turn_number=str(turn_number),
history=history,
action_list=" ".join(action_types),
),
output_parser=PydanticOutputParser(pydantic_object=AgentAction),
temperature=temperature,
)
# except Exception as e:
# print(e)
# return AgentAction(action_type="none", argument="")
@cache
def prepare_model(model_name):
compute_type = torch.float16
if model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR':
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", model_max_length=4096)
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.1",
cache_dir="./.cache",
device_map='cuda'
)
model = PeftModel.from_pretrained(model, model_name).to("cuda")
elif model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit':
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", model_max_length=4096)
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.1",
cache_dir="./.cache",
device_map='cuda',
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_type,
)
)
model = PeftModel.from_pretrained(model, model_name[0:-5]).to("cuda")
elif model_name == 'mistralai/Mistral-7B-Instruct-v0.1':
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", model_max_length=4096)
tokenizer.model_max_length = 4096
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.1",
cache_dir="./.cache",
device_map='cuda'
)
else:
raise RuntimeError(f"Model {model_name} not supported")
return model, tokenizer
def obtain_chain_hf(
model_name: str,
template: str,
input_variables: list[str],
temperature: float = 0.7,
max_retries: int = 6,
max_tokens: int = 2700
) -> LLMChain:
human_message_prompt = HumanMessagePromptTemplate(
prompt=PromptTemplate(template=template, input_variables=input_variables)
)
chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
model, tokenizer = prepare_model(model_name)
pipe = pipeline("text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=100,
temperature=temperature,
return_full_text=False,
do_sample=True,
num_beams=3,
)
hf = HuggingFacePipeline(pipeline=pipe)
chain = LLMChain(llm=hf, prompt=chat_prompt_template)
return chain
def generate(
model_name: str,
template: str,
input_values: dict[str, str],
output_parser: BaseOutputParser[OutputType],
temperature: float = 0.7,
) -> OutputType:
input_variables = re.findall(r"{(.*?)}", template)
assert (
set(input_variables) == set(list(input_values.keys()) + ["format_instructions"])
or set(input_variables) == set(list(input_values.keys()))
), f"The variables in the template must match input_values except for format_instructions. Got {sorted(input_values.keys())}, expect {sorted(input_variables)}"
# process template
template = format_docstring(template)
chain = obtain_chain(model_name, template, input_variables, temperature)
if "format_instructions" not in input_values:
input_values["format_instructions"] = output_parser.get_format_instructions()
result = chain.predict([logging_handler], **input_values)
prompt = logging_handler.retrive_prompt()
print(f"Prompt:\n {prompt}")
print(f"Result:\n {result}")
try:
parsed_result = output_parser.parse(result)
except KeyboardInterrupt:
raise KeyboardInterrupt
except Exception as e:
log.debug(
f"[red] Failed to parse result: {result}\nEncounter Exception {e}\nstart to reparse",
extra={"markup": True},
)
reformat_parsed_result = format_bad_output(
result, format_instructions=output_parser.get_format_instructions()
)
print(f"Reformatted result:\n {reformat_parsed_result}")
parsed_result = output_parser.parse(reformat_parsed_result)
log.info(f"Generated result: {parsed_result}")
return parsed_result
def format_bad_output(
ill_formed_output: str,
format_instructions: str,
model_name: str = "gpt-3.5-turbo",
) -> str:
template = """
Given the string that can not be parsed by json parser, reformat it to a string that can be parsed by json parser.
Original string: {ill_formed_output}
Format instructions: {format_instructions}
Please only generate the JSON:
"""
chain = obtain_chain(
model_name=model_name,
template=template,
input_variables=re.findall(r"{(.*?)}", template),
)
input_values = {
"ill_formed_output": ill_formed_output,
"format_instructions": format_instructions,
}
reformat = chain.predict([logging_handler], **input_values)
log.info(f"Reformated output: {reformat}")
return reformat
def obtain_chain(
model_name: str,
template: str,
input_variables: list[str],
temperature: float = 0.7,
max_retries: int = 6,
) -> LLMChain:
"""
Using langchain to sample profiles for participants
"""
if model_name in ["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit", "mistralai/Mistral-7B-Instruct-v0.1"]:
return obtain_chain_hf(
model_name=model_name,
template=template,
input_variables=input_variables,
temperature=temperature,
max_retries=max_retries,
)
model_name = _return_fixed_model_version(model_name)
chat = ChatLiteLLM(
model=model_name,
temperature=temperature,
max_tokens=2700, # tweak as needed
max_retries=max_retries,
)
human_message_prompt = HumanMessagePromptTemplate(
prompt=PromptTemplate(template=template, input_variables=input_variables)
)
chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
chain = LLMChain(llm=chat, prompt=chat_prompt_template)
return chain
def _return_fixed_model_version(model_name: str) -> str:
model_version_map = {
"gpt-3.5-turbo": "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-finetuned": "ft:gpt-3.5-turbo-0613:academicscmu::8nY2zgdt",
"gpt-3.5-turbo-ft-MF": "ft:gpt-3.5-turbo-0613:academicscmu::8nuER4bO",
"gpt-4": "gpt-4-0613",
"gpt-4-turbo": "gpt-4-1106-preview",
}
return model_version_map[model_name] if model_name in model_version_map else model_name