File size: 2,335 Bytes
8d12dce
bc18618
 
a332ed4
5da6a11
c9d6d7a
bc18618
c9d6d7a
efac948
6bb1d3a
bc18618
 
c9d6d7a
bc18618
 
 
 
 
 
 
 
 
 
 
 
 
c9d6d7a
 
 
 
 
 
 
 
5da6a11
01f4a12
bc18618
c9d6d7a
bc18618
 
 
7e0ee60
b26ba26
 
492b5e5
da42d5c
fc5c4f6
c9d6d7a
9da9b42
 
fc5c4f6
00b5ecc
1af1578
 
 
 
82e3bb8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image 
import numpy as np
import rembg

# Define the model and feature extractor
model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

# Define the label mapping
id2label = {
    "0": "Aleurocanthus spiniferus",
    "1": "Chancre citrique",
    "2": "Cochenille blanche",
    "3": "Dépérissement des agrumes",
    "4": "Feuille saine",
    "5": "Jaunissement des feuilles",
    "6": "Maladie de l'oïdium",
    "7": "Maladie du dragon jaune",
    "8": "Mineuse des agrumes",
    "9": "Trou de balle"
}

def remove_background(image):
    image = image.convert("RGBA")
    image_np = np.array(image)
    output_np = rembg.remove(image_np)
    white_bg = Image.new("RGBA", image.size, "WHITE")
    output_image = Image.alpha_composite(white_bg, Image.fromarray(output_np))
    output_image = output_image.convert("RGB")
    return output_image

    
def predict(image):
    image = remove_background(image)
    inputs = feature_extractor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits
    probs = torch.nn.functional.softmax(logits, dim=-1)[0]
    predicted_class_idx = probs.argmax().item()
    predicted_label = id2label[str(predicted_class_idx)]
    confidence_score = probs[predicted_class_idx].item() * 100  
    return f"{predicted_label}: {confidence_score:.2f}%"

# Create the Gradio interface
image = gr.Image(type="pil") 
label = gr.Textbox(label="Prediction")

gr.Interface(fn=predict,
             inputs=image,
             outputs=label,
             title="Classification des maladies des agrumes",
             description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
             examples=["maladie_du_dragon_jaune.jpg",  "feuille_saine.jpg"]).launch(share=True)