mirror of
https://github.com/osmarks/random-stuff
synced 2024-12-26 18:10:34 +00:00
105 lines
3.1 KiB
Python
105 lines
3.1 KiB
Python
import numpy as np
|
|
import torch as th
|
|
|
|
import pickle
|
|
|
|
LOAD_FROM_SAVED = False
|
|
|
|
# feature dictionary
|
|
|
|
DEV = "cuda"
|
|
features = th.randn(10000, 100000, device=DEV)
|
|
features /= th.linalg.norm(features, axis=1, keepdims=True)
|
|
|
|
def sample_activations(batch_size, features, max_act=1, p_act=0.01):
|
|
print(100 / features.shape[0], "p act")
|
|
active = th.rand(batch_size, features.shape[0], device=DEV) < (100 / features.shape[0])
|
|
print("here")
|
|
activation = th.rand(batch_size, features.shape[0], device=DEV) * max_act
|
|
print("there")
|
|
activation[active == False] = 0
|
|
return th.einsum('ij,bi->bj', features, activation)
|
|
|
|
def calc_gram_matrix(activations):
|
|
return th.einsum('bi,ci->bc', activations, activations)
|
|
|
|
sample_sizes = [1000]
|
|
|
|
for sample_size in sample_sizes:
|
|
if LOAD_FROM_SAVED:
|
|
with open('acts.pkl', 'rb') as f:
|
|
acts = pickle.load(f)
|
|
else:
|
|
acts = sample_activations(sample_size, features)
|
|
|
|
with open('acts.pkl', 'wb') as f:
|
|
pickle.dump(acts, f)
|
|
|
|
print("sampled")
|
|
|
|
# fit normal distribution to activations
|
|
means = th.mean(acts, axis=1)
|
|
cov = th.cov(acts)
|
|
cov = cov + th.eye(cov.shape[0], cov.shape[1], device=DEV) * 1e2
|
|
|
|
print("fitted")
|
|
|
|
# sample from normal distribution
|
|
print(means.shape, cov.shape, sample_size, acts.shape)
|
|
normal_acts = th.distributions.multivariate_normal.MultivariateNormal(means, cov).sample_n((sample_size)).to(DEV)
|
|
|
|
if LOAD_FROM_SAVED:
|
|
with open('gram.pkl', 'rb') as f:
|
|
gram = pickle.load(f)
|
|
else:
|
|
gram = calc_gram_matrix(acts)
|
|
|
|
with open('gram.pkl', 'wb') as f:
|
|
pickle.dump(gram, f)
|
|
|
|
# set diagonal to 0
|
|
#gram.fill_diagonal_(0)
|
|
|
|
print("grammed")
|
|
|
|
#normal_gram = calc_gram_matrix(normal_acts)
|
|
|
|
print("normal grammed")
|
|
|
|
# flatten gram matrix & plot histogram
|
|
#gram_flat = gram.flatten().cpu().numpy()
|
|
mask = th.ones_like(gram, device=DEV, dtype=th.bool) ^ th.eye(*gram.shape, device=DEV, dtype=th.bool)
|
|
gram_flat = gram[mask].cpu().numpy()
|
|
#normal_gram_flat = normal_gram.flatten()
|
|
|
|
# fit normal distribution to gram matrix
|
|
gram_mean = np.mean(gram_flat)
|
|
gram_std = np.std(gram_flat)
|
|
|
|
print("gram fitted")
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
print("plotting")
|
|
|
|
y, x, _ = plt.hist(gram_flat, bins=200, alpha=0.5, label='sampled', density=True)
|
|
#plt.hist(normal_gram_flat, bins=250, alpha=0.5, label='normal', density=True)
|
|
|
|
# take min and max of both histograms
|
|
hist_min = np.min(gram_flat)
|
|
hist_max = np.max(gram_flat)
|
|
mean = np.mean(gram_flat)
|
|
# count things in gram_flat which are greater than and less than 0
|
|
print(np.sum(gram_flat < mean), np.sum(gram_flat > mean), mean)
|
|
|
|
# plot normal distribution pdf
|
|
x = np.linspace(hist_min, hist_max, 200)
|
|
plt.plot(x, 1 / (gram_std * np.sqrt(2 * np.pi)) * np.exp( - (x - gram_mean)**2 / (2 * gram_std**2) ), linewidth=2, color='r', label='normal fit')
|
|
|
|
ax = plt.gca()
|
|
#ax.set_yscale("function", functions=(lambda x: x**(1/3), lambda x: x**3.0))
|
|
|
|
ax.set_ylim(0, y.max())
|
|
|
|
plt.legend(loc='upper right')
|
|
plt.savefig("/media/plot.png", dpi=1000) |