smhavens commited on
Commit
0473071
·
1 Parent(s): 05a2e2d

Begin Fine-tuning

Browse files
Files changed (1) hide show
  1. app.py +53 -3
app.py CHANGED
@@ -3,9 +3,15 @@ import spacy
3
  import math
4
  from datasets import load_dataset
5
  from sentence_transformers import SentenceTransformer
6
- from transformers import AutoTokenizer, AutoModel
 
7
  import torch
8
  import torch.nn.functional as F
 
 
 
 
 
9
 
10
 
11
  #Mean Pooling - Take attention mask into account for correct averaging
@@ -15,13 +21,56 @@ def mean_pooling(model_output, attention_mask):
15
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
18
  def training():
19
  dataset = load_dataset("glue", "cola")
20
  dataset = dataset["train"]
 
 
 
 
 
 
 
 
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  sentences = ["This is an example sentence", "Each sentence is converted"]
23
 
24
- model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
25
  embeddings = model.encode(sentences)
26
  print(embeddings)
27
 
@@ -47,7 +96,8 @@ def training():
47
 
48
  print("Sentence embeddings:")
49
  print(sentence_embeddings)
50
-
 
51
 
52
  def greet(name):
53
  return "Hello " + name + "!!"
 
3
  import math
4
  from datasets import load_dataset
5
  from sentence_transformers import SentenceTransformer
6
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
7
+ from transformers import TrainingArguments, Trainer
8
  import torch
9
  import torch.nn.functional as F
10
+ import numpy as np
11
+ import evaluate
12
+
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
15
 
16
 
17
  #Mean Pooling - Take attention mask into account for correct averaging
 
21
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
22
 
23
 
24
+ def tokenize_function(examples):
25
+ return tokenizer(examples["text"], padding="max_length", truncation=True)
26
+
27
+
28
+ def compute_metrics(eval_pred):
29
+ logits, labels = eval_pred
30
+ predictions = np.argmax(logits, axis=-1)
31
+ metric = evaluate.load("accuracy")
32
+ return metric.compute(predictions=predictions, references=labels)
33
+
34
+
35
  def training():
36
  dataset = load_dataset("glue", "cola")
37
  dataset = dataset["train"]
38
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
39
+
40
+ small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
41
+ small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
42
+
43
+
44
+
45
+ finetune(small_train_dataset, small_eval_dataset)
46
+
47
 
48
+ def finetune(train, eval):
49
+ model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)
50
+
51
+ training_args = TrainingArguments(output_dir="test_trainer")
52
+
53
+ # USE THIS LINK
54
+ # https://huggingface.co/blog/how-to-train-sentence-transformers
55
+
56
+
57
+ # accuracy = compute_metrics(eval, metric)
58
+
59
+ training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")
60
+
61
+ trainer = Trainer(
62
+ model=model,
63
+ args=training_args,
64
+ train_dataset=train,
65
+ eval_dataset=eval,
66
+ compute_metrics=compute_metrics,
67
+ )
68
+
69
+ trainer.train()
70
+
71
  sentences = ["This is an example sentence", "Each sentence is converted"]
72
 
73
+ # model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
74
  embeddings = model.encode(sentences)
75
  print(embeddings)
76
 
 
96
 
97
  print("Sentence embeddings:")
98
  print(sentence_embeddings)
99
+
100
+
101
 
102
  def greet(name):
103
  return "Hello " + name + "!!"