rynmurdock commited on
Commit
86d2837
·
1 Parent(s): 9c7e8e1

limit to 10 rows from 1 user for diversity.

Browse files
Files changed (1) hide show
  1. app.py +39 -23
app.py CHANGED
@@ -1,6 +1,6 @@
1
 
2
 
3
-
4
  # TODO save & restart from (if it exists) dataframe parquet
5
  import torch
6
 
@@ -37,12 +37,9 @@ torch.set_grad_enabled(False)
37
  torch.backends.cuda.matmul.allow_tf32 = True
38
  torch.backends.cudnn.allow_tf32 = True
39
 
40
- prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate'])
41
 
42
  import spaces
43
- prompt_list = [p for p in list(set(
44
- pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
45
-
46
  start_time = time.time()
47
 
48
  ####################### Setup Model
@@ -55,13 +52,13 @@ from transformers import CLIPVisionModelWithProjection
55
  import uuid
56
  import av
57
 
58
- def write_video(file_name, images, fps=17):
59
  print('Saving')
60
  container = av.open(file_name, mode="w")
61
 
62
  stream = container.add_stream("h264", rate=fps)
63
  # stream.options = {'preset': 'faster'}
64
- stream.thread_count = 0
65
  stream.width = 512
66
  stream.height = 512
67
  stream.pix_fmt = "yuv420p"
@@ -79,8 +76,16 @@ def write_video(file_name, images, fps=17):
79
  container.close()
80
  print('Saved')
81
 
 
 
 
 
 
 
 
82
 
83
- image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="sdxl_models/image_encoder", torch_dtype=dtype).to(DEVICE)
 
84
  #vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=dtype)
85
 
86
  # vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
@@ -91,8 +96,9 @@ image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter",
91
  #text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
92
 
93
 
94
- unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet').to(dtype)
95
- text_encoder = CLIPTextModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='text_encoder').to(dtype)
 
96
 
97
  adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
98
  pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype, unet=unet, text_encoder=text_encoder)
@@ -101,6 +107,7 @@ pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_
101
  pipe.set_adapters(["lcm-lora"], [.9])
102
  pipe.fuse_lora()
103
 
 
104
  #pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder)
105
  #pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
106
  #repo = "ByteDance/AnimateDiff-Lightning"
@@ -116,8 +123,7 @@ pipe.unet.fuse_qkv_projections()
116
  pipe.to(device=DEVICE)
117
  #pipe.unet = torch.compile(pipe.unet)
118
  #pipe.vae = torch.compile(pipe.vae)
119
-
120
-
121
  #im_embs = torch.zeros(1, 1, 1, 1280, device=DEVICE, dtype=dtype)
122
  #output = pipe(prompt='a person', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[im_embs], num_inference_steps=STEPS)
123
  #leave_im_emb, _ = pipe.encode_image(
@@ -126,13 +132,13 @@ pipe.to(device=DEVICE)
126
  #assert len(output.frames[0]) == 16
127
  #leave_im_emb.detach().to('cpu')
128
 
129
- @spaces.GPU(duration=20)
130
  def generate_gpu(in_im_embs):
131
  print('start gen')
132
  in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
133
  #im_embs = torch.cat((torch.zeros(1, 1280, device=DEVICE, dtype=dtype), in_im_embs), 0)
134
 
135
- output = pipe(prompt='a scene', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
136
  print('image is made')
137
  im_emb, _ = pipe.encode_image(
138
  output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
@@ -163,10 +169,6 @@ def generate(in_im_embs):
163
 
164
  #######################
165
 
166
-
167
- # TODO only generate ~5 new images ahead from a specific user embedding. Do this by tracking a column of who's embedding it was and
168
- # taking the intersection for unrated by that user and from that users' embedding. Then we keep styles less consistent for better variety.
169
-
170
  def get_user_emb(embs, ys):
171
  # handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
172
  if len(list(set(ys))) <= 1:
@@ -245,7 +247,17 @@ def background_next_image():
245
  for uid in user_id_list:
246
  rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is not None for i in prevs_df.iterrows()]]
