pszemraj commited on
Commit
5a43bd7
·
verified ·
1 Parent(s): 57a7aa0

remove BetterTransformer for classifier

Browse files

already integrated into transformers based on error

Files changed (1) hide show
  1. app.py +11 -19
app.py CHANGED
@@ -1,38 +1,30 @@
1
- import re
2
- import os
3
  import gc
 
 
 
4
 
 
5
  from cleantext import clean
6
  import gradio as gr
7
  from tqdm.auto import tqdm
8
  from transformers import pipeline
9
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
10
 
 
 
11
 
12
  checker_model_name = "textattack/roberta-base-CoLA"
13
  corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis"
14
 
15
  # pipelines
16
-
17
-
18
- if os.environ.get("HF_DEMO_NO_USE_ONNX") is None:
19
- from optimum.bettertransformer import BetterTransformer
20
-
21
- model_hf = AutoModelForSequenceClassification.from_pretrained(checker_model_name)
22
- tokenizer = AutoTokenizer.from_pretrained(checker_model_name)
23
- model = BetterTransformer.transform(model_hf, keep_original_model=False)
24
-
25
- checker = pipeline(
26
- "text-classification",
27
- model=model,
28
- tokenizer=tokenizer,
29
- )
30
- else:
31
- checker = pipeline(
32
  "text-classification",
33
  checker_model_name,
34
  )
 
 
35
  gc.collect()
 
36
  if os.environ.get("HF_DEMO_NO_USE_ONNX") is None:
37
  # load onnx runtime unless HF_DEMO_NO_USE_ONNX is set
38
  from optimum.pipelines import pipeline
@@ -130,4 +122,4 @@ with gr.Blocks() as demo:
130
  "- see the [model card](https://huggingface.co/pszemraj/flan-t5-large-grammar-synthesis) for more info"
131
  )
132
  gr.Markdown("- if experiencing long wait times, feel free to duplicate the space!")
133
- demo.launch()
 
 
 
1
  import gc
2
+ import logging
3
+ import os
4
+ import re
5
 
6
+ import torch
7
  from cleantext import clean
8
  import gradio as gr
9
  from tqdm.auto import tqdm
10
  from transformers import pipeline
11
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
12
 
13
+ logging.basicConfig(level=logging.INFO)
14
+ logging.info(f"torch version:\t{torch.__version__}")
15
 
16
  checker_model_name = "textattack/roberta-base-CoLA"
17
  corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis"
18
 
19
  # pipelines
20
+ checker = pipeline(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  "text-classification",
22
  checker_model_name,
23
  )
24
+ checker.model = torch.compile(checker.model)
25
+
26
  gc.collect()
27
+
28
  if os.environ.get("HF_DEMO_NO_USE_ONNX") is None:
29
  # load onnx runtime unless HF_DEMO_NO_USE_ONNX is set
30
  from optimum.pipelines import pipeline
 
122
  "- see the [model card](https://huggingface.co/pszemraj/flan-t5-large-grammar-synthesis) for more info"
123
  )
124
  gr.Markdown("- if experiencing long wait times, feel free to duplicate the space!")
125
+ demo.launch(debug=True)