1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-02-22 14:00:09 +00:00
meme-search-engine/diskann/scalar_quantize.py
2024-12-31 23:05:48 +00:00

168 lines
6.0 KiB
Python

import numpy as np
import msgpack
import math
n_dims = 1152
n_buckets = n_dims
#n_buckets = n_dims // 2 # we now have one quant scale per pair of components
#pair_separation = 16 # for efficient dot product computation, we need to have the second element of a pair exactly chunk_size after the first
n_dims_per_bucket = n_dims // n_buckets
data = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims).astype(np.float32) # sorry
CUTOFF = 1e-3 / 2
print("computing quantiles")
smin = np.quantile(data, CUTOFF, axis=0)
smax = np.quantile(data, 1 - CUTOFF, axis=0)
# naive O(n²) greedy algorithm
# probably overbuilt for the 2-components-per-bucket case but I'm not getting rid of it
def assign_buckets():
import random
intervals = list(enumerate(zip(smin, smax)))
random.shuffle(intervals)
buckets = [ [ intervals.pop() ] for _ in range(n_buckets) ]
def bucket_cost(bucket):
bmin = min(cmin for id, (cmin, cmax) in bucket)
bmax = max(cmax for id, (cmin, cmax) in bucket)
#print("MIN", bmin, "MAX", bmax)
return sum(abs(cmin - bmin) + abs(cmax - bmax) for id, (cmin, cmax) in bucket)
while len(intervals):
for bucket in buckets:
def new_interval_cost(interval):
return bucket_cost(bucket + [interval[1]])
i, interval = min(enumerate(intervals), key=new_interval_cost)
bucket.append(intervals.pop(i))
return buckets
ranges = smax - smin
# TODO: it is possible to do better assignment to buckets
#order = np.argsort(ranges)
print("bucket assignment")
order = np.arange(n_dims) # np.concatenate(np.stack([ [ id for id, (cmin, cmax) in bucket ] for bucket in assign_buckets() ]))
bucket_ranges = []
bucket_centres = []
bucket_absmax = []
bucket_gmins = []
for bucket_min in range(0, n_dims, n_dims_per_bucket):
bucket_max = bucket_min + n_dims_per_bucket
indices = order[bucket_min:bucket_max]
gmin = float(np.min(smin[indices]))
gmax = float(np.max(smax[indices]))
bucket_range = gmax - gmin
bucket_centre = (gmax + gmin) / 2
bucket_gmins.append(gmin)
bucket_ranges.append(bucket_range)
bucket_centres.append(bucket_centre)
bucket_absmax.append(max(abs(gmin), abs(gmax)))
print("determining scales")
scales = [] # multiply by float and convert to quantize
offsets = []
q_offsets = [] # int16 value to add at dot time
q_scales = [] # rescales channel up at dot time; must be proportional(ish) to square of scale factor but NOT cause overflow in accumulation or PLMULLW
scale_factor_bound = float("inf")
for bucket in range(n_buckets):
step_size = bucket_ranges[bucket] / 255
scales.append(1 / step_size)
q_offset = int(bucket_gmins[bucket] / step_size)
q_offsets.append(q_offset)
nsfb = (2**31 - 1) / (n_dims_per_bucket * abs((255**2) + 2 * q_offset * 255 + q_offset ** 2)) / 2
# we are bounded both by overflow in accumulation and PLMULLW (u8 plus offset times scale factor)
scale_factor_bound = min(scale_factor_bound, nsfb, (2**15 - 1) // (q_offset + 255))
offsets.append(bucket_gmins[bucket])
for bucket in range(n_buckets):
sfb = scale_factor_bound / max(map(lambda x: x ** 2, bucket_ranges))
sf = (bucket_ranges[bucket]) ** 2 * sfb
q_scales.append(int(sf))
print(bucket_ranges, bucket_centres, bucket_absmax)
print(scales, offsets, q_offsets, q_scales)
"""
interleave = np.concatenate([
np.arange(0, n_dims, n_dims_per_bucket) + a
for a in range(n_dims_per_bucket)
])
"""
"""
interleave = np.arange(0, n_dims)
for base in range(0, n_dims, 2 * pair_separation):
interleave[base:base + pair_separation] = np.arange(base, base + 2 * pair_separation, 2)
interleave[base + pair_separation:base + 2 * pair_separation] = np.arange(base + 1, base + 2 * pair_separation + 1, 2)
"""
#print(bucket_ranges, bucket_centres, order[interleave])
#print(ranges[order][interleave].tolist())
#print(ranges.tolist())
with open("quantizer.msgpack", "wb") as f:
msgpack.pack({
"permutation": order.tolist(),
"offsets": offsets,
"scales": scales,
"q_offsets": q_offsets,
"q_scales": q_scales
}, f)
def rquantize(vec):
out = np.zeros(len(vec), dtype=np.uint8)
for i, p in enumerate(order[interleave]):
bucket = p % n_buckets
raw = vec[i]
raw = (raw - offsets[bucket]) * scales[bucket]
raw = min(max(raw, 0.0), 255.0)
out[p] = round(raw)
return out
def rdquantize(bytes):
vec = np.zeros(n_dims, dtype=np.float32)
for i, p in enumerate(order[interleave]):
bucket = p % n_buckets
raw = float(bytes[p])
vec[i] = raw / scales[bucket] + offsets[bucket]
return vec
def rdot(x, y):
xq_offsets = np.array(q_offsets, dtype=np.int16)
xq_scales = np.array(q_scales, dtype=np.int16)
assert x.shape == y.shape
assert x.dtype == np.uint8 == y.dtype
acc = 0
for i in range(0, len(x), n_buckets):
x1 = x[i:i+n_buckets].astype(np.int16) + xq_offsets
y1 = y[i:i+n_buckets].astype(np.int16) + xq_offsets
x1 *= xq_scales
acc += np.dot(x1.astype(np.int32), y1.astype(np.int32))
return acc
def cmp(i, j):
return np.dot(data[i], data[j]) / rdot(rquantize(data[i]), rquantize(data[j]))
def rdot_cmp(a, b):
x = rquantize(a)
y = rquantize(b)
a = a[order[interleave]]
b = b[order[interleave]]
xq_offsets = np.array(q_offsets, dtype=np.int16)
xq_scales = np.array(q_scales, dtype=np.int16)
assert x.shape == y.shape
assert x.dtype == np.uint8 == y.dtype
acc = 0
for i in range(0, len(x), n_buckets):
x1 = x[i:i+n_buckets].astype(np.int16) + xq_offsets
y1 = y[i:i+n_buckets].astype(np.int16) + xq_offsets
x1 *= xq_scales
component = np.dot(x1.astype(np.int32), y1.astype(np.int32))
a1 = a[i:i+n_buckets]
b1 = b[i:i+n_buckets]
component_exact = np.dot(a1, b1)
print(x1, a1, sep="\n")
print(component, component_exact, component / component_exact)
acc += component
return acc