Spaces:
Runtime error
Runtime error
Stefan Heimersheim
commited on
Commit
·
ba97e60
1
Parent(s):
eef2607
Fix x-axis bug
Browse files
app.py
CHANGED
@@ -114,7 +114,8 @@ def imshow(tensor, xlabel="X", ylabel="Y", zlabel=None, xticks=None, yticks=None
|
|
114 |
|
115 |
def plot_residual_stream_patch(clean_prompt=None, answer=None, corrupt_prompt=None, corrupt_answer=None):
|
116 |
layers = ["blocks.0.hook_resid_pre", *[f"blocks.{i}.hook_resid_post" for i in range(model.cfg.n_layers)]]
|
117 |
-
|
|
|
118 |
patching_effect = compute_residual_stream_patch(clean_prompt=clean_prompt, answer=answer, corrupt_prompt=corrupt_prompt, corrupt_answer=corrupt_answer, layers=layers)
|
119 |
fig = imshow(patching_effect, xticks=token_labels, yticks=layers, xlabel="Position", ylabel="Layer",
|
120 |
zlabel="Logit Difference", title="Patching residual stream at specific layer and position")
|
|
|
114 |
|
115 |
def plot_residual_stream_patch(clean_prompt=None, answer=None, corrupt_prompt=None, corrupt_answer=None):
|
116 |
layers = ["blocks.0.hook_resid_pre", *[f"blocks.{i}.hook_resid_post" for i in range(model.cfg.n_layers)]]
|
117 |
+
clean_tokens = model.to_str_tokens(clean_prompt)
|
118 |
+
token_labels = [f"(pos {i:2}) {t}" for i, t in enumerate(clean_tokens)]
|
119 |
patching_effect = compute_residual_stream_patch(clean_prompt=clean_prompt, answer=answer, corrupt_prompt=corrupt_prompt, corrupt_answer=corrupt_answer, layers=layers)
|
120 |
fig = imshow(patching_effect, xticks=token_labels, yticks=layers, xlabel="Position", ylabel="Layer",
|
121 |
zlabel="Logit Difference", title="Patching residual stream at specific layer and position")
|