Spaces:
Runtime error
Runtime error
# ########################################################################### | |
# | |
# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) | |
# (C) Cloudera, Inc. 2022 | |
# All rights reserved. | |
# | |
# Applicable Open Source License: Apache 2.0 | |
# | |
# NOTE: Cloudera open source products are modular software products | |
# made up of hundreds of individual components, each of which was | |
# individually copyrighted. Each Cloudera open source product is a | |
# collective work under U.S. Copyright Law. Your license to use the | |
# collective work is as provided in your written agreement with | |
# Cloudera. Used apart from the collective work, this file is | |
# licensed for your use pursuant to the open source license | |
# identified above. | |
# | |
# This code is provided to you pursuant a written agreement with | |
# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute | |
# this code. If you do not have a written agreement with Cloudera nor | |
# with an authorized and properly licensed third party, you do not | |
# have any rights to access nor to use this code. | |
# | |
# Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the | |
# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY | |
# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED | |
# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO | |
# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND | |
# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, | |
# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS | |
# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE | |
# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY | |
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR | |
# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES | |
# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF | |
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF | |
# DATA. | |
# | |
# ########################################################################### | |
from typing import List, Union | |
import torch | |
from transformers import pipeline | |
class StyleTransfer: | |
""" | |
Model wrapper for a Text2TextGeneration pipeline used to transfer a style attribute on a given piece of text. | |
Attributes: | |
model_identifier (str) - Path to the model that will be used by the pipeline to make predictions | |
max_gen_length (int) - Upper limit on number of tokens the model can generate as output | |
""" | |
def __init__( | |
self, | |
model_identifier: str, | |
max_gen_length: int = 200, | |
num_beams=4, | |
temperature=1, | |
): | |
self.model_identifier = model_identifier | |
self.max_gen_length = max_gen_length | |
self.num_beams = num_beams | |
self.temperature = temperature | |
self.device = torch.cuda.current_device() if torch.cuda.is_available() else -1 | |
self._build_pipeline() | |
def _build_pipeline(self): | |
self.pipeline = pipeline( | |
task="text2text-generation", | |
model=self.model_identifier, | |
device=self.device, | |
max_length=self.max_gen_length, | |
num_beams=self.num_beams, | |
temperature=self.temperature, | |
) | |
def transfer(self, input_text: Union[str, List[str]]) -> List[str]: | |
""" | |
Transfer the style attribute on a given piece of text using the | |
initialized `model_identifier`. | |
Args: | |
input_text (`str` or `List[str]`) - Input text for style transfer | |
Returns: | |
generated_text (`List[str]`) - The generated text outputs | |
""" | |
return [item["generated_text"] for item in self.pipeline(input_text)] | |