chainyo commited on
Commit
d015acd
·
1 Parent(s): 6a961c3

fix loading pipelines

Browse files
Files changed (1) hide show
  1. main.py +52 -68
main.py CHANGED
@@ -18,7 +18,6 @@ from typing import Dict, List, Union
18
 
19
  from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer, ORTQuantizer
20
  from optimum.onnxruntime.configuration import OptimizationConfig, AutoQuantizationConfig
21
- from optimum.onnxruntime.model import ORTModel
22
  from optimum.pipelines import pipeline as ort_pipeline
23
  from transformers import BertTokenizer, BertForSequenceClassification, pipeline
24
 
@@ -39,6 +38,11 @@ VAR2LABEL = {
39
  "ort_quantized_pipeline": "ONNXRuntime (Quantized)",
40
  }
41
 
 
 
 
 
 
42
 
43
  def get_timers(
44
  samples: Union[List[str], str], exp_number: int, only_mean: bool = False
@@ -64,9 +68,10 @@ def get_timers(
64
  timers: Dict[str, float] = {}
65
  for model in VAR2LABEL.keys():
66
  time_buffer = []
 
67
  for _ in range(exp_number):
68
  with calculate_inference_time(time_buffer):
69
- st.session_state[model](samples)
70
  timers[VAR2LABEL[model]] = np.mean(time_buffer) if only_mean else time_buffer
71
  return timers
72
 
@@ -87,6 +92,47 @@ def get_plot(timers: Dict[str, Union[float, List[float]]]) -> plotly.graph_objs.
87
  )
88
  fig.update_layout(title_text="Inference Time", xaxis_title="Inference Time (s)", yaxis_title="Number of Samples")
89
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
 
92
  st.set_page_config(page_title="Optimum Text Classification", page_icon="⭐")
@@ -125,72 +171,10 @@ if st.session_state["init_models"]:
125
  tokenizer = BertTokenizer.from_pretrained(HUB_MODEL_PATH)
126
  st.session_state["tokenizer"] = tokenizer
127
  st.text("✅ Tokenizer loaded.")
128
-
129
- if "pt_model" not in st.session_state:
130
- pt_model = BertForSequenceClassification.from_pretrained(HUB_MODEL_PATH, num_labels=3)
131
- st.session_state["pt_model"] = pt_model
132
- st.text("✅ PyTorch model loaded.")
133
-
134
- if "ort_model" not in st.session_state:
135
- ort_model = ORTModelForSequenceClassification.from_pretrained(HUB_MODEL_PATH, from_transformers=True)
136
- if not ONNX_MODEL_PATH.exists():
137
- ort_model.save_pretrained(ONNX_MODEL_PATH)
138
- st.session_state["ort_model"] = ort_model
139
- st.text("✅ ONNX Model loaded.")
140
-
141
- if "optimized_model" not in st.session_state:
142
- optimization_config = OptimizationConfig(optimization_level=99)
143
- optimizer = ORTOptimizer.from_pretrained(HUB_MODEL_PATH, feature="sequence-classification")
144
- if not OPTIMIZED_MODEL_PATH.exists():
145
- optimizer.export(ONNX_MODEL_PATH, OPTIMIZED_MODEL_PATH, optimization_config=optimization_config)
146
- optimizer.model.config.save_pretrained(OPTIMIZED_BASE_PATH)
147
- optimized_model = ORTModelForSequenceClassification.from_pretrained(
148
- OPTIMIZED_BASE_PATH, file_name=OPTIMIZED_MODEL_PATH.name
149
- )
150
- st.session_state["optimized_model"] = optimized_model
151
- st.text("✅ Optimized ONNX model loaded.")
152
-
153
- if "quantized_model" not in st.session_state:
154
- quantization_config = AutoQuantizationConfig.arm64(is_static=False, per_channel=False)
155
- quantizer = ORTQuantizer.from_pretrained(HUB_MODEL_PATH, feature="sequence-classification")
156
- if not QUANTIZED_MODEL_PATH.exists():
157
- quantizer.export(ONNX_MODEL_PATH, QUANTIZED_MODEL_PATH, quantization_config=quantization_config)
158
- quantizer.model.config.save_pretrained(QUANTIZED_BASE_PATH)
159
- quantized_model = ORTModelForSequenceClassification.from_pretrained(
160
- QUANTIZED_BASE_PATH, file_name=QUANTIZED_MODEL_PATH.name
161
- )
162
- st.session_state["quantized_model"] = quantized_model
163
- st.text("✅ Quantized ONNX model loaded.")
164
-
165
- if "pt_pipeline" not in st.session_state:
166
- pt_pipeline = pipeline(
167
- "sentiment-analysis", tokenizer=st.session_state["tokenizer"], model=st.session_state["pt_model"]
168
- )
169
- st.session_state["pt_pipeline"] = pt_pipeline
170
-
171
- if "ort_pipeline" not in st.session_state:
172
- ort_pipeline = ort_pipeline(
173
- "text-classification", tokenizer=st.session_state["tokenizer"], model=st.session_state["ort_model"]
174
- )
175
- st.session_state["ort_pipeline"] = ort_pipeline
176
-
177
- if "ort_optimized_pipeline" not in st.session_state:
178
- ort_optimized_pipeline = pipeline(
179
- "text-classification",
180
- tokenizer=st.session_state["tokenizer"],
181
- model=st.session_state["optimized_model"],
182
- )
183
- st.session_state["ort_optimized_pipeline"] = ort_optimized_pipeline
184
-
185
- if "ort_quantized_pipeline" not in st.session_state:
186
- ort_quantized_pipeline = pipeline(
187
- "text-classification",
188
- tokenizer=st.session_state["tokenizer"],
189
- model=st.session_state["quantized_model"],
190
- )
191
- st.session_state["ort_quantized_pipeline"] = ort_quantized_pipeline
192
-
193
- st.text("✅ All pipelines are ready.")
194
  sleep(2)
195
  loading_logs.success("🎉 Everything is ready!")
196
  st.session_state["init_models"] = False
 
18
 
19
  from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer, ORTQuantizer
20
  from optimum.onnxruntime.configuration import OptimizationConfig, AutoQuantizationConfig
 
21
  from optimum.pipelines import pipeline as ort_pipeline
22
  from transformers import BertTokenizer, BertForSequenceClassification, pipeline
23
 
 
38
  "ort_quantized_pipeline": "ONNXRuntime (Quantized)",
39
  }
