codegen / services /model_generator.py
AP\VivekIsh
codegen: Stage the code
6fadbbc
raw
history blame
1.85 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.config import Config
from utils.logger import Logger
logger = Logger.get_logger(__name__)
class ModelGenerator:
"""
Singleton class responsible for generating text using a specified language model.
This class initializes a language model and tokenizer, and provides methods
to generate text and extract code blocks from generated text.
Attributes:
device (torch.device): Device to run the model on (CPU or GPU).
model (AutoModelForCausalLM): Language model for text generation.
tokenizer (AutoTokenizer): Tokenizer corresponding to the language model.
Methods:
acceptTextGenerator(self, visitor, *args, **kwargs):
Accepts a visitor to generates text based on the input provided with the model generator.
acceptExtractCodeBlock(self, visitor, *args, **kwargs):
Accepts a visitor to extract code blocks from the output text.
"""
_instance = None
_format_data_time = "%Y-%m-%d %H:%M:%S"
def __new__(cls, model_name=Config.read('app', 'model')):
if cls._instance is None:
cls._instance = super(ModelGenerator, cls).__new__(cls)
cls._instance._initialize(model_name)
return cls._instance
def _initialize(self, model_name):
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
self.model = AutoModelForCausalLM.from_pretrained(
model_name).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def acceptTextGenerator(self, visitor, *args, **kwargs):
return visitor.visit(self, *args, **kwargs)
def acceptExtractCodeBlock(self, visitor, *args, **kwargs):
return visitor.visit(self, *args, **kwargs)