from threading import Thread from typing import Tuple, Generator from optimum.bettertransformer import BetterTransformer import streamlit as st import torch from torch.quantization import quantize_dynamic from torch import nn, qint8 from transformers import T5ForConditionalGeneration, T5Tokenizer, TextStreamer, TextIteratorStreamer @st.cache_resource(show_spinner=False) def get_resources(quantize: bool = True, no_cuda: bool = False) -> Tuple[T5ForConditionalGeneration, T5Tokenizer, TextIteratorStreamer]: """ """ tokenizer = T5Tokenizer.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023", use_fast=False) model = T5ForConditionalGeneration.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023") model = BetterTransformer.transform(model, keep_original_model=False) model.resize_token_embeddings(len(tokenizer)) if torch.cuda.is_available() and not no_cuda: model = model.to("cuda") elif quantize: # Quantization not supported on CUDA model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8) model.eval() streamer = TextIteratorStreamer(tokenizer, decode_kwargs={"skip_special_tokens": True, "clean_up_tokenization_spaces": True}) return model, tokenizer, streamer def simplify( text: str, model: T5ForConditionalGeneration, tokenizer: T5Tokenizer, streamer: TextIteratorStreamer ) -> Generator: """ """ text = "[NLG] " + text encoded = tokenizer(text, return_tensors="pt") encoded = {k: v.to(model.device) for k, v in encoded.items()} gen_kwargs = { **encoded, "max_new_tokens": 128, "streamer": streamer, } with torch.no_grad(): thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text yield generated_text