Soumic commited on
Commit
ef3f3d9
·
1 Parent(s): 4fbeb5c

:rocket: Code is stable, save progress

Browse files
Files changed (6) hide show
  1. .env_sample +2 -0
  2. .gitignore +173 -0
  3. Dockerfile +36 -0
  4. app.py +282 -0
  5. app_v1.py +516 -0
  6. requirements.txt +33 -0
.env_sample ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ HF_TOKEN=hf_YOUR_AWESOME_TOKEN
2
+ WAND_DB_API_KEY=YOUR_WAND_DB_TOKEN
.gitignore ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ # c++ generated files
163
+ *.out
164
+ *.exe
165
+
166
+ # my custom gitignores
167
+ lightning_logs/
168
+ *.pth
169
+ my-awesome-model/
170
+ my-awesome-model-200/
171
+ my-awesome-model-4000/
172
+ output_hyena_dna-mqtl_classification/
173
+ wandb/
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the official PyTorch Docker image as a base (includes CUDA and PyTorch)
2
+ FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
3
+
4
+ # Install required dependencies (add any additional system dependencies you need)
5
+ RUN apt update && apt install -y ffmpeg
6
+
7
+ # Create a non-root user with a home directory
8
+ RUN useradd -m -u 1000 user
9
+
10
+ # Switch to the new non-root user
11
+ USER user
12
+
13
+ # Set environment variables for the new user
14
+ ENV HOME=/home/user \
15
+ PATH=/home/user/.local/bin:$PATH
16
+
17
+ # Set a working directory
18
+ WORKDIR $HOME/app
19
+
20
+ # Set the TRANSFORMERS_CACHE directory to be within the user's home directory
21
+ ENV TRANSFORMERS_CACHE=$HOME/cache
22
+
23
+ # Copy the app code and set ownership to the non-root user
24
+ COPY --chown=user . $HOME/app
25
+
26
+ # Install Python dependencies in the virtual environment
27
+ RUN python -m venv /home/user/venv
28
+ ENV PATH="/home/user/venv/bin:$PATH"
29
+
30
+ # Install pip dependencies within the virtual environment
31
+ COPY requirements.txt .
32
+ RUN pip install --upgrade pip
33
+ RUN pip install -r requirements.txt
34
+
35
+ # Run the training script
36
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import huggingface_hub
5
+ from datasets import load_dataset, Dataset
6
+ from dotenv import load_dotenv
7
+ from pytorch_lightning import LightningDataModule
8
+ from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
9
+ from torch.utils.data import DataLoader, IterableDataset
10
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
11
+ from transformers import TrainingArguments, Trainer
12
+ import torch
13
+ import logging
14
+ import wandb
15
+
16
+ timber = logging.getLogger()
17
+ # logging.basicConfig(level=logging.DEBUG)
18
+ logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
19
+
20
+ black = "\u001b[30m"
21
+ red = "\u001b[31m"
22
+ green = "\u001b[32m"
23
+ yellow = "\u001b[33m"
24
+ blue = "\u001b[34m"
25
+ magenta = "\u001b[35m"
26
+ cyan = "\u001b[36m"
27
+ white = "\u001b[37m"
28
+
29
+ FORWARD = "FORWARD_INPUT"
30
+ BACKWARD = "BACKWARD_INPUT"
31
+
32
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+
34
+ PRETRAINED_MODEL_NAME: str = "LongSafari/hyenadna-small-32k-seqlen-hf"
35
+
36
+
37
+ def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
38
+ start = 0
39
+ end = len(seq)
40
+ rand_pos = random.randrange(start, (end - len(DEBUG_MOTIF)))
41
+ random_end = rand_pos + len(DEBUG_MOTIF)
42
+ output = seq[start: rand_pos] + DEBUG_MOTIF + seq[random_end: end]
43
+ assert len(seq) == len(output)
44
+ return output
45
+
46
+
47
+ class PagingMQTLDataset(IterableDataset):
48
+ def __init__(self,
49
+ m_dataset,
50
+ seq_len,
51
+ tokenizer,
52
+ max_length=512,
53
+ check_if_pipeline_is_ok_by_inserting_debug_motif=False):
54
+ self.dataset = m_dataset
55
+ self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
56
+ self.debug_motif = "ATCGCCTA"
57
+ self.seq_len = seq_len
58
+
59
+ self.bert_tokenizer = tokenizer
60
+ self.max_length = max_length
61
+ pass
62
+
63
+ def __iter__(self):
64
+ for row in self.dataset:
65
+ processed = self.preprocess(row)
66
+ if processed is not None:
67
+ yield processed
68
+
69
+ def preprocess(self, row):
70
+ sequence = row['sequence'] # Fetch the 'sequence' column
71
+ if len(sequence) != self.seq_len:
72
+ return None # skip problematic row!
73
+ label = row['label'] # Fetch the 'label' column (or whatever target you use)
74
+ if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
75
+ sequence = insert_debug_motif_at_random_position(seq=sequence, DEBUG_MOTIF=self.debug_motif)
76
+
77
+ input_ids = self.bert_tokenizer(sequence)["input_ids"]
78
+ tokenized_tensor = torch.tensor(input_ids)
79
+ label_tensor = torch.tensor(label)
80
+ output_dict = {"input_ids": tokenized_tensor, "labels": label_tensor} # so this is now you do it?
81
+ return output_dict # tokenized_tensor, label_tensor
82
+
83
+
84
+ class MqtlDataModule(LightningDataModule):
85
+ def __init__(self, train_ds, val_ds, test_ds, batch_size=16):
86
+ super().__init__()
87
+ self.batch_size = batch_size
88
+ self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False,
89
+ # collate_fn=collate_fn,
90
+ num_workers=1,
91
+ # persistent_workers=True
92
+ )
93
+ self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False,
94
+ # collate_fn=collate_fn,
95
+ num_workers=1,
96
+ # persistent_workers=True
97
+ )
98
+ self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False,
99
+ # collate_fn=collate_fn,
100
+ num_workers=1,
101
+ # persistent_workers=True
102
+ )
103
+ pass
104
+
105
+ def prepare_data(self):
106
+ pass
107
+
108
+ def setup(self, stage: str) -> None:
109
+ timber.info(f"inside setup: {stage = }")
110
+ pass
111
+
112
+ def train_dataloader(self) -> TRAIN_DATALOADERS:
113
+ return self.train_loader
114
+
115
+ def val_dataloader(self) -> EVAL_DATALOADERS:
116
+ return self.validate_loader
117
+
118
+ def test_dataloader(self) -> EVAL_DATALOADERS:
119
+ return self.test_loader
120
+
121
+
122
+ def create_paging_train_val_test_datasets(tokenizer, WINDOW, is_debug, batch_size=1000):
123
+ data_files = {
124
+ # small samples
125
+ "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
126
+ "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
127
+ "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
128
+ # medium samples
129
+ "train_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_train_binned.csv",
130
+ "validate_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_validate_binned.csv",
131
+ "test_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_test_binned.csv",
132
+
133
+ # large samples
134
+ "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
135
+ "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
136
+ "test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
137
+ }
138
+
139
+ dataset_map = None
140
+ is_my_laptop = os.path.isfile("/src/inputdata/dataset_4000_test_binned.csv")
141
+ if is_my_laptop:
142
+ dataset_map = load_dataset("csv", data_files=data_files, streaming=True)
143
+ else:
144
+ dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True)
145
+
146
+ train_dataset = PagingMQTLDataset(dataset_map[f"train_binned_{WINDOW}"],
147
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
148
+ tokenizer=tokenizer,
149
+ seq_len=WINDOW
150
+ )
151
+ val_dataset = PagingMQTLDataset(dataset_map[f"validate_binned_{WINDOW}"],
152
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
153
+ tokenizer=tokenizer,
154
+ seq_len=WINDOW)
155
+ test_dataset = PagingMQTLDataset(dataset_map[f"test_binned_{WINDOW}"],
156
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
157
+ tokenizer=tokenizer,
158
+ seq_len=WINDOW)
159
+ # data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset, batch_size=batch_size)
160
+ return train_dataset, val_dataset, test_dataset
161
+
162
+
163
+ def login_inside_huggingface_virtualmachine():
164
+ # Load the .env file, but don't crash if it's not found (e.g., in Hugging Face Space)
165
+ try:
166
+ load_dotenv() # Only useful on your laptop if .env exists
167
+ print(".env file loaded successfully.")
168
+ except Exception as e:
169
+ print(f"Warning: Could not load .env file. Exception: {e}")
170
+
171
+ # Try to get the token from environment variables
172
+ try:
173
+ token = os.getenv("HF_TOKEN")
174
+
175
+ if not token:
176
+ raise ValueError("HF_TOKEN not found. Make sure to set it in the environment variables or .env file.")
177
+
178
+ # Log in to Hugging Face Hub
179
+ huggingface_hub.login(token)
180
+ print("Logged in to Hugging Face Hub successfully.")
181
+
182
+ except Exception as e:
183
+ print(f"Error during Hugging Face login: {e}")
184
+ # Handle the error appropriately (e.g., exit or retry)
185
+
186
+ # wand db login
187
+ try:
188
+ api_key = os.getenv("WAND_DB_API_KEY")
189
+ timber.info(f"{api_key = }")
190
+
191
+ if not api_key:
192
+ raise ValueError("WAND_DB_API_KEY not found. Make sure to set it in the environment variables or .env file.")
193
+
194
+ # Log in to Hugging Face Hub
195
+ wandb.login(key=api_key)
196
+ print("Logged in to wand db successfully.")
197
+
198
+ except Exception as e:
199
+ print(f"Error during wand db Face login: {e}")
200
+ pass
201
+
202
+
203
+ def start():
204
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
205
+
206
+ login_inside_huggingface_virtualmachine()
207
+ WINDOW = 1000
208
+ batch_size = 100
209
+ tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME, trust_remote_code=True)
210
+ model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_MODEL_NAME, torch_dtype=torch.bfloat16,
211
+ device_map="auto",
212
+ trust_remote_code=True)
213
+ args = {
214
+ "output_dir": "output_hyena_dna-mqtl_classification",
215
+ "num_train_epochs": 2,
216
+ "max_steps": 20,
217
+ # Set the number of steps you expect to train, originally 1000, takes too much time. So I set it to 10 to run faster and check my code/pipeline
218
+ "run_name": "laptop_run_hyena_dna-mqtl_classification", # Override run_name here
219
+ "per_device_train_batch_size": 1,
220
+ "gradient_accumulation_steps": 4,
221
+ "gradient_checkpointing": True,
222
+ "learning_rate": 2e-5,
223
+ "save_safetensors": False # I added it. this solves the runtime error!
224
+ }
225
+
226
+ # """
227
+ # got this error at the end!
228
+ # raise RuntimeError(
229
+ # RuntimeError: The weights trying to be saved contained shared tensors [{'hyena.backbone.layers.0.mixer.filter_fn.implicit_filter.3.freq', 'hyena.backbone.layers.0.mixer.filter_fn.implicit_filter.1.freq', 'hyena.backbone.layers.0.mixer.filter_fn.implicit_filter.5.freq'}] that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.
230
+ # """
231
+
232
+ training_args = TrainingArguments(**args)
233
+ # train_dataset, eval_dataset, test_dataset = create_data_module(tokenizer=tokenizer, WINDOW=WINDOW,
234
+ # batch_size=batch_size,
235
+ # is_debug=False)
236
+ max_length = 32_000
237
+ sequence = 'ACTG' * int(max_length / 4)
238
+ # sequence = 'ACTG' * int(1000) # seq_len = 4000 it works!
239
+ sequence = [sequence] * 8 # Create 8 identical samples
240
+ tokenized = tokenizer(sequence)["input_ids"]
241
+ labels = [0, 1] * 4
242
+
243
+ # Create a dataset for training
244
+ run_the_code_ds = Dataset.from_dict({"input_ids": tokenized, "labels": labels})
245
+ run_the_code_ds.set_format("pt")
246
+
247
+ # train_ds, val_ds, test_ds = create_paging_train_val_test_datasets(tokenizer, WINDOW=WINDOW, is_debug=False)
248
+ train_ds, val_ds, test_ds = run_the_code_ds, run_the_code_ds, run_the_code_ds
249
+ # train_ds.set_format("pt") # doesn't work!
250
+
251
+ trainer = Trainer(
252
+ model=model,
253
+ args=training_args,
254
+ train_dataset=train_ds,
255
+ eval_dataset=val_ds,
256
+ )
257
+ # train, and validate
258
+ result = trainer.train()
259
+ try:
260
+ print(f"{result = }")
261
+ except Exception as x:
262
+ print(f"{x = }")
263
+
264
+ # testing
265
+ try:
266
+ # with torch.no_grad(): # didn't work :/
267
+ test_results = trainer.evaluate(eval_dataset=test_ds)
268
+ print(f"{test_results = }")
269
+ except Exception as oome:
270
+ print(f"{oome = }")
271
+
272
+
273
+
274
+ if __name__ == '__main__':
275
+ start()
276
+ pass
277
+
278
+ """
279
+ git submodule add https://huggingface.co/spaces/fahimfarhan/hyenadna-sm-32k-mqtl-classifier-space src/huggingface-mqtl-classification-hyena-dna
280
+
281
+ """
282
+
app_v1.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from pytorch_lightning import Trainer, LightningModule, LightningDataModule
9
+ from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
10
+ from torch.nn.utils.rnn import pad_sequence
11
+ from torch.utils.data import DataLoader, Dataset
12
+ from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
13
+ from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments, AutoModelForSequenceClassification
14
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
15
+ import torch
16
+ from torch import nn
17
+ from datasets import load_dataset, IterableDataset
18
+ from huggingface_hub import PyTorchModelHubMixin
19
+
20
+ from dotenv import load_dotenv
21
+ from huggingface_hub import login
22
+
23
+ timber = logging.getLogger()
24
+ # logging.basicConfig(level=logging.DEBUG)
25
+ logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
26
+
27
+ black = "\u001b[30m"
28
+ red = "\u001b[31m"
29
+ green = "\u001b[32m"
30
+ yellow = "\u001b[33m"
31
+ blue = "\u001b[34m"
32
+ magenta = "\u001b[35m"
33
+ cyan = "\u001b[36m"
34
+ white = "\u001b[37m"
35
+
36
+ FORWARD = "FORWARD_INPUT"
37
+ BACKWARD = "BACKWARD_INPUT"
38
+
39
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+
41
+ PRETRAINED_MODEL_NAME: str = "LongSafari/hyenadna-small-32k-seqlen-hf"
42
+
43
+
44
+ def login_inside_huggingface_virtualmachine():
45
+ # Load the .env file, but don't crash if it's not found (e.g., in Hugging Face Space)
46
+ try:
47
+ load_dotenv() # Only useful on your laptop if .env exists
48
+ print(".env file loaded successfully.")
49
+ except Exception as e:
50
+ print(f"Warning: Could not load .env file. Exception: {e}")
51
+
52
+ # Try to get the token from environment variables
53
+ try:
54
+ token = os.getenv("HF_TOKEN")
55
+
56
+ if not token:
57
+ raise ValueError("HF_TOKEN not found. Make sure to set it in the environment variables or .env file.")
58
+
59
+ # Log in to Hugging Face Hub
60
+ login(token)
61
+ print("Logged in to Hugging Face Hub successfully.")
62
+
63
+ except Exception as e:
64
+ print(f"Error during Hugging Face login: {e}")
65
+ # Handle the error appropriately (e.g., exit or retry)
66
+
67
+
68
+ def one_hot_e(dna_seq: str) -> np.ndarray:
69
+ mydict = {'A': np.asarray([1.0, 0.0, 0.0, 0.0]), 'C': np.asarray([0.0, 1.0, 0.0, 0.0]),
70
+ 'G': np.asarray([0.0, 0.0, 1.0, 0.0]), 'T': np.asarray([0.0, 0.0, 0.0, 1.0]),
71
+ 'N': np.asarray([0.0, 0.0, 0.0, 0.0]), 'H': np.asarray([0.0, 0.0, 0.0, 0.0]),
72
+ 'a': np.asarray([1.0, 0.0, 0.0, 0.0]), 'c': np.asarray([0.0, 1.0, 0.0, 0.0]),
73
+ 'g': np.asarray([0.0, 0.0, 1.0, 0.0]), 't': np.asarray([0.0, 0.0, 0.0, 1.0]),
74
+ 'n': np.asarray([0.0, 0.0, 0.0, 0.0]), '-': np.asarray([0.0, 0.0, 0.0, 0.0])}
75
+
76
+ size_of_a_seq: int = len(dna_seq)
77
+
78
+ # forward = np.zeros(shape=(size_of_a_seq, 4))
79
+
80
+ forward_list: list = [mydict[dna_seq[i]] for i in range(0, size_of_a_seq)]
81
+ encoded = np.asarray(forward_list)
82
+ encoded_transposed = encoded.transpose() # todo: Needs review
83
+ return encoded_transposed
84
+
85
+
86
+ def one_hot_e_column(column: pd.Series) -> np.ndarray:
87
+ tmp_list: list = [one_hot_e(seq) for seq in column]
88
+ encoded_column = np.asarray(tmp_list).astype(np.float32)
89
+ return encoded_column
90
+
91
+
92
+ def reverse_dna_seq(dna_seq: str) -> str:
93
+ # m_reversed = ""
94
+ # for i in range(0, len(dna_seq)):
95
+ # m_reversed = dna_seq[i] + m_reversed
96
+ # return m_reversed
97
+ return dna_seq[::-1]
98
+
99
+
100
+ def complement_dna_seq(dna_seq: str) -> str:
101
+ comp_map = {"A": "T", "C": "G", "T": "A", "G": "C",
102
+ "a": "t", "c": "g", "t": "a", "g": "c",
103
+ "N": "N", "H": "H", "-": "-",
104
+ "n": "n", "h": "h"
105
+ }
106
+
107
+ comp_dna_seq_list: list = [comp_map[nucleotide] for nucleotide in dna_seq]
108
+ comp_dna_seq: str = "".join(comp_dna_seq_list)
109
+ return comp_dna_seq
110
+
111
+
112
+ def reverse_complement_dna_seq(dna_seq: str) -> str:
113
+ return reverse_dna_seq(complement_dna_seq(dna_seq))
114
+
115
+
116
+ def reverse_complement_column(column: pd.Series) -> np.ndarray:
117
+ rc_column: list = [reverse_complement_dna_seq(seq) for seq in column]
118
+ return rc_column
119
+
120
+
121
+ class TorchMetrics:
122
+ def __init__(self, device=DEVICE):
123
+ self.binary_accuracy = BinaryAccuracy().to(device)
124
+ self.binary_auc = BinaryAUROC().to(device)
125
+ self.binary_f1_score = BinaryF1Score().to(device)
126
+ self.binary_precision = BinaryPrecision().to(device)
127
+ self.binary_recall = BinaryRecall().to(device)
128
+ pass
129
+
130
+ def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed
131
+ self.binary_accuracy.update(preds=batch_predicted_labels, target=batch_actual_labels)
132
+ self.binary_auc.update(preds=batch_predicted_labels, target=batch_actual_labels)
133
+ self.binary_f1_score.update(preds=batch_predicted_labels, target=batch_actual_labels)
134
+ self.binary_precision.update(preds=batch_predicted_labels, target=batch_actual_labels)
135
+ self.binary_recall.update(preds=batch_predicted_labels, target=batch_actual_labels)
136
+ pass
137
+
138
+ def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green):
139
+ b_accuracy = self.binary_accuracy.compute()
140
+ b_auc = self.binary_auc.compute()
141
+ b_f1_score = self.binary_f1_score.compute()
142
+ b_precision = self.binary_precision.compute()
143
+ b_recall = self.binary_recall.compute()
144
+ timber.info(
145
+ log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
146
+ log(f"{log_prefix}_accuracy", b_accuracy)
147
+ log(f"{log_prefix}_auc", b_auc)
148
+ log(f"{log_prefix}_f1_score", b_f1_score)
149
+ log(f"{log_prefix}_precision", b_precision)
150
+ log(f"{log_prefix}_recall", b_recall)
151
+
152
+ self.binary_accuracy.reset()
153
+ self.binary_auc.reset()
154
+ self.binary_f1_score.reset()
155
+ self.binary_precision.reset()
156
+ self.binary_recall.reset()
157
+ pass
158
+
159
+
160
+ def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
161
+ start = 0
162
+ end = len(seq)
163
+ rand_pos = random.randrange(start, (end - len(DEBUG_MOTIF)))
164
+ random_end = rand_pos + len(DEBUG_MOTIF)
165
+ output = seq[start: rand_pos] + DEBUG_MOTIF + seq[random_end: end]
166
+ assert len(seq) == len(output)
167
+ return output
168
+
169
+
170
+ class PagingMQTLDataset(IterableDataset):
171
+ def __init__(self,
172
+ m_dataset,
173
+ seq_len,
174
+ tokenizer,
175
+ max_length=512,
176
+ check_if_pipeline_is_ok_by_inserting_debug_motif=False):
177
+ self.dataset = m_dataset
178
+ self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
179
+ self.debug_motif = "ATCGCCTA"
180
+ self.seq_len = seq_len
181
+
182
+ self.bert_tokenizer = tokenizer
183
+ self.max_length = max_length
184
+ pass
185
+
186
+ def __iter__(self):
187
+ for row in self.dataset:
188
+ processed = self.preprocess(row)
189
+ if processed is not None:
190
+ yield processed
191
+
192
+ def preprocess(self, row):
193
+ sequence = row['sequence'] # Fetch the 'sequence' column
194
+ if len(sequence) != self.seq_len:
195
+ return None # skip problematic row!
196
+ label = row['label'] # Fetch the 'label' column (or whatever target you use)
197
+ if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
198
+ sequence = insert_debug_motif_at_random_position(seq=sequence, DEBUG_MOTIF=self.debug_motif)
199
+ ohe_sequence = one_hot_e(dna_seq=sequence)
200
+ one_seq_tensor = torch.from_numpy(ohe_sequence).to(torch.int64)
201
+ # Tokenize the sequence
202
+ encoded_sequence_tokenized: BatchEncoding = self.bert_tokenizer(one_seq_tensor)
203
+ input_ids = encoded_sequence_tokenized["input_ids"]
204
+ # encoded_sequence_squeezed = {key: val.squeeze() for key, val in encoded_sequence.items()}
205
+ return input_ids, label
206
+
207
+
208
+ # def collate_fn(batch):
209
+ # sequences, labels = zip(*batch)
210
+ # ohe_seq, ohe_seq_rc = sequences[0], sequences[1]
211
+ # # Pad sequences to the maximum length in this batch
212
+ # padded_sequences = pad_sequence(ohe_seq, batch_first=True, padding_value=0)
213
+ # padded_sequences_rc = pad_sequence(ohe_seq_rc, batch_first=True, padding_value=0)
214
+ # # Convert labels to a tensor
215
+ # labels = torch.stack(labels)
216
+ # return [padded_sequences, padded_sequences_rc], labels
217
+
218
+
219
+ class MqtlDataModule(LightningDataModule):
220
+ def __init__(self, train_ds, val_ds, test_ds, batch_size=16):
221
+ super().__init__()
222
+ self.batch_size = batch_size
223
+ self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False,
224
+ # collate_fn=collate_fn,
225
+ num_workers=1,
226
+ # persistent_workers=True
227
+ )
228
+ self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False,
229
+ # collate_fn=collate_fn,
230
+ num_workers=1,
231
+ # persistent_workers=True
232
+ )
233
+ self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False,
234
+ # collate_fn=collate_fn,
235
+ num_workers=1,
236
+ # persistent_workers=True
237
+ )
238
+ pass
239
+
240
+ def prepare_data(self):
241
+ pass
242
+
243
+ def setup(self, stage: str) -> None:
244
+ timber.info(f"inside setup: {stage = }")
245
+ pass
246
+
247
+ def train_dataloader(self) -> TRAIN_DATALOADERS:
248
+ return self.train_loader
249
+
250
+ def val_dataloader(self) -> EVAL_DATALOADERS:
251
+ return self.validate_loader
252
+
253
+ def test_dataloader(self) -> EVAL_DATALOADERS:
254
+ return self.test_loader
255
+
256
+
257
+ class MQtlBertClassifierLightningModule(LightningModule):
258
+ def __init__(self,
259
+ classifier: nn.Module,
260
+ criterion=None, # nn.BCEWithLogitsLoss(),
261
+ regularization: int = 2, # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care
262
+ l1_lambda=0.001,
263
+ l2_wright_decay=0.001,
264
+ *args: Any,
265
+ **kwargs: Any):
266
+ super().__init__(*args, **kwargs)
267
+ self.classifier = classifier
268
+ self.criterion = criterion
269
+ self.train_metrics = TorchMetrics()
270
+ self.validate_metrics = TorchMetrics()
271
+ self.test_metrics = TorchMetrics()
272
+
273
+ self.regularization = regularization
274
+ self.l1_lambda = l1_lambda
275
+ self.l2_weight_decay = l2_wright_decay
276
+ pass
277
+
278
+ def forward(self, x, *args: Any, **kwargs: Any) -> Any:
279
+ input_ids: torch.tensor = x["input_ids"]
280
+ return self.classifier.forward(input_ids)
281
+
282
+ def configure_optimizers(self) -> OptimizerLRScheduler:
283
+ # Here we add weight decay (L2 regularization) to the optimizer
284
+ weight_decay = 0.0
285
+ if self.regularization == 2 or self.regularization == 3:
286
+ weight_decay = self.l2_weight_decay
287
+ return torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=weight_decay) # , weight_decay=0.005)
288
+
289
+ def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
290
+ # Accuracy on training batch data
291
+ x, y = batch
292
+ preds = self.forward(x)
293
+ loss = self.criterion(preds, y)
294
+
295
+ if self.regularization == 1 or self.regularization == 3: # apply l1 regularization
296
+ l1_norm = sum(p.abs().sum() for p in self.parameters())
297
+ loss += self.l1_lambda * l1_norm
298
+
299
+ self.log("train_loss", loss)
300
+ # calculate the scores start
301
+ self.train_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
302
+ # calculate the scores end
303
+ return loss
304
+
305
+ def on_train_epoch_end(self) -> None:
306
+ self.train_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="train")
307
+ pass
308
+
309
+ def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
310
+ # Accuracy on validation batch data
311
+ # print(f"debug { batch = }")
312
+ x, y = batch
313
+ preds = self.forward(x)
314
+ loss = self.criterion(preds, y)
315
+ self.log("valid_loss", loss)
316
+ # calculate the scores start
317
+ self.validate_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
318
+ # calculate the scores end
319
+ return loss
320
+
321
+ def on_validation_epoch_end(self) -> None:
322
+ self.validate_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="validate", log_color=blue)
323
+ return None
324
+
325
+ def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
326
+ # Accuracy on validation batch data
327
+ x, y = batch
328
+ preds = self.forward(x)
329
+ loss = self.criterion(preds, y)
330
+ self.log("test_loss", loss) # do we need this?
331
+ # calculate the scores start
332
+ self.test_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
333
+ # calculate the scores end
334
+ return loss
335
+
336
+ def on_test_epoch_end(self) -> None:
337
+ self.test_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="test", log_color=magenta)
338
+ return None
339
+
340
+ pass
341
+
342
+
343
+ def start_bert(classifier_model, criterion, m_optimizer=torch.optim.Adam, WINDOW=200,
344
+ is_binned=True, is_debug=False, max_epochs=10, batch_size=8):
345
+ file_suffix = ""
346
+ if is_binned:
347
+ file_suffix = "_binned"
348
+
349
+ data_files = {
350
+ # small samples
351
+ "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
352
+ "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
353
+ "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
354
+ # medium samples
355
+ "train_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_train_binned.csv",
356
+ "validate_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_validate_binned.csv",
357
+ "test_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_test_binned.csv",
358
+
359
+ # large samples
360
+ "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
361
+ "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
362
+ "test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
363
+ }
364
+
365
+ dataset_map = None
366
+ is_my_laptop = os.path.isfile("/src/inputdata/dataset_4000_test_binned.csv")
367
+ if is_my_laptop:
368
+ dataset_map = load_dataset("csv", data_files=data_files, streaming=True)
369
+ else:
370
+ dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True)
371
+
372
+ tokenizer = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path=PRETRAINED_MODEL_NAME,
373
+ trust_remote_code=True)
374
+
375
+ train_dataset = PagingMQTLDataset(dataset_map[f"train_binned_{WINDOW}"],
376
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
377
+ tokenizer=tokenizer,
378
+ seq_len=WINDOW
379
+ )
380
+ val_dataset = PagingMQTLDataset(dataset_map[f"validate_binned_{WINDOW}"],
381
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
382
+ tokenizer=tokenizer,
383
+ seq_len=WINDOW)
384
+ test_dataset = PagingMQTLDataset(dataset_map[f"test_binned_{WINDOW}"],
385
+ check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
386
+ tokenizer=tokenizer,
387
+ seq_len=WINDOW)
388
+
389
+ data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset, batch_size=batch_size)
390
+
391
+ classifier_model = classifier_model #.to(DEVICE)
392
+ try:
393
+ classifier_model = classifier_model.from_pretrained(classifier_model.model_repository_name)
394
+ except Exception as x:
395
+ print(x)
396
+
397
+ # classifier_module = MQtlBertClassifierLightningModule(
398
+ # classifier=classifier_model,
399
+ # regularization=2, criterion=criterion)
400
+
401
+ # if os.path.exists(model_save_path):
402
+ # classifier_module.load_state_dict(torch.load(model_save_path))
403
+ args = {
404
+ "output_dir": "tmp",
405
+ "num_train_epochs": 1,
406
+ "per_device_train_batch_size": 1,
407
+ "gradient_accumulation_steps": 4,
408
+ "gradient_checkpointing": True,
409
+ "learning_rate": 2e-5,
410
+ }
411
+ training_args = TrainingArguments(**args)
412
+
413
+ trainer = Trainer(model=classifier_model, args=training_args, datamodule=data_module, max_epochs=max_epochs,
414
+ precision="32")
415
+ trainer.fit(model=classifier_model)
416
+ timber.info("\n\n")
417
+ trainer.test(model=classifier_model)
418
+ timber.info("\n\n")
419
+ # torch.save(classifier_module.state_dict(), model_save_path) # deprecated, use classifier_model.save_pretrained(model_subdirectory) instead
420
+
421
+ # save locally
422
+ model_subdirectory = classifier_model.model_repository_name
423
+ classifier_model.save_pretrained(model_subdirectory)
424
+
425
+ # push to the hub
426
+ commit_message = f":tada: Push model for window size {WINDOW} from huggingface space"
427
+ if is_my_laptop:
428
+ commit_message = f":tada: Push model for window size {WINDOW} from zephyrus"
429
+
430
+ classifier_model.push_to_hub(
431
+ repo_id=f"fahimfarhan/{classifier_model.model_repository_name}",
432
+ # subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/
433
+ commit_message=commit_message # f":tada: Push model for window size {WINDOW}"
434
+ )
435
+
436
+ # reload
437
+ # classifier_model = classifier_model.from_pretrained(f"fahimfarhan/{classifier_model.model_repository_name}")
438
+ # classifier_model = classifier_model.from_pretrained(model_subdirectory)
439
+
440
+ pass
441
+
442
+
443
+ class CommonAttentionLayer(nn.Module):
444
+ def __init__(self, hidden_size, *args, **kwargs):
445
+ super().__init__(*args, **kwargs)
446
+ self.attention_linear = nn.Linear(hidden_size, 1)
447
+ pass
448
+
449
+ def forward(self, hidden_states):
450
+ # Apply linear layer
451
+ attn_weights = self.attention_linear(hidden_states)
452
+ # Apply softmax to get attention scores
453
+ attn_weights = torch.softmax(attn_weights, dim=1)
454
+ # Apply attention weights to hidden states
455
+ context_vector = torch.sum(attn_weights * hidden_states, dim=1)
456
+ return context_vector, attn_weights
457
+
458
+
459
+ class ReshapedBCEWithLogitsLoss(nn.BCEWithLogitsLoss):
460
+ def forward(self, input, target):
461
+ return super().forward(input.squeeze(), target.float())
462
+
463
+
464
+ class HyenaDnaMQTLClassifier(nn.Module):
465
+ def __init__(self,
466
+ seq_len: int, model_repository_name: str,
467
+ bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=PRETRAINED_MODEL_NAME),
468
+ hidden_size=768,
469
+ num_classes=1,
470
+ *args,
471
+ **kwargs
472
+ ):
473
+ super().__init__(*args, **kwargs)
474
+ self.seq_len = seq_len
475
+ self.model_repository_name = model_repository_name
476
+
477
+ self.model_name = "MQtlDnaBERT6Classifier"
478
+
479
+ self.bert_model = bert_model
480
+ self.attention = CommonAttentionLayer(hidden_size)
481
+ self.classifier = nn.Linear(hidden_size, num_classes)
482
+ pass
483
+
484
+ def forward(self, input_ids: torch.tensor):
485
+ """
486
+ # torch.Size([128, 1, 512]) --> [128, 512]
487
+ input_ids = input_ids.squeeze(dim=1).to(DEVICE)
488
+ # torch.Size([16, 1, 512]) --> [16, 512]
489
+ attention_mask = attention_mask.squeeze(dim=1).to(DEVICE)
490
+ token_type_ids = token_type_ids.squeeze(dim=1).to(DEVICE)
491
+ """
492
+ bert_output: BaseModelOutputWithPoolingAndCrossAttentions = self.bert_model(input_ids=input_ids)
493
+
494
+ last_hidden_state = bert_output.last_hidden_state
495
+ context_vector, ignore_attention_weight = self.attention(last_hidden_state)
496
+ y = self.classifier(context_vector)
497
+ return y
498
+
499
+
500
+ if __name__ == '__main__':
501
+ login_inside_huggingface_virtualmachine()
502
+
503
+ WINDOW = 1000
504
+ some_model = BertModel.from_pretrained(
505
+ pretrained_model_name_or_path=PRETRAINED_MODEL_NAME) # HyenaDnaMQTLClassifier(seq_len=WINDOW, model_repository_name="hyenadna-sm-32k-mqtl-classifier")
506
+ criterion = None
507
+
508
+ start_bert(
509
+ classifier_model=some_model,
510
+ criterion=criterion,
511
+ WINDOW=WINDOW,
512
+ is_debug=False,
513
+ max_epochs=20,
514
+ batch_size=16
515
+ )
516
+ pass
requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate # required by HayenaDNA
2
+ datasets
3
+ pandas
4
+ polars
5
+ numpy
6
+ matplotlib
7
+ scipy
8
+ shap
9
+ scikit-learn
10
+ skorch==1.0.0
11
+ six
12
+ hyperopt
13
+ requests
14
+ pyyaml
15
+ Bio
16
+ plotly
17
+ Levenshtein
18
+ # pytorch
19
+ captum
20
+ torch==2.4.0
21
+ torchvision
22
+ torchaudio
23
+ torchsummary
24
+ torcheval
25
+ pydot
26
+ pydotplus
27
+ PySide2 # matplotlib dependency on ubuntu. you may need sth else for other os/env setup
28
+ torchviz
29
+ gReLU # luckily now available in pip!
30
+ # gReLU @ git+https://github.com/Genentech/gReLU # @623fee8023aabcef89f0afeedbeafff4b71453af
31
+ # lightning[extra] # cz I got a stupid warning in the console logs
32
+ torchmetrics
33
+ python-dotenv