diff --git a/sparse_act_gram_matrix.py b/sparse_act_gram_matrix.py new file mode 100644 index 0000000..1480bab --- /dev/null +++ b/sparse_act_gram_matrix.py @@ -0,0 +1,105 @@ +import numpy as np +import torch as th + +import pickle + +LOAD_FROM_SAVED = False + +# feature dictionary + +DEV = "cuda" +features = th.randn(10000, 100000, device=DEV) +features /= th.linalg.norm(features, axis=1, keepdims=True) + +def sample_activations(batch_size, features, max_act=1, p_act=0.01): + print(100 / features.shape[0], "p act") + active = th.rand(batch_size, features.shape[0], device=DEV) < (100 / features.shape[0]) + print("here") + activation = th.rand(batch_size, features.shape[0], device=DEV) * max_act + print("there") + activation[active == False] = 0 + return th.einsum('ij,bi->bj', features, activation) + +def calc_gram_matrix(activations): + return th.einsum('bi,ci->bc', activations, activations) + +sample_sizes = [1000] + +for sample_size in sample_sizes: + if LOAD_FROM_SAVED: + with open('acts.pkl', 'rb') as f: + acts = pickle.load(f) + else: + acts = sample_activations(sample_size, features) + + with open('acts.pkl', 'wb') as f: + pickle.dump(acts, f) + + print("sampled") + + # fit normal distribution to activations + means = th.mean(acts, axis=1) + cov = th.cov(acts) + cov = cov + th.eye(cov.shape[0], cov.shape[1], device=DEV) * 1e2 + + print("fitted") + + # sample from normal distribution + print(means.shape, cov.shape, sample_size, acts.shape) + normal_acts = th.distributions.multivariate_normal.MultivariateNormal(means, cov).sample_n((sample_size)).to(DEV) + + if LOAD_FROM_SAVED: + with open('gram.pkl', 'rb') as f: + gram = pickle.load(f) + else: + gram = calc_gram_matrix(acts) + + with open('gram.pkl', 'wb') as f: + pickle.dump(gram, f) + + # set diagonal to 0 + #gram.fill_diagonal_(0) + + print("grammed") + + #normal_gram = calc_gram_matrix(normal_acts) + + print("normal grammed") + + # flatten gram matrix & plot histogram + #gram_flat = gram.flatten().cpu().numpy() + mask = th.ones_like(gram, device=DEV, dtype=th.bool) ^ th.eye(*gram.shape, device=DEV, dtype=th.bool) + gram_flat = gram[mask].cpu().numpy() + #normal_gram_flat = normal_gram.flatten() + + # fit normal distribution to gram matrix + gram_mean = np.mean(gram_flat) + gram_std = np.std(gram_flat) + + print("gram fitted") + + import matplotlib.pyplot as plt + + print("plotting") + + y, x, _ = plt.hist(gram_flat, bins=200, alpha=0.5, label='sampled', density=True) + #plt.hist(normal_gram_flat, bins=250, alpha=0.5, label='normal', density=True) + + # take min and max of both histograms + hist_min = np.min(gram_flat) + hist_max = np.max(gram_flat) + mean = np.mean(gram_flat) + # count things in gram_flat which are greater than and less than 0 + print(np.sum(gram_flat < mean), np.sum(gram_flat > mean), mean) + + # plot normal distribution pdf + x = np.linspace(hist_min, hist_max, 200) + plt.plot(x, 1 / (gram_std * np.sqrt(2 * np.pi)) * np.exp( - (x - gram_mean)**2 / (2 * gram_std**2) ), linewidth=2, color='r', label='normal fit') + + ax = plt.gca() + #ax.set_yscale("function", functions=(lambda x: x**(1/3), lambda x: x**3.0)) + + ax.set_ylim(0, y.max()) + + plt.legend(loc='upper right') + plt.savefig("/media/plot.png", dpi=1000) \ No newline at end of file