import torch import time class ModelHandler: def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer self.device = "cuda" if torch.cuda.is_available() else "cpu" def generate_response(self, conversation): inputs = self.tokenizer(conversation, return_tensors="pt", truncation=True, max_length=1024) inputs = {k: v.to(self.device) for k, v in inputs.items()} start_time = time.time() output = "" with torch.no_grad(): for _ in range(150): # Increased range for potentially longer responses generated = self.model.generate( **inputs, max_new_tokens=1, do_sample=True, top_k=50, top_p=0.95 ) new_token = generated[0, -1].item() new_word = self.tokenizer.decode([new_token]) output += new_word inputs = self.tokenizer(conversation + output, return_tensors="pt", truncation=True, max_length=1024) inputs = {k: v.to(self.device) for k, v in inputs.items()} if time.time() - start_time >= 0.01: yield output start_time = time.time() if new_token == self.tokenizer.eos_token_id: break return output.strip()