hysts HF staff commited on
Commit
26fa884
·
1 Parent(s): 5ed3dd9
Files changed (7) hide show
  1. .pre-commit-config.yaml +37 -0
  2. README.md +4 -1
  3. app.py +73 -114
  4. images/README.md +0 -1
  5. model.py +12 -16
  6. requirements.txt +1 -1
  7. style.css +8 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: ^patch.*
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
README.md CHANGED
@@ -4,9 +4,12 @@ emoji: ⚡
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.0.17
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
 
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.35.2
8
  app_file: app.py
9
  pinned: false
10
+ suggested_hardware: t4-small
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
14
+
15
+ https://arxiv.org/abs/2112.05142
app.py CHANGED
@@ -2,24 +2,16 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
  import pathlib
7
 
8
  import gradio as gr
9
 
10
  from model import Model
11
 
 
12
 
13
- def parse_args() -> argparse.Namespace:
14
- parser = argparse.ArgumentParser()
15
- parser.add_argument('--device', type=str, default='cpu')
16
- parser.add_argument('--theme', type=str)
17
- parser.add_argument('--share', action='store_true')
18
- parser.add_argument('--port', type=int)
19
- parser.add_argument('--disable-queue',
20
- dest='enable_queue',
21
- action='store_false')
22
- return parser.parse_args()
23
 
24
 
25
  def load_hairstyle_list() -> list[str]:
@@ -40,106 +32,73 @@ def update_step2_components(choice: str) -> tuple[dict, dict]:
40
  )
41
 
42
 
43
- def main():
44
- args = parse_args()
45
- model = Model(device=args.device)
46
-
47
- css = '''
48
- h1#title {
49
- text-align: center;
50
- }
51
- img#teaser {
52
- max-width: 1000px;
53
- max-height: 600px;
54
- }
55
- '''
56
-
57
- with gr.Blocks(theme=args.theme, css=css) as demo:
58
- gr.Markdown('''<h1 id="title">HairCLIP</h1>
59
-
60
- This is an unofficial demo for <a href="https://github.com/wty-ustc/HairCLIP">https://github.com/wty-ustc/HairCLIP</a>.
61
-
62
- <center><img id="teaser" src="https://raw.githubusercontent.com/wty-ustc/HairCLIP/main/assets/teaser.png" alt="teaser"></center>
63
- ''')
64
- with gr.Box():
65
- gr.Markdown('## Step 1')
66
- with gr.Row():
67
- with gr.Column():
68
- with gr.Row():
69
- input_image = gr.Image(label='Input Image',
70
- type='file')
71
- with gr.Row():
72
- preprocess_button = gr.Button('Preprocess')
73
- with gr.Column():
74
- aligned_face = gr.Image(label='Aligned Face',
75
- type='pil',
76
- interactive=False)
77
- with gr.Column():
78
- reconstructed_face = gr.Image(label='Reconstructed Face',
79
- type='numpy')
80
- latent = gr.Variable()
81
-
82
- with gr.Row():
83
- paths = sorted(pathlib.Path('images').glob('*.jpg'))
84
- example_images = gr.Dataset(components=[input_image],
85
- samples=[[path.as_posix()]
86
- for path in paths])
87
-
88
- with gr.Box():
89
- gr.Markdown('## Step 2')
90
- with gr.Row():
91
- with gr.Column():
92
- with gr.Row():
93
- editing_type = gr.Radio(['hairstyle', 'color', 'both'],
94
- value='both',
95
- type='value',
96
- label='Editing Type')
97
- with gr.Row():
98
- hairstyles = load_hairstyle_list()
99
- hairstyle_index = gr.Dropdown(hairstyles,
100
- value='afro',
101
- type='index',
102
- label='Hairstyle')
103
- with gr.Row():
104
- color_description = gr.Textbox(value='red',
105
- label='Color')
106
- with gr.Row():
107
- run_button = gr.Button('Run')
108
-
109
- with gr.Column():
110
- result = gr.Image(label='Result')
111
-
112
- gr.Markdown(
113
- '<center><img src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.hairclip" alt="visitor badge"/></center>'
114
- )
115
-
116
- preprocess_button.click(fn=model.detect_and_align_face,
117
- inputs=[input_image],
118
- outputs=[aligned_face])
119
- aligned_face.change(fn=model.reconstruct_face,
120
- inputs=[aligned_face],
121
- outputs=[reconstructed_face, latent])
122
- editing_type.change(fn=update_step2_components,
123
- inputs=[editing_type],
124
- outputs=[hairstyle_index, color_description])
125
- run_button.click(fn=model.generate,
126
- inputs=[
127
- editing_type,
128
- hairstyle_index,
129
- color_description,
130
- latent,
131
- ],
132
- outputs=[result])
133
- example_images.click(fn=set_example_image,
134
- inputs=example_images,
135
- outputs=example_images.components)
136
-
137
- demo.launch(
138
- enable_queue=args.enable_queue,
139
- server_port=args.port,
140
- share=args.share,
141
- )
142
-
143
-
144
- if __name__ == '__main__':
145
- main()
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import pathlib
6
 
