Ffftdtd5dtft commited on
Commit
e856b6e
·
verified ·
1 Parent(s): 820691f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -0
app.py CHANGED
@@ -13,14 +13,20 @@ import multiprocessing
13
  import io
14
  import time
15
  from tqdm import tqdm
 
 
16
 
17
  hf_token = os.getenv("HF_TOKEN")
18
  redis_host = os.getenv("REDIS_HOST")
19
  redis_port = int(os.getenv("REDIS_PORT", 6379))
20
  redis_password = os.getenv("REDIS_PASSWORD")
 
 
21
 
22
  HfFolder.save_token(hf_token)
23
 
 
 
24
  def connect_to_redis():
25
  while True:
26
  try:
@@ -57,6 +63,16 @@ def save_object_to_redis(key, obj):
57
  except redis.exceptions.RedisError as e:
58
  print(f"Failed to save object to Redis: {e}")
59
 
 
 
 
 
 
 
 
 
 
 
60
  def get_model_or_download(model_id, redis_key, loader_func):
61
  model = load_object_from_redis(redis_key)
62
  if model:
@@ -66,6 +82,8 @@ def get_model_or_download(model_id, redis_key, loader_func):
66
  model = loader_func(model_id, torch_dtype=torch.float16)
67
  pbar.update(1)
68
  save_object_to_redis(redis_key, model)
 
 
69
  except Exception as e:
70
  print(f"Failed to load or save model: {e}")
71
  return None
@@ -82,6 +100,7 @@ def generate_image(prompt):
82
  image.save(buffered, format="JPEG")
83
  image_bytes = buffered.getvalue()
84
  save_object_to_redis(redis_key, image_bytes)
 
85
  except Exception as e:
86
  print(f"Failed to generate image: {e}")
87
  return None
@@ -100,6 +119,7 @@ def edit_image_with_prompt(image_bytes, prompt, strength=0.75):
100
  edited_image.save(buffered, format="JPEG")
101
  edited_image_bytes = buffered.getvalue()
102
  save_object_to_redis(redis_key, edited_image_bytes)
 
103
  except Exception as e:
104
  print(f"Failed to edit image: {e}")
105
  return None
@@ -115,6 +135,7 @@ def generate_song(prompt, duration=10):
115
  pbar.update(1)
116
  song_bytes = song[0].getvalue()
117
  save_object_to_redis(redis_key, song_bytes)
 
118
  except Exception as e:
119
  print(f"Failed to generate song: {e}")
120
  return None
@@ -129,6 +150,7 @@ def generate_text(prompt):
129
  text = text_gen_pipeline(prompt, max_new_tokens=256)[0]["generated_text"].strip()
130
  pbar.update(1)
131
  save_object_to_redis(redis_key, text)
 
132
  except Exception as e:
133
  print(f"Failed to generate text: {e}")
134
  return None
@@ -152,6 +174,7 @@ def generate_flux_image(prompt):
152
  flux_image.save(buffered, format="JPEG")
153
  flux_image_bytes = buffered.getvalue()
154
  save_object_to_redis(redis_key, flux_image_bytes)
 
155
  except Exception as e:
156
  print(f"Failed to generate flux image: {e}")
157
  return None
@@ -168,6 +191,7 @@ def generate_code(prompt):
168
  code = starcoder_tokenizer.decode(outputs[0])
169
  pbar.update(1)
170
  save_object_to_redis(redis_key, code)
 
171
  except Exception as e:
172
  print(f"Failed to generate code: {e}")
173
  return None
@@ -185,6 +209,7 @@ def generate_video(prompt):
185
  video = export_to_video(pipe(prompt, num_inference_steps=25).frames)
186
  pbar.update(1)
187
  save_object_to_redis(redis_key, video)
 
188
  except Exception as e:
189
  print(f"Failed to generate video: {e}")
190
  return None
@@ -203,6 +228,7 @@ def test_model_meta_llama():
203
  response = meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"].strip()
204
  pbar.update(1)
205
  save_object_to_redis(redis_key, response)
 
206
  except Exception as e:
207
  print(f"Failed to test Meta-Llama: {e}")
208
  return None
@@ -223,6 +249,8 @@ def train_model(model, dataset, epochs, batch_size, learning_rate):
223
  pbar.update(epochs)
224
  save_object_to_redis("trained_model", model)
225
  save_object_to_redis("training_results", output_dir.getvalue())
 
 
226
  except Exception as e:
227
  print(f"Failed to train model: {e}")
228
 
 
13
  import io
14
  import time
15
  from tqdm import tqdm
16
+ from google.cloud import storage
17
+ import json
18
 
19
  hf_token = os.getenv("HF_TOKEN")
20
  redis_host = os.getenv("REDIS_HOST")
21
  redis_port = int(os.getenv("REDIS_PORT", 6379))
