m3hrdadfi commited on
Commit
757486f
·
1 Parent(s): 233e97e

Fix permission

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
-
3
  import torch
4
  from transformers import pipeline
5
  from transformers import AutoConfig, AutoTokenizer, AutoModelForTokenClassification
@@ -16,6 +16,7 @@ MODELS = {
16
  "Persian (fa)": "m3hrdadfi/typo-detector-distilbert-fa",
17
  "Icelandic (is)": "m3hrdadfi/typo-detector-distilbert-is",
18
  }
 
19
 
20
 
21
  class TypoDetector:
@@ -34,11 +35,15 @@ class TypoDetector:
34
  self.nlp = None
35
  self.normalizer = None
36
 
37
- def load(self):
 
38
  if not self.debug:
39
- self.config = AutoConfig.from_pretrained(self.model_name_or_path)
40
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
41
- self.model = AutoModelForTokenClassification.from_pretrained(self.model_name_or_path, config=self.config)
 
 
 
42
  self.nlp = pipeline(
43
  self.task_name,
44
  model=self.model,
@@ -70,7 +75,7 @@ def load_typo_detectors():
70
  is_detector.load()
71
 
72
  fa_detector = TypoDetector(MODELS["Persian (fa)"])
73
- fa_detector.load()
74
 
75
  return {
76
  "en": en_detector,
 
1
  import streamlit as st
2
+ import os
3
  import torch
4
  from transformers import pipeline
5
  from transformers import AutoConfig, AutoTokenizer, AutoModelForTokenClassification
 
16
  "Persian (fa)": "m3hrdadfi/typo-detector-distilbert-fa",
17
  "Icelandic (is)": "m3hrdadfi/typo-detector-distilbert-is",
18
  }
19
+ API_TOKEN = os.environ.get("API_TOKEN")
20
 
21
 
22
  class TypoDetector:
 
35
  self.nlp = None
36
  self.normalizer = None
37
 
38
+ def load(self, api_token=None):
39
+ api_token = api_token if api_token else False
40
  if not self.debug:
41
+ self.config = AutoConfig.from_pretrained(self.model_name_or_path, use_auth_token=api_token)
42
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_auth_token=api_token)
43
+ self.model = AutoModelForTokenClassification.from_pretrained(
44
+ self.model_name_or_path,
45
+ config=self.config,
46
+ use_auth_token=api_token)
47
  self.nlp = pipeline(
48
  self.task_name,
49
  model=self.model,
 
75
  is_detector.load()
76
 
77
  fa_detector = TypoDetector(MODELS["Persian (fa)"])
78
+ fa_detector.load(api_token=API_TOKEN)
79
 
80
  return {
81
  "en": en_detector,