BertChristiaens commited on
Commit
71af695
·
1 Parent(s): 9bfe550

add segmentation

Browse files
Files changed (1) hide show
  1. app.py +34 -6
app.py CHANGED
@@ -29,11 +29,11 @@ def on_upload() -> None:
29
  if 'input_image' in st.session_state and st.session_state['input_image'] is not None:
30
  image = Image.open(st.session_state['input_image']).convert('RGB')
31
  st.session_state['initial_image'] = image
32
- # st.session_state['history'] = [{'image': image.resize((512, 512)),
33
- # 'message': "initial image",
34
- # "positive_prompt": "",
35
- # "negative_prompt": "",
36
- # "index": 0}]
37
 
38
 
39
  def check_reset_state() -> bool:
@@ -63,6 +63,10 @@ def move_image(source: Union[str, Image.Image],
63
 
64
  if remove_state:
65
  st.session_state['reset_canvas'] = True
 
 
 
 
66
 
67
  st.session_state[dest] = source_image
68
  if rerun:
@@ -161,8 +165,32 @@ def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, pain
161
  brush=brush,
162
  _reset_state=_reset_state
163
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- if generation_mode == "Re-generate objects":
166
  canvas = st_canvas(
167
  **canvas_dict,
168
  )
 
29
  if 'input_image' in st.session_state and st.session_state['input_image'] is not None:
30
  image = Image.open(st.session_state['input_image']).convert('RGB')
31
  st.session_state['initial_image'] = image
32
+ if 'seg' in st.session_state:
33
+ del st.session_state['seg']
34
+ if 'unique_colors' in st.session_state:
35
+ del st.session_state['unique_colors']
36
+
37
 
38
 
39
  def check_reset_state() -> bool:
 
63
 
64
  if remove_state:
65
  st.session_state['reset_canvas'] = True
66
+ if 'seg' in st.session_state:
67
+ del st.session_state['seg']
68
+ if 'unique_colors' in st.session_state:
69
+ del st.session_state['unique_colors']
70
 
71
  st.session_state[dest] = source_image
72
  if rerun:
 
165
  brush=brush,
166
  _reset_state=_reset_state
167
  )
168
+ if generation_mode == "Segmentation conditioning":
169
+ canvas = st_canvas(
170
+ **canvas_dict,
171
+ )
172
+
173
+ if st.button("generate image", key='generate_button'):
174
+ image = get_image()
175
+ print("Preparing image segmentation")
176
+ real_seg = segment_image(Image.fromarray(image))
177
+ mask, seg = preprocess_seg_mask(canvas, real_seg)
178
+
179
+ with st.spinner(text="Generating image"):
180
+ print("Making image")
181
+ result_image = make_image_controlnet(image=image,
182
+ mask_image=mask,
183
+ controlnet_conditioning_image=seg,
184
+ positive_prompt=st.session_state['positive_prompt'],
185
+ negative_prompt=st.session_state['negative_prompt'],
186
+ seed=random.randint(0, 100000) # nosec
187
+ )[0]
188
+ if isinstance(result_image, np.ndarray):
189
+ result_image = Image.fromarray(result_image)
190
+ st.session_state['output_image'] = result_image
191
+
192
 
193
+ elif generation_mode == "Re-generate objects":
194
  canvas = st_canvas(
195
  **canvas_dict,
196
  )