Ffftdtd5dtft commited on
Commit
2779600
·
verified ·
1 Parent(s): 5fcbc21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -18
app.py CHANGED
@@ -1,20 +1,18 @@
 
 
1
  import redis
2
  import pickle
3
  import torch
4
  from PIL import Image
5
  from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
6
  from diffusers.utils import export_to_video
7
- from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer, TrainingArguments
8
  from audiocraft.models import MusicGen
9
  import gradio as gr
10
  from huggingface_hub import snapshot_download, HfApi, HfFolder
11
  import multiprocessing
12
  import io
13
- from dotenv import load_dotenv
14
- import os
15
-
16
- # Cargar las variables del archivo .env
17
- load_dotenv()
18
 
19
  # Obtener las variables de entorno
20
  hf_token = os.getenv("HF_TOKEN")
@@ -22,26 +20,44 @@ redis_host = os.getenv("REDIS_HOST")
22
  redis_port = os.getenv("REDIS_PORT")
23
  redis_password = os.getenv("REDIS_PASSWORD")
24
 
25
- # Usar las variables de huggingface
26
  HfFolder.save_token(hf_token)
27
 
28
- # Usar las variables de redis
29
  def connect_to_redis():
30
- return redis.Redis(host=redis_host, port=redis_port, password=redis_password)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def load_object_from_redis(key):
33
- with connect_to_redis() as redis_client:
34
- obj_data = redis_client.get(key)
 
35
  return pickle.loads(obj_data) if obj_data else None
36
 
37
  def save_object_to_redis(key, obj):
38
- with connect_to_redis() as redis_client:
39
- redis_client.set(key, pickle.dumps(obj))
 
40
 
41
  def get_model_or_download(model_id, redis_key, loader_func):
42
  model = load_object_from_redis(redis_key)
43
  if not model:
44
- model = loader_func(model_id, use_auth_token=hf_token, torch_dtype=torch.float16)
45
  save_object_to_redis(redis_key, model)
46
  return model
47
 
@@ -55,7 +71,7 @@ def generate_song(prompt, duration=10):
55
  return music_gen.generate(prompt, duration=duration)
56
 
57
  def generate_text(prompt):
58
- return text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"][-1]["content"].strip()
59
 
60
  def generate_flux_image(prompt):
61
  return flux_pipeline(
@@ -72,7 +88,7 @@ def generate_code(prompt):
72
  return starcoder_tokenizer.decode(outputs[0])
73
 
74
  def generate_video(prompt):
75
- pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
76
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
77
  pipe.enable_model_cpu_offload()
78
  return export_to_video(pipe(prompt, num_inference_steps=25).frames)
@@ -82,7 +98,7 @@ def test_model_meta_llama():
82
  {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
83
  {"role": "user", "content": "Who are you?"}
84
  ]
85
- return meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"][-1]
86
 
87
  def train_model(model, dataset, epochs, batch_size, learning_rate):
88
  output_dir = io.BytesIO()
@@ -156,4 +172,4 @@ app.launch(share=True)
156
  for _ in range(num_processes):
157
  task_queue.put(None)
158
  for p in processes:
159
- p.join()
 
1
+ !pip install redis diffusers transformers accelerate torch gradio audiocraft huggingface_hub
2
+
3
  import redis
4
  import pickle
5
  import torch
6
  from PIL import Image
7
  from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
8
  from diffusers.utils import export_to_video
9
+ from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
10
  from audiocraft.models import MusicGen
11
  import gradio as gr
12
  from huggingface_hub import snapshot_download, HfApi, HfFolder
13
  import multiprocessing
14
  import io
15
+ import time
 
 
 
 
16
 
17
  # Obtener las variables de entorno
18
  hf_token = os.getenv("HF_TOKEN")
 
20
  redis_port = os.getenv("REDIS_PORT")
21
  redis_password = os.getenv("REDIS_PASSWORD")
22
 
 
23
  HfFolder.save_token(hf_token)
24
 
 
25
  def connect_to_redis():
26
+ max_retries = 5
27
+ retry_delay = 1
28
+ for attempt in range(max_retries):
29
+ try:
30
+ redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_password)
31
+ redis_client.ping()
32
+ return redis_client
33
+ except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError, BrokenPipeError) as e:
34
+ print(f"Attempt {attempt + 1}: Connection to Redis failed: {e}. Retrying in {retry_delay} seconds...")
35
+ time.sleep(retry_delay)
36
+ raise ConnectionError("Failed to connect to Redis after multiple retries.")
37
+
38
+ def reconnect_if_needed(redis_client):
39
+ try:
40
+ redis_client.ping()
41
+ except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError, BrokenPipeError):
42
+ print("Reconnecting to Redis...")
43
+ return connect_to_redis()
44
+ return redis_client
45
 
46
  def load_object_from_redis(key):
47
+ redis_client = connect_to_redis()
48
+ redis_client = reconnect_if_needed(redis_client)
49
+ obj_data = redis_client.get(key)
50
  return pickle.loads(obj_data) if obj_data else None
51
 
52
  def save_object_to_redis(key, obj):
53
+ redis_client = connect_to_redis()
54
+ redis_client = reconnect_if_needed(redis_client)
55
+ redis_client.set(key, pickle.dumps(obj))
56
 
57
  def get_model_or_download(model_id, redis_key, loader_func):
58
  model = load_object_from_redis(redis_key)
59
  if not model:
60
+ model = loader_func(model_id, torch_dtype=torch.float16)
61
  save_object_to_redis(redis_key, model)
62
  return model
63
 
 
71
  return music_gen.generate(prompt, duration=duration)
72
 
73
  def generate_text(prompt):
74
+ return text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"].strip()
75
 
76
  def generate_flux_image(prompt):
77
  return flux_pipeline(
 
88
  return starcoder_tokenizer.decode(outputs[0])
89
 
90
  def generate_video(prompt):
91
+ pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16)
92
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
93
  pipe.enable_model_cpu_offload()
94
  return export_to_video(pipe(prompt, num_inference_steps=25).frames)
 
98
  {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
99
  {"role": "user", "content": "Who are you?"}
100
  ]
101
+ return meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"].strip()
102
 
103
  def train_model(model, dataset, epochs, batch_size, learning_rate):
104
  output_dir = io.BytesIO()
 
172
  for _ in range(num_processes):
173
  task_queue.put(None)
174
  for p in processes:
175
+ p.join()