Spaces:
Runtime error
Runtime error
# ########################################################################### | |
# | |
# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) | |
# (C) Cloudera, Inc. 2022 | |
# All rights reserved. | |
# | |
# Applicable Open Source License: Apache 2.0 | |
# | |
# NOTE: Cloudera open source products are modular software products | |
# made up of hundreds of individual components, each of which was | |
# individually copyrighted. Each Cloudera open source product is a | |
# collective work under U.S. Copyright Law. Your license to use the | |
# collective work is as provided in your written agreement with | |
# Cloudera. Used apart from the collective work, this file is | |
# licensed for your use pursuant to the open source license | |
# identified above. | |
# | |
# This code is provided to you pursuant a written agreement with | |
# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute | |
# this code. If you do not have a written agreement with Cloudera nor | |
# with an authorized and properly licensed third party, you do not | |
# have any rights to access nor to use this code. | |
# | |
# Absent a written agreement with Cloudera, Inc. (βClouderaβ) to the | |
# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY | |
# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED | |
# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO | |
# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND | |
# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, | |
# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS | |
# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE | |
# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY | |
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR | |
# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES | |
# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF | |
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF | |
# DATA. | |
# | |
# ########################################################################### | |
from typing import List | |
import tokenizers | |
import streamlit as st | |
from src.style_transfer import StyleTransfer | |
from src.style_classification import StyleIntensityClassifier | |
from src.content_preservation import ContentPreservationScorer | |
from src.transformer_interpretability import InterpretTransformer | |
from apps.data_utils import StyleAttributeData, string_to_list_string | |
# CALLBACKS | |
def increment_page_progress(): | |
st.session_state.page_progress += 1 | |
def reset_page_progress_state(): | |
del st.session_state.st_result | |
st.session_state.page_progress = 1 | |
# UTILITY CLASSES | |
class DisableableButton: | |
""" | |
Utility class for creating "disable-able" buttons upon click. | |
We initialize an empty container, then update that container with buttons | |
upon calling `create_enabled_button` and `disable` methods where clicking | |
is enabled and then disabled, respectively. | |
""" | |
def __init__(self, button_number, button_text): | |
self.button_number = button_number | |
self.button_text = button_text | |
def _init_placeholder_container(self): | |
self.ph = st.empty() | |
def create_enabled_button(self): | |
self._init_placeholder_container() | |
self.ph.button( | |
self.button_text, | |
on_click=increment_page_progress, | |
key=f"ph{self.button_number}_before", | |
disabled=False, | |
) | |
def disable(self): | |
self.ph.button( | |
self.button_text, key=f"ph{self.button_number}_after", disabled=True | |
) | |
# CACHED FUNCTIONS | |
def get_cached_style_intensity_classifier( | |
style_data: StyleAttributeData, | |
) -> StyleIntensityClassifier: | |
""" | |
Return a cached style classifier. | |
This function overwrites the existing model's config values for | |
`id2label` and `label2id`. | |
Args: | |
style_data (StyleAttributeData) | |
Returns: | |
StyleIntensityClassifier | |
""" | |
sic = StyleIntensityClassifier(style_data.cls_model_path) | |
# create or overwrite id-label lookup in model config | |
sic.pipeline.model.config.__dict__["id2label"] = { | |
i: a | |
for i, a in enumerate( | |
[ | |
style_data.source_attribute.capitalize(), | |
style_data.target_attribute.capitalize(), | |
] | |
) | |
} | |
sic.pipeline.model.config.__dict__["label2id"] = { | |
v: k for k, v in sic.pipeline.model.config.__dict__["id2label"].items() | |
} | |
return sic | |
def get_cached_word_attributions( | |
text_sample: str, style_data: StyleAttributeData | |
) -> str: | |
""" | |
Calculated word attributions and return HTML visual. | |
This function overwrites the existing model's config values for | |
`id2label` and `label2id`. | |
Args: | |
text_sample (str) | |
style_data (StyleAttributeData) | |
Returns: | |
str | |
""" | |
it = InterpretTransformer(cls_model_identifier=style_data.cls_model_path) | |
# create or overwrite id-label lookup in model config | |
it.explainer.id2label = { | |
i: a | |
for i, a in enumerate( | |
[ | |
style_data.source_attribute.capitalize(), | |
style_data.target_attribute.capitalize(), | |
] | |
) | |
} | |
it.explainer.label2id = {v: k for k, v in it.explainer.id2label.items()} | |
return it.visualize_feature_attribution_scores(text_sample).data | |
def get_sti_metric( | |
input_text: str, output_text: str, style_data: StyleAttributeData | |
) -> List[float]: | |
""" | |
Calculate Style Transfer Intensity (STI) | |
Args: | |
input_text (str) | |
output_text (str) | |
style_data (StyleAttributeData) | |
Returns: | |
List[float] | |
""" | |
sti = StyleIntensityClassifier( | |
model_identifier=style_data.cls_model_path, | |
) | |
return sti.calculate_transfer_intensity_fraction( | |
string_to_list_string(input_text), string_to_list_string(output_text) | |
) | |
def get_cps_metric( | |
input_text: str, output_text: str, style_data: StyleAttributeData | |
) -> List[float]: | |
""" | |
Calculate Content Preservation Score (CPS) | |
Args: | |
input_text (str) | |
output_text (str) | |
style_data (StyleAttributeData) | |
Returns: | |
List[float] | |
""" | |
cps = ContentPreservationScorer( | |
cls_model_identifier=style_data.cls_model_path, | |
sbert_model_identifier=style_data.sbert_model_path, | |
) | |
return cps.calculate_content_preservation_score( | |
string_to_list_string(input_text), | |
string_to_list_string(output_text), | |
mask_type="none", | |
) | |
def generate_style_transfer( | |
text_sample: str, | |
style_data: StyleAttributeData, | |
max_gen_length: int, | |
num_beams: int, | |
temperature: int, | |
): | |
""" | |
Run inference on seq2seq model and persist result to | |
`session_state` varaible. | |
Args: | |
text_sample (str): _description_ | |
style_data (StyleAttributeData): _description_ | |
max_gen_length (int): _description_ | |
num_beams (int): _description_ | |
temperature (int): _description_ | |
""" | |
with st.spinner("Transferring style, hang tight!"): | |
generate_kwargs = { | |
"max_gen_length": max_gen_length, | |
"num_beams": num_beams, | |
"temperature": temperature, | |
} | |
st_class = StyleTransfer( | |
model_identifier=style_data.seq2seq_model_path, | |
**generate_kwargs, | |
) | |
st_result = st_class.transfer(text_sample) | |
st.session_state.st_result = st_result | |