|
from gradio.components import Component |
|
import torch |
|
from hydra import Hydra |
|
from transformers import AutoTokenizer |
|
import gradio as gr |
|
from hydra import Hydra |
|
import os |
|
from typing import Any, Optional |
|
|
|
model_name = "ellenhp/query2osm-bert-v1" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True) |
|
model = Hydra.from_pretrained(model_name).to('cpu') |
|
|
|
def predict(input_query): |
|
with torch.no_grad(): |
|
print(input_query) |
|
input_text = input_query.strip().lower() |
|
inputs = tokenizer(input_text, return_tensors="pt") |
|
outputs = model.forward(inputs.input_ids) |
|
return {classification[0]: classification[1] for classification in outputs.classifications[0]} |
|
|
|
|
|
textbox = gr.Textbox(label="Query", |
|
placeholder="Quick bite to eat near me") |
|
label = gr.Label(label="Result", num_top_classes=5) |
|
|
|
gradio_app = gr.Interface( |
|
predict, |
|
inputs=[textbox], |
|
outputs=[label], |
|
title="Query Classification", |
|
) |
|
|
|
if __name__ == "__main__": |
|
gradio_app.launch() |
|
|