Alain Vaucher
Add the pretrained model
8bbfef5
raw
history blame
4.92 kB
import functools
import html
import logging
import textwrap
import traceback
from pathlib import Path
from typing import List
import gradio as gr
import pandas as pd
from rxn.utilities.logging import setup_console_logger
from rxn.utilities.strings import remove_postfix
from utils import TranslatorWithSentencePiece, download_cde_data, split_into_sentences
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
_SAC_MODEL_TAG = "Heterogenous SAC model (ACE)"
_ORGANIC_MODEL_TAG = "Organic chemistry model"
_PRETRAINED_MODEL_TAG = "Pretrained model"
model_type_to_models = {
_SAC_MODEL_TAG: ["sac.pt"],
_ORGANIC_MODEL_TAG: ["organic-1.pt", "organic-2.pt", "organic-3.pt"],
_PRETRAINED_MODEL_TAG: ["pretrained.pt"],
}
SYNTHESIS_TEXT_PLACEHOLDER = (
"Enter the synthesis procedure here, or click on one of the examples below."
)
@functools.lru_cache
def load_model(model_type: str) -> TranslatorWithSentencePiece:
logger.info(f"Loading model {model_type}... ")
model_files = model_type_to_models[model_type]
sp_model = "sp_model.model"
model = TranslatorWithSentencePiece(
translation_model=model_files,
sentencepiece_model=sp_model,
)
logger.info(f"Loading model {model_type}... Done.")
return model
def sentence_and_actions_to_html(
sentence: str, action_string: str, show_sentences: bool
) -> str:
output = ""
li_start = '<li style="margin-left: 12px;">' if show_sentences else "<li>"
li_end = "</li>"
action_string = remove_postfix(action_string, ".")
if show_sentences:
output += f"<p>{sentence}</p>"
actions = [f"{li_start}{action}{li_end}" for action in action_string.split("; ")]
output += "".join(actions)
if show_sentences:
# If we show the sentence, we need the list/enumeration delimiters,
# as there is one list per sentence.
output = f"<ol>{output}</ol>"
return output
def try_action_extraction(model_type: str, text: str, show_sentences: bool) -> str:
logger.info(f'Extracting actions from paragraph "{textwrap.shorten(text, 60)}"...')
download_cde_data()
model = load_model(model_type)
logger.info(f"Splitting paragraph into sentences...")
sentences = split_into_sentences(text)
logger.info(f"Splitting paragraph into sentences... Done.")
logger.info(f"Translation with OpenNMT...")
action_strings = model.translate(sentences)
logger.info(f"Translation with OpenNMT... Done.")
output = ""
for sentence, action_string in zip(sentences, action_strings):
output += sentence_and_actions_to_html(sentence, action_string, show_sentences)
if not show_sentences:
# If the sentences were not shown, we need to add the list/enumeration
# delimiters here (globally)
output = f"<ol>{output}</ol>"
# PostTreatment was renamed to ThermalTreatment, old model still relies on the former
output = output.replace("POSTTREATMENT", "THERMALTREATMENT")
logger.info(
f'Extracting actions from paragraph "{textwrap.shorten(text, 60)}"... Done.'
)
return output
def action_extraction(model_type: str, text: str, show_sentences: bool) -> str:
try:
return try_action_extraction(model_type, text, show_sentences)
except Exception as e:
tb = "".join(traceback.TracebackException.from_exception(e).format())
tb_html = f"<pre>{html.escape(tb)}</pre>"
return f"<p><b>Error!</b> The action extraction failed: {e}</p>{tb_html}"
def launch() -> gr.Interface:
logger.info("Launching the Gradio app")
metadata_dir = Path(__file__).parent / "model_cards"
examples_df: pd.DataFrame = pd.read_csv(
metadata_dir / "sac_synthesis_mining_examples.csv", header=None
).fillna("")
examples: List[List[str]] = examples_df.to_numpy().tolist()
with open(metadata_dir / "sac_synthesis_mining_article.md", "r") as f:
article = f.read()
with open(metadata_dir / "sac_synthesis_mining_description.md", "r") as f:
description = f.read()
demo = gr.Interface(
fn=action_extraction,
title="Extraction of synthesis protocols from paragraphs in machine readable format",
inputs=[
gr.Dropdown(
[_SAC_MODEL_TAG, _ORGANIC_MODEL_TAG, _PRETRAINED_MODEL_TAG],
label="Model",
value=_SAC_MODEL_TAG,
),
gr.Textbox(label="Synthesis text", lines=7, placeholder=SYNTHESIS_TEXT_PLACEHOLDER),
gr.Checkbox(label="Show sentences in the output"),
],
outputs=gr.HTML(label="Output"),
article=article,
description=description,
examples=examples,
allow_flagging="never",
theme="gradio/base",
)
demo.launch(debug=True, show_error=True)
return demo
setup_console_logger(level="INFO")
demo = launch()