1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-09-21 01:59:37 +00:00

Rewrite entire application (well, backend) in Rust and also Go

I decided I wanted to integrate the experimental OCR thing better, so I rewrote in Go and also integrated the thumbnailer.
However, Go is a bad langauge and I only used it out of spite.
It turned out to have a very hard-to-fix memory leak due to some unclear interaction between libvips and both sets of bindings I tried, so I had Claude-3 transpile it to Rust then spent a while fixing the several mistakes it made and making tweaks.
The new Rust version works, although I need to actually do something with the OCR data and make the index queryable concurrently.
This commit is contained in:
osmarks 2024-05-21 00:09:04 +01:00
parent fa863c2075
commit 7cb42e028f
24 changed files with 7192 additions and 51 deletions

3
.gitignore vendored
View File

@ -6,3 +6,6 @@ meme-rater/meta/
meme-rater/*.sqlite3* meme-rater/*.sqlite3*
meme-rater/deploy_for_training.sh meme-rater/deploy_for_training.sh
node_modules/* node_modules/*
node_modules
*sqlite3*
thumbtemp

View File

@ -0,0 +1,12 @@
{
"db_name": "SQLite",
"query": "INSERT OR IGNORE INTO files (filename) VALUES (?)",
"describe": {
"columns": [],
"parameters": {
"Right": 1
},
"nullable": []
},
"hash": "0d5b91c01acf72be0cd78f1a0c58c417e06d7c4e53e1ec542243ccf2808bbab7"
}

View File

@ -0,0 +1,12 @@
{
"db_name": "SQLite",
"query": "UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?",
"describe": {
"columns": [],
"parameters": {
"Right": 4
},
"nullable": []
},
"hash": "63edaa9692deb1a9fb17d9e16905a299878bc0fe4af582c6791a411741ee41d9"
}

View File

@ -0,0 +1,12 @@
{
"db_name": "SQLite",
"query": "UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?",
"describe": {
"columns": [],
"parameters": {
"Right": 3
},
"nullable": []
},
"hash": "b6803e2443445de725290dde85de5b5ef87958bec4d6db6bd06660b71f7a1ad0"
}

View File

@ -0,0 +1,12 @@
{
"db_name": "SQLite",
"query": "UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?",
"describe": {
"columns": [],
"parameters": {
"Right": 3
},
"nullable": []
},
"hash": "bed71d48c691bff7464b1aa767162df98ee2fcbe8df11f7db18a3647b2e0f1a2"
}

View File

@ -0,0 +1,62 @@
{
"db_name": "SQLite",
"query": "SELECT * FROM files WHERE filename = ?",
"describe": {
"columns": [
{
"name": "filename",
"ordinal": 0,
"type_info": "Text"
},
{
"name": "embedding_time",
"ordinal": 1,
"type_info": "Int64"
},
{
"name": "ocr_time",
"ordinal": 2,
"type_info": "Int64"
},
{
"name": "thumbnail_time",
"ordinal": 3,
"type_info": "Int64"
},
{
"name": "embedding",
"ordinal": 4,
"type_info": "Blob"
},
{
"name": "ocr",
"ordinal": 5,
"type_info": "Text"
},
{
"name": "raw_ocr_segments",
"ordinal": 6,
"type_info": "Blob"
},
{
"name": "thumbnails",
"ordinal": 7,
"type_info": "Blob"
}
],
"parameters": {
"Right": 1
},
"nullable": [
false,
true,
true,
true,
true,
true,
true,
true
]
},
"hash": "ec2da4ab11ede7a9a468ff3a50c55e0f6503fddd369f2c3031f39c0759bb97a0"
}

View File

@ -0,0 +1,12 @@
{
"db_name": "SQLite",
"query": "DELETE FROM files WHERE filename = ?",
"describe": {
"columns": [],
"parameters": {
"Right": 1
},
"nullable": []
},
"hash": "ee6eca5b34c3fbf76cd10932db35c6d8631e48be9166c02b593020a17fcf2686"
}

3320
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

34
Cargo.toml Normal file
View File

@ -0,0 +1,34 @@
[package]
name = "meme-search-engine"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
tokio = { version = "1", features = ["full"] }
axum = "0.7"
image = { version = "0.25", features = ["avif"] }
reqwest = { version = "0.12", features = ["multipart"] }
serde = { version = "1", features = ["derive"] }
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite"] }
walkdir = "1"
log = "0.4"
rmp-serde = "1"
serde_json = "1"
chrono = "0.4"
base64 = "0.22"
anyhow = "1"
fnv = "1"
faiss = "0.12"
ndarray = "0.15"
half = { version = "2" }
regex = "1"
pretty_env_logger = "0.5"
futures-util = "0.3"
tokio-stream = "0.1"
num_cpus = "1"
serde_bytes = "0.11"
tower-http = { version = "0.5", features = ["cors"] }
tower = "0.4"
json5 = "0.4"

View File

@ -62,6 +62,8 @@
.result .result
border: 1px solid gray border: 1px solid gray
*
display: block
.result img .result img
width: 100% width: 100%
</style> </style>
@ -109,17 +111,22 @@
<Loading /> <Loading />
{/if} {/if}
{#if results} {#if results}
{#if displayedResults.length === 0}
No results. Wait for index rebuild.
{/if}
<Masonry bind:refreshLayout={refreshLayout} colWidth="minmax(Min(20em, 100%), 1fr)" items={displayedResults}> <Masonry bind:refreshLayout={refreshLayout} colWidth="minmax(Min(20em, 100%), 1fr)" items={displayedResults}>
{#each displayedResults as result} {#each displayedResults as result}
{#key result.file} {#key result.file}
<div class="result"> <div class="result">
<a href={util.getURL(result.file)}> <a href={util.getURL(result)}>
<picture> <picture>
{#if util.hasThumbnails(result.file)} {#if util.hasFormat(results, result, "avifl")}
<source srcset={util.thumbnailPath(result.file, "avif-lq") + ", " + util.thumbnailPath(result.file, "avif-hq") + " 2x"} type="image/avif" /> <source srcset={util.thumbnailURL(results, result, "avifl") + (util.hasFormat(results, result, "avifh") ? ", " + util.thumbnailURL(results, result, "avifh") + " 2x" : "")} type="image/avif" />
<source srcset={util.thumbnailPath(result.file, "jpeg-800") + " 800w, " + util.thumbnailPath(result.file, "jpeg-fullscale")} type="image/jpeg" />
{/if} {/if}
<img src={util.getURL(result.file)} on:load={updateCounter} on:error={updateCounter} alt={result.caption || result.file}> {#if util.hasFormat(results, result, "jpegl")}
<source srcset={util.thumbnailURL(results, result, "jpegl") + (util.hasFormat(results, result, "jpegh") ? ", " + util.thumbnailURL(results, result, "jpegh") + " 2x" : "")} type="image/jpeg" />
{/if}
<img src={util.getURL(result)} on:load={updateCounter} on:error={updateCounter} alt={result[1]}>
</picture> </picture>
</a> </a>
</div> </div>
@ -171,9 +178,7 @@
let displayedResults = [] let displayedResults = []
const runSearch = async () => { const runSearch = async () => {
if (!resultPromise) { if (!resultPromise) {
let args = {} let args = {"terms": queryTerms.map(x => ({ image: x.imageData, text: x.text, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))}
args.text = queryTerms.filter(x => x.type === "text" && x.text).map(({ text, weight, sign }) => [ text, weight * { "+": 1, "-": -1 }[sign] ])
args.images = queryTerms.filter(x => x.type === "image").map(({ imageData, weight, sign }) => [ imageData, weight * { "+": 1, "-": -1 }[sign] ])
resultPromise = util.doQuery(args).then(res => { resultPromise = util.doQuery(args).then(res => {
error = null error = null
results = res results = res
@ -181,7 +186,8 @@
displayedResults = [] displayedResults = []
pendingImageLoads = 0 pendingImageLoads = 0
for (let i = 0; i < chunkSize; i++) { for (let i = 0; i < chunkSize; i++) {
displayedResults.push(results[i]) if (i >= results.matches.length) break
displayedResults.push(results.matches[i])
pendingImageLoads += 1 pendingImageLoads += 1
} }
redrawGrid() redrawGrid()
@ -195,7 +201,8 @@
if (window.scrollY + window.innerHeight < heightThreshold) return; if (window.scrollY + window.innerHeight < heightThreshold) return;
let init = displayedResults.length let init = displayedResults.length
for (let i = 0; i < chunkSize; i++) { for (let i = 0; i < chunkSize; i++) {
displayedResults.push(results[init + i]) if (init + i >= results.matches.length) break
displayedResults.push(results.matches[init + i])
pendingImageLoads += 1 pendingImageLoads += 1
} }
displayedResults = displayedResults displayedResults = displayedResults

View File

@ -7,7 +7,7 @@ esbuild
.build({ .build({
entryPoints: [path.join(__dirname, "app.js")], entryPoints: [path.join(__dirname, "app.js")],
bundle: true, bundle: true,
minify: true, minify: false,
outfile: path.join(__dirname, "../static/app.js"), outfile: path.join(__dirname, "../static/app.js"),
plugins: [sveltePlugin({ plugins: [sveltePlugin({
preprocess: { preprocess: {

View File

@ -1,7 +1,8 @@
import * as config from "../../frontend_config.json" import * as config from "../../frontend_config.json"
import * as backendConfig from "../../mse_config.json"
import * as formats from "../../formats.json" import * as formats from "../../formats.json"
export const getURL = x => config.image_path + x export const getURL = x => config.image_path + x[1]
export const doQuery = args => fetch(config.backend_url, { export const doQuery = args => fetch(config.backend_url, {
method: "POST", method: "POST",
@ -11,15 +12,11 @@ export const doQuery = args => fetch(config.backend_url, {
body: JSON.stringify(args) body: JSON.stringify(args)
}).then(x => x.json()) }).then(x => x.json())
const filesafeCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-" export const hasFormat = (results, result, format) => {
export const thumbnailPath = (originalPath, format) => { return result[3] && (1 << results.formats.indexOf(format)) !== 0
const extension = formats.formats[format][0]
// Python and JS have minor differences in string handling wrt. astral characters which could result in incorrect quantities of dashes. Fortunately, Array.from handles this correctly.
return config.thumb_path + `${Array.from(originalPath).map(x => filesafeCharset.includes(x) ? x : "_").join("")}.${format}${extension}`
} }
const thumbedExtensions = formats.extensions export const thumbnailURL = (results, result, format) => {
export const hasThumbnails = t => { console.log("RES", results)
const parts = t.split(".") return `${config.thumb_path}${result[2]}${format}.${results.extensions[format]}`
return thumbedExtensions.includes("." + parts[parts.length - 1])
} }

View File

@ -1,5 +1,5 @@
{ {
"backend_url": "https://mse.osmarks.net/backend", "backend_url": "http://localhost:1707/",
"image_path": "https://i2.osmarks.net/memes-or-something/", "image_path": "http://localhost:7858/",
"thumb_path": "https://i2.osmarks.net/thumbs/memes-or-something_" "thumb_path": "http://localhost:7857/"
} }

View File

@ -0,0 +1,26 @@
module meme-search
go 1.22.2
require (
github.com/DataIntelligenceCrew/go-faiss v0.2.0
github.com/jmoiron/sqlx v1.4.0
github.com/mattn/go-sqlite3 v1.14.22
github.com/samber/lo v1.39.0
github.com/titanous/json5 v1.0.0
github.com/vmihailenco/msgpack v4.0.4+incompatible
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.7.0
)
require (
github.com/davidbyttow/govips/v2 v2.14.0 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/h2non/bimg v1.1.9 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
golang.org/x/image v0.16.0 // indirect
golang.org/x/net v0.25.0 // indirect
golang.org/x/text v0.15.0 // indirect
google.golang.org/appengine v1.6.8 // indirect
google.golang.org/protobuf v1.26.0 // indirect
)

100
misc/bad-go-version/go.sum Normal file
View File

@ -0,0 +1,100 @@
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/DataIntelligenceCrew/go-faiss v0.2.0 h1:c0pxAr0vldXIuE4DZnqpl6FuuH1uZd45d+NiQHKg1uU=
github.com/DataIntelligenceCrew/go-faiss v0.2.0/go.mod h1:4Gi7G3PF78IwZigTL2M1AJXOaAgxyL66vCqUYVaNgwk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davidbyttow/govips/v2 v2.14.0 h1:il3pX0XMZ5nlwipkFJHRZ3vGzcdXWApARalJxNpRHJU=
github.com/davidbyttow/govips/v2 v2.14.0/go.mod h1:eglyvgm65eImDiJJk4wpj9LSz4pWivPzWgDqkxWJn5k=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/h2non/bimg v1.1.9 h1:WH20Nxko9l/HFm4kZCA3Phbgu2cbHvYzxwxn9YROEGg=
github.com/h2non/bimg v1.1.9/go.mod h1:R3+UiYwkK4rQl6KVFTOFJHitgLbZXBZNFh2cv3AEbp8=
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
github.com/json5/json5-go v0.0.0-20160331055859-40c2958e3bf8 h1:BQuwfXQRDQMI8YNqINKNlFV23P0h07ZvOQAtezAEsP8=
github.com/json5/json5-go v0.0.0-20160331055859-40c2958e3bf8/go.mod h1:7n1PdYNh4RIHTvILru80IEstTADqQz/wmjeNXTcC9rA=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/thoas/go-funk v0.9.3 h1:7+nAEx3kn5ZJcnDm2Bh23N2yOtweO14bi//dvRtgLpw=
github.com/thoas/go-funk v0.9.3/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q=
github.com/titanous/json5 v1.0.0 h1:hJf8Su1d9NuI/ffpxgxQfxh/UiBFZX7bMPid0rIL/7s=
github.com/titanous/json5 v1.0.0/go.mod h1:7JH1M8/LHKc6cyP5o5g3CSaRj+mBrIimTxzpvmckH8c=
github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI=
github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/image v0.10.0/go.mod h1:jtrku+n79PfroUbvDdeUWMAI+heR786BofxrbiSF+J0=
golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw=
golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM=
google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -0,0 +1,877 @@
package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"hash/fnv"
"io"
"log"
"net/http"
"os"
"path/filepath"
"runtime"
"runtime/pprof"
"strings"
"sync"
"time"
"github.com/DataIntelligenceCrew/go-faiss"
"github.com/h2non/bimg"
"github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3"
"github.com/samber/lo"
"github.com/vmihailenco/msgpack"
"github.com/x448/float16"
"golang.org/x/sync/errgroup"
)
type Config struct {
ClipServer string `json:"clip_server"`
DbPath string `json:"db_path"`
Port int16 `json:"port"`
Files string `json:"files"`
EnableOCR bool `json:"enable_ocr"`
ThumbsPath string `json:"thumbs_path"`
EnableThumbnails bool `json:"enable_thumbs"`
}
type Index struct {
vectors *faiss.IndexImpl
filenames []string
formatCodes []int64
formatNames []string
}
var schema = `
CREATE TABLE IF NOT EXISTS files (
filename TEXT PRIMARY KEY,
embedding_time INTEGER,
ocr_time INTEGER,
thumbnail_time INTEGER,
embedding BLOB,
ocr TEXT,
raw_ocr_segments BLOB,
thumbnails BLOB
);
CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
filename,
ocr,
tokenize='unicode61 remove_diacritics 2',
content='ocr'
);
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN
INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
END;
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER DELETE ON files BEGIN
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
END;
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER UPDATE ON files BEGIN
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
END;
`
type FileRecord struct {
Filename string `db:"filename"`
EmbedTime int64 `db:"embedding_time"`
OcrTime int64 `db:"ocr_time"`
ThumbnailTime int64 `db:"thumbnail_time"`
Embedding []byte `db:"embedding"`
Ocr string `db:"ocr"`
RawOcrSegments []byte `db:"raw_ocr_segments"`
Thumbnails []byte `db:"thumbnails"`
}
type InferenceServerConfig struct {
BatchSize uint `msgpack:"batch"`
ImageSize []uint `msgpack:"image_size"`
EmbeddingSize uint `msgpack:"embedding_size"`
}
func decodeMsgpackFrom[O interface{}](resp *http.Response) (O, error) {
var result O
respData, err := io.ReadAll(resp.Body)
if err != nil {
return result, err
}
err = msgpack.Unmarshal(respData, &result)
return result, err
}
func queryClipServer[I interface{}, O interface{}](config Config, path string, data I) (O, error) {
var result O
b, err := msgpack.Marshal(data)
if err != nil {
return result, err
}
resp, err := http.Post(config.ClipServer+path, "application/msgpack", bytes.NewReader(b))
if err != nil {
return result, err
}
defer resp.Body.Close()
return decodeMsgpackFrom[O](resp)
}
type LoadedImage struct {
image *bimg.Image
filename string
originalSize int
}
type EmbeddingInput struct {
image []byte
filename string
}
type EmbeddingRequest struct {
Images [][]byte `msgpack:"images"`
Text []string `msgpack:"text"`
}
type EmbeddingResponse = [][]byte
func timestamp() int64 {
return time.Now().UnixMicro()
}
type ImageFormatConfig struct {
targetWidth int
targetFilesize int
quality int
format bimg.ImageType
extension string
}
func generateFilenameHash(filename string) string {
hasher := fnv.New128()
hasher.Write([]byte(filename))
hash := hasher.Sum(make([]byte, 0))
return base64.RawURLEncoding.EncodeToString(hash)
}
func generateThumbnailFilename(filename string, formatName string, formatConfig ImageFormatConfig) string {
return fmt.Sprintf("%s%s.%s", generateFilenameHash(filename), formatName, formatConfig.extension)
}
func initializeDatabase(config Config) (*sqlx.DB, error) {
db, err := sqlx.Connect("sqlite3", config.DbPath)
if err != nil {
return nil, err
}
_, err = db.Exec("PRAGMA busy_timeout = 2000; PRAGMA journal_mode = WAL")
if err != nil {
return nil, err
}
return db, nil
}
func imageFormats(config Config) map[string]ImageFormatConfig {
return map[string]ImageFormatConfig{
"jpegl": {
targetWidth: 800,
quality: 70,
format: bimg.JPEG,
extension: "jpg",
},
"jpegh": {
targetWidth: 1600,
quality: 80,
format: bimg.JPEG,
extension: "jpg",
},
"jpeg256kb": {
targetWidth: 500,
targetFilesize: 256000,
format: bimg.JPEG,
extension: "jpg",
},
"avifh": {
targetWidth: 1600,
quality: 80,
format: bimg.AVIF,
extension: "avif",
},
"avifl": {
targetWidth: 800,
quality: 30,
format: bimg.AVIF,
extension: "avif",
},
}
}
func ingestFiles(config Config, backend InferenceServerConfig) error {
var wg errgroup.Group
var iwg errgroup.Group
// We assume everything is either a modern browser (low-DPI or high-DPI), an ancient browser or a ComputerCraft machine abusing Extra Utilities 2 screens.
var formats = imageFormats(config)
db, err := initializeDatabase(config)
if err != nil {
return err
}
defer db.Close()
toProcess := make(chan FileRecord, 100)
toEmbed := make(chan EmbeddingInput, backend.BatchSize)
toThumbnail := make(chan LoadedImage, 30)
toOCR := make(chan LoadedImage, 30)
embedBatches := make(chan []EmbeddingInput, 1)
// image loading and preliminary resizing
for range runtime.NumCPU() {
iwg.Go(func() error {
for record := range toProcess {
path := filepath.Join(config.Files, record.Filename)
buffer, err := bimg.Read(path)
if err != nil {
log.Println("could not read ", record.Filename)
}
img := bimg.NewImage(buffer)
if record.Embedding == nil {
resized, err := img.Process(bimg.Options{
Width: int(backend.ImageSize[0]),
Height: int(backend.ImageSize[1]),
Force: true,
Type: bimg.PNG,
Interpretation: bimg.InterpretationSRGB,
})
if err != nil {
log.Println("resize failure", record.Filename, err)
} else {
toEmbed <- EmbeddingInput{
image: resized,
filename: record.Filename,
}
}
}
if record.Thumbnails == nil && config.EnableThumbnails {
toThumbnail <- LoadedImage{
image: img,
filename: record.Filename,
originalSize: len(buffer),
}
}
if record.RawOcrSegments == nil && config.EnableOCR {
toOCR <- LoadedImage{
image: img,
filename: record.Filename,
}
}
}
return nil
})
}
if config.EnableThumbnails {
for range runtime.NumCPU() {
wg.Go(func() error {
for image := range toThumbnail {
generatedFormats := make([]string, 0)
for formatName, formatConfig := range formats {
var err error
var resized []byte
if formatConfig.targetFilesize != 0 {
lb := 1
ub := 100
for {
quality := (lb + ub) / 2
resized, err = image.image.Process(bimg.Options{
Width: formatConfig.targetWidth,
Type: formatConfig.format,
Speed: 4,
Quality: quality,
StripMetadata: true,
Enlarge: false,
})
if len(resized) > image.originalSize {
ub = quality
} else {
lb = quality + 1
}
if lb >= ub {
break
}
}
} else {
resized, err = image.image.Process(bimg.Options{
Width: formatConfig.targetWidth,
Type: formatConfig.format,
Speed: 4,
Quality: formatConfig.quality,
StripMetadata: true,
Enlarge: false,
})
}
if err != nil {
log.Println("thumbnailing failure", image.filename, err)
continue
}
if len(resized) < image.originalSize {
generatedFormats = append(generatedFormats, formatName)
err = bimg.Write(filepath.Join(config.ThumbsPath, generateThumbnailFilename(image.filename, formatName, formatConfig)), resized)
if err != nil {
return err
}
}
}
formatsData, err := msgpack.Marshal(generatedFormats)
if err != nil {
return err
}
_, err = db.Exec("UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?", formatsData, timestamp(), image.filename)
if err != nil {
return err
}
}
return nil
})
}
}
if config.EnableOCR {
for range 100 {
wg.Go(func() error {
for image := range toOCR {
scan, err := scanImage(image.image)
if err != nil {
log.Println("OCR failure", image.filename, err)
continue
}
ocrText := ""
for _, segment := range scan {
ocrText += segment.text
ocrText += "\n"
}
ocrData, err := msgpack.Marshal(scan)
if err != nil {
return err
}
_, err = db.Exec("UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?", ocrText, ocrData, timestamp(), image.filename)
if err != nil {
return err
}
}
return nil
})
}
}
wg.Go(func() error {
buffer := make([]EmbeddingInput, 0, backend.BatchSize)
for input := range toEmbed {
buffer = append(buffer, input)
if len(buffer) == int(backend.BatchSize) {
embedBatches <- buffer
buffer = make([]EmbeddingInput, 0, backend.BatchSize)
}
}
if len(buffer) > 0 {
embedBatches <- buffer
}
close(embedBatches)
return nil
})
for range 3 {
wg.Go(func() error {
for batch := range embedBatches {
result, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "", EmbeddingRequest{
Images: lo.Map(batch, func(item EmbeddingInput, _ int) []byte { return item.image }),
})
if err != nil {
return err
}
tx, err := db.Begin()
if err != nil {
return err
}
for i, vector := range result {
_, err = tx.Exec("UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?", timestamp(), vector, batch[i].filename)
if err != nil {
return err
}
}
err = tx.Commit()
if err != nil {
return err
}
}
return nil
})
}
filenamesOnDisk := make(map[string]struct{})
err = filepath.WalkDir(config.Files, func(path string, d os.DirEntry, err error) error {
filename := strings.TrimPrefix(path, config.Files)
if err != nil {
return err
}
if d.IsDir() {
return nil
}
filenamesOnDisk[filename] = struct{}{}
records := []FileRecord{}
err = db.Select(&records, "SELECT * FROM files WHERE filename = ?", filename)
if err != nil {
return err
}
stat, err := d.Info()
if err != nil {
return err
}
modtime := stat.ModTime().UnixMicro()
if len(records) == 0 || modtime > records[0].EmbedTime || modtime > records[0].OcrTime || modtime > records[0].ThumbnailTime {
_, err = db.Exec("INSERT OR IGNORE INTO files VALUES (?, 0, 0, 0, '', '', '', '')", filename)
if err != nil {
return err
}
record := FileRecord{
Filename: filename,
}
if len(records) > 0 {
record = records[0]
}
if modtime > record.EmbedTime || len(record.Embedding) == 0 {
record.Embedding = nil
}
if modtime > record.OcrTime || len(record.RawOcrSegments) == 0 {
record.RawOcrSegments = nil
}
if modtime > record.ThumbnailTime || len(record.Thumbnails) == 0 {
record.Thumbnails = nil
}
toProcess <- record
}
return nil
})
if err != nil {
return err
}
close(toProcess)
err = iwg.Wait()
close(toEmbed)
close(toThumbnail)
if err != nil {
return err
}
err = wg.Wait()
if err != nil {
return err
}
rows, err := db.Queryx("SELECT filename FROM files")
if err != nil {
return err
}
tx, err := db.Begin()
if err != nil {
return err
}
for rows.Next() {
var filename string
err := rows.Scan(&filename)
if err != nil {
return err
}
if _, ok := filenamesOnDisk[filename]; !ok {
_, err = tx.Exec("DELETE FROM files WHERE filename = ?", filename)
if err != nil {
return err
}
}
}
if err = tx.Commit(); err != nil {
return err
}
return nil
}
const INDEX_ADD_BATCH = 512
func buildIndex(config Config, backend InferenceServerConfig) (Index, error) {
var index Index
db, err := initializeDatabase(config)
if err != nil {
return index, err
}
defer db.Close()
newFAISSIndex, err := faiss.IndexFactory(int(backend.EmbeddingSize), "SQfp16", faiss.MetricInnerProduct)
if err != nil {
return index, err
}
index.vectors = newFAISSIndex
var count int
err = db.Get(&count, "SELECT COUNT(*) FROM files")
if err != nil {
return index, err
}
index.filenames = make([]string, 0, count)
index.formatCodes = make([]int64, 0, count)
buffer := make([]float32, 0, INDEX_ADD_BATCH*backend.EmbeddingSize)
index.formatNames = make([]string, 0, 5)
record := FileRecord{}
rows, err := db.Queryx("SELECT * FROM files")
if err != nil {
return index, err
}
for rows.Next() {
err := rows.StructScan(&record)
if err != nil {
return index, err
}
if len(record.Embedding) > 0 {
index.filenames = append(index.filenames, record.Filename)
for i := 0; i < len(record.Embedding); i += 2 {
buffer = append(buffer, float16.Frombits(uint16(record.Embedding[i])+uint16(record.Embedding[i+1])<<8).Float32())
}
if len(buffer) == cap(buffer) {
index.vectors.Add(buffer)
buffer = make([]float32, 0, INDEX_ADD_BATCH*backend.EmbeddingSize)
}
formats := make([]string, 0, 5)
if len(record.Thumbnails) > 0 {
err := msgpack.Unmarshal(record.Thumbnails, &formats)
if err != nil {
return index, err
}
}
formatCode := int64(0)
for _, formatString := range formats {
found := false
for i, name := range index.formatNames {
if name == formatString {
formatCode |= 1 << i
found = true
break
}
}
if !found {
newIndex := len(index.formatNames)
formatCode |= 1 << newIndex
index.formatNames = append(index.formatNames, formatString)
}
}
index.formatCodes = append(index.formatCodes, formatCode)
}
}
if len(buffer) > 0 {
index.vectors.Add(buffer)
}
return index, nil
}
func decodeFP16Buffer(buf []byte) []float32 {
out := make([]float32, 0, len(buf)/2)
for i := 0; i < len(buf); i += 2 {
out = append(out, float16.Frombits(uint16(buf[i])+uint16(buf[i+1])<<8).Float32())
}
return out
}
type EmbeddingVector []float32
type QueryResult struct {
Matches [][]interface{} `json:"matches"`
Formats []string `json:"formats"`
Extensions map[string]string `json:"extensions"`
}
// this terrible language cannot express tagged unions
type QueryTerm struct {
Embedding *EmbeddingVector `json:"embedding"`
Image *string `json:"image"` // base64
Text *string `json:"text"`
Weight *float32 `json:"weight"`
}
type QueryRequest struct {
Terms []QueryTerm `json:"terms"`
K *int `json:"k"`
}
func queryIndex(index *Index, query EmbeddingVector, k int) (QueryResult, error) {
var qr QueryResult
distances, ids, err := index.vectors.Search(query, int64(k))
if err != nil {
return qr, err
}
items := lo.Map(lo.Zip2(distances, ids), func(x lo.Tuple2[float32, int64], i int) []interface{} {
return []interface{}{
x.A,
index.filenames[x.B],
generateFilenameHash(index.filenames[x.B]),
index.formatCodes[x.B],
}
})
return QueryResult{
Matches: items,
Formats: index.formatNames,
}, nil
}
func handleRequest(config Config, backendConfig InferenceServerConfig, index *Index, w http.ResponseWriter, req *http.Request) error {
if req.Body == nil {
io.WriteString(w, "OK") // health check
return nil
}
dec := json.NewDecoder(req.Body)
var qreq QueryRequest
err := dec.Decode(&qreq)
if err != nil {
return err
}
totalEmbedding := make(EmbeddingVector, backendConfig.EmbeddingSize)
imageBatch := make([][]byte, 0)
imageWeights := make([]float32, 0)
textBatch := make([]string, 0)
textWeights := make([]float32, 0)
for _, term := range qreq.Terms {
if term.Image != nil {
bytes, err := base64.StdEncoding.DecodeString(*term.Image)
if err != nil {
return err
}
loaded := bimg.NewImage(bytes)
resized, err := loaded.Process(bimg.Options{
Width: int(backendConfig.ImageSize[0]),
Height: int(backendConfig.ImageSize[1]),
Force: true,
Type: bimg.PNG,
Interpretation: bimg.InterpretationSRGB,
})
if err != nil {
return err
}
imageBatch = append(imageBatch, resized)
if term.Weight != nil {
imageWeights = append(imageWeights, *term.Weight)
} else {
imageWeights = append(imageWeights, 1)
}
}
if term.Text != nil {
textBatch = append(textBatch, *term.Text)
if term.Weight != nil {
textWeights = append(textWeights, *term.Weight)
} else {
textWeights = append(textWeights, 1)
}
}
if term.Embedding != nil {
weight := float32(1.0)
if term.Weight != nil {
weight = *term.Weight
}
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
totalEmbedding[i] += (*term.Embedding)[i] * weight
}
}
}
if len(imageBatch) > 0 {
embs, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "/", EmbeddingRequest{
Images: imageBatch,
})
if err != nil {
return err
}
for j, emb := range embs {
embd := decodeFP16Buffer(emb)
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
totalEmbedding[i] += embd[i] * imageWeights[j]
}
}
}
if len(textBatch) > 0 {
embs, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "/", EmbeddingRequest{
Text: textBatch,
})
if err != nil {
return err
}
for j, emb := range embs {
embd := decodeFP16Buffer(emb)
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
totalEmbedding[i] += embd[i] * textWeights[j]
}
}
}
k := 1000
if qreq.K != nil {
k = *qreq.K
}
w.Header().Add("Content-Type", "application/json")
enc := json.NewEncoder(w)
qres, err := queryIndex(index, totalEmbedding, k)
qres.Extensions = make(map[string]string)
for k, v := range imageFormats(config) {
qres.Extensions[k] = v.extension
}
if err != nil {
return err
}
err = enc.Encode(qres)
if err != nil {
return err
}
return nil
}
func init() {
os.Setenv("VIPS_WARNING", "FALSE") // this does not actually work
bimg.VipsCacheSetMax(0)
bimg.VipsCacheSetMaxMem(0)
}
func main() {
content, err := os.ReadFile(os.Args[1])
if err != nil {
log.Fatal("config file unreadable ", err)
}
var config Config
err = json.Unmarshal(content, &config)
if err != nil {
log.Fatal("config file wrong ", err)
}
fmt.Println(config)
db, err := sqlx.Connect("sqlite3", config.DbPath)
if err != nil {
log.Fatal("DB connection failure ", db)
}
db.MustExec(schema)
var backend InferenceServerConfig
for {
resp, err := http.Get(config.ClipServer + "/config")
if err != nil {
log.Println("backend failed (fetch) ", err)
}
backend, err = decodeMsgpackFrom[InferenceServerConfig](resp)
resp.Body.Close()
if err != nil {
log.Println("backend failed (parse) ", err)
} else {
break
}
time.Sleep(time.Second)
}
requestIngest := make(chan struct{}, 1)
var index *Index
// maybe this ought to be mutexed?
var lastError *error
// there's not a neat way to reusably broadcast to multiple channels, but I *can* abuse WaitGroups probably
// this might cause horrible concurrency issues, but you brought me to this point, Go designers
var wg sync.WaitGroup
go func() {
for {
wg.Add(1)
log.Println("ingest running")
err := ingestFiles(config, backend)
if err != nil {
log.Println("ingest failed ", err)
lastError = &err
} else {
newIndex, err := buildIndex(config, backend)
if err != nil {
log.Println("index build failed ", err)
lastError = &err
} else {
lastError = nil
index = &newIndex
}
}
wg.Done()
<-requestIngest
}
}()
newIndex, err := buildIndex(config, backend)
index = &newIndex
if err != nil {
log.Fatal("index build failed ", err)
}
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Headers", "Content-Type")
if req.Method == "OPTIONS" {
w.WriteHeader(204)
return
}
err := handleRequest(config, backend, index, w, req)
if err != nil {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{
"error": err.Error(),
})
}
})
http.HandleFunc("/reload", func(w http.ResponseWriter, req *http.Request) {
if req.Method == "POST" {
log.Println("requesting index reload")
select {
case requestIngest <- struct{}{}:
default:
}
wg.Wait()
if lastError == nil {
w.Write([]byte("OK"))
} else {
w.WriteHeader(500)
w.Write([]byte((*lastError).Error()))
}
}
})
http.HandleFunc("/profile", func(w http.ResponseWriter, req *http.Request) {
f, err := os.Create("mem.pprof")
if err != nil {
log.Fatal("could not create memory profile: ", err)
}
defer f.Close()
var m runtime.MemStats
runtime.ReadMemStats(&m)
log.Printf("Memory usage: Alloc=%v, TotalAlloc=%v, Sys=%v", m.Alloc, m.TotalAlloc, m.Sys)
log.Println(bimg.VipsMemory())
bimg.VipsDebugInfo()
runtime.GC() // Trigger garbage collection
if err := pprof.WriteHeapProfile(f); err != nil {
log.Fatal("could not write memory profile: ", err)
}
})
log.Println("starting server")
http.ListenAndServe(fmt.Sprintf(":%d", config.Port), nil)
}

264
misc/bad-go-version/ocr.go Normal file
View File

@ -0,0 +1,264 @@
package main
import (
"bytes"
"errors"
"fmt"
"io"
"math"
"mime/multipart"
"net/http"
"net/textproto"
"regexp"
"strings"
"time"
"github.com/h2non/bimg"
"github.com/samber/lo"
"github.com/titanous/json5"
)
const CALLBACK_REGEX string = ">AF_initDataCallback\\(({key: 'ds:1'.*?)\\);</script>"
type SegmentCoords struct {
x int
y int
w int
h int
}
type Segment struct {
coords SegmentCoords
text string
}
type ScanResult []Segment
// TODO coordinates are negative sometimes and I think they shouldn't be
func rationalizeCoordsFormat1(imageW float64, imageH float64, centerXFraction float64, centerYFraction float64, widthFraction float64, heightFraction float64) SegmentCoords {
return SegmentCoords{
x: int(math.Round((centerXFraction - widthFraction/2) * imageW)),
y: int(math.Round((centerYFraction - heightFraction/2) * imageH)),
w: int(math.Round(widthFraction * imageW)),
h: int(math.Round(heightFraction * imageH)),
}
}
func scanImageChunk(image []byte, imageWidth int, imageHeight int) (ScanResult, error) {
var result ScanResult
timestamp := time.Now().UnixMicro()
var b bytes.Buffer
w := multipart.NewWriter(&b)
defer w.Close()
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="encoded_image"; filename="ocr%d.png"`, timestamp))
h.Set("Content-Type", "image/png")
fw, err := w.CreatePart(h)
if err != nil {
return result, err
}
fw.Write(image)
w.Close()
req, err := http.NewRequest("POST", fmt.Sprintf("https://lens.google.com/v3/upload?stcs=%d", timestamp), &b)
if err != nil {
return result, err
}
req.Header.Add("User-Agent", "Mozilla/5.0 (Linux; Android 13; RMX3771) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.6167.144 Mobile Safari/537.36")
req.AddCookie(&http.Cookie{
Name: "SOCS",
Value: "CAESEwgDEgk0ODE3Nzk3MjQaAmVuIAEaBgiA_LyaBg",
})
req.Header.Set("Content-Type", w.FormDataContentType())
client := http.Client{}
res, err := client.Do(req)
if err != nil {
return result, err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return result, err
}
re, _ := regexp.Compile(CALLBACK_REGEX)
matches := re.FindStringSubmatch(string(body[:]))
if len(matches) == 0 {
return result, fmt.Errorf("invalid API response")
}
match := matches[1]
var lensObject map[string]interface{}
err = json5.Unmarshal([]byte(match), &lensObject)
if err != nil {
return result, err
}
if _, ok := lensObject["errorHasStatus"]; ok {
return result, errors.New("lens failed")
}
root := lensObject["data"].([]interface{})
var textSegments []string
var textRegions []SegmentCoords
// I don't know why Google did this.
// Text segments are in one place and their locations are in another, using a very strange coordinate system.
// At least I don't need whatever is contained in the base64 parts (which I assume are protobufs).
// TODO: on a few images, this seems to not work for some reason.
defer func() {
if r := recover(); r != nil {
// https://github.com/dimdenGD/chrome-lens-ocr/blob/main/src/core.js#L316 has code for a fallback text segment read mode.
// In testing, this proved unnecessary (quirks of the HTTP request? I don't know), and this only happens on textless images.
textSegments = []string{}
textRegions = []SegmentCoords{}
}
}()
textSegmentsRaw := root[3].([]interface{})[4].([]interface{})[0].([]interface{})[0].([]interface{})
textRegionsRaw := root[2].([]interface{})[3].([]interface{})[0].([]interface{})
for _, x := range textRegionsRaw {
if strings.HasPrefix(x.([]interface{})[11].(string), "text:") {
rawCoords := x.([]interface{})[1].([]interface{})
coords := rationalizeCoordsFormat1(float64(imageWidth), float64(imageHeight), rawCoords[0].(float64), rawCoords[1].(float64), rawCoords[2].(float64), rawCoords[3].(float64))
textRegions = append(textRegions, coords)
}
}
for _, x := range textSegmentsRaw {
textSegment := x.(string)
textSegments = append(textSegments, textSegment)
}
return lo.Map(lo.Zip2(textSegments, textRegions), func(x lo.Tuple2[string, SegmentCoords], _ int) Segment {
return Segment{
text: x.A,
coords: x.B,
}
}), nil
}
const MAX_DIM int = 1024
func scanImage(image *bimg.Image) (ScanResult, error) {
result := ScanResult{}
metadata, err := image.Metadata()
if err != nil {
return result, err
}
width := metadata.Size.Width
height := metadata.Size.Height
if width > MAX_DIM {
width = MAX_DIM
height = int(math.Round(float64(height) * (float64(width) / float64(metadata.Size.Width))))
}
for y := 0; y < height; y += MAX_DIM {
chunkHeight := MAX_DIM
if y+chunkHeight > height {
chunkHeight = height - y
}
chunk, err := image.Process(bimg.Options{
Height: height, // these are for overall image dimensions (resize then crop)
Width: width,
Top: y,
AreaHeight: chunkHeight,
AreaWidth: width,
Crop: true,
Type: bimg.PNG,
})
if err != nil {
return result, err
}
res, err := scanImageChunk(chunk, width, chunkHeight)
if err != nil {
return result, err
}
for _, segment := range res {
result = append(result, Segment{
text: segment.text,
coords: SegmentCoords{
y: segment.coords.y + y,
x: segment.coords.x,
w: segment.coords.w,
h: segment.coords.h,
},
})
}
}
return result, nil
}
/*
async def scan_image_chunk(sess, image):
# send data to inscrutable undocumented Google service
# https://github.com/AuroraWright/owocr/blob/master/owocr/ocr.py#L193
async with aiohttp.ClientSession() as sess:
data = aiohttp.FormData()
data.add_field(
"encoded_image",
encode_img(image),
filename="ocr" + str(timestamp) + ".png",
content_type="image/png"
)
async with sess.post(url, headers=headers, cookies=cookies, data=data, timeout=10) as res:
body = await res.text()
# I really worry about Google sometimes. This is not a sensible format.
match = CALLBACK_REGEX.search(body)
if match == None:
raise ValueError("Invalid callback")
lens_object = pyjson5.loads(match.group(1))
if "errorHasStatus" in lens_object:
raise RuntimeError("Lens failed")
text_segments = []
text_regions = []
root = lens_object["data"]
# I don't know why Google did this.
# Text segments are in one place and their locations are in another, using a very strange coordinate system.
# At least I don't need whatever is contained in the base64 partss (which I assume are protobufs).
# TODO: on a few images, this seems to not work for some reason.
try:
text_segments = root[3][4][0][0]
text_regions = [ rationalize_coords_format1(image.width, image.height, *x[1]) for x in root[2][3][0] if x[11].startswith("text:") ]
except (KeyError, IndexError):
# https://github.com/dimdenGD/chrome-lens-ocr/blob/main/src/core.js#L316 has code for a fallback text segment read mode.
# In testing, this proved unnecessary (quirks of the HTTP request? I don't know), and this only happens on textless images.
return [], []
return text_segments, text_regions
MAX_SCAN_DIM = 1000 # not actually true but close enough
def chunk_image(image: Image):
chunks = []
# Cut image down in X axis (I'm assuming images aren't too wide to scan in downscaled form because merging text horizontally would be annoying)
if image.width > MAX_SCAN_DIM:
image = image.resize((MAX_SCAN_DIM, round(image.height * (image.width / MAX_SCAN_DIM))), Image.LANCZOS)
for y in range(0, image.height, MAX_SCAN_DIM):
chunks.append(image.crop((0, y, image.width, min(y + MAX_SCAN_DIM, image.height))))
return chunks
async def scan_chunks(sess: aiohttp.ClientSession, chunks: [Image]):
# If text happens to be split across the cut line it won't get read.
# This is because doing overlap read areas would be really annoying.
text = ""
regions = []
for chunk in chunks:
new_segments, new_regions = await scan_image_chunk(sess, chunk)
for segment in new_segments:
text += segment + "\n"
for i, (segment, region) in enumerate(zip(new_segments, new_regions)):
regions.append({ **region, "y": region["y"] + (MAX_SCAN_DIM * i), "text": segment })
return text, regions
async def scan_image(sess: aiohttp.ClientSession, image: Image):
return await scan_chunks(sess, chunk_image(image))
if __name__ == "__main__":
async def main():
async with aiohttp.ClientSession() as sess:
print(await scan_image(sess, Image.open("/data/public/memes-or-something/linear-algebra-chess.png")))
asyncio.run(main())
*/

