vshulev commited on
Commit
bb0609f
·
1 Parent(s): 47e4e3e
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -126,7 +126,8 @@ def predict_genus(method: str, dna_sequence: str, latitude: str, longitude: str)
126
  if method == "fine_tuned_model":
127
  bert_inputs = tokenize(dna_sequence)
128
  logits = classification_model(bert_inputs, torch.zeros(1, 7))
129
- probs = torch.softmax(logits, dim=1).squeeze()
 
130
  top_k = torch.topk(probs, 10)
131
  top_k = pd.Series(
132
  top_k.values.detach().numpy(),
@@ -136,10 +137,12 @@ def predict_genus(method: str, dna_sequence: str, latitude: str, longitude: str)
136
 
137
  fig, ax = plt.subplots()
138
  ax.bar(top_k.index.astype(str), top_k.values)
 
139
  ax.set_title("Genus Prediction")
140
  ax.set_xlabel("Genus")
141
  ax.set_ylabel("Probability")
142
- ax.set_xticklabels(top_k.index.astype(str), rotation=90)
 
143
  fig.canvas.draw()
144
 
145
  return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
 
126
  if method == "fine_tuned_model":
127
  bert_inputs = tokenize(dna_sequence)
128
  logits = classification_model(bert_inputs, torch.zeros(1, 7))
129
+ temperature = 0.5
130
+ probs = torch.softmax(logits / temperature, dim=1).squeeze()
131
  top_k = torch.topk(probs, 10)
132
  top_k = pd.Series(
133
  top_k.values.detach().numpy(),
 
137
 
138
  fig, ax = plt.subplots()
139
  ax.bar(top_k.index.astype(str), top_k.values)
140
+ ax.set_ylim(0, 1)
141
  ax.set_title("Genus Prediction")
142
  ax.set_xlabel("Genus")
143
  ax.set_ylabel("Probability")
144
+ ax.set_xticklabels(top_k.index.astype(str), rotation=90)
145
+ fig.subplots_adjust(bottom=0.3)
146
  fig.canvas.draw()
147
 
148
  return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())