LanceY2004 commited on
Commit
209c441
·
verified ·
1 Parent(s): 3750f0c
Files changed (1) hide show
  1. localrag_no_rewrite.py +193 -0
localrag_no_rewrite.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import ollama
3
+ import os
4
+ from openai import OpenAI
5
+ import argparse
6
+
7
+ # ANSI escape codes for colors
8
+ PINK = '\033[95m'
9
+ CYAN = '\033[96m'
10
+ YELLOW = '\033[93m'
11
+ NEON_GREEN = '\033[92m'
12
+ RESET_COLOR = '\033[0m'
13
+
14
+ # Function to open a file and return its contents as a string
15
+ def open_file(filepath):
16
+ with open(filepath, 'r', encoding='utf-8') as infile:
17
+ return infile.read()
18
+
19
+ # Function to get relevant context from the vault based on user input
20
+ def get_relevant_context(rewritten_input, vault_embeddings, vault_content, top_k=3):
21
+ if vault_embeddings.nelement() == 0: # Check if the tensor has any elements
22
+ return []
23
+ # Encode the rewritten input
24
+ input_embedding = ollama.embeddings(model='mxbai-embed-large', prompt=rewritten_input)["embedding"]
25
+ # Compute cosine similarity between the input and vault embeddings
26
+ cos_scores = torch.cosine_similarity(torch.tensor(input_embedding).unsqueeze(0), vault_embeddings)
27
+ # Adjust top_k if it's greater than the number of available scores
28
+ top_k = min(top_k, len(cos_scores))
29
+ # Sort the scores and get the top-k indices
30
+ top_indices = torch.topk(cos_scores, k=top_k)[1].tolist()
31
+ # Get the corresponding context from the vault
32
+ relevant_context = [vault_content[idx].strip() for idx in top_indices]
33
+ return relevant_context
34
+
35
+ # Function to interact with the Ollama model
36
+ def ollama_chat(user_input, system_message, vault_embeddings, vault_content, ollama_model, conversation_history):
37
+ # Get relevant context from the vault
38
+ relevant_context = get_relevant_context(user_input, vault_embeddings, vault_content, top_k=3)
39
+ if relevant_context:
40
+ # Convert list to a single string with newlines between items
41
+ context_str = "\n".join(relevant_context)
42
+ print("Context Pulled from Documents: \n\n" + CYAN + context_str + RESET_COLOR)
43
+ else:
44
+ print(CYAN + "No relevant context found." + RESET_COLOR)
45
+
46
+ # Prepare the user's input by concatenating it with the relevant context
47
+ user_input_with_context = user_input
48
+ if relevant_context:
49
+ user_input_with_context = context_str + "\n\n" + user_input
50
+
51
+ # Append the user's input to the conversation history
52
+ conversation_history.append({"role": "user", "content": user_input_with_context})
53
+
54
+ # Create a message history including the system message and the conversation history
55
+ messages = [
56
+ {"role": "system", "content": system_message},
57
+ *conversation_history
58
+ ]
59
+
60
+ # Send the completion request to the Ollama model
61
+ response = client.chat.completions.create(
62
+ model=ollama_model,
63
+ messages=messages
64
+ )
65
+
66
+ # Append the model's response to the conversation history
67
+ conversation_history.append({"role": "assistant", "content": response.choices[0].message.content})
68
+
69
+ # Return the content of the response from the model
70
+ return response.choices[0].message.content
71
+
72
+ def process_text_files(user_input):
73
+ text_parse_directory = os.path.join("local-rag", "text_parse")
74
+ temp_file_path = os.path.join("local-rag", "temp.txt")
75
+
76
+ # Check if text_parse directory exists
77
+ if not os.path.exists(text_parse_directory):
78
+ print(f"Directory '{text_parse_directory}' does not exist.")
79
+ return False
80
+
81
+ # Check if temp.txt exists
82
+ if not os.path.exists(temp_file_path):
83
+ print("temp.txt does not exist.")
84
+ return False
85
+
86
+ # Read the first line of temp.txt
87
+ with open(temp_file_path, 'r', encoding='utf-8') as temp_file:
88
+ first_line = temp_file.readline().strip()
89
+
90
+ # Get all text files in the text_parse directory
91
+ text_files = [f for f in os.listdir(text_parse_directory) if f.endswith('.txt')]
92
+
93
+ # Check if the first line matches any of the text files
94
+ if f"{first_line}" not in text_files:
95
+ print(f"No matching file found for '{first_line}.txt' in text_parse directory.")
96
+ return False
97
+
98
+ # Proceed to check for the NOT FINISHED flag
99
+ file_path = os.path.join(text_parse_directory, f"{first_line}")
100
+ with open(file_path, 'r', encoding='utf-8') as f:
101
+ lines = f.readlines()
102
+
103
+
104
+ # Check if there are any lines after NOT FINISHED
105
+ if lines[-2].strip() == "====================NOT FINISHED====================":
106
+ print(f"'{first_line}' contains the 'NOT FINISHED' flag. Computing embeddings.")
107
+
108
+ vault_content = []
109
+ if os.path.exists(temp_file_path):
110
+ with open(temp_file_path, "r", encoding='utf-8') as vault_file:
111
+ vault_content = vault_file.readlines()
112
+
113
+
114
+ # Generate embeddings for the vault content using Ollama
115
+ vault_embeddings = []
116
+ for content in vault_content:
117
+ response = ollama.embeddings(model='mxbai-embed-large', prompt=content)
118
+ vault_embeddings.append(response["embedding"])
119
+
120
+ # Convert to tensor and print embeddings
121
+ vault_embeddings_tensor = torch.tensor(vault_embeddings)
122
+ print("Embeddings for each line in the vault:")
123
+ print(vault_embeddings_tensor)
124
+
125
+ # Save the tensor result to a file or variable as needed
126
+ with open(os.path.join(text_parse_directory, f"{first_line}_embedding.pt"), "wb") as tensor_file:
127
+ torch.save(vault_embeddings_tensor, tensor_file)
128
+
129
+ # Remove the NOT FINISHED line from the original file
130
+ with open(file_path, 'w', encoding='utf-8') as f:
131
+ f.writelines(lines[:-1]) # Write back all lines except the NOT FINISHED line
132
+
133
+ else:
134
+ print(f"'{first_line}' does not contain the 'NOT FINISHED' flag or is already complete. Loading tensor if it exists.")
135
+
136
+ # Try to load the tensor from the corresponding file
137
+ tensor_file_path = os.path.join(text_parse_directory, f"{first_line}_embedding.pt")
138
+ if os.path.exists(tensor_file_path):
139
+ vault_embeddings_tensor = torch.load(tensor_file_path)
140
+ print("Loaded Vault Embedding Tensor:")
141
+ print(vault_embeddings_tensor)
142
+
143
+ vault_content = []
144
+
145
+ if os.path.exists(temp_file_path):
146
+ with open(temp_file_path, "r", encoding='utf-8') as vault_file:
147
+ vault_content = vault_file.readlines()
148
+
149
+ else:
150
+ print(f"No tensor file found for '{text_files}'.")
151
+
152
+
153
+
154
+ # Conversation loop
155
+ conversation_history = []
156
+ system_message = "You are a helpful assistant that is an expert at extracting the most useful information from a given text"
157
+
158
+ response = ollama_chat(user_input, system_message, vault_embeddings_tensor, vault_content, args.model, conversation_history)
159
+
160
+ return response
161
+
162
+
163
+ # # Read each file in the text_parse directory and check for the NOT FINISHED flag
164
+ # for txt_file in text_files:
165
+ # file_path = os.path.join(text_parse_directory, txt_file)
166
+ # with open(file_path, 'r', encoding='utf-8') as f:
167
+ # lines = f.readlines()
168
+ # # Check if the last line contains the "NOT FINISHED" flag
169
+ # if lines and lines[-1].strip() == "==========NOT FINISHED==========":
170
+ # print(f"'{txt_file}' contains the 'NOT FINISHED' flag. Proceeding to next step.")
171
+ # # Append the content of this file to the vault
172
+ # with open(temp_file_path, 'a', encoding='utf-8') as vault_file:
173
+ # vault_file.write('\n'.join(lines[:-1]) + '\n') # Append content without the last flag line
174
+ # else:
175
+ # print(f"'{txt_file}' does not contain the 'NOT FINISHED' flag. Skipping.")
176
+
177
+ # Parse command-line arguments
178
+ parser = argparse.ArgumentParser(description="Ollama Chat")
179
+ parser.add_argument("--model", default="llama3", help="Ollama model to use (default: llama3)")
180
+ args = parser.parse_args()
181
+
182
+ # Configuration for the Ollama API client
183
+ client = OpenAI(
184
+ base_url='http://localhost:11434/v1',
185
+ api_key='llama3'
186
+ )
187
+
188
+ if __name__ == "__main__":
189
+
190
+
191
+ print(process_text_files("tell me about iterators"))
192
+
193
+