Spaces:
Running
Running
Ffftdtd5dtft
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,20 +1,18 @@
|
|
|
|
|
|
1 |
import redis
|
2 |
import pickle
|
3 |
import torch
|
4 |
from PIL import Image
|
5 |
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
|
6 |
from diffusers.utils import export_to_video
|
7 |
-
from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
8 |
from audiocraft.models import MusicGen
|
9 |
import gradio as gr
|
10 |
from huggingface_hub import snapshot_download, HfApi, HfFolder
|
11 |
import multiprocessing
|
12 |
import io
|
13 |
-
|
14 |
-
import os
|
15 |
-
|
16 |
-
# Cargar las variables del archivo .env
|
17 |
-
load_dotenv()
|
18 |
|
19 |
# Obtener las variables de entorno
|
20 |
hf_token = os.getenv("HF_TOKEN")
|
@@ -22,26 +20,44 @@ redis_host = os.getenv("REDIS_HOST")
|
|
22 |
redis_port = os.getenv("REDIS_PORT")
|
23 |
redis_password = os.getenv("REDIS_PASSWORD")
|
24 |
|
25 |
-
# Usar las variables de huggingface
|
26 |
HfFolder.save_token(hf_token)
|
27 |
|
28 |
-
# Usar las variables de redis
|
29 |
def connect_to_redis():
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def load_object_from_redis(key):
|
33 |
-
|
34 |
-
|
|
|
35 |
return pickle.loads(obj_data) if obj_data else None
|
36 |
|
37 |
def save_object_to_redis(key, obj):
|
38 |
-
|
39 |
-
|
|
|
40 |
|
41 |
def get_model_or_download(model_id, redis_key, loader_func):
|
42 |
model = load_object_from_redis(redis_key)
|
43 |
if not model:
|
44 |
-
model = loader_func(model_id,
|
45 |
save_object_to_redis(redis_key, model)
|
46 |
return model
|
47 |
|
@@ -55,7 +71,7 @@ def generate_song(prompt, duration=10):
|
|
55 |
return music_gen.generate(prompt, duration=duration)
|
56 |
|
57 |
def generate_text(prompt):
|
58 |
-
return text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"]
|
59 |
|
60 |
def generate_flux_image(prompt):
|
61 |
return flux_pipeline(
|
@@ -72,7 +88,7 @@ def generate_code(prompt):
|
|
72 |
return starcoder_tokenizer.decode(outputs[0])
|
73 |
|
74 |
def generate_video(prompt):
|
75 |
-
pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16
|
76 |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
77 |
pipe.enable_model_cpu_offload()
|
78 |
return export_to_video(pipe(prompt, num_inference_steps=25).frames)
|
@@ -82,7 +98,7 @@ def test_model_meta_llama():
|
|
82 |
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
|
83 |
{"role": "user", "content": "Who are you?"}
|
84 |
]
|
85 |
-
return meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"]
|
86 |
|
87 |
def train_model(model, dataset, epochs, batch_size, learning_rate):
|
88 |
output_dir = io.BytesIO()
|
@@ -156,4 +172,4 @@ app.launch(share=True)
|
|
156 |
for _ in range(num_processes):
|
157 |
task_queue.put(None)
|
158 |
for p in processes:
|
159 |
-
p.join()
|
|
|
1 |
+
!pip install redis diffusers transformers accelerate torch gradio audiocraft huggingface_hub
|
2 |
+
|
3 |
import redis
|
4 |
import pickle
|
5 |
import torch
|
6 |
from PIL import Image
|
7 |
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
|
8 |
from diffusers.utils import export_to_video
|
9 |
+
from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
|
10 |
from audiocraft.models import MusicGen
|
11 |
import gradio as gr
|
12 |
from huggingface_hub import snapshot_download, HfApi, HfFolder
|
13 |
import multiprocessing
|
14 |
import io
|
15 |
+
import time
|
|
|
|
|
|
|
|
|
16 |
|
17 |
# Obtener las variables de entorno
|
18 |
hf_token = os.getenv("HF_TOKEN")
|
|
|
20 |
redis_port = os.getenv("REDIS_PORT")
|
21 |
redis_password = os.getenv("REDIS_PASSWORD")
|
22 |
|
|
|
23 |
HfFolder.save_token(hf_token)
|
24 |
|
|
|
25 |
def connect_to_redis():
|
26 |
+
max_retries = 5
|
27 |
+
retry_delay = 1
|
28 |
+
for attempt in range(max_retries):
|
29 |
+
try:
|
30 |
+
redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_password)
|
31 |
+
redis_client.ping()
|
32 |
+
return redis_client
|
33 |
+
except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError, BrokenPipeError) as e:
|
34 |
+
print(f"Attempt {attempt + 1}: Connection to Redis failed: {e}. Retrying in {retry_delay} seconds...")
|
35 |
+
time.sleep(retry_delay)
|
36 |
+
raise ConnectionError("Failed to connect to Redis after multiple retries.")
|
37 |
+
|
38 |
+
def reconnect_if_needed(redis_client):
|
39 |
+
try:
|
40 |
+
redis_client.ping()
|
41 |
+
except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError, BrokenPipeError):
|
42 |
+
print("Reconnecting to Redis...")
|
43 |
+
return connect_to_redis()
|
44 |
+
return redis_client
|
45 |
|
46 |
def load_object_from_redis(key):
|
47 |
+
redis_client = connect_to_redis()
|
48 |
+
redis_client = reconnect_if_needed(redis_client)
|
49 |
+
obj_data = redis_client.get(key)
|
50 |
return pickle.loads(obj_data) if obj_data else None
|
51 |
|
52 |
def save_object_to_redis(key, obj):
|
53 |
+
redis_client = connect_to_redis()
|
54 |
+
redis_client = reconnect_if_needed(redis_client)
|
55 |
+
redis_client.set(key, pickle.dumps(obj))
|
56 |
|
57 |
def get_model_or_download(model_id, redis_key, loader_func):
|
58 |
model = load_object_from_redis(redis_key)
|
59 |
if not model:
|
60 |
+
model = loader_func(model_id, torch_dtype=torch.float16)
|
61 |
save_object_to_redis(redis_key, model)
|
62 |
return model
|
63 |
|
|
|
71 |
return music_gen.generate(prompt, duration=duration)
|
72 |
|
73 |
def generate_text(prompt):
|
74 |
+
return text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"].strip()
|
75 |
|
76 |
def generate_flux_image(prompt):
|
77 |
return flux_pipeline(
|
|
|
88 |
return starcoder_tokenizer.decode(outputs[0])
|
89 |
|
90 |
def generate_video(prompt):
|
91 |
+
pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16)
|
92 |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
93 |
pipe.enable_model_cpu_offload()
|
94 |
return export_to_video(pipe(prompt, num_inference_steps=25).frames)
|
|
|
98 |
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
|
99 |
{"role": "user", "content": "Who are you?"}
|
100 |
]
|
101 |
+
return meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"].strip()
|
102 |
|
103 |
def train_model(model, dataset, epochs, batch_size, learning_rate):
|
104 |
output_dir = io.BytesIO()
|
|
|
172 |
for _ in range(num_processes):
|
173 |
task_queue.put(None)
|
174 |
for p in processes:
|
175 |
+
p.join()
|