mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-13 23:34:49 +00:00
performance improvements
This commit is contained in:
parent
7fa14d45ae
commit
b9bb629e6f
30
Cargo.lock
generated
30
Cargo.lock
generated
@ -697,6 +697,15 @@ dependencies = [
|
|||||||
"subtle",
|
"subtle",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "document-features"
|
||||||
|
version = "0.2.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cb6969eaabd2421f8a2775cfd2471a2b634372b4a25d41e3bd647b79912850a0"
|
||||||
|
dependencies = [
|
||||||
|
"litrs",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dotenvy"
|
name = "dotenvy"
|
||||||
version = "0.15.7"
|
version = "0.15.7"
|
||||||
@ -794,6 +803,20 @@ dependencies = [
|
|||||||
"hashbrown 0.13.2",
|
"hashbrown 0.13.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fast_image_resize"
|
||||||
|
version = "5.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a66a61fbfc84ef99a839499cf9e5a7c2951d2da874ea00f29ee938bc50d1b396"
|
||||||
|
dependencies = [
|
||||||
|
"bytemuck",
|
||||||
|
"cfg-if",
|
||||||
|
"document-features",
|
||||||
|
"image",
|
||||||
|
"num-traits",
|
||||||
|
"thiserror",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastrand"
|
name = "fastrand"
|
||||||
version = "2.1.1"
|
version = "2.1.1"
|
||||||
@ -1558,6 +1581,12 @@ version = "0.4.14"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
|
checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "litrs"
|
||||||
|
version = "0.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b4ce301924b7887e9d637144fdade93f9dfff9b60981d4ac161db09720d39aa5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lock_api"
|
name = "lock_api"
|
||||||
version = "0.4.12"
|
version = "0.4.12"
|
||||||
@ -1645,6 +1674,7 @@ dependencies = [
|
|||||||
"compact_str",
|
"compact_str",
|
||||||
"console-subscriber",
|
"console-subscriber",
|
||||||
"faiss",
|
"faiss",
|
||||||
|
"fast_image_resize",
|
||||||
"fastrand",
|
"fastrand",
|
||||||
"ffmpeg-the-third",
|
"ffmpeg-the-third",
|
||||||
"fnv",
|
"fnv",
|
||||||
|
@ -43,6 +43,7 @@ ffmpeg-the-third = "2.0"
|
|||||||
compact_str = { version = "0.8.0-beta", features = ["serde"] }
|
compact_str = { version = "0.8.0-beta", features = ["serde"] }
|
||||||
itertools = "0.13"
|
itertools = "0.13"
|
||||||
async-recursion = "1"
|
async-recursion = "1"
|
||||||
|
fast_image_resize = { version = "5", features = ["image"] }
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "reddit-dump"
|
name = "reddit-dump"
|
||||||
|
@ -5,7 +5,7 @@ from aiohttp import web
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
import umsgpack
|
import msgpack
|
||||||
import collections
|
import collections
|
||||||
import queue
|
import queue
|
||||||
import open_clip
|
import open_clip
|
||||||
@ -58,7 +58,7 @@ if CONFIG.get("aitemplate_image_models"):
|
|||||||
if "patch_embed.proj.weight" not in key:
|
if "patch_embed.proj.weight" not in key:
|
||||||
params[key.replace(".", "_")] = value.cuda()
|
params[key.replace(".", "_")] = value.cuda()
|
||||||
#print(orig_key, key.replace(".", "_"))
|
#print(orig_key, key.replace(".", "_"))
|
||||||
|
|
||||||
params["patch_embed_proj_weight"] = conv_weights
|
params["patch_embed_proj_weight"] = conv_weights
|
||||||
|
|
||||||
return params
|
return params
|
||||||
@ -151,7 +151,7 @@ routes = web.RouteTableDef()
|
|||||||
@routes.post("/")
|
@routes.post("/")
|
||||||
async def run_inference(request):
|
async def run_inference(request):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
data = umsgpack.loads(await request.read())
|
data = msgpack.loads(await request.read())
|
||||||
event = asyncio.Event()
|
event = asyncio.Event()
|
||||||
results = None
|
results = None
|
||||||
def callback(*argv):
|
def callback(*argv):
|
||||||
@ -167,11 +167,11 @@ async def run_inference(request):
|
|||||||
else:
|
else:
|
||||||
status = 500
|
status = 500
|
||||||
print(results[1])
|
print(results[1])
|
||||||
return web.Response(body=umsgpack.dumps(body_data), status=status, content_type="application/msgpack")
|
return web.Response(body=msgpack.dumps(body_data), status=status, content_type="application/msgpack")
|
||||||
|
|
||||||
@routes.get("/config")
|
@routes.get("/config")
|
||||||
async def config(request):
|
async def config(request):
|
||||||
return web.Response(body=umsgpack.dumps({
|
return web.Response(body=msgpack.dumps({
|
||||||
"model": CONFIG["model"],
|
"model": CONFIG["model"],
|
||||||
"batch": BS,
|
"batch": BS,
|
||||||
"image_size": [ t for t in preprocess.transforms if isinstance(t, transforms.Resize) ][0].size,
|
"image_size": [ t for t in preprocess.transforms if isinstance(t, transforms.Resize) ][0].size,
|
||||||
@ -206,4 +206,4 @@ try:
|
|||||||
loop.run_forever()
|
loop.run_forever()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
import sys
|
import sys
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
Pillow==10.0.1
|
Pillow==10.0.1
|
||||||
prometheus-client==0.17.1
|
prometheus-client==0.17.1
|
||||||
u-msgpack-python==2.8.0
|
|
||||||
aiohttp==3.8.5
|
aiohttp==3.8.5
|
||||||
aiohttp-cors==0.7.0
|
aiohttp-cors==0.7.0
|
||||||
faiss-cpu==1.7.4
|
faiss-cpu==1.7.4
|
||||||
aiosqlite==0.19.0
|
aiosqlite==0.19.0
|
||||||
open-clip-torch==2.23.0
|
open-clip-torch==2.23.0
|
||||||
|
msgpack==1.1.0
|
||||||
|
@ -1,10 +1,19 @@
|
|||||||
|
use image::codecs::bmp::BmpEncoder;
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
use std::borrow::Borrow;
|
use std::borrow::Borrow;
|
||||||
use image::{DynamicImage, imageops::FilterType, ImageFormat};
|
use std::cell::RefCell;
|
||||||
|
use image::{DynamicImage, ExtendedColorType, ImageEncoder};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
use fast_image_resize::Resizer;
|
||||||
|
use fast_image_resize::images::Image;
|
||||||
|
use anyhow::Context;
|
||||||
|
|
||||||
|
std::thread_local! {
|
||||||
|
static RESIZER: RefCell<Resizer> = RefCell::new(Resizer::new());
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Clone)]
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
pub struct InferenceServerConfig {
|
pub struct InferenceServerConfig {
|
||||||
@ -16,14 +25,18 @@ pub struct InferenceServerConfig {
|
|||||||
pub fn resize_for_embed_sync<T: Borrow<DynamicImage> + Send + 'static>(config: InferenceServerConfig, image: T) -> Result<Vec<u8>> {
|
pub fn resize_for_embed_sync<T: Borrow<DynamicImage> + Send + 'static>(config: InferenceServerConfig, image: T) -> Result<Vec<u8>> {
|
||||||
// the model currently in use wants aspect ratio 1:1 regardless of input
|
// the model currently in use wants aspect ratio 1:1 regardless of input
|
||||||
// I think this was previously being handled in the CLIP server but that is slightly lossy
|
// I think this was previously being handled in the CLIP server but that is slightly lossy
|
||||||
let new = image.borrow().resize_exact(
|
|
||||||
config.image_size.0,
|
let src_rgb = DynamicImage::from(image.borrow().to_rgb8()); // TODO this might be significantly inefficient for RGB8->RGB8 case
|
||||||
config.image_size.1,
|
|
||||||
FilterType::CatmullRom
|
let mut dst_image = Image::new(config.image_size.0, config.image_size.1, fast_image_resize::PixelType::U8x3);
|
||||||
).into_rgb8();
|
|
||||||
|
RESIZER.with_borrow_mut(|resizer| {
|
||||||
|
resizer.resize(&src_rgb, &mut dst_image, None)
|
||||||
|
}).context("resize failure")?;
|
||||||
|
|
||||||
let mut buf = Vec::new();
|
let mut buf = Vec::new();
|
||||||
let mut csr = Cursor::new(&mut buf);
|
let mut csr = Cursor::new(&mut buf);
|
||||||
new.write_to(&mut csr, ImageFormat::Png)?;
|
BmpEncoder::new(&mut csr).write_image(dst_image.buffer(), config.image_size.0, config.image_size.1, ExtendedColorType::Rgb8)?;
|
||||||
Ok::<Vec<u8>, anyhow::Error>(buf)
|
Ok::<Vec<u8>, anyhow::Error>(buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ use serde::{Serialize, Deserialize};
|
|||||||
use std::io::BufReader;
|
use std::io::BufReader;
|
||||||
use rmp_serde::decode::Error as DecodeError;
|
use rmp_serde::decode::Error as DecodeError;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
|
||||||
|
|
||||||
// TODO refactor
|
// TODO refactor
|
||||||
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)]
|
||||||
@ -40,7 +41,7 @@ fn main() -> Result<()> {
|
|||||||
match res {
|
match res {
|
||||||
Ok(x) => {
|
Ok(x) => {
|
||||||
if x.timestamp > latest_timestamp {
|
if x.timestamp > latest_timestamp {
|
||||||
println!("{} {} https://reddit.com/r/{}/comments/{}", x.timestamp, count, x.subreddit, x.id);
|
println!("{} {} https://reddit.com/r/{}/comments/{} {} https://mse.osmarks.net/?e={}", x.timestamp, count, x.subreddit, x.id, x.metadata.final_url, URL_SAFE.encode(x.embedding));
|
||||||
latest_timestamp = x.timestamp;
|
latest_timestamp = x.timestamp;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -16,7 +16,7 @@ use axum::{
|
|||||||
use common::resize_for_embed_sync;
|
use common::resize_for_embed_sync;
|
||||||
use compact_str::CompactString;
|
use compact_str::CompactString;
|
||||||
use image::RgbImage;
|
use image::RgbImage;
|
||||||
use image::{imageops::FilterType, io::Reader as ImageReader, DynamicImage, ImageFormat};
|
use image::{imageops::FilterType, ImageReader, DynamicImage, ImageFormat};
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sqlx::SqliteConnection;
|
use sqlx::SqliteConnection;
|
||||||
|
Loading…
Reference in New Issue
Block a user