mirror of
https://github.com/osmarks/random-stuff
synced 2024-12-21 23:50:35 +00:00
incomprehensible pytorch
This commit is contained in:
parent
df51de06e6
commit
61f69bc003
105
sparse_act_gram_matrix.py
Normal file
105
sparse_act_gram_matrix.py
Normal file
@ -0,0 +1,105 @@
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
import pickle
|
||||
|
||||
LOAD_FROM_SAVED = False
|
||||
|
||||
# feature dictionary
|
||||
|
||||
DEV = "cuda"
|
||||
features = th.randn(10000, 100000, device=DEV)
|
||||
features /= th.linalg.norm(features, axis=1, keepdims=True)
|
||||
|
||||
def sample_activations(batch_size, features, max_act=1, p_act=0.01):
|
||||
print(100 / features.shape[0], "p act")
|
||||
active = th.rand(batch_size, features.shape[0], device=DEV) < (100 / features.shape[0])
|
||||
print("here")
|
||||
activation = th.rand(batch_size, features.shape[0], device=DEV) * max_act
|
||||
print("there")
|
||||
activation[active == False] = 0
|
||||
return th.einsum('ij,bi->bj', features, activation)
|
||||
|
||||
def calc_gram_matrix(activations):
|
||||
return th.einsum('bi,ci->bc', activations, activations)
|
||||
|
||||
sample_sizes = [1000]
|
||||
|
||||
for sample_size in sample_sizes:
|
||||
if LOAD_FROM_SAVED:
|
||||
with open('acts.pkl', 'rb') as f:
|
||||
acts = pickle.load(f)
|
||||
else:
|
||||
acts = sample_activations(sample_size, features)
|
||||
|
||||
with open('acts.pkl', 'wb') as f:
|
||||
pickle.dump(acts, f)
|
||||
|
||||
print("sampled")
|
||||
|
||||
# fit normal distribution to activations
|
||||
means = th.mean(acts, axis=1)
|
||||
cov = th.cov(acts)
|
||||
cov = cov + th.eye(cov.shape[0], cov.shape[1], device=DEV) * 1e2
|
||||
|
||||
print("fitted")
|
||||
|
||||
# sample from normal distribution
|
||||
print(means.shape, cov.shape, sample_size, acts.shape)
|
||||
normal_acts = th.distributions.multivariate_normal.MultivariateNormal(means, cov).sample_n((sample_size)).to(DEV)
|
||||
|
||||
if LOAD_FROM_SAVED:
|
||||
with open('gram.pkl', 'rb') as f:
|
||||
gram = pickle.load(f)
|
||||
else:
|
||||
gram = calc_gram_matrix(acts)
|
||||
|
||||
with open('gram.pkl', 'wb') as f:
|
||||
pickle.dump(gram, f)
|
||||
|
||||
# set diagonal to 0
|
||||
#gram.fill_diagonal_(0)
|
||||
|
||||
print("grammed")
|
||||
|
||||
#normal_gram = calc_gram_matrix(normal_acts)
|
||||
|
||||
print("normal grammed")
|
||||
|
||||
# flatten gram matrix & plot histogram
|
||||
#gram_flat = gram.flatten().cpu().numpy()
|
||||
mask = th.ones_like(gram, device=DEV, dtype=th.bool) ^ th.eye(*gram.shape, device=DEV, dtype=th.bool)
|
||||
gram_flat = gram[mask].cpu().numpy()
|
||||
#normal_gram_flat = normal_gram.flatten()
|
||||
|
||||
# fit normal distribution to gram matrix
|
||||
gram_mean = np.mean(gram_flat)
|
||||
gram_std = np.std(gram_flat)
|
||||
|
||||
print("gram fitted")
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
print("plotting")
|
||||
|
||||
y, x, _ = plt.hist(gram_flat, bins=200, alpha=0.5, label='sampled', density=True)
|
||||
#plt.hist(normal_gram_flat, bins=250, alpha=0.5, label='normal', density=True)
|
||||
|
||||
# take min and max of both histograms
|
||||
hist_min = np.min(gram_flat)
|
||||
hist_max = np.max(gram_flat)
|
||||
mean = np.mean(gram_flat)
|
||||
# count things in gram_flat which are greater than and less than 0
|
||||
print(np.sum(gram_flat < mean), np.sum(gram_flat > mean), mean)
|
||||
|
||||
# plot normal distribution pdf
|
||||
x = np.linspace(hist_min, hist_max, 200)
|
||||
plt.plot(x, 1 / (gram_std * np.sqrt(2 * np.pi)) * np.exp( - (x - gram_mean)**2 / (2 * gram_std**2) ), linewidth=2, color='r', label='normal fit')
|
||||
|
||||
ax = plt.gca()
|
||||
#ax.set_yscale("function", functions=(lambda x: x**(1/3), lambda x: x**3.0))
|
||||
|
||||
ax.set_ylim(0, y.max())
|
||||
|
||||
plt.legend(loc='upper right')
|
||||
plt.savefig("/media/plot.png", dpi=1000)
|
Loading…
Reference in New Issue
Block a user