View File

@ -0,0 +1,891 @@
package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"hash/fnv"
"io"
"log"
"net/http"
"os"
"path/filepath"
"runtime"
"runtime/pprof"
"strings"
"sync"
"time"
"github.com/DataIntelligenceCrew/go-faiss"
"github.com/davidbyttow/govips/v2/vips"
"github.com/h2non/bimg"
"github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3"
"github.com/samber/lo"
"github.com/vmihailenco/msgpack"
"github.com/x448/float16"
"golang.org/x/sync/errgroup"
)
type Config struct {
ClipServer string `json:"clip_server"`
DbPath string `json:"db_path"`
Port int16 `json:"port"`
Files string `json:"files"`
EnableOCR bool `json:"enable_ocr"`
ThumbsPath string `json:"thumbs_path"`
EnableThumbnails bool `json:"enable_thumbs"`
}
type Index struct {
vectors *faiss.IndexImpl
filenames []string
formatCodes []int64
formatNames []string
}
var schema = `
CREATE TABLE IF NOT EXISTS files (
filename TEXT PRIMARY KEY,
embedding_time INTEGER,
ocr_time INTEGER,
thumbnail_time INTEGER,
embedding BLOB,
ocr TEXT,
raw_ocr_segments BLOB,
thumbnails BLOB
);
CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
filename,
ocr,
tokenize='unicode61 remove_diacritics 2',
content='ocr'
);
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN
INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
END;
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER DELETE ON files BEGIN
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
END;
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER UPDATE ON files BEGIN
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
END;
`
type FileRecord struct {
Filename string `db:"filename"`
EmbedTime int64 `db:"embedding_time"`
OcrTime int64 `db:"ocr_time"`
ThumbnailTime int64 `db:"thumbnail_time"`
Embedding []byte `db:"embedding"`
Ocr string `db:"ocr"`
RawOcrSegments []byte `db:"raw_ocr_segments"`
Thumbnails []byte `db:"thumbnails"`
filesize int64
}
type InferenceServerConfig struct {
BatchSize uint `msgpack:"batch"`
ImageSize []uint `msgpack:"image_size"`
EmbeddingSize uint `msgpack:"embedding_size"`
}
func decodeMsgpackFrom[O interface{}](resp *http.Response) (O, error) {
var result O
respData, err := io.ReadAll(resp.Body)
if err != nil {
return result, err
}
err = msgpack.Unmarshal(respData, &result)
return result, err
}
func queryClipServer[I interface{}, O interface{}](config Config, path string, data I) (O, error) {
var result O
b, err := msgpack.Marshal(data)
if err != nil {
return result, err
}
resp, err := http.Post(config.ClipServer+path, "application/msgpack", bytes.NewReader(b))
if err != nil {
return result, err
}
defer resp.Body.Close()
return decodeMsgpackFrom[O](resp)
}
type LoadedImage struct {
image *vips.ImageRef
filename string
originalSize int
}
type EmbeddingInput struct {
image []byte
filename string
}
type EmbeddingRequest struct {
Images [][]byte `msgpack:"images"`
Text []string `msgpack:"text"`
}
type EmbeddingResponse = [][]byte
func timestamp() int64 {
return time.Now().UnixMicro()
}
type ImageFormatConfig struct {
targetWidth int
targetFilesize int
quality int
format vips.ImageType
extension string
}
func generateFilenameHash(filename string) string {
hasher := fnv.New128()
hasher.Write([]byte(filename))
hash := hasher.Sum(make([]byte, 0))
return base64.RawURLEncoding.EncodeToString(hash)
}
func generateThumbnailFilename(filename string, formatName string, formatConfig ImageFormatConfig) string {
return fmt.Sprintf("%s%s.%s", generateFilenameHash(filename), formatName, formatConfig.extension)
}
func initializeDatabase(config Config) (*sqlx.DB, error) {
db, err := sqlx.Connect("sqlite3", config.DbPath)
if err != nil {
return nil, err
}
_, err = db.Exec("PRAGMA busy_timeout = 2000; PRAGMA journal_mode = WAL")
if err != nil {
return nil, err
}
return db, nil
}
func imageFormats(config Config) map[string]ImageFormatConfig {
return map[string]ImageFormatConfig{
"jpegl": {
targetWidth: 800,
quality: 70,
format: vips.ImageTypeJPEG,
extension: "jpg",
},
"jpegh": {
targetWidth: 1600,
quality: 80,
format: vips.ImageTypeJPEG,
extension: "jpg",
},
"jpeg256kb": {
targetWidth: 500,
targetFilesize: 256000,
format: vips.ImageTypeJPEG,
extension: "jpg",
},
"avifh": {
targetWidth: 1600,
quality: 80,
format: vips.ImageTypeAVIF,
extension: "avif",
},
"avifl": {
targetWidth: 800,
quality: 30,
format: vips.ImageTypeAVIF,
extension: "avif",
},
}
}
func ingestFiles(config Config, backend InferenceServerConfig) error {
var wg errgroup.Group
var iwg errgroup.Group
// We assume everything is either a modern browser (low-DPI or high-DPI), an ancient browser or a ComputerCraft machine abusing Extra Utilities 2 screens.
var formats = imageFormats(config)
db, err := initializeDatabase(config)
if err != nil {
return err
}
defer db.Close()
toProcess := make(chan FileRecord, 100)
toEmbed := make(chan EmbeddingInput, backend.BatchSize)
toThumbnail := make(chan LoadedImage, 30)
toOCR := make(chan LoadedImage, 30)
embedBatches := make(chan []EmbeddingInput, 1)
// image loading and preliminary resizing
for range runtime.NumCPU() {
iwg.Go(func() error {
for record := range toProcess {
path := filepath.Join(config.Files, record.Filename)
img, err := vips.LoadImageFromFile(path, &vips.ImportParams{})
if err != nil {
log.Println("could not read", record.Filename)
continue
}
if record.Embedding == nil {
i, err := img.Copy() // TODO this is ugly, we should not need to do in-place operations
if err != nil {
return err
}
err = i.ResizeWithVScale(float64(backend.ImageSize[0])/float64(i.Width()), float64(backend.ImageSize[1])/float64(i.Height()), vips.KernelLanczos3)
if err != nil {
return err
}
resized, _, err := i.ExportPng(vips.NewPngExportParams())
if err != nil {
log.Println("resize failure", record.Filename, err)
} else {
toEmbed <- EmbeddingInput{
image: resized,
filename: record.Filename,
}
}
}
if record.Thumbnails == nil && config.EnableThumbnails {
toThumbnail <- LoadedImage{
image: img,
filename: record.Filename,
originalSize: int(record.filesize),
}
}
if record.RawOcrSegments == nil && config.EnableOCR {
toOCR <- LoadedImage{
image: img,
filename: record.Filename,
}
}
}
return nil
})
}
if config.EnableThumbnails {
for range runtime.NumCPU() {
wg.Go(func() error {
for image := range toThumbnail {
generatedFormats := make([]string, 0)
for formatName, formatConfig := range formats {
var err error
var resized []byte
if formatConfig.targetFilesize != 0 {
lb := 1
ub := 100
for {
quality := (lb + ub) / 2
i, err := image.image.Copy()
if err != nil {
return err
}
i.Resize(float64(formatConfig.targetWidth)/float64(i.Width()), vips.KernelLanczos3)
resized, _, err = i.Export(&vips.ExportParams{
Format: formatConfig.format,
Speed: 4,
Quality: quality,
StripMetadata: true,
})
if len(resized) > image.originalSize {
ub = quality
} else {
lb = quality + 1
}
if lb >= ub {
break
}
}
} else {
i, err := image.image.Copy()
if err != nil {
return err
}
i.Resize(float64(formatConfig.targetWidth)/float64(i.Width()), vips.KernelLanczos3)
resized, _, err = i.Export(&vips.ExportParams{
Format: formatConfig.format,
Speed: 4,
Quality: formatConfig.quality,
StripMetadata: true,
})
}
if err != nil {
log.Println("thumbnailing failure", image.filename, err)
continue
}
if len(resized) < image.originalSize {
generatedFormats = append(generatedFormats, formatName)
err = bimg.Write(filepath.Join(config.ThumbsPath, generateThumbnailFilename(image.filename, formatName, formatConfig)), resized)
if err != nil {
return err
}
}
}
formatsData, err := msgpack.Marshal(generatedFormats)
if err != nil {
return err
}
_, err = db.Exec("UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?", formatsData, timestamp(), image.filename)
if err != nil {
return err
}
}
return nil
})
}
}
if config.EnableOCR {
for range 100 {
wg.Go(func() error {
for image := range toOCR {
scan, err := scanImage(image.image)
if err != nil {
log.Println("OCR failure", image.filename, err)
continue
}
ocrText := ""
for _, segment := range scan {
ocrText += segment.text
ocrText += "\n"
}
ocrData, err := msgpack.Marshal(scan)
if err != nil {
return err
}
_, err = db.Exec("UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?", ocrText, ocrData, timestamp(), image.filename)
if err != nil {
return err
}
}
return nil
})
}
}
wg.Go(func() error {
buffer := make([]EmbeddingInput, 0, backend.BatchSize)
for input := range toEmbed {
buffer = append(buffer, input)
if len(buffer) == int(backend.BatchSize) {
embedBatches <- buffer
buffer = make([]EmbeddingInput, 0, backend.BatchSize)
}
}
if len(buffer) > 0 {
embedBatches <- buffer
}
close(embedBatches)
return nil
})
for range 3 {
wg.Go(func() error {
for batch := range embedBatches {
result, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "", EmbeddingRequest{
Images: lo.Map(batch, func(item EmbeddingInput, _ int) []byte { return item.image }),
})
if err != nil {
return err
}
tx, err := db.Begin()
if err != nil {
return err
}
for i, vector := range result {
_, err = tx.Exec("UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?", timestamp(), vector, batch[i].filename)
if err != nil {
return err
}
}
err = tx.Commit()
if err != nil {
return err
}
}
return nil
})
}
filenamesOnDisk := make(map[string]struct{})
err = filepath.WalkDir(config.Files, func(path string, d os.DirEntry, err error) error {
filename := strings.TrimPrefix(path, config.Files)
if err != nil {
return err
}
if d.IsDir() {
return nil
}
filenamesOnDisk[filename] = struct{}{}
records := []FileRecord{}
err = db.Select(&records, "SELECT * FROM files WHERE filename = ?", filename)
if err != nil {
return err
}
stat, err := d.Info()
if err != nil {
return err
}
modtime := stat.ModTime().UnixMicro()
if len(records) == 0 || modtime > records[0].EmbedTime || modtime > records[0].OcrTime || modtime > records[0].ThumbnailTime {
_, err = db.Exec("INSERT OR IGNORE INTO files VALUES (?, 0, 0, 0, '', '', '', '')", filename)
if err != nil {
return err
}
record := FileRecord{
Filename: filename,
filesize: stat.Size(),
}
if len(records) > 0 {
record = records[0]
}
if modtime > record.EmbedTime || len(record.Embedding) == 0 {
record.Embedding = nil
}
if modtime > record.OcrTime || len(record.RawOcrSegments) == 0 {
record.RawOcrSegments = nil
}
if modtime > record.ThumbnailTime || len(record.Thumbnails) == 0 {
record.Thumbnails = nil
}
toProcess <- record
}
return nil
})
if err != nil {
return err
}
close(toProcess)
err = iwg.Wait()
close(toEmbed)
close(toThumbnail)
if err != nil {
return err
}
err = wg.Wait()
if err != nil {
return err
}
rows, err := db.Queryx("SELECT filename FROM files")
if err != nil {
return err
}
tx, err := db.Begin()
if err != nil {
return err
}
for rows.Next() {
var filename string
err := rows.Scan(&filename)
if err != nil {
return err
}
if _, ok := filenamesOnDisk[filename]; !ok {
_, err = tx.Exec("DELETE FROM files WHERE filename = ?", filename)
if err != nil {
return err
}
}
}
if err = tx.Commit(); err != nil {
return err
}
return nil
}
const INDEX_ADD_BATCH = 512
func buildIndex(config Config, backend InferenceServerConfig) (Index, error) {
var index Index
db, err := initializeDatabase(config)
if err != nil {
return index, err
}
defer db.Close()
newFAISSIndex, err := faiss.IndexFactory(int(backend.EmbeddingSize), "SQfp16", faiss.MetricInnerProduct)
if err != nil {
return index, err
}
index.vectors = newFAISSIndex
var count int
err = db.Get(&count, "SELECT COUNT(*) FROM files")
if err != nil {
return index, err
}
index.filenames = make([]string, 0, count)
index.formatCodes = make([]int64, 0, count)
buffer := make([]float32, 0, INDEX_ADD_BATCH*backend.EmbeddingSize)
index.formatNames = make([]string, 0, 5)
record := FileRecord{}
rows, err := db.Queryx("SELECT * FROM files")
if err != nil {
return index, err
}
for rows.Next() {
err := rows.StructScan(&record)
if err != nil {
return index, err
}
if len(record.Embedding) > 0 {
index.filenames = append(index.filenames, record.Filename)
for i := 0; i < len(record.Embedding); i += 2 {
buffer = append(buffer, float16.Frombits(uint16(record.Embedding[i])+uint16(record.Embedding[i+1])<<8).Float32())
}
if len(buffer) == cap(buffer) {
index.vectors.Add(buffer)
buffer = make([]float32, 0, INDEX_ADD_BATCH*backend.EmbeddingSize)
}
formats := make([]string, 0, 5)
if len(record.Thumbnails) > 0 {
err := msgpack.Unmarshal(record.Thumbnails, &formats)
if err != nil {
return index, err
}
}
formatCode := int64(0)
for _, formatString := range formats {
found := false
for i, name := range index.formatNames {
if name == formatString {
formatCode |= 1 << i
found = true
break
}
}
if !found {
newIndex := len(index.formatNames)
formatCode |= 1 << newIndex
index.formatNames = append(index.formatNames, formatString)
}
}
index.formatCodes = append(index.formatCodes, formatCode)
}
}
if len(buffer) > 0 {
index.vectors.Add(buffer)
}
return index, nil
}
func decodeFP16Buffer(buf []byte) []float32 {
out := make([]float32, 0, len(buf)/2)
for i := 0; i < len(buf); i += 2 {
out = append(out, float16.Frombits(uint16(buf[i])+uint16(buf[i+1])<<8).Float32())
}
return out
}
type EmbeddingVector []float32
type QueryResult struct {
Matches [][]interface{} `json:"matches"`
Formats []string `json:"formats"`
Extensions map[string]string `json:"extensions"`
}
// this terrible language cannot express tagged unions
type QueryTerm struct {
Embedding *EmbeddingVector `json:"embedding"`
Image *string `json:"image"` // base64
Text *string `json:"text"`
Weight *float32 `json:"weight"`
}
type QueryRequest struct {
Terms []QueryTerm `json:"terms"`
K *int `json:"k"`
}
func queryIndex(index *Index, query EmbeddingVector, k int) (QueryResult, error) {
var qr QueryResult
distances, ids, err := index.vectors.Search(query, int64(k))
if err != nil {
return qr, err
}
items := lo.Map(lo.Zip2(distances, ids), func(x lo.Tuple2[float32, int64], i int) []interface{} {
return []interface{}{
x.A,
index.filenames[x.B],
generateFilenameHash(index.filenames[x.B]),
index.formatCodes[x.B],
}
})
return QueryResult{
Matches: items,
Formats: index.formatNames,
}, nil
}
func handleRequest(config Config, backendConfig InferenceServerConfig, index *Index, w http.ResponseWriter, req *http.Request) error {
if req.Body == nil {
io.WriteString(w, "OK") // health check
return nil
}
dec := json.NewDecoder(req.Body)
var qreq QueryRequest
err := dec.Decode(&qreq)
if err != nil {
return err
}
totalEmbedding := make(EmbeddingVector, backendConfig.EmbeddingSize)
imageBatch := make([][]byte, 0)
imageWeights := make([]float32, 0)
textBatch := make([]string, 0)
textWeights := make([]float32, 0)
for _, term := range qreq.Terms {
if term.Image != nil {
bytes, err := base64.StdEncoding.DecodeString(*term.Image)
if err != nil {
return err
}
loaded := bimg.NewImage(bytes)
resized, err := loaded.Process(bimg.Options{
Width: int(backendConfig.ImageSize[0]),
Height: int(backendConfig.ImageSize[1]),
Force: true,
Type: bimg.PNG,
Interpretation: bimg.InterpretationSRGB,
})
if err != nil {
return err
}
imageBatch = append(imageBatch, resized)
if term.Weight != nil {
imageWeights = append(imageWeights, *term.Weight)
} else {
imageWeights = append(imageWeights, 1)
}
}
if term.Text != nil {
textBatch = append(textBatch, *term.Text)
if term.Weight != nil {
textWeights = append(textWeights, *term.Weight)
} else {
textWeights = append(textWeights, 1)
}
}
if term.Embedding != nil {
weight := float32(1.0)
if term.Weight != nil {
weight = *term.Weight
}
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
totalEmbedding[i] += (*term.Embedding)[i] * weight
}
}
}
if len(imageBatch) > 0 {
embs, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "/", EmbeddingRequest{
Images: imageBatch,
})
if err != nil {
return err
}
for j, emb := range embs {
embd := decodeFP16Buffer(emb)
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
totalEmbedding[i] += embd[i] * imageWeights[j]
}
}
}
if len(textBatch) > 0 {
embs, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "/", EmbeddingRequest{
Text: textBatch,
})
if err != nil {
return err
}
for j, emb := range embs {
embd := decodeFP16Buffer(emb)
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
totalEmbedding[i] += embd[i] * textWeights[j]
}
}
}
k := 1000
if qreq.K != nil {
k = *qreq.K
}
w.Header().Add("Content-Type", "application/json")
enc := json.NewEncoder(w)
qres, err := queryIndex(index, totalEmbedding, k)
qres.Extensions = make(map[string]string)
for k, v := range imageFormats(config) {
qres.Extensions[k] = v.extension
}
if err != nil {
return err
}
err = enc.Encode(qres)
if err != nil {
return err
}
return nil
}
func init() {
os.Setenv("VIPS_WARNING", "FALSE") // this does not actually work
bimg.VipsCacheSetMax(0)
bimg.VipsCacheSetMaxMem(0)
}
func main() {
vips.Startup(&vips.Config{})
defer vips.Shutdown()
content, err := os.ReadFile(os.Args[1])
if err != nil {
log.Fatal("config file unreadable ", err)
}
var config Config
err = json.Unmarshal(content, &config)
if err != nil {
log.Fatal("config file wrong ", err)
}
fmt.Println(config)
db, err := sqlx.Connect("sqlite3", config.DbPath)
if err != nil {
log.Fatal("DB connection failure ", db)
}
db.MustExec(schema)
var backend InferenceServerConfig
for {
resp, err := http.Get(config.ClipServer + "/config")
if err != nil {
log.Println("backend failed (fetch) ", err)
}
backend, err = decodeMsgpackFrom[InferenceServerConfig](resp)
resp.Body.Close()
if err != nil {
log.Println("backend failed (parse) ", err)
} else {
break
}
time.Sleep(time.Second)
}
requestIngest := make(chan struct{}, 1)
var index *Index
// maybe this ought to be mutexed?
var lastError *error
// there's not a neat way to reusably broadcast to multiple channels, but I *can* abuse WaitGroups probably
// this might cause horrible concurrency issues, but you brought me to this point, Go designers
var wg sync.WaitGroup
go func() {
for {
wg.Add(1)
log.Println("ingest running")
err := ingestFiles(config, backend)
if err != nil {
log.Println("ingest failed ", err)
lastError = &err
} else {
newIndex, err := buildIndex(config, backend)
if err != nil {
log.Println("index build failed ", err)
lastError = &err
} else {
lastError = nil
index = &newIndex
}
}
wg.Done()
<-requestIngest
}
}()
newIndex, err := buildIndex(config, backend)
index = &newIndex
if err != nil {
log.Fatal("index build failed ", err)
}
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Headers", "Content-Type")
if req.Method == "OPTIONS" {
w.WriteHeader(204)
return
}
err := handleRequest(config, backend, index, w, req)
if err != nil {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{
"error": err.Error(),
})
}
})
http.HandleFunc("/reload", func(w http.ResponseWriter, req *http.Request) {
if req.Method == "POST" {
log.Println("requesting index reload")
select {
case requestIngest <- struct{}{}:
default:
}
wg.Wait()
if lastError == nil {
w.Write([]byte("OK"))
} else {
w.WriteHeader(500)
w.Write([]byte((*lastError).Error()))
}
}
})
http.HandleFunc("/profile", func(w http.ResponseWriter, req *http.Request) {
f, err := os.Create("mem.pprof")
if err != nil {
log.Fatal("could not create memory profile: ", err)
}
defer f.Close()
var m runtime.MemStats
runtime.ReadMemStats(&m)
log.Printf("Memory usage: Alloc=%v, TotalAlloc=%v, Sys=%v", m.Alloc, m.TotalAlloc, m.Sys)
log.Println(bimg.VipsMemory())
bimg.VipsDebugInfo()
runtime.GC() // Trigger garbage collection
if err := pprof.WriteHeapProfile(f); err != nil {
log.Fatal("could not write memory profile: ", err)
}
})
log.Println("starting server")
http.ListenAndServe(fmt.Sprintf(":%d", config.Port), nil)
}

