thankrandomness commited on
Commit
f2ca0de
·
1 Parent(s): 22a06c1

add efficiency metrics

Browse files
Files changed (2) hide show
  1. app.py +72 -33
  2. requirements.txt +2 -1
app.py CHANGED
@@ -4,6 +4,7 @@ from datasets import load_dataset
4
  from transformers import AutoTokenizer, AutoModel
5
  import chromadb
6
  import gradio as gr
 
7
 
8
  # Mean Pooling - Take attention mask into account for correct averaging
9
  def meanpooling(output, mask):
@@ -11,7 +12,7 @@ def meanpooling(output, mask):
11
  mask = mask.unsqueeze(-1).expand(embeddings.size()).float()
12
  return torch.sum(embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
13
 
14
- # Load the private dataset using the token
15
  dataset = load_dataset("thankrandomness/mimic-iii-sample")
16
 
17
  # Load the model and tokenizer
@@ -30,36 +31,37 @@ def embed_text(text):
30
  client = chromadb.Client()
31
  collection = client.create_collection(name="pubmedbert_matryoshka_embeddings")
32
 
33
- # Process the dataset and upsert into ChromaDB
34
- for i, row in enumerate(dataset['train']):
35
- for note in row['notes']:
36
- text = note.get('text', '')
37
- annotations_list = []
38
-
39
- for annotation in note.get('annotations', []):
40
- try:
41
- code = annotation['code']
42
- code_system = annotation['code_system']
43
- description = annotation['description']
44
- #annotations_list.append(f"{code}: {code_system}: {description}")
45
- annotations_list.append({"code": code, "code_system": code_system, "description": description})
46
- except KeyError as e:
47
- print(f"Skipping annotation due to missing key: {e}")
48
-
49
- print(f"Processed annotations for note {note['note_id']}: {annotations_list}")
50
-
51
- if text and annotations_list:
52
- embeddings = embed_text([text])[0]
53
-
54
- # Upsert data, embeddings, and annotations into ChromaDB
55
- for j, annotation in enumerate(annotations_list):
56
- collection.upsert(
57
- ids=[f"note_{note['note_id']}_{j}"],
58
- embeddings=[embeddings],
59
- metadatas=[annotation]
60
- )
61
- else:
62
- print(f"Skipping note {note['note_id']} due to missing 'text' or 'annotations'")
 
63
 
64
  # Define retrieval function
65
  def retrieve_relevant_text(input_text):
@@ -81,6 +83,33 @@ def retrieve_relevant_text(input_text):
81
  })
82
  return output
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  # Gradio interface
85
  def gradio_interface(input_text):
86
  results = retrieve_relevant_text(input_text)
@@ -88,7 +117,17 @@ def gradio_interface(input_text):
88
  f"Similarity Score: {result['similarity_score']:.2f}, Code: {result['code']}, Description: {result['description']}"
89
  for result in results
90
  ]
91
- return formatted_results
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- interface = gr.Interface(fn=gradio_interface, inputs="text", outputs="text")
94
  interface.launch()
 
4
  from transformers import AutoTokenizer, AutoModel
5
  import chromadb
6
  import gradio as gr
7
+ from sklearn.metrics import precision_score, recall_score, f1_score
8
 
9
  # Mean Pooling - Take attention mask into account for correct averaging
10
  def meanpooling(output, mask):
 
12
  mask = mask.unsqueeze(-1).expand(embeddings.size()).float()
13
  return torch.sum(embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
14
 
15
+ # Load the dataset
16
  dataset = load_dataset("thankrandomness/mimic-iii-sample")
17
 
18
  # Load the model and tokenizer
 
31
  client = chromadb.Client()
32
  collection = client.create_collection(name="pubmedbert_matryoshka_embeddings")
33
 
34
+ # Function to upsert data into ChromaDB
35
+ def upsert_data(dataset_split):
36
+ for i, row in enumerate(dataset_split):
37
+ for note in row['notes']:
38
+ text = note.get('text', '')
39
+ annotations_list = []
40
+
41
+ for annotation in note.get('annotations', []):
42
+ try:
43
+ code = annotation['code']
44
+ code_system = annotation['code_system']
45
+ description = annotation['description']
46
+ annotations_list.append({"code": code, "code_system": code_system, "description": description})
47
+ except KeyError as e:
48
+ print(f"Skipping annotation due to missing key: {e}")
49
+
50
+ if text and annotations_list:
51
+ embeddings = embed_text([text])[0]
52
+
53
+ # Upsert data, embeddings, and annotations into ChromaDB
54
+ for j, annotation in enumerate(annotations_list):
55
+ collection.upsert(
56
+ ids=[f"note_{note['note_id']}_{j}"],
57
+ embeddings=[embeddings],
58
+ metadatas=[annotation]
59
+ )
60
+ else:
61
+ print(f"Skipping note {note['note_id']} due to missing 'text' or 'annotations'")
62
+
63
+ # Upsert training data
64
+ upsert_data(dataset['train'])
65
 
66
  # Define retrieval function
67
  def retrieve_relevant_text(input_text):
 
83
  })
84
  return output
85
 
86
+ # Evaluate retrieval efficiency on the validation/test set
87
+ def evaluate_efficiency(dataset_split):
88
+ y_true = []
89
+ y_pred = []
90
+ for i, row in enumerate(dataset_split):
91
+ for note in row['notes']:
92
+ text = note.get('text', '')
93
+ annotations_list = [annotation['code'] for annotation in note.get('annotations', []) if 'code' in annotation]
94
+
95
+ if text and annotations_list:
96
+ retrieved_results = retrieve_relevant_text(text)
97
+ retrieved_codes = [result['code'] for result in retrieved_results]
98
+
99
+ # Ground truth
100
+ y_true.extend(annotations_list)
101
+ # Predictions
102
+ y_pred.extend(retrieved_codes[:len(annotations_list)]) # Assuming we compare the top-k results
103
+
104
+ precision = precision_score(y_true, y_pred, average='macro')
105
+ recall = recall_score(y_true, y_pred, average='macro')
106
+ f1 = f1_score(y_true, y_pred, average='macro')
107
+
108
+ return precision, recall, f1
109
+
110
+ # Calculate retrieval efficiency metrics
111
+ precision, recall, f1 = evaluate_efficiency(dataset['validation'])
112
+
113
  # Gradio interface
114
  def gradio_interface(input_text):
115
  results = retrieve_relevant_text(input_text)
 
117
  f"Similarity Score: {result['similarity_score']:.2f}, Code: {result['code']}, Description: {result['description']}"
118
  for result in results
119
  ]
120
+ metrics = f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}"
121
+ return formatted_results, metrics
122
+
123
+ interface = gr.Interface(
124
+ fn=gradio_interface,
125
+ inputs="text",
126
+ outputs=["text", "text"],
127
+ live=True
128
+ )
129
+
130
+ # Display retrieval efficiency metrics
131
+ print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}")
132
 
 
133
  interface.launch()
requirements.txt CHANGED
@@ -2,4 +2,5 @@ torch
2
  transformers
3
  chromadb
4
  gradio
5
- numpy
 
 
2
  transformers
3
  chromadb
4
  gradio
5
+ numpy
6
+ scikit-learn