Spaces:
Runtime error
Runtime error
File size: 4,035 Bytes
97822ab 3547e01 97822ab 9af533e 97822ab abcfdac 97822ab 4cb428b 97822ab 4cb428b 97822ab 4cb428b 97822ab 4cb428b 97822ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from transformers import AutoTokenizer, MistralForCausalLM
import torch
import gradio as gr
import random
from textwrap import wrap
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
from peft import PeftModel, PeftConfig
import torch
import gradio as gr
import os
import huggingface
from huggingface_hub import login
hf_token = os.environ.get('HUGGINGFACE_TOKEN')
login(hf_token)
# Functions to Wrap the Prompt Correctly
def wrap_text(text, width=90):
lines = text.split('\n')
wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
wrapped_text = '\n'.join(wrapped_lines)
return wrapped_text
def multimodal_prompt(user_input, system_prompt="You are an expert medical analyst:"):
# Combine user input and system prompt
formatted_input = f"{user_input}{system_prompt}"
# Encode the input text
encodeds = tokenizer(formatted_input, return_tensors="pt", add_special_tokens=False)
model_inputs = encodeds.to(device)
# Generate a response using the model
output = model.generate(
**model_inputs,
max_length=max_length,
use_cache=True,
early_stopping=True,
bos_token_id=model.config.bos_token_id,
eos_token_id=model.config.eos_token_id,
pad_token_id=model.config.eos_token_id,
temperature=0.1,
do_sample=True
)
# Decode the response
response_text = tokenizer.decode(output[0], skip_special_tokens=True)
return response_text
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Use the base model's ID
base_model_id = "stabilityai/stablelm-3b-4e1t"
model_directory = "vaishakgkumar/stablemedv3"
# Instantiate the Tokenizer
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-3b-4e1t", token=hf_token, trust_remote_code=True, padding_side="left")
# tokenizer = AutoTokenizer.from_pretrained("vaishakgkumar/stablemedv3", trust_remote_code=True, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
# Load the PEFT model
peft_config = PeftConfig.from_pretrained("vaishakgkumar/stablemedv3", token=hf_token)
peft_model = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-3b-4e1t", token=hf_token, trust_remote_code=True)
peft_model = PeftModel.from_pretrained(peft_model, "vaishakgkumar/stablemedv3", token=hf_token)
class ChatBot:
def __init__(self):
self.history = []
def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
# Combine user input and system prompt
formatted_input = f"{user_input}{system_prompt}"
# Encode user input
user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")
# Concatenate the user input with chat history
if len(self.history) > 0:
chat_history_ids = torch.cat([self.history, user_input_ids], dim=-1)
else:
chat_history_ids = user_input_ids
# Generate a response using the PEFT model
response = peft_model.generate(input_ids=chat_history_ids, max_length=1200, pad_token_id=tokenizer.eos_token_id)
# Update chat history
self.history = chat_history_ids
# Decode and return the response
response_text = tokenizer.decode(response[0], skip_special_tokens=True)
return response_text
bot = ChatBot()
title = "👋🏻Welcome to Tonic's 😷StableMed⚕️ Chat🦟"
description = """
You can use this Space to test out the current model vaishakgkumar/stablemedv3
"""
examples = [["What is the proper treatment for buccal herpes?", "Please provide information on the most effective antiviral medications and home remedies for treating buccal herpes."]]
iface = gr.Interface(
fn=bot.predict,
title=title,
description=description,
examples=examples,
inputs=["text", "text"], # Take user input and system prompt separately
outputs="text",
theme="ParityError/Anime"
)
iface.launch() |