mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-13 07:19:54 +00:00
performance improvements
This commit is contained in:
parent
7fa14d45ae
commit
b9bb629e6f
30
Cargo.lock
generated
30
Cargo.lock
generated
@ -697,6 +697,15 @@ dependencies = [
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "document-features"
|
||||
version = "0.2.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cb6969eaabd2421f8a2775cfd2471a2b634372b4a25d41e3bd647b79912850a0"
|
||||
dependencies = [
|
||||
"litrs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dotenvy"
|
||||
version = "0.15.7"
|
||||
@ -794,6 +803,20 @@ dependencies = [
|
||||
"hashbrown 0.13.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fast_image_resize"
|
||||
version = "5.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a66a61fbfc84ef99a839499cf9e5a7c2951d2da874ea00f29ee938bc50d1b396"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cfg-if",
|
||||
"document-features",
|
||||
"image",
|
||||
"num-traits",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.1.1"
|
||||
@ -1558,6 +1581,12 @@ version = "0.4.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
|
||||
|
||||
[[package]]
|
||||
name = "litrs"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4ce301924b7887e9d637144fdade93f9dfff9b60981d4ac161db09720d39aa5"
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.12"
|
||||
@ -1645,6 +1674,7 @@ dependencies = [
|
||||
"compact_str",
|
||||
"console-subscriber",
|
||||
"faiss",
|
||||
"fast_image_resize",
|
||||
"fastrand",
|
||||
"ffmpeg-the-third",
|
||||
"fnv",
|
||||
|
@ -43,6 +43,7 @@ ffmpeg-the-third = "2.0"
|
||||
compact_str = { version = "0.8.0-beta", features = ["serde"] }
|
||||
itertools = "0.13"
|
||||
async-recursion = "1"
|
||||
fast_image_resize = { version = "5", features = ["image"] }
|
||||
|
||||
[[bin]]
|
||||
name = "reddit-dump"
|
||||
|
@ -5,7 +5,7 @@ from aiohttp import web
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import traceback
|
||||
import umsgpack
|
||||
import msgpack
|
||||
import collections
|
||||
import queue
|
||||
import open_clip
|
||||
@ -58,7 +58,7 @@ if CONFIG.get("aitemplate_image_models"):
|
||||
if "patch_embed.proj.weight" not in key:
|
||||
params[key.replace(".", "_")] = value.cuda()
|
||||
#print(orig_key, key.replace(".", "_"))
|
||||
|
||||
|
||||
params["patch_embed_proj_weight"] = conv_weights
|
||||
|
||||
return params
|
||||
@ -151,7 +151,7 @@ routes = web.RouteTableDef()
|
||||
@routes.post("/")
|
||||
async def run_inference(request):
|
||||
loop = asyncio.get_event_loop()
|
||||
data = umsgpack.loads(await request.read())
|
||||
data = msgpack.loads(await request.read())
|
||||
event = asyncio.Event()
|
||||
results = None
|
||||
def callback(*argv):
|
||||
@ -167,11 +167,11 @@ async def run_inference(request):
|
||||
else:
|
||||
status = 500
|
||||
print(results[1])
|
||||
return web.Response(body=umsgpack.dumps(body_data), status=status, content_type="application/msgpack")
|
||||
return web.Response(body=msgpack.dumps(body_data), status=status, content_type="application/msgpack")
|
||||
|
||||
@routes.get("/config")
|
||||
async def config(request):
|
||||
return web.Response(body=umsgpack.dumps({
|
||||
return web.Response(body=msgpack.dumps({
|
||||
"model": CONFIG["model"],
|
||||
"batch": BS,
|
||||
"image_size": [ t for t in preprocess.transforms if isinstance(t, transforms.Resize) ][0].size,
|
||||
@ -206,4 +206,4 @@ try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
import sys
|
||||
sys.exit(0)
|
||||
sys.exit(0)
|
||||
|
@ -1,8 +1,8 @@
|
||||
Pillow==10.0.1
|
||||
prometheus-client==0.17.1
|
||||
u-msgpack-python==2.8.0
|
||||
aiohttp==3.8.5
|
||||
aiohttp-cors==0.7.0
|
||||
faiss-cpu==1.7.4
|
||||
aiosqlite==0.19.0
|
||||
open-clip-torch==2.23.0
|
||||
open-clip-torch==2.23.0
|
||||
msgpack==1.1.0
|
||||
|
@ -1,10 +1,19 @@
|
||||
use image::codecs::bmp::BmpEncoder;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use std::borrow::Borrow;
|
||||
use image::{DynamicImage, imageops::FilterType, ImageFormat};
|
||||
use std::cell::RefCell;
|
||||
use image::{DynamicImage, ExtendedColorType, ImageEncoder};
|
||||
use anyhow::Result;
|
||||
use std::io::Cursor;
|
||||
use reqwest::Client;
|
||||
use tracing::instrument;
|
||||
use fast_image_resize::Resizer;
|
||||
use fast_image_resize::images::Image;
|
||||
use anyhow::Context;
|
||||
|
||||
std::thread_local! {
|
||||
static RESIZER: RefCell<Resizer> = RefCell::new(Resizer::new());
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct InferenceServerConfig {
|
||||
@ -16,14 +25,18 @@ pub struct InferenceServerConfig {
|
||||
pub fn resize_for_embed_sync<T: Borrow<DynamicImage> + Send + 'static>(config: InferenceServerConfig, image: T) -> Result<Vec<u8>> {
|
||||
// the model currently in use wants aspect ratio 1:1 regardless of input
|
||||
// I think this was previously being handled in the CLIP server but that is slightly lossy
|
||||
let new = image.borrow().resize_exact(
|
||||
config.image_size.0,
|
||||
config.image_size.1,
|
||||
FilterType::CatmullRom
|
||||
).into_rgb8();
|
||||
|
||||
let src_rgb = DynamicImage::from(image.borrow().to_rgb8()); // TODO this might be significantly inefficient for RGB8->RGB8 case
|
||||
|
||||
let mut dst_image = Image::new(config.image_size.0, config.image_size.1, fast_image_resize::PixelType::U8x3);
|
||||
|
||||
RESIZER.with_borrow_mut(|resizer| {
|
||||
resizer.resize(&src_rgb, &mut dst_image, None)
|
||||
}).context("resize failure")?;
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let mut csr = Cursor::new(&mut buf);
|
||||
new.write_to(&mut csr, ImageFormat::Png)?;
|
||||
BmpEncoder::new(&mut csr).write_image(dst_image.buffer(), config.image_size.0, config.image_size.1, ExtendedColorType::Rgb8)?;
|
||||
Ok::<Vec<u8>, anyhow::Error>(buf)
|
||||
}
|
||||
|
||||
|
@ -3,6 +3,7 @@ use serde::{Serialize, Deserialize};
|
||||
use std::io::BufReader;
|
||||
use rmp_serde::decode::Error as DecodeError;
|
||||
use std::fs;
|
||||
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
|
||||
|
||||
// TODO refactor
|
||||
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)]
|
||||
@ -40,7 +41,7 @@ fn main() -> Result<()> {
|
||||
match res {
|
||||
Ok(x) => {
|
||||
if x.timestamp > latest_timestamp {
|
||||
println!("{} {} https://reddit.com/r/{}/comments/{}", x.timestamp, count, x.subreddit, x.id);
|
||||
println!("{} {} https://reddit.com/r/{}/comments/{} {} https://mse.osmarks.net/?e={}", x.timestamp, count, x.subreddit, x.id, x.metadata.final_url, URL_SAFE.encode(x.embedding));
|
||||
latest_timestamp = x.timestamp;
|
||||
}
|
||||
},
|
||||
|
@ -16,7 +16,7 @@ use axum::{
|
||||
use common::resize_for_embed_sync;
|
||||
use compact_str::CompactString;
|
||||
use image::RgbImage;
|
||||
use image::{imageops::FilterType, io::Reader as ImageReader, DynamicImage, ImageFormat};
|
||||
use image::{imageops::FilterType, ImageReader, DynamicImage, ImageFormat};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::SqliteConnection;
|
||||
|
Loading…
Reference in New Issue
Block a user