XThomasBU
commited on
Commit
Β·
f51bb92
1
Parent(s):
48f8268
init commit
Browse files- .chainlit/translations/en-US.json +0 -231
- .chainlit/translations/pt-BR.json +0 -155
- .gitignore +7 -1
- {.chainlit β code/.chainlit}/config.toml +0 -0
- code/__init__.py +1 -0
- chainlit.md β code/chainlit.md +0 -0
- code/main.py +40 -33
- code/modules/__init__.py +2 -0
- code/modules/chat/__init__.py +2 -0
- code/modules/{chat_model_loader.py β chat/chat_model_loader.py} +2 -3
- code/modules/{helpers.py β chat/helpers.py} +26 -259
- code/modules/{llm_tutor.py β chat/llm_tutor.py} +97 -54
- code/modules/config/__init__.py +1 -0
- code/{config.yml β modules/config/config.yml} +19 -10
- code/modules/{constants.py β config/constants.py} +0 -0
- code/modules/dataloader/__init__.py +2 -0
- code/modules/{data_loader.py β dataloader/data_loader.py} +181 -117
- code/modules/dataloader/helpers.py +108 -0
- code/modules/dataloader/webpage_crawler.py +115 -0
- code/modules/retriever/__init__.py +2 -0
- code/modules/retriever/base.py +6 -0
- code/modules/retriever/chroma_retriever.py +24 -0
- code/modules/retriever/faiss_retriever.py +23 -0
- code/modules/retriever/helpers.py +39 -0
- code/modules/vectorstore/__init__.py +2 -0
- code/modules/vectorstore/base.py +18 -0
- code/modules/vectorstore/chroma.py +41 -0
- code/modules/{embedding_model_loader.py β vectorstore/embedding_model_loader.py} +5 -8
- code/modules/vectorstore/faiss.py +45 -0
- code/modules/vectorstore/helpers.py +0 -0
- code/modules/{vector_db.py β vectorstore/store_manager.py} +147 -138
.chainlit/translations/en-US.json
DELETED
@@ -1,231 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"components": {
|
3 |
-
"atoms": {
|
4 |
-
"buttons": {
|
5 |
-
"userButton": {
|
6 |
-
"menu": {
|
7 |
-
"settings": "Settings",
|
8 |
-
"settingsKey": "S",
|
9 |
-
"APIKeys": "API Keys",
|
10 |
-
"logout": "Logout"
|
11 |
-
}
|
12 |
-
}
|
13 |
-
}
|
14 |
-
},
|
15 |
-
"molecules": {
|
16 |
-
"newChatButton": {
|
17 |
-
"newChat": "New Chat"
|
18 |
-
},
|
19 |
-
"tasklist": {
|
20 |
-
"TaskList": {
|
21 |
-
"title": "\ud83d\uddd2\ufe0f Task List",
|
22 |
-
"loading": "Loading...",
|
23 |
-
"error": "An error occured"
|
24 |
-
}
|
25 |
-
},
|
26 |
-
"attachments": {
|
27 |
-
"cancelUpload": "Cancel upload",
|
28 |
-
"removeAttachment": "Remove attachment"
|
29 |
-
},
|
30 |
-
"newChatDialog": {
|
31 |
-
"createNewChat": "Create new chat?",
|
32 |
-
"clearChat": "This will clear the current messages and start a new chat.",
|
33 |
-
"cancel": "Cancel",
|
34 |
-
"confirm": "Confirm"
|
35 |
-
},
|
36 |
-
"settingsModal": {
|
37 |
-
"settings": "Settings",
|
38 |
-
"expandMessages": "Expand Messages",
|
39 |
-
"hideChainOfThought": "Hide Chain of Thought",
|
40 |
-
"darkMode": "Dark Mode"
|
41 |
-
},
|
42 |
-
"detailsButton": {
|
43 |
-
"using": "Using",
|
44 |
-
"running": "Running",
|
45 |
-
"took_one": "Took {{count}} step",
|
46 |
-
"took_other": "Took {{count}} steps"
|
47 |
-
},
|
48 |
-
"auth": {
|
49 |
-
"authLogin": {
|
50 |
-
"title": "Login to access the app.",
|
51 |
-
"form": {
|
52 |
-
"email": "Email address",
|
53 |
-
"password": "Password",
|
54 |
-
"noAccount": "Don't have an account?",
|
55 |
-
"alreadyHaveAccount": "Already have an account?",
|
56 |
-
"signup": "Sign Up",
|
57 |
-
"signin": "Sign In",
|
58 |
-
"or": "OR",
|
59 |
-
"continue": "Continue",
|
60 |
-
"forgotPassword": "Forgot password?",
|
61 |
-
"passwordMustContain": "Your password must contain:",
|
62 |
-
"emailRequired": "email is a required field",
|
63 |
-
"passwordRequired": "password is a required field"
|
64 |
-
},
|
65 |
-
"error": {
|
66 |
-
"default": "Unable to sign in.",
|
67 |
-
"signin": "Try signing in with a different account.",
|
68 |
-
"oauthsignin": "Try signing in with a different account.",
|
69 |
-
"redirect_uri_mismatch": "The redirect URI is not matching the oauth app configuration.",
|
70 |
-
"oauthcallbackerror": "Try signing in with a different account.",
|
71 |
-
"oauthcreateaccount": "Try signing in with a different account.",
|
72 |
-
"emailcreateaccount": "Try signing in with a different account.",
|
73 |
-
"callback": "Try signing in with a different account.",
|
74 |
-
"oauthaccountnotlinked": "To confirm your identity, sign in with the same account you used originally.",
|
75 |
-
"emailsignin": "The e-mail could not be sent.",
|
76 |
-
"emailverify": "Please verify your email, a new email has been sent.",
|
77 |
-
"credentialssignin": "Sign in failed. Check the details you provided are correct.",
|
78 |
-
"sessionrequired": "Please sign in to access this page."
|
79 |
-
}
|
80 |
-
},
|
81 |
-
"authVerifyEmail": {
|
82 |
-
"almostThere": "You're almost there! We've sent an email to ",
|
83 |
-
"verifyEmailLink": "Please click on the link in that email to complete your signup.",
|
84 |
-
"didNotReceive": "Can't find the email?",
|
85 |
-
"resendEmail": "Resend email",
|
86 |
-
"goBack": "Go Back",
|
87 |
-
"emailSent": "Email sent successfully.",
|
88 |
-
"verifyEmail": "Verify your email address"
|
89 |
-
},
|
90 |
-
"providerButton": {
|
91 |
-
"continue": "Continue with {{provider}}",
|
92 |
-
"signup": "Sign up with {{provider}}"
|
93 |
-
},
|
94 |
-
"authResetPassword": {
|
95 |
-
"newPasswordRequired": "New password is a required field",
|
96 |
-
"passwordsMustMatch": "Passwords must match",
|
97 |
-
"confirmPasswordRequired": "Confirm password is a required field",
|
98 |
-
"newPassword": "New password",
|
99 |
-
"confirmPassword": "Confirm password",
|
100 |
-
"resetPassword": "Reset Password"
|
101 |
-
},
|
102 |
-
"authForgotPassword": {
|
103 |
-
"email": "Email address",
|
104 |
-
"emailRequired": "email is a required field",
|
105 |
-
"emailSent": "Please check the email address {{email}} for instructions to reset your password.",
|
106 |
-
"enterEmail": "Enter your email address and we will send you instructions to reset your password.",
|
107 |
-
"resendEmail": "Resend email",
|
108 |
-
"continue": "Continue",
|
109 |
-
"goBack": "Go Back"
|
110 |
-
}
|
111 |
-
}
|
112 |
-
},
|
113 |
-
"organisms": {
|
114 |
-
"chat": {
|
115 |
-
"history": {
|
116 |
-
"index": {
|
117 |
-
"showHistory": "Show history",
|
118 |
-
"lastInputs": "Last Inputs",
|
119 |
-
"noInputs": "Such empty...",
|
120 |
-
"loading": "Loading..."
|
121 |
-
}
|
122 |
-
},
|
123 |
-
"inputBox": {
|
124 |
-
"input": {
|
125 |
-
"placeholder": "Type your message here..."
|
126 |
-
},
|
127 |
-
"speechButton": {
|
128 |
-
"start": "Start recording",
|
129 |
-
"stop": "Stop recording"
|
130 |
-
},
|
131 |
-
"SubmitButton": {
|
132 |
-
"sendMessage": "Send message",
|
133 |
-
"stopTask": "Stop Task"
|
134 |
-
},
|
135 |
-
"UploadButton": {
|
136 |
-
"attachFiles": "Attach files"
|
137 |
-
},
|
138 |
-
"waterMark": {
|
139 |
-
"text": "Built with"
|
140 |
-
}
|
141 |
-
},
|
142 |
-
"Messages": {
|
143 |
-
"index": {
|
144 |
-
"running": "Running",
|
145 |
-
"executedSuccessfully": "executed successfully",
|
146 |
-
"failed": "failed",
|
147 |
-
"feedbackUpdated": "Feedback updated",
|
148 |
-
"updating": "Updating"
|
149 |
-
}
|
150 |
-
},
|
151 |
-
"dropScreen": {
|
152 |
-
"dropYourFilesHere": "Drop your files here"
|
153 |
-
},
|
154 |
-
"index": {
|
155 |
-
"failedToUpload": "Failed to upload",
|
156 |
-
"cancelledUploadOf": "Cancelled upload of",
|
157 |
-
"couldNotReachServer": "Could not reach the server",
|
158 |
-
"continuingChat": "Continuing previous chat"
|
159 |
-
},
|
160 |
-
"settings": {
|
161 |
-
"settingsPanel": "Settings panel",
|
162 |
-
"reset": "Reset",
|
163 |
-
"cancel": "Cancel",
|
164 |
-
"confirm": "Confirm"
|
165 |
-
}
|
166 |
-
},
|
167 |
-
"threadHistory": {
|
168 |
-
"sidebar": {
|
169 |
-
"filters": {
|
170 |
-
"FeedbackSelect": {
|
171 |
-
"feedbackAll": "Feedback: All",
|
172 |
-
"feedbackPositive": "Feedback: Positive",
|
173 |
-
"feedbackNegative": "Feedback: Negative"
|
174 |
-
},
|
175 |
-
"SearchBar": {
|
176 |
-
"search": "Search"
|
177 |
-
}
|
178 |
-
},
|
179 |
-
"DeleteThreadButton": {
|
180 |
-
"confirmMessage": "This will delete the thread as well as it's messages and elements.",
|
181 |
-
"cancel": "Cancel",
|
182 |
-
"confirm": "Confirm",
|
183 |
-
"deletingChat": "Deleting chat",
|
184 |
-
"chatDeleted": "Chat deleted"
|
185 |
-
},
|
186 |
-
"index": {
|
187 |
-
"pastChats": "Past Chats"
|
188 |
-
},
|
189 |
-
"ThreadList": {
|
190 |
-
"empty": "Empty...",
|
191 |
-
"today": "Today",
|
192 |
-
"yesterday": "Yesterday",
|
193 |
-
"previous7days": "Previous 7 days",
|
194 |
-
"previous30days": "Previous 30 days"
|
195 |
-
},
|
196 |
-
"TriggerButton": {
|
197 |
-
"closeSidebar": "Close sidebar",
|
198 |
-
"openSidebar": "Open sidebar"
|
199 |
-
}
|
200 |
-
},
|
201 |
-
"Thread": {
|
202 |
-
"backToChat": "Go back to chat",
|
203 |
-
"chatCreatedOn": "This chat was created on"
|
204 |
-
}
|
205 |
-
},
|
206 |
-
"header": {
|
207 |
-
"chat": "Chat",
|
208 |
-
"readme": "Readme"
|
209 |
-
}
|
210 |
-
}
|
211 |
-
},
|
212 |
-
"hooks": {
|
213 |
-
"useLLMProviders": {
|
214 |
-
"failedToFetchProviders": "Failed to fetch providers:"
|
215 |
-
}
|
216 |
-
},
|
217 |
-
"pages": {
|
218 |
-
"Design": {},
|
219 |
-
"Env": {
|
220 |
-
"savedSuccessfully": "Saved successfully",
|
221 |
-
"requiredApiKeys": "Required API Keys",
|
222 |
-
"requiredApiKeysInfo": "To use this app, the following API keys are required. The keys are stored on your device's local storage."
|
223 |
-
},
|
224 |
-
"Page": {
|
225 |
-
"notPartOfProject": "You are not part of this project."
|
226 |
-
},
|
227 |
-
"ResumeButton": {
|
228 |
-
"resumeChat": "Resume Chat"
|
229 |
-
}
|
230 |
-
}
|
231 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.chainlit/translations/pt-BR.json
DELETED
@@ -1,155 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"components": {
|
3 |
-
"atoms": {
|
4 |
-
"buttons": {
|
5 |
-
"userButton": {
|
6 |
-
"menu": {
|
7 |
-
"settings": "Configura\u00e7\u00f5es",
|
8 |
-
"settingsKey": "S",
|
9 |
-
"APIKeys": "Chaves de API",
|
10 |
-
"logout": "Sair"
|
11 |
-
}
|
12 |
-
}
|
13 |
-
}
|
14 |
-
},
|
15 |
-
"molecules": {
|
16 |
-
"newChatButton": {
|
17 |
-
"newChat": "Nova Conversa"
|
18 |
-
},
|
19 |
-
"tasklist": {
|
20 |
-
"TaskList": {
|
21 |
-
"title": "\ud83d\uddd2\ufe0f Lista de Tarefas",
|
22 |
-
"loading": "Carregando...",
|
23 |
-
"error": "Ocorreu um erro"
|
24 |
-
}
|
25 |
-
},
|
26 |
-
"attachments": {
|
27 |
-
"cancelUpload": "Cancelar envio",
|
28 |
-
"removeAttachment": "Remover anexo"
|
29 |
-
},
|
30 |
-
"newChatDialog": {
|
31 |
-
"createNewChat": "Criar novo chat?",
|
32 |
-
"clearChat": "Isso limpar\u00e1 as mensagens atuais e iniciar\u00e1 uma nova conversa.",
|
33 |
-
"cancel": "Cancelar",
|
34 |
-
"confirm": "Confirmar"
|
35 |
-
},
|
36 |
-
"settingsModal": {
|
37 |
-
"expandMessages": "Expandir Mensagens",
|
38 |
-
"hideChainOfThought": "Esconder Sequ\u00eancia de Pensamento",
|
39 |
-
"darkMode": "Modo Escuro"
|
40 |
-
}
|
41 |
-
},
|
42 |
-
"organisms": {
|
43 |
-
"chat": {
|
44 |
-
"history": {
|
45 |
-
"index": {
|
46 |
-
"lastInputs": "\u00daltimas Entradas",
|
47 |
-
"noInputs": "Vazio...",
|
48 |
-
"loading": "Carregando..."
|
49 |
-
}
|
50 |
-
},
|
51 |
-
"inputBox": {
|
52 |
-
"input": {
|
53 |
-
"placeholder": "Digite sua mensagem aqui..."
|
54 |
-
},
|
55 |
-
"speechButton": {
|
56 |
-
"start": "Iniciar grava\u00e7\u00e3o",
|
57 |
-
"stop": "Parar grava\u00e7\u00e3o"
|
58 |
-
},
|
59 |
-
"SubmitButton": {
|
60 |
-
"sendMessage": "Enviar mensagem",
|
61 |
-
"stopTask": "Parar Tarefa"
|
62 |
-
},
|
63 |
-
"UploadButton": {
|
64 |
-
"attachFiles": "Anexar arquivos"
|
65 |
-
},
|
66 |
-
"waterMark": {
|
67 |
-
"text": "Constru\u00eddo com"
|
68 |
-
}
|
69 |
-
},
|
70 |
-
"Messages": {
|
71 |
-
"index": {
|
72 |
-
"running": "Executando",
|
73 |
-
"executedSuccessfully": "executado com sucesso",
|
74 |
-
"failed": "falhou",
|
75 |
-
"feedbackUpdated": "Feedback atualizado",
|
76 |
-
"updating": "Atualizando"
|
77 |
-
}
|
78 |
-
},
|
79 |
-
"dropScreen": {
|
80 |
-
"dropYourFilesHere": "Solte seus arquivos aqui"
|
81 |
-
},
|
82 |
-
"index": {
|
83 |
-
"failedToUpload": "Falha ao enviar",
|
84 |
-
"cancelledUploadOf": "Envio cancelado de",
|
85 |
-
"couldNotReachServer": "N\u00e3o foi poss\u00edvel conectar ao servidor",
|
86 |
-
"continuingChat": "Continuando o chat anterior"
|
87 |
-
},
|
88 |
-
"settings": {
|
89 |
-
"settingsPanel": "Painel de Configura\u00e7\u00f5es",
|
90 |
-
"reset": "Redefinir",
|
91 |
-
"cancel": "Cancelar",
|
92 |
-
"confirm": "Confirmar"
|
93 |
-
}
|
94 |
-
},
|
95 |
-
"threadHistory": {
|
96 |
-
"sidebar": {
|
97 |
-
"filters": {
|
98 |
-
"FeedbackSelect": {
|
99 |
-
"feedbackAll": "Feedback: Todos",
|
100 |
-
"feedbackPositive": "Feedback: Positivo",
|
101 |
-
"feedbackNegative": "Feedback: Negativo"
|
102 |
-
},
|
103 |
-
"SearchBar": {
|
104 |
-
"search": "Buscar"
|
105 |
-
}
|
106 |
-
},
|
107 |
-
"DeleteThreadButton": {
|
108 |
-
"confirmMessage": "Isso deletar\u00e1 a conversa, assim como suas mensagens e elementos.",
|
109 |
-
"cancel": "Cancelar",
|
110 |
-
"confirm": "Confirmar",
|
111 |
-
"deletingChat": "Deletando conversa",
|
112 |
-
"chatDeleted": "Conversa deletada"
|
113 |
-
},
|
114 |
-
"index": {
|
115 |
-
"pastChats": "Conversas Anteriores"
|
116 |
-
},
|
117 |
-
"ThreadList": {
|
118 |
-
"empty": "Vazio..."
|
119 |
-
},
|
120 |
-
"TriggerButton": {
|
121 |
-
"closeSidebar": "Fechar barra lateral",
|
122 |
-
"openSidebar": "Abrir barra lateral"
|
123 |
-
}
|
124 |
-
},
|
125 |
-
"Thread": {
|
126 |
-
"backToChat": "Voltar para a conversa",
|
127 |
-
"chatCreatedOn": "Esta conversa foi criada em"
|
128 |
-
}
|
129 |
-
},
|
130 |
-
"header": {
|
131 |
-
"chat": "Conversa",
|
132 |
-
"readme": "Leia-me"
|
133 |
-
}
|
134 |
-
},
|
135 |
-
"hooks": {
|
136 |
-
"useLLMProviders": {
|
137 |
-
"failedToFetchProviders": "Falha ao buscar provedores:"
|
138 |
-
}
|
139 |
-
},
|
140 |
-
"pages": {
|
141 |
-
"Design": {},
|
142 |
-
"Env": {
|
143 |
-
"savedSuccessfully": "Salvo com sucesso",
|
144 |
-
"requiredApiKeys": "Chaves de API necess\u00e1rias",
|
145 |
-
"requiredApiKeysInfo": "Para usar este aplicativo, as seguintes chaves de API s\u00e3o necess\u00e1rias. As chaves s\u00e3o armazenadas localmente em seu dispositivo."
|
146 |
-
},
|
147 |
-
"Page": {
|
148 |
-
"notPartOfProject": "Voc\u00ea n\u00e3o faz parte deste projeto."
|
149 |
-
},
|
150 |
-
"ResumeButton": {
|
151 |
-
"resumeChat": "Continuar Conversa"
|
152 |
-
}
|
153 |
-
}
|
154 |
-
}
|
155 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
CHANGED
@@ -160,4 +160,10 @@ cython_debug/
|
|
160 |
#.idea/
|
161 |
|
162 |
# log files
|
163 |
-
*.log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
#.idea/
|
161 |
|
162 |
# log files
|
163 |
+
*.log
|
164 |
+
|
165 |
+
.ragatouille/*
|
166 |
+
*/__pycache__/*
|
167 |
+
*/.chainlit/translations/*
|
168 |
+
storage/logs/*
|
169 |
+
vectorstores/*
|
{.chainlit β code/.chainlit}/config.toml
RENAMED
File without changes
|
code/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .modules import *
|
chainlit.md β code/chainlit.md
RENAMED
File without changes
|
code/main.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
-
from
|
2 |
from langchain import PromptTemplate
|
3 |
-
from
|
4 |
-
from
|
5 |
from langchain.chains import RetrievalQA
|
6 |
-
from langchain.llms import CTransformers
|
7 |
import chainlit as cl
|
8 |
from langchain_community.chat_models import ChatOpenAI
|
9 |
from langchain_community.embeddings import OpenAIEmbeddings
|
@@ -11,13 +10,22 @@ import yaml
|
|
11 |
import logging
|
12 |
from dotenv import load_dotenv
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
from modules.helpers import get_sources
|
17 |
|
|
|
|
|
|
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
logger = logging.getLogger(__name__)
|
20 |
logger.setLevel(logging.INFO)
|
|
|
21 |
|
22 |
# Console Handler
|
23 |
console_handler = logging.StreamHandler()
|
@@ -26,13 +34,6 @@ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
|
26 |
console_handler.setFormatter(formatter)
|
27 |
logger.addHandler(console_handler)
|
28 |
|
29 |
-
# File Handler
|
30 |
-
log_file_path = "log_file.log" # Change this to your desired log file path
|
31 |
-
file_handler = logging.FileHandler(log_file_path)
|
32 |
-
file_handler.setLevel(logging.INFO)
|
33 |
-
file_handler.setFormatter(formatter)
|
34 |
-
logger.addHandler(file_handler)
|
35 |
-
|
36 |
|
37 |
# Adding option to select the chat profile
|
38 |
@cl.set_chat_profiles
|
@@ -66,12 +67,26 @@ def rename(orig_author: str):
|
|
66 |
# chainlit code
|
67 |
@cl.on_chat_start
|
68 |
async def start():
|
69 |
-
with open("
|
70 |
config = yaml.safe_load(f)
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
chat_profile = cl.user_session.get("chat_profile")
|
77 |
if chat_profile is not None:
|
@@ -93,8 +108,7 @@ async def start():
|
|
93 |
llm_tutor = LLMTutor(config, logger=logger)
|
94 |
|
95 |
chain = llm_tutor.qa_bot()
|
96 |
-
|
97 |
-
msg = cl.Message(content=f"Starting the bot {model}...")
|
98 |
await msg.send()
|
99 |
msg.content = opening_message
|
100 |
await msg.update()
|
@@ -104,24 +118,17 @@ async def start():
|
|
104 |
|
105 |
@cl.on_message
|
106 |
async def main(message):
|
|
|
107 |
user = cl.user_session.get("user")
|
108 |
chain = cl.user_session.get("chain")
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
#
|
113 |
-
# res=await chain.acall(message, callbacks=[cb])
|
114 |
-
res = await chain.acall(message.content)
|
115 |
-
print(f"response: {res}")
|
116 |
try:
|
117 |
answer = res["answer"]
|
118 |
except:
|
119 |
answer = res["result"]
|
120 |
-
print(f"answer: {answer}")
|
121 |
-
|
122 |
-
logger.info(f"Question: {res['question']}")
|
123 |
-
logger.info(f"History: {res['chat_history']}")
|
124 |
-
logger.info(f"Answer: {answer}\n")
|
125 |
|
126 |
answer_with_sources, source_elements = get_sources(res, answer)
|
127 |
|
|
|
1 |
+
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
|
2 |
from langchain import PromptTemplate
|
3 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
4 |
+
from langchain_community.vectorstores import FAISS
|
5 |
from langchain.chains import RetrievalQA
|
|
|
6 |
import chainlit as cl
|
7 |
from langchain_community.chat_models import ChatOpenAI
|
8 |
from langchain_community.embeddings import OpenAIEmbeddings
|
|
|
10 |
import logging
|
11 |
from dotenv import load_dotenv
|
12 |
|
13 |
+
import os
|
14 |
+
import sys
|
|
|
15 |
|
16 |
+
# Add the 'code' directory to the Python path
|
17 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
18 |
+
sys.path.append(current_dir)
|
19 |
|
20 |
+
from modules.chat.llm_tutor import LLMTutor
|
21 |
+
from modules.config.constants import *
|
22 |
+
from modules.chat.helpers import get_sources
|
23 |
+
|
24 |
+
|
25 |
+
global logger
|
26 |
logger = logging.getLogger(__name__)
|
27 |
logger.setLevel(logging.INFO)
|
28 |
+
logger.propagate = False
|
29 |
|
30 |
# Console Handler
|
31 |
console_handler = logging.StreamHandler()
|
|
|
34 |
console_handler.setFormatter(formatter)
|
35 |
logger.addHandler(console_handler)
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
# Adding option to select the chat profile
|
39 |
@cl.set_chat_profiles
|
|
|
67 |
# chainlit code
|
68 |
@cl.on_chat_start
|
69 |
async def start():
|
70 |
+
with open("modules/config/config.yml", "r") as f:
|
71 |
config = yaml.safe_load(f)
|
72 |
+
|
73 |
+
# Ensure log directory exists
|
74 |
+
log_directory = config["log_dir"]
|
75 |
+
if not os.path.exists(log_directory):
|
76 |
+
os.makedirs(log_directory)
|
77 |
+
|
78 |
+
# File Handler
|
79 |
+
log_file_path = (
|
80 |
+
f"{log_directory}/tutor.log" # Change this to your desired log file path
|
81 |
+
)
|
82 |
+
file_handler = logging.FileHandler(log_file_path, mode="w")
|
83 |
+
file_handler.setLevel(logging.INFO)
|
84 |
+
file_handler.setFormatter(formatter)
|
85 |
+
logger.addHandler(file_handler)
|
86 |
+
|
87 |
+
logger.info("Config file loaded")
|
88 |
+
logger.info(f"Config: {config}")
|
89 |
+
logger.info("Creating llm_tutor instance")
|
90 |
|
91 |
chat_profile = cl.user_session.get("chat_profile")
|
92 |
if chat_profile is not None:
|
|
|
108 |
llm_tutor = LLMTutor(config, logger=logger)
|
109 |
|
110 |
chain = llm_tutor.qa_bot()
|
111 |
+
msg = cl.Message(content=f"Starting the bot {chat_profile}...")
|
|
|
112 |
await msg.send()
|
113 |
msg.content = opening_message
|
114 |
await msg.update()
|
|
|
118 |
|
119 |
@cl.on_message
|
120 |
async def main(message):
|
121 |
+
global logger
|
122 |
user = cl.user_session.get("user")
|
123 |
chain = cl.user_session.get("chain")
|
124 |
+
cb = cl.AsyncLangchainCallbackHandler() # TODO: fix streaming here
|
125 |
+
cb.answer_reached = True
|
126 |
+
res = await chain.acall(message.content, callbacks=[cb])
|
127 |
+
# res = await chain.acall(message.content)
|
|
|
|
|
|
|
128 |
try:
|
129 |
answer = res["answer"]
|
130 |
except:
|
131 |
answer = res["result"]
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
answer_with_sources, source_elements = get_sources(res, answer)
|
134 |
|
code/modules/__init__.py
CHANGED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from . import vectorstore
|
2 |
+
from . import dataloader
|
code/modules/chat/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .llm_tutor import LLMTutor
|
2 |
+
from .chat_model_loader import ChatModelLoader
|
code/modules/{chat_model_loader.py β chat/chat_model_loader.py}
RENAMED
@@ -1,8 +1,7 @@
|
|
1 |
from langchain_community.chat_models import ChatOpenAI
|
2 |
-
from
|
3 |
-
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
4 |
from transformers import AutoTokenizer, TextStreamer
|
5 |
-
from
|
6 |
import torch
|
7 |
import transformers
|
8 |
import os
|
|
|
1 |
from langchain_community.chat_models import ChatOpenAI
|
2 |
+
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
|
|
3 |
from transformers import AutoTokenizer, TextStreamer
|
4 |
+
from langchain_community.llms import LlamaCpp
|
5 |
import torch
|
6 |
import transformers
|
7 |
import os
|
code/modules/{helpers.py β chat/helpers.py}
RENAMED
@@ -1,176 +1,6 @@
|
|
1 |
-
import
|
2 |
-
from bs4 import BeautifulSoup
|
3 |
-
from tqdm import tqdm
|
4 |
import chainlit as cl
|
5 |
-
from
|
6 |
-
import requests
|
7 |
-
from bs4 import BeautifulSoup
|
8 |
-
from urllib.parse import urlparse, urljoin, urldefrag
|
9 |
-
import asyncio
|
10 |
-
import aiohttp
|
11 |
-
from aiohttp import ClientSession
|
12 |
-
from typing import Dict, Any, List
|
13 |
-
|
14 |
-
try:
|
15 |
-
from modules.constants import *
|
16 |
-
except:
|
17 |
-
from constants import *
|
18 |
-
|
19 |
-
"""
|
20 |
-
Ref: https://python.plainenglish.io/scraping-the-subpages-on-a-website-ea2d4e3db113
|
21 |
-
"""
|
22 |
-
|
23 |
-
|
24 |
-
class WebpageCrawler:
|
25 |
-
def __init__(self):
|
26 |
-
self.dict_href_links = {}
|
27 |
-
|
28 |
-
async def fetch(self, session: ClientSession, url: str) -> str:
|
29 |
-
async with session.get(url) as response:
|
30 |
-
try:
|
31 |
-
return await response.text()
|
32 |
-
except UnicodeDecodeError:
|
33 |
-
return await response.text(encoding="latin1")
|
34 |
-
|
35 |
-
def url_exists(self, url: str) -> bool:
|
36 |
-
try:
|
37 |
-
response = requests.head(url)
|
38 |
-
return response.status_code == 200
|
39 |
-
except requests.ConnectionError:
|
40 |
-
return False
|
41 |
-
|
42 |
-
async def get_links(self, session: ClientSession, website_link: str, base_url: str):
|
43 |
-
html_data = await self.fetch(session, website_link)
|
44 |
-
soup = BeautifulSoup(html_data, "html.parser")
|
45 |
-
list_links = []
|
46 |
-
for link in soup.find_all("a", href=True):
|
47 |
-
href = link["href"].strip()
|
48 |
-
full_url = urljoin(base_url, href)
|
49 |
-
normalized_url = self.normalize_url(full_url) # sections removed
|
50 |
-
if (
|
51 |
-
normalized_url not in self.dict_href_links
|
52 |
-
and self.is_child_url(normalized_url, base_url)
|
53 |
-
and self.url_exists(normalized_url)
|
54 |
-
):
|
55 |
-
self.dict_href_links[normalized_url] = None
|
56 |
-
list_links.append(normalized_url)
|
57 |
-
|
58 |
-
return list_links
|
59 |
-
|
60 |
-
async def get_subpage_links(
|
61 |
-
self, session: ClientSession, urls: list, base_url: str
|
62 |
-
):
|
63 |
-
tasks = [self.get_links(session, url, base_url) for url in urls]
|
64 |
-
results = await asyncio.gather(*tasks)
|
65 |
-
all_links = [link for sublist in results for link in sublist]
|
66 |
-
return all_links
|
67 |
-
|
68 |
-
async def get_all_pages(self, url: str, base_url: str):
|
69 |
-
async with aiohttp.ClientSession() as session:
|
70 |
-
dict_links = {url: "Not-checked"}
|
71 |
-
counter = None
|
72 |
-
while counter != 0:
|
73 |
-
unchecked_links = [
|
74 |
-
link
|
75 |
-
for link, status in dict_links.items()
|
76 |
-
if status == "Not-checked"
|
77 |
-
]
|
78 |
-
if not unchecked_links:
|
79 |
-
break
|
80 |
-
new_links = await self.get_subpage_links(
|
81 |
-
session, unchecked_links, base_url
|
82 |
-
)
|
83 |
-
for link in unchecked_links:
|
84 |
-
dict_links[link] = "Checked"
|
85 |
-
print(f"Checked: {link}")
|
86 |
-
dict_links.update(
|
87 |
-
{
|
88 |
-
link: "Not-checked"
|
89 |
-
for link in new_links
|
90 |
-
if link not in dict_links
|
91 |
-
}
|
92 |
-
)
|
93 |
-
counter = len(
|
94 |
-
[
|
95 |
-
status
|
96 |
-
for status in dict_links.values()
|
97 |
-
if status == "Not-checked"
|
98 |
-
]
|
99 |
-
)
|
100 |
-
|
101 |
-
checked_urls = [
|
102 |
-
url for url, status in dict_links.items() if status == "Checked"
|
103 |
-
]
|
104 |
-
return checked_urls
|
105 |
-
|
106 |
-
def is_webpage(self, url: str) -> bool:
|
107 |
-
try:
|
108 |
-
response = requests.head(url, allow_redirects=True)
|
109 |
-
content_type = response.headers.get("Content-Type", "").lower()
|
110 |
-
return "text/html" in content_type
|
111 |
-
except requests.RequestException:
|
112 |
-
return False
|
113 |
-
|
114 |
-
def clean_url_list(self, urls):
|
115 |
-
files, webpages = [], []
|
116 |
-
|
117 |
-
for url in urls:
|
118 |
-
if self.is_webpage(url):
|
119 |
-
webpages.append(url)
|
120 |
-
else:
|
121 |
-
files.append(url)
|
122 |
-
|
123 |
-
return files, webpages
|
124 |
-
|
125 |
-
def is_child_url(self, url, base_url):
|
126 |
-
return url.startswith(base_url)
|
127 |
-
|
128 |
-
def normalize_url(self, url: str):
|
129 |
-
# Strip the fragment identifier
|
130 |
-
defragged_url, _ = urldefrag(url)
|
131 |
-
return defragged_url
|
132 |
-
|
133 |
-
|
134 |
-
def get_urls_from_file(file_path: str):
|
135 |
-
"""
|
136 |
-
Function to get urls from a file
|
137 |
-
"""
|
138 |
-
with open(file_path, "r") as f:
|
139 |
-
urls = f.readlines()
|
140 |
-
urls = [url.strip() for url in urls]
|
141 |
-
return urls
|
142 |
-
|
143 |
-
|
144 |
-
def get_base_url(url):
|
145 |
-
parsed_url = urlparse(url)
|
146 |
-
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/"
|
147 |
-
return base_url
|
148 |
-
|
149 |
-
|
150 |
-
def get_prompt(config):
|
151 |
-
if config["llm_params"]["use_history"]:
|
152 |
-
if config["llm_params"]["llm_loader"] == "local_llm":
|
153 |
-
custom_prompt_template = tinyllama_prompt_template_with_history
|
154 |
-
elif config["llm_params"]["llm_loader"] == "openai":
|
155 |
-
custom_prompt_template = openai_prompt_template_with_history
|
156 |
-
# else:
|
157 |
-
# custom_prompt_template = tinyllama_prompt_template_with_history # default
|
158 |
-
prompt = PromptTemplate(
|
159 |
-
template=custom_prompt_template,
|
160 |
-
input_variables=["context", "chat_history", "question"],
|
161 |
-
)
|
162 |
-
else:
|
163 |
-
if config["llm_params"]["llm_loader"] == "local_llm":
|
164 |
-
custom_prompt_template = tinyllama_prompt_template
|
165 |
-
elif config["llm_params"]["llm_loader"] == "openai":
|
166 |
-
custom_prompt_template = openai_prompt_template
|
167 |
-
# else:
|
168 |
-
# custom_prompt_template = tinyllama_prompt_template
|
169 |
-
prompt = PromptTemplate(
|
170 |
-
template=custom_prompt_template,
|
171 |
-
input_variables=["context", "question"],
|
172 |
-
)
|
173 |
-
return prompt
|
174 |
|
175 |
|
176 |
def get_sources(res, answer):
|
@@ -248,90 +78,27 @@ def get_sources(res, answer):
|
|
248 |
return full_answer, source_elements
|
249 |
|
250 |
|
251 |
-
def
|
252 |
-
"""
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
)
|
276 |
-
description_div = row.find("div", {"data-label": "Description"})
|
277 |
-
slides_link_tag = description_div.find("a", title="Download slides")
|
278 |
-
slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
|
279 |
-
slides_link = (
|
280 |
-
f"https://dl4ds.github.io{slides_link}" if slides_link else None
|
281 |
-
)
|
282 |
-
if slides_link:
|
283 |
-
date_mapping[slides_link] = date
|
284 |
-
except Exception as e:
|
285 |
-
print(f"Error processing schedule row: {e}")
|
286 |
-
continue
|
287 |
-
|
288 |
-
for block in lecture_blocks:
|
289 |
-
try:
|
290 |
-
# Extract the lecture title
|
291 |
-
title = block.find("span", style="font-weight: bold;").text.strip()
|
292 |
-
|
293 |
-
# Extract the TL;DR
|
294 |
-
tldr = block.find("strong", text="tl;dr:").next_sibling.strip()
|
295 |
-
|
296 |
-
# Extract the link to the slides
|
297 |
-
slides_link_tag = block.find("a", title="Download slides")
|
298 |
-
slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
|
299 |
-
slides_link = (
|
300 |
-
f"https://dl4ds.github.io{slides_link}" if slides_link else None
|
301 |
-
)
|
302 |
-
|
303 |
-
# Extract the link to the lecture recording
|
304 |
-
recording_link_tag = block.find("a", title="Download lecture recording")
|
305 |
-
recording_link = (
|
306 |
-
recording_link_tag["href"].strip() if recording_link_tag else None
|
307 |
-
)
|
308 |
-
|
309 |
-
# Extract suggested readings or summary if available
|
310 |
-
suggested_readings_tag = block.find("p", text="Suggested Readings:")
|
311 |
-
if suggested_readings_tag:
|
312 |
-
suggested_readings = suggested_readings_tag.find_next_sibling("ul")
|
313 |
-
if suggested_readings:
|
314 |
-
suggested_readings = suggested_readings.get_text(
|
315 |
-
separator="\n"
|
316 |
-
).strip()
|
317 |
-
else:
|
318 |
-
suggested_readings = "No specific readings provided."
|
319 |
-
else:
|
320 |
-
suggested_readings = "No specific readings provided."
|
321 |
-
|
322 |
-
# Get the date from the schedule
|
323 |
-
date = date_mapping.get(slides_link, "No date available")
|
324 |
-
|
325 |
-
# Add to the dictionary
|
326 |
-
lecture_metadata[slides_link] = {
|
327 |
-
"date": date,
|
328 |
-
"tldr": tldr,
|
329 |
-
"title": title,
|
330 |
-
"lecture_recording": recording_link,
|
331 |
-
"suggested_readings": suggested_readings,
|
332 |
-
}
|
333 |
-
except Exception as e:
|
334 |
-
print(f"Error processing block: {e}")
|
335 |
-
continue
|
336 |
-
|
337 |
-
return lecture_metadata
|
|
|
1 |
+
from modules.config.constants import *
|
|
|
|
|
2 |
import chainlit as cl
|
3 |
+
from langchain_core.prompts import PromptTemplate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
|
6 |
def get_sources(res, answer):
|
|
|
78 |
return full_answer, source_elements
|
79 |
|
80 |
|
81 |
+
def get_prompt(config):
|
82 |
+
if config["llm_params"]["use_history"]:
|
83 |
+
if config["llm_params"]["llm_loader"] == "local_llm":
|
84 |
+
custom_prompt_template = tinyllama_prompt_template_with_history
|
85 |
+
elif config["llm_params"]["llm_loader"] == "openai":
|
86 |
+
custom_prompt_template = openai_prompt_template_with_history
|
87 |
+
# else:
|
88 |
+
# custom_prompt_template = tinyllama_prompt_template_with_history # default
|
89 |
+
prompt = PromptTemplate(
|
90 |
+
template=custom_prompt_template,
|
91 |
+
input_variables=["context", "chat_history", "question"],
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
if config["llm_params"]["llm_loader"] == "local_llm":
|
95 |
+
custom_prompt_template = tinyllama_prompt_template
|
96 |
+
elif config["llm_params"]["llm_loader"] == "openai":
|
97 |
+
custom_prompt_template = openai_prompt_template
|
98 |
+
# else:
|
99 |
+
# custom_prompt_template = tinyllama_prompt_template
|
100 |
+
prompt = PromptTemplate(
|
101 |
+
template=custom_prompt_template,
|
102 |
+
input_variables=["context", "question"],
|
103 |
+
)
|
104 |
+
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/modules/{llm_tutor.py β chat/llm_tutor.py}
RENAMED
@@ -1,24 +1,52 @@
|
|
1 |
-
from langchain import PromptTemplate
|
2 |
-
from langchain.embeddings import HuggingFaceEmbeddings
|
3 |
-
from langchain_community.chat_models import ChatOpenAI
|
4 |
-
from langchain_community.embeddings import OpenAIEmbeddings
|
5 |
-
from langchain.vectorstores import FAISS
|
6 |
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
7 |
-
from langchain.
|
8 |
-
|
|
|
|
|
9 |
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
10 |
import os
|
11 |
-
from modules.constants import *
|
12 |
-
from modules.helpers import get_prompt
|
13 |
-
from modules.chat_model_loader import ChatModelLoader
|
14 |
-
from modules.
|
15 |
-
|
|
|
|
|
|
|
16 |
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
17 |
import inspect
|
18 |
from langchain.chains.conversational_retrieval.base import _get_chat_history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
async def _acall(
|
23 |
self,
|
24 |
inputs: Dict[str, Any],
|
@@ -26,13 +54,31 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
26 |
) -> Dict[str, Any]:
|
27 |
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
28 |
question = inputs["question"]
|
29 |
-
get_chat_history = self.
|
30 |
chat_history_str = get_chat_history(inputs["chat_history"])
|
31 |
-
print(f"chat_history_str: {chat_history_str}")
|
32 |
if chat_history_str:
|
33 |
-
callbacks = _run_manager.get_child()
|
34 |
-
new_question = await self.question_generator.arun(
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
)
|
37 |
else:
|
38 |
new_question = question
|
@@ -56,27 +102,24 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
56 |
# Prepare the final prompt with metadata
|
57 |
context = "\n\n".join(
|
58 |
[
|
59 |
-
f"Document content: {doc.page_content}\nMetadata: {doc.metadata}"
|
60 |
-
for doc in docs
|
61 |
]
|
62 |
)
|
63 |
-
final_prompt =
|
64 |
-
You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos.
|
65 |
-
|
66 |
-
|
67 |
-
Use the
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
Chat History
|
72 |
-
{
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
Question: {new_question}
|
78 |
-
AI Tutor:
|
79 |
-
"""
|
80 |
|
81 |
new_inputs["input"] = final_prompt
|
82 |
new_inputs["question"] = final_prompt
|
@@ -98,8 +141,9 @@ class LLMTutor:
|
|
98 |
def __init__(self, config, logger=None):
|
99 |
self.config = config
|
100 |
self.llm = self.load_llm()
|
101 |
-
self.
|
102 |
-
|
|
|
103 |
self.vector_db.create_database()
|
104 |
self.vector_db.save_database()
|
105 |
|
@@ -114,24 +158,20 @@ class LLMTutor:
|
|
114 |
|
115 |
# Retrieval QA Chain
|
116 |
def retrieval_qa_chain(self, llm, prompt, db):
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
# "k": self.config["embedding_options"]["search_top_k"],
|
126 |
-
# },
|
127 |
-
)
|
128 |
-
elif self.config["embedding_options"]["db_option"] == "RAGatouille":
|
129 |
retriever = db.as_langchain_retriever(
|
130 |
-
k=self.config["
|
131 |
)
|
|
|
132 |
if self.config["llm_params"]["use_history"]:
|
133 |
-
memory =
|
134 |
-
llm = llm,
|
135 |
k=self.config["llm_params"]["memory_window"],
|
136 |
memory_key="chat_history",
|
137 |
return_messages=True,
|
@@ -145,6 +185,7 @@ class LLMTutor:
|
|
145 |
return_source_documents=True,
|
146 |
memory=memory,
|
147 |
combine_docs_chain_kwargs={"prompt": prompt},
|
|
|
148 |
)
|
149 |
else:
|
150 |
qa_chain = RetrievalQA.from_chain_type(
|
@@ -166,7 +207,9 @@ class LLMTutor:
|
|
166 |
def qa_bot(self):
|
167 |
db = self.vector_db.load_database()
|
168 |
qa_prompt = self.set_custom_prompt()
|
169 |
-
qa = self.retrieval_qa_chain(
|
|
|
|
|
170 |
|
171 |
return qa
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
2 |
+
from langchain.memory import (
|
3 |
+
ConversationBufferWindowMemory,
|
4 |
+
ConversationSummaryBufferMemory,
|
5 |
+
)
|
6 |
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
7 |
import os
|
8 |
+
from modules.config.constants import *
|
9 |
+
from modules.chat.helpers import get_prompt
|
10 |
+
from modules.chat.chat_model_loader import ChatModelLoader
|
11 |
+
from modules.vectorstore.store_manager import VectorStoreManager
|
12 |
+
|
13 |
+
from modules.retriever import FaissRetriever, ChromaRetriever
|
14 |
+
|
15 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
16 |
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
17 |
import inspect
|
18 |
from langchain.chains.conversational_retrieval.base import _get_chat_history
|
19 |
+
from langchain_core.messages import BaseMessage
|
20 |
+
|
21 |
+
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
|
22 |
+
|
23 |
+
from langchain_core.output_parsers import StrOutputParser
|
24 |
+
from langchain_core.prompts import ChatPromptTemplate
|
25 |
+
from langchain_community.chat_models import ChatOpenAI
|
26 |
|
27 |
|
28 |
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
29 |
+
|
30 |
+
def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
|
31 |
+
_ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
|
32 |
+
buffer = ""
|
33 |
+
for dialogue_turn in chat_history:
|
34 |
+
if isinstance(dialogue_turn, BaseMessage):
|
35 |
+
role_prefix = _ROLE_MAP.get(
|
36 |
+
dialogue_turn.type, f"{dialogue_turn.type}: "
|
37 |
+
)
|
38 |
+
buffer += f"\n{role_prefix}{dialogue_turn.content}"
|
39 |
+
elif isinstance(dialogue_turn, tuple):
|
40 |
+
human = "Student: " + dialogue_turn[0]
|
41 |
+
ai = "AI Tutor: " + dialogue_turn[1]
|
42 |
+
buffer += "\n" + "\n".join([human, ai])
|
43 |
+
else:
|
44 |
+
raise ValueError(
|
45 |
+
f"Unsupported chat history format: {type(dialogue_turn)}."
|
46 |
+
f" Full chat history: {chat_history} "
|
47 |
+
)
|
48 |
+
return buffer
|
49 |
+
|
50 |
async def _acall(
|
51 |
self,
|
52 |
inputs: Dict[str, Any],
|
|
|
54 |
) -> Dict[str, Any]:
|
55 |
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
56 |
question = inputs["question"]
|
57 |
+
get_chat_history = self._get_chat_history
|
58 |
chat_history_str = get_chat_history(inputs["chat_history"])
|
|
|
59 |
if chat_history_str:
|
60 |
+
# callbacks = _run_manager.get_child()
|
61 |
+
# new_question = await self.question_generator.arun(
|
62 |
+
# question=question, chat_history=chat_history_str, callbacks=callbacks
|
63 |
+
# )
|
64 |
+
system = (
|
65 |
+
"You are an AI Tutor helping a student. Your task is to rephrase the student's question to provide more context from their chat history (only if relevant), ensuring the rephrased question still reflects the student's point of view. "
|
66 |
+
"The rephrased question should incorporate relevant details from the chat history to make it clearer and more specific. It should also expand upon the original question to provide more context on only what the student provided."
|
67 |
+
"Always end the rephrased question with the original question in parentheses for reference. "
|
68 |
+
"Do not change the meaning of the question, and keep the tone and perspective as if it were asked by the student. "
|
69 |
+
"Here is the chat history for context: \n{chat_history_str}\n"
|
70 |
+
"Now, rephrase the following question: '{question}'"
|
71 |
+
)
|
72 |
+
prompt = ChatPromptTemplate.from_messages(
|
73 |
+
[
|
74 |
+
("system", system),
|
75 |
+
("human", "{question}, {chat_history_str}"),
|
76 |
+
]
|
77 |
+
)
|
78 |
+
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
79 |
+
step_back = prompt | llm | StrOutputParser()
|
80 |
+
new_question = step_back.invoke(
|
81 |
+
{"question": question, "chat_history_str": chat_history_str}
|
82 |
)
|
83 |
else:
|
84 |
new_question = question
|
|
|
102 |
# Prepare the final prompt with metadata
|
103 |
context = "\n\n".join(
|
104 |
[
|
105 |
+
f"Context {idx+1}: \n(Document content: {doc.page_content}\nMetadata: (source_file: {doc.metadata['source']}))"
|
106 |
+
for idx, doc in enumerate(docs)
|
107 |
]
|
108 |
)
|
109 |
+
final_prompt = (
|
110 |
+
"You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. "
|
111 |
+
"Use the following pieces of information to answer the user's question. "
|
112 |
+
"If you don't know the answer, try your best, but don't try to make up an answer. Keep the flow of the conversation going. "
|
113 |
+
"Use the chat history just as a gist to answer the question only if it's relevant; otherwise, ignore it. Do not repeat responses in the history. Use the context as a guide to construct your answer. The context for the answer will be under 'Document context:'. Remember, the conext may include text not directly related to the question."
|
114 |
+
"Make sure to use the source_file field in metadata from each document to provide links to the user to the correct sources. "
|
115 |
+
"The context is ordered by relevance to the question. "
|
116 |
+
"Talk in a friendly and personalized manner, similar to how you would speak to a friend who needs help. Make the conversation engaging and avoid sounding repetitive or robotic.\n\n"
|
117 |
+
f"Chat History:\n{chat_history_str}\n\n"
|
118 |
+
f"Context:\n{context}\n\n"
|
119 |
+
f"Student: {new_question}\n"
|
120 |
+
"Anwer the student's question in a friendly, concise, and engaging manner.\n"
|
121 |
+
"AI Tutor:"
|
122 |
+
)
|
|
|
|
|
|
|
123 |
|
124 |
new_inputs["input"] = final_prompt
|
125 |
new_inputs["question"] = final_prompt
|
|
|
141 |
def __init__(self, config, logger=None):
|
142 |
self.config = config
|
143 |
self.llm = self.load_llm()
|
144 |
+
self.logger = logger
|
145 |
+
self.vector_db = VectorStoreManager(config, logger=self.logger)
|
146 |
+
if self.config["vectorstore"]["embedd_files"]:
|
147 |
self.vector_db.create_database()
|
148 |
self.vector_db.save_database()
|
149 |
|
|
|
158 |
|
159 |
# Retrieval QA Chain
|
160 |
def retrieval_qa_chain(self, llm, prompt, db):
|
161 |
+
|
162 |
+
if self.config["vectorstore"]["db_option"] == "FAISS":
|
163 |
+
retriever = FaissRetriever().return_retriever(db, self.config)
|
164 |
+
|
165 |
+
elif self.config["vectorstore"]["db_option"] == "Chroma":
|
166 |
+
retriever = ChromaRetriever().return_retriever(db, self.config)
|
167 |
+
|
168 |
+
elif self.config["vectorstore"]["db_option"] == "RAGatouille":
|
|
|
|
|
|
|
|
|
169 |
retriever = db.as_langchain_retriever(
|
170 |
+
k=self.config["vectorstore"]["search_top_k"]
|
171 |
)
|
172 |
+
|
173 |
if self.config["llm_params"]["use_history"]:
|
174 |
+
memory = ConversationBufferWindowMemory(
|
|
|
175 |
k=self.config["llm_params"]["memory_window"],
|
176 |
memory_key="chat_history",
|
177 |
return_messages=True,
|
|
|
185 |
return_source_documents=True,
|
186 |
memory=memory,
|
187 |
combine_docs_chain_kwargs={"prompt": prompt},
|
188 |
+
response_if_no_docs_found="No context found",
|
189 |
)
|
190 |
else:
|
191 |
qa_chain = RetrievalQA.from_chain_type(
|
|
|
207 |
def qa_bot(self):
|
208 |
db = self.vector_db.load_database()
|
209 |
qa_prompt = self.set_custom_prompt()
|
210 |
+
qa = self.retrieval_qa_chain(
|
211 |
+
self.llm, qa_prompt, db
|
212 |
+
) # TODO: PROMPT is overwritten in CustomConversationalRetrievalChain
|
213 |
|
214 |
return qa
|
215 |
|
code/modules/config/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .constants import *
|
code/{config.yml β modules/config/config.yml}
RENAMED
@@ -1,23 +1,32 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
2 |
embedd_files: False # bool
|
3 |
-
data_path: 'storage/data' # str
|
4 |
-
url_file_path: 'storage/data/urls.txt' # str
|
5 |
-
expand_urls:
|
6 |
-
db_option : '
|
7 |
-
db_path : 'vectorstores' # str
|
8 |
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
9 |
search_top_k : 3 # int
|
10 |
score_threshold : 0.2 # float
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
llm_params:
|
12 |
use_history: True # bool
|
13 |
memory_window: 3 # int
|
14 |
llm_loader: 'openai' # str [local_llm, openai]
|
15 |
openai_params:
|
16 |
model: 'gpt-3.5-turbo-1106' # str [gpt-3.5-turbo-1106, gpt-4]
|
17 |
-
|
18 |
-
model: "storage/models/llama-2-7b-chat.Q4_0.gguf"
|
19 |
-
model_type: "llama"
|
20 |
-
temperature: 0.2
|
21 |
splitter_options:
|
22 |
use_splitter: True # bool
|
23 |
split_by_token : True # bool
|
|
|
1 |
+
log_dir: '../storage/logs' # str
|
2 |
+
log_chunk_dir: '../storage/logs/chunks' # str
|
3 |
+
device: 'cpu' # str [cuda, cpu]
|
4 |
+
|
5 |
+
vectorstore:
|
6 |
embedd_files: False # bool
|
7 |
+
data_path: '../storage/data' # str
|
8 |
+
url_file_path: '../storage/data/urls.txt' # str
|
9 |
+
expand_urls: False # bool
|
10 |
+
db_option : 'Chroma' # str [FAISS, Chroma, RAGatouille]
|
11 |
+
db_path : '../vectorstores' # str
|
12 |
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
13 |
search_top_k : 3 # int
|
14 |
score_threshold : 0.2 # float
|
15 |
+
|
16 |
+
faiss_params: # Not used as of now
|
17 |
+
index_path: '../vectorstores/faiss.index' # str
|
18 |
+
index_type: 'Flat' # str [Flat, HNSW, IVF]
|
19 |
+
index_dimension: 384 # int
|
20 |
+
index_nlist: 100 # int
|
21 |
+
index_nprobe: 10 # int
|
22 |
+
|
23 |
llm_params:
|
24 |
use_history: True # bool
|
25 |
memory_window: 3 # int
|
26 |
llm_loader: 'openai' # str [local_llm, openai]
|
27 |
openai_params:
|
28 |
model: 'gpt-3.5-turbo-1106' # str [gpt-3.5-turbo-1106, gpt-4]
|
29 |
+
|
|
|
|
|
|
|
30 |
splitter_options:
|
31 |
use_splitter: True # bool
|
32 |
split_by_token : True # bool
|
code/modules/{constants.py β config/constants.py}
RENAMED
File without changes
|
code/modules/dataloader/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .webpage_crawler import WebpageCrawler
|
2 |
+
from .data_loader import DataLoader
|
code/modules/{data_loader.py β dataloader/data_loader.py}
RENAMED
@@ -16,15 +16,12 @@ import logging
|
|
16 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
from ragatouille import RAGPretrainedModel
|
18 |
from langchain.chains import LLMChain
|
19 |
-
from
|
20 |
from langchain import PromptTemplate
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
from modules.helpers import get_metadata
|
24 |
-
except:
|
25 |
-
from helpers import get_metadata
|
26 |
-
|
27 |
-
logger = logging.getLogger(__name__)
|
28 |
|
29 |
|
30 |
class PDFReader:
|
@@ -40,8 +37,9 @@ class PDFReader:
|
|
40 |
|
41 |
|
42 |
class FileReader:
|
43 |
-
def __init__(self):
|
44 |
self.pdf_reader = PDFReader()
|
|
|
45 |
|
46 |
def extract_text_from_pdf(self, pdf_path):
|
47 |
text = ""
|
@@ -61,7 +59,7 @@ class FileReader:
|
|
61 |
temp_file_path = temp_file.name
|
62 |
return temp_file_path
|
63 |
else:
|
64 |
-
|
65 |
return None
|
66 |
|
67 |
def read_pdf(self, temp_file_path: str):
|
@@ -99,13 +97,18 @@ class FileReader:
|
|
99 |
if response.status_code == 200:
|
100 |
return [Document(page_content=response.text)]
|
101 |
else:
|
102 |
-
|
103 |
return None
|
104 |
|
105 |
|
106 |
class ChunkProcessor:
|
107 |
-
def __init__(self, config):
|
108 |
self.config = config
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
if config["splitter_options"]["use_splitter"]:
|
111 |
if config["splitter_options"]["split_by_token"]:
|
@@ -124,7 +127,7 @@ class ChunkProcessor:
|
|
124 |
)
|
125 |
else:
|
126 |
self.splitter = None
|
127 |
-
logger.info("ChunkProcessor instance created")
|
128 |
|
129 |
def remove_delimiters(self, document_chunks: list):
|
130 |
for chunk in document_chunks:
|
@@ -139,7 +142,6 @@ class ChunkProcessor:
|
|
139 |
del document_chunks[0]
|
140 |
for _ in range(end):
|
141 |
document_chunks.pop()
|
142 |
-
logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
|
143 |
return document_chunks
|
144 |
|
145 |
def process_chunks(
|
@@ -172,122 +174,184 @@ class ChunkProcessor:
|
|
172 |
|
173 |
return document_chunks
|
174 |
|
175 |
-
def
|
176 |
-
self.document_chunks_full = []
|
177 |
-
self.parent_document_names = []
|
178 |
-
self.child_document_names = []
|
179 |
-
self.documents = []
|
180 |
-
self.document_metadata = []
|
181 |
-
|
182 |
addl_metadata = get_metadata(
|
183 |
"https://dl4ds.github.io/sp2024/lectures/",
|
184 |
"https://dl4ds.github.io/sp2024/schedule/",
|
185 |
) # For any additional metadata
|
186 |
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
self.parent_document_names.append(file_name)
|
219 |
-
if self.config["embedding_options"]["db_option"] not in [
|
220 |
-
"RAGatouille"
|
221 |
-
]:
|
222 |
-
document_chunks = self.process_chunks(
|
223 |
-
self.documents[-1],
|
224 |
-
file_type,
|
225 |
-
source=file_path,
|
226 |
-
page=page_num,
|
227 |
-
metadata=metadata,
|
228 |
-
)
|
229 |
-
self.document_chunks_full.extend(document_chunks)
|
230 |
-
|
231 |
-
# except Exception as e:
|
232 |
-
# logger.error(f"Error processing file {file_name}: {str(e)}")
|
233 |
-
|
234 |
-
self.process_weblinks(file_reader, weblinks)
|
235 |
-
|
236 |
-
logger.info(
|
237 |
f"Total document chunks extracted: {len(self.document_chunks_full)}"
|
238 |
)
|
239 |
-
return (
|
240 |
-
self.document_chunks_full,
|
241 |
-
self.child_document_names,
|
242 |
-
self.documents,
|
243 |
-
self.document_metadata,
|
244 |
-
)
|
245 |
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
|
284 |
|
285 |
class DataLoader:
|
286 |
-
def __init__(self, config):
|
287 |
-
self.file_reader = FileReader()
|
288 |
-
self.chunk_processor = ChunkProcessor(config)
|
289 |
|
290 |
def get_chunks(self, uploaded_files, weblinks):
|
291 |
-
return self.chunk_processor.
|
292 |
self.file_reader, uploaded_files, weblinks
|
293 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
from ragatouille import RAGPretrainedModel
|
18 |
from langchain.chains import LLMChain
|
19 |
+
from langchain_community.llms import OpenAI
|
20 |
from langchain import PromptTemplate
|
21 |
+
import json
|
22 |
+
from concurrent.futures import ThreadPoolExecutor
|
23 |
|
24 |
+
from modules.dataloader.helpers import get_metadata
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
class PDFReader:
|
|
|
37 |
|
38 |
|
39 |
class FileReader:
|
40 |
+
def __init__(self, logger):
|
41 |
self.pdf_reader = PDFReader()
|
42 |
+
self.logger = logger
|
43 |
|
44 |
def extract_text_from_pdf(self, pdf_path):
|
45 |
text = ""
|
|
|
59 |
temp_file_path = temp_file.name
|
60 |
return temp_file_path
|
61 |
else:
|
62 |
+
self.logger.error(f"Failed to download PDF from URL: {pdf_url}")
|
63 |
return None
|
64 |
|
65 |
def read_pdf(self, temp_file_path: str):
|
|
|
97 |
if response.status_code == 200:
|
98 |
return [Document(page_content=response.text)]
|
99 |
else:
|
100 |
+
self.logger.error(f"Failed to fetch .tex file from URL: {tex_url}")
|
101 |
return None
|
102 |
|
103 |
|
104 |
class ChunkProcessor:
|
105 |
+
def __init__(self, config, logger):
|
106 |
self.config = config
|
107 |
+
self.logger = logger
|
108 |
+
|
109 |
+
self.document_data = {}
|
110 |
+
self.document_metadata = {}
|
111 |
+
self.document_chunks_full = []
|
112 |
|
113 |
if config["splitter_options"]["use_splitter"]:
|
114 |
if config["splitter_options"]["split_by_token"]:
|
|
|
127 |
)
|
128 |
else:
|
129 |
self.splitter = None
|
130 |
+
self.logger.info("ChunkProcessor instance created")
|
131 |
|
132 |
def remove_delimiters(self, document_chunks: list):
|
133 |
for chunk in document_chunks:
|
|
|
142 |
del document_chunks[0]
|
143 |
for _ in range(end):
|
144 |
document_chunks.pop()
|
|
|
145 |
return document_chunks
|
146 |
|
147 |
def process_chunks(
|
|
|
174 |
|
175 |
return document_chunks
|
176 |
|
177 |
+
def chunk_docs(self, file_reader, uploaded_files, weblinks):
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
addl_metadata = get_metadata(
|
179 |
"https://dl4ds.github.io/sp2024/lectures/",
|
180 |
"https://dl4ds.github.io/sp2024/schedule/",
|
181 |
) # For any additional metadata
|
182 |
|
183 |
+
with ThreadPoolExecutor() as executor:
|
184 |
+
executor.map(
|
185 |
+
self.process_file,
|
186 |
+
uploaded_files,
|
187 |
+
range(len(uploaded_files)),
|
188 |
+
[file_reader] * len(uploaded_files),
|
189 |
+
[addl_metadata] * len(uploaded_files),
|
190 |
+
)
|
191 |
+
executor.map(
|
192 |
+
self.process_weblink,
|
193 |
+
weblinks,
|
194 |
+
range(len(weblinks)),
|
195 |
+
[file_reader] * len(weblinks),
|
196 |
+
[addl_metadata] * len(weblinks),
|
197 |
+
)
|
198 |
+
|
199 |
+
document_names = [
|
200 |
+
f"{file_name}_{page_num}"
|
201 |
+
for file_name, pages in self.document_data.items()
|
202 |
+
for page_num in pages.keys()
|
203 |
+
]
|
204 |
+
documents = [
|
205 |
+
page for doc in self.document_data.values() for page in doc.values()
|
206 |
+
]
|
207 |
+
document_metadata = [
|
208 |
+
page for doc in self.document_metadata.values() for page in doc.values()
|
209 |
+
]
|
210 |
+
|
211 |
+
self.save_document_data()
|
212 |
+
|
213 |
+
self.logger.info(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
f"Total document chunks extracted: {len(self.document_chunks_full)}"
|
215 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
|
217 |
+
return self.document_chunks_full, document_names, documents, document_metadata
|
218 |
+
|
219 |
+
def process_documents(
|
220 |
+
self, documents, file_path, file_type, metadata_source, addl_metadata
|
221 |
+
):
|
222 |
+
file_data = {}
|
223 |
+
file_metadata = {}
|
224 |
+
|
225 |
+
for doc in documents:
|
226 |
+
if len(doc.page_content) <= 400:
|
227 |
+
continue
|
228 |
+
|
229 |
+
page_num = doc.metadata.get("page", 0)
|
230 |
+
file_data[page_num] = doc.page_content
|
231 |
+
metadata = (
|
232 |
+
addl_metadata.get(file_path, {})
|
233 |
+
if metadata_source == "file"
|
234 |
+
else {"source": file_path, "page": page_num}
|
235 |
+
)
|
236 |
+
file_metadata[page_num] = metadata
|
237 |
+
|
238 |
+
if self.config["vectorstore"]["db_option"] not in ["RAGatouille"]:
|
239 |
+
document_chunks = self.process_chunks(
|
240 |
+
doc.page_content,
|
241 |
+
file_type,
|
242 |
+
source=file_path,
|
243 |
+
page=page_num,
|
244 |
+
metadata=metadata,
|
245 |
+
)
|
246 |
+
self.document_chunks_full.extend(document_chunks)
|
247 |
+
|
248 |
+
self.document_data[file_path] = file_data
|
249 |
+
self.document_metadata[file_path] = file_metadata
|
250 |
+
|
251 |
+
def process_file(self, file_path, file_index, file_reader, addl_metadata):
|
252 |
+
file_name = os.path.basename(file_path)
|
253 |
+
if file_name in self.document_data:
|
254 |
+
return
|
255 |
+
|
256 |
+
file_type = file_name.split(".")[-1].lower()
|
257 |
+
self.logger.info(f"Reading file {file_index + 1}: {file_path}")
|
258 |
+
|
259 |
+
read_methods = {
|
260 |
+
"pdf": file_reader.read_pdf,
|
261 |
+
"txt": file_reader.read_txt,
|
262 |
+
"docx": file_reader.read_docx,
|
263 |
+
"srt": file_reader.read_srt,
|
264 |
+
"tex": file_reader.read_tex_from_url,
|
265 |
+
}
|
266 |
+
if file_type not in read_methods:
|
267 |
+
self.logger.warning(f"Unsupported file type: {file_type}")
|
268 |
+
return
|
269 |
+
|
270 |
+
try:
|
271 |
+
documents = read_methods[file_type](file_path)
|
272 |
+
self.process_documents(
|
273 |
+
documents, file_path, file_type, "file", addl_metadata
|
274 |
+
)
|
275 |
+
except Exception as e:
|
276 |
+
self.logger.error(f"Error processing file {file_name}: {str(e)}")
|
277 |
+
|
278 |
+
def process_weblink(self, link, link_index, file_reader, addl_metadata):
|
279 |
+
if link in self.document_data:
|
280 |
+
return
|
281 |
+
|
282 |
+
self.logger.info(f"Reading link {link_index + 1} : {link}")
|
283 |
+
|
284 |
+
try:
|
285 |
+
if "youtube" in link:
|
286 |
+
documents = file_reader.read_youtube_transcript(link)
|
287 |
+
else:
|
288 |
+
documents = file_reader.read_html(link)
|
289 |
+
|
290 |
+
self.process_documents(documents, link, "txt", "link", addl_metadata)
|
291 |
+
except Exception as e:
|
292 |
+
self.logger.error(f"Error Reading link {link_index + 1} : {link}: {str(e)}")
|
293 |
+
|
294 |
+
def save_document_data(self):
|
295 |
+
if not os.path.exists(f"{self.config['log_chunk_dir']}/docs"):
|
296 |
+
os.makedirs(f"{self.config['log_chunk_dir']}/docs")
|
297 |
+
self.logger.info(
|
298 |
+
f"Creating directory {self.config['log_chunk_dir']}/docs for document data"
|
299 |
+
)
|
300 |
+
self.logger.info(
|
301 |
+
f"Saving document content to {self.config['log_chunk_dir']}/docs/doc_content.json"
|
302 |
+
)
|
303 |
+
if not os.path.exists(f"{self.config['log_chunk_dir']}/metadata"):
|
304 |
+
os.makedirs(f"{self.config['log_chunk_dir']}/metadata")
|
305 |
+
self.logger.info(
|
306 |
+
f"Creating directory {self.config['log_chunk_dir']}/metadata for document metadata"
|
307 |
+
)
|
308 |
+
self.logger.info(
|
309 |
+
f"Saving document metadata to {self.config['log_chunk_dir']}/metadata/doc_metadata.json"
|
310 |
+
)
|
311 |
+
with open(
|
312 |
+
f"{self.config['log_chunk_dir']}/docs/doc_content.json", "w"
|
313 |
+
) as json_file:
|
314 |
+
json.dump(self.document_data, json_file, indent=4)
|
315 |
+
with open(
|
316 |
+
f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "w"
|
317 |
+
) as json_file:
|
318 |
+
json.dump(self.document_metadata, json_file, indent=4)
|
319 |
+
|
320 |
+
def load_document_data(self):
|
321 |
+
with open(
|
322 |
+
f"{self.config['log_chunk_dir']}/docs/doc_content.json", "r"
|
323 |
+
) as json_file:
|
324 |
+
self.document_data = json.load(json_file)
|
325 |
+
with open(
|
326 |
+
f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "r"
|
327 |
+
) as json_file:
|
328 |
+
self.document_metadata = json.load(json_file)
|
329 |
|
330 |
|
331 |
class DataLoader:
|
332 |
+
def __init__(self, config, logger=None):
|
333 |
+
self.file_reader = FileReader(logger=logger)
|
334 |
+
self.chunk_processor = ChunkProcessor(config, logger=logger)
|
335 |
|
336 |
def get_chunks(self, uploaded_files, weblinks):
|
337 |
+
return self.chunk_processor.chunk_docs(
|
338 |
self.file_reader, uploaded_files, weblinks
|
339 |
)
|
340 |
+
|
341 |
+
|
342 |
+
if __name__ == "__main__":
|
343 |
+
import yaml
|
344 |
+
|
345 |
+
logger = logging.getLogger(__name__)
|
346 |
+
logger.setLevel(logging.INFO)
|
347 |
+
|
348 |
+
with open("../code/config.yml", "r") as f:
|
349 |
+
config = yaml.safe_load(f)
|
350 |
+
|
351 |
+
data_loader = DataLoader(config, logger=logger)
|
352 |
+
document_chunks, document_names, documents, document_metadata = (
|
353 |
+
data_loader.get_chunks(
|
354 |
+
[],
|
355 |
+
["https://dl4ds.github.io/sp2024/"],
|
356 |
+
)
|
357 |
+
)
|
code/modules/dataloader/helpers.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
|
6 |
+
def get_urls_from_file(file_path: str):
|
7 |
+
"""
|
8 |
+
Function to get urls from a file
|
9 |
+
"""
|
10 |
+
with open(file_path, "r") as f:
|
11 |
+
urls = f.readlines()
|
12 |
+
urls = [url.strip() for url in urls]
|
13 |
+
return urls
|
14 |
+
|
15 |
+
|
16 |
+
def get_base_url(url):
|
17 |
+
parsed_url = urlparse(url)
|
18 |
+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/"
|
19 |
+
return base_url
|
20 |
+
|
21 |
+
|
22 |
+
def get_metadata(lectures_url, schedule_url):
|
23 |
+
"""
|
24 |
+
Function to get the lecture metadata from the lectures and schedule URLs.
|
25 |
+
"""
|
26 |
+
lecture_metadata = {}
|
27 |
+
|
28 |
+
# Get the main lectures page content
|
29 |
+
r_lectures = requests.get(lectures_url)
|
30 |
+
soup_lectures = BeautifulSoup(r_lectures.text, "html.parser")
|
31 |
+
|
32 |
+
# Get the main schedule page content
|
33 |
+
r_schedule = requests.get(schedule_url)
|
34 |
+
soup_schedule = BeautifulSoup(r_schedule.text, "html.parser")
|
35 |
+
|
36 |
+
# Find all lecture blocks
|
37 |
+
lecture_blocks = soup_lectures.find_all("div", class_="lecture-container")
|
38 |
+
|
39 |
+
# Create a mapping from slides link to date
|
40 |
+
date_mapping = {}
|
41 |
+
schedule_rows = soup_schedule.find_all("li", class_="table-row-lecture")
|
42 |
+
for row in schedule_rows:
|
43 |
+
try:
|
44 |
+
date = (
|
45 |
+
row.find("div", {"data-label": "Date"}).get_text(separator=" ").strip()
|
46 |
+
)
|
47 |
+
description_div = row.find("div", {"data-label": "Description"})
|
48 |
+
slides_link_tag = description_div.find("a", title="Download slides")
|
49 |
+
slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
|
50 |
+
slides_link = (
|
51 |
+
f"https://dl4ds.github.io{slides_link}" if slides_link else None
|
52 |
+
)
|
53 |
+
if slides_link:
|
54 |
+
date_mapping[slides_link] = date
|
55 |
+
except Exception as e:
|
56 |
+
print(f"Error processing schedule row: {e}")
|
57 |
+
continue
|
58 |
+
|
59 |
+
for block in lecture_blocks:
|
60 |
+
try:
|
61 |
+
# Extract the lecture title
|
62 |
+
title = block.find("span", style="font-weight: bold;").text.strip()
|
63 |
+
|
64 |
+
# Extract the TL;DR
|
65 |
+
tldr = block.find("strong", text="tl;dr:").next_sibling.strip()
|
66 |
+
|
67 |
+
# Extract the link to the slides
|
68 |
+
slides_link_tag = block.find("a", title="Download slides")
|
69 |
+
slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
|
70 |
+
slides_link = (
|
71 |
+
f"https://dl4ds.github.io{slides_link}" if slides_link else None
|
72 |
+
)
|
73 |
+
|
74 |
+
# Extract the link to the lecture recording
|
75 |
+
recording_link_tag = block.find("a", title="Download lecture recording")
|
76 |
+
recording_link = (
|
77 |
+
recording_link_tag["href"].strip() if recording_link_tag else None
|
78 |
+
)
|
79 |
+
|
80 |
+
# Extract suggested readings or summary if available
|
81 |
+
suggested_readings_tag = block.find("p", text="Suggested Readings:")
|
82 |
+
if suggested_readings_tag:
|
83 |
+
suggested_readings = suggested_readings_tag.find_next_sibling("ul")
|
84 |
+
if suggested_readings:
|
85 |
+
suggested_readings = suggested_readings.get_text(
|
86 |
+
separator="\n"
|
87 |
+
).strip()
|
88 |
+
else:
|
89 |
+
suggested_readings = "No specific readings provided."
|
90 |
+
else:
|
91 |
+
suggested_readings = "No specific readings provided."
|
92 |
+
|
93 |
+
# Get the date from the schedule
|
94 |
+
date = date_mapping.get(slides_link, "No date available")
|
95 |
+
|
96 |
+
# Add to the dictionary
|
97 |
+
lecture_metadata[slides_link] = {
|
98 |
+
"date": date,
|
99 |
+
"tldr": tldr,
|
100 |
+
"title": title,
|
101 |
+
"lecture_recording": recording_link,
|
102 |
+
"suggested_readings": suggested_readings,
|
103 |
+
}
|
104 |
+
except Exception as e:
|
105 |
+
print(f"Error processing block: {e}")
|
106 |
+
continue
|
107 |
+
|
108 |
+
return lecture_metadata
|
code/modules/dataloader/webpage_crawler.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import aiohttp
|
2 |
+
from aiohttp import ClientSession
|
3 |
+
import asyncio
|
4 |
+
import requests
|
5 |
+
from bs4 import BeautifulSoup
|
6 |
+
from urllib.parse import urlparse, urljoin, urldefrag
|
7 |
+
|
8 |
+
class WebpageCrawler:
|
9 |
+
def __init__(self):
|
10 |
+
self.dict_href_links = {}
|
11 |
+
|
12 |
+
async def fetch(self, session: ClientSession, url: str) -> str:
|
13 |
+
async with session.get(url) as response:
|
14 |
+
try:
|
15 |
+
return await response.text()
|
16 |
+
except UnicodeDecodeError:
|
17 |
+
return await response.text(encoding="latin1")
|
18 |
+
|
19 |
+
def url_exists(self, url: str) -> bool:
|
20 |
+
try:
|
21 |
+
response = requests.head(url)
|
22 |
+
return response.status_code == 200
|
23 |
+
except requests.ConnectionError:
|
24 |
+
return False
|
25 |
+
|
26 |
+
async def get_links(self, session: ClientSession, website_link: str, base_url: str):
|
27 |
+
html_data = await self.fetch(session, website_link)
|
28 |
+
soup = BeautifulSoup(html_data, "html.parser")
|
29 |
+
list_links = []
|
30 |
+
for link in soup.find_all("a", href=True):
|
31 |
+
href = link["href"].strip()
|
32 |
+
full_url = urljoin(base_url, href)
|
33 |
+
normalized_url = self.normalize_url(full_url) # sections removed
|
34 |
+
if (
|
35 |
+
normalized_url not in self.dict_href_links
|
36 |
+
and self.is_child_url(normalized_url, base_url)
|
37 |
+
and self.url_exists(normalized_url)
|
38 |
+
):
|
39 |
+
self.dict_href_links[normalized_url] = None
|
40 |
+
list_links.append(normalized_url)
|
41 |
+
|
42 |
+
return list_links
|
43 |
+
|
44 |
+
async def get_subpage_links(
|
45 |
+
self, session: ClientSession, urls: list, base_url: str
|
46 |
+
):
|
47 |
+
tasks = [self.get_links(session, url, base_url) for url in urls]
|
48 |
+
results = await asyncio.gather(*tasks)
|
49 |
+
all_links = [link for sublist in results for link in sublist]
|
50 |
+
return all_links
|
51 |
+
|
52 |
+
async def get_all_pages(self, url: str, base_url: str):
|
53 |
+
async with aiohttp.ClientSession() as session:
|
54 |
+
dict_links = {url: "Not-checked"}
|
55 |
+
counter = None
|
56 |
+
while counter != 0:
|
57 |
+
unchecked_links = [
|
58 |
+
link
|
59 |
+
for link, status in dict_links.items()
|
60 |
+
if status == "Not-checked"
|
61 |
+
]
|
62 |
+
if not unchecked_links:
|
63 |
+
break
|
64 |
+
new_links = await self.get_subpage_links(
|
65 |
+
session, unchecked_links, base_url
|
66 |
+
)
|
67 |
+
for link in unchecked_links:
|
68 |
+
dict_links[link] = "Checked"
|
69 |
+
print(f"Checked: {link}")
|
70 |
+
dict_links.update(
|
71 |
+
{
|
72 |
+
link: "Not-checked"
|
73 |
+
for link in new_links
|
74 |
+
if link not in dict_links
|
75 |
+
}
|
76 |
+
)
|
77 |
+
counter = len(
|
78 |
+
[
|
79 |
+
status
|
80 |
+
for status in dict_links.values()
|
81 |
+
if status == "Not-checked"
|
82 |
+
]
|
83 |
+
)
|
84 |
+
|
85 |
+
checked_urls = [
|
86 |
+
url for url, status in dict_links.items() if status == "Checked"
|
87 |
+
]
|
88 |
+
return checked_urls
|
89 |
+
|
90 |
+
def is_webpage(self, url: str) -> bool:
|
91 |
+
try:
|
92 |
+
response = requests.head(url, allow_redirects=True)
|
93 |
+
content_type = response.headers.get("Content-Type", "").lower()
|
94 |
+
return "text/html" in content_type
|
95 |
+
except requests.RequestException:
|
96 |
+
return False
|
97 |
+
|
98 |
+
def clean_url_list(self, urls):
|
99 |
+
files, webpages = [], []
|
100 |
+
|
101 |
+
for url in urls:
|
102 |
+
if self.is_webpage(url):
|
103 |
+
webpages.append(url)
|
104 |
+
else:
|
105 |
+
files.append(url)
|
106 |
+
|
107 |
+
return files, webpages
|
108 |
+
|
109 |
+
def is_child_url(self, url, base_url):
|
110 |
+
return url.startswith(base_url)
|
111 |
+
|
112 |
+
def normalize_url(self, url: str):
|
113 |
+
# Strip the fragment identifier
|
114 |
+
defragged_url, _ = urldefrag(url)
|
115 |
+
return defragged_url
|
code/modules/retriever/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .faiss_retriever import FaissRetriever
|
2 |
+
from .chroma_retriever import ChromaRetriever
|
code/modules/retriever/base.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class BaseRetriever:
|
2 |
+
def __init__(self, config):
|
3 |
+
self.config = config
|
4 |
+
|
5 |
+
def return_retriever(self):
|
6 |
+
raise NotImplementedError
|
code/modules/retriever/chroma_retriever.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .helpers import VectorStoreRetrieverScore
|
2 |
+
from .base import BaseRetriever
|
3 |
+
|
4 |
+
|
5 |
+
class ChromaRetriever(BaseRetriever):
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
def return_retriever(self, db, config):
|
10 |
+
retriever = VectorStoreRetrieverScore(
|
11 |
+
vectorstore=db,
|
12 |
+
# search_type="similarity_score_threshold",
|
13 |
+
# search_kwargs={
|
14 |
+
# "score_threshold": self.config["vectorstore"][
|
15 |
+
# "score_threshold"
|
16 |
+
# ],
|
17 |
+
# "k": self.config["vectorstore"]["search_top_k"],
|
18 |
+
# },
|
19 |
+
search_kwargs={
|
20 |
+
"k": config["vectorstore"]["search_top_k"],
|
21 |
+
},
|
22 |
+
)
|
23 |
+
|
24 |
+
return retriever
|
code/modules/retriever/faiss_retriever.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .helpers import VectorStoreRetrieverScore
|
2 |
+
from .base import BaseRetriever
|
3 |
+
|
4 |
+
|
5 |
+
class FaissRetriever(BaseRetriever):
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
def return_retriever(self, db, config):
|
10 |
+
retriever = VectorStoreRetrieverScore(
|
11 |
+
vectorstore=db,
|
12 |
+
# search_type="similarity_score_threshold",
|
13 |
+
# search_kwargs={
|
14 |
+
# "score_threshold": self.config["vectorstore"][
|
15 |
+
# "score_threshold"
|
16 |
+
# ],
|
17 |
+
# "k": self.config["vectorstore"]["search_top_k"],
|
18 |
+
# },
|
19 |
+
search_kwargs={
|
20 |
+
"k": config["vectorstore"]["search_top_k"],
|
21 |
+
},
|
22 |
+
)
|
23 |
+
return retriever
|
code/modules/retriever/helpers.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.schema.vectorstore import VectorStoreRetriever
|
2 |
+
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
3 |
+
from langchain.schema.document import Document
|
4 |
+
from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
|
8 |
+
class VectorStoreRetrieverScore(VectorStoreRetriever):
|
9 |
+
|
10 |
+
# See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
|
11 |
+
def _get_relevant_documents(
|
12 |
+
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
13 |
+
) -> List[Document]:
|
14 |
+
docs_and_similarities = (
|
15 |
+
self.vectorstore.similarity_search_with_relevance_scores(
|
16 |
+
query, **self.search_kwargs
|
17 |
+
)
|
18 |
+
)
|
19 |
+
# Make the score part of the document metadata
|
20 |
+
for doc, similarity in docs_and_similarities:
|
21 |
+
doc.metadata["score"] = similarity
|
22 |
+
|
23 |
+
docs = [doc for doc, _ in docs_and_similarities]
|
24 |
+
return docs
|
25 |
+
|
26 |
+
async def _aget_relevant_documents(
|
27 |
+
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
28 |
+
) -> List[Document]:
|
29 |
+
docs_and_similarities = (
|
30 |
+
self.vectorstore.similarity_search_with_relevance_scores(
|
31 |
+
query, **self.search_kwargs
|
32 |
+
)
|
33 |
+
)
|
34 |
+
# Make the score part of the document metadata
|
35 |
+
for doc, similarity in docs_and_similarities:
|
36 |
+
doc.metadata["score"] = similarity
|
37 |
+
|
38 |
+
docs = [doc for doc, _ in docs_and_similarities]
|
39 |
+
return docs
|
code/modules/vectorstore/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .base import VectorStoreBase
|
2 |
+
from .faiss import FAISS
|
code/modules/vectorstore/base.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class VectorStoreBase:
|
2 |
+
def __init__(self, config):
|
3 |
+
self.config = config
|
4 |
+
|
5 |
+
def _init_vector_db(self):
|
6 |
+
raise NotImplementedError
|
7 |
+
|
8 |
+
def create_database(self, database_name):
|
9 |
+
raise NotImplementedError
|
10 |
+
|
11 |
+
def load_database(self, database_name):
|
12 |
+
raise NotImplementedError
|
13 |
+
|
14 |
+
def as_retriever(self):
|
15 |
+
raise NotImplementedError
|
16 |
+
|
17 |
+
def __str__(self):
|
18 |
+
return self.__class__.__name__
|
code/modules/vectorstore/chroma.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.vectorstores import Chroma
|
2 |
+
from modules.vectorstore.base import VectorStoreBase
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class ChromaVectorStore(VectorStoreBase):
|
7 |
+
def __init__(self, config):
|
8 |
+
self.config = config
|
9 |
+
self._init_vector_db()
|
10 |
+
|
11 |
+
def _init_vector_db(self):
|
12 |
+
self.chroma = Chroma()
|
13 |
+
|
14 |
+
def create_database(self, document_chunks, embedding_model):
|
15 |
+
self.vectorstore = self.chroma.from_documents(
|
16 |
+
documents=document_chunks,
|
17 |
+
embedding=embedding_model,
|
18 |
+
persist_directory=os.path.join(
|
19 |
+
self.config["vectorstore"]["db_path"],
|
20 |
+
"db_"
|
21 |
+
+ self.config["vectorstore"]["db_option"]
|
22 |
+
+ "_"
|
23 |
+
+ self.config["vectorstore"]["model"],
|
24 |
+
),
|
25 |
+
)
|
26 |
+
|
27 |
+
def load_database(self, embedding_model):
|
28 |
+
self.vectorstore = Chroma(
|
29 |
+
persist_directory=os.path.join(
|
30 |
+
self.config["vectorstore"]["db_path"],
|
31 |
+
"db_"
|
32 |
+
+ self.config["vectorstore"]["db_option"]
|
33 |
+
+ "_"
|
34 |
+
+ self.config["vectorstore"]["model"],
|
35 |
+
),
|
36 |
+
embedding_function=embedding_model,
|
37 |
+
)
|
38 |
+
return self.vectorstore
|
39 |
+
|
40 |
+
def as_retriever(self):
|
41 |
+
return self.vectorstore.as_retriever()
|
code/modules/{embedding_model_loader.py β vectorstore/embedding_model_loader.py}
RENAMED
@@ -2,10 +2,7 @@ from langchain_community.embeddings import OpenAIEmbeddings
|
|
2 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
3 |
from langchain_community.embeddings import LlamaCppEmbeddings
|
4 |
|
5 |
-
|
6 |
-
from modules.constants import *
|
7 |
-
except:
|
8 |
-
from constants import *
|
9 |
import os
|
10 |
|
11 |
|
@@ -14,19 +11,19 @@ class EmbeddingModelLoader:
|
|
14 |
self.config = config
|
15 |
|
16 |
def load_embedding_model(self):
|
17 |
-
if self.config["
|
18 |
embedding_model = OpenAIEmbeddings(
|
19 |
deployment="SL-document_embedder",
|
20 |
-
model=self.config["
|
21 |
show_progress_bar=True,
|
22 |
openai_api_key=OPENAI_API_KEY,
|
23 |
disallowed_special=(),
|
24 |
)
|
25 |
else:
|
26 |
embedding_model = HuggingFaceEmbeddings(
|
27 |
-
model_name=self.config["
|
28 |
model_kwargs={
|
29 |
-
"device": "
|
30 |
"token": f"{HUGGINGFACE_TOKEN}",
|
31 |
"trust_remote_code": True,
|
32 |
},
|
|
|
2 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
3 |
from langchain_community.embeddings import LlamaCppEmbeddings
|
4 |
|
5 |
+
from modules.config.constants import *
|
|
|
|
|
|
|
6 |
import os
|
7 |
|
8 |
|
|
|
11 |
self.config = config
|
12 |
|
13 |
def load_embedding_model(self):
|
14 |
+
if self.config["vectorstore"]["model"] in ["text-embedding-ada-002"]:
|
15 |
embedding_model = OpenAIEmbeddings(
|
16 |
deployment="SL-document_embedder",
|
17 |
+
model=self.config["vectorestore"]["model"],
|
18 |
show_progress_bar=True,
|
19 |
openai_api_key=OPENAI_API_KEY,
|
20 |
disallowed_special=(),
|
21 |
)
|
22 |
else:
|
23 |
embedding_model = HuggingFaceEmbeddings(
|
24 |
+
model_name=self.config["vectorstore"]["model"],
|
25 |
model_kwargs={
|
26 |
+
"device": f"{self.config['device']}",
|
27 |
"token": f"{HUGGINGFACE_TOKEN}",
|
28 |
"trust_remote_code": True,
|
29 |
},
|
code/modules/vectorstore/faiss.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.vectorstores import FAISS
|
2 |
+
from modules.vectorstore.base import VectorStoreBase
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class FaissVectorStore(VectorStoreBase):
|
7 |
+
def __init__(self, config):
|
8 |
+
self.config = config
|
9 |
+
self._init_vector_db()
|
10 |
+
|
11 |
+
def _init_vector_db(self):
|
12 |
+
self.faiss = FAISS(
|
13 |
+
embedding_function=None, index=0, index_to_docstore_id={}, docstore={}
|
14 |
+
)
|
15 |
+
|
16 |
+
def create_database(self, document_chunks, embedding_model):
|
17 |
+
self.vectorstore = self.faiss.from_documents(
|
18 |
+
documents=document_chunks, embedding=embedding_model
|
19 |
+
)
|
20 |
+
self.vectorstore.save_local(
|
21 |
+
os.path.join(
|
22 |
+
self.config["vectorstore"]["db_path"],
|
23 |
+
"db_"
|
24 |
+
+ self.config["vectorstore"]["db_option"]
|
25 |
+
+ "_"
|
26 |
+
+ self.config["vectorstore"]["model"],
|
27 |
+
)
|
28 |
+
)
|
29 |
+
|
30 |
+
def load_database(self, embedding_model):
|
31 |
+
self.vectorstore = self.faiss.load_local(
|
32 |
+
os.path.join(
|
33 |
+
self.config["vectorstore"]["db_path"],
|
34 |
+
"db_"
|
35 |
+
+ self.config["vectorstore"]["db_option"]
|
36 |
+
+ "_"
|
37 |
+
+ self.config["vectorstore"]["model"],
|
38 |
+
),
|
39 |
+
embedding_model,
|
40 |
+
allow_dangerous_deserialization=True,
|
41 |
+
)
|
42 |
+
return self.vectorstore
|
43 |
+
|
44 |
+
def as_retriever(self):
|
45 |
+
return self.vectorstore.as_retriever()
|
code/modules/vectorstore/helpers.py
ADDED
File without changes
|
code/modules/{vector_db.py β vectorstore/store_manager.py}
RENAMED
@@ -1,72 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
-
import
|
4 |
-
|
5 |
-
from langchain.schema.vectorstore import VectorStoreRetriever
|
6 |
-
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
7 |
-
from langchain.schema.document import Document
|
8 |
-
from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun
|
9 |
-
from ragatouille import RAGPretrainedModel
|
10 |
-
|
11 |
-
try:
|
12 |
-
from modules.embedding_model_loader import EmbeddingModelLoader
|
13 |
-
from modules.data_loader import DataLoader
|
14 |
-
from modules.constants import *
|
15 |
-
from modules.helpers import *
|
16 |
-
except:
|
17 |
-
from embedding_model_loader import EmbeddingModelLoader
|
18 |
-
from data_loader import DataLoader
|
19 |
-
from constants import *
|
20 |
-
from helpers import *
|
21 |
-
|
22 |
-
from typing import List
|
23 |
-
|
24 |
-
|
25 |
-
class VectorDBScore(VectorStoreRetriever):
|
26 |
-
|
27 |
-
# See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
|
28 |
-
def _get_relevant_documents(
|
29 |
-
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
30 |
-
) -> List[Document]:
|
31 |
-
docs_and_similarities = (
|
32 |
-
self.vectorstore.similarity_search_with_relevance_scores(
|
33 |
-
query, **self.search_kwargs
|
34 |
-
)
|
35 |
-
)
|
36 |
-
# Make the score part of the document metadata
|
37 |
-
for doc, similarity in docs_and_similarities:
|
38 |
-
doc.metadata["score"] = similarity
|
39 |
-
|
40 |
-
docs = [doc for doc, _ in docs_and_similarities]
|
41 |
-
return docs
|
42 |
-
|
43 |
-
async def _aget_relevant_documents(
|
44 |
-
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
45 |
-
) -> List[Document]:
|
46 |
-
docs_and_similarities = (
|
47 |
-
self.vectorstore.similarity_search_with_relevance_scores(
|
48 |
-
query, **self.search_kwargs
|
49 |
-
)
|
50 |
-
)
|
51 |
-
# Make the score part of the document metadata
|
52 |
-
for doc, similarity in docs_and_similarities:
|
53 |
-
doc.metadata["score"] = similarity
|
54 |
-
|
55 |
-
docs = [doc for doc, _ in docs_and_similarities]
|
56 |
-
return docs
|
57 |
|
58 |
|
59 |
-
class
|
60 |
def __init__(self, config, logger=None):
|
61 |
self.config = config
|
62 |
-
self.db_option = config["
|
63 |
self.document_names = None
|
64 |
-
self.webpage_crawler = WebpageCrawler()
|
65 |
|
66 |
# Set up logging to both console and a file
|
67 |
if logger is None:
|
68 |
self.logger = logging.getLogger(__name__)
|
69 |
self.logger.setLevel(logging.INFO)
|
|
|
70 |
|
71 |
# Console Handler
|
72 |
console_handler = logging.StreamHandler()
|
@@ -75,8 +30,13 @@ class VectorDB:
|
|
75 |
console_handler.setFormatter(formatter)
|
76 |
self.logger.addHandler(console_handler)
|
77 |
|
|
|
|
|
|
|
|
|
|
|
78 |
# File Handler
|
79 |
-
log_file_path = "vector_db.log" # Change this to your desired log file path
|
80 |
file_handler = logging.FileHandler(log_file_path, mode="w")
|
81 |
file_handler.setLevel(logging.INFO)
|
82 |
file_handler.setFormatter(formatter)
|
@@ -84,16 +44,18 @@ class VectorDB:
|
|
84 |
else:
|
85 |
self.logger = logger
|
86 |
|
|
|
|
|
87 |
self.logger.info("VectorDB instance instantiated")
|
88 |
|
89 |
def load_files(self):
|
90 |
-
files = os.listdir(self.config["
|
91 |
files = [
|
92 |
-
os.path.join(self.config["
|
93 |
for file in files
|
94 |
]
|
95 |
-
urls = get_urls_from_file(self.config["
|
96 |
-
if self.config["
|
97 |
all_urls = []
|
98 |
for url in urls:
|
99 |
loop = asyncio.get_event_loop()
|
@@ -109,8 +71,9 @@ class VectorDB:
|
|
109 |
|
110 |
def create_embedding_model(self):
|
111 |
self.logger.info("Creating embedding function")
|
112 |
-
|
113 |
-
|
|
|
114 |
|
115 |
def initialize_database(
|
116 |
self,
|
@@ -120,107 +83,153 @@ class VectorDB:
|
|
120 |
document_metadata: list,
|
121 |
):
|
122 |
if self.db_option in ["FAISS", "Chroma"]:
|
123 |
-
self.create_embedding_model()
|
124 |
-
|
125 |
self.logger.info("Initializing vector_db")
|
126 |
self.logger.info("\tUsing {} as db_option".format(self.db_option))
|
127 |
if self.db_option == "FAISS":
|
128 |
-
self.vector_db =
|
129 |
-
|
130 |
-
)
|
131 |
elif self.db_option == "Chroma":
|
132 |
-
self.vector_db =
|
133 |
-
|
134 |
-
embedding=self.embedding_model,
|
135 |
-
persist_directory=os.path.join(
|
136 |
-
self.config["embedding_options"]["db_path"],
|
137 |
-
"db_"
|
138 |
-
+ self.config["embedding_options"]["db_option"]
|
139 |
-
+ "_"
|
140 |
-
+ self.config["embedding_options"]["model"],
|
141 |
-
),
|
142 |
-
)
|
143 |
elif self.db_option == "RAGatouille":
|
144 |
self.RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
|
145 |
-
index_path = self.RAG.index(
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
)
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
def create_database(self):
|
154 |
-
|
|
|
155 |
self.logger.info("Loading data")
|
156 |
files, urls = self.load_files()
|
157 |
files, webpages = self.webpage_crawler.clean_url_list(urls)
|
158 |
-
|
159 |
-
|
|
|
|
|
160 |
document_chunks, document_names, documents, document_metadata = (
|
161 |
data_loader.get_chunks(files, webpages)
|
162 |
)
|
|
|
|
|
|
|
|
|
163 |
self.logger.info("Completed loading data")
|
164 |
self.initialize_database(
|
165 |
document_chunks, document_names, documents, document_metadata
|
166 |
)
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
-
def save_database(self):
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
def load_database(self):
|
188 |
-
|
|
|
|
|
189 |
if self.db_option == "FAISS":
|
190 |
-
self.vector_db =
|
191 |
-
|
192 |
-
self.config["embedding_options"]["db_path"],
|
193 |
-
"db_"
|
194 |
-
+ self.config["embedding_options"]["db_option"]
|
195 |
-
+ "_"
|
196 |
-
+ self.config["embedding_options"]["model"],
|
197 |
-
),
|
198 |
-
self.embedding_model,
|
199 |
-
allow_dangerous_deserialization=True,
|
200 |
-
)
|
201 |
elif self.db_option == "Chroma":
|
202 |
-
self.vector_db =
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
self.logger.info("Loaded database")
|
217 |
-
return self.
|
218 |
|
219 |
|
220 |
if __name__ == "__main__":
|
221 |
-
|
|
|
|
|
222 |
config = yaml.safe_load(f)
|
223 |
print(config)
|
224 |
-
|
|
|
225 |
vector_db.create_database()
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.vectorstore.faiss import FaissVectorStore
|
2 |
+
from modules.vectorstore.chroma import ChromaVectorStore
|
3 |
+
from modules.vectorstore.helpers import *
|
4 |
+
from modules.dataloader.webpage_crawler import WebpageCrawler
|
5 |
+
from modules.dataloader.data_loader import DataLoader
|
6 |
+
from modules.dataloader.helpers import *
|
7 |
+
from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
|
8 |
import logging
|
9 |
import os
|
10 |
+
import time
|
11 |
+
import asyncio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
+
class VectorStoreManager:
|
15 |
def __init__(self, config, logger=None):
|
16 |
self.config = config
|
17 |
+
self.db_option = config["vectorstore"]["db_option"]
|
18 |
self.document_names = None
|
|
|
19 |
|
20 |
# Set up logging to both console and a file
|
21 |
if logger is None:
|
22 |
self.logger = logging.getLogger(__name__)
|
23 |
self.logger.setLevel(logging.INFO)
|
24 |
+
self.logger.propagate = False
|
25 |
|
26 |
# Console Handler
|
27 |
console_handler = logging.StreamHandler()
|
|
|
30 |
console_handler.setFormatter(formatter)
|
31 |
self.logger.addHandler(console_handler)
|
32 |
|
33 |
+
# Ensure log directory exists
|
34 |
+
log_directory = self.config["log_dir"]
|
35 |
+
if not os.path.exists(log_directory):
|
36 |
+
os.makedirs(log_directory)
|
37 |
+
|
38 |
# File Handler
|
39 |
+
log_file_path = f"{log_directory}/vector_db.log" # Change this to your desired log file path
|
40 |
file_handler = logging.FileHandler(log_file_path, mode="w")
|
41 |
file_handler.setLevel(logging.INFO)
|
42 |
file_handler.setFormatter(formatter)
|
|
|
44 |
else:
|
45 |
self.logger = logger
|
46 |
|
47 |
+
self.webpage_crawler = WebpageCrawler()
|
48 |
+
|
49 |
self.logger.info("VectorDB instance instantiated")
|
50 |
|
51 |
def load_files(self):
|
52 |
+
files = os.listdir(self.config["vectorstore"]["data_path"])
|
53 |
files = [
|
54 |
+
os.path.join(self.config["vectorstore"]["data_path"], file)
|
55 |
for file in files
|
56 |
]
|
57 |
+
urls = get_urls_from_file(self.config["vectorstore"]["url_file_path"])
|
58 |
+
if self.config["vectorstore"]["expand_urls"]:
|
59 |
all_urls = []
|
60 |
for url in urls:
|
61 |
loop = asyncio.get_event_loop()
|
|
|
71 |
|
72 |
def create_embedding_model(self):
|
73 |
self.logger.info("Creating embedding function")
|
74 |
+
embedding_model_loader = EmbeddingModelLoader(self.config)
|
75 |
+
embedding_model = embedding_model_loader.load_embedding_model()
|
76 |
+
return embedding_model
|
77 |
|
78 |
def initialize_database(
|
79 |
self,
|
|
|
83 |
document_metadata: list,
|
84 |
):
|
85 |
if self.db_option in ["FAISS", "Chroma"]:
|
86 |
+
self.embedding_model = self.create_embedding_model()
|
87 |
+
|
88 |
self.logger.info("Initializing vector_db")
|
89 |
self.logger.info("\tUsing {} as db_option".format(self.db_option))
|
90 |
if self.db_option == "FAISS":
|
91 |
+
self.vector_db = FaissVectorStore(self.config)
|
92 |
+
self.vector_db.create_database(document_chunks, self.embedding_model)
|
|
|
93 |
elif self.db_option == "Chroma":
|
94 |
+
self.vector_db = ChromaVectorStore(self.config)
|
95 |
+
self.vector_db.create_database(document_chunks, self.embedding_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
elif self.db_option == "RAGatouille":
|
97 |
self.RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
|
98 |
+
# index_path = self.RAG.index(
|
99 |
+
# index_name="new_idx",
|
100 |
+
# collection=documents,
|
101 |
+
# document_ids=document_names,
|
102 |
+
# document_metadatas=document_metadata,
|
103 |
+
# )
|
104 |
+
batch_size = 32
|
105 |
+
for i in range(0, len(documents), batch_size):
|
106 |
+
if i == 0:
|
107 |
+
self.RAG.index(
|
108 |
+
index_name="new_idx",
|
109 |
+
collection=documents[i : i + batch_size],
|
110 |
+
document_ids=document_names[i : i + batch_size],
|
111 |
+
document_metadatas=document_metadata[i : i + batch_size],
|
112 |
+
)
|
113 |
+
else:
|
114 |
+
self.RAG = RAGPretrainedModel.from_index(
|
115 |
+
".ragatouille/colbert/indexes/new_idx"
|
116 |
+
)
|
117 |
+
self.RAG.add_to_index(
|
118 |
+
new_collection=documents[i : i + batch_size],
|
119 |
+
new_document_ids=document_names[i : i + batch_size],
|
120 |
+
new_document_metadatas=document_metadata[i : i + batch_size],
|
121 |
+
)
|
122 |
|
123 |
def create_database(self):
|
124 |
+
start_time = time.time() # Start time for creating database
|
125 |
+
data_loader = DataLoader(self.config, self.logger)
|
126 |
self.logger.info("Loading data")
|
127 |
files, urls = self.load_files()
|
128 |
files, webpages = self.webpage_crawler.clean_url_list(urls)
|
129 |
+
self.logger.info(f"Number of files: {len(files)}")
|
130 |
+
self.logger.info(f"Number of webpages: {len(webpages)}")
|
131 |
+
if f"{self.config['vectorstore']['url_file_path']}" in files:
|
132 |
+
files.remove(f"{self.config['vectorstores']['url_file_path']}") # cleanup
|
133 |
document_chunks, document_names, documents, document_metadata = (
|
134 |
data_loader.get_chunks(files, webpages)
|
135 |
)
|
136 |
+
num_documents = len(document_chunks)
|
137 |
+
self.logger.info(f"Number of documents in the DB: {num_documents}")
|
138 |
+
metadata_keys = list(document_metadata[0].keys())
|
139 |
+
self.logger.info(f"Metadata keys: {metadata_keys}")
|
140 |
self.logger.info("Completed loading data")
|
141 |
self.initialize_database(
|
142 |
document_chunks, document_names, documents, document_metadata
|
143 |
)
|
144 |
+
end_time = time.time() # End time for creating database
|
145 |
+
self.logger.info("Created database")
|
146 |
+
self.logger.info(
|
147 |
+
f"Time taken to create database: {end_time - start_time} seconds"
|
148 |
+
)
|
149 |
|
150 |
+
# def save_database(self):
|
151 |
+
# start_time = time.time() # Start time for saving database
|
152 |
+
# if self.db_option == "FAISS":
|
153 |
+
# self.vector_db.save_local(
|
154 |
+
# os.path.join(
|
155 |
+
# self.config["vectorstore"]["db_path"],
|
156 |
+
# "db_"
|
157 |
+
# + self.config["vectorstore"]["db_option"]
|
158 |
+
# + "_"
|
159 |
+
# + self.config["vectorstore"]["model"],
|
160 |
+
# )
|
161 |
+
# )
|
162 |
+
# elif self.db_option == "Chroma":
|
163 |
+
# # db is saved in the persist directory during initialization
|
164 |
+
# pass
|
165 |
+
# elif self.db_option == "RAGatouille":
|
166 |
+
# # index is saved during initialization
|
167 |
+
# pass
|
168 |
+
# self.logger.info("Saved database")
|
169 |
+
# end_time = time.time() # End time for saving database
|
170 |
+
# self.logger.info(
|
171 |
+
# f"Time taken to save database: {end_time - start_time} seconds"
|
172 |
+
# )
|
173 |
|
174 |
def load_database(self):
|
175 |
+
start_time = time.time() # Start time for loading database
|
176 |
+
if self.db_option in ["FAISS", "Chroma"]:
|
177 |
+
self.embedding_model = self.create_embedding_model()
|
178 |
if self.db_option == "FAISS":
|
179 |
+
self.vector_db = FaissVectorStore(self.config)
|
180 |
+
self.loaded_vector_db = self.vector_db.load_database(self.embedding_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
elif self.db_option == "Chroma":
|
182 |
+
self.vector_db = ChromaVectorStore(self.config)
|
183 |
+
self.loaded_vector_db = self.vector_db.load_database(self.embedding_model)
|
184 |
+
# if self.db_option == "FAISS":
|
185 |
+
# self.vector_db = FAISS.load_local(
|
186 |
+
# os.path.join(
|
187 |
+
# self.config["vectorstore"]["db_path"],
|
188 |
+
# "db_"
|
189 |
+
# + self.config["vectorstore"]["db_option"]
|
190 |
+
# + "_"
|
191 |
+
# + self.config["vectorstore"]["model"],
|
192 |
+
# ),
|
193 |
+
# self.embedding_model,
|
194 |
+
# allow_dangerous_deserialization=True,
|
195 |
+
# )
|
196 |
+
# elif self.db_option == "Chroma":
|
197 |
+
# self.vector_db = Chroma(
|
198 |
+
# persist_directory=os.path.join(
|
199 |
+
# self.config["embedding_options"]["db_path"],
|
200 |
+
# "db_"
|
201 |
+
# + self.config["embedding_options"]["db_option"]
|
202 |
+
# + "_"
|
203 |
+
# + self.config["embedding_options"]["model"],
|
204 |
+
# ),
|
205 |
+
# embedding_function=self.embedding_model,
|
206 |
+
# )
|
207 |
+
# elif self.db_option == "RAGatouille":
|
208 |
+
# self.vector_db = RAGPretrainedModel.from_index(
|
209 |
+
# ".ragatouille/colbert/indexes/new_idx"
|
210 |
+
# )
|
211 |
+
end_time = time.time() # End time for loading database
|
212 |
+
self.logger.info(
|
213 |
+
f"Time taken to load database: {end_time - start_time} seconds"
|
214 |
+
)
|
215 |
self.logger.info("Loaded database")
|
216 |
+
return self.loaded_vector_db
|
217 |
|
218 |
|
219 |
if __name__ == "__main__":
|
220 |
+
import yaml
|
221 |
+
|
222 |
+
with open("modules/config/config.yml", "r") as f:
|
223 |
config = yaml.safe_load(f)
|
224 |
print(config)
|
225 |
+
print(f"Trying to create database with config: {config}")
|
226 |
+
vector_db = VectorStoreManager(config)
|
227 |
vector_db.create_database()
|
228 |
+
print("Created database")
|
229 |
+
|
230 |
+
print(f"Trying to load the database")
|
231 |
+
vector_db = VectorStoreManager(config)
|
232 |
+
vector_db.load_database()
|
233 |
+
print("Loaded database")
|
234 |
+
|
235 |
+
print(f"View the logs at {config['log_dir']}/vector_db.log")
|