File size: 3,607 Bytes
6a5c7ef
 
49183b2
bef811e
6a5c7ef
49183b2
01462ec
65d5227
5698a84
65d5227
 
 
 
 
 
 
49183b2
6a5c7ef
 
49183b2
65d5227
49183b2
 
 
65d5227
49183b2
 
 
 
6a5c7ef
49183b2
 
6a5c7ef
7bc8f9c
49183b2
5698a84
 
 
 
 
 
 
 
bef811e
59ce7be
bef811e
49183b2
5698a84
 
 
 
bef811e
5698a84
bef811e
49183b2
 
 
5698a84
 
 
 
bef811e
 
 
 
49183b2
 
 
 
509133e
 
 
5698a84
509133e
 
5698a84
 
 
 
 
 
 
 
509133e
 
 
 
 
 
 
 
 
 
 
 
 
 
5698a84
509133e
 
5698a84
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import gradio as gr
import tensorflow as tf  # version 2.13.0
from keras.models import load_model
import cv2
import json
import os

def analyse(img, plant_type):
    # Load label_disease.json
    with open('data/label_disease.json', 'r') as f:
        label_disease = json.load(f)

    # Load plant_label_disease.json
    with open('data/plant_label_disease.json', 'r') as f:
        plant_label_disease = json.load(f)

    HEIGHT = 256
    WIDTH = 256
    modelArchitecturePath = 'model/model_architecture.h5'
    modelWeightsPath = 'model/model_weights.h5'

    # Load the model
    dnn_model = load_model(modelArchitecturePath, compile=False)
    dnn_model.load_weights(modelWeightsPath)

    # Preprocess the image
    process_img = cv2.resize(img, (HEIGHT, WIDTH), interpolation=cv2.INTER_LINEAR)
    process_img = process_img / 255.0
    process_img = np.expand_dims(process_img, axis=0)

    # Predict using the model
    y_pred = dnn_model.predict(process_img)
    y_pred = y_pred[0]

    # Identify plant-specific predictions
    plant_label_ids = plant_label_disease[plant_type.lower()]
    plant_predicted_id = plant_label_ids[0]
    for disease in plant_label_ids:
        if y_pred[disease] > y_pred[plant_predicted_id]:
            plant_predicted_id = disease

    # Determine overall prediction
    overall_predicted_id = int(np.argmax(y_pred))
    overall_predicted_name = label_disease[str(overall_predicted_id)]
    overall_predicted_confidence = float(y_pred[overall_predicted_id])

    # Determine plant-specific prediction
    plant_predicted_name = label_disease[str(plant_predicted_id)]
    plant_predicted_confidence = float(y_pred[plant_predicted_id])

    # Determine health status
    is_plant_specific_healthy = "healthy" in plant_predicted_name.lower()
    is_overall_healthy = "healthy" in overall_predicted_name.lower()

    # Return results as a JSON object
    result = {
        "plant_specific_prediction_id": plant_predicted_id,
        "plant_specific_prediction_name": plant_predicted_name,
        "plant_specific_confidence": plant_predicted_confidence,
        "is_plant_specific_healthy": is_plant_specific_healthy,
        "overall_prediction_id": overall_predicted_id,
        "overall_prediction_name": overall_predicted_name,
        "overall_confidence": overall_predicted_confidence,
        "is_overall_healthy": is_overall_healthy
    }

    return result

# Build the Gradio Blocks interface
with gr.Blocks() as demo:
    gr.Markdown("## Plant Disease Detection")
    gr.Markdown("Upload an image of a plant leaf and select the plant type to detect diseases.")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Upload Image", type="numpy")
            plant_type = gr.Radio(
                ["Apple", "Blueberry", "Cherry", "Corn", "Grape", "Orange", "Peach",
                 "Pepper", "Potato", "Raspberry", "Soybean", "Squash", "Strawberry", "Tomato"],
                label="Plant Type"
            )
            submit = gr.Button("Analyze")
            
        with gr.Column():
            result_json = gr.JSON(label="Analysis Result")

    # Example images section
    gr.Examples(
        examples=[os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))],
        inputs=[input_image],
        label="Examples",
        cache_examples=False,
        examples_per_page=8
    )
    
    # Define interaction
    submit.click(fn=analyse, inputs=[input_image, plant_type], outputs=result_json)

# Launch the application
demo.launch(share=True, show_error=True)