View File

@ -0,0 +1,265 @@
package main
import (
"bytes"
"errors"
"fmt"
"io"
"math"
"mime/multipart"
"net/http"
"net/textproto"
"regexp"
"strings"
"time"
"github.com/davidbyttow/govips/v2/vips"
"github.com/samber/lo"
"github.com/titanous/json5"
)
const CALLBACK_REGEX string = ">AF_initDataCallback\\(({key: 'ds:1'.*?)\\);</script>"
type SegmentCoords struct {
x int
y int
w int
h int
}
type Segment struct {
coords SegmentCoords
text string
}
type ScanResult []Segment
// TODO coordinates are negative sometimes and I think they shouldn't be
func rationalizeCoordsFormat1(imageW float64, imageH float64, centerXFraction float64, centerYFraction float64, widthFraction float64, heightFraction float64) SegmentCoords {
return SegmentCoords{
x: int(math.Round((centerXFraction - widthFraction/2) * imageW)),
y: int(math.Round((centerYFraction - heightFraction/2) * imageH)),
w: int(math.Round(widthFraction * imageW)),
h: int(math.Round(heightFraction * imageH)),
}
}
func scanImageChunk(image []byte, imageWidth int, imageHeight int) (ScanResult, error) {
var result ScanResult
timestamp := time.Now().UnixMicro()
var b bytes.Buffer
w := multipart.NewWriter(&b)
defer w.Close()
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="encoded_image"; filename="ocr%d.png"`, timestamp))
h.Set("Content-Type", "image/png")
fw, err := w.CreatePart(h)
if err != nil {
return result, err
}
fw.Write(image)
w.Close()
req, err := http.NewRequest("POST", fmt.Sprintf("https://lens.google.com/v3/upload?stcs=%d", timestamp), &b)
if err != nil {
return result, err
}
req.Header.Add("User-Agent", "Mozilla/5.0 (Linux; Android 13; RMX3771) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.6167.144 Mobile Safari/537.36")
req.AddCookie(&http.Cookie{
Name: "SOCS",
Value: "CAESEwgDEgk0ODE3Nzk3MjQaAmVuIAEaBgiA_LyaBg",
})
req.Header.Set("Content-Type", w.FormDataContentType())
client := http.Client{}
res, err := client.Do(req)
if err != nil {
return result, err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return result, err
}
re, _ := regexp.Compile(CALLBACK_REGEX)
matches := re.FindStringSubmatch(string(body[:]))
if len(matches) == 0 {
return result, fmt.Errorf("invalid API response")
}
match := matches[1]
var lensObject map[string]interface{}
err = json5.Unmarshal([]byte(match), &lensObject)
if err != nil {
return result, err
}
if _, ok := lensObject["errorHasStatus"]; ok {
return result, errors.New("lens failed")
}
root := lensObject["data"].([]interface{})
var textSegments []string
var textRegions []SegmentCoords
// I don't know why Google did this.
// Text segments are in one place and their locations are in another, using a very strange coordinate system.
// At least I don't need whatever is contained in the base64 parts (which I assume are protobufs).
// TODO: on a few images, this seems to not work for some reason.
defer func() {
if r := recover(); r != nil {
// https://github.com/dimdenGD/chrome-lens-ocr/blob/main/src/core.js#L316 has code for a fallback text segment read mode.
// In testing, this proved unnecessary (quirks of the HTTP request? I don't know), and this only happens on textless images.
textSegments = []string{}
textRegions = []SegmentCoords{}
}
}()
textSegmentsRaw := root[3].([]interface{})[4].([]interface{})[0].([]interface{})[0].([]interface{})
textRegionsRaw := root[2].([]interface{})[3].([]interface{})[0].([]interface{})
for _, x := range textRegionsRaw {
if strings.HasPrefix(x.([]interface{})[11].(string), "text:") {
rawCoords := x.([]interface{})[1].([]interface{})
coords := rationalizeCoordsFormat1(float64(imageWidth), float64(imageHeight), rawCoords[0].(float64), rawCoords[1].(float64), rawCoords[2].(float64), rawCoords[3].(float64))
textRegions = append(textRegions, coords)
}
}
for _, x := range textSegmentsRaw {
textSegment := x.(string)
textSegments = append(textSegments, textSegment)
}
return lo.Map(lo.Zip2(textSegments, textRegions), func(x lo.Tuple2[string, SegmentCoords], _ int) Segment {
return Segment{
text: x.A,
coords: x.B,
}
}), nil
}
const MAX_DIM int = 1024
func scanImage(image *vips.ImageRef) (ScanResult, error) {
result := ScanResult{}
width := image.Width()
height := image.Height()
if width > MAX_DIM {
width = MAX_DIM
height = int(math.Round(float64(height) * (float64(width) / float64(image.Width()))))
}
downscaled, err := image.Copy()
if err != nil {
return result, err
}
downscaled.Resize(float64(width)/float64(image.Width()), vips.KernelLanczos3)
for y := 0; y < height; y += MAX_DIM {
chunkHeight := MAX_DIM
if y+chunkHeight > height {
chunkHeight = height - y
}
chunk, err := image.Copy() // TODO this really really should not be in-place
if err != nil {
return result, err
}
err = chunk.ExtractArea(0, y, width, height)
if err != nil {
return result, err
}
buf, _, err := chunk.ExportPng(&vips.PngExportParams{})
if err != nil {
return result, err
}
res, err := scanImageChunk(buf, width, chunkHeight)
if err != nil {
return result, err
}
for _, segment := range res {
result = append(result, Segment{
text: segment.text,
coords: SegmentCoords{
y: segment.coords.y + y,
x: segment.coords.x,
w: segment.coords.w,
h: segment.coords.h,
},
})
}
}
return result, nil
}
/*
async def scan_image_chunk(sess, image):
# send data to inscrutable undocumented Google service
# https://github.com/AuroraWright/owocr/blob/master/owocr/ocr.py#L193
async with aiohttp.ClientSession() as sess:
data = aiohttp.FormData()
data.add_field(
"encoded_image",
encode_img(image),
filename="ocr" + str(timestamp) + ".png",
content_type="image/png"
)
async with sess.post(url, headers=headers, cookies=cookies, data=data, timeout=10) as res:
body = await res.text()
# I really worry about Google sometimes. This is not a sensible format.
match = CALLBACK_REGEX.search(body)
if match == None:
raise ValueError("Invalid callback")
lens_object = pyjson5.loads(match.group(1))
if "errorHasStatus" in lens_object:
raise RuntimeError("Lens failed")
text_segments = []
text_regions = []
root = lens_object["data"]
# I don't know why Google did this.
# Text segments are in one place and their locations are in another, using a very strange coordinate system.
# At least I don't need whatever is contained in the base64 partss (which I assume are protobufs).
# TODO: on a few images, this seems to not work for some reason.
try:
text_segments = root[3][4][0][0]
text_regions = [ rationalize_coords_format1(image.width, image.height, *x[1]) for x in root[2][3][0] if x[11].startswith("text:") ]
except (KeyError, IndexError):
# https://github.com/dimdenGD/chrome-lens-ocr/blob/main/src/core.js#L316 has code for a fallback text segment read mode.
# In testing, this proved unnecessary (quirks of the HTTP request? I don't know), and this only happens on textless images.
return [], []
return text_segments, text_regions
MAX_SCAN_DIM = 1000 # not actually true but close enough
def chunk_image(image: Image):
chunks = []
# Cut image down in X axis (I'm assuming images aren't too wide to scan in downscaled form because merging text horizontally would be annoying)
if image.width > MAX_SCAN_DIM:
image = image.resize((MAX_SCAN_DIM, round(image.height * (image.width / MAX_SCAN_DIM))), Image.LANCZOS)
for y in range(0, image.height, MAX_SCAN_DIM):
chunks.append(image.crop((0, y, image.width, min(y + MAX_SCAN_DIM, image.height))))
return chunks
async def scan_chunks(sess: aiohttp.ClientSession, chunks: [Image]):
# If text happens to be split across the cut line it won't get read.
# This is because doing overlap read areas would be really annoying.
text = ""
regions = []
for chunk in chunks:
new_segments, new_regions = await scan_image_chunk(sess, chunk)
for segment in new_segments:
text += segment + "\n"
for i, (segment, region) in enumerate(zip(new_segments, new_regions)):
regions.append({ **region, "y": region["y"] + (MAX_SCAN_DIM * i), "text": segment })
return text, regions
async def scan_image(sess: aiohttp.ClientSession, image: Image):
return await scan_chunks(sess, chunk_image(image))
if __name__ == "__main__":
async def main():
async with aiohttp.ClientSession() as sess:
print(await scan_image(sess, Image.open("/data/public/memes-or-something/linear-algebra-chess.png")))
asyncio.run(main())
*/

90
mse.py
View File

@ -12,8 +12,11 @@ import os
import aiohttp_cors import aiohttp_cors
import json import json
import io import io
import time
import sys import sys
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import threading
with open(sys.argv[1], "r") as config_file: with open(sys.argv[1], "r") as config_file:
CONFIG = json.load(config_file) CONFIG = json.load(config_file)
@ -21,8 +24,7 @@ with open(sys.argv[1], "r") as config_file:
app = web.Application(client_max_size=32*1024**2) app = web.Application(client_max_size=32*1024**2)
routes = web.RouteTableDef() routes = web.RouteTableDef()
async def clip_server(query, unpack_buffer=True): async def clip_server(sess: aiohttp.ClientSession, query, unpack_buffer=True):
async with aiohttp.ClientSession() as sess:
async with sess.post(CONFIG["clip_server"], data=umsgpack.dumps(query)) as res: async with sess.post(CONFIG["clip_server"], data=umsgpack.dumps(query)) as res:
response = umsgpack.loads(await res.read()) response = umsgpack.loads(await res.read())
if res.status == 200: if res.status == 200:
@ -34,13 +36,14 @@ async def clip_server(query, unpack_buffer=True):
@routes.post("/") @routes.post("/")
async def run_query(request): async def run_query(request):
sess = app["session"]
data = await request.json() data = await request.json()
embeddings = [] embeddings = []
if images := data.get("images", []): if images := data.get("images", []):
target_image_size = app["index"].inference_server_config["image_size"] target_image_size = app["index"].inference_server_config["image_size"]
embeddings.extend(await clip_server({ "images": [ load_image(io.BytesIO(base64.b64decode(x)), target_image_size)[0] for x, w in images ] })) embeddings.extend(await clip_server(sess, { "images": [ load_image(io.BytesIO(base64.b64decode(x)), target_image_size)[0] for x, w in images ] }))
if text := data.get("text", []): if text := data.get("text", []):
embeddings.extend(await clip_server({ "text": [ x for x, w in text ] })) embeddings.extend(await clip_server(sess, { "text": [ x for x, w in text ] }))
weights = [ w for x, w in images ] + [ w for x, w in text ] weights = [ w for x, w in images ] + [ w for x, w in text ]
weighted_embeddings = [ e * w for e, w in zip(embeddings, weights) ] weighted_embeddings = [ e * w for e, w in zip(embeddings, weights) ]
weighted_embeddings.extend([ numpy.array(x) for x in data.get("embeddings", []) ]) weighted_embeddings.extend([ numpy.array(x) for x in data.get("embeddings", []) ])
@ -65,11 +68,12 @@ def load_image(path, image_size):
return buf.getvalue(), path return buf.getvalue(), path
class Index: class Index:
def __init__(self, inference_server_config): def __init__(self, inference_server_config, http_session):
self.faiss_index = faiss.IndexFlatIP(inference_server_config["embedding_size"]) self.faiss_index = faiss.IndexFlatIP(inference_server_config["embedding_size"])
self.associated_filenames = [] self.associated_filenames = []
self.inference_server_config = inference_server_config self.inference_server_config = inference_server_config
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.session = http_session
def search(self, query, top_k): def search(self, query, top_k):
distances, indices = self.faiss_index.search(numpy.array([query]), top_k) distances, indices = self.faiss_index.search(numpy.array([query]), top_k)
@ -80,18 +84,77 @@ class Index:
except IndexError: pass except IndexError: pass
return [ { "score": float(distance), "file": self.associated_filenames[index] } for index, distance in zip(indices, distances) ] return [ { "score": float(distance), "file": self.associated_filenames[index] } for index, distance in zip(indices, distances) ]
async def run_ocr(self):
if not CONFIG.get("enable_ocr"): return
import ocr
print("Running OCR")
conn = await aiosqlite.connect(CONFIG["db_path"])
unocred = await conn.execute_fetchall("SELECT files.filename FROM files LEFT JOIN ocr ON files.filename = ocr.filename WHERE ocr.scan_time IS NULL OR ocr.scan_time < files.modtime")
ocr_sem = asyncio.Semaphore(20) # Google has more concurrency than our internal CLIP backend. I am sure they will be fine.
load_sem = threading.Semaphore(100) # provide backpressure in loading to avoid using 50GB of RAM (this happened)
async def run_image(filename, chunks):
try:
text, regions = await ocr.scan_chunks(self.session, chunks)
await conn.execute("INSERT OR REPLACE INTO ocr VALUES (?, ?, ?, ?)", (filename, time.time(), text, json.dumps(regions)))
await conn.commit()
sys.stdout.write(".")
sys.stdout.flush()
except:
print("OCR failed on", filename)
finally:
ocr_sem.release()
def load_and_chunk_image(filename):
load_sem.acquire()
im = Image.open(Path(CONFIG["files"]) / filename)
return filename, ocr.chunk_image(im)
async with asyncio.TaskGroup() as tg:
with ThreadPoolExecutor(max_workers=CONFIG.get("n_workers", 1)) as executor:
for task in asyncio.as_completed([ asyncio.get_running_loop().run_in_executor(executor, load_and_chunk_image, file[0]) for file in unocred ]):
filename, chunks = await task
await ocr_sem.acquire()
tg.create_task(run_image(filename, chunks))
load_sem.release()
async def reload(self): async def reload(self):
async with self.lock: async with self.lock:
with ProcessPoolExecutor(max_workers=12) as executor: with ThreadPoolExecutor(max_workers=CONFIG.get("n_workers", 1)) as executor:
print("Indexing") print("Indexing")
conn = await aiosqlite.connect(CONFIG["db_path"]) conn = await aiosqlite.connect(CONFIG["db_path"])
conn.row_factory = aiosqlite.Row conn.row_factory = aiosqlite.Row
await conn.executescript(""" await conn.executescript("""
CREATE TABLE IF NOT EXISTS files ( CREATE TABLE IF NOT EXISTS files (
filename TEXT PRIMARY KEY, filename TEXT PRIMARY KEY,
modtime REAL NOT NULL, modtime REAL NOT NULL,
embedding_vector BLOB NOT NULL embedding_vector BLOB NOT NULL
); );
CREATE TABLE IF NOT EXISTS ocr (
filename TEXT PRIMARY KEY REFERENCES files(filename),
scan_time INTEGER NOT NULL,
text TEXT NOT NULL,
raw_segments TEXT
);
CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
filename,
text,
tokenize='unicode61 remove_diacritics 2',
content='ocr'
);
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON ocr BEGIN
INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, new.text);
END;
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER DELETE ON ocr BEGIN
INSERT INTO ocr_fts (ocr_fts, rowid, filename, text) VALUES ('delete', old.rowid, old.filename, old.text);
END;
""") """)
try: try:
async with asyncio.TaskGroup() as tg: async with asyncio.TaskGroup() as tg:
@ -102,7 +165,7 @@ class Index:
async def do_batch(batch): async def do_batch(batch):
try: try:
query = { "images": [ arg[2] for arg in batch ] } query = { "images": [ arg[2] for arg in batch ] }
embeddings = await clip_server(query, False) embeddings = await clip_server(self.session, query, False)
await conn.executemany("INSERT OR REPLACE INTO files VALUES (?, ?, ?)", [ await conn.executemany("INSERT OR REPLACE INTO files VALUES (?, ?, ?)", [
(filename, modtime, embedding) for (filename, modtime, _), embedding in zip(batch, embeddings) (filename, modtime, embedding) for (filename, modtime, _), embedding in zip(batch, embeddings)
]) ])
@ -188,6 +251,8 @@ class Index:
finally: finally:
await conn.close() await conn.close()
await self.run_ocr()
app.router.add_routes(routes) app.router.add_routes(routes)
cors = aiohttp_cors.setup(app, defaults={ cors = aiohttp_cors.setup(app, defaults={
@ -201,8 +266,8 @@ for route in list(app.router.routes()):
cors.add(route) cors.add(route)
async def main(): async def main():
sess = aiohttp.ClientSession()
while True: while True:
async with aiohttp.ClientSession() as sess:
try: try:
async with await sess.get(CONFIG["clip_server"] + "config") as res: async with await sess.get(CONFIG["clip_server"] + "config") as res:
inference_server_config = umsgpack.unpackb(await res.read()) inference_server_config = umsgpack.unpackb(await res.read())
@ -211,8 +276,9 @@ async def main():
except: except:
traceback.print_exc() traceback.print_exc()
await asyncio.sleep(1) await asyncio.sleep(1)
index = Index(inference_server_config) index = Index(inference_server_config, sess)
app["index"] = index app["index"] = index
app["session"] = sess
await index.reload() await index.reload()
print("Ready") print("Ready")
if CONFIG.get("no_run_server", False): return if CONFIG.get("no_run_server", False): return

View File

@ -1,6 +1,9 @@
{ {
"clip_server": "http://localhost:1708/", "clip_server": "http://100.64.0.10:1708",
"db_path": "/srv/mse/data.sqlite3", "db_path": "data.sqlite3",
"port": 1707, "port": 1707,
"files": "/data/public/memes-or-something/" "files": "/data/public/memes-or-something/",
"enable_ocr": false,
"thumbs_path": "./thumbtemp",
"enable_thumbs": false
} }

101
ocr.py Normal file
View File

@ -0,0 +1,101 @@
import pyjson5
import re
import asyncio
import aiohttp
from PIL import Image
import time
import io
CALLBACK_REGEX = re.compile(r">AF_initDataCallback\(({key: 'ds:1'.*?)\);</script>")
def encode_img(img):
image_bytes = io.BytesIO()
img.save(image_bytes, format="PNG", compress_level=6)
return image_bytes.getvalue()
def rationalize_coords_format1(image_w, image_h, center_x_fraction, center_y_fraction, width_fraction, height_fraction, mysterious):
return {
"x": round((center_x_fraction - width_fraction / 2) * image_w),
"y": round((center_y_fraction - height_fraction / 2) * image_h),
"w": round(width_fraction * image_w),
"h": round(height_fraction * image_h)
}
async def scan_image_chunk(sess, image):
timestamp = int(time.time() * 1000)
url = f"https://lens.google.com/v3/upload?stcs={timestamp}"
headers = {"User-Agent": "Mozilla/5.0 (Linux; Android 13; RMX3771) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.6167.144 Mobile Safari/537.36"}
cookies = {"SOCS": "CAESEwgDEgk0ODE3Nzk3MjQaAmVuIAEaBgiA_LyaBg"}
# send data to inscrutable undocumented Google service
# https://github.com/AuroraWright/owocr/blob/master/owocr/ocr.py#L193
async with aiohttp.ClientSession() as sess:
data = aiohttp.FormData()
data.add_field(
"encoded_image",
encode_img(image),
filename="ocr" + str(timestamp) + ".png",
content_type="image/png"
)
async with sess.post(url, headers=headers, cookies=cookies, data=data, timeout=10) as res:
body = await res.text()
# I really worry about Google sometimes. This is not a sensible format.
match = CALLBACK_REGEX.search(body)
if match == None:
raise ValueError("Invalid callback")
lens_object = pyjson5.loads(match.group(1))
if "errorHasStatus" in lens_object:
raise RuntimeError("Lens failed")
text_segments = []
text_regions = []
root = lens_object["data"]
# I don't know why Google did this.
# Text segments are in one place and their locations are in another, using a very strange coordinate system.
# At least I don't need whatever is contained in the base64 parts (which I assume are protobufs).
# TODO: on a few images, this seems to not work for some reason.
try:
text_segments = root[3][4][0][0]
text_regions = [ rationalize_coords_format1(image.width, image.height, *x[1]) for x in root[2][3][0] if x[11].startswith("text:") ]
except (KeyError, IndexError):
# https://github.com/dimdenGD/chrome-lens-ocr/blob/main/src/core.js#L316 has code for a fallback text segment read mode.
# In testing, this proved unnecessary (quirks of the HTTP request? I don't know), and this only happens on textless images.
return [], []
return text_segments, text_regions
MAX_SCAN_DIM = 1000 # not actually true but close enough
def chunk_image(image: Image):
chunks = []
# Cut image down in X axis (I'm assuming images aren't too wide to scan in downscaled form because merging text horizontally would be annoying)
if image.width > MAX_SCAN_DIM:
image = image.resize((MAX_SCAN_DIM, round(image.height * (image.width / MAX_SCAN_DIM))), Image.LANCZOS)
for y in range(0, image.height, MAX_SCAN_DIM):
chunks.append(image.crop((0, y, image.width, min(y + MAX_SCAN_DIM, image.height))))
return chunks
async def scan_chunks(sess: aiohttp.ClientSession, chunks: [Image]):
# If text happens to be split across the cut line it won't get read.
# This is because doing overlap read areas would be really annoying.
text = ""
regions = []
for chunk in chunks:
new_segments, new_regions = await scan_image_chunk(sess, chunk)
for segment in new_segments:
text += segment + "\n"
for i, (segment, region) in enumerate(zip(new_segments, new_regions)):
regions.append({ **region, "y": region["y"] + (MAX_SCAN_DIM * i), "text": segment })
return text, regions
async def scan_image(sess: aiohttp.ClientSession, image: Image):
return await scan_chunks(sess, chunk_image(image))
if __name__ == "__main__":
async def main():
async with aiohttp.ClientSession() as sess:
print(await scan_image(sess, Image.open("/data/public/memes-or-something/linear-algebra-chess.png")))
asyncio.run(main())

892
src/main.rs Normal file
View File

@ -0,0 +1,892 @@
use std::{collections::HashMap, io::Cursor};
use std::path::Path;
use std::sync::Arc;
use anyhow::{Result, Context};
use axum::body::Body;
use axum::response::Response;
use axum::{
extract::Json,
response::IntoResponse,
routing::{get, post},
Router,
http::StatusCode
};
use image::{imageops::FilterType, io::Reader as ImageReader, DynamicImage, ImageFormat};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use sqlx::{sqlite::SqliteConnectOptions, SqlitePool};
use tokio::sync::{broadcast, mpsc};
use tokio::task::JoinHandle;
use walkdir::WalkDir;
use base64::prelude::*;
use faiss::Index;
use futures_util::stream::{StreamExt, TryStreamExt};
use tokio_stream::wrappers::ReceiverStream;
use tower_http::cors::CorsLayer;
mod ocr;
use crate::ocr::scan_image;
fn function_which_returns_50() -> usize { 50 }
#[derive(Debug, Deserialize, Clone)]
struct Config {
clip_server: String,
db_path: String,
port: u16,
files: String,
#[serde(default)]
enable_ocr: bool,
#[serde(default)]
thumbs_path: String,
#[serde(default)]
enable_thumbs: bool,
#[serde(default="function_which_returns_50")]
ocr_concurrency: usize,
#[serde(default)]
no_run_server: bool
}
#[derive(Debug)]
struct IIndex {
vectors: faiss::index::IndexImpl,
filenames: Vec<String>,
format_codes: Vec<u64>,
format_names: Vec<String>,
}
const SCHEMA: &str = r#"
CREATE TABLE IF NOT EXISTS files (
filename TEXT NOT NULL PRIMARY KEY,
embedding_time INTEGER,
ocr_time INTEGER,
thumbnail_time INTEGER,
embedding BLOB,
ocr TEXT,
raw_ocr_segments BLOB,
thumbnails BLOB
);
CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
filename,
ocr,
tokenize='unicode61 remove_diacritics 2',
content='ocr'
);
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN
INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
END;
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER DELETE ON files BEGIN
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
END;
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER UPDATE ON files BEGIN
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
END;
"#;
#[derive(Debug, sqlx::FromRow, Clone, Default)]
struct FileRecord {
filename: String,
embedding_time: Option<i64>,
ocr_time: Option<i64>,
thumbnail_time: Option<i64>,
embedding: Option<Vec<u8>>,
// this totally "will" be used later
ocr: Option<String>,
raw_ocr_segments: Option<Vec<u8>>,
thumbnails: Option<Vec<u8>>,
}
#[derive(Debug, Deserialize, Clone)]
struct InferenceServerConfig {
batch: usize,
image_size: (u32, u32),
embedding_size: usize,
}
async fn query_clip_server<I, O>(
client: &Client,
config: &Config,
path: &str,
data: I,
) -> Result<O> where I: Serialize, O: serde::de::DeserializeOwned,
{
let response = client
.post(&format!("{}{}", config.clip_server, path))
.header("Content-Type", "application/msgpack")
.body(rmp_serde::to_vec_named(&data)?)
.send()
.await?;
let result: O = rmp_serde::from_slice(&response.bytes().await?)?;
Ok(result)
}
#[derive(Debug)]
struct LoadedImage {
image: Arc<DynamicImage>,
filename: String,
original_size: usize,
}
#[derive(Debug)]
struct EmbeddingInput {
image: Vec<u8>,
filename: String,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
enum EmbeddingRequest {
Images { images: Vec<serde_bytes::ByteBuf> },
Text { text: Vec<String> }
}
fn timestamp() -> i64 {
chrono::Utc::now().timestamp_micros()
}
#[derive(Debug, Clone)]
struct ImageFormatConfig {
target_width: u32,
target_filesize: u32,
quality: u8,
format: ImageFormat,
extension: String,
}
fn generate_filename_hash(filename: &str) -> String {
use std::hash::{Hash, Hasher};
let mut hasher = fnv::FnvHasher::default();
filename.hash(&mut hasher);
BASE64_URL_SAFE_NO_PAD.encode(hasher.finish().to_le_bytes())
}
fn generate_thumbnail_filename(
filename: &str,
format_name: &str,
format_config: &ImageFormatConfig,
) -> String {
format!(
"{}{}.{}",
generate_filename_hash(filename),
format_name,
format_config.extension
)
}
async fn initialize_database(config: &Config) -> Result<SqlitePool> {
let connection_options = SqliteConnectOptions::new()
.filename(&config.db_path)
.create_if_missing(true);
let pool = SqlitePool::connect_with(connection_options).await?;
sqlx::query(SCHEMA).execute(&pool).await?;
Ok(pool)
}
fn image_formats(_config: &Config) -> HashMap<String, ImageFormatConfig> {
let mut formats = HashMap::new();
formats.insert(
"jpegl".to_string(),
ImageFormatConfig {
target_width: 800,
target_filesize: 0,
quality: 70,
format: ImageFormat::Jpeg,
extension: "jpg".to_string(),
},
);
formats.insert(
"jpegh".to_string(),
ImageFormatConfig {
target_width: 1600,
target_filesize: 0,
quality: 80,
format: ImageFormat::Jpeg,
extension: "jpg".to_string(),
},
);
formats.insert(
"jpeg256kb".to_string(),
ImageFormatConfig {
target_width: 500,
target_filesize: 256000,
quality: 0,
format: ImageFormat::Jpeg,
extension: "jpg".to_string(),
},
);
formats.insert(
"avifh".to_string(),
ImageFormatConfig {
target_width: 1600,
target_filesize: 0,
quality: 80,
format: ImageFormat::Avif,
extension: "avif".to_string(),
},
);
formats.insert(
"avifl".to_string(),
ImageFormatConfig {
target_width: 800,
target_filesize: 0,
quality: 30,
format: ImageFormat::Avif,
extension: "avif".to_string(),
},
);
formats
}
async fn resize_for_embed(backend_config: Arc<InferenceServerConfig>, image: Arc<DynamicImage>) -> Result<Vec<u8>> {
let resized = tokio::task::spawn_blocking(move || {
let new = image.resize(
backend_config.image_size.0,
backend_config.image_size.1,
FilterType::Lanczos3
);
let mut buf = Vec::new();
let mut csr = Cursor::new(&mut buf);
new.write_to(&mut csr, ImageFormat::Png)?;
Ok::<Vec<u8>, anyhow::Error>(buf)
}).await??;
Ok(resized)
}
async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>) -> Result<()> {
let pool = initialize_database(&config).await?;
let client = Client::new();
let formats = image_formats(&config);
let (to_process_tx, to_process_rx) = mpsc::channel::<FileRecord>(100);
let (to_embed_tx, to_embed_rx) = mpsc::channel(backend.batch as usize);
let (to_thumbnail_tx, to_thumbnail_rx) = mpsc::channel(30);
let (to_ocr_tx, to_ocr_rx) = mpsc::channel(30);
let cpus = num_cpus::get();
// Image loading and preliminary resizing
let image_loading: JoinHandle<Result<()>> = tokio::spawn({
let config = config.clone();
let backend = backend.clone();
let stream = ReceiverStream::new(to_process_rx).map(Ok);
stream.try_for_each_concurrent(Some(cpus), move |record| {
let config = config.clone();
let backend = backend.clone();
let to_embed_tx = to_embed_tx.clone();
let to_thumbnail_tx = to_thumbnail_tx.clone();
let to_ocr_tx = to_ocr_tx.clone();
async move {
let path = Path::new(&config.files).join(&record.filename);
let image: Result<Arc<DynamicImage>> = tokio::task::block_in_place(|| Ok(Arc::new(ImageReader::open(&path)?.with_guessed_format()?.decode()?)));
let image = match image {
Ok(image) => image,
Err(e) => {
log::error!("Could not read {}: {}", record.filename, e);
return Ok(())
}
};
if record.embedding.is_none() {
let resized = resize_for_embed(backend.clone(), image.clone()).await?;
to_embed_tx.send(EmbeddingInput { image: resized, filename: record.filename.clone() }).await?
}
if record.thumbnails.is_none() && config.enable_thumbs {
to_thumbnail_tx
.send(LoadedImage {
image: image.clone(),
filename: record.filename.clone(),
original_size: std::fs::metadata(&path)?.len() as usize,
})
.await?;
}
if record.raw_ocr_segments.is_none() && config.enable_ocr {
to_ocr_tx
.send(LoadedImage {
image,
filename: record.filename.clone(),
original_size: 0,
})
.await?;
}
Ok(())
}
})
});
// Thumbnail generation
let thumbnail_generation: Option<JoinHandle<Result<()>>> = if config.enable_thumbs {
let config = config.clone();
let pool = pool.clone();
let stream = ReceiverStream::new(to_thumbnail_rx).map(Ok);
let formats = Arc::new(formats);
Some(tokio::spawn({
stream.try_for_each_concurrent(Some(cpus), move |image| {
use image::codecs::*;
let formats = formats.clone();
let config = config.clone();
let pool = pool.clone();
async move {
let filename = image.filename.clone();
log::debug!("thumbnailing {}", filename);
let generated_formats = tokio::task::spawn_blocking(move || {
let mut generated_formats = Vec::new();
let rgb = DynamicImage::from(image.image.to_rgb8());
for (format_name, format_config) in &*formats {
let resized = if format_config.target_filesize != 0 {
let mut lb = 1;
let mut ub = 100;
loop {
let quality = (lb + ub) / 2;
let thumbnail = rgb.resize(
format_config.target_width,
u32::MAX,
FilterType::Lanczos3,
);
let mut buf: Vec<u8> = Vec::new();
let mut csr = Cursor::new(&mut buf);
// this is ugly but I don't actually know how to fix it (cannot factor it out due to issues with dyn Trait)
match format_config.format {
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, quality)),
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, quality)),
_ => unimplemented!()
}?;
if buf.len() > image.original_size {
ub = quality;
} else {
lb = quality + 1;
}
if lb >= ub {
break buf;
}
}
} else {
let thumbnail = rgb.resize(
format_config.target_width,
u32::MAX,
FilterType::Lanczos3,
);
let mut buf: Vec<u8> = Vec::new();
let mut csr = Cursor::new(&mut buf);
match format_config.format {
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, format_config.quality)),
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, format_config.quality)),
ImageFormat::WebP => thumbnail.write_with_encoder(webp::WebPEncoder::new_lossless(&mut csr)),
_ => unimplemented!()
}?;
buf
};
if resized.len() < image.original_size {
generated_formats.push(format_name.clone());
let thumbnail_path = Path::new(&config.thumbs_path).join(
generate_thumbnail_filename(
&image.filename,
format_name,
format_config,
),
);
std::fs::write(thumbnail_path, resized)?;
}
}
Ok::<Vec<String>, anyhow::Error>(generated_formats)
}).await??;
let formats_data = rmp_serde::to_vec(&generated_formats)?;
let ts = timestamp();
sqlx::query!(
"UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?",
formats_data,
ts,
filename
)
.execute(&pool)
.await?;
Ok(())
}
})
}))
} else {
None
};
// OCR
let ocr: Option<JoinHandle<Result<()>>> = if config.enable_ocr {
let client = client.clone();
let pool = pool.clone();
let stream = ReceiverStream::new(to_ocr_rx).map(Ok);
Some(tokio::spawn({
stream.try_for_each_concurrent(Some(config.ocr_concurrency), move |image| {
let client = client.clone();
let pool = pool.clone();
async move {
log::debug!("OCRing {}", image.filename);
let scan = match scan_image(&client, &image.image).await {
Ok(scan) => scan,
Err(e) => {
log::error!("OCR failure {}: {}", image.filename, e);
return Ok(())
}
};
let ocr_text = scan
.iter()
.map(|segment| segment.text.clone())
.collect::<Vec<_>>()
.join("\n");
let ocr_data = rmp_serde::to_vec(&scan)?;
let ts = timestamp();
sqlx::query!(
"UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?",
ocr_text,
ocr_data,
ts,
image.filename
)
.execute(&pool)
.await?;
Ok(())
}
})
}))
} else {
None
};
let embedding_generation: JoinHandle<Result<()>> = tokio::spawn({
let stream = ReceiverStream::new(to_embed_rx).chunks(backend.batch);
let client = client.clone();
let config = config.clone();
let pool = pool.clone();
// keep multiple embedding requests in flight
stream.map(Ok).try_for_each_concurrent(Some(3), move |batch| {
let client = client.clone();
let config = config.clone();
let pool = pool.clone();
async move {
let result: Vec<serde_bytes::ByteBuf> = query_clip_server(
&client,
&config,
"",
EmbeddingRequest::Images {
images: batch.iter().map(|input| serde_bytes::ByteBuf::from(input.image.clone())).collect(),
},
).await.context("querying CLIP server")?;
let mut tx = pool.begin().await?;
let ts = timestamp();
for (i, vector) in result.into_iter().enumerate() {
let vector = vector.into_vec();
log::debug!("embedded {}", batch[i].filename);
sqlx::query!(
"UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?",
ts,
vector,
batch[i].filename
)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
anyhow::Result::Ok(())
}
})
});
let mut filenames = HashMap::new();
// blocking OS calls
tokio::task::block_in_place(|| -> anyhow::Result<()> {
for entry in WalkDir::new(config.files.as_str()) {
let entry = entry?;
let path = entry.path();
if path.is_file() {
let filename = path.strip_prefix(&config.files)?.to_str().unwrap().to_string();
let modtime = entry.metadata()?.modified()?.duration_since(std::time::UNIX_EPOCH)?;
let modtime = modtime.as_micros() as i64;
filenames.insert(filename.clone(), (path.to_path_buf(), modtime));
}
}
Ok(())
})?;
log::debug!("finished reading filenames");
for (filename, (_path, modtime)) in filenames.iter() {
let modtime = *modtime;
let record = sqlx::query_as!(FileRecord, "SELECT * FROM files WHERE filename = ?", filename)
.fetch_optional(&pool)
.await?;
let new_record = match record {
None => Some(FileRecord {
filename: filename.clone(),
..Default::default()
}),
Some(r) if modtime > r.embedding_time.unwrap_or(i64::MIN) || (modtime > r.ocr_time.unwrap_or(i64::MIN) && config.enable_ocr) || (modtime > r.thumbnail_time.unwrap_or(i64::MIN) && config.enable_thumbs) => {
Some(r)
},
_ => None
};
if let Some(mut record) = new_record {
log::debug!("processing {}", record.filename);
sqlx::query!("INSERT OR IGNORE INTO files (filename) VALUES (?)", filename)
.execute(&pool)
.await?;
if modtime > record.embedding_time.unwrap_or(i64::MIN) {
record.embedding = None;
}
if modtime > record.ocr_time.unwrap_or(i64::MIN) {
record.raw_ocr_segments = None;
}
if modtime > record.thumbnail_time.unwrap_or(i64::MIN) {
record.thumbnails = None;
}
// we need to exit here to actually capture the error
if !to_process_tx.send(record).await.is_ok() {
break
}
}
}
drop(to_process_tx);
embedding_generation.await?.context("generating embeddings")?;
if let Some(thumbnail_generation) = thumbnail_generation {
thumbnail_generation.await?.context("generating thumbnails")?;
}
if let Some(ocr) = ocr {
ocr.await?.context("OCRing")?;
}
image_loading.await?.context("loading images")?;
let stored: Vec<String> = sqlx::query_scalar("SELECT filename FROM files").fetch_all(&pool).await?;
let mut tx = pool.begin().await?;
for filename in stored {
if !filenames.contains_key(&filename) {
sqlx::query!("DELETE FROM files WHERE filename = ?", filename)
.execute(&mut *tx)
.await?;
}
}
tx.commit().await?;
log::info!("ingest done");
Result::Ok(())
}
const INDEX_ADD_BATCH: usize = 512;
async fn build_index(config: Arc<Config>, backend: Arc<InferenceServerConfig>) -> Result<IIndex> {
let pool = initialize_database(&config).await?;
let mut index = IIndex {
// Use a suitable vector similarity search library for Rust
vectors: faiss::index_factory(backend.embedding_size as u32, "SQfp16", faiss::MetricType::InnerProduct)?,
filenames: Vec::new(),
format_codes: Vec::new(),
format_names: Vec::new(),
};
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM files")
.fetch_one(&pool)
.await?;
index.filenames = Vec::with_capacity(count as usize);
index.format_codes = Vec::with_capacity(count as usize);
let mut buffer = Vec::with_capacity(INDEX_ADD_BATCH * backend.embedding_size as usize);
index.format_names = Vec::with_capacity(5);
let mut rows = sqlx::query_as::<_, FileRecord>("SELECT * FROM files").fetch(&pool);
while let Some(record) = rows.try_next().await? {
if let Some(emb) = record.embedding {
index.filenames.push(record.filename);
for i in (0..emb.len()).step_by(2) {
buffer.push(
half::f16::from_le_bytes([emb[i], emb[i + 1]])
.to_f32(),
);
}
if buffer.len() == buffer.capacity() {
index.vectors.add(&buffer)?;
buffer.clear();
}
let mut formats: Vec<String> = Vec::new();
if let Some(t) = record.thumbnails {
formats = rmp_serde::from_slice(&t)?;
}
let mut format_code = 0;
for format_string in &formats {
let mut found = false;
for (i, name) in index.format_names.iter().enumerate() {
if name == format_string {
format_code |= 1 << i;
found = true;
break;
}
}
if !found {
let new_index = index.format_names.len();
format_code |= 1 << new_index;
index.format_names.push(format_string.clone());
}
}
index.format_codes.push(format_code);
}
}
if !buffer.is_empty() {
index.vectors.add(&buffer)?;
}
Ok(index)
}
fn decode_fp16_buffer(buf: &[u8]) -> Vec<f32> {
buf.chunks_exact(2)
.map(|chunk| half::f16::from_le_bytes([chunk[0], chunk[1]]).to_f32())
.collect()
}
type EmbeddingVector = Vec<f32>;
#[derive(Debug, Serialize)]
struct QueryResult {
matches: Vec<(f32, String, String, u64)>,
formats: Vec<String>,
extensions: HashMap<String, String>,
}
#[derive(Debug, Deserialize)]
struct QueryTerm {
embedding: Option<EmbeddingVector>,
image: Option<String>,
text: Option<String>,
weight: Option<f32>,
}
#[derive(Debug, Deserialize)]
struct QueryRequest {
terms: Vec<QueryTerm>,
k: Option<usize>,
}
async fn query_index(index: &mut IIndex, query: EmbeddingVector, k: usize) -> Result<QueryResult> {
let result = index.vectors.search(&query, k as usize)?;
let items = result.distances
.into_iter()
.zip(result.labels)
.filter_map(|(distance, id)| {
let id = id.get()? as usize;
Some((
distance,
index.filenames[id].clone(),
generate_filename_hash(&index.filenames[id as usize]).clone(),
index.format_codes[id]
))
})
.collect();
Ok(QueryResult {
matches: items,
formats: index.format_names.clone(),
extensions: HashMap::new(),
})
}
async fn handle_request(
config: &Config,
backend_config: Arc<InferenceServerConfig>,
client: Arc<Client>,
index: &mut IIndex,
req: Json<QueryRequest>,
) -> Result<Response<Body>> {
let mut total_embedding = ndarray::Array::from(vec![0.0; backend_config.embedding_size]);
let mut image_batch = Vec::new();
let mut image_weights = Vec::new();
let mut text_batch = Vec::new();
let mut text_weights = Vec::new();
for term in &req.terms {
if let Some(image) = &term.image {
let bytes = BASE64_STANDARD.decode(image)?;
let image = Arc::new(tokio::task::block_in_place(|| image::load_from_memory(&bytes))?);
image_batch.push(serde_bytes::ByteBuf::from(resize_for_embed(backend_config.clone(), image).await?));
image_weights.push(term.weight.unwrap_or(1.0));
}
if let Some(text) = &term.text {
text_batch.push(text.clone());
text_weights.push(term.weight.unwrap_or(1.0));
}
if let Some(embedding) = &term.embedding {
let weight = term.weight.unwrap_or(1.0);
for (i, value) in embedding.iter().enumerate() {
total_embedding[i] += value * weight;
}
}
}
let mut batches = vec![];
if !image_batch.is_empty() {
batches.push(
EmbeddingRequest::Images {
images: image_batch
}
);
}
if !text_batch.is_empty() {
batches.push(
EmbeddingRequest::Text {
text: text_batch,
}
);
}
for batch in batches {
let embs: Vec<Vec<u8>> = query_clip_server(&client, config, "/", batch).await?;
for emb in embs {
total_embedding += &ndarray::Array::from_vec(decode_fp16_buffer(&emb));
}
}
let k = req.k.unwrap_or(1000);
let qres = query_index(index, total_embedding.to_vec(), k).await?;
let mut extensions = HashMap::new();
for (k, v) in image_formats(config) {
extensions.insert(k, v.extension);
}
Ok(Json(QueryResult {
matches: qres.matches,
formats: qres.formats,
extensions,
}).into_response())
}
async fn get_backend_config(config: &Config) -> Result<InferenceServerConfig> {
let res = Client::new().get(&format!("{}/config", config.clip_server)).send().await?;
Ok(rmp_serde::from_slice(&res.bytes().await?)?)
}
#[tokio::main]
async fn main() -> Result<()> {
pretty_env_logger::init();
let config_path = std::env::args().nth(1).expect("Missing config file path");
let config: Arc<Config> = Arc::new(serde_json::from_slice(&std::fs::read(config_path)?)?);
let pool = initialize_database(&config).await?;
sqlx::query(SCHEMA).execute(&pool).await?;
let backend = Arc::new(loop {
match get_backend_config(&config).await {
Ok(backend) => break backend,
Err(e) => {
log::error!("Backend failed (fetch): {}", e);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
}
});
if config.no_run_server {
ingest_files(config.clone(), backend.clone()).await?;
return Ok(())
}
let (request_ingest_tx, mut request_ingest_rx) = mpsc::channel(1);
let index = Arc::new(tokio::sync::Mutex::new(build_index(config.clone(), backend.clone()).await?));
let (ingest_done_tx, _ingest_done_rx) = broadcast::channel(1);
let done_tx = Arc::new(ingest_done_tx.clone());
let _ingest_task = tokio::spawn({
let config = config.clone();
let backend = backend.clone();
let index = index.clone();
async move {
loop {
log::info!("Ingest running");
match ingest_files(config.clone(), backend.clone()).await {
Ok(_) => {
match build_index(config.clone(), backend.clone()).await {
Ok(new_index) => {
*index.lock().await = new_index;
}
Err(e) => {
log::error!("Index build failed: {:?}", e);
ingest_done_tx.send((false, format!("{:?}", e))).unwrap();
}
}
}
Err(e) => {
log::error!("Ingest failed: {:?}", e);
ingest_done_tx.send((false, format!("{:?}", e))).unwrap();
}
}
ingest_done_tx.send((true, format!("OK"))).unwrap();
request_ingest_rx.recv().await;
}
}
});
let cors = CorsLayer::permissive();
let config_ = config.clone();
let client = Arc::new(Client::new());
let app = Router::new()
.route("/", post(|req| async move {
let config = config.clone();
let backend_config = backend.clone();
let mut index = index.lock().await; // TODO: use ConcurrentIndex here
let client = client.clone();
handle_request(&config, backend_config, client.clone(), &mut index, req).await.map_err(|e| format!("{:?}", e))
}))
.route("/", get(|_req: axum::http::Request<axum::body::Body>| async move {
"OK"
}))
.route("/reload", post(|_req: axum::http::Request<axum::body::Body>| async move {
log::info!("Requesting index reload");
let mut done_rx = done_tx.clone().subscribe();
let _ = request_ingest_tx.send(()).await; // ignore possible error, which is presumably because the queue is full
match done_rx.recv().await {
Ok((true, status)) => {
let mut res = status.into_response();
*res.status_mut() = StatusCode::OK;
res
},
Ok((false, status)) => {
let mut res = status.into_response();
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
res
},
Err(_) => {
let mut res = "internal error".into_response();
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
res
}
}
}))
.layer(cors);
let addr = format!("0.0.0.0:{}", config_.port);
log::info!("Starting server on {}", addr);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app).await?;
Ok(())
}

173
src/ocr.rs Normal file
View File

@ -0,0 +1,173 @@
use anyhow::{anyhow, Result};
use image::{DynamicImage, GenericImageView, ImageFormat};
use regex::Regex;
use reqwest::{
header::{HeaderMap, HeaderValue},
multipart::{Form, Part},
Client,
};
use serde_json::Value;
use std::{io::Cursor, time::{SystemTime, UNIX_EPOCH}};
use serde::{Deserialize, Serialize};
const CALLBACK_REGEX: &str = r">AF_initDataCallback\((\{key: 'ds:1'.*?\})\);</script>";
const MAX_DIM: u32 = 1024;
#[derive(Debug, Serialize, Deserialize)]
pub struct SegmentCoords {
pub x: i32,
pub y: i32,
pub w: i32,
pub h: i32,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Segment {
pub coords: SegmentCoords,
pub text: String,
}
pub type ScanResult = Vec<Segment>;
fn rationalize_coords_format1(
image_w: f64,
image_h: f64,
center_x_fraction: f64,
center_y_fraction: f64,
width_fraction: f64,
height_fraction: f64,
) -> SegmentCoords {
SegmentCoords {
x: ((center_x_fraction - width_fraction / 2.0) * image_w).round() as i32,
y: ((center_y_fraction - height_fraction / 2.0) * image_h).round() as i32,
w: (width_fraction * image_w).round() as i32,
h: (height_fraction * image_h).round() as i32,
}
}
async fn scan_image_chunk(
client: &Client,
image: &[u8],
image_width: u32,
image_height: u32,
) -> Result<ScanResult> {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_micros();
let part = Part::bytes(image.to_vec())
.file_name(format!("ocr{}.png", timestamp))
.mime_str("image/png")?;
let form = Form::new().part("encoded_image", part);
let mut headers = HeaderMap::new();
headers.insert(
"User-Agent",
HeaderValue::from_static("Mozilla/5.0 (Linux; Android 13; RMX3771) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.6167.144 Mobile Safari/537.36"),
);
headers.insert("Cookie", HeaderValue::from_str(&format!("SOCS=CAESEwgDEgk0ODE3Nzk3MjQaAmVuIAEaBgiA_LyaBg; stcs={}", timestamp))?);
let response = client
.post(&format!("https://lens.google.com/v3/upload?stcs={}", timestamp))
.multipart(form)
.headers(headers)
.send()
.await?;
let body = response.text().await?;
let re = Regex::new(CALLBACK_REGEX)?;
let captures = re
.captures(&body)
.ok_or_else(|| anyhow!("invalid API response"))?;
let match_str = captures.get(1).unwrap().as_str();
let lens_object: Value = json5::from_str(match_str)?;
if lens_object.get("errorHasStatus").is_some() {
return Err(anyhow!("lens failed"));
}
let root = lens_object["data"].as_array().unwrap();
let mut text_segments = Vec::new();
let mut text_regions = Vec::new();
let text_segments_raw = root[3][4][0][0]
.as_array()
.ok_or_else(|| anyhow!("invalid text segments"))?;
let text_regions_raw = root[2][3][0]
.as_array()
.ok_or_else(|| anyhow!("invalid text regions"))?;
for region in text_regions_raw {
let region_data = region.as_array().unwrap();
if region_data[11].as_str().unwrap().starts_with("text:") {
let raw_coords = region_data[1].as_array().unwrap();
let coords = rationalize_coords_format1(
image_width as f64,
image_height as f64,
raw_coords[0].as_f64().unwrap(),
raw_coords[1].as_f64().unwrap(),
raw_coords[2].as_f64().unwrap(),
raw_coords[3].as_f64().unwrap(),
);
text_regions.push(coords);
}
}
for segment in text_segments_raw {
let text_segment = segment.as_str().unwrap().to_string();
text_segments.push(text_segment);
}
Ok(text_segments
.into_iter()
.zip(text_regions.into_iter())
.map(|(text, coords)| Segment { text, coords })
.collect())
}
pub async fn scan_image(client: &Client, image: &DynamicImage) -> Result<ScanResult> {
let mut result = ScanResult::new();
let (width, height) = image.dimensions();
let (width, height, image) = if width > MAX_DIM {
let height = ((height as f64) * (MAX_DIM as f64) / (width as f64)).round() as u32;
let new_image = tokio::task::block_in_place(|| image.resize_exact(MAX_DIM, height, image::imageops::FilterType::Lanczos3));
(MAX_DIM, height, std::borrow::Cow::Owned(new_image))
} else {
(width, height, std::borrow::Cow::Borrowed(image))
};
let mut y = 0;
while y < height {
let chunk_height = (height - y).min(MAX_DIM);
let chunk = tokio::task::block_in_place(|| {
let chunk = image.view(0, y, width, chunk_height).to_image();
let mut buf = Vec::new();
let mut csr = Cursor::new(&mut buf);
chunk.write_to(&mut csr, ImageFormat::Png)?;
Ok::<Vec<u8>, anyhow::Error>(buf)
})?;
let res = scan_image_chunk(client, &chunk, width, chunk_height).await?;
for segment in res {
result.push(Segment {
text: segment.text,
coords: SegmentCoords {
y: segment.coords.y + y as i32,
x: segment.coords.x,
w: segment.coords.w,
h: segment.coords.h,
},
});
}
y += chunk_height;
}
Ok(result)
}