import pickle import random import pandas as pd import gradio as gr from fastai.vision.all import * zone_lookup = pd.read_csv('./data/zone_lookup.csv') with open('./models/lin_reg.bin', 'rb') as handle: dv, model = pickle.load(handle) def prepare_features(pickup, dropoff, trip_distance): pickupId = zone_lookup[zone_lookup["borough_zone"] == pickup].LocationID dropoffId = zone_lookup[zone_lookup["borough_zone"] == dropoff].LocationID trip_distance = round(trip_distance, 4) features = {} features['PU_DO'] = '%s_%s' % (pickupId, dropoffId) features['trip_distance'] = trip_distance return features def predict(pickup, dropoff, trip_distance): features = prepare_features(pickup, dropoff, trip_distance) X = dv.transform(features) preds = model.predict(X) duration = float(preds[0]) return "The predicted duration is %.4f minutes." % duration with gr.Blocks() as demo: gr.Markdown("Predict Taxi Duration or Classify dog breeds using this demo") with gr.Tab("Predict Taxi Duration"): with gr.Row(): pickup = gr.Dropdown( choices=list(zone_lookup["borough_zone"]), label="Pickup Location", info="The location where the passenger(s) were picked up", value=lambda: random.choice(zone_lookup["borough_zone"]) ) dropoff = gr.Dropdown( choices=list(zone_lookup["borough_zone"]), label="Dropoff Location", info="The location where the passenger(s) were dropped off", value=lambda: random.choice(zone_lookup["borough_zone"]) ) trip_distance = gr.Slider( minimum=0.0, maximum=100.0, step=0.1, label="Trip Distance", info="The trip distance in miles calculated by the taximeter", value=lambda: random.uniform(0.0, 100.0) ) with gr.Column(): output = gr.Textbox(label="Output Box") predict_btn = gr.Button("Predict") with gr.Tab("Classify Dog Breed"): def is_cat(x): return x[0].isupper() learn = load_learner('./models/model.pkl') categories = ('Dog', 'Cat') def classify_image(img): pred, idx, probs = learn.predict(img) return dict(zip(categories, map(float,probs))) image = gr.inputs.Image(shape=(192, 192)) label = gr.outputs.Label() examples = ['dog.jpg', 'cat.jpg', 'dunno.jpg'] classify_btn = gr.Button("Predict") predict_btn.click(fn=predict, inputs=[pickup, dropoff, trip_distance], outputs=output, api_name="predict-duration") classify_btn.click(fn=classify_image, inputs=image, outputs=label, api_name="classify-dog-breed") demo.launch()