File size: 4,922 Bytes
5741766 143d1fc 5741766 143d1fc 5741766 143d1fc 5741766 143d1fc 5741766 143d1fc 5741766 143d1fc 5741766 143d1fc 5741766 143d1fc 5741766 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# Lint as: python3
r"""Example demo loading a handful of GLUE models.
For a quick-start set of models, run:
python -m lit_nlp.examples.glue_demo \
--quickstart --port=5432
To run with the 'normal' defaults, including full-size BERT models:
python -m lit_nlp.examples.glue_demo --port=5432
Then navigate to localhost:5432 to access the demo UI.
"""
import sys
from absl import app
from absl import flags
from absl import logging
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.examples.datasets import glue
from lit_nlp.examples.models import glue_models
import transformers # for path caching
# NOTE: additional flags defined in server_flags.py
FLAGS = flags.FLAGS
FLAGS.set_default("development_demo", True)
flags.DEFINE_bool(
"quickstart", False,
"Quick-start mode, loads smaller models and a subset of the full data.")
flags.DEFINE_list(
"models", [
"sst2-tiny:sst2:https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz",
"sst2-base:sst2:https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_base.tar.gz",
"stsb:stsb:https://storage.googleapis.com/what-if-tool-resources/lit-models/stsb_base.tar.gz",
"mnli:mnli:https://storage.googleapis.com/what-if-tool-resources/lit-models/mnli_base.tar.gz",
], "List of models to load, as <name>:<task>:<path>. "
"See MODELS_BY_TASK for available tasks. Path should be the output of "
"saving a transformers model, e.g. model.save_pretrained(path) and "
"tokenizer.save_pretrained(path). Remote .tar.gz files will be downloaded "
"and cached locally.")
flags.DEFINE_integer(
"max_examples", None, "Maximum number of examples to load into LIT. "
"Note: MNLI eval set is 10k examples, so will take a while to run and may "
"be slow on older machines. Set --max_examples=200 for a quick start.")
MODELS_BY_TASK = {
"sst2": glue_models.SST2Model,
"stsb": glue_models.STSBModel,
"mnli": glue_models.MNLIModel,
}
# Pre-specified set of small models, which will load and run much faster.
QUICK_START_MODELS = (
"sst2-tiny:sst2:https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz",
"sst2-small:sst2:https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_small.tar.gz",
"stsb-tiny:stsb:https://storage.googleapis.com/what-if-tool-resources/lit-models/stsb_tiny.tar.gz",
"mnli-small:mnli:https://storage.googleapis.com/what-if-tool-resources/lit-models/mnli_small.tar.gz",
)
def get_wsgi_app():
"""Return WSGI app for container-hosted demos."""
FLAGS.set_default("server_type", "external")
FLAGS.set_default("demo_mode", True)
# Parse flags without calling app.run(main), to avoid conflict with
# gunicorn command line flags.
unused = flags.FLAGS(sys.argv, known_only=True)
return main(unused)
def main(_):
# Quick-start mode.
if FLAGS.quickstart:
FLAGS.models = QUICK_START_MODELS # smaller, faster models
if FLAGS.max_examples is None or FLAGS.max_examples > 1000:
FLAGS.max_examples = 1000 # truncate larger eval sets
logging.info("Quick-start mode; overriding --models and --max_examples.")
models = {}
datasets = {}
tasks_to_load = set()
for model_string in FLAGS.models:
# Only split on the first two ':', because path may be a URL
# containing 'https://'
name, task, path = model_string.split(":", 2)
logging.info("Loading model '%s' for task '%s' from '%s'", name, task, path)
# Normally path is a directory; if it's an archive file, download and
# extract to the transformers cache.
if path.endswith(".tar.gz"):
path = transformers.file_utils.cached_path(
path, extract_compressed_file=True)
# Load the model from disk.
models[name] = MODELS_BY_TASK[task](path)
tasks_to_load.add(task)
##
# Load datasets for each task that we have a model for
if "sst2" in tasks_to_load:
logging.info("Loading data for SST-2 task.")
datasets["sst_dev"] = glue.SST2Data("validation")
if "stsb" in tasks_to_load:
logging.info("Loading data for STS-B task.")
datasets["stsb_dev"] = glue.STSBData("validation")
if "mnli" in tasks_to_load:
logging.info("Loading data for MultiNLI task.")
datasets["mnli_dev"] = glue.MNLIData("validation_matched")
datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched")
# Truncate datasets if --max_examples is set.
for name in datasets:
logging.info("Dataset: '%s' with %d examples", name, len(datasets[name]))
datasets[name] = datasets[name].slice[:FLAGS.max_examples]
logging.info(" truncated to %d examples", len(datasets[name]))
# Start the LIT server. See server_flags.py for server options.
lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
return lit_demo.serve()
if __name__ == "__main__":
app.run(main)
|