Tonic commited on
Commit
a86354d
·
1 Parent(s): 399e250

learning how to code with the post-introspector

Browse files
Files changed (1) hide show
  1. app.py +9 -10
app.py CHANGED
@@ -65,18 +65,16 @@ class EmbeddingGenerator:
65
  {"role": "user", "content": escaped_input_text}
66
  ]
67
  )
68
- intention_output = intention_completion.choices[0].message['content']
69
-
70
  # Parse and route the intention
71
  parsed_task = parse_and_route(intention_output)
72
- selected_task = list(parsed_task.keys())[0]
73
-
74
  # Construct the prompt
75
- try:
76
  task_description = tasks[selected_task]
77
- except KeyError:
 
78
  print(f"Selected task not found: {selected_task}")
79
- return f"Error: Task '{selected_task}' not found. Please select a valid task."
80
 
81
  query_prefix = f"Instruct: {task_description}\nQuery: "
82
  queries = [escaped_input_text]
@@ -89,13 +87,14 @@ class EmbeddingGenerator:
89
  {"role": "user", "content": escaped_input_text}
90
  ]
91
  )
92
- metadata_output = metadata_completion.choices[0].message['content']
93
  metadata = self.extract_metadata(metadata_output)
94
 
95
  # Get the embeddings
96
  with torch.no_grad():
97
  inputs = self.tokenizer(queries, return_tensors='pt', padding=True, truncation=True, max_length=4096).to(self.device)
98
- outputs = self.model(**inputs)
 
99
  query_embeddings = outputs.last_hidden_state.mean(dim=1)
100
 
101
  # Normalize embeddings
@@ -118,7 +117,7 @@ class MyEmbeddingFunction(EmbeddingFunction):
118
  self.embedding_generator = embedding_generator
119
 
120
  def __call__(self, input: Documents) -> (Embeddings, list):
121
- embeddings_with_metadata = [self.embedding_generator.compute_embeddings(doc) for doc in input]
122
  embeddings = [item[0] for item in embeddings_with_metadata]
123
  metadata = [item[1] for item in embeddings_with_metadata]
124
  embeddings_flattened = [emb for sublist in embeddings for emb in sublist]
 
65
  {"role": "user", "content": escaped_input_text}
66
  ]
67
  )
68
+ intention_output = intention_completion.choices[0].message.content
 
69
  # Parse and route the intention
70
  parsed_task = parse_and_route(intention_output)
71
+ selected_task = parsed_task
 
72
  # Construct the prompt
73
+ if selected_task in tasks:
74
  task_description = tasks[selected_task]
75
+ else:
76
+ task_description = tasks["DEFAULT"]
77
  print(f"Selected task not found: {selected_task}")
 
78
 
79
  query_prefix = f"Instruct: {task_description}\nQuery: "
80
  queries = [escaped_input_text]
 
87
  {"role": "user", "content": escaped_input_text}
88
  ]
89
  )
90
+ metadata_output = metadata_completion.choices[0].message.content
91
  metadata = self.extract_metadata(metadata_output)
92
 
93
  # Get the embeddings
94
  with torch.no_grad():
95
  inputs = self.tokenizer(queries, return_tensors='pt', padding=True, truncation=True, max_length=4096).to(self.device)
96
+ outputs = self.model(**inputs)
97
+ query_embeddings = outputs["sentence_embeddings"].mean(dim=1)
98
  query_embeddings = outputs.last_hidden_state.mean(dim=1)
99
 
100
  # Normalize embeddings
 
117
  self.embedding_generator = embedding_generator
118
 
119
  def __call__(self, input: Documents) -> (Embeddings, list):
120
+ embeddings_with_metadata = [self.embedding_generator.compute_embeddings(doc.page_content) for doc in input]
121
  embeddings = [item[0] for item in embeddings_with_metadata]
122
  metadata = [item[1] for item in embeddings_with_metadata]
123
  embeddings_flattened = [emb for sublist in embeddings for emb in sublist]