40
 
41
+ # Check if repositories exist, if not create them
42
+ BASE_PATH.mkdir(exist_ok=True)
43
+ QUANTIZED_BASE_PATH.mkdir(exist_ok=True)
44
+ OPTIMIZED_BASE_PATH.mkdir(exist_ok=True)
45
+
46
 
47
  def get_timers(
48
  samples: Union[List[str], str], exp_number: int, only_mean: bool = False
 
68
  timers: Dict[str, float] = {}
69
  for model in VAR2LABEL.keys():
70
  time_buffer = []
71
+ st.session_state["pipeline"] = load_pipeline(model)
72
  for _ in range(exp_number):
73
  with calculate_inference_time(time_buffer):
74
+ st.session_state["pipeline"](samples)
75
  timers[VAR2LABEL[model]] = np.mean(time_buffer) if only_mean else time_buffer
76
  return timers
77
 
 
92
  )
93
  fig.update_layout(title_text="Inference Time", xaxis_title="Inference Time (s)", yaxis_title="Number of Samples")
94
  return fig
95
+
96
+
97
+ def load_pipeline(pipeline_name: str) -> None:
98
+ """
99
+ Load a pipeline for a given model.
100
+
101
+ Parameters
102
+ ----------
103
+ pipeline_name : str
104
+ Name of the pipeline to load.
105
+ """
106
+ if pipeline_name == "pt_pipeline":
107
+ model = BertForSequenceClassification.from_pretrained(HUB_MODEL_PATH, num_labels=3)
108
+ pipeline = pipeline("sentiment-analysis", tokenizer=st.session_state["tokenizer"], model=model)
109
+ elif pipeline_name == "ort_pipeline":
110
+ model = ORTModelForSequenceClassification.from_pretrained(HUB_MODEL_PATH, from_transformers=True)
111
+ if not ONNX_MODEL_PATH.exists():
112
+ model.save_pretrained(ONNX_MODEL_PATH)
113
+ pipeline = ort_pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model)
114
+ elif pipeline_name == "ort_optimized_pipeline":
115
+ if not OPTIMIZED_MODEL_PATH.exists():
116
+ optimization_config = OptimizationConfig(optimization_level=99)
117
+ optimizer = ORTOptimizer.from_pretrained(HUB_MODEL_PATH, feature="sequence-classification")
118
+ optimizer.export(ONNX_MODEL_PATH, OPTIMIZED_MODEL_PATH, optimization_config=optimization_config)
119
+ optimizer.model.config.save_pretrained(OPTIMIZED_BASE_PATH)
120
+ model = ORTModelForSequenceClassification.from_pretrained(
121
+ OPTIMIZED_BASE_PATH, file_name=OPTIMIZED_MODEL_PATH.name
122
+ )
123
+ pipeline = pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model)
124
+ elif pipeline_name == "ort_quantized_pipeline":
125
+ if not QUANTIZED_MODEL_PATH.exists():
126
+ quantization_config = AutoQuantizationConfig.arm64(is_static=False, per_channel=False)
127
+ quantizer = ORTQuantizer.from_pretrained(HUB_MODEL_PATH, feature="sequence-classification")
128
+ quantizer.export(ONNX_MODEL_PATH, QUANTIZED_MODEL_PATH, quantization_config=quantization_config)
129
+ quantizer.model.config.save_pretrained(QUANTIZED_BASE_PATH)
130
+ model = ORTModelForSequenceClassification.from_pretrained(
131
+ QUANTIZED_BASE_PATH, file_name=QUANTIZED_MODEL_PATH.name
132
+ )
133
+ pipeline = pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model)
134
+ print(type(pipeline))
135
+ return pipeline
136
 
137
 
138
  st.set_page_config(page_title="Optimum Text Classification", page_icon="⭐")
 
171
  tokenizer = BertTokenizer.from_pretrained(HUB_MODEL_PATH)
172
  st.session_state["tokenizer"] = tokenizer
173
  st.text("✅ Tokenizer loaded.")
174
+ if "pipeline" not in st.session_state:
175
+ for pipeline in VAR2LABEL.keys():
176
+ st.session_state["pipeline"] = load_pipeline(pipeline)
177
+ st.text(" Models ready.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  sleep(2)
179
  loading_logs.success("🎉 Everything is ready!")
180
  st.session_state["init_models"] = False