incomprehensible pytorch

This commit is contained in:
osmarks 2023-10-18 20:55:02 +01:00
parent df51de06e6
commit 61f69bc003
1 changed files with 105 additions and 0 deletions

105
sparse_act_gram_matrix.py Normal file
View File

@ -0,0 +1,105 @@
import numpy as np
import torch as th
import pickle
LOAD_FROM_SAVED = False
# feature dictionary
DEV = "cuda"
features = th.randn(10000, 100000, device=DEV)
features /= th.linalg.norm(features, axis=1, keepdims=True)
def sample_activations(batch_size, features, max_act=1, p_act=0.01):
print(100 / features.shape[0], "p act")
active = th.rand(batch_size, features.shape[0], device=DEV) < (100 / features.shape[0])
print("here")
activation = th.rand(batch_size, features.shape[0], device=DEV) * max_act
print("there")
activation[active == False] = 0
return th.einsum('ij,bi->bj', features, activation)
def calc_gram_matrix(activations):
return th.einsum('bi,ci->bc', activations, activations)
sample_sizes = [1000]
for sample_size in sample_sizes:
if LOAD_FROM_SAVED:
with open('acts.pkl', 'rb') as f:
acts = pickle.load(f)
else:
acts = sample_activations(sample_size, features)
with open('acts.pkl', 'wb') as f:
pickle.dump(acts, f)
print("sampled")
# fit normal distribution to activations
means = th.mean(acts, axis=1)
cov = th.cov(acts)
cov = cov + th.eye(cov.shape[0], cov.shape[1], device=DEV) * 1e2
print("fitted")
# sample from normal distribution
print(means.shape, cov.shape, sample_size, acts.shape)
normal_acts = th.distributions.multivariate_normal.MultivariateNormal(means, cov).sample_n((sample_size)).to(DEV)
if LOAD_FROM_SAVED:
with open('gram.pkl', 'rb') as f:
gram = pickle.load(f)
else:
gram = calc_gram_matrix(acts)
with open('gram.pkl', 'wb') as f:
pickle.dump(gram, f)
# set diagonal to 0
#gram.fill_diagonal_(0)
print("grammed")
#normal_gram = calc_gram_matrix(normal_acts)
print("normal grammed")
# flatten gram matrix & plot histogram
#gram_flat = gram.flatten().cpu().numpy()
mask = th.ones_like(gram, device=DEV, dtype=th.bool) ^ th.eye(*gram.shape, device=DEV, dtype=th.bool)
gram_flat = gram[mask].cpu().numpy()
#normal_gram_flat = normal_gram.flatten()
# fit normal distribution to gram matrix
gram_mean = np.mean(gram_flat)
gram_std = np.std(gram_flat)
print("gram fitted")
import matplotlib.pyplot as plt
print("plotting")
y, x, _ = plt.hist(gram_flat, bins=200, alpha=0.5, label='sampled', density=True)
#plt.hist(normal_gram_flat, bins=250, alpha=0.5, label='normal', density=True)
# take min and max of both histograms
hist_min = np.min(gram_flat)
hist_max = np.max(gram_flat)
mean = np.mean(gram_flat)
# count things in gram_flat which are greater than and less than 0
print(np.sum(gram_flat < mean), np.sum(gram_flat > mean), mean)
# plot normal distribution pdf
x = np.linspace(hist_min, hist_max, 200)
plt.plot(x, 1 / (gram_std * np.sqrt(2 * np.pi)) * np.exp( - (x - gram_mean)**2 / (2 * gram_std**2) ), linewidth=2, color='r', label='normal fit')
ax = plt.gca()
#ax.set_yscale("function", functions=(lambda x: x**(1/3), lambda x: x**3.0))
ax.set_ylim(0, y.max())
plt.legend(loc='upper right')
plt.savefig("/media/plot.png", dpi=1000)