KhadijaAsehnoune12 commited on
Commit
7e0ee60
·
verified ·
1 Parent(s): d353c03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -9
app.py CHANGED
@@ -23,22 +23,29 @@ id2label = {
23
  }
24
 
25
  def predict(image):
26
- # Preprocess the image
27
  inputs = feature_extractor(images=image, return_tensors="pt")
28
 
29
- # Forward pass through the model
30
  outputs = model(**inputs)
31
 
32
- # Get the predicted label and confidence score
33
  logits = outputs.logits
34
- predicted_class_idx = logits.argmax(-1).item()
35
- confidence_score = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item()
36
 
37
- # Get the label name
38
- predicted_label = id2label[str(predicted_class_idx)]
39
 
40
- # Return the predicted label and confidence score
41
- return f"{predicted_label}: {confidence_score:.2f}"
 
 
 
 
 
 
 
 
 
42
 
43
 
44
  # Create the Gradio interface
 
23
  }
24
 
25
  def predict(image):
26
+ # Prétraiter l'image
27
  inputs = feature_extractor(images=image, return_tensors="pt")
28
 
29
+ # Passage en avant dans le modèle
30
  outputs = model(**inputs)
31
 
32
+ # Obtenir les logits
33
  logits = outputs.logits
 
 
34
 
35
+ # Calculer les scores de confiance avec softmax
36
+ probs = torch.nn.functional.softmax(logits, dim=-1)[0]
37
 
38
+ # Obtenir les indices des trois classes les plus probables
39
+ top_3_indices = torch.topk(probs, 3).indices.tolist()
40
+
41
+ # Obtenir les labels et les scores de confiance pour les trois classes les plus probables
42
+ top_3_labels_and_scores = [(id2label[str(idx)], probs[idx].item()) for idx in top_3_indices]
43
+
44
+ # Formater les résultats
45
+ results = [f"{label}: {score:.2f}" for label, score in top_3_labels_and_scores]
46
+
47
+ # Retourner les résultats
48
+ return "\n".join(results)
49
 
50
 
51
  # Create the Gradio interface