247
  not_rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is None for i in prevs_df.iterrows()]]
248
- if len(rated_rows) < 4:# or len(not_rated_rows) > 7:
 
 
 
 
 
 
 
 
 
 
249
  print(f'latest user {uid} has < 4 rows') # or > 7 unrated rows')
250
  continue
251
 
@@ -260,6 +272,7 @@ def background_next_image():
260
  tmp_df['paths'] = [img]
261
  tmp_df['embeddings'] = [embs]
262
  tmp_df['user:rating'] = [{' ': ' '}]
 
263
  prevs_df = pd.concat((prevs_df, tmp_df))
264
  # we can free up storage by deleting the image
265
  if len(prevs_df) > 50:
@@ -345,7 +358,9 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
345
  choice = 0
346
 
347
  row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
348
- if len(prevs_df.loc[row_mask, 'user:rating'][0]) > 0:
 
 
349
  prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
350
  prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
351
  img, calibrate_prompts = next_image(calibrate_prompts, user_id)
@@ -411,6 +426,7 @@ Explore the latent space without text prompts based on your preferences. Learn m
411
  ''', elem_id="description")
412
  user_id = gr.State()
413
  print('USER_ID: ',user_id)
 
414
  calibrate_prompts = gr.State([
415
  './first.mp4',
416
  './second.mp4',
@@ -429,7 +445,7 @@ Explore the latent space without text prompts based on your preferences. Learn m
429
  interactive=False,
430
  height=512,
431
  width=512,
432
- include_audio=False,
433
  elem_id="video_output"
434
  )
435
  img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
@@ -471,12 +487,12 @@ log = logging.getLogger('log_here')
471
  log.setLevel(logging.ERROR)
472
 
473
  scheduler = BackgroundScheduler()
474
- scheduler.add_job(func=background_next_image, trigger="interval", seconds=4)
475
  scheduler.start()
476
 
477
  def encode_space(x):
478
  im_emb, _ = pipe.encode_image(
479
- image, 'cpu', 1, output_hidden_state
480
  )
481
  return im_emb.detach().to('cpu').to(torch.float32)
482
 
 
1
 
2
 
3
+ # TODO unify/merge origin and this
4
  # TODO save & restart from (if it exists) dataframe parquet
5
  import torch
6
 
 
37
  torch.backends.cuda.matmul.allow_tf32 = True
38
  torch.backends.cudnn.allow_tf32 = True
39
 
40
+ prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'from_user_id'])
41
 
42
  import spaces
 
 
 
43
  start_time = time.time()
44
 
45
  ####################### Setup Model
 
52
  import uuid
53
  import av
54
 
55
+ def write_video_av(file_name, images, fps=17):
56
  print('Saving')
57
  container = av.open(file_name, mode="w")
58
 
59
  stream = container.add_stream("h264", rate=fps)
60
  # stream.options = {'preset': 'faster'}
61
+ stream.thread_count = -1
62
  stream.width = 512
63
  stream.height = 512
64
  stream.pix_fmt = "yuv420p"
 
76
  container.close()
77
  print('Saved')
78
 
79
+ def write_video(file_name, images, fps=15):
80
+ writer = imageio.get_writer(file_name, fps=fps)
81
+
82
+ for im in images:
83
+ writer.append_data(np.array(im))
84
+ writer.close()
85
+
86
 
87
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="sdxl_models/image_encoder", torch_dtype=dtype,
88
+ device_map='cpu')
89
  #vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=dtype)
90
 
91
  # vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
 
96
  #text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
97
 
98
 
99
+ unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet',).to(dtype).to('cpu')
100
+ text_encoder = CLIPTextModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='text_encoder',
101
+ device_map='cpu').to(dtype)
102
 
103
  adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
104
  pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype, unet=unet, text_encoder=text_encoder)
 
107
  pipe.set_adapters(["lcm-lora"], [.9])
108
  pipe.fuse_lora()
109
 
110
+
111
  #pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder)
112
  #pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
113
  #repo = "ByteDance/AnimateDiff-Lightning"
 
123
  pipe.to(device=DEVICE)
124
  #pipe.unet = torch.compile(pipe.unet)
125
  #pipe.vae = torch.compile(pipe.vae)
126
+ # TODO cannot compile on Spaces or we time out; don't run leave_imb stuff either
 
127
  #im_embs = torch.zeros(1, 1, 1, 1280, device=DEVICE, dtype=dtype)
128
  #output = pipe(prompt='a person', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[im_embs], num_inference_steps=STEPS)
129
  #leave_im_emb, _ = pipe.encode_image(
 
132
  #assert len(output.frames[0]) == 16
133
  #leave_im_emb.detach().to('cpu')
134
 
135
+ @spaces.GPU(duration=10)
136
  def generate_gpu(in_im_embs):
137
  print('start gen')
138
  in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
139
  #im_embs = torch.cat((torch.zeros(1, 1280, device=DEVICE, dtype=dtype), in_im_embs), 0)
140
 
141
+ output = pipe(prompt='', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
142
  print('image is made')
143
  im_emb, _ = pipe.encode_image(
144
  output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
 
169
 
170
  #######################
171
 
 
 
 
 
172
  def get_user_emb(embs, ys):
173
  # handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
174
  if len(list(set(ys))) <= 1:
 
247
  for uid in user_id_list:
248
  rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is not None for i in prevs_df.iterrows()]]
249
  not_rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is None for i in prevs_df.iterrows()]]
250
+
251
+ # we need to intersect not_rated_rows from this user's embed > 7. Just add a new column on which user_id spawned the
252
+ # media.
253
+
254
+ from_user = prevs_df[[i[1]['from_user_id'] == uid for i in prevs_df.iterrows()]]
255
+ if len(from_user) >= 10:
256
+ oldest = from_user.iloc[-1]['paths']
257
+ print(f'User has {len(from_user)} rows. Popping oldest: {oldest}')
258
+ prevs_df = prevs_df[prevs_df['paths'] != oldest]
259
+
260
+ if len(rated_rows) < 4:
261
  print(f'latest user {uid} has < 4 rows') # or > 7 unrated rows')
262
  continue
263
 
 
272
  tmp_df['paths'] = [img]
273
  tmp_df['embeddings'] = [embs]
274
  tmp_df['user:rating'] = [{' ': ' '}]
275
+ tmp_df['from_user_id'] = [uid]
276
  prevs_df = pd.concat((prevs_df, tmp_df))
277
  # we can free up storage by deleting the image
278
  if len(prevs_df) > 50:
 
358
  choice = 0
359
 
360
  row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
361
+
362
+
363
+ if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
364
  prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
365
  prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
366
  img, calibrate_prompts = next_image(calibrate_prompts, user_id)
 
426
  ''', elem_id="description")
427
  user_id = gr.State()
428
  print('USER_ID: ',user_id)
429
+ # calibration videos -- this is a misnomer now :D
430
  calibrate_prompts = gr.State([
431
  './first.mp4',
432
  './second.mp4',
 
445
  interactive=False,
446
  height=512,
447
  width=512,
448
+ #include_audio=False,
449
  elem_id="video_output"
450
  )
451
  img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
 
487
  log.setLevel(logging.ERROR)
488
 
489
  scheduler = BackgroundScheduler()
490
+ scheduler.add_job(func=background_next_image, trigger="interval", seconds=.1)
491
  scheduler.start()
492
 
493
  def encode_space(x):
494
  im_emb, _ = pipe.encode_image(
495
+ image, DEVICE, 1, output_hidden_state
496
  )
497
  return im_emb.detach().to('cpu').to(torch.float32)
498