CamiloVega commited on
Commit
d283cbc
·
verified ·
1 Parent(s): 32f8ac6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -14
app.py CHANGED
@@ -47,7 +47,7 @@ class ModelManager:
47
  """Initialize models with optimized settings"""
48
  try:
49
  import torch
50
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
51
 
52
  HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN')
53
  if not HUGGINGFACE_TOKEN:
@@ -56,14 +56,6 @@ class ModelManager:
56
  logger.info("Starting model initialization...")
57
  model_name = "meta-llama/Llama-2-7b-chat-hf"
58
 
59
- # Configure 8-bit quantization instead of 4-bit
60
- bnb_config = BitsAndBytesConfig(
61
- load_in_8bit=True,
62
- bnb_8bit_use_double_quant=True,
63
- bnb_8bit_quant_type="nf8",
64
- bnb_8bit_compute_dtype=torch.float16
65
- )
66
-
67
  # Load tokenizer with optimized settings
68
  logger.info("Loading tokenizer...")
69
  self.tokenizer = AutoTokenizer.from_pretrained(
@@ -74,18 +66,18 @@ class ModelManager:
74
  )
75
  self.tokenizer.pad_token = self.tokenizer.eos_token
76
 
77
- # Initialize model with optimized settings
78
  logger.info("Loading model...")
79
  self.model = AutoModelForCausalLM.from_pretrained(
80
  model_name,
81
  token=HUGGINGFACE_TOKEN,
82
  device_map="auto",
83
  torch_dtype=torch.float16,
84
- quantization_config=bnb_config,
85
  low_cpu_mem_usage=True,
86
  )
87
 
88
- # Create optimized pipeline
89
  logger.info("Creating pipeline...")
90
  from transformers import pipeline
91
  self.news_generator = pipeline(
@@ -103,11 +95,11 @@ class ModelManager:
103
  early_stopping=True
104
  )
105
 
106
- # Load Whisper model with optimized settings
107
  logger.info("Loading Whisper model...")
108
  self.whisper_model = whisper.load_model(
109
  "tiny",
110
- device="cuda",
111
  download_root="/tmp/whisper",
112
  in_memory=True
113
  )
 
47
  """Initialize models with optimized settings"""
48
  try:
49
  import torch
50
+ from transformers import AutoModelForCausalLM, AutoTokenizer
51
 
52
  HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN')
53
  if not HUGGINGFACE_TOKEN:
 
56
  logger.info("Starting model initialization...")
57
  model_name = "meta-llama/Llama-2-7b-chat-hf"
58
 
 
 
 
 
 
 
 
 
59
  # Load tokenizer with optimized settings
60
  logger.info("Loading tokenizer...")
61
  self.tokenizer = AutoTokenizer.from_pretrained(
 
66
  )
67
  self.tokenizer.pad_token = self.tokenizer.eos_token
68
 
69
+ # Initialize model with basic settings
70
  logger.info("Loading model...")
71
  self.model = AutoModelForCausalLM.from_pretrained(
72
  model_name,
73
  token=HUGGINGFACE_TOKEN,
74
  device_map="auto",
75
  torch_dtype=torch.float16,
76
+ load_in_8bit=True,
77
  low_cpu_mem_usage=True,
78
  )
79
 
80
+ # Create pipeline
81
  logger.info("Creating pipeline...")
82
  from transformers import pipeline
83
  self.news_generator = pipeline(
 
95
  early_stopping=True
96
  )
97
 
98
+ # Load Whisper model with basic settings
99
  logger.info("Loading Whisper model...")
100
  self.whisper_model = whisper.load_model(
101
  "tiny",
102
+ device="cuda" if torch.cuda.is_available() else "cpu",
103
  download_root="/tmp/whisper",
104
  in_memory=True
105
  )