|
|
|
|
|
import re |
|
|
|
""" |
|
Extracts code from the file "./Libraries.ts". |
|
(Note that "Libraries.ts", must be in the same directory as |
|
this script). |
|
""" |
|
|
|
file = None |
|
|
|
def read_file(library: str, model_name: str) -> str: |
|
text = file |
|
|
|
match = re.search('const ' + library + '.*', text, re.DOTALL).group() |
|
if match: |
|
text = match[match.index('`') + 1:match.index('`;')].replace('${model.id}', model_name) |
|
|
|
return text |
|
|
|
file = """ |
|
import type { ModelData } from "./Types"; |
|
/** |
|
* Add your new library here. |
|
*/ |
|
export enum ModelLibrary { |
|
"adapter-transformers" = "Adapter Transformers", |
|
"allennlp" = "allenNLP", |
|
"asteroid" = "Asteroid", |
|
"diffusers" = "Diffusers", |
|
"espnet" = "ESPnet", |
|
"fairseq" = "Fairseq", |
|
"flair" = "Flair", |
|
"keras" = "Keras", |
|
"nemo" = "NeMo", |
|
"pyannote-audio" = "pyannote.audio", |
|
"sentence-transformers" = "Sentence Transformers", |
|
"sklearn" = "Scikit-learn", |
|
"spacy" = "spaCy", |
|
"speechbrain" = "speechbrain", |
|
"tensorflowtts" = "TensorFlowTTS", |
|
"timm" = "Timm", |
|
"fastai" = "fastai", |
|
"transformers" = "Transformers", |
|
"stanza" = "Stanza", |
|
"fasttext" = "fastText", |
|
"stable-baselines3" = "Stable-Baselines3", |
|
"ml-agents" = "ML-Agents", |
|
} |
|
|
|
export const ALL_MODEL_LIBRARY_KEYS = Object.keys(ModelLibrary) as (keyof typeof ModelLibrary)[]; |
|
|
|
|
|
/** |
|
* Elements configurable by a model library. |
|
*/ |
|
export interface LibraryUiElement { |
|
/** |
|
* Name displayed on the main |
|
* call-to-action button on the model page. |
|
*/ |
|
btnLabel: string; |
|
/** |
|
* Repo name |
|
*/ |
|
repoName: string; |
|
/** |
|
* URL to library's repo |
|
*/ |
|
repoUrl: string; |
|
/** |
|
* Code snippet displayed on model page |
|
*/ |
|
snippet: (model: ModelData) => string; |
|
} |
|
|
|
function nameWithoutNamespace(modelId: string): string { |
|
const splitted = modelId.split("/"); |
|
return splitted.length === 1 ? splitted[0] : splitted[1]; |
|
} |
|
|
|
//#region snippets |
|
|
|
const adapter_transformers = (model: ModelData) => |
|
`from transformers import ${model.config?.adapter_transformers?.model_class} |
|
|
|
model = ${model.config?.adapter_transformers?.model_class}.from_pretrained("${model.config?.adapter_transformers?.{model.id}}") |
|
model.load_adapter("${model.id}", source="hf")`; |
|
|
|
const allennlpUnknown = (model: ModelData) => |
|
`import allennlp_models |
|
from allennlp.predictors.predictor import Predictor |
|
|
|
predictor = Predictor.from_path("hf://${model.id}")`; |
|
|
|
const allennlpQuestionAnswering = (model: ModelData) => |
|
`import allennlp_models |
|
from allennlp.predictors.predictor import Predictor |
|
|
|
predictor = Predictor.from_path("hf://${model.id}") |
|
predictor_input = {"passage": "My name is Wolfgang and I live in Berlin", "question": "Where do I live?"} |
|
predictions = predictor.predict_json(predictor_input)`; |
|
|
|
const allennlp = (model: ModelData) => { |
|
if (model.tags?.includes("question-answering")) { |
|
return allennlpQuestionAnswering(model); |
|
} |
|
return allennlpUnknown(model); |
|
}; |
|
|
|
const asteroid = (model: ModelData) => |
|
`from asteroid.models import BaseModel |
|
|
|
model = BaseModel.from_pretrained("${model.id}")`; |
|
|
|
const diffusers = (model: ModelData) => |
|
`from diffusers import DiffusionPipeline |
|
|
|
pipeline = DiffusionPipeline.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`; |
|
|
|
const espnetTTS = (model: ModelData) => |
|
`from espnet2.bin.tts_inference import Text2Speech |
|
|
|
model = Text2Speech.from_pretrained("${model.id}") |
|
|
|
speech, *_ = model("text to generate speech from")`; |
|
|
|
const espnetASR = (model: ModelData) => |
|
`from espnet2.bin.asr_inference import Speech2Text |
|
|
|
model = Speech2Text.from_pretrained( |
|
"${model.id}" |
|
) |
|
|
|
speech, rate = soundfile.read("speech.wav") |
|
text, *_ = model(speech)`; |
|
|
|
const espnetUnknown = () => |
|
`unknown model type (must be text-to-speech or automatic-speech-recognition)`; |
|
|
|
const espnet = (model: ModelData) => { |
|
if (model.tags?.includes("text-to-speech")) { |
|
return espnetTTS(model); |
|
} else if (model.tags?.includes("automatic-speech-recognition")) { |
|
return espnetASR(model); |
|
} |
|
return espnetUnknown(); |
|
}; |
|
|
|
const fairseq = (model: ModelData) => |
|
`from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub |
|
|
|
models, cfg, task = load_model_ensemble_and_task_from_hf_hub( |
|
"${model.id}" |
|
)`; |
|
|
|
|
|
const flair = (model: ModelData) => |
|
`from flair.models import SequenceTagger |
|
|
|
tagger = SequenceTagger.load("${model.id}")`; |
|
|
|
const keras = (model: ModelData) => |
|
`from huggingface_hub import from_pretrained_keras |
|
|
|
model = from_pretrained_keras("${model.id}") |
|
`; |
|
|
|
const pyannote_audio_pipeline = (model: ModelData) => |
|
`from pyannote.audio import Pipeline |
|
|
|
pipeline = Pipeline.from_pretrained("${model.id}") |
|
|
|
# inference on the whole file |
|
pipeline("file.wav") |
|
|
|
# inference on an excerpt |
|
from pyannote.core import Segment |
|
excerpt = Segment(start=2.0, end=5.0) |
|
|
|
from pyannote.audio import Audio |
|
waveform, sample_rate = Audio().crop("file.wav", excerpt) |
|
pipeline({"waveform": waveform, "sample_rate": sample_rate})`; |
|
|
|
const pyannote_audio_model = (model: ModelData) => |
|
`from pyannote.audio import Model, Inference |
|
|
|
model = Model.from_pretrained("${model.id}") |
|
inference = Inference(model) |
|
|
|
# inference on the whole file |
|
inference("file.wav") |
|
|
|
# inference on an excerpt |
|
from pyannote.core import Segment |
|
excerpt = Segment(start=2.0, end=5.0) |
|
inference.crop("file.wav", excerpt)`; |
|
|
|
const pyannote_audio = (model: ModelData) => { |
|
if (model.tags?.includes("pyannote-audio-pipeline")) { |
|
return pyannote_audio_pipeline(model); |
|
} |
|
return pyannote_audio_model(model); |
|
}; |
|
|
|
const tensorflowttsTextToMel = (model: ModelData) => |
|
`from tensorflow_tts.inference import AutoProcessor, TFAutoModel |
|
|
|
processor = AutoProcessor.from_pretrained("${model.id}") |
|
model = TFAutoModel.from_pretrained("${model.id}") |
|
`; |
|
|
|
const tensorflowttsMelToWav = (model: ModelData) => |
|
`from tensorflow_tts.inference import TFAutoModel |
|
|
|
model = TFAutoModel.from_pretrained("${model.id}") |
|
audios = model.inference(mels) |
|
`; |
|
|
|
const tensorflowttsUnknown = (model: ModelData) => |
|
`from tensorflow_tts.inference import TFAutoModel |
|
|
|
model = TFAutoModel.from_pretrained("${model.id}") |
|
`; |
|
|
|
const tensorflowtts = (model: ModelData) => { |
|
if (model.tags?.includes("text-to-mel")) { |
|
return tensorflowttsTextToMel(model); |
|
} else if (model.tags?.includes("mel-to-wav")) { |
|
return tensorflowttsMelToWav(model); |
|
} |
|
return tensorflowttsUnknown(model); |
|
}; |
|
|
|
const timm = (model: ModelData) => |
|
`import timm |
|
|
|
model = timm.create_model("hf_hub:${model.id}", pretrained=True)`; |
|
|
|
const sklearn = (model: ModelData) => |
|
`from huggingface_hub import hf_hub_download |
|
import joblib |
|
|
|
model = joblib.load( |
|
hf_hub_download("${model.id}", "sklearn_model.joblib") |
|
)`; |
|
|
|
const fastai = (model: ModelData) => |
|
`from huggingface_hub import from_pretrained_fastai |
|
|
|
learn = from_pretrained_fastai("${model.id}")`; |
|
|
|
const sentenceTransformers = (model: ModelData) => |
|
`from sentence_transformers import SentenceTransformer |
|
|
|
model = SentenceTransformer("${model.id}")`; |
|
|
|
const spacy = (model: ModelData) => |
|
`!pip install https://huggingface.co/${model.id}/resolve/main/${nameWithoutNamespace(model.id)}-any-py3-none-any.whl |
|
|
|
# Using spacy.load(). |
|
import spacy |
|
nlp = spacy.load("${nameWithoutNamespace(model.id)}") |
|
|
|
# Importing as module. |
|
import ${nameWithoutNamespace(model.id)} |
|
nlp = ${nameWithoutNamespace(model.id)}.load()`; |
|
|
|
const stanza = (model: ModelData) => |
|
`import stanza |
|
|
|
stanza.download("${nameWithoutNamespace(model.id).replace("stanza-", "")}") |
|
nlp = stanza.Pipeline("${nameWithoutNamespace(model.id).replace("stanza-", "")}")`; |
|
|
|
|
|
const speechBrainMethod = (speechbrainInterface: string) => { |
|
switch (speechbrainInterface) { |
|
case "EncoderClassifier": |
|
return "classify_file"; |
|
case "EncoderDecoderASR": |
|
case "EncoderASR": |
|
return "transcribe_file"; |
|
case "SpectralMaskEnhancement": |
|
return "enhance_file"; |
|
case "SepformerSeparation": |
|
return "separate_file"; |
|
default: |
|
return undefined; |
|
} |
|
}; |
|
|
|
const speechbrain = (model: ModelData) => { |
|
const speechbrainInterface = model.config?.speechbrain?.interface; |
|
if (speechbrainInterface === undefined) { |
|
return `# interface not specified in config.json`; |
|
} |
|
|
|
const speechbrainMethod = speechBrainMethod(speechbrainInterface); |
|
if (speechbrainMethod === undefined) { |
|
return `# interface in config.json invalid`; |
|
} |
|
|
|
return `from speechbrain.pretrained import ${speechbrainInterface} |
|
model = ${speechbrainInterface}.from_hparams( |
|
"${model.id}" |
|
) |
|
model.${speechbrainMethod}("file.wav")`; |
|
}; |
|
|
|
const transformers = (model: ModelData) => { |
|
const info = model.transformersInfo; |
|
if (!info) { |
|
return `# ⚠️ Type of model unknown`; |
|
} |
|
if (info.processor) { |
|
const varName = info.processor === "AutoTokenizer" ? "tokenizer" |
|
: info.processor === "AutoFeatureExtractor" ? "extractor" |
|
: "processor" |
|
; |
|
return [ |
|
`from transformers import ${info.processor}, ${info.auto_model}`, |
|
"", |
|
`${varName} = ${info.processor}.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`, |
|
"", |
|
`model = ${info.auto_model}.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`, |
|
].join("\n"); |
|
} else { |
|
return [ |
|
`from transformers import ${info.auto_model}`, |
|
"", |
|
`model = ${info.auto_model}.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`, |
|
].join("\n"); |
|
} |
|
}; |
|
|
|
const fasttext = (model: ModelData) => |
|
`from huggingface_hub import hf_hub_download |
|
import fasttext |
|
|
|
model = fasttext.load_model(hf_hub_download("${model.id}", "model.bin"))`; |
|
|
|
const stableBaselines3 = (model: ModelData) => |
|
`from huggingface_sb3 import load_from_hub |
|
checkpoint = load_from_hub( |
|
repo_id="${model.id}", |
|
filename="{MODEL FILENAME}.zip", |
|
)`; |
|
|
|
const nemoDomainResolver = (domain: string, model: ModelData): string | undefined => { |
|
const modelName = `${nameWithoutNamespace(model.id)}.nemo`; |
|
|
|
switch (domain) { |
|
case "ASR": |
|
return `import nemo.collections.asr as nemo_asr |
|
asr_model = nemo_asr.models.ASRModel.from_pretrained("${model.id}") |
|
|
|
transcriptions = asr_model.transcribe(["file.wav"])`; |
|
default: |
|
return undefined; |
|
} |
|
}; |
|
|
|
const mlAgents = (model: ModelData) => |
|
`mlagents-load-from-hf --repo-id="${model.id}" --local-dir="./downloads"`; |
|
|
|
const nemo = (model: ModelData) => { |
|
let command: string | undefined = undefined; |
|
// Resolve the tag to a nemo domain/sub-domain |
|
if (model.tags?.includes("automatic-speech-recognition")) { |
|
command = nemoDomainResolver("ASR", model); |
|
} |
|
|
|
return command ?? `# tag did not correspond to a valid NeMo domain.`; |
|
}; |
|
|
|
//#endregion |
|
|
|
|
|
|
|
export const MODEL_LIBRARIES_UI_ELEMENTS: { [key in keyof typeof ModelLibrary]?: LibraryUiElement } = { |
|
// ^^ TODO(remove the optional ? marker when Stanza snippet is available) |
|
"adapter-transformers": { |
|
btnLabel: "Adapter Transformers", |
|
repoName: "adapter-transformers", |
|
repoUrl: "https://github.com/Adapter-Hub/adapter-transformers", |
|
snippet: adapter_transformers, |
|
}, |
|
"allennlp": { |
|
btnLabel: "AllenNLP", |
|
repoName: "AllenNLP", |
|
repoUrl: "https://github.com/allenai/allennlp", |
|
snippet: allennlp, |
|
}, |
|
"asteroid": { |
|
btnLabel: "Asteroid", |
|
repoName: "Asteroid", |
|
repoUrl: "https://github.com/asteroid-team/asteroid", |
|
snippet: asteroid, |
|
}, |
|
"diffusers": { |
|
btnLabel: "Diffusers", |
|
repoName: "🤗/diffusers", |
|
repoUrl: "https://github.com/huggingface/diffusers", |
|
snippet: diffusers, |
|
}, |
|
"espnet": { |
|
btnLabel: "ESPnet", |
|
repoName: "ESPnet", |
|
repoUrl: "https://github.com/espnet/espnet", |
|
snippet: espnet, |
|
}, |
|
"fairseq": { |
|
btnLabel: "Fairseq", |
|
repoName: "fairseq", |
|
repoUrl: "https://github.com/pytorch/fairseq", |
|
snippet: fairseq, |
|
}, |
|
"flair": { |
|
btnLabel: "Flair", |
|
repoName: "Flair", |
|
repoUrl: "https://github.com/flairNLP/flair", |
|
snippet: flair, |
|
}, |
|
"keras": { |
|
btnLabel: "Keras", |
|
repoName: "Keras", |
|
repoUrl: "https://github.com/keras-team/keras", |
|
snippet: keras, |
|
}, |
|
"nemo": { |
|
btnLabel: "NeMo", |
|
repoName: "NeMo", |
|
repoUrl: "https://github.com/NVIDIA/NeMo", |
|
snippet: nemo, |
|
}, |
|
"pyannote-audio": { |
|
btnLabel: "pyannote.audio", |
|
repoName: "pyannote-audio", |
|
repoUrl: "https://github.com/pyannote/pyannote-audio", |
|
snippet: pyannote_audio, |
|
}, |
|
"sentence-transformers": { |
|
btnLabel: "sentence-transformers", |
|
repoName: "sentence-transformers", |
|
repoUrl: "https://github.com/UKPLab/sentence-transformers", |
|
snippet: sentenceTransformers, |
|
}, |
|
"sklearn": { |
|
btnLabel: "Scikit-learn", |
|
repoName: "Scikit-learn", |
|
repoUrl: "https://github.com/scikit-learn/scikit-learn", |
|
snippet: sklearn, |
|
}, |
|
"fastai": { |
|
btnLabel: "fastai", |
|
repoName: "fastai", |
|
repoUrl: "https://github.com/fastai/fastai", |
|
snippet: fastai, |
|
}, |
|
"spacy": { |
|
btnLabel: "spaCy", |
|
repoName: "spaCy", |
|
repoUrl: "https://github.com/explosion/spaCy", |
|
snippet: spacy, |
|
}, |
|
"speechbrain": { |
|
btnLabel: "speechbrain", |
|
repoName: "speechbrain", |
|
repoUrl: "https://github.com/speechbrain/speechbrain", |
|
snippet: speechbrain, |
|
}, |
|
"stanza": { |
|
btnLabel: "Stanza", |
|
repoName: "stanza", |
|
repoUrl: "https://github.com/stanfordnlp/stanza", |
|
snippet: stanza, |
|
}, |
|
"tensorflowtts": { |
|
btnLabel: "TensorFlowTTS", |
|
repoName: "TensorFlowTTS", |
|
repoUrl: "https://github.com/TensorSpeech/TensorFlowTTS", |
|
snippet: tensorflowtts, |
|
}, |
|
"timm": { |
|
btnLabel: "timm", |
|
repoName: "pytorch-image-models", |
|
repoUrl: "https://github.com/rwightman/pytorch-image-models", |
|
snippet: timm, |
|
}, |
|
"transformers": { |
|
btnLabel: "Transformers", |
|
repoName: "🤗/transformers", |
|
repoUrl: "https://github.com/huggingface/transformers", |
|
snippet: transformers, |
|
}, |
|
"fasttext": { |
|
btnLabel: "fastText", |
|
repoName: "fastText", |
|
repoUrl: "https://fasttext.cc/", |
|
snippet: fasttext, |
|
}, |
|
"stable-baselines3": { |
|
btnLabel: "stable-baselines3", |
|
repoName: "stable-baselines3", |
|
repoUrl: "https://github.com/huggingface/huggingface_sb3", |
|
snippet: stableBaselines3, |
|
}, |
|
"ml-agents": { |
|
btnLabel: "ml-agents", |
|
repoName: "ml-agents", |
|
repoUrl: "https://github.com/huggingface/ml-agents", |
|
snippet: mlAgents, |
|
}, |
|
} as const; |
|
""" |
|
|
|
|
|
if __name__ == '__main__': |
|
import sys |
|
library_name = "keras" |
|
model_name = "Distillgpt2" |
|
print(read_file(library_name, model_name)) |
|
|
|
"""" |
|
try: |
|
args = sys.argv[1:] |
|
if args: |
|
print(read_file(args[0], args[1])) |
|
except IndexError: |
|
pass |
|
""" |