Spaces:
Running
Running
Ffftdtd5dtft
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -13,14 +13,20 @@ import multiprocessing
|
|
13 |
import io
|
14 |
import time
|
15 |
from tqdm import tqdm
|
|
|
|
|
16 |
|
17 |
hf_token = os.getenv("HF_TOKEN")
|
18 |
redis_host = os.getenv("REDIS_HOST")
|
19 |
redis_port = int(os.getenv("REDIS_PORT", 6379))
|
20 |
redis_password = os.getenv("REDIS_PASSWORD")
|
|
|
|
|
21 |
|
22 |
HfFolder.save_token(hf_token)
|
23 |
|
|
|
|
|
24 |
def connect_to_redis():
|
25 |
while True:
|
26 |
try:
|
@@ -57,6 +63,16 @@ def save_object_to_redis(key, obj):
|
|
57 |
except redis.exceptions.RedisError as e:
|
58 |
print(f"Failed to save object to Redis: {e}")
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
def get_model_or_download(model_id, redis_key, loader_func):
|
61 |
model = load_object_from_redis(redis_key)
|
62 |
if model:
|
@@ -66,6 +82,8 @@ def get_model_or_download(model_id, redis_key, loader_func):
|
|
66 |
model = loader_func(model_id, torch_dtype=torch.float16)
|
67 |
pbar.update(1)
|
68 |
save_object_to_redis(redis_key, model)
|
|
|
|
|
69 |
except Exception as e:
|
70 |
print(f"Failed to load or save model: {e}")
|
71 |
return None
|
@@ -82,6 +100,7 @@ def generate_image(prompt):
|
|
82 |
image.save(buffered, format="JPEG")
|
83 |
image_bytes = buffered.getvalue()
|
84 |
save_object_to_redis(redis_key, image_bytes)
|
|
|
85 |
except Exception as e:
|
86 |
print(f"Failed to generate image: {e}")
|
87 |
return None
|
@@ -100,6 +119,7 @@ def edit_image_with_prompt(image_bytes, prompt, strength=0.75):
|
|
100 |
edited_image.save(buffered, format="JPEG")
|
101 |
edited_image_bytes = buffered.getvalue()
|
102 |
save_object_to_redis(redis_key, edited_image_bytes)
|
|
|
103 |
except Exception as e:
|
104 |
print(f"Failed to edit image: {e}")
|
105 |
return None
|
@@ -115,6 +135,7 @@ def generate_song(prompt, duration=10):
|
|
115 |
pbar.update(1)
|
116 |
song_bytes = song[0].getvalue()
|
117 |
save_object_to_redis(redis_key, song_bytes)
|
|
|
118 |
except Exception as e:
|
119 |
print(f"Failed to generate song: {e}")
|
120 |
return None
|
@@ -129,6 +150,7 @@ def generate_text(prompt):
|
|
129 |
text = text_gen_pipeline(prompt, max_new_tokens=256)[0]["generated_text"].strip()
|
130 |
pbar.update(1)
|
131 |
save_object_to_redis(redis_key, text)
|
|
|
132 |
except Exception as e:
|
133 |
print(f"Failed to generate text: {e}")
|
134 |
return None
|
@@ -152,6 +174,7 @@ def generate_flux_image(prompt):
|
|
152 |
flux_image.save(buffered, format="JPEG")
|
153 |
flux_image_bytes = buffered.getvalue()
|
154 |
save_object_to_redis(redis_key, flux_image_bytes)
|
|
|
155 |
except Exception as e:
|
156 |
print(f"Failed to generate flux image: {e}")
|
157 |
return None
|
@@ -168,6 +191,7 @@ def generate_code(prompt):
|
|
168 |
code = starcoder_tokenizer.decode(outputs[0])
|
169 |
pbar.update(1)
|
170 |
save_object_to_redis(redis_key, code)
|
|
|
171 |
except Exception as e:
|
172 |
print(f"Failed to generate code: {e}")
|
173 |
return None
|
@@ -185,6 +209,7 @@ def generate_video(prompt):
|
|
185 |
video = export_to_video(pipe(prompt, num_inference_steps=25).frames)
|
186 |
pbar.update(1)
|
187 |
save_object_to_redis(redis_key, video)
|
|
|
188 |
except Exception as e:
|
189 |
print(f"Failed to generate video: {e}")
|
190 |
return None
|
@@ -203,6 +228,7 @@ def test_model_meta_llama():
|
|
203 |
response = meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"].strip()
|
204 |
pbar.update(1)
|
205 |
save_object_to_redis(redis_key, response)
|
|
|
206 |
except Exception as e:
|
207 |
print(f"Failed to test Meta-Llama: {e}")
|
208 |
return None
|
@@ -223,6 +249,8 @@ def train_model(model, dataset, epochs, batch_size, learning_rate):
|
|
223 |
pbar.update(epochs)
|
224 |
save_object_to_redis("trained_model", model)
|
225 |
save_object_to_redis("training_results", output_dir.getvalue())
|
|
|
|
|
226 |
except Exception as e:
|
227 |
print(f"Failed to train model: {e}")
|
228 |
|
|
|
13 |
import io
|
14 |
import time
|
15 |
from tqdm import tqdm
|
16 |
+
from google.cloud import storage
|
17 |
+
import json
|
18 |
|
19 |
hf_token = os.getenv("HF_TOKEN")
|
20 |
redis_host = os.getenv("REDIS_HOST")
|
21 |
redis_port = int(os.getenv("REDIS_PORT", 6379))
|
22 |
redis_password = os.getenv("REDIS_PASSWORD")
|
23 |
+
gcs_credentials = json.loads(os.getenv("GCS_CREDENTIALS"))
|
24 |
+
gcs_bucket_name = os.getenv("GCS_BUCKET_NAME")
|
25 |
|
26 |
HfFolder.save_token(hf_token)
|
27 |
|
28 |
+
storage_client = storage.Client.from_service_account_info(gcs_credentials)
|
29 |
+
|
30 |
def connect_to_redis():
|
31 |
while True:
|
32 |
try:
|
|
|
63 |
except redis.exceptions.RedisError as e:
|
64 |
print(f"Failed to save object to Redis: {e}")
|
65 |
|
66 |
+
def upload_to_gcs(bucket_name, blob_name, data):
|
67 |
+
bucket = storage_client.bucket(bucket_name)
|
68 |
+
blob = bucket.blob(blob_name)
|
69 |
+
blob.upload_from_string(data)
|
70 |
+
|
71 |
+
def download_from_gcs(bucket_name, blob_name):
|
72 |
+
bucket = storage_client.bucket(bucket_name)
|
73 |
+
blob = bucket.blob(blob_name)
|
74 |
+
return blob.download_as_bytes()
|
75 |
+
|
76 |
def get_model_or_download(model_id, redis_key, loader_func):
|
77 |
model = load_object_from_redis(redis_key)
|
78 |
if model:
|
|
|
82 |
model = loader_func(model_id, torch_dtype=torch.float16)
|
83 |
pbar.update(1)
|
84 |
save_object_to_redis(redis_key, model)
|
85 |
+
model_bytes = pickle.dumps(model)
|
86 |
+
upload_to_gcs(gcs_bucket_name, redis_key, model_bytes)
|
87 |
except Exception as e:
|
88 |
print(f"Failed to load or save model: {e}")
|
89 |
return None
|
|
|
100 |
image.save(buffered, format="JPEG")
|
101 |
image_bytes = buffered.getvalue()
|
102 |
save_object_to_redis(redis_key, image_bytes)
|
103 |
+
upload_to_gcs(gcs_bucket_name, redis_key, image_bytes)
|
104 |
except Exception as e:
|
105 |
print(f"Failed to generate image: {e}")
|
106 |
return None
|
|
|
119 |
edited_image.save(buffered, format="JPEG")
|
120 |
edited_image_bytes = buffered.getvalue()
|
121 |
save_object_to_redis(redis_key, edited_image_bytes)
|
122 |
+
upload_to_gcs(gcs_bucket_name, redis_key, edited_image_bytes)
|
123 |
except Exception as e:
|
124 |
print(f"Failed to edit image: {e}")
|
125 |
return None
|
|
|
135 |
pbar.update(1)
|
136 |
song_bytes = song[0].getvalue()
|
137 |
save_object_to_redis(redis_key, song_bytes)
|
138 |
+
upload_to_gcs(gcs_bucket_name, redis_key, song_bytes)
|
139 |
except Exception as e:
|
140 |
print(f"Failed to generate song: {e}")
|
141 |
return None
|
|
|
150 |
text = text_gen_pipeline(prompt, max_new_tokens=256)[0]["generated_text"].strip()
|
151 |
pbar.update(1)
|
152 |
save_object_to_redis(redis_key, text)
|
153 |
+
upload_to_gcs(gcs_bucket_name, redis_key, text.encode())
|
154 |
except Exception as e:
|
155 |
print(f"Failed to generate text: {e}")
|
156 |
return None
|
|
|
174 |
flux_image.save(buffered, format="JPEG")
|
175 |
flux_image_bytes = buffered.getvalue()
|
176 |
save_object_to_redis(redis_key, flux_image_bytes)
|
177 |
+
upload_to_gcs(gcs_bucket_name, redis_key, flux_image_bytes)
|
178 |
except Exception as e:
|
179 |
print(f"Failed to generate flux image: {e}")
|
180 |
return None
|
|
|
191 |
code = starcoder_tokenizer.decode(outputs[0])
|
192 |
pbar.update(1)
|
193 |
save_object_to_redis(redis_key, code)
|
194 |
+
upload_to_gcs(gcs_bucket_name, redis_key, code.encode())
|
195 |
except Exception as e:
|
196 |
print(f"Failed to generate code: {e}")
|
197 |
return None
|
|
|
209 |
video = export_to_video(pipe(prompt, num_inference_steps=25).frames)
|
210 |
pbar.update(1)
|
211 |
save_object_to_redis(redis_key, video)
|
212 |
+
upload_to_gcs(gcs_bucket_name, redis_key, video.encode())
|
213 |
except Exception as e:
|
214 |
print(f"Failed to generate video: {e}")
|
215 |
return None
|
|
|
228 |
response = meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"].strip()
|
229 |
pbar.update(1)
|
230 |
save_object_to_redis(redis_key, response)
|
231 |
+
upload_to_gcs(gcs_bucket_name, redis_key, response.encode())
|
232 |
except Exception as e:
|
233 |
print(f"Failed to test Meta-Llama: {e}")
|
234 |
return None
|
|
|
249 |
pbar.update(epochs)
|
250 |
save_object_to_redis("trained_model", model)
|
251 |
save_object_to_redis("training_results", output_dir.getvalue())
|
252 |
+
upload_to_gcs(gcs_bucket_name, "trained_model", pickle.dumps(model))
|
253 |
+
upload_to_gcs(gcs_bucket_name, "training_results", output_dir.getvalue())
|
254 |
except Exception as e:
|
255 |
print(f"Failed to train model: {e}")
|
256 |
|