ludolara commited on
Commit
5cd6bff
·
1 Parent(s): 7e60f53

Create LightningViTRegressor.py

Browse files
Files changed (1) hide show
  1. LightningViTRegressor.py +56 -0
LightningViTRegressor.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning.pytorch as pl
2
+ import torchmetrics
3
+ from torch.optim import AdamW
4
+ from transformers import ViTForImageClassification
5
+ from torch import nn
6
+ from transformers.optimization import get_scheduler
7
+
8
+ class LightningViTRegressor(pl.LightningModule):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.model = ViTForImageClassification.from_pretrained(
12
+ "google/vit-base-patch16-224-in21k",
13
+ num_labels=1,
14
+ )
15
+ self.mse = torchmetrics.MeanSquaredError()
16
+ self.mae = torchmetrics.MeanAbsoluteError()
17
+ self.r2_score = torchmetrics.R2Score()
18
+
19
+ def common_step(self, step_type, batch, batch_idx):
20
+ x,y = batch
21
+ x = self.model(x)
22
+ x = x.logits
23
+ loss = nn.functional.mse_loss(x,y)
24
+ mean_squared_error = self.mse(x,y)
25
+ mean_absolute_error = self.mae(x,y)
26
+ r2_score = self.r2_score(x,y)
27
+ to_log = {step_type + "_loss": loss,
28
+ step_type + "_mse": mean_squared_error,
29
+ step_type + "_mae": mean_absolute_error,
30
+ step_type + '_r2_score': r2_score} # add more items if needed
31
+ self.log_dict(to_log)
32
+ return loss
33
+
34
+ def training_step(self, batch, batch_idx):
35
+ loss = self.common_step("train", batch, batch_idx)
36
+ return loss
37
+
38
+ def validation_step(self, batch, batch_idx):
39
+ loss = self.common_step("val", batch, batch_idx)
40
+ return loss
41
+
42
+ def test_step(self, batch, batch_idx):
43
+ loss = self.common_step("test", batch, batch_idx)
44
+ return loss
45
+
46
+ # def configure_optimizers(self):
47
+ # optimizer = optim.Adam(self.parameters(), lr = 1e-5)
48
+ # return optimizer
49
+
50
+ def configure_optimizers(self):
51
+ # optimizer = AdamW(optimizer_grouped_params, lr=self.hparams.lr, betas=(0.9, 0.999), eps=1e-7)
52
+ optimizer = AdamW(self.parameters(), lr = 1e-5)
53
+ # Configure learning rate scheduler.
54
+ scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=self.trainer.estimated_stepping_batches)
55
+ scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
56
+ return [optimizer], [scheduler]