Spaces:
Runtime error
Runtime error
import transformers | |
import argparse | |
import json | |
from petals.client.remote_model import DistributedBloomForCausalLM | |
from personalized_chat_bot import PersonalizedChatBot, PersonalityManager | |
from models.personality_clustering import PersonalityClustering | |
def load_config(path): | |
with open(path, 'r') as f: | |
config = json.load(f) | |
return argparse.Namespace(**config) | |
def main(): | |
greating = 'Describe the person you want to talk:' | |
print(greating) | |
persona_description = input() | |
print('Cool! wait a few seconds...') | |
personality_clustering = PersonalityClustering() | |
personality_clustering.load('./data/models/personality_clustering_500_paraphrase-MiniLM-L6-v2_k-means.pkl') | |
hook = lambda dct: {int(k): v for k, v in dct.items()} | |
with open('prompt_paths.json', 'r') as f: | |
prompt_paths = json.load(f, object_hook=hook) | |
pm = PersonalityManager(prompt_paths, personality_clustering) | |
prompt_path, closest_persona = pm.get_prompt(persona_description) | |
print(f'The closest personality is: {closest_persona}') | |
print('Wait a little longer...') | |
config = load_config('./scripts/config_176b.json') | |
model = DistributedBloomForCausalLM.from_pretrained( | |
config.MODEL_NAME, | |
pre_seq_len=config.NUM_PREFIX_TOKENS, | |
tuning_mode=config.TUNING_MODE | |
).to(config.DEVICE) | |
generation_config = load_config('generation_config.json') | |
tokenizer = transformers.BloomTokenizerFast.from_pretrained(config.MODEL_NAME) | |
tokenizer.padding_side = 'right' | |
tokenizer.model_max_length = config.MODEL_MAX_LENGTH | |
chatbot = PersonalizedChatBot(model, tokenizer, generation_config=generation_config) | |
chatbot.load_prompt(prompt_path) | |
print('Done! You can start a dialogue.') | |
try: | |
while True: | |
text = input('You: ') | |
answer = chatbot.answer(text) | |
print(f'Bloom: {answer}') | |
except KeyboardInterrupt: | |
print('Thank you for the conversation!') | |
if __name__ == '__main__': | |
main() |