yhavinga commited on
Commit
f3e8368
Β·
1 Parent(s): 0d30451

Fix double model load

Browse files
Files changed (1) hide show
  1. generator.py +19 -21
generator.py CHANGED
@@ -18,7 +18,7 @@ def get_access_token():
18
  return access_token
19
 
20
 
21
- # @st.cache(hash_funcs={_thread.RLock: lambda _: None}, suppress_st_warning=True, allow_output_mutation=True)
22
  def load_model(model_name):
23
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
24
  tokenizer = AutoTokenizer.from_pretrained(
@@ -67,25 +67,24 @@ class Generator:
67
  self.load()
68
 
69
  def load(self):
70
- if not self.model:
71
- print(f"Loading model {self.model_name}")
72
- self.tokenizer, self.model = load_model(self.model_name)
73
-
74
- for key in self.gen_kwargs:
75
- if key in self.model.config.__dict__:
76
- self.gen_kwargs[key] = self.model.config.__dict__[key]
77
- try:
78
- if self.task in self.model.config.task_specific_params:
79
- task_specific_params = self.model.config.task_specific_params[
80
- self.task
81
- ]
82
- if "prefix" in task_specific_params:
83
- self.prefix = task_specific_params["prefix"]
84
- for key in self.gen_kwargs:
85
- if key in task_specific_params:
86
- self.gen_kwargs[key] = task_specific_params[key]
87
- except TypeError:
88
- pass
89
 
90
  def generate(self, text: str, streamer=None, **generate_kwargs) -> (str, dict):
91
  # Replace two or more newlines with a single newline in text
@@ -148,7 +147,6 @@ class GeneratorFactory:
148
  # If the generator is not yet present, add it
149
  if not self.get_generator(model_name=model_name, task=task, desc=desc):
150
  g = Generator(model_name, task, desc, split_sentences)
151
- g.load()
152
  self.generators.append(g)
153
 
154
  def get_generator(self, **kwargs):
 
18
  return access_token
19
 
20
 
21
+ @st.cache_resource
22
  def load_model(model_name):
23
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
24
  tokenizer = AutoTokenizer.from_pretrained(
 
67
  self.load()
68
 
69
  def load(self):
70
+ print(f"Loading model {self.model_name}")
71
+ self.tokenizer, self.model = load_model(self.model_name)
72
+
73
+ for key in self.gen_kwargs:
74
+ if key in self.model.config.__dict__:
75
+ self.gen_kwargs[key] = self.model.config.__dict__[key]
76
+ try:
77
+ if self.task in self.model.config.task_specific_params:
78
+ task_specific_params = self.model.config.task_specific_params[
79
+ self.task
80
+ ]
81
+ if "prefix" in task_specific_params:
82
+ self.prefix = task_specific_params["prefix"]
83
+ for key in self.gen_kwargs:
84
+ if key in task_specific_params:
85
+ self.gen_kwargs[key] = task_specific_params[key]
86
+ except TypeError:
87
+ pass
 
88
 
89
  def generate(self, text: str, streamer=None, **generate_kwargs) -> (str, dict):
90
  # Replace two or more newlines with a single newline in text
 
147
  # If the generator is not yet present, add it
148
  if not self.get_generator(model_name=model_name, task=task, desc=desc):
149
  g = Generator(model_name, task, desc, split_sentences)
 
150
  self.generators.append(g)
151
 
152
  def get_generator(self, **kwargs):