7
  import gradio as gr
8
 
9
  from model import Model
10
 
11
+ DESCRIPTION = '''# [HairCLIP](https://github.com/wty-ustc/HairCLIP)
12
 
13
+ <center><img id="teaser" src="https://raw.githubusercontent.com/wty-ustc/HairCLIP/main/assets/teaser.png" alt="teaser"></center>
14
+ '''
 
 
 
 
 
 
 
 
15
 
16
 
17
  def load_hairstyle_list() -> list[str]:
 
32
  )
33
 
34
 
35
+ model = Model()
36
+
37
+ with gr.Blocks(css='style.css') as demo:
38
+ gr.Markdown(DESCRIPTION)
39
+ with gr.Box():
40
+ gr.Markdown('## Step 1')
41
+ with gr.Row():
42
+ with gr.Column():
43
+ with gr.Row():
44
+ input_image = gr.Image(label='Input Image',
45
+ type='filepath')
46
+ with gr.Row():
47
+ preprocess_button = gr.Button('Preprocess')
48
+ with gr.Column():
49
+ aligned_face = gr.Image(label='Aligned Face',
50
+ type='pil',
51
+ interactive=False)
52
+ with gr.Column():
53
+ reconstructed_face = gr.Image(label='Reconstructed Face',
54
+ type='numpy')
55
+ latent = gr.Variable()
56
+
57
+ with gr.Row():
58
+ paths = sorted(pathlib.Path('images').glob('*.jpg'))
59
+ gr.Examples(examples=[[path.as_posix()] for path in paths],
60
+ inputs=input_image)
61
+
62
+ with gr.Box():
63
+ gr.Markdown('## Step 2')
64
+ with gr.Row():
65
+ with gr.Column():
66
+ with gr.Row():
67
+ editing_type = gr.Radio(
68
+ label='Editing Type',
69
+ choices=['hairstyle', 'color', 'both'],
70
+ value='both',
71
+ type='value')
72
+ with gr.Row():
73
+ hairstyles = load_hairstyle_list()
74
+ hairstyle_index = gr.Dropdown(label='Hairstyle',
75
+ choices=hairstyles,
76
+ value='afro',
77
+ type='index')
78
+ with gr.Row():
79
+ color_description = gr.Textbox(label='Color', value='red')
80
+ with gr.Row():
81
+ run_button = gr.Button('Run')
82
+
83
+ with gr.Column():
84
+ result = gr.Image(label='Result')
85
+
86
+ preprocess_button.click(fn=model.detect_and_align_face,
87
+ inputs=input_image,
88
+ outputs=aligned_face)
89
+ aligned_face.change(fn=model.reconstruct_face,
90
+ inputs=aligned_face,
91
+ outputs=[reconstructed_face, latent])
92
+ editing_type.change(fn=update_step2_components,
93
+ inputs=editing_type,
94
+ outputs=[hairstyle_index, color_description])
95
+ run_button.click(fn=model.generate,
96
+ inputs=[
97
+ editing_type,
98
+ hairstyle_index,
99
+ color_description,
100
+ latent,
101
+ ],
102
+ outputs=result)
103
+
104
+ demo.queue(max_size=10).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
images/README.md CHANGED
@@ -4,4 +4,3 @@ These images are freely-usable ones from [Unsplash](https://unsplash.com/).
4
  - https://unsplash.com/photos/et_78QkMMQs
5
  - https://unsplash.com/photos/ILip77SbmOE
6
  - https://unsplash.com/photos/95UF6LXe-Lo
7
-
 
4
  - https://unsplash.com/photos/et_78QkMMQs
5
  - https://unsplash.com/photos/ILip77SbmOE
6
  - https://unsplash.com/photos/95UF6LXe-Lo
 
model.py CHANGED
@@ -15,7 +15,7 @@ import torch
15
  import torch.nn as nn
16
  import torchvision.transforms as T
17
 
18
- if os.getenv('SYSTEM') == 'spaces':
19
  with open('patch.e4e') as f:
20
  subprocess.run('patch -p1'.split(), cwd='encoder4editing', stdin=f)
21
  with open('patch.hairclip') as f:
@@ -37,12 +37,11 @@ sys.path.insert(0, mapper_dir.as_posix())
37
  from mapper.datasets.latents_dataset_inference import LatentsDatasetInference
38
  from mapper.hairclip_mapper import HairCLIPMapper
39
 
40
- HF_TOKEN = os.environ['HF_TOKEN']
41
-
42
 
43
  class Model:
44
- def __init__(self, device: Union[torch.device, str]):
45
- self.device = torch.device(device)
 
46
  self.landmark_model = self._create_dlib_landmark_model()
47
  self.e4e = self._load_e4e()
48
  self.hairclip = self._load_hairclip()
@@ -51,15 +50,13 @@ class Model:
51
  @staticmethod
52
  def _create_dlib_landmark_model():
53
  path = huggingface_hub.hf_hub_download(
54
- 'hysts/dlib_face_landmark_model',
55
- 'shape_predictor_68_face_landmarks.dat',
56
- use_auth_token=HF_TOKEN)
57
  return dlib.shape_predictor(path)
58
 
59
  def _load_e4e(self) -> nn.Module:
60
- ckpt_path = huggingface_hub.hf_hub_download('hysts/e4e',
61
- 'e4e_ffhq_encode.pt',
62
- use_auth_token=HF_TOKEN)
63
  ckpt = torch.load(ckpt_path, map_location='cpu')
64
  opts = ckpt['opts']
65
  opts['device'] = self.device.type
@@ -71,9 +68,8 @@ class Model:
71
  return model
72
 
73
  def _load_hairclip(self) -> nn.Module:
74
- ckpt_path = huggingface_hub.hf_hub_download('hysts/HairCLIP',
75
- 'hairclip.pt',
76
- use_auth_token=HF_TOKEN)
77
  ckpt = torch.load(ckpt_path, map_location='cpu')
78
  opts = ckpt['opts']
79
  opts['device'] = self.device.type
@@ -98,8 +94,8 @@ class Model:
98
  ])
