Reggie commited on
Commit
b11da2d
·
1 Parent(s): e7351df

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -15
README.md CHANGED
@@ -34,22 +34,13 @@ You'll need to pip install transformers & maybe sentencepiece
34
 
35
  ### How to use
36
  ```python
37
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
38
- import torch, time
39
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
40
  model_name = 'Reggie/muppet-roberta-base-joke_detector'
41
  max_seq_len = 510
42
 
43
- tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=max_seq_len)
44
- model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
45
-
46
- premise = """A nervous passenger is about to book a flight ticket, and he asks the airlines' ticket seller, "I hope your planes are safe. Do they have a good track record for safety?" The airline agent replies, "Sir, I can guarantee you, we've never had a plane that has crashed more than once." """
47
- hypothesis = ""
48
-
49
- input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
50
- output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
51
- prediction = torch.softmax(output["logits"][0], -1).tolist()
52
- is_joke = True if prediction[0] < prediction[1] else False
53
-
54
- print(is_joke)
55
  ```
 
34
 
35
  ### How to use
36
  ```python
37
+ from transformers import pipeline
38
+ device = 0 if torch.cuda.is_available() else -1
 
39
  model_name = 'Reggie/muppet-roberta-base-joke_detector'
40
  max_seq_len = 510
41
 
42
+ pipe = pipeline(model=model_name, device=device, truncation=True, max_length=max_seq_len)
43
+ is_it_a_joke = """A nervous passenger is about to book a flight ticket, and he asks the airlines' ticket seller, "I hope your planes are safe. Do they have a good track record for safety?" The airline agent replies, "Sir, I can guarantee you, we've never had a plane that has crashed more than once." """
44
+ result = pipe(is_it_a_joke) # [{'label': 'LABEL_1', 'score': 0.7313136458396912}]
45
+ print('This is a joke') if result[0]['label'] == 'LABEL_1' else print('This is not a joke')
 
 
 
 
 
 
 
 
46
  ```