22
  redis_password = os.getenv("REDIS_PASSWORD")
23
+ gcs_credentials = json.loads(os.getenv("GCS_CREDENTIALS"))
24
+ gcs_bucket_name = os.getenv("GCS_BUCKET_NAME")
25
 
26
  HfFolder.save_token(hf_token)
27
 
28
+ storage_client = storage.Client.from_service_account_info(gcs_credentials)
29
+
30
  def connect_to_redis():
31
  while True:
32
  try:
 
63
  except redis.exceptions.RedisError as e:
64
  print(f"Failed to save object to Redis: {e}")
65
 
66
+ def upload_to_gcs(bucket_name, blob_name, data):
67
+ bucket = storage_client.bucket(bucket_name)
68
+ blob = bucket.blob(blob_name)
69
+ blob.upload_from_string(data)
70
+
71
+ def download_from_gcs(bucket_name, blob_name):
72
+ bucket = storage_client.bucket(bucket_name)
73
+ blob = bucket.blob(blob_name)
74
+ return blob.download_as_bytes()
75
+
76
  def get_model_or_download(model_id, redis_key, loader_func):
77
  model = load_object_from_redis(redis_key)
78
  if model:
 
82
  model = loader_func(model_id, torch_dtype=torch.float16)
83
  pbar.update(1)
84
  save_object_to_redis(redis_key, model)
85
+ model_bytes = pickle.dumps(model)
86
+ upload_to_gcs(gcs_bucket_name, redis_key, model_bytes)
87
  except Exception as e:
88
  print(f"Failed to load or save model: {e}")
89
  return None
 
100
  image.save(buffered, format="JPEG")
101
  image_bytes = buffered.getvalue()
102
  save_object_to_redis(redis_key, image_bytes)
103
+ upload_to_gcs(gcs_bucket_name, redis_key, image_bytes)
104
  except Exception as e:
105
  print(f"Failed to generate image: {e}")
106
  return None
 
119
  edited_image.save(buffered, format="JPEG")
120
  edited_image_bytes = buffered.getvalue()
121
  save_object_to_redis(redis_key, edited_image_bytes)
122
+ upload_to_gcs(gcs_bucket_name, redis_key, edited_image_bytes)
123
  except Exception as e:
124
  print(f"Failed to edit image: {e}")
125
  return None
 
135
  pbar.update(1)
136
  song_bytes = song[0].getvalue()
137
  save_object_to_redis(redis_key, song_bytes)
138
+ upload_to_gcs(gcs_bucket_name, redis_key, song_bytes)
139
  except Exception as e:
140
  print(f"Failed to generate song: {e}")
141
  return None
 
150
  text = text_gen_pipeline(prompt, max_new_tokens=256)[0]["generated_text"].strip()
151
  pbar.update(1)
152
  save_object_to_redis(redis_key, text)
153
+ upload_to_gcs(gcs_bucket_name, redis_key, text.encode())
154
  except Exception as e:
155
  print(f"Failed to generate text: {e}")
156
  return None
 
174
  flux_image.save(buffered, format="JPEG")
175
  flux_image_bytes = buffered.getvalue()
176
  save_object_to_redis(redis_key, flux_image_bytes)
177
+ upload_to_gcs(gcs_bucket_name, redis_key, flux_image_bytes)
178
  except Exception as e:
179
  print(f"Failed to generate flux image: {e}")
180
  return None
 
191
  code = starcoder_tokenizer.decode(outputs[0])
192
  pbar.update(1)
193
  save_object_to_redis(redis_key, code)
194
+ upload_to_gcs(gcs_bucket_name, redis_key, code.encode())
195
  except Exception as e:
196
  print(f"Failed to generate code: {e}")
197
  return None
 
209
  video = export_to_video(pipe(prompt, num_inference_steps=25).frames)
210
  pbar.update(1)
211
  save_object_to_redis(redis_key, video)
212
+ upload_to_gcs(gcs_bucket_name, redis_key, video.encode())
213
  except Exception as e:
214
  print(f"Failed to generate video: {e}")
215
  return None
 
228
  response = meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"].strip()
229
  pbar.update(1)
230
  save_object_to_redis(redis_key, response)
231
+ upload_to_gcs(gcs_bucket_name, redis_key, response.encode())
232
  except Exception as e:
233
  print(f"Failed to test Meta-Llama: {e}")
234
  return None
 
249
  pbar.update(epochs)
250
  save_object_to_redis("trained_model", model)
251
  save_object_to_redis("training_results", output_dir.getvalue())
252
+ upload_to_gcs(gcs_bucket_name, "trained_model", pickle.dumps(model))
253
+ upload_to_gcs(gcs_bucket_name, "training_results", output_dir.getvalue())
254
  except Exception as e:
255
  print(f"Failed to train model: {e}")
256