query2osm / app.py
ellenhp's picture
Remove flagging
1e324b4
from gradio.components import Component
import torch
from hydra import Hydra
from transformers import AutoTokenizer
import gradio as gr
from hydra import Hydra
import os
from typing import Any, Optional
model_name = "ellenhp/query2osm-bert-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True)
model = Hydra.from_pretrained(model_name).to('cpu')
def predict(input_query):
with torch.no_grad():
print(input_query)
input_text = input_query.strip().lower()
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.forward(inputs.input_ids)
return {classification[0]: classification[1] for classification in outputs.classifications[0]}
textbox = gr.Textbox(label="Query",
placeholder="Quick bite to eat near me")
label = gr.Label(label="Result", num_top_classes=5)
gradio_app = gr.Interface(
predict,
inputs=[textbox],
outputs=[label],
title="Query Classification",
)
if __name__ == "__main__":
gradio_app.launch()