gfgf / app.py
Ffftdtd5dtft's picture
Upload 2 files
f13c41f verified
raw
history blame
7.03 kB
import redis
import pickle
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.utils import export_to_video
from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from audiocraft.models import MusicGen
import gradio as gr
from huggingface_hub import snapshot_download, HfApi, HfFolder
import multiprocessing
import io
from dotenv import load_dotenv
import os
# Cargar las variables del archivo .env
load_dotenv()
# Obtener las variables de entorno
hf_token = os.getenv("HF_TOKEN")
redis_host = os.getenv("REDIS_HOST")
redis_port = os.getenv("REDIS_PORT")
redis_password = os.getenv("REDIS_PASSWORD")
# Usar las variables de huggingface
HfFolder.save_token(hf_token)
# Usar las variables de redis
def connect_to_redis():
return redis.Redis(host=redis_host, port=redis_port, password=redis_password)
def load_object_from_redis(key):
with connect_to_redis() as redis_client:
obj_data = redis_client.get(key)
return pickle.loads(obj_data) if obj_data else None
def save_object_to_redis(key, obj):
with connect_to_redis() as redis_client:
redis_client.set(key, pickle.dumps(obj))
def get_model_or_download(model_id, redis_key, loader_func):
model = load_object_from_redis(redis_key)
if not model:
model = loader_func(model_id, use_auth_token=hf_token, torch_dtype=torch.float16)
save_object_to_redis(redis_key, model)
return model
def generate_image(prompt):
return text_to_image_pipeline(prompt).images[0]
def edit_image_with_prompt(image, prompt, strength=0.75):
return img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0]
def generate_song(prompt, duration=10):
return music_gen.generate(prompt, duration=duration)
def generate_text(prompt):
return text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"][-1]["content"].strip()
def generate_flux_image(prompt):
return flux_pipeline(
prompt,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
def generate_code(prompt):
inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to("cuda")
outputs = starcoder_model.generate(inputs)
return starcoder_tokenizer.decode(outputs[0])
def generate_video(prompt):
pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
return export_to_video(pipe(prompt, num_inference_steps=25).frames)
def test_model_meta_llama():
messages = [
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
{"role": "user", "content": "Who are you?"}
]
return meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"][-1]
def train_model(model, dataset, epochs, batch_size, learning_rate):
output_dir = io.BytesIO()
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
learning_rate=learning_rate,
)
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
trainer.train()
save_object_to_redis("trained_model", model)
save_object_to_redis("training_results", output_dir.getvalue())
def run_task(task_queue):
while True:
task = task_queue.get()
if task is None:
break
func, args, kwargs = task
func(*args, **kwargs)
task_queue = multiprocessing.Queue()
num_processes = multiprocessing.cpu_count()
processes = []
for _ in range(num_processes):
p = multiprocessing.Process(target=run_task, args=(task_queue,))
p.start()
processes.append(p)
device = "cuda" if torch.cuda.is_available() else "cpu"
text_to_image_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "text_to_image_model", StableDiffusionPipeline.from_pretrained).to(device)
img2img_pipeline = get_model_or_download("runwayml/stable-diffusion-inpainting", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained).to(device)
flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained)
flux_pipeline.enable_model_cpu_offload()
music_gen = load_object_from_redis("music_gen") or MusicGen.get_pretrained('melody', use_auth_token=hf_token)
save_object_to_redis("music_gen", music_gen)
text_gen_pipeline = load_object_from_redis("text_gen_pipeline") or transformers_pipeline(
"text-generation",
model="google/gemma-2-2b-it",
model_kwargs={"torch_dtype": torch.bfloat16},
device=device,
use_auth_token=hf_token,
)
save_object_to_redis("text_gen_pipeline", text_gen_pipeline)
starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-15b", use_auth_token=hf_token)
starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder2-15b", device_map="auto", torch_dtype=torch.bfloat16, use_auth_token=hf_token)
meta_llama_pipeline = transformers_pipeline(
"text-generation",
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
use_auth_token=hf_token
)
gen_image_tab = gr.Interface(generate_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate Images")
edit_image_tab = gr.Interface(edit_image_with_prompt, [gr.inputs.Image(type="pil", label="Image:"), gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:")], gr.outputs.Image(type="pil"), title="Edit Images")
generate_song_tab = gr.Interface(generate_song, [gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(5, 60, 10, step=1, label="Duration (s):")], gr.outputs.Audio(type="numpy"), title="Generate Songs")
generate_text_tab = gr.Interface(generate_text, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Textbox(label="Generated Text:"), title="Generate Text")
generate_flux_image_tab = gr.Interface(generate_flux_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate FLUX Images")
model_meta_llama_test_tab = gr.Interface(test_model_meta_llama, gr.inputs.Textbox(label="Test Input:"), gr.outputs.Textbox(label="Model Output:"), title="Test Meta-Llama")
app = gr.TabbedInterface(
[gen_image_tab, edit_image_tab, generate_song_tab, generate_text_tab, generate_flux_image_tab, model_meta_llama_test_tab],
["Generate Image", "Edit Image", "Generate Song", "Generate Text", "Generate FLUX Image", "Test Meta-Llama"]
)
app.launch(share=True)
for _ in range(num_processes):
task_queue.put(None)
for p in processes:
p.join()