Spaces:
Runtime error
Runtime error
fix pipeline referencement
Browse files
main.py
CHANGED
@@ -19,7 +19,7 @@ from typing import Dict, List, Union
|
|
19 |
from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer, ORTQuantizer
|
20 |
from optimum.onnxruntime.configuration import OptimizationConfig, AutoQuantizationConfig
|
21 |
from optimum.pipelines import pipeline as ort_pipeline
|
22 |
-
from transformers import BertTokenizer, BertForSequenceClassification,
|
23 |
|
24 |
from utils import calculate_inference_time
|
25 |
|
@@ -105,7 +105,7 @@ def load_pipeline(pipeline_name: str) -> None:
|
|
105 |
"""
|
106 |
if pipeline_name == "pt_pipeline":
|
107 |
model = BertForSequenceClassification.from_pretrained(HUB_MODEL_PATH, num_labels=3)
|
108 |
-
pipeline =
|
109 |
elif pipeline_name == "ort_pipeline":
|
110 |
model = ORTModelForSequenceClassification.from_pretrained(HUB_MODEL_PATH, from_transformers=True)
|
111 |
if not ONNX_MODEL_PATH.exists():
|
@@ -120,7 +120,7 @@ def load_pipeline(pipeline_name: str) -> None:
|
|
120 |
model = ORTModelForSequenceClassification.from_pretrained(
|
121 |
OPTIMIZED_BASE_PATH, file_name=OPTIMIZED_MODEL_PATH.name
|
122 |
)
|
123 |
-
pipeline =
|
124 |
elif pipeline_name == "ort_quantized_pipeline":
|
125 |
if not QUANTIZED_MODEL_PATH.exists():
|
126 |
quantization_config = AutoQuantizationConfig.arm64(is_static=False, per_channel=False)
|
@@ -130,7 +130,7 @@ def load_pipeline(pipeline_name: str) -> None:
|
|
130 |
model = ORTModelForSequenceClassification.from_pretrained(
|
131 |
QUANTIZED_BASE_PATH, file_name=QUANTIZED_MODEL_PATH.name
|
132 |
)
|
133 |
-
pipeline =
|
134 |
print(type(pipeline))
|
135 |
return pipeline
|
136 |
|
|
|
19 |
from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer, ORTQuantizer
|
20 |
from optimum.onnxruntime.configuration import OptimizationConfig, AutoQuantizationConfig
|
21 |
from optimum.pipelines import pipeline as ort_pipeline
|
22 |
+
from transformers import BertTokenizer, BertForSequenceClassification, pt_pipeline
|
23 |
|
24 |
from utils import calculate_inference_time
|
25 |
|
|
|
105 |
"""
|
106 |
if pipeline_name == "pt_pipeline":
|
107 |
model = BertForSequenceClassification.from_pretrained(HUB_MODEL_PATH, num_labels=3)
|
108 |
+
pipeline = pt_pipeline("sentiment-analysis", tokenizer=st.session_state["tokenizer"], model=model)
|
109 |
elif pipeline_name == "ort_pipeline":
|
110 |
model = ORTModelForSequenceClassification.from_pretrained(HUB_MODEL_PATH, from_transformers=True)
|
111 |
if not ONNX_MODEL_PATH.exists():
|
|
|
120 |
model = ORTModelForSequenceClassification.from_pretrained(
|
121 |
OPTIMIZED_BASE_PATH, file_name=OPTIMIZED_MODEL_PATH.name
|
122 |
)
|
123 |
+
pipeline = ort_pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model)
|
124 |
elif pipeline_name == "ort_quantized_pipeline":
|
125 |
if not QUANTIZED_MODEL_PATH.exists():
|
126 |
quantization_config = AutoQuantizationConfig.arm64(is_static=False, per_channel=False)
|
|
|
130 |
model = ORTModelForSequenceClassification.from_pretrained(
|
131 |
QUANTIZED_BASE_PATH, file_name=QUANTIZED_MODEL_PATH.name
|
132 |
)
|
133 |
+
pipeline = ort_pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model)
|
134 |
print(type(pipeline))
|
135 |
return pipeline
|
136 |
|