1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-11-13 07:19:54 +00:00

performance improvements

This commit is contained in:
osmarks 2024-11-07 16:52:58 +00:00
parent 7fa14d45ae
commit b9bb629e6f
7 changed files with 62 additions and 17 deletions

30
Cargo.lock generated
View File

@ -697,6 +697,15 @@ dependencies = [
"subtle",
]
[[package]]
name = "document-features"
version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb6969eaabd2421f8a2775cfd2471a2b634372b4a25d41e3bd647b79912850a0"
dependencies = [
"litrs",
]
[[package]]
name = "dotenvy"
version = "0.15.7"
@ -794,6 +803,20 @@ dependencies = [
"hashbrown 0.13.2",
]
[[package]]
name = "fast_image_resize"
version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a66a61fbfc84ef99a839499cf9e5a7c2951d2da874ea00f29ee938bc50d1b396"
dependencies = [
"bytemuck",
"cfg-if",
"document-features",
"image",
"num-traits",
"thiserror",
]
[[package]]
name = "fastrand"
version = "2.1.1"
@ -1558,6 +1581,12 @@ version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
[[package]]
name = "litrs"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ce301924b7887e9d637144fdade93f9dfff9b60981d4ac161db09720d39aa5"
[[package]]
name = "lock_api"
version = "0.4.12"
@ -1645,6 +1674,7 @@ dependencies = [
"compact_str",
"console-subscriber",
"faiss",
"fast_image_resize",
"fastrand",
"ffmpeg-the-third",
"fnv",

View File

@ -43,6 +43,7 @@ ffmpeg-the-third = "2.0"
compact_str = { version = "0.8.0-beta", features = ["serde"] }
itertools = "0.13"
async-recursion = "1"
fast_image_resize = { version = "5", features = ["image"] }
[[bin]]
name = "reddit-dump"

View File

@ -5,7 +5,7 @@ from aiohttp import web
import aiohttp
import asyncio
import traceback
import umsgpack
import msgpack
import collections
import queue
import open_clip
@ -58,7 +58,7 @@ if CONFIG.get("aitemplate_image_models"):
if "patch_embed.proj.weight" not in key:
params[key.replace(".", "_")] = value.cuda()
#print(orig_key, key.replace(".", "_"))
params["patch_embed_proj_weight"] = conv_weights
return params
@ -151,7 +151,7 @@ routes = web.RouteTableDef()
@routes.post("/")
async def run_inference(request):
loop = asyncio.get_event_loop()
data = umsgpack.loads(await request.read())
data = msgpack.loads(await request.read())
event = asyncio.Event()
results = None
def callback(*argv):
@ -167,11 +167,11 @@ async def run_inference(request):
else:
status = 500
print(results[1])
return web.Response(body=umsgpack.dumps(body_data), status=status, content_type="application/msgpack")
return web.Response(body=msgpack.dumps(body_data), status=status, content_type="application/msgpack")
@routes.get("/config")
async def config(request):
return web.Response(body=umsgpack.dumps({
return web.Response(body=msgpack.dumps({
"model": CONFIG["model"],
"batch": BS,
"image_size": [ t for t in preprocess.transforms if isinstance(t, transforms.Resize) ][0].size,
@ -206,4 +206,4 @@ try:
loop.run_forever()
except KeyboardInterrupt:
import sys
sys.exit(0)
sys.exit(0)

View File

@ -1,8 +1,8 @@
Pillow==10.0.1
prometheus-client==0.17.1
u-msgpack-python==2.8.0
aiohttp==3.8.5
aiohttp-cors==0.7.0
faiss-cpu==1.7.4
aiosqlite==0.19.0
open-clip-torch==2.23.0
open-clip-torch==2.23.0
msgpack==1.1.0

View File

@ -1,10 +1,19 @@
use image::codecs::bmp::BmpEncoder;
use serde::{Serialize, Deserialize};
use std::borrow::Borrow;
use image::{DynamicImage, imageops::FilterType, ImageFormat};
use std::cell::RefCell;
use image::{DynamicImage, ExtendedColorType, ImageEncoder};
use anyhow::Result;
use std::io::Cursor;
use reqwest::Client;
use tracing::instrument;
use fast_image_resize::Resizer;
use fast_image_resize::images::Image;
use anyhow::Context;
std::thread_local! {
static RESIZER: RefCell<Resizer> = RefCell::new(Resizer::new());
}
#[derive(Debug, Deserialize, Clone)]
pub struct InferenceServerConfig {
@ -16,14 +25,18 @@ pub struct InferenceServerConfig {
pub fn resize_for_embed_sync<T: Borrow<DynamicImage> + Send + 'static>(config: InferenceServerConfig, image: T) -> Result<Vec<u8>> {
// the model currently in use wants aspect ratio 1:1 regardless of input
// I think this was previously being handled in the CLIP server but that is slightly lossy
let new = image.borrow().resize_exact(
config.image_size.0,
config.image_size.1,
FilterType::CatmullRom
).into_rgb8();
let src_rgb = DynamicImage::from(image.borrow().to_rgb8()); // TODO this might be significantly inefficient for RGB8->RGB8 case
let mut dst_image = Image::new(config.image_size.0, config.image_size.1, fast_image_resize::PixelType::U8x3);
RESIZER.with_borrow_mut(|resizer| {
resizer.resize(&src_rgb, &mut dst_image, None)
}).context("resize failure")?;
let mut buf = Vec::new();
let mut csr = Cursor::new(&mut buf);
new.write_to(&mut csr, ImageFormat::Png)?;
BmpEncoder::new(&mut csr).write_image(dst_image.buffer(), config.image_size.0, config.image_size.1, ExtendedColorType::Rgb8)?;
Ok::<Vec<u8>, anyhow::Error>(buf)
}

View File

@ -3,6 +3,7 @@ use serde::{Serialize, Deserialize};
use std::io::BufReader;
use rmp_serde::decode::Error as DecodeError;
use std::fs;
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
// TODO refactor
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)]
@ -40,7 +41,7 @@ fn main() -> Result<()> {
match res {
Ok(x) => {
if x.timestamp > latest_timestamp {
println!("{} {} https://reddit.com/r/{}/comments/{}", x.timestamp, count, x.subreddit, x.id);
println!("{} {} https://reddit.com/r/{}/comments/{} {} https://mse.osmarks.net/?e={}", x.timestamp, count, x.subreddit, x.id, x.metadata.final_url, URL_SAFE.encode(x.embedding));
latest_timestamp = x.timestamp;
}
},

View File

@ -16,7 +16,7 @@ use axum::{
use common::resize_for_embed_sync;
use compact_str::CompactString;
use image::RgbImage;
use image::{imageops::FilterType, io::Reader as ImageReader, DynamicImage, ImageFormat};
use image::{imageops::FilterType, ImageReader, DynamicImage, ImageFormat};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use sqlx::SqliteConnection;