-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
search_category.py
99 lines (88 loc) · 3.67 KB
/
search_category.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import json
import torch
import math
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from sae_lens import SAE
from datasets import load_dataset, concatenate_datasets
from tqdm import tqdm
from sklearn.preprocessing import LabelBinarizer
from scipy.stats import pointbiserialr
model_id = 'holistic-ai/gpt2-EMGSD'
sae_path = 'sae'
dataset_name = 'holistic-ai/EMGSD'
target_layer = 11
output_file = 'feature_label_correlations.json'
# Load tokenizer and model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
model.eval()
# Load the SAE
sae = SAE.load_from_pretrained(os.path.join(sae_path), device=device)
# sae, cfg_dict, sparcity = SAE.from_pretrained('jbloom/GPT2-Small-SAEs-Reformatted', 'blocks.11.hook_resid_post', device=device)
d_model = model.config.n_embd
# Load the test dataset and filter
def preprocess_dataset(dataset_name):
dataset = load_dataset(dataset_name)
data = dataset['test'].select_columns(['category', 'text'])
return data
data = preprocess_dataset(dataset_name)
# Collect labels and texts
labels = []
texts = []
print("Preparing data...")
for example in tqdm(data, desc="Processing"):
labels.append(example['category'])
texts.append(example['text'])
lb = LabelBinarizer()
binary_labels = lb.fit_transform(labels)
# Tokenize the dataset
tokens_list = []
print("Tokenizing dataset...")
for text in tqdm(texts, desc="Tokenizing"):
encoding = tokenizer(text, return_tensors='pt', truncation=True, max_length=512).to(device)
tokens_list.append(encoding)
# Collect feature activations
feature_activations = []
print("Collecting feature activations...")
for encoding in tqdm(tokens_list, desc="Collecting activations"):
input_ids = encoding['input_ids']
with torch.no_grad():
outputs = model(**encoding, output_hidden_states=True)
hidden_states = outputs.hidden_states[target_layer]
activations = sae.encode(hidden_states)
activations = activations.squeeze(0).cpu().numpy() # (sequence_length, num_features)
binary_activations = (activations > 0).astype(int)
# Aggregate activations over the sequence (e.g., max or mean)
aggregated_activations = binary_activations.max(axis=0) # (num_features)
feature_activations.append(aggregated_activations)
feature_activations = np.array(feature_activations) # (num_samples, num_features)
num_labels = binary_labels.shape[1]
num_features = feature_activations.shape[1]
correlations = {}
print(f"Number of featuers is {num_features}")
# Compute point biserial correlation
print("Computing correlations between features and labels...")
for label_idx in range(num_labels):
label_name = lb.classes_[label_idx]
label_values = binary_labels[:, label_idx]
correlations[label_name] = []
for feature_idx in range(num_features):
feature_values = feature_activations[:, feature_idx]
corr, p_value = pointbiserialr(label_values, feature_values)
correlations[label_name].append({
'feature_index': feature_idx,
'correlation': corr,
'p_value': p_value
})
# Find features with highest correlation for each label
top_features = {}
for label_name in correlations:
sorted_features = sorted(correlations[label_name], key=lambda x: 0 if math.isnan(x['correlation']) else abs(x['correlation']), reverse=True)
top_features[label_name] = sorted_features[:10]
# Save results to JSON
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(top_features, f, indent=2, ensure_ascii=False)
print(f"Correlation results saved to '{output_file}'.")