PANH commited on
Commit
1f76ea6
·
verified ·
1 Parent(s): b6c6cdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -9
app.py CHANGED
@@ -1,15 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def get_safetensors():
4
- with open("AlignScore-base.safetensors", "rb") as f:
5
- return f.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  iface = gr.Interface(
8
- fn=get_safetensors,
9
- inputs=[],
10
- outputs=gr.outputs.File(label="Download SafeTensors Model"),
11
- title="Download SafeTensors Model",
12
- description="Click the button below to download the SafeTensors version of the model."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  )
14
 
15
- iface.launch()
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import tempfile
4
+ import shutil
5
+
6
+ import torch
7
+ from pytorch_lightning import LightningModule
8
+ from safetensors.torch import save_file
9
+ from torch import nn
10
+ from modelalign import BERTAlignModel
11
+
12
  import gradio as gr
13
 
14
+
15
+ # ===========================
16
+ # Utility Functions
17
+ # ===========================
18
+
19
+ def download_checkpoint(url: str, dest_path: str):
20
+ """
21
+ Downloads the checkpoint from the specified URL to the destination path.
22
+ """
23
+ try:
24
+ with requests.get(url, stream=True) as response:
25
+ response.raise_for_status()
26
+ with open(dest_path, 'wb') as f:
27
+ shutil.copyfileobj(response.raw, f)
28
+ return True, "Checkpoint downloaded successfully."
29
+ except Exception as e:
30
+ return False, f"Failed to download checkpoint: {str(e)}"
31
+
32
+ def initialize_model(model_name: str, device: str = 'cpu'):
33
+ """
34
+ Initializes the BERTAlignModel based on the provided model name.
35
+ """
36
+ try:
37
+ model = BERTAlignModel(base_model_name=model_name)
38
+ model.to(device)
39
+ model.eval() # Set to evaluation mode
40
+ return True, model
41
+ except Exception as e:
42
+ return False, f"Failed to initialize model: {str(e)}"
43
+
44
+ def load_checkpoint(model: LightningModule, checkpoint_path: str, device: str = 'cpu'):
45
+ """
46
+ Loads the checkpoint into the model.
47
+ """
48
+ try:
49
+ # Load the checkpoint; adjust map_location based on device
50
+ checkpoint = torch.load(checkpoint_path, map_location=device)
51
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
52
+ return True, "Checkpoint loaded successfully."
53
+ except Exception as e:
54
+ return False, f"Failed to load checkpoint: {str(e)}"
55
+
56
+ def convert_to_safetensors(model: LightningModule, save_path: str):
57
+ """
58
+ Converts the model's state_dict to the safetensors format.
59
+ """
60
+ try:
61
+ state_dict = model.state_dict()
62
+ save_file(state_dict, save_path)
63
+ return True, "Model converted to SafeTensors successfully."
64
+ except Exception as e:
65
+ return False, f"Failed to convert to SafeTensors: {str(e)}"
66
+
67
+ # ===========================
68
+ # Gradio Interface Function
69
+ # ===========================
70
+
71
+ def convert_checkpoint_to_safetensors(checkpoint_url: str, model_name: str):
72
+ """
73
+ Orchestrates the download, loading, conversion, and preparation for download.
74
+ Returns the safetensors file or an error message.
75
+ """
76
+ with tempfile.TemporaryDirectory() as tmpdir:
77
+ checkpoint_path = os.path.join(tmpdir, "model.ckpt")
78
+ safetensors_path = os.path.join(tmpdir, "model.safetensors")
79
+
80
+ # Step 1: Download the checkpoint
81
+ success, message = download_checkpoint(checkpoint_url, checkpoint_path)
82
+ if not success:
83
+ return gr.update(value=None, visible=False), message
84
+
85
+ # Step 2: Initialize the model
86
+ success, model_or_msg = initialize_model(model_name)
87
+ if not success:
88
+ return gr.update(value=None, visible=False), model_or_msg
89
+ model = model_or_msg
90
+
91
+ # Step 3: Load the checkpoint
92
+ success, message = load_checkpoint(model, checkpoint_path)
93
+ if not success:
94
+ return gr.update(value=None, visible=False), message
95
+
96
+ # Step 4: Convert to SafeTensors
97
+ success, message = convert_to_safetensors(model, safetensors_path)
98
+ if not success:
99
+ return gr.update(value=None, visible=False), message
100
+
101
+ # Step 5: Read the safetensors file for download
102
+ try:
103
+ with open(safetensors_path, "rb") as f:
104
+ safetensors_bytes = f.read()
105
+ return safetensors_bytes, "Conversion successful! Download your SafeTensors file below."
106
+ except Exception as e:
107
+ return gr.update(value=None, visible=False), f"Failed to prepare download: {str(e)}"
108
+
109
+ # ===========================
110
+ # Gradio Interface Setup
111
+ # ===========================
112
+
113
+ title = "Checkpoint to SafeTensors Converter"
114
+ description = """
115
+ Convert your PyTorch Lightning `.ckpt` checkpoints to the secure `safetensors` format.
116
+
117
+ **Inputs**:
118
+ - **Checkpoint URL**: Direct link to the `.ckpt` file.
119
+ - **Model Name**: Name of the base model (e.g., `roberta-base`, `bert-base-uncased`).
120
+
121
+ **Output**:
122
+ - Downloadable `safetensors` file.
123
+ """
124
 
125
  iface = gr.Interface(
126
+ fn=convert_checkpoint_to_safetensors,
127
+ inputs=[
128
+ gr.inputs.Textbox(lines=2, placeholder="Enter the checkpoint URL here...", label="Checkpoint URL"),
129
+ gr.inputs.Textbox(lines=1, placeholder="e.g., roberta-base", label="Model Name")
130
+ ],
131
+ outputs=[
132
+ gr.outputs.File(label="Download SafeTensors File"),
133
+ gr.outputs.Textbox(label="Status")
134
+ ],
135
+ title=title,
136
+ description=description,
137
+ examples=[
138
+ [
139
+ "https://huggingface.co/yzha/AlignScore/resolve/main/AlignScore-base.ckpt?download=true",
140
+ "roberta-base"
141
+ ],
142
+ [
143
+ "https://path.to/your/checkpoint.ckpt",
144
+ "bert-base-uncased"
145
+ ]
146
+ ],
147
+ allow_flagging="never"
148
  )
149
 
150
+ # ===========================
151
+ # Launch the Interface
152
+ # ===========================
153
+
154
+ if __name__ == "__main__":
155
+ iface.launch()