99
  return transform
100
 
101
- def detect_and_align_face(self, image) -> PIL.Image.Image:
102
- image = align_face(filepath=image.name, predictor=self.landmark_model)
103
  return image
104
 
105
  @staticmethod
 
15
  import torch.nn as nn
16
  import torchvision.transforms as T
17
 
18
+ if os.getenv('SYSTEM') == 'spaces' and not torch.cuda.is_available():
19
  with open('patch.e4e') as f:
20
  subprocess.run('patch -p1'.split(), cwd='encoder4editing', stdin=f)
21
  with open('patch.hairclip') as f:
 
37
  from mapper.datasets.latents_dataset_inference import LatentsDatasetInference
38
  from mapper.hairclip_mapper import HairCLIPMapper
39
 
 
 
40
 
41
  class Model:
42
+ def __init__(self):
43
+ self.device = torch.device(
44
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
45
  self.landmark_model = self._create_dlib_landmark_model()
46
  self.e4e = self._load_e4e()
47
  self.hairclip = self._load_hairclip()
 
50
  @staticmethod
51
  def _create_dlib_landmark_model():
52
  path = huggingface_hub.hf_hub_download(
53
+ 'public-data/dlib_face_landmark_model',
54
+ 'shape_predictor_68_face_landmarks.dat')
 
55
  return dlib.shape_predictor(path)
56
 
57
  def _load_e4e(self) -> nn.Module:
58
+ ckpt_path = huggingface_hub.hf_hub_download('public-data/e4e',
59
+ 'e4e_ffhq_encode.pt')
 
60
  ckpt = torch.load(ckpt_path, map_location='cpu')
61
  opts = ckpt['opts']
62
  opts['device'] = self.device.type
 
68
  return model
69
 
70
  def _load_hairclip(self) -> nn.Module:
71
+ ckpt_path = huggingface_hub.hf_hub_download('public-data/HairCLIP',
72
+ 'hairclip.pt')
 
73
  ckpt = torch.load(ckpt_path, map_location='cpu')
74
  opts = ckpt['opts']
75
  opts['device'] = self.device.type
 
94
  ])
95
  return transform
96
 
97
+ def detect_and_align_face(self, image: str) -> PIL.Image.Image:
98
+ image = align_face(filepath=image, predictor=self.landmark_model)
99
  return image
100
 
101
  @staticmethod
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
  dlib==19.23.0
 
2
  numpy==1.22.3
3
  opencv-python-headless==4.5.5.64
4
  Pillow==9.1.0
5
  scipy==1.8.0
6
  torch==1.11.0
7
  torchvision==0.12.0
8
- git+https://github.com/openai/CLIP.git
 
1
  dlib==19.23.0
2
+ git+https://github.com/openai/CLIP.git
3
  numpy==1.22.3
4
  opencv-python-headless==4.5.5.64
5
  Pillow==9.1.0
6
  scipy==1.8.0
7
  torch==1.11.0
8
  torchvision==0.12.0
 
style.css ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ img#teaser {
6
+ max-width: 1000px;
7
+ max-height: 600px;
8
+ }