mirror of
				https://github.com/osmarks/random-stuff
				synced 2025-10-31 13:53:01 +00:00 
			
		
		
		
	incomprehensible pytorch
This commit is contained in:
		
							
								
								
									
										105
									
								
								sparse_act_gram_matrix.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								sparse_act_gram_matrix.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,105 @@ | ||||
| import numpy as np | ||||
| import torch as th | ||||
|  | ||||
| import pickle | ||||
|  | ||||
| LOAD_FROM_SAVED = False | ||||
|  | ||||
| # feature dictionary | ||||
|  | ||||
| DEV = "cuda" | ||||
| features = th.randn(10000, 100000, device=DEV) | ||||
| features /= th.linalg.norm(features, axis=1, keepdims=True) | ||||
|  | ||||
| def sample_activations(batch_size, features, max_act=1, p_act=0.01): | ||||
|     print(100 / features.shape[0], "p act") | ||||
|     active = th.rand(batch_size, features.shape[0], device=DEV) < (100 / features.shape[0]) | ||||
|     print("here") | ||||
|     activation = th.rand(batch_size, features.shape[0], device=DEV) * max_act | ||||
|     print("there") | ||||
|     activation[active == False] = 0 | ||||
|     return th.einsum('ij,bi->bj', features, activation) | ||||
|  | ||||
| def calc_gram_matrix(activations): | ||||
|     return th.einsum('bi,ci->bc', activations, activations) | ||||
|  | ||||
| sample_sizes = [1000] | ||||
|  | ||||
| for sample_size in sample_sizes: | ||||
|     if LOAD_FROM_SAVED: | ||||
|         with open('acts.pkl', 'rb') as f: | ||||
|             acts = pickle.load(f) | ||||
|     else: | ||||
|         acts = sample_activations(sample_size, features) | ||||
|  | ||||
|         with open('acts.pkl', 'wb') as f: | ||||
|             pickle.dump(acts, f) | ||||
|  | ||||
|     print("sampled") | ||||
|  | ||||
|     # fit normal distribution to activations | ||||
|     means = th.mean(acts, axis=1) | ||||
|     cov = th.cov(acts) | ||||
|     cov = cov + th.eye(cov.shape[0], cov.shape[1], device=DEV) * 1e2 | ||||
|  | ||||
|     print("fitted") | ||||
|  | ||||
|     # sample from normal distribution | ||||
|     print(means.shape, cov.shape, sample_size, acts.shape) | ||||
|     normal_acts = th.distributions.multivariate_normal.MultivariateNormal(means, cov).sample_n((sample_size)).to(DEV) | ||||
|  | ||||
|     if LOAD_FROM_SAVED: | ||||
|         with open('gram.pkl', 'rb') as f: | ||||
|             gram = pickle.load(f) | ||||
|     else: | ||||
|         gram = calc_gram_matrix(acts) | ||||
|  | ||||
|         with open('gram.pkl', 'wb') as f: | ||||
|             pickle.dump(gram, f) | ||||
|      | ||||
|     # set diagonal to 0 | ||||
|     #gram.fill_diagonal_(0) | ||||
|  | ||||
|     print("grammed") | ||||
|  | ||||
|     #normal_gram = calc_gram_matrix(normal_acts) | ||||
|  | ||||
|     print("normal grammed") | ||||
|  | ||||
|     # flatten gram matrix & plot histogram | ||||
|     #gram_flat = gram.flatten().cpu().numpy() | ||||
|     mask = th.ones_like(gram, device=DEV, dtype=th.bool) ^ th.eye(*gram.shape, device=DEV, dtype=th.bool) | ||||
|     gram_flat = gram[mask].cpu().numpy() | ||||
|     #normal_gram_flat = normal_gram.flatten() | ||||
|  | ||||
|     # fit normal distribution to gram matrix | ||||
|     gram_mean = np.mean(gram_flat) | ||||
|     gram_std = np.std(gram_flat) | ||||
|  | ||||
|     print("gram fitted") | ||||
|  | ||||
|     import matplotlib.pyplot as plt | ||||
|  | ||||
|     print("plotting") | ||||
|  | ||||
|     y, x, _ = plt.hist(gram_flat, bins=200, alpha=0.5, label='sampled', density=True) | ||||
|     #plt.hist(normal_gram_flat, bins=250, alpha=0.5, label='normal', density=True) | ||||
|  | ||||
|     # take min and max of both histograms | ||||
|     hist_min = np.min(gram_flat) | ||||
|     hist_max = np.max(gram_flat) | ||||
|     mean = np.mean(gram_flat) | ||||
|     # count things in gram_flat which are greater than and less than 0 | ||||
|     print(np.sum(gram_flat < mean), np.sum(gram_flat > mean), mean) | ||||
|  | ||||
|     # plot normal distribution pdf | ||||
|     x = np.linspace(hist_min, hist_max, 200) | ||||
|     plt.plot(x, 1 / (gram_std * np.sqrt(2 * np.pi)) * np.exp( - (x - gram_mean)**2 / (2 * gram_std**2) ), linewidth=2, color='r', label='normal fit') | ||||
|  | ||||
|     ax = plt.gca() | ||||
|     #ax.set_yscale("function", functions=(lambda x: x**(1/3), lambda x: x**3.0)) | ||||
|  | ||||
|     ax.set_ylim(0, y.max()) | ||||
|  | ||||
|     plt.legend(loc='upper right') | ||||
|     plt.savefig("/media/plot.png", dpi=1000) | ||||
		Reference in New Issue
	
	Block a user