zetavg commited on
Commit
570c043
·
1 Parent(s): 4870204

extract and fix get_device

Browse files
llama_lora/lib/get_device.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_device():
5
+ device ="cpu"
6
+ if torch.cuda.is_available():
7
+ device = "cuda"
8
+
9
+ try:
10
+ if torch.backends.mps.is_available():
11
+ device = "mps"
12
+ except: # noqa: E722
13
+ pass
14
+
15
+ return device
llama_lora/lib/inference.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import transformers
3
 
 
4
  from .streaming_generation_utils import Iteratorize, Stream
5
 
6
 
@@ -62,16 +63,3 @@ def generate(
62
  decoded_output = tokenizer.decode(output, skip_special_tokens=True)
63
  yield decoded_output, output
64
  return
65
-
66
-
67
- def get_device():
68
- if torch.cuda.is_available():
69
- return "cuda"
70
- else:
71
- return "cpu"
72
-
73
- try:
74
- if torch.backends.mps.is_available():
75
- return "mps"
76
- except: # noqa: E722
77
- pass
 
1
  import torch
2
  import transformers
3
 
4
+ from .get_device import get_device
5
  from .streaming_generation_utils import Iteratorize, Stream
6
 
7
 
 
63
  decoded_output = tokenizer.decode(output, skip_special_tokens=True)
64
  yield decoded_output, output
65
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
llama_lora/models.py CHANGED
@@ -8,19 +8,7 @@ from transformers import LlamaForCausalLM, LlamaTokenizer
8
  from peft import PeftModel
9
 
10
  from .globals import Global
11
-
12
-
13
- def get_device():
14
- if torch.cuda.is_available():
15
- return "cuda"
16
- else:
17
- return "cpu"
18
-
19
- try:
20
- if torch.backends.mps.is_available():
21
- return "mps"
22
- except: # noqa: E722
23
- pass
24
 
25
 
26
  def get_new_base_model(base_model_name):
 
8
  from peft import PeftModel
9
 
10
  from .globals import Global
11
+ from .lib.get_device import get_device
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  def get_new_base_model(base_model_name):