MEDIC-Benchmark / eval_metrics_app.py
tathagataraha's picture
[ADD] Harness tasks, data display
09b313f
import gradio as gr
# Function to compute evaluation metrics (dummy implementation)
def compute_metrics(gt_spans, pred_spans):
# Dummy implementation of a metric computation
# Replace this with actual metric computation logic
tp = len(set(gt_spans) & set(pred_spans))
fp = len(set(pred_spans) - set(gt_spans))
fn = len(set(gt_spans) - set(pred_spans))
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
return {"precision": precision, "recall": recall, "f1_score": f1_score}
def create_app():
with gr.Blocks() as demo:
# Input components
text_input = gr.Textbox(label="Input Text")
highlight_input = gr.Textbox(label="Highlight Text and Press Add")
gt_spans_state = gr.State([])
pred_spans_state = gr.State([])
# Buttons for ground truth and prediction
add_gt_button = gr.Button("Add to Ground Truth")
add_pred_button = gr.Button("Add to Predictions")
# Outputs for highlighted spans
gt_output = gr.HighlightedText(label="Ground Truth Spans")
pred_output = gr.HighlightedText(label="Predicted Spans")
# Compute metrics button and its output
compute_button = gr.Button("Compute Metrics")
metrics_output = gr.JSON(label="Metrics")
# Function to update spans
def update_spans(text, span, gt_spans, pred_spans, is_gt):
start_idx = text.find(span)
end_idx = start_idx + len(span)
new_span = (start_idx, end_idx)
if is_gt:
gt_spans.append(new_span)
gt_spans = list(set(gt_spans))
else:
pred_spans.append(new_span)
pred_spans = list(set(pred_spans))
return gt_spans, pred_spans, highlight_spans(text, gt_spans), highlight_spans(text, pred_spans)
# Function to highlight spans
def highlight_spans(text, spans):
span_dict = {}
for span in spans:
span_dict[(span[0], span[1])] = "highlight"
return span_dict
# Event handlers for buttons
add_gt_button.click(fn=update_spans, inputs=[text_input, highlight_input, gt_spans_state, pred_spans_state, gr.State(True)], outputs=[gt_spans_state, pred_spans_state, gt_output, pred_output])
add_pred_button.click(fn=update_spans, inputs=[text_input, highlight_input, gt_spans_state, pred_spans_state, gr.State(False)], outputs=[gt_spans_state, pred_spans_state, gt_output, pred_output])
# Function to compute metrics
def on_compute_metrics(gt_spans, pred_spans):
metrics = compute_metrics(gt_spans, pred_spans)
return metrics
compute_button.click(fn=on_compute_metrics, inputs=[gt_spans_state, pred_spans_state], outputs=metrics_output)
# Layout arrangement
text_input.change(fn=lambda x: x, inputs=text_input, outputs=[gt_output, pred_output])
return demo
# Run the app
demo = create_app()
demo.launch()