Update README.md
Browse files
README.md
CHANGED
@@ -24,9 +24,13 @@ The model is used to generate Japanese lyrics.
|
|
24 |
import torch
|
25 |
from transformers import T5Tokenizer, GPT2LMHeadModel
|
26 |
|
|
|
|
|
|
|
|
|
27 |
tokenizer = T5Tokenizer.from_pretrained("skytnt/gpt2-japanese-lyric-medium")
|
28 |
model = GPT2LMHeadModel.from_pretrained("skytnt/gpt2-japanese-lyric-medium")
|
29 |
-
|
30 |
|
31 |
def gen_lyric(title: str, prompt_text: str):
|
32 |
if len(title)!= 0 or len(prompt_text)!= 0:
|
|
|
24 |
import torch
|
25 |
from transformers import T5Tokenizer, GPT2LMHeadModel
|
26 |
|
27 |
+
device = torch.device("cpu")
|
28 |
+
if torch.cuda.is_available():
|
29 |
+
device = torch.device("cuda")
|
30 |
+
|
31 |
tokenizer = T5Tokenizer.from_pretrained("skytnt/gpt2-japanese-lyric-medium")
|
32 |
model = GPT2LMHeadModel.from_pretrained("skytnt/gpt2-japanese-lyric-medium")
|
33 |
+
model = model.to(device)
|
34 |
|
35 |
def gen_lyric(title: str, prompt_text: str):
|
36 |
if len(title)!= 0 or len(prompt_text)!= 0:
|