3
3
"""
4
4
5
5
import os
6
+ from matplotlib .pylab import f
6
7
import torch as t
7
8
import numpy as np
8
9
import seaborn as sns
9
10
import matplotlib .pyplot as plt
10
- from sklearn .decomposition import PCA
11
11
from behaviors import ALL_BEHAVIORS , get_analysis_dir , HUMAN_NAMES , get_steering_vector , ANALYSIS_PATH
12
12
from utils .helpers import get_model_path , model_name_format , set_plotting_settings
13
13
from tqdm import tqdm
@@ -37,17 +37,17 @@ def plot_per_layer_similarities(model_size: str, is_base: bool, behavior: str):
37
37
for layer2 in range (n_layers ):
38
38
cosine_sim = t .nn .functional .cosine_similarity (all_vectors [layer1 ], all_vectors [layer2 ], dim = 0 ).item ()
39
39
matrix [layer1 , layer2 ] = cosine_sim
40
- plt .figure (figsize = (5 , 5 ))
40
+ plt .figure (figsize = (3 , 3 ))
41
41
sns .heatmap (matrix , annot = False , cmap = 'coolwarm' )
42
42
# Set ticks for every 5th layer
43
43
plt .xticks (list (range (n_layers ))[::5 ], list (range (n_layers ))[::5 ])
44
44
plt .yticks (list (range (n_layers ))[::5 ], list (range (n_layers ))[::5 ])
45
- plt .title (f"Inter-layer similarity, { model_name } " )
46
- plt .savefig (os .path .join (analysis_dir , f"cosine_similarities_{ model_name .replace (' ' , '_' )} _{ behavior } .png " ), format = 'png ' )
45
+ plt .title (f"Layer similarity, { model_name } " , fontsize = 11 )
46
+ plt .savefig (os .path .join (analysis_dir , f"cosine_similarities_{ model_name .replace (' ' , '_' )} _{ behavior } .svg " ), format = 'svg ' )
47
47
plt .close ()
48
48
49
49
def plot_base_chat_similarities ():
50
- plt .figure (figsize = (8 , 4 ))
50
+ plt .figure (figsize = (5 , 3 ))
51
51
for behavior in ALL_BEHAVIORS :
52
52
base_caa_info = get_caa_info (behavior , "7b" , True )
53
53
chat_caa_info = get_caa_info (behavior , "7b" , False )
@@ -57,48 +57,19 @@ def plot_base_chat_similarities():
57
57
for layer in range (base_caa_info ["n_layers" ]):
58
58
cos_sim = t .nn .functional .cosine_similarity (vectors_base [layer ], vectors_chat [layer ], dim = 0 ).item ()
59
59
cos_sims .append (cos_sim )
60
- plt .plot (list (range (base_caa_info ["n_layers" ])), cos_sims , label = HUMAN_NAMES [behavior ])
60
+ plt .plot (list (range (base_caa_info ["n_layers" ])), cos_sims , label = HUMAN_NAMES [behavior ], linestyle = "solid" , linewidth = 2 )
61
61
plt .xlabel ("Layer" )
62
62
plt .ylabel ("Cosine Similarity" )
63
- plt .title ("Steering vector similarity between Llama 2 base and chat" )
64
- plt .legend ()
63
+ plt .title ("Base vs. Chat model vector similarity" , fontsize = 12 )
64
+ # legend in bottom right
65
+ plt .legend (loc = "lower right" )
65
66
plt .tight_layout ()
66
67
plt .savefig (os .path .join (ANALYSIS_PATH , "base_chat_similarities.png" ), format = 'png' )
67
68
plt .close ()
68
-
69
- def plot_pca_of_all_vectors ():
70
- """
71
- plot pca of all vectors in llama 2 7b chat
72
- normalize vectors before pca
73
- """
74
- all_vectors = []
75
- n_layers = 32
76
- for behavior in ALL_BEHAVIORS :
77
- caa_info = get_caa_info (behavior , "7b" , False )
78
- all_vectors .extend (caa_info ["vectors" ])
79
- all_vectors = t .stack (all_vectors )
80
- # normalize vectors for pca (mean 0, std 1)
81
- all_vectors = (all_vectors - all_vectors .mean (dim = 0 )) / all_vectors .std (dim = 0 )
82
- pca = PCA (n_components = 2 )
83
- pca .fit (all_vectors )
84
- pca_vectors = pca .transform (all_vectors )
85
- plt .figure (figsize = (5 , 5 ))
86
- for i , behavior in enumerate (ALL_BEHAVIORS ):
87
- start = i * n_layers
88
- end = start + n_layers
89
- plt .scatter (pca_vectors [start :end , 0 ], pca_vectors [start :end , 1 ], label = HUMAN_NAMES [behavior ])
90
- plt .xlabel ("PC1" )
91
- plt .ylabel ("PC2" )
92
- plt .title ("PCA of all steering vectors" )
93
- plt .legend ()
94
- plt .tight_layout ()
95
- plt .savefig (os .path .join (ANALYSIS_PATH , "pca_all_vectors.png" ), format = 'png' )
96
- plt .close ()
97
69
98
70
if __name__ == "__main__" :
99
- for behavior in ALL_BEHAVIORS :
71
+ for behavior in tqdm ( ALL_BEHAVIORS ) :
100
72
plot_per_layer_similarities ("7b" , True , behavior )
101
73
plot_per_layer_similarities ("7b" , False , behavior )
102
74
plot_per_layer_similarities ("13b" , False , behavior )
103
75
plot_base_chat_similarities ()
104
- plot_pca_of_all_vectors ()
0 commit comments