winglian commited on
Commit
262dc29
·
unverified ·
2 Parent(s): 28fd429 a032c9f

Merge pull request #300 from OpenAccess-AI-Collective/pytorch-201

Browse files
.github/workflows/base.yml CHANGED
@@ -18,12 +18,12 @@ jobs:
18
  - cuda: "118"
19
  cuda_version: 11.8.0
20
  python_version: "3.9"
21
- pytorch: 2.0.0
22
  axolotl_extras:
23
  - cuda: "118"
24
  cuda_version: 11.8.0
25
  python_version: "3.10"
26
- pytorch: 2.0.0
27
  axolotl_extras:
28
  - cuda: "117"
29
  cuda_version: 11.7.1
@@ -33,7 +33,7 @@ jobs:
33
  - cuda: "118"
34
  cuda_version: 11.8.0
35
  python_version: "3.9"
36
- pytorch: 2.0.0
37
  axolotl_extras: gptq
38
  steps:
39
  - name: Checkout
 
18
  - cuda: "118"
19
  cuda_version: 11.8.0
20
  python_version: "3.9"
21
+ pytorch: 2.0.1
22
  axolotl_extras:
23
  - cuda: "118"
24
  cuda_version: 11.8.0
25
  python_version: "3.10"
26
+ pytorch: 2.0.1
27
  axolotl_extras:
28
  - cuda: "117"
29
  cuda_version: 11.7.1
 
33
  - cuda: "118"
34
  cuda_version: 11.8.0
35
  python_version: "3.9"
36
+ pytorch: 2.0.1
37
  axolotl_extras: gptq
38
  steps:
39
  - name: Checkout
.github/workflows/main.yml CHANGED
@@ -17,17 +17,17 @@ jobs:
17
  - cuda: cu118
18
  cuda_version: 11.8.0
19
  python_version: "3.9"
20
- pytorch: 2.0.0
21
  axolotl_extras:
22
  - cuda: cu118
23
  cuda_version: 11.8.0
24
  python_version: "3.10"
25
- pytorch: 2.0.0
26
  axolotl_extras:
27
  - cuda: cu118
28
  cuda_version: 11.8.0
29
  python_version: "3.9"
30
- pytorch: 2.0.0
31
  axolotl_extras: gptq
32
  - cuda: cu117
33
  cuda_version: 11.7.1
@@ -72,17 +72,17 @@ jobs:
72
  - cuda: cu118
73
  cuda_version: 11.8.0
74
  python_version: "3.9"
75
- pytorch: 2.0.0
76
  axolotl_extras:
77
  - cuda: cu118
78
  cuda_version: 11.8.0
79
  python_version: "3.10"
80
- pytorch: 2.0.0
81
  axolotl_extras:
82
  - cuda: cu118
83
  cuda_version: 11.8.0
84
  python_version: "3.9"
85
- pytorch: 2.0.0
86
  axolotl_extras: gptq
87
  - cuda: cu117
88
  cuda_version: 11.7.1
 
17
  - cuda: cu118
18
  cuda_version: 11.8.0
19
  python_version: "3.9"
20
+ pytorch: 2.0.1
21
  axolotl_extras:
22
  - cuda: cu118
23
  cuda_version: 11.8.0
24
  python_version: "3.10"
25
+ pytorch: 2.0.1
26
  axolotl_extras:
27
  - cuda: cu118
28
  cuda_version: 11.8.0
29
  python_version: "3.9"
30
+ pytorch: 2.0.1
31
  axolotl_extras: gptq
32
  - cuda: cu117
33
  cuda_version: 11.7.1
 
72
  - cuda: cu118
73
  cuda_version: 11.8.0
74
  python_version: "3.9"
75
+ pytorch: 2.0.1
76
  axolotl_extras:
77
  - cuda: cu118
78
  cuda_version: 11.8.0
79
  python_version: "3.10"
80
+ pytorch: 2.0.1
81
  axolotl_extras:
82
  - cuda: cu118
83
  cuda_version: 11.8.0
84
  python_version: "3.9"
85
+ pytorch: 2.0.1
86
  axolotl_extras: gptq
87
  - cuda: cu117
88
  cuda_version: 11.7.1
docker/Dockerfile-base CHANGED
@@ -38,8 +38,9 @@ WORKDIR /workspace
38
 
39
  ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
40
 
41
- RUN git clone https://github.com/HazyResearch/flash-attention.git && \
42
  cd flash-attention && \
 
43
  python3 setup.py bdist_wheel && \
44
  cd csrc/fused_dense_lib && \
45
  python3 setup.py bdist_wheel && \
 
38
 
39
  ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
40
 
41
+ RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
42
  cd flash-attention && \
43
+ git checkout v1.0.9 && \
44
  python3 setup.py bdist_wheel && \
45
  cd csrc/fused_dense_lib && \
46
  python3 setup.py bdist_wheel && \
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py CHANGED
@@ -184,14 +184,15 @@ def sdp_attention_forward(
184
 
185
  # We only apply sdp attention if we don't need to output the whole attention matrix
186
  if not output_attentions:
187
- attn_output = torch.nn.functional.scaled_dot_product_attention(
188
- query_states,
189
- key_states,
190
- value_states,
191
- attn_mask=attention_mask,
192
- is_causal=False,
193
- )
194
- attn_weights = None
 
195
  else:
196
  attn_weights = torch.matmul(
197
  query_states, key_states.transpose(2, 3)
 
184
 
185
  # We only apply sdp attention if we don't need to output the whole attention matrix
186
  if not output_attentions:
187
+ with torch.backends.cuda.sdp_kernel():
188
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
189
+ query_states,
190
+ key_states,
191
+ value_states,
192
+ attn_mask=attention_mask,
193
+ is_causal=False,
194
+ )
195
+ attn_weights = None
196
  else:
197
  attn_weights = torch.matmul(
198
  query_states, key_states.transpose(2, 3)