mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-10 22:09:54 +00:00
Rewrite entire application (well, backend) in Rust and also Go
I decided I wanted to integrate the experimental OCR thing better, so I rewrote in Go and also integrated the thumbnailer. However, Go is a bad langauge and I only used it out of spite. It turned out to have a very hard-to-fix memory leak due to some unclear interaction between libvips and both sets of bindings I tried, so I had Claude-3 transpile it to Rust then spent a while fixing the several mistakes it made and making tweaks. The new Rust version works, although I need to actually do something with the OCR data and make the index queryable concurrently.
This commit is contained in:
parent
fa863c2075
commit
7cb42e028f
3
.gitignore
vendored
3
.gitignore
vendored
@ -6,3 +6,6 @@ meme-rater/meta/
|
||||
meme-rater/*.sqlite3*
|
||||
meme-rater/deploy_for_training.sh
|
||||
node_modules/*
|
||||
node_modules
|
||||
*sqlite3*
|
||||
thumbtemp
|
@ -0,0 +1,12 @@
|
||||
{
|
||||
"db_name": "SQLite",
|
||||
"query": "INSERT OR IGNORE INTO files (filename) VALUES (?)",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Right": 1
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "0d5b91c01acf72be0cd78f1a0c58c417e06d7c4e53e1ec542243ccf2808bbab7"
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
{
|
||||
"db_name": "SQLite",
|
||||
"query": "UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Right": 4
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "63edaa9692deb1a9fb17d9e16905a299878bc0fe4af582c6791a411741ee41d9"
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
{
|
||||
"db_name": "SQLite",
|
||||
"query": "UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Right": 3
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "b6803e2443445de725290dde85de5b5ef87958bec4d6db6bd06660b71f7a1ad0"
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
{
|
||||
"db_name": "SQLite",
|
||||
"query": "UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Right": 3
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "bed71d48c691bff7464b1aa767162df98ee2fcbe8df11f7db18a3647b2e0f1a2"
|
||||
}
|
@ -0,0 +1,62 @@
|
||||
{
|
||||
"db_name": "SQLite",
|
||||
"query": "SELECT * FROM files WHERE filename = ?",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "filename",
|
||||
"ordinal": 0,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "embedding_time",
|
||||
"ordinal": 1,
|
||||
"type_info": "Int64"
|
||||
},
|
||||
{
|
||||
"name": "ocr_time",
|
||||
"ordinal": 2,
|
||||
"type_info": "Int64"
|
||||
},
|
||||
{
|
||||
"name": "thumbnail_time",
|
||||
"ordinal": 3,
|
||||
"type_info": "Int64"
|
||||
},
|
||||
{
|
||||
"name": "embedding",
|
||||
"ordinal": 4,
|
||||
"type_info": "Blob"
|
||||
},
|
||||
{
|
||||
"name": "ocr",
|
||||
"ordinal": 5,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "raw_ocr_segments",
|
||||
"ordinal": 6,
|
||||
"type_info": "Blob"
|
||||
},
|
||||
{
|
||||
"name": "thumbnails",
|
||||
"ordinal": 7,
|
||||
"type_info": "Blob"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Right": 1
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "ec2da4ab11ede7a9a468ff3a50c55e0f6503fddd369f2c3031f39c0759bb97a0"
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
{
|
||||
"db_name": "SQLite",
|
||||
"query": "DELETE FROM files WHERE filename = ?",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Right": 1
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "ee6eca5b34c3fbf76cd10932db35c6d8631e48be9166c02b593020a17fcf2686"
|
||||
}
|
3320
Cargo.lock
generated
Normal file
3320
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
34
Cargo.toml
Normal file
34
Cargo.toml
Normal file
@ -0,0 +1,34 @@
|
||||
[package]
|
||||
name = "meme-search-engine"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
axum = "0.7"
|
||||
image = { version = "0.25", features = ["avif"] }
|
||||
reqwest = { version = "0.12", features = ["multipart"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite"] }
|
||||
walkdir = "1"
|
||||
log = "0.4"
|
||||
rmp-serde = "1"
|
||||
serde_json = "1"
|
||||
chrono = "0.4"
|
||||
base64 = "0.22"
|
||||
anyhow = "1"
|
||||
fnv = "1"
|
||||
faiss = "0.12"
|
||||
ndarray = "0.15"
|
||||
half = { version = "2" }
|
||||
regex = "1"
|
||||
pretty_env_logger = "0.5"
|
||||
futures-util = "0.3"
|
||||
tokio-stream = "0.1"
|
||||
num_cpus = "1"
|
||||
serde_bytes = "0.11"
|
||||
tower-http = { version = "0.5", features = ["cors"] }
|
||||
tower = "0.4"
|
||||
json5 = "0.4"
|
@ -62,6 +62,8 @@
|
||||
|
||||
.result
|
||||
border: 1px solid gray
|
||||
*
|
||||
display: block
|
||||
.result img
|
||||
width: 100%
|
||||
</style>
|
||||
@ -109,17 +111,22 @@
|
||||
<Loading />
|
||||
{/if}
|
||||
{#if results}
|
||||
{#if displayedResults.length === 0}
|
||||
No results. Wait for index rebuild.
|
||||
{/if}
|
||||
<Masonry bind:refreshLayout={refreshLayout} colWidth="minmax(Min(20em, 100%), 1fr)" items={displayedResults}>
|
||||
{#each displayedResults as result}
|
||||
{#key result.file}
|
||||
<div class="result">
|
||||
<a href={util.getURL(result.file)}>
|
||||
<a href={util.getURL(result)}>
|
||||
<picture>
|
||||
{#if util.hasThumbnails(result.file)}
|
||||
<source srcset={util.thumbnailPath(result.file, "avif-lq") + ", " + util.thumbnailPath(result.file, "avif-hq") + " 2x"} type="image/avif" />
|
||||
<source srcset={util.thumbnailPath(result.file, "jpeg-800") + " 800w, " + util.thumbnailPath(result.file, "jpeg-fullscale")} type="image/jpeg" />
|
||||
{#if util.hasFormat(results, result, "avifl")}
|
||||
<source srcset={util.thumbnailURL(results, result, "avifl") + (util.hasFormat(results, result, "avifh") ? ", " + util.thumbnailURL(results, result, "avifh") + " 2x" : "")} type="image/avif" />
|
||||
{/if}
|
||||
<img src={util.getURL(result.file)} on:load={updateCounter} on:error={updateCounter} alt={result.caption || result.file}>
|
||||
{#if util.hasFormat(results, result, "jpegl")}
|
||||
<source srcset={util.thumbnailURL(results, result, "jpegl") + (util.hasFormat(results, result, "jpegh") ? ", " + util.thumbnailURL(results, result, "jpegh") + " 2x" : "")} type="image/jpeg" />
|
||||
{/if}
|
||||
<img src={util.getURL(result)} on:load={updateCounter} on:error={updateCounter} alt={result[1]}>
|
||||
</picture>
|
||||
</a>
|
||||
</div>
|
||||
@ -171,9 +178,7 @@
|
||||
let displayedResults = []
|
||||
const runSearch = async () => {
|
||||
if (!resultPromise) {
|
||||
let args = {}
|
||||
args.text = queryTerms.filter(x => x.type === "text" && x.text).map(({ text, weight, sign }) => [ text, weight * { "+": 1, "-": -1 }[sign] ])
|
||||
args.images = queryTerms.filter(x => x.type === "image").map(({ imageData, weight, sign }) => [ imageData, weight * { "+": 1, "-": -1 }[sign] ])
|
||||
let args = {"terms": queryTerms.map(x => ({ image: x.imageData, text: x.text, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))}
|
||||
resultPromise = util.doQuery(args).then(res => {
|
||||
error = null
|
||||
results = res
|
||||
@ -181,7 +186,8 @@
|
||||
displayedResults = []
|
||||
pendingImageLoads = 0
|
||||
for (let i = 0; i < chunkSize; i++) {
|
||||
displayedResults.push(results[i])
|
||||
if (i >= results.matches.length) break
|
||||
displayedResults.push(results.matches[i])
|
||||
pendingImageLoads += 1
|
||||
}
|
||||
redrawGrid()
|
||||
@ -195,7 +201,8 @@
|
||||
if (window.scrollY + window.innerHeight < heightThreshold) return;
|
||||
let init = displayedResults.length
|
||||
for (let i = 0; i < chunkSize; i++) {
|
||||
displayedResults.push(results[init + i])
|
||||
if (init + i >= results.matches.length) break
|
||||
displayedResults.push(results.matches[init + i])
|
||||
pendingImageLoads += 1
|
||||
}
|
||||
displayedResults = displayedResults
|
||||
|
@ -7,7 +7,7 @@ esbuild
|
||||
.build({
|
||||
entryPoints: [path.join(__dirname, "app.js")],
|
||||
bundle: true,
|
||||
minify: true,
|
||||
minify: false,
|
||||
outfile: path.join(__dirname, "../static/app.js"),
|
||||
plugins: [sveltePlugin({
|
||||
preprocess: {
|
||||
|
@ -1,7 +1,8 @@
|
||||
import * as config from "../../frontend_config.json"
|
||||
import * as backendConfig from "../../mse_config.json"
|
||||
import * as formats from "../../formats.json"
|
||||
|
||||
export const getURL = x => config.image_path + x
|
||||
export const getURL = x => config.image_path + x[1]
|
||||
|
||||
export const doQuery = args => fetch(config.backend_url, {
|
||||
method: "POST",
|
||||
@ -11,15 +12,11 @@ export const doQuery = args => fetch(config.backend_url, {
|
||||
body: JSON.stringify(args)
|
||||
}).then(x => x.json())
|
||||
|
||||
const filesafeCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-"
|
||||
export const thumbnailPath = (originalPath, format) => {
|
||||
const extension = formats.formats[format][0]
|
||||
// Python and JS have minor differences in string handling wrt. astral characters which could result in incorrect quantities of dashes. Fortunately, Array.from handles this correctly.
|
||||
return config.thumb_path + `${Array.from(originalPath).map(x => filesafeCharset.includes(x) ? x : "_").join("")}.${format}${extension}`
|
||||
export const hasFormat = (results, result, format) => {
|
||||
return result[3] && (1 << results.formats.indexOf(format)) !== 0
|
||||
}
|
||||
|
||||
const thumbedExtensions = formats.extensions
|
||||
export const hasThumbnails = t => {
|
||||
const parts = t.split(".")
|
||||
return thumbedExtensions.includes("." + parts[parts.length - 1])
|
||||
export const thumbnailURL = (results, result, format) => {
|
||||
console.log("RES", results)
|
||||
return `${config.thumb_path}${result[2]}${format}.${results.extensions[format]}`
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
{
|
||||
"backend_url": "https://mse.osmarks.net/backend",
|
||||
"image_path": "https://i2.osmarks.net/memes-or-something/",
|
||||
"thumb_path": "https://i2.osmarks.net/thumbs/memes-or-something_"
|
||||
"backend_url": "http://localhost:1707/",
|
||||
"image_path": "http://localhost:7858/",
|
||||
"thumb_path": "http://localhost:7857/"
|
||||
}
|
26
misc/bad-go-version/go.mod
Normal file
26
misc/bad-go-version/go.mod
Normal file
@ -0,0 +1,26 @@
|
||||
module meme-search
|
||||
|
||||
go 1.22.2
|
||||
|
||||
require (
|
||||
github.com/DataIntelligenceCrew/go-faiss v0.2.0
|
||||
github.com/jmoiron/sqlx v1.4.0
|
||||
github.com/mattn/go-sqlite3 v1.14.22
|
||||
github.com/samber/lo v1.39.0
|
||||
github.com/titanous/json5 v1.0.0
|
||||
github.com/vmihailenco/msgpack v4.0.4+incompatible
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.7.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davidbyttow/govips/v2 v2.14.0 // indirect
|
||||
github.com/golang/protobuf v1.5.2 // indirect
|
||||
github.com/h2non/bimg v1.1.9 // indirect
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
|
||||
golang.org/x/image v0.16.0 // indirect
|
||||
golang.org/x/net v0.25.0 // indirect
|
||||
golang.org/x/text v0.15.0 // indirect
|
||||
google.golang.org/appengine v1.6.8 // indirect
|
||||
google.golang.org/protobuf v1.26.0 // indirect
|
||||
)
|
100
misc/bad-go-version/go.sum
Normal file
100
misc/bad-go-version/go.sum
Normal file
@ -0,0 +1,100 @@
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/DataIntelligenceCrew/go-faiss v0.2.0 h1:c0pxAr0vldXIuE4DZnqpl6FuuH1uZd45d+NiQHKg1uU=
|
||||
github.com/DataIntelligenceCrew/go-faiss v0.2.0/go.mod h1:4Gi7G3PF78IwZigTL2M1AJXOaAgxyL66vCqUYVaNgwk=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davidbyttow/govips/v2 v2.14.0 h1:il3pX0XMZ5nlwipkFJHRZ3vGzcdXWApARalJxNpRHJU=
|
||||
github.com/davidbyttow/govips/v2 v2.14.0/go.mod h1:eglyvgm65eImDiJJk4wpj9LSz4pWivPzWgDqkxWJn5k=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
|
||||
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/h2non/bimg v1.1.9 h1:WH20Nxko9l/HFm4kZCA3Phbgu2cbHvYzxwxn9YROEGg=
|
||||
github.com/h2non/bimg v1.1.9/go.mod h1:R3+UiYwkK4rQl6KVFTOFJHitgLbZXBZNFh2cv3AEbp8=
|
||||
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
|
||||
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
|
||||
github.com/json5/json5-go v0.0.0-20160331055859-40c2958e3bf8 h1:BQuwfXQRDQMI8YNqINKNlFV23P0h07ZvOQAtezAEsP8=
|
||||
github.com/json5/json5-go v0.0.0-20160331055859-40c2958e3bf8/go.mod h1:7n1PdYNh4RIHTvILru80IEstTADqQz/wmjeNXTcC9rA=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
|
||||
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/thoas/go-funk v0.9.3 h1:7+nAEx3kn5ZJcnDm2Bh23N2yOtweO14bi//dvRtgLpw=
|
||||
github.com/thoas/go-funk v0.9.3/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q=
|
||||
github.com/titanous/json5 v1.0.0 h1:hJf8Su1d9NuI/ffpxgxQfxh/UiBFZX7bMPid0rIL/7s=
|
||||
github.com/titanous/json5 v1.0.0/go.mod h1:7JH1M8/LHKc6cyP5o5g3CSaRj+mBrIimTxzpvmckH8c=
|
||||
github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI=
|
||||
github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
|
||||
golang.org/x/image v0.10.0/go.mod h1:jtrku+n79PfroUbvDdeUWMAI+heR786BofxrbiSF+J0=
|
||||
golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw=
|
||||
golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM=
|
||||
google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk=
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
877
misc/bad-go-version/meme_search.go
Normal file
877
misc/bad-go-version/meme_search.go
Normal file
@ -0,0 +1,877 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/DataIntelligenceCrew/go-faiss"
|
||||
"github.com/h2non/bimg"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/samber/lo"
|
||||
"github.com/vmihailenco/msgpack"
|
||||
"github.com/x448/float16"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ClipServer string `json:"clip_server"`
|
||||
DbPath string `json:"db_path"`
|
||||
Port int16 `json:"port"`
|
||||
Files string `json:"files"`
|
||||
EnableOCR bool `json:"enable_ocr"`
|
||||
ThumbsPath string `json:"thumbs_path"`
|
||||
EnableThumbnails bool `json:"enable_thumbs"`
|
||||
}
|
||||
|
||||
type Index struct {
|
||||
vectors *faiss.IndexImpl
|
||||
filenames []string
|
||||
formatCodes []int64
|
||||
formatNames []string
|
||||
}
|
||||
|
||||
var schema = `
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
filename TEXT PRIMARY KEY,
|
||||
embedding_time INTEGER,
|
||||
ocr_time INTEGER,
|
||||
thumbnail_time INTEGER,
|
||||
embedding BLOB,
|
||||
ocr TEXT,
|
||||
raw_ocr_segments BLOB,
|
||||
thumbnails BLOB
|
||||
);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
|
||||
filename,
|
||||
ocr,
|
||||
tokenize='unicode61 remove_diacritics 2',
|
||||
content='ocr'
|
||||
);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN
|
||||
INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER DELETE ON files BEGIN
|
||||
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER UPDATE ON files BEGIN
|
||||
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
|
||||
INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
|
||||
END;
|
||||
`
|
||||
|
||||
type FileRecord struct {
|
||||
Filename string `db:"filename"`
|
||||
EmbedTime int64 `db:"embedding_time"`
|
||||
OcrTime int64 `db:"ocr_time"`
|
||||
ThumbnailTime int64 `db:"thumbnail_time"`
|
||||
Embedding []byte `db:"embedding"`
|
||||
Ocr string `db:"ocr"`
|
||||
RawOcrSegments []byte `db:"raw_ocr_segments"`
|
||||
Thumbnails []byte `db:"thumbnails"`
|
||||
}
|
||||
|
||||
type InferenceServerConfig struct {
|
||||
BatchSize uint `msgpack:"batch"`
|
||||
ImageSize []uint `msgpack:"image_size"`
|
||||
EmbeddingSize uint `msgpack:"embedding_size"`
|
||||
}
|
||||
|
||||
func decodeMsgpackFrom[O interface{}](resp *http.Response) (O, error) {
|
||||
var result O
|
||||
respData, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
err = msgpack.Unmarshal(respData, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func queryClipServer[I interface{}, O interface{}](config Config, path string, data I) (O, error) {
|
||||
var result O
|
||||
b, err := msgpack.Marshal(data)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
resp, err := http.Post(config.ClipServer+path, "application/msgpack", bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return decodeMsgpackFrom[O](resp)
|
||||
}
|
||||
|
||||
type LoadedImage struct {
|
||||
image *bimg.Image
|
||||
filename string
|
||||
originalSize int
|
||||
}
|
||||
|
||||
type EmbeddingInput struct {
|
||||
image []byte
|
||||
filename string
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Images [][]byte `msgpack:"images"`
|
||||
Text []string `msgpack:"text"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse = [][]byte
|
||||
|
||||
func timestamp() int64 {
|
||||
return time.Now().UnixMicro()
|
||||
}
|
||||
|
||||
type ImageFormatConfig struct {
|
||||
targetWidth int
|
||||
targetFilesize int
|
||||
quality int
|
||||
format bimg.ImageType
|
||||
extension string
|
||||
}
|
||||
|
||||
func generateFilenameHash(filename string) string {
|
||||
hasher := fnv.New128()
|
||||
hasher.Write([]byte(filename))
|
||||
hash := hasher.Sum(make([]byte, 0))
|
||||
return base64.RawURLEncoding.EncodeToString(hash)
|
||||
}
|
||||
|
||||
func generateThumbnailFilename(filename string, formatName string, formatConfig ImageFormatConfig) string {
|
||||
return fmt.Sprintf("%s%s.%s", generateFilenameHash(filename), formatName, formatConfig.extension)
|
||||
}
|
||||
|
||||
func initializeDatabase(config Config) (*sqlx.DB, error) {
|
||||
db, err := sqlx.Connect("sqlite3", config.DbPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = db.Exec("PRAGMA busy_timeout = 2000; PRAGMA journal_mode = WAL")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func imageFormats(config Config) map[string]ImageFormatConfig {
|
||||
return map[string]ImageFormatConfig{
|
||||
"jpegl": {
|
||||
targetWidth: 800,
|
||||
quality: 70,
|
||||
format: bimg.JPEG,
|
||||
extension: "jpg",
|
||||
},
|
||||
"jpegh": {
|
||||
targetWidth: 1600,
|
||||
quality: 80,
|
||||
format: bimg.JPEG,
|
||||
extension: "jpg",
|
||||
},
|
||||
"jpeg256kb": {
|
||||
targetWidth: 500,
|
||||
targetFilesize: 256000,
|
||||
format: bimg.JPEG,
|
||||
extension: "jpg",
|
||||
},
|
||||
"avifh": {
|
||||
targetWidth: 1600,
|
||||
quality: 80,
|
||||
format: bimg.AVIF,
|
||||
extension: "avif",
|
||||
},
|
||||
"avifl": {
|
||||
targetWidth: 800,
|
||||
quality: 30,
|
||||
format: bimg.AVIF,
|
||||
extension: "avif",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ingestFiles(config Config, backend InferenceServerConfig) error {
|
||||
var wg errgroup.Group
|
||||
var iwg errgroup.Group
|
||||
|
||||
// We assume everything is either a modern browser (low-DPI or high-DPI), an ancient browser or a ComputerCraft machine abusing Extra Utilities 2 screens.
|
||||
var formats = imageFormats(config)
|
||||
|
||||
db, err := initializeDatabase(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
toProcess := make(chan FileRecord, 100)
|
||||
toEmbed := make(chan EmbeddingInput, backend.BatchSize)
|
||||
toThumbnail := make(chan LoadedImage, 30)
|
||||
toOCR := make(chan LoadedImage, 30)
|
||||
embedBatches := make(chan []EmbeddingInput, 1)
|
||||
|
||||
// image loading and preliminary resizing
|
||||
for range runtime.NumCPU() {
|
||||
iwg.Go(func() error {
|
||||
for record := range toProcess {
|
||||
path := filepath.Join(config.Files, record.Filename)
|
||||
buffer, err := bimg.Read(path)
|
||||
if err != nil {
|
||||
log.Println("could not read ", record.Filename)
|
||||
}
|
||||
img := bimg.NewImage(buffer)
|
||||
if record.Embedding == nil {
|
||||
resized, err := img.Process(bimg.Options{
|
||||
Width: int(backend.ImageSize[0]),
|
||||
Height: int(backend.ImageSize[1]),
|
||||
Force: true,
|
||||
Type: bimg.PNG,
|
||||
Interpretation: bimg.InterpretationSRGB,
|
||||
})
|
||||
if err != nil {
|
||||
log.Println("resize failure", record.Filename, err)
|
||||
} else {
|
||||
toEmbed <- EmbeddingInput{
|
||||
image: resized,
|
||||
filename: record.Filename,
|
||||
}
|
||||
}
|
||||
}
|
||||
if record.Thumbnails == nil && config.EnableThumbnails {
|
||||
toThumbnail <- LoadedImage{
|
||||
image: img,
|
||||
filename: record.Filename,
|
||||
originalSize: len(buffer),
|
||||
}
|
||||
}
|
||||
if record.RawOcrSegments == nil && config.EnableOCR {
|
||||
toOCR <- LoadedImage{
|
||||
image: img,
|
||||
filename: record.Filename,
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if config.EnableThumbnails {
|
||||
for range runtime.NumCPU() {
|
||||
wg.Go(func() error {
|
||||
for image := range toThumbnail {
|
||||
generatedFormats := make([]string, 0)
|
||||
for formatName, formatConfig := range formats {
|
||||
var err error
|
||||
var resized []byte
|
||||
if formatConfig.targetFilesize != 0 {
|
||||
lb := 1
|
||||
ub := 100
|
||||
for {
|
||||
quality := (lb + ub) / 2
|
||||
resized, err = image.image.Process(bimg.Options{
|
||||
Width: formatConfig.targetWidth,
|
||||
Type: formatConfig.format,
|
||||
Speed: 4,
|
||||
Quality: quality,
|
||||
StripMetadata: true,
|
||||
Enlarge: false,
|
||||
})
|
||||
if len(resized) > image.originalSize {
|
||||
ub = quality
|
||||
} else {
|
||||
lb = quality + 1
|
||||
}
|
||||
if lb >= ub {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
resized, err = image.image.Process(bimg.Options{
|
||||
Width: formatConfig.targetWidth,
|
||||
Type: formatConfig.format,
|
||||
Speed: 4,
|
||||
Quality: formatConfig.quality,
|
||||
StripMetadata: true,
|
||||
Enlarge: false,
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
log.Println("thumbnailing failure", image.filename, err)
|
||||
continue
|
||||
}
|
||||
if len(resized) < image.originalSize {
|
||||
generatedFormats = append(generatedFormats, formatName)
|
||||
err = bimg.Write(filepath.Join(config.ThumbsPath, generateThumbnailFilename(image.filename, formatName, formatConfig)), resized)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
formatsData, err := msgpack.Marshal(generatedFormats)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = db.Exec("UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?", formatsData, timestamp(), image.filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if config.EnableOCR {
|
||||
for range 100 {
|
||||
wg.Go(func() error {
|
||||
for image := range toOCR {
|
||||
scan, err := scanImage(image.image)
|
||||
if err != nil {
|
||||
log.Println("OCR failure", image.filename, err)
|
||||
continue
|
||||
}
|
||||
ocrText := ""
|
||||
for _, segment := range scan {
|
||||
ocrText += segment.text
|
||||
ocrText += "\n"
|
||||
}
|
||||
ocrData, err := msgpack.Marshal(scan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = db.Exec("UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?", ocrText, ocrData, timestamp(), image.filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
wg.Go(func() error {
|
||||
buffer := make([]EmbeddingInput, 0, backend.BatchSize)
|
||||
for input := range toEmbed {
|
||||
buffer = append(buffer, input)
|
||||
if len(buffer) == int(backend.BatchSize) {
|
||||
embedBatches <- buffer
|
||||
buffer = make([]EmbeddingInput, 0, backend.BatchSize)
|
||||
}
|
||||
}
|
||||
if len(buffer) > 0 {
|
||||
embedBatches <- buffer
|
||||
}
|
||||
close(embedBatches)
|
||||
return nil
|
||||
})
|
||||
|
||||
for range 3 {
|
||||
wg.Go(func() error {
|
||||
for batch := range embedBatches {
|
||||
result, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "", EmbeddingRequest{
|
||||
Images: lo.Map(batch, func(item EmbeddingInput, _ int) []byte { return item.image }),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i, vector := range result {
|
||||
_, err = tx.Exec("UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?", timestamp(), vector, batch[i].filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
filenamesOnDisk := make(map[string]struct{})
|
||||
|
||||
err = filepath.WalkDir(config.Files, func(path string, d os.DirEntry, err error) error {
|
||||
filename := strings.TrimPrefix(path, config.Files)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
filenamesOnDisk[filename] = struct{}{}
|
||||
records := []FileRecord{}
|
||||
err = db.Select(&records, "SELECT * FROM files WHERE filename = ?", filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stat, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
modtime := stat.ModTime().UnixMicro()
|
||||
if len(records) == 0 || modtime > records[0].EmbedTime || modtime > records[0].OcrTime || modtime > records[0].ThumbnailTime {
|
||||
_, err = db.Exec("INSERT OR IGNORE INTO files VALUES (?, 0, 0, 0, '', '', '', '')", filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
record := FileRecord{
|
||||
Filename: filename,
|
||||
}
|
||||
if len(records) > 0 {
|
||||
record = records[0]
|
||||
}
|
||||
if modtime > record.EmbedTime || len(record.Embedding) == 0 {
|
||||
record.Embedding = nil
|
||||
}
|
||||
if modtime > record.OcrTime || len(record.RawOcrSegments) == 0 {
|
||||
record.RawOcrSegments = nil
|
||||
}
|
||||
if modtime > record.ThumbnailTime || len(record.Thumbnails) == 0 {
|
||||
record.Thumbnails = nil
|
||||
}
|
||||
toProcess <- record
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
close(toProcess)
|
||||
|
||||
err = iwg.Wait()
|
||||
close(toEmbed)
|
||||
close(toThumbnail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = wg.Wait()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := db.Queryx("SELECT filename FROM files")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for rows.Next() {
|
||||
var filename string
|
||||
err := rows.Scan(&filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, ok := filenamesOnDisk[filename]; !ok {
|
||||
_, err = tx.Exec("DELETE FROM files WHERE filename = ?", filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if err = tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const INDEX_ADD_BATCH = 512
|
||||
|
||||
func buildIndex(config Config, backend InferenceServerConfig) (Index, error) {
|
||||
var index Index
|
||||
|
||||
db, err := initializeDatabase(config)
|
||||
if err != nil {
|
||||
return index, err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
newFAISSIndex, err := faiss.IndexFactory(int(backend.EmbeddingSize), "SQfp16", faiss.MetricInnerProduct)
|
||||
if err != nil {
|
||||
return index, err
|
||||
}
|
||||
index.vectors = newFAISSIndex
|
||||
|
||||
var count int
|
||||
err = db.Get(&count, "SELECT COUNT(*) FROM files")
|
||||
if err != nil {
|
||||
return index, err
|
||||
}
|
||||
|
||||
index.filenames = make([]string, 0, count)
|
||||
index.formatCodes = make([]int64, 0, count)
|
||||
buffer := make([]float32, 0, INDEX_ADD_BATCH*backend.EmbeddingSize)
|
||||
index.formatNames = make([]string, 0, 5)
|
||||
|
||||
record := FileRecord{}
|
||||
rows, err := db.Queryx("SELECT * FROM files")
|
||||
if err != nil {
|
||||
return index, err
|
||||
}
|
||||
for rows.Next() {
|
||||
err := rows.StructScan(&record)
|
||||
if err != nil {
|
||||
return index, err
|
||||
}
|
||||
if len(record.Embedding) > 0 {
|
||||
index.filenames = append(index.filenames, record.Filename)
|
||||
for i := 0; i < len(record.Embedding); i += 2 {
|
||||
buffer = append(buffer, float16.Frombits(uint16(record.Embedding[i])+uint16(record.Embedding[i+1])<<8).Float32())
|
||||
}
|
||||
if len(buffer) == cap(buffer) {
|
||||
index.vectors.Add(buffer)
|
||||
buffer = make([]float32, 0, INDEX_ADD_BATCH*backend.EmbeddingSize)
|
||||
}
|
||||
|
||||
formats := make([]string, 0, 5)
|
||||
if len(record.Thumbnails) > 0 {
|
||||
err := msgpack.Unmarshal(record.Thumbnails, &formats)
|
||||
if err != nil {
|
||||
return index, err
|
||||
}
|
||||
}
|
||||
|
||||
formatCode := int64(0)
|
||||
for _, formatString := range formats {
|
||||
found := false
|
||||
for i, name := range index.formatNames {
|
||||
if name == formatString {
|
||||
formatCode |= 1 << i
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
newIndex := len(index.formatNames)
|
||||
formatCode |= 1 << newIndex
|
||||
index.formatNames = append(index.formatNames, formatString)
|
||||
}
|
||||
}
|
||||
index.formatCodes = append(index.formatCodes, formatCode)
|
||||
}
|
||||
}
|
||||
if len(buffer) > 0 {
|
||||
index.vectors.Add(buffer)
|
||||
}
|
||||
|
||||
return index, nil
|
||||
}
|
||||
|
||||
func decodeFP16Buffer(buf []byte) []float32 {
|
||||
out := make([]float32, 0, len(buf)/2)
|
||||
for i := 0; i < len(buf); i += 2 {
|
||||
out = append(out, float16.Frombits(uint16(buf[i])+uint16(buf[i+1])<<8).Float32())
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type EmbeddingVector []float32
|
||||
|
||||
type QueryResult struct {
|
||||
Matches [][]interface{} `json:"matches"`
|
||||
Formats []string `json:"formats"`
|
||||
Extensions map[string]string `json:"extensions"`
|
||||
}
|
||||
|
||||
// this terrible language cannot express tagged unions
|
||||
type QueryTerm struct {
|
||||
Embedding *EmbeddingVector `json:"embedding"`
|
||||
Image *string `json:"image"` // base64
|
||||
Text *string `json:"text"`
|
||||
Weight *float32 `json:"weight"`
|
||||
}
|
||||
|
||||
type QueryRequest struct {
|
||||
Terms []QueryTerm `json:"terms"`
|
||||
K *int `json:"k"`
|
||||
}
|
||||
|
||||
func queryIndex(index *Index, query EmbeddingVector, k int) (QueryResult, error) {
|
||||
var qr QueryResult
|
||||
distances, ids, err := index.vectors.Search(query, int64(k))
|
||||
if err != nil {
|
||||
return qr, err
|
||||
}
|
||||
items := lo.Map(lo.Zip2(distances, ids), func(x lo.Tuple2[float32, int64], i int) []interface{} {
|
||||
return []interface{}{
|
||||
x.A,
|
||||
index.filenames[x.B],
|
||||
generateFilenameHash(index.filenames[x.B]),
|
||||
index.formatCodes[x.B],
|
||||
}
|
||||
})
|
||||
|
||||
return QueryResult{
|
||||
Matches: items,
|
||||
Formats: index.formatNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func handleRequest(config Config, backendConfig InferenceServerConfig, index *Index, w http.ResponseWriter, req *http.Request) error {
|
||||
if req.Body == nil {
|
||||
io.WriteString(w, "OK") // health check
|
||||
return nil
|
||||
}
|
||||
dec := json.NewDecoder(req.Body)
|
||||
var qreq QueryRequest
|
||||
err := dec.Decode(&qreq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
totalEmbedding := make(EmbeddingVector, backendConfig.EmbeddingSize)
|
||||
|
||||
imageBatch := make([][]byte, 0)
|
||||
imageWeights := make([]float32, 0)
|
||||
textBatch := make([]string, 0)
|
||||
textWeights := make([]float32, 0)
|
||||
|
||||
for _, term := range qreq.Terms {
|
||||
if term.Image != nil {
|
||||
bytes, err := base64.StdEncoding.DecodeString(*term.Image)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
loaded := bimg.NewImage(bytes)
|
||||
resized, err := loaded.Process(bimg.Options{
|
||||
Width: int(backendConfig.ImageSize[0]),
|
||||
Height: int(backendConfig.ImageSize[1]),
|
||||
Force: true,
|
||||
Type: bimg.PNG,
|
||||
Interpretation: bimg.InterpretationSRGB,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
imageBatch = append(imageBatch, resized)
|
||||
if term.Weight != nil {
|
||||
imageWeights = append(imageWeights, *term.Weight)
|
||||
} else {
|
||||
imageWeights = append(imageWeights, 1)
|
||||
}
|
||||
}
|
||||
if term.Text != nil {
|
||||
textBatch = append(textBatch, *term.Text)
|
||||
if term.Weight != nil {
|
||||
textWeights = append(textWeights, *term.Weight)
|
||||
} else {
|
||||
textWeights = append(textWeights, 1)
|
||||
}
|
||||
}
|
||||
if term.Embedding != nil {
|
||||
weight := float32(1.0)
|
||||
if term.Weight != nil {
|
||||
weight = *term.Weight
|
||||
}
|
||||
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
|
||||
totalEmbedding[i] += (*term.Embedding)[i] * weight
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(imageBatch) > 0 {
|
||||
embs, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "/", EmbeddingRequest{
|
||||
Images: imageBatch,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for j, emb := range embs {
|
||||
embd := decodeFP16Buffer(emb)
|
||||
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
|
||||
totalEmbedding[i] += embd[i] * imageWeights[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(textBatch) > 0 {
|
||||
embs, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "/", EmbeddingRequest{
|
||||
Text: textBatch,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for j, emb := range embs {
|
||||
embd := decodeFP16Buffer(emb)
|
||||
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
|
||||
totalEmbedding[i] += embd[i] * textWeights[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
k := 1000
|
||||
if qreq.K != nil {
|
||||
k = *qreq.K
|
||||
}
|
||||
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
enc := json.NewEncoder(w)
|
||||
|
||||
qres, err := queryIndex(index, totalEmbedding, k)
|
||||
|
||||
qres.Extensions = make(map[string]string)
|
||||
for k, v := range imageFormats(config) {
|
||||
qres.Extensions[k] = v.extension
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = enc.Encode(qres)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
os.Setenv("VIPS_WARNING", "FALSE") // this does not actually work
|
||||
bimg.VipsCacheSetMax(0)
|
||||
bimg.VipsCacheSetMaxMem(0)
|
||||
}
|
||||
|
||||
func main() {
|
||||
content, err := os.ReadFile(os.Args[1])
|
||||
if err != nil {
|
||||
log.Fatal("config file unreadable ", err)
|
||||
}
|
||||
var config Config
|
||||
err = json.Unmarshal(content, &config)
|
||||
if err != nil {
|
||||
log.Fatal("config file wrong ", err)
|
||||
}
|
||||
fmt.Println(config)
|
||||
|
||||
db, err := sqlx.Connect("sqlite3", config.DbPath)
|
||||
if err != nil {
|
||||
log.Fatal("DB connection failure ", db)
|
||||
}
|
||||
db.MustExec(schema)
|
||||
|
||||
var backend InferenceServerConfig
|
||||
for {
|
||||
resp, err := http.Get(config.ClipServer + "/config")
|
||||
if err != nil {
|
||||
log.Println("backend failed (fetch) ", err)
|
||||
}
|
||||
backend, err = decodeMsgpackFrom[InferenceServerConfig](resp)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
log.Println("backend failed (parse) ", err)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
requestIngest := make(chan struct{}, 1)
|
||||
|
||||
var index *Index
|
||||
// maybe this ought to be mutexed?
|
||||
var lastError *error
|
||||
// there's not a neat way to reusably broadcast to multiple channels, but I *can* abuse WaitGroups probably
|
||||
// this might cause horrible concurrency issues, but you brought me to this point, Go designers
|
||||
var wg sync.WaitGroup
|
||||
|
||||
go func() {
|
||||
for {
|
||||
wg.Add(1)
|
||||
log.Println("ingest running")
|
||||
err := ingestFiles(config, backend)
|
||||
if err != nil {
|
||||
log.Println("ingest failed ", err)
|
||||
lastError = &err
|
||||
} else {
|
||||
newIndex, err := buildIndex(config, backend)
|
||||
if err != nil {
|
||||
log.Println("index build failed ", err)
|
||||
lastError = &err
|
||||
} else {
|
||||
lastError = nil
|
||||
index = &newIndex
|
||||
}
|
||||
}
|
||||
wg.Done()
|
||||
<-requestIngest
|
||||
}
|
||||
}()
|
||||
newIndex, err := buildIndex(config, backend)
|
||||
index = &newIndex
|
||||
if err != nil {
|
||||
log.Fatal("index build failed ", err)
|
||||
}
|
||||
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Add("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Add("Access-Control-Allow-Headers", "Content-Type")
|
||||
if req.Method == "OPTIONS" {
|
||||
w.WriteHeader(204)
|
||||
return
|
||||
}
|
||||
err := handleRequest(config, backend, index, w, req)
|
||||
if err != nil {
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.WriteHeader(500)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
})
|
||||
http.HandleFunc("/reload", func(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Method == "POST" {
|
||||
log.Println("requesting index reload")
|
||||
select {
|
||||
case requestIngest <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
wg.Wait()
|
||||
if lastError == nil {
|
||||
w.Write([]byte("OK"))
|
||||
} else {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte((*lastError).Error()))
|
||||
}
|
||||
}
|
||||
})
|
||||
http.HandleFunc("/profile", func(w http.ResponseWriter, req *http.Request) {
|
||||
f, err := os.Create("mem.pprof")
|
||||
if err != nil {
|
||||
log.Fatal("could not create memory profile: ", err)
|
||||
}
|
||||
defer f.Close()
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
log.Printf("Memory usage: Alloc=%v, TotalAlloc=%v, Sys=%v", m.Alloc, m.TotalAlloc, m.Sys)
|
||||
log.Println(bimg.VipsMemory())
|
||||
bimg.VipsDebugInfo()
|
||||
runtime.GC() // Trigger garbage collection
|
||||
if err := pprof.WriteHeapProfile(f); err != nil {
|
||||
log.Fatal("could not write memory profile: ", err)
|
||||
}
|
||||
})
|
||||
log.Println("starting server")
|
||||
http.ListenAndServe(fmt.Sprintf(":%d", config.Port), nil)
|
||||
}
|
264
misc/bad-go-version/ocr.go
Normal file
264
misc/bad-go-version/ocr.go
Normal file
@ -0,0 +1,264 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/h2non/bimg"
|
||||
"github.com/samber/lo"
|
||||
"github.com/titanous/json5"
|
||||
)
|
||||
|
||||
const CALLBACK_REGEX string = ">AF_initDataCallback\\(({key: 'ds:1'.*?)\\);</script>"
|
||||
|
||||
type SegmentCoords struct {
|
||||
x int
|
||||
y int
|
||||
w int
|
||||
h int
|
||||
}
|
||||
|
||||
type Segment struct {
|
||||
coords SegmentCoords
|
||||
text string
|
||||
}
|
||||
|
||||
type ScanResult []Segment
|
||||
|
||||
// TODO coordinates are negative sometimes and I think they shouldn't be
|
||||
func rationalizeCoordsFormat1(imageW float64, imageH float64, centerXFraction float64, centerYFraction float64, widthFraction float64, heightFraction float64) SegmentCoords {
|
||||
return SegmentCoords{
|
||||
x: int(math.Round((centerXFraction - widthFraction/2) * imageW)),
|
||||
y: int(math.Round((centerYFraction - heightFraction/2) * imageH)),
|
||||
w: int(math.Round(widthFraction * imageW)),
|
||||
h: int(math.Round(heightFraction * imageH)),
|
||||
}
|
||||
}
|
||||
|
||||
func scanImageChunk(image []byte, imageWidth int, imageHeight int) (ScanResult, error) {
|
||||
var result ScanResult
|
||||
timestamp := time.Now().UnixMicro()
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
defer w.Close()
|
||||
h := make(textproto.MIMEHeader)
|
||||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="encoded_image"; filename="ocr%d.png"`, timestamp))
|
||||
h.Set("Content-Type", "image/png")
|
||||
fw, err := w.CreatePart(h)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
fw.Write(image)
|
||||
w.Close()
|
||||
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://lens.google.com/v3/upload?stcs=%d", timestamp), &b)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
req.Header.Add("User-Agent", "Mozilla/5.0 (Linux; Android 13; RMX3771) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.6167.144 Mobile Safari/537.36")
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "SOCS",
|
||||
Value: "CAESEwgDEgk0ODE3Nzk3MjQaAmVuIAEaBgiA_LyaBg",
|
||||
})
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
client := http.Client{}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
re, _ := regexp.Compile(CALLBACK_REGEX)
|
||||
matches := re.FindStringSubmatch(string(body[:]))
|
||||
if len(matches) == 0 {
|
||||
return result, fmt.Errorf("invalid API response")
|
||||
}
|
||||
match := matches[1]
|
||||
var lensObject map[string]interface{}
|
||||
err = json5.Unmarshal([]byte(match), &lensObject)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
if _, ok := lensObject["errorHasStatus"]; ok {
|
||||
return result, errors.New("lens failed")
|
||||
}
|
||||
|
||||
root := lensObject["data"].([]interface{})
|
||||
|
||||
var textSegments []string
|
||||
var textRegions []SegmentCoords
|
||||
|
||||
// I don't know why Google did this.
|
||||
// Text segments are in one place and their locations are in another, using a very strange coordinate system.
|
||||
// At least I don't need whatever is contained in the base64 parts (which I assume are protobufs).
|
||||
// TODO: on a few images, this seems to not work for some reason.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// https://github.com/dimdenGD/chrome-lens-ocr/blob/main/src/core.js#L316 has code for a fallback text segment read mode.
|
||||
// In testing, this proved unnecessary (quirks of the HTTP request? I don't know), and this only happens on textless images.
|
||||
textSegments = []string{}
|
||||
textRegions = []SegmentCoords{}
|
||||
}
|
||||
}()
|
||||
|
||||
textSegmentsRaw := root[3].([]interface{})[4].([]interface{})[0].([]interface{})[0].([]interface{})
|
||||
textRegionsRaw := root[2].([]interface{})[3].([]interface{})[0].([]interface{})
|
||||
for _, x := range textRegionsRaw {
|
||||
if strings.HasPrefix(x.([]interface{})[11].(string), "text:") {
|
||||
rawCoords := x.([]interface{})[1].([]interface{})
|
||||
coords := rationalizeCoordsFormat1(float64(imageWidth), float64(imageHeight), rawCoords[0].(float64), rawCoords[1].(float64), rawCoords[2].(float64), rawCoords[3].(float64))
|
||||
textRegions = append(textRegions, coords)
|
||||
}
|
||||
}
|
||||
for _, x := range textSegmentsRaw {
|
||||
textSegment := x.(string)
|
||||
textSegments = append(textSegments, textSegment)
|
||||
}
|
||||
|
||||
return lo.Map(lo.Zip2(textSegments, textRegions), func(x lo.Tuple2[string, SegmentCoords], _ int) Segment {
|
||||
return Segment{
|
||||
text: x.A,
|
||||
coords: x.B,
|
||||
}
|
||||
}), nil
|
||||
}
|
||||
|
||||
const MAX_DIM int = 1024
|
||||
|
||||
func scanImage(image *bimg.Image) (ScanResult, error) {
|
||||
result := ScanResult{}
|
||||
metadata, err := image.Metadata()
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
width := metadata.Size.Width
|
||||
height := metadata.Size.Height
|
||||
if width > MAX_DIM {
|
||||
width = MAX_DIM
|
||||
height = int(math.Round(float64(height) * (float64(width) / float64(metadata.Size.Width))))
|
||||
}
|
||||
for y := 0; y < height; y += MAX_DIM {
|
||||
chunkHeight := MAX_DIM
|
||||
if y+chunkHeight > height {
|
||||
chunkHeight = height - y
|
||||
}
|
||||
chunk, err := image.Process(bimg.Options{
|
||||
Height: height, // these are for overall image dimensions (resize then crop)
|
||||
Width: width,
|
||||
Top: y,
|
||||
AreaHeight: chunkHeight,
|
||||
AreaWidth: width,
|
||||
Crop: true,
|
||||
Type: bimg.PNG,
|
||||
})
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
res, err := scanImageChunk(chunk, width, chunkHeight)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
for _, segment := range res {
|
||||
result = append(result, Segment{
|
||||
text: segment.text,
|
||||
coords: SegmentCoords{
|
||||
y: segment.coords.y + y,
|
||||
x: segment.coords.x,
|
||||
w: segment.coords.w,
|
||||
h: segment.coords.h,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
/*
|
||||
async def scan_image_chunk(sess, image):
|
||||
# send data to inscrutable undocumented Google service
|
||||
# https://github.com/AuroraWright/owocr/blob/master/owocr/ocr.py#L193
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
data = aiohttp.FormData()
|
||||
data.add_field(
|
||||
"encoded_image",
|
||||
encode_img(image),
|
||||
filename="ocr" + str(timestamp) + ".png",
|
||||
content_type="image/png"
|
||||
)
|
||||
async with sess.post(url, headers=headers, cookies=cookies, data=data, timeout=10) as res:
|
||||
body = await res.text()
|
||||
|
||||
# I really worry about Google sometimes. This is not a sensible format.
|
||||
match = CALLBACK_REGEX.search(body)
|
||||
if match == None:
|
||||
raise ValueError("Invalid callback")
|
||||
|
||||
lens_object = pyjson5.loads(match.group(1))
|
||||
if "errorHasStatus" in lens_object:
|
||||
raise RuntimeError("Lens failed")
|
||||
|
||||
text_segments = []
|
||||
text_regions = []
|
||||
|
||||
root = lens_object["data"]
|
||||
|
||||
# I don't know why Google did this.
|
||||
# Text segments are in one place and their locations are in another, using a very strange coordinate system.
|
||||
# At least I don't need whatever is contained in the base64 partss (which I assume are protobufs).
|
||||
# TODO: on a few images, this seems to not work for some reason.
|
||||
try:
|
||||
text_segments = root[3][4][0][0]
|
||||
text_regions = [ rationalize_coords_format1(image.width, image.height, *x[1]) for x in root[2][3][0] if x[11].startswith("text:") ]
|
||||
except (KeyError, IndexError):
|
||||
# https://github.com/dimdenGD/chrome-lens-ocr/blob/main/src/core.js#L316 has code for a fallback text segment read mode.
|
||||
# In testing, this proved unnecessary (quirks of the HTTP request? I don't know), and this only happens on textless images.
|
||||
return [], []
|
||||
|
||||
return text_segments, text_regions
|
||||
|
||||
MAX_SCAN_DIM = 1000 # not actually true but close enough
|
||||
def chunk_image(image: Image):
|
||||
chunks = []
|
||||
# Cut image down in X axis (I'm assuming images aren't too wide to scan in downscaled form because merging text horizontally would be annoying)
|
||||
if image.width > MAX_SCAN_DIM:
|
||||
image = image.resize((MAX_SCAN_DIM, round(image.height * (image.width / MAX_SCAN_DIM))), Image.LANCZOS)
|
||||
for y in range(0, image.height, MAX_SCAN_DIM):
|
||||
chunks.append(image.crop((0, y, image.width, min(y + MAX_SCAN_DIM, image.height))))
|
||||
return chunks
|
||||
|
||||
async def scan_chunks(sess: aiohttp.ClientSession, chunks: [Image]):
|
||||
# If text happens to be split across the cut line it won't get read.
|
||||
# This is because doing overlap read areas would be really annoying.
|
||||
text = ""
|
||||
regions = []
|
||||
for chunk in chunks:
|
||||
new_segments, new_regions = await scan_image_chunk(sess, chunk)
|
||||
for segment in new_segments:
|
||||
text += segment + "\n"
|
||||
for i, (segment, region) in enumerate(zip(new_segments, new_regions)):
|
||||
regions.append({ **region, "y": region["y"] + (MAX_SCAN_DIM * i), "text": segment })
|
||||
return text, regions
|
||||
|
||||
async def scan_image(sess: aiohttp.ClientSession, image: Image):
|
||||
return await scan_chunks(sess, chunk_image(image))
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
print(await scan_image(sess, Image.open("/data/public/memes-or-something/linear-algebra-chess.png")))
|
||||
asyncio.run(main())
|
||||
*/
|
891
misc/bad-go-version/problematic_thing.go
Normal file
891
misc/bad-go-version/problematic_thing.go
Normal file
@ -0,0 +1,891 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/DataIntelligenceCrew/go-faiss"
|
||||
"github.com/davidbyttow/govips/v2/vips"
|
||||
"github.com/h2non/bimg"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/samber/lo"
|
||||
"github.com/vmihailenco/msgpack"
|
||||
"github.com/x448/float16"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ClipServer string `json:"clip_server"`
|
||||
DbPath string `json:"db_path"`
|
||||
Port int16 `json:"port"`
|
||||
Files string `json:"files"`
|
||||
EnableOCR bool `json:"enable_ocr"`
|
||||
ThumbsPath string `json:"thumbs_path"`
|
||||
EnableThumbnails bool `json:"enable_thumbs"`
|
||||
}
|
||||
|
||||
type Index struct {
|
||||
vectors *faiss.IndexImpl
|
||||
filenames []string
|
||||
formatCodes []int64
|
||||
formatNames []string
|
||||
}
|
||||
|
||||
var schema = `
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
filename TEXT PRIMARY KEY,
|
||||
embedding_time INTEGER,
|
||||
ocr_time INTEGER,
|
||||
thumbnail_time INTEGER,
|
||||
embedding BLOB,
|
||||
ocr TEXT,
|
||||
raw_ocr_segments BLOB,
|
||||
thumbnails BLOB
|
||||
);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
|
||||
filename,
|
||||
ocr,
|
||||
tokenize='unicode61 remove_diacritics 2',
|
||||
content='ocr'
|
||||
);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN
|
||||
INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER DELETE ON files BEGIN
|
||||
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER UPDATE ON files BEGIN
|
||||
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
|
||||
INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
|
||||
END;
|
||||
`
|
||||
|
||||
type FileRecord struct {
|
||||
Filename string `db:"filename"`
|
||||
EmbedTime int64 `db:"embedding_time"`
|
||||
OcrTime int64 `db:"ocr_time"`
|
||||
ThumbnailTime int64 `db:"thumbnail_time"`
|
||||
Embedding []byte `db:"embedding"`
|
||||
Ocr string `db:"ocr"`
|
||||
RawOcrSegments []byte `db:"raw_ocr_segments"`
|
||||
Thumbnails []byte `db:"thumbnails"`
|
||||
filesize int64
|
||||
}
|
||||
|
||||
type InferenceServerConfig struct {
|
||||
BatchSize uint `msgpack:"batch"`
|
||||
ImageSize []uint `msgpack:"image_size"`
|
||||
EmbeddingSize uint `msgpack:"embedding_size"`
|
||||
}
|
||||
|
||||
func decodeMsgpackFrom[O interface{}](resp *http.Response) (O, error) {
|
||||
var result O
|
||||
respData, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
err = msgpack.Unmarshal(respData, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func queryClipServer[I interface{}, O interface{}](config Config, path string, data I) (O, error) {
|
||||
var result O
|
||||
b, err := msgpack.Marshal(data)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
resp, err := http.Post(config.ClipServer+path, "application/msgpack", bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return decodeMsgpackFrom[O](resp)
|
||||
}
|
||||
|
||||
type LoadedImage struct {
|
||||
image *vips.ImageRef
|
||||
filename string
|
||||
originalSize int
|
||||
}
|
||||
|
||||
type EmbeddingInput struct {
|
||||
image []byte
|
||||
filename string
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Images [][]byte `msgpack:"images"`
|
||||
Text []string `msgpack:"text"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse = [][]byte
|
||||
|
||||
func timestamp() int64 {
|
||||
return time.Now().UnixMicro()
|
||||
}
|
||||
|
||||
type ImageFormatConfig struct {
|
||||
targetWidth int
|
||||
targetFilesize int
|
||||
quality int
|
||||
format vips.ImageType
|
||||
extension string
|
||||
}
|
||||
|
||||
func generateFilenameHash(filename string) string {
|
||||
hasher := fnv.New128()
|
||||
hasher.Write([]byte(filename))
|
||||
hash := hasher.Sum(make([]byte, 0))
|
||||
return base64.RawURLEncoding.EncodeToString(hash)
|
||||
}
|
||||
|
||||
func generateThumbnailFilename(filename string, formatName string, formatConfig ImageFormatConfig) string {
|
||||
return fmt.Sprintf("%s%s.%s", generateFilenameHash(filename), formatName, formatConfig.extension)
|
||||
}
|
||||
|
||||
func initializeDatabase(config Config) (*sqlx.DB, error) {
|
||||
db, err := sqlx.Connect("sqlite3", config.DbPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = db.Exec("PRAGMA busy_timeout = 2000; PRAGMA journal_mode = WAL")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func imageFormats(config Config) map[string]ImageFormatConfig {
|
||||
return map[string]ImageFormatConfig{
|
||||
"jpegl": {
|
||||
targetWidth: 800,
|
||||
quality: 70,
|
||||
format: vips.ImageTypeJPEG,
|
||||
extension: "jpg",
|
||||
},
|
||||
"jpegh": {
|
||||
targetWidth: 1600,
|
||||
quality: 80,
|
||||
format: vips.ImageTypeJPEG,
|
||||
extension: "jpg",
|
||||
},
|
||||
"jpeg256kb": {
|
||||
targetWidth: 500,
|
||||
targetFilesize: 256000,
|
||||
format: vips.ImageTypeJPEG,
|
||||
extension: "jpg",
|
||||
},
|
||||
"avifh": {
|
||||
targetWidth: 1600,
|
||||
quality: 80,
|
||||
format: vips.ImageTypeAVIF,
|
||||
extension: "avif",
|
||||
},
|
||||
"avifl": {
|
||||
targetWidth: 800,
|
||||
quality: 30,
|
||||
format: vips.ImageTypeAVIF,
|
||||
extension: "avif",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ingestFiles(config Config, backend InferenceServerConfig) error {
|
||||
var wg errgroup.Group
|
||||
var iwg errgroup.Group
|
||||
|
||||
// We assume everything is either a modern browser (low-DPI or high-DPI), an ancient browser or a ComputerCraft machine abusing Extra Utilities 2 screens.
|
||||
var formats = imageFormats(config)
|
||||
|
||||
db, err := initializeDatabase(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
toProcess := make(chan FileRecord, 100)
|
||||
toEmbed := make(chan EmbeddingInput, backend.BatchSize)
|
||||
toThumbnail := make(chan LoadedImage, 30)
|
||||
toOCR := make(chan LoadedImage, 30)
|
||||
embedBatches := make(chan []EmbeddingInput, 1)
|
||||
|
||||
// image loading and preliminary resizing
|
||||
for range runtime.NumCPU() {
|
||||
iwg.Go(func() error {
|
||||
for record := range toProcess {
|
||||
path := filepath.Join(config.Files, record.Filename)
|
||||
img, err := vips.LoadImageFromFile(path, &vips.ImportParams{})
|
||||
if err != nil {
|
||||
log.Println("could not read", record.Filename)
|
||||
continue
|
||||
}
|
||||
if record.Embedding == nil {
|
||||
i, err := img.Copy() // TODO this is ugly, we should not need to do in-place operations
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = i.ResizeWithVScale(float64(backend.ImageSize[0])/float64(i.Width()), float64(backend.ImageSize[1])/float64(i.Height()), vips.KernelLanczos3)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resized, _, err := i.ExportPng(vips.NewPngExportParams())
|
||||
if err != nil {
|
||||
log.Println("resize failure", record.Filename, err)
|
||||
} else {
|
||||
toEmbed <- EmbeddingInput{
|
||||
image: resized,
|
||||
filename: record.Filename,
|
||||
}
|
||||
}
|
||||
}
|
||||
if record.Thumbnails == nil && config.EnableThumbnails {
|
||||
toThumbnail <- LoadedImage{
|
||||
image: img,
|
||||
filename: record.Filename,
|
||||
originalSize: int(record.filesize),
|
||||
}
|
||||
}
|
||||
if record.RawOcrSegments == nil && config.EnableOCR {
|
||||
toOCR <- LoadedImage{
|
||||
image: img,
|
||||
filename: record.Filename,
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if config.EnableThumbnails {
|
||||
for range runtime.NumCPU() {
|
||||
wg.Go(func() error {
|
||||
for image := range toThumbnail {
|
||||
generatedFormats := make([]string, 0)
|
||||
for formatName, formatConfig := range formats {
|
||||
var err error
|
||||
var resized []byte
|
||||
if formatConfig.targetFilesize != 0 {
|
||||
lb := 1
|
||||
ub := 100
|
||||
for {
|
||||
quality := (lb + ub) / 2
|
||||
i, err := image.image.Copy()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.Resize(float64(formatConfig.targetWidth)/float64(i.Width()), vips.KernelLanczos3)
|
||||
resized, _, err = i.Export(&vips.ExportParams{
|
||||
Format: formatConfig.format,
|
||||
Speed: 4,
|
||||
Quality: quality,
|
||||
StripMetadata: true,
|
||||
})
|
||||
if len(resized) > image.originalSize {
|
||||
ub = quality
|
||||
} else {
|
||||
lb = quality + 1
|
||||
}
|
||||
if lb >= ub {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
i, err := image.image.Copy()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.Resize(float64(formatConfig.targetWidth)/float64(i.Width()), vips.KernelLanczos3)
|
||||
resized, _, err = i.Export(&vips.ExportParams{
|
||||
Format: formatConfig.format,
|
||||
Speed: 4,
|
||||
Quality: formatConfig.quality,
|
||||
StripMetadata: true,
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
log.Println("thumbnailing failure", image.filename, err)
|
||||
continue
|
||||
}
|
||||
if len(resized) < image.originalSize {
|
||||
generatedFormats = append(generatedFormats, formatName)
|
||||
err = bimg.Write(filepath.Join(config.ThumbsPath, generateThumbnailFilename(image.filename, formatName, formatConfig)), resized)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
formatsData, err := msgpack.Marshal(generatedFormats)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = db.Exec("UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?", formatsData, timestamp(), image.filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if config.EnableOCR {
|
||||
for range 100 {
|
||||
wg.Go(func() error {
|
||||
for image := range toOCR {
|
||||
scan, err := scanImage(image.image)
|
||||
if err != nil {
|
||||
log.Println("OCR failure", image.filename, err)
|
||||
continue
|
||||
}
|
||||
ocrText := ""
|
||||
for _, segment := range scan {
|
||||
ocrText += segment.text
|
||||
ocrText += "\n"
|
||||
}
|
||||
ocrData, err := msgpack.Marshal(scan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = db.Exec("UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?", ocrText, ocrData, timestamp(), image.filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
wg.Go(func() error {
|
||||
buffer := make([]EmbeddingInput, 0, backend.BatchSize)
|
||||
for input := range toEmbed {
|
||||
buffer = append(buffer, input)
|
||||
if len(buffer) == int(backend.BatchSize) {
|
||||
embedBatches <- buffer
|
||||
buffer = make([]EmbeddingInput, 0, backend.BatchSize)
|
||||
}
|
||||
}
|
||||
if len(buffer) > 0 {
|
||||
embedBatches <- buffer
|
||||
}
|
||||
close(embedBatches)
|
||||
return nil
|
||||
})
|
||||
|
||||
for range 3 {
|
||||
wg.Go(func() error {
|
||||
for batch := range embedBatches {
|
||||
result, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "", EmbeddingRequest{
|
||||
Images: lo.Map(batch, func(item EmbeddingInput, _ int) []byte { return item.image }),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i, vector := range result {
|
||||
_, err = tx.Exec("UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?", timestamp(), vector, batch[i].filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
filenamesOnDisk := make(map[string]struct{})
|
||||
|
||||
err = filepath.WalkDir(config.Files, func(path string, d os.DirEntry, err error) error {
|
||||
filename := strings.TrimPrefix(path, config.Files)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
filenamesOnDisk[filename] = struct{}{}
|
||||
records := []FileRecord{}
|
||||
err = db.Select(&records, "SELECT * FROM files WHERE filename = ?", filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stat, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
modtime := stat.ModTime().UnixMicro()
|
||||
if len(records) == 0 || modtime > records[0].EmbedTime || modtime > records[0].OcrTime || modtime > records[0].ThumbnailTime {
|
||||
_, err = db.Exec("INSERT OR IGNORE INTO files VALUES (?, 0, 0, 0, '', '', '', '')", filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
record := FileRecord{
|
||||
Filename: filename,
|
||||
filesize: stat.Size(),
|
||||
}
|
||||
if len(records) > 0 {
|
||||
record = records[0]
|
||||
}
|
||||
if modtime > record.EmbedTime || len(record.Embedding) == 0 {
|
||||
record.Embedding = nil
|
||||
}
|
||||
if modtime > record.OcrTime || len(record.RawOcrSegments) == 0 {
|
||||
record.RawOcrSegments = nil
|
||||
}
|
||||
if modtime > record.ThumbnailTime || len(record.Thumbnails) == 0 {
|
||||
record.Thumbnails = nil
|
||||
}
|
||||
toProcess <- record
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
close(toProcess)
|
||||
|
||||
err = iwg.Wait()
|
||||
close(toEmbed)
|
||||
close(toThumbnail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = wg.Wait()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := db.Queryx("SELECT filename FROM files")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for rows.Next() {
|
||||
var filename string
|
||||
err := rows.Scan(&filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, ok := filenamesOnDisk[filename]; !ok {
|
||||
_, err = tx.Exec("DELETE FROM files WHERE filename = ?", filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if err = tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const INDEX_ADD_BATCH = 512
|
||||
|
||||
func buildIndex(config Config, backend InferenceServerConfig) (Index, error) {
|
||||
var index Index
|
||||
|
||||
db, err := initializeDatabase(config)
|
||||
if err != nil {
|
||||
return index, err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
newFAISSIndex, err := faiss.IndexFactory(int(backend.EmbeddingSize), "SQfp16", faiss.MetricInnerProduct)
|
||||
if err != nil {
|
||||
return index, err
|
||||
}
|
||||
index.vectors = newFAISSIndex
|
||||
|
||||
var count int
|
||||
err = db.Get(&count, "SELECT COUNT(*) FROM files")
|
||||
if err != nil {
|
||||
return index, err
|
||||
}
|
||||
|
||||
index.filenames = make([]string, 0, count)
|
||||
index.formatCodes = make([]int64, 0, count)
|
||||
buffer := make([]float32, 0, INDEX_ADD_BATCH*backend.EmbeddingSize)
|
||||
index.formatNames = make([]string, 0, 5)
|
||||
|
||||
record := FileRecord{}
|
||||
rows, err := db.Queryx("SELECT * FROM files")
|
||||
if err != nil {
|
||||
return index, err
|
||||
}
|
||||
for rows.Next() {
|
||||
err := rows.StructScan(&record)
|
||||
if err != nil {
|
||||
return index, err
|
||||
}
|
||||
if len(record.Embedding) > 0 {
|
||||
index.filenames = append(index.filenames, record.Filename)
|
||||
for i := 0; i < len(record.Embedding); i += 2 {
|
||||
buffer = append(buffer, float16.Frombits(uint16(record.Embedding[i])+uint16(record.Embedding[i+1])<<8).Float32())
|
||||
}
|
||||
if len(buffer) == cap(buffer) {
|
||||
index.vectors.Add(buffer)
|
||||
buffer = make([]float32, 0, INDEX_ADD_BATCH*backend.EmbeddingSize)
|
||||
}
|
||||
|
||||
formats := make([]string, 0, 5)
|
||||
if len(record.Thumbnails) > 0 {
|
||||
err := msgpack.Unmarshal(record.Thumbnails, &formats)
|
||||
if err != nil {
|
||||
return index, err
|
||||
}
|
||||
}
|
||||
|
||||
formatCode := int64(0)
|
||||
for _, formatString := range formats {
|
||||
found := false
|
||||
for i, name := range index.formatNames {
|
||||
if name == formatString {
|
||||
formatCode |= 1 << i
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
newIndex := len(index.formatNames)
|
||||
formatCode |= 1 << newIndex
|
||||
index.formatNames = append(index.formatNames, formatString)
|
||||
}
|
||||
}
|
||||
index.formatCodes = append(index.formatCodes, formatCode)
|
||||
}
|
||||
}
|
||||
if len(buffer) > 0 {
|
||||
index.vectors.Add(buffer)
|
||||
}
|
||||
|
||||
return index, nil
|
||||
}
|
||||
|
||||
func decodeFP16Buffer(buf []byte) []float32 {
|
||||
out := make([]float32, 0, len(buf)/2)
|
||||
for i := 0; i < len(buf); i += 2 {
|
||||
out = append(out, float16.Frombits(uint16(buf[i])+uint16(buf[i+1])<<8).Float32())
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type EmbeddingVector []float32
|
||||
|
||||
type QueryResult struct {
|
||||
Matches [][]interface{} `json:"matches"`
|
||||
Formats []string `json:"formats"`
|
||||
Extensions map[string]string `json:"extensions"`
|
||||
}
|
||||
|
||||
// this terrible language cannot express tagged unions
|
||||
type QueryTerm struct {
|
||||
Embedding *EmbeddingVector `json:"embedding"`
|
||||
Image *string `json:"image"` // base64
|
||||
Text *string `json:"text"`
|
||||
Weight *float32 `json:"weight"`
|
||||
}
|
||||
|
||||
type QueryRequest struct {
|
||||
Terms []QueryTerm `json:"terms"`
|
||||
K *int `json:"k"`
|
||||
}
|
||||
|
||||
func queryIndex(index *Index, query EmbeddingVector, k int) (QueryResult, error) {
|
||||
var qr QueryResult
|
||||
distances, ids, err := index.vectors.Search(query, int64(k))
|
||||
if err != nil {
|
||||
return qr, err
|
||||
}
|
||||
items := lo.Map(lo.Zip2(distances, ids), func(x lo.Tuple2[float32, int64], i int) []interface{} {
|
||||
return []interface{}{
|
||||
x.A,
|
||||
index.filenames[x.B],
|
||||
generateFilenameHash(index.filenames[x.B]),
|
||||
index.formatCodes[x.B],
|
||||
}
|
||||
})
|
||||
|
||||
return QueryResult{
|
||||
Matches: items,
|
||||
Formats: index.formatNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func handleRequest(config Config, backendConfig InferenceServerConfig, index *Index, w http.ResponseWriter, req *http.Request) error {
|
||||
if req.Body == nil {
|
||||
io.WriteString(w, "OK") // health check
|
||||
return nil
|
||||
}
|
||||
dec := json.NewDecoder(req.Body)
|
||||
var qreq QueryRequest
|
||||
err := dec.Decode(&qreq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
totalEmbedding := make(EmbeddingVector, backendConfig.EmbeddingSize)
|
||||
|
||||
imageBatch := make([][]byte, 0)
|
||||
imageWeights := make([]float32, 0)
|
||||
textBatch := make([]string, 0)
|
||||
textWeights := make([]float32, 0)
|
||||
|
||||
for _, term := range qreq.Terms {
|
||||
if term.Image != nil {
|
||||
bytes, err := base64.StdEncoding.DecodeString(*term.Image)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
loaded := bimg.NewImage(bytes)
|
||||
resized, err := loaded.Process(bimg.Options{
|
||||
Width: int(backendConfig.ImageSize[0]),
|
||||
Height: int(backendConfig.ImageSize[1]),
|
||||
Force: true,
|
||||
Type: bimg.PNG,
|
||||
Interpretation: bimg.InterpretationSRGB,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
imageBatch = append(imageBatch, resized)
|
||||
if term.Weight != nil {
|
||||
imageWeights = append(imageWeights, *term.Weight)
|
||||
} else {
|
||||
imageWeights = append(imageWeights, 1)
|
||||
}
|
||||
}
|
||||
if term.Text != nil {
|
||||
textBatch = append(textBatch, *term.Text)
|
||||
if term.Weight != nil {
|
||||
textWeights = append(textWeights, *term.Weight)
|
||||
} else {
|
||||
textWeights = append(textWeights, 1)
|
||||
}
|
||||
}
|
||||
if term.Embedding != nil {
|
||||
weight := float32(1.0)
|
||||
if term.Weight != nil {
|
||||
weight = *term.Weight
|
||||
}
|
||||
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
|
||||
totalEmbedding[i] += (*term.Embedding)[i] * weight
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(imageBatch) > 0 {
|
||||
embs, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "/", EmbeddingRequest{
|
||||
Images: imageBatch,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for j, emb := range embs {
|
||||
embd := decodeFP16Buffer(emb)
|
||||
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
|
||||
totalEmbedding[i] += embd[i] * imageWeights[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(textBatch) > 0 {
|
||||
embs, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "/", EmbeddingRequest{
|
||||
Text: textBatch,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for j, emb := range embs {
|
||||
embd := decodeFP16Buffer(emb)
|
||||
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
|
||||
totalEmbedding[i] += embd[i] * textWeights[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
k := 1000
|
||||
if qreq.K != nil {
|
||||
k = *qreq.K
|
||||
}
|
||||
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
enc := json.NewEncoder(w)
|
||||
|
||||
qres, err := queryIndex(index, totalEmbedding, k)
|
||||
|
||||
qres.Extensions = make(map[string]string)
|
||||
for k, v := range imageFormats(config) {
|
||||
qres.Extensions[k] = v.extension
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = enc.Encode(qres)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
os.Setenv("VIPS_WARNING", "FALSE") // this does not actually work
|
||||
bimg.VipsCacheSetMax(0)
|
||||
bimg.VipsCacheSetMaxMem(0)
|
||||
}
|
||||
|
||||
func main() {
|
||||
vips.Startup(&vips.Config{})
|
||||
defer vips.Shutdown()
|
||||
|
||||
content, err := os.ReadFile(os.Args[1])
|
||||
if err != nil {
|
||||
log.Fatal("config file unreadable ", err)
|
||||
}
|
||||
var config Config
|
||||
err = json.Unmarshal(content, &config)
|
||||
if err != nil {
|
||||
log.Fatal("config file wrong ", err)
|
||||
}
|
||||
fmt.Println(config)
|
||||
|
||||
db, err := sqlx.Connect("sqlite3", config.DbPath)
|
||||
if err != nil {
|
||||
log.Fatal("DB connection failure ", db)
|
||||
}
|
||||
db.MustExec(schema)
|
||||
|
||||
var backend InferenceServerConfig
|
||||
for {
|
||||
resp, err := http.Get(config.ClipServer + "/config")
|
||||
if err != nil {
|
||||
log.Println("backend failed (fetch) ", err)
|
||||
}
|
||||
backend, err = decodeMsgpackFrom[InferenceServerConfig](resp)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
log.Println("backend failed (parse) ", err)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
requestIngest := make(chan struct{}, 1)
|
||||
|
||||
var index *Index
|
||||
// maybe this ought to be mutexed?
|
||||
var lastError *error
|
||||
// there's not a neat way to reusably broadcast to multiple channels, but I *can* abuse WaitGroups probably
|
||||
// this might cause horrible concurrency issues, but you brought me to this point, Go designers
|
||||
var wg sync.WaitGroup
|
||||
|
||||
go func() {
|
||||
for {
|
||||
wg.Add(1)
|
||||
log.Println("ingest running")
|
||||
err := ingestFiles(config, backend)
|
||||
if err != nil {
|
||||
log.Println("ingest failed ", err)
|
||||
lastError = &err
|
||||
} else {
|
||||
newIndex, err := buildIndex(config, backend)
|
||||
if err != nil {
|
||||
log.Println("index build failed ", err)
|
||||
lastError = &err
|
||||
} else {
|
||||
lastError = nil
|
||||
index = &newIndex
|
||||
}
|
||||
}
|
||||
wg.Done()
|
||||
<-requestIngest
|
||||
}
|
||||
}()
|
||||
newIndex, err := buildIndex(config, backend)
|
||||
index = &newIndex
|
||||
if err != nil {
|
||||
log.Fatal("index build failed ", err)
|
||||
}
|
||||
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Add("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Add("Access-Control-Allow-Headers", "Content-Type")
|
||||
if req.Method == "OPTIONS" {
|
||||
w.WriteHeader(204)
|
||||
return
|
||||
}
|
||||
err := handleRequest(config, backend, index, w, req)
|
||||
if err != nil {
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.WriteHeader(500)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
})
|
||||
http.HandleFunc("/reload", func(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Method == "POST" {
|
||||
log.Println("requesting index reload")
|
||||
select {
|
||||
case requestIngest <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
wg.Wait()
|
||||
if lastError == nil {
|
||||
w.Write([]byte("OK"))
|
||||
} else {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte((*lastError).Error()))
|
||||
}
|
||||
}
|
||||
})
|
||||
http.HandleFunc("/profile", func(w http.ResponseWriter, req *http.Request) {
|
||||
f, err := os.Create("mem.pprof")
|
||||
if err != nil {
|
||||
log.Fatal("could not create memory profile: ", err)
|
||||
}
|
||||
defer f.Close()
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
log.Printf("Memory usage: Alloc=%v, TotalAlloc=%v, Sys=%v", m.Alloc, m.TotalAlloc, m.Sys)
|
||||
log.Println(bimg.VipsMemory())
|
||||
bimg.VipsDebugInfo()
|
||||
runtime.GC() // Trigger garbage collection
|
||||
if err := pprof.WriteHeapProfile(f); err != nil {
|
||||
log.Fatal("could not write memory profile: ", err)
|
||||
}
|
||||
})
|
||||
log.Println("starting server")
|
||||
http.ListenAndServe(fmt.Sprintf(":%d", config.Port), nil)
|
||||
}
|
265
misc/bad-go-version/problematic_thing_2.go
Normal file
265
misc/bad-go-version/problematic_thing_2.go
Normal file
@ -0,0 +1,265 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/davidbyttow/govips/v2/vips"
|
||||
"github.com/samber/lo"
|
||||
"github.com/titanous/json5"
|
||||
)
|
||||
|
||||
const CALLBACK_REGEX string = ">AF_initDataCallback\\(({key: 'ds:1'.*?)\\);</script>"
|
||||
|
||||
type SegmentCoords struct {
|
||||
x int
|
||||
y int
|
||||
w int
|
||||
h int
|
||||
}
|
||||
|
||||
type Segment struct {
|
||||
coords SegmentCoords
|
||||
text string
|
||||
}
|
||||
|
||||
type ScanResult []Segment
|
||||
|
||||
// TODO coordinates are negative sometimes and I think they shouldn't be
|
||||
func rationalizeCoordsFormat1(imageW float64, imageH float64, centerXFraction float64, centerYFraction float64, widthFraction float64, heightFraction float64) SegmentCoords {
|
||||
return SegmentCoords{
|
||||
x: int(math.Round((centerXFraction - widthFraction/2) * imageW)),
|
||||
y: int(math.Round((centerYFraction - heightFraction/2) * imageH)),
|
||||
w: int(math.Round(widthFraction * imageW)),
|
||||
h: int(math.Round(heightFraction * imageH)),
|
||||
}
|
||||
}
|
||||
|
||||
func scanImageChunk(image []byte, imageWidth int, imageHeight int) (ScanResult, error) {
|
||||
var result ScanResult
|
||||
timestamp := time.Now().UnixMicro()
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
defer w.Close()
|
||||
h := make(textproto.MIMEHeader)
|
||||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="encoded_image"; filename="ocr%d.png"`, timestamp))
|
||||
h.Set("Content-Type", "image/png")
|
||||
fw, err := w.CreatePart(h)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
fw.Write(image)
|
||||
w.Close()
|
||||
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://lens.google.com/v3/upload?stcs=%d", timestamp), &b)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
req.Header.Add("User-Agent", "Mozilla/5.0 (Linux; Android 13; RMX3771) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.6167.144 Mobile Safari/537.36")
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "SOCS",
|
||||
Value: "CAESEwgDEgk0ODE3Nzk3MjQaAmVuIAEaBgiA_LyaBg",
|
||||
})
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
client := http.Client{}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
re, _ := regexp.Compile(CALLBACK_REGEX)
|
||||
matches := re.FindStringSubmatch(string(body[:]))
|
||||
if len(matches) == 0 {
|
||||
return result, fmt.Errorf("invalid API response")
|
||||
}
|
||||
match := matches[1]
|
||||
var lensObject map[string]interface{}
|
||||
err = json5.Unmarshal([]byte(match), &lensObject)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
if _, ok := lensObject["errorHasStatus"]; ok {
|
||||
return result, errors.New("lens failed")
|
||||
}
|
||||
|
||||
root := lensObject["data"].([]interface{})
|
||||
|
||||
var textSegments []string
|
||||
var textRegions []SegmentCoords
|
||||
|
||||
// I don't know why Google did this.
|
||||
// Text segments are in one place and their locations are in another, using a very strange coordinate system.
|
||||
// At least I don't need whatever is contained in the base64 parts (which I assume are protobufs).
|
||||
// TODO: on a few images, this seems to not work for some reason.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// https://github.com/dimdenGD/chrome-lens-ocr/blob/main/src/core.js#L316 has code for a fallback text segment read mode.
|
||||
// In testing, this proved unnecessary (quirks of the HTTP request? I don't know), and this only happens on textless images.
|
||||
textSegments = []string{}
|
||||
textRegions = []SegmentCoords{}
|
||||
}
|
||||
}()
|
||||
|
||||
textSegmentsRaw := root[3].([]interface{})[4].([]interface{})[0].([]interface{})[0].([]interface{})
|
||||
textRegionsRaw := root[2].([]interface{})[3].([]interface{})[0].([]interface{})
|
||||
for _, x := range textRegionsRaw {
|
||||
if strings.HasPrefix(x.([]interface{})[11].(string), "text:") {
|
||||
rawCoords := x.([]interface{})[1].([]interface{})
|
||||
coords := rationalizeCoordsFormat1(float64(imageWidth), float64(imageHeight), rawCoords[0].(float64), rawCoords[1].(float64), rawCoords[2].(float64), rawCoords[3].(float64))
|
||||
textRegions = append(textRegions, coords)
|
||||
}
|
||||
}
|
||||
for _, x := range textSegmentsRaw {
|
||||
textSegment := x.(string)
|
||||
textSegments = append(textSegments, textSegment)
|
||||
}
|
||||
|
||||
return lo.Map(lo.Zip2(textSegments, textRegions), func(x lo.Tuple2[string, SegmentCoords], _ int) Segment {
|
||||
return Segment{
|
||||
text: x.A,
|
||||
coords: x.B,
|
||||
}
|
||||
}), nil
|
||||
}
|
||||
|
||||
const MAX_DIM int = 1024
|
||||
|
||||
func scanImage(image *vips.ImageRef) (ScanResult, error) {
|
||||
result := ScanResult{}
|
||||
width := image.Width()
|
||||
height := image.Height()
|
||||
if width > MAX_DIM {
|
||||
width = MAX_DIM
|
||||
height = int(math.Round(float64(height) * (float64(width) / float64(image.Width()))))
|
||||
}
|
||||
downscaled, err := image.Copy()
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
downscaled.Resize(float64(width)/float64(image.Width()), vips.KernelLanczos3)
|
||||
for y := 0; y < height; y += MAX_DIM {
|
||||
chunkHeight := MAX_DIM
|
||||
if y+chunkHeight > height {
|
||||
chunkHeight = height - y
|
||||
}
|
||||
chunk, err := image.Copy() // TODO this really really should not be in-place
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
err = chunk.ExtractArea(0, y, width, height)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
buf, _, err := chunk.ExportPng(&vips.PngExportParams{})
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
res, err := scanImageChunk(buf, width, chunkHeight)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
for _, segment := range res {
|
||||
result = append(result, Segment{
|
||||
text: segment.text,
|
||||
coords: SegmentCoords{
|
||||
y: segment.coords.y + y,
|
||||
x: segment.coords.x,
|
||||
w: segment.coords.w,
|
||||
h: segment.coords.h,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
/*
|
||||
async def scan_image_chunk(sess, image):
|
||||
# send data to inscrutable undocumented Google service
|
||||
# https://github.com/AuroraWright/owocr/blob/master/owocr/ocr.py#L193
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
data = aiohttp.FormData()
|
||||
data.add_field(
|
||||
"encoded_image",
|
||||
encode_img(image),
|
||||
filename="ocr" + str(timestamp) + ".png",
|
||||
content_type="image/png"
|
||||
)
|
||||
async with sess.post(url, headers=headers, cookies=cookies, data=data, timeout=10) as res:
|
||||
body = await res.text()
|
||||
|
||||
# I really worry about Google sometimes. This is not a sensible format.
|
||||
match = CALLBACK_REGEX.search(body)
|
||||
if match == None:
|
||||
raise ValueError("Invalid callback")
|
||||
|
||||
lens_object = pyjson5.loads(match.group(1))
|
||||
if "errorHasStatus" in lens_object:
|
||||
raise RuntimeError("Lens failed")
|
||||
|
||||
text_segments = []
|
||||
text_regions = []
|
||||
|
||||
root = lens_object["data"]
|
||||
|
||||
# I don't know why Google did this.
|
||||
# Text segments are in one place and their locations are in another, using a very strange coordinate system.
|
||||
# At least I don't need whatever is contained in the base64 partss (which I assume are protobufs).
|
||||
# TODO: on a few images, this seems to not work for some reason.
|
||||
try:
|
||||
text_segments = root[3][4][0][0]
|
||||
text_regions = [ rationalize_coords_format1(image.width, image.height, *x[1]) for x in root[2][3][0] if x[11].startswith("text:") ]
|
||||
except (KeyError, IndexError):
|
||||
# https://github.com/dimdenGD/chrome-lens-ocr/blob/main/src/core.js#L316 has code for a fallback text segment read mode.
|
||||
# In testing, this proved unnecessary (quirks of the HTTP request? I don't know), and this only happens on textless images.
|
||||
return [], []
|
||||
|
||||
return text_segments, text_regions
|
||||
|
||||
MAX_SCAN_DIM = 1000 # not actually true but close enough
|
||||
def chunk_image(image: Image):
|
||||
chunks = []
|
||||
# Cut image down in X axis (I'm assuming images aren't too wide to scan in downscaled form because merging text horizontally would be annoying)
|
||||
if image.width > MAX_SCAN_DIM:
|
||||
image = image.resize((MAX_SCAN_DIM, round(image.height * (image.width / MAX_SCAN_DIM))), Image.LANCZOS)
|
||||
for y in range(0, image.height, MAX_SCAN_DIM):
|
||||
chunks.append(image.crop((0, y, image.width, min(y + MAX_SCAN_DIM, image.height))))
|
||||
return chunks
|
||||
|
||||
async def scan_chunks(sess: aiohttp.ClientSession, chunks: [Image]):
|
||||
# If text happens to be split across the cut line it won't get read.
|
||||
# This is because doing overlap read areas would be really annoying.
|
||||
text = ""
|
||||
regions = []
|
||||
for chunk in chunks:
|
||||
new_segments, new_regions = await scan_image_chunk(sess, chunk)
|
||||
for segment in new_segments:
|
||||
text += segment + "\n"
|
||||
for i, (segment, region) in enumerate(zip(new_segments, new_regions)):
|
||||
regions.append({ **region, "y": region["y"] + (MAX_SCAN_DIM * i), "text": segment })
|
||||
return text, regions
|
||||
|
||||
async def scan_image(sess: aiohttp.ClientSession, image: Image):
|
||||
return await scan_chunks(sess, chunk_image(image))
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
print(await scan_image(sess, Image.open("/data/public/memes-or-something/linear-algebra-chess.png")))
|
||||
asyncio.run(main())
|
||||
*/
|
112
mse.py
112
mse.py
@ -12,8 +12,11 @@ import os
|
||||
import aiohttp_cors
|
||||
import json
|
||||
import io
|
||||
import time
|
||||
import sys
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
import threading
|
||||
|
||||
with open(sys.argv[1], "r") as config_file:
|
||||
CONFIG = json.load(config_file)
|
||||
@ -21,26 +24,26 @@ with open(sys.argv[1], "r") as config_file:
|
||||
app = web.Application(client_max_size=32*1024**2)
|
||||
routes = web.RouteTableDef()
|
||||
|
||||
async def clip_server(query, unpack_buffer=True):
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
async with sess.post(CONFIG["clip_server"], data=umsgpack.dumps(query)) as res:
|
||||
response = umsgpack.loads(await res.read())
|
||||
if res.status == 200:
|
||||
if unpack_buffer:
|
||||
response = [ numpy.frombuffer(x, dtype="float16") for x in response ]
|
||||
return response
|
||||
else:
|
||||
raise Exception(response if res.headers.get("content-type") == "application/msgpack" else (await res.text()))
|
||||
async def clip_server(sess: aiohttp.ClientSession, query, unpack_buffer=True):
|
||||
async with sess.post(CONFIG["clip_server"], data=umsgpack.dumps(query)) as res:
|
||||
response = umsgpack.loads(await res.read())
|
||||
if res.status == 200:
|
||||
if unpack_buffer:
|
||||
response = [ numpy.frombuffer(x, dtype="float16") for x in response ]
|
||||
return response
|
||||
else:
|
||||
raise Exception(response if res.headers.get("content-type") == "application/msgpack" else (await res.text()))
|
||||
|
||||
@routes.post("/")
|
||||
async def run_query(request):
|
||||
sess = app["session"]
|
||||
data = await request.json()
|
||||
embeddings = []
|
||||
if images := data.get("images", []):
|
||||
target_image_size = app["index"].inference_server_config["image_size"]
|
||||
embeddings.extend(await clip_server({ "images": [ load_image(io.BytesIO(base64.b64decode(x)), target_image_size)[0] for x, w in images ] }))
|
||||
embeddings.extend(await clip_server(sess, { "images": [ load_image(io.BytesIO(base64.b64decode(x)), target_image_size)[0] for x, w in images ] }))
|
||||
if text := data.get("text", []):
|
||||
embeddings.extend(await clip_server({ "text": [ x for x, w in text ] }))
|
||||
embeddings.extend(await clip_server(sess, { "text": [ x for x, w in text ] }))
|
||||
weights = [ w for x, w in images ] + [ w for x, w in text ]
|
||||
weighted_embeddings = [ e * w for e, w in zip(embeddings, weights) ]
|
||||
weighted_embeddings.extend([ numpy.array(x) for x in data.get("embeddings", []) ])
|
||||
@ -65,11 +68,12 @@ def load_image(path, image_size):
|
||||
return buf.getvalue(), path
|
||||
|
||||
class Index:
|
||||
def __init__(self, inference_server_config):
|
||||
def __init__(self, inference_server_config, http_session):
|
||||
self.faiss_index = faiss.IndexFlatIP(inference_server_config["embedding_size"])
|
||||
self.associated_filenames = []
|
||||
self.inference_server_config = inference_server_config
|
||||
self.lock = asyncio.Lock()
|
||||
self.session = http_session
|
||||
|
||||
def search(self, query, top_k):
|
||||
distances, indices = self.faiss_index.search(numpy.array([query]), top_k)
|
||||
@ -80,18 +84,77 @@ class Index:
|
||||
except IndexError: pass
|
||||
return [ { "score": float(distance), "file": self.associated_filenames[index] } for index, distance in zip(indices, distances) ]
|
||||
|
||||
async def run_ocr(self):
|
||||
if not CONFIG.get("enable_ocr"): return
|
||||
|
||||
import ocr
|
||||
|
||||
print("Running OCR")
|
||||
|
||||
conn = await aiosqlite.connect(CONFIG["db_path"])
|
||||
unocred = await conn.execute_fetchall("SELECT files.filename FROM files LEFT JOIN ocr ON files.filename = ocr.filename WHERE ocr.scan_time IS NULL OR ocr.scan_time < files.modtime")
|
||||
|
||||
ocr_sem = asyncio.Semaphore(20) # Google has more concurrency than our internal CLIP backend. I am sure they will be fine.
|
||||
load_sem = threading.Semaphore(100) # provide backpressure in loading to avoid using 50GB of RAM (this happened)
|
||||
|
||||
async def run_image(filename, chunks):
|
||||
try:
|
||||
text, regions = await ocr.scan_chunks(self.session, chunks)
|
||||
await conn.execute("INSERT OR REPLACE INTO ocr VALUES (?, ?, ?, ?)", (filename, time.time(), text, json.dumps(regions)))
|
||||
await conn.commit()
|
||||
sys.stdout.write(".")
|
||||
sys.stdout.flush()
|
||||
except:
|
||||
print("OCR failed on", filename)
|
||||
finally:
|
||||
ocr_sem.release()
|
||||
|
||||
def load_and_chunk_image(filename):
|
||||
load_sem.acquire()
|
||||
im = Image.open(Path(CONFIG["files"]) / filename)
|
||||
return filename, ocr.chunk_image(im)
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
with ThreadPoolExecutor(max_workers=CONFIG.get("n_workers", 1)) as executor:
|
||||
for task in asyncio.as_completed([ asyncio.get_running_loop().run_in_executor(executor, load_and_chunk_image, file[0]) for file in unocred ]):
|
||||
filename, chunks = await task
|
||||
await ocr_sem.acquire()
|
||||
tg.create_task(run_image(filename, chunks))
|
||||
load_sem.release()
|
||||
|
||||
async def reload(self):
|
||||
async with self.lock:
|
||||
with ProcessPoolExecutor(max_workers=12) as executor:
|
||||
with ThreadPoolExecutor(max_workers=CONFIG.get("n_workers", 1)) as executor:
|
||||
print("Indexing")
|
||||
conn = await aiosqlite.connect(CONFIG["db_path"])
|
||||
conn.row_factory = aiosqlite.Row
|
||||
await conn.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
filename TEXT PRIMARY KEY,
|
||||
modtime REAL NOT NULL,
|
||||
embedding_vector BLOB NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
filename TEXT PRIMARY KEY,
|
||||
modtime REAL NOT NULL,
|
||||
embedding_vector BLOB NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS ocr (
|
||||
filename TEXT PRIMARY KEY REFERENCES files(filename),
|
||||
scan_time INTEGER NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
raw_segments TEXT
|
||||
);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
|
||||
filename,
|
||||
text,
|
||||
tokenize='unicode61 remove_diacritics 2',
|
||||
content='ocr'
|
||||
);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON ocr BEGIN
|
||||
INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, new.text);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER DELETE ON ocr BEGIN
|
||||
INSERT INTO ocr_fts (ocr_fts, rowid, filename, text) VALUES ('delete', old.rowid, old.filename, old.text);
|
||||
END;
|
||||
""")
|
||||
try:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
@ -102,7 +165,7 @@ class Index:
|
||||
async def do_batch(batch):
|
||||
try:
|
||||
query = { "images": [ arg[2] for arg in batch ] }
|
||||
embeddings = await clip_server(query, False)
|
||||
embeddings = await clip_server(self.session, query, False)
|
||||
await conn.executemany("INSERT OR REPLACE INTO files VALUES (?, ?, ?)", [
|
||||
(filename, modtime, embedding) for (filename, modtime, _), embedding in zip(batch, embeddings)
|
||||
])
|
||||
@ -188,6 +251,8 @@ class Index:
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
await self.run_ocr()
|
||||
|
||||
app.router.add_routes(routes)
|
||||
|
||||
cors = aiohttp_cors.setup(app, defaults={
|
||||
@ -201,8 +266,8 @@ for route in list(app.router.routes()):
|
||||
cors.add(route)
|
||||
|
||||
async def main():
|
||||
sess = aiohttp.ClientSession()
|
||||
while True:
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
try:
|
||||
async with await sess.get(CONFIG["clip_server"] + "config") as res:
|
||||
inference_server_config = umsgpack.unpackb(await res.read())
|
||||
@ -211,8 +276,9 @@ async def main():
|
||||
except:
|
||||
traceback.print_exc()
|
||||
await asyncio.sleep(1)
|
||||
index = Index(inference_server_config)
|
||||
index = Index(inference_server_config, sess)
|
||||
app["index"] = index
|
||||
app["session"] = sess
|
||||
await index.reload()
|
||||
print("Ready")
|
||||
if CONFIG.get("no_run_server", False): return
|
||||
|
@ -1,6 +1,9 @@
|
||||
{
|
||||
"clip_server": "http://localhost:1708/",
|
||||
"db_path": "/srv/mse/data.sqlite3",
|
||||
"clip_server": "http://100.64.0.10:1708",
|
||||
"db_path": "data.sqlite3",
|
||||
"port": 1707,
|
||||
"files": "/data/public/memes-or-something/"
|
||||
"files": "/data/public/memes-or-something/",
|
||||
"enable_ocr": false,
|
||||
"thumbs_path": "./thumbtemp",
|
||||
"enable_thumbs": false
|
||||
}
|
101
ocr.py
Normal file
101
ocr.py
Normal file
@ -0,0 +1,101 @@
|
||||
import pyjson5
|
||||
import re
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from PIL import Image
|
||||
import time
|
||||
import io
|
||||
|
||||
CALLBACK_REGEX = re.compile(r">AF_initDataCallback\(({key: 'ds:1'.*?)\);</script>")
|
||||
|
||||
def encode_img(img):
|
||||
image_bytes = io.BytesIO()
|
||||
img.save(image_bytes, format="PNG", compress_level=6)
|
||||
return image_bytes.getvalue()
|
||||
|
||||
def rationalize_coords_format1(image_w, image_h, center_x_fraction, center_y_fraction, width_fraction, height_fraction, mysterious):
|
||||
return {
|
||||
"x": round((center_x_fraction - width_fraction / 2) * image_w),
|
||||
"y": round((center_y_fraction - height_fraction / 2) * image_h),
|
||||
"w": round(width_fraction * image_w),
|
||||
"h": round(height_fraction * image_h)
|
||||
}
|
||||
|
||||
async def scan_image_chunk(sess, image):
|
||||
timestamp = int(time.time() * 1000)
|
||||
url = f"https://lens.google.com/v3/upload?stcs={timestamp}"
|
||||
headers = {"User-Agent": "Mozilla/5.0 (Linux; Android 13; RMX3771) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.6167.144 Mobile Safari/537.36"}
|
||||
cookies = {"SOCS": "CAESEwgDEgk0ODE3Nzk3MjQaAmVuIAEaBgiA_LyaBg"}
|
||||
|
||||
# send data to inscrutable undocumented Google service
|
||||
# https://github.com/AuroraWright/owocr/blob/master/owocr/ocr.py#L193
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
data = aiohttp.FormData()
|
||||
data.add_field(
|
||||
"encoded_image",
|
||||
encode_img(image),
|
||||
filename="ocr" + str(timestamp) + ".png",
|
||||
content_type="image/png"
|
||||
)
|
||||
async with sess.post(url, headers=headers, cookies=cookies, data=data, timeout=10) as res:
|
||||
body = await res.text()
|
||||
|
||||
# I really worry about Google sometimes. This is not a sensible format.
|
||||
match = CALLBACK_REGEX.search(body)
|
||||
if match == None:
|
||||
raise ValueError("Invalid callback")
|
||||
|
||||
lens_object = pyjson5.loads(match.group(1))
|
||||
if "errorHasStatus" in lens_object:
|
||||
raise RuntimeError("Lens failed")
|
||||
|
||||
text_segments = []
|
||||
text_regions = []
|
||||
|
||||
root = lens_object["data"]
|
||||
|
||||
# I don't know why Google did this.
|
||||
# Text segments are in one place and their locations are in another, using a very strange coordinate system.
|
||||
# At least I don't need whatever is contained in the base64 parts (which I assume are protobufs).
|
||||
# TODO: on a few images, this seems to not work for some reason.
|
||||
try:
|
||||
text_segments = root[3][4][0][0]
|
||||
text_regions = [ rationalize_coords_format1(image.width, image.height, *x[1]) for x in root[2][3][0] if x[11].startswith("text:") ]
|
||||
except (KeyError, IndexError):
|
||||
# https://github.com/dimdenGD/chrome-lens-ocr/blob/main/src/core.js#L316 has code for a fallback text segment read mode.
|
||||
# In testing, this proved unnecessary (quirks of the HTTP request? I don't know), and this only happens on textless images.
|
||||
return [], []
|
||||
|
||||
return text_segments, text_regions
|
||||
|
||||
MAX_SCAN_DIM = 1000 # not actually true but close enough
|
||||
def chunk_image(image: Image):
|
||||
chunks = []
|
||||
# Cut image down in X axis (I'm assuming images aren't too wide to scan in downscaled form because merging text horizontally would be annoying)
|
||||
if image.width > MAX_SCAN_DIM:
|
||||
image = image.resize((MAX_SCAN_DIM, round(image.height * (image.width / MAX_SCAN_DIM))), Image.LANCZOS)
|
||||
for y in range(0, image.height, MAX_SCAN_DIM):
|
||||
chunks.append(image.crop((0, y, image.width, min(y + MAX_SCAN_DIM, image.height))))
|
||||
return chunks
|
||||
|
||||
async def scan_chunks(sess: aiohttp.ClientSession, chunks: [Image]):
|
||||
# If text happens to be split across the cut line it won't get read.
|
||||
# This is because doing overlap read areas would be really annoying.
|
||||
text = ""
|
||||
regions = []
|
||||
for chunk in chunks:
|
||||
new_segments, new_regions = await scan_image_chunk(sess, chunk)
|
||||
for segment in new_segments:
|
||||
text += segment + "\n"
|
||||
for i, (segment, region) in enumerate(zip(new_segments, new_regions)):
|
||||
regions.append({ **region, "y": region["y"] + (MAX_SCAN_DIM * i), "text": segment })
|
||||
return text, regions
|
||||
|
||||
async def scan_image(sess: aiohttp.ClientSession, image: Image):
|
||||
return await scan_chunks(sess, chunk_image(image))
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
print(await scan_image(sess, Image.open("/data/public/memes-or-something/linear-algebra-chess.png")))
|
||||
asyncio.run(main())
|
892
src/main.rs
Normal file
892
src/main.rs
Normal file
@ -0,0 +1,892 @@
|
||||
use std::{collections::HashMap, io::Cursor};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, Context};
|
||||
use axum::body::Body;
|
||||
use axum::response::Response;
|
||||
use axum::{
|
||||
extract::Json,
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
http::StatusCode
|
||||
};
|
||||
use image::{imageops::FilterType, io::Reader as ImageReader, DynamicImage, ImageFormat};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{sqlite::SqliteConnectOptions, SqlitePool};
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
use tokio::task::JoinHandle;
|
||||
use walkdir::WalkDir;
|
||||
use base64::prelude::*;
|
||||
use faiss::Index;
|
||||
use futures_util::stream::{StreamExt, TryStreamExt};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tower_http::cors::CorsLayer;
|
||||
|
||||
mod ocr;
|
||||
|
||||
use crate::ocr::scan_image;
|
||||
|
||||
fn function_which_returns_50() -> usize { 50 }
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
struct Config {
|
||||
clip_server: String,
|
||||
db_path: String,
|
||||
port: u16,
|
||||
files: String,
|
||||
#[serde(default)]
|
||||
enable_ocr: bool,
|
||||
#[serde(default)]
|
||||
thumbs_path: String,
|
||||
#[serde(default)]
|
||||
enable_thumbs: bool,
|
||||
#[serde(default="function_which_returns_50")]
|
||||
ocr_concurrency: usize,
|
||||
#[serde(default)]
|
||||
no_run_server: bool
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct IIndex {
|
||||
vectors: faiss::index::IndexImpl,
|
||||
filenames: Vec<String>,
|
||||
format_codes: Vec<u64>,
|
||||
format_names: Vec<String>,
|
||||
}
|
||||
|
||||
const SCHEMA: &str = r#"
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
filename TEXT NOT NULL PRIMARY KEY,
|
||||
embedding_time INTEGER,
|
||||
ocr_time INTEGER,
|
||||
thumbnail_time INTEGER,
|
||||
embedding BLOB,
|
||||
ocr TEXT,
|
||||
raw_ocr_segments BLOB,
|
||||
thumbnails BLOB
|
||||
);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
|
||||
filename,
|
||||
ocr,
|
||||
tokenize='unicode61 remove_diacritics 2',
|
||||
content='ocr'
|
||||
);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN
|
||||
INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER DELETE ON files BEGIN
|
||||
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER UPDATE ON files BEGIN
|
||||
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
|
||||
INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
|
||||
END;
|
||||
"#;
|
||||
|
||||
#[derive(Debug, sqlx::FromRow, Clone, Default)]
|
||||
struct FileRecord {
|
||||
filename: String,
|
||||
embedding_time: Option<i64>,
|
||||
ocr_time: Option<i64>,
|
||||
thumbnail_time: Option<i64>,
|
||||
embedding: Option<Vec<u8>>,
|
||||
// this totally "will" be used later
|
||||
ocr: Option<String>,
|
||||
raw_ocr_segments: Option<Vec<u8>>,
|
||||
thumbnails: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
struct InferenceServerConfig {
|
||||
batch: usize,
|
||||
image_size: (u32, u32),
|
||||
embedding_size: usize,
|
||||
}
|
||||
|
||||
async fn query_clip_server<I, O>(
|
||||
client: &Client,
|
||||
config: &Config,
|
||||
path: &str,
|
||||
data: I,
|
||||
) -> Result<O> where I: Serialize, O: serde::de::DeserializeOwned,
|
||||
{
|
||||
let response = client
|
||||
.post(&format!("{}{}", config.clip_server, path))
|
||||
.header("Content-Type", "application/msgpack")
|
||||
.body(rmp_serde::to_vec_named(&data)?)
|
||||
.send()
|
||||
.await?;
|
||||
let result: O = rmp_serde::from_slice(&response.bytes().await?)?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct LoadedImage {
|
||||
image: Arc<DynamicImage>,
|
||||
filename: String,
|
||||
original_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EmbeddingInput {
|
||||
image: Vec<u8>,
|
||||
filename: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(untagged)]
|
||||
enum EmbeddingRequest {
|
||||
Images { images: Vec<serde_bytes::ByteBuf> },
|
||||
Text { text: Vec<String> }
|
||||
}
|
||||
|
||||
fn timestamp() -> i64 {
|
||||
chrono::Utc::now().timestamp_micros()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ImageFormatConfig {
|
||||
target_width: u32,
|
||||
target_filesize: u32,
|
||||
quality: u8,
|
||||
format: ImageFormat,
|
||||
extension: String,
|
||||
}
|
||||
|
||||
fn generate_filename_hash(filename: &str) -> String {
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut hasher = fnv::FnvHasher::default();
|
||||
filename.hash(&mut hasher);
|
||||
BASE64_URL_SAFE_NO_PAD.encode(hasher.finish().to_le_bytes())
|
||||
}
|
||||
|
||||
fn generate_thumbnail_filename(
|
||||
filename: &str,
|
||||
format_name: &str,
|
||||
format_config: &ImageFormatConfig,
|
||||
) -> String {
|
||||
format!(
|
||||
"{}{}.{}",
|
||||
generate_filename_hash(filename),
|
||||
format_name,
|
||||
format_config.extension
|
||||
)
|
||||
}
|
||||
|
||||
async fn initialize_database(config: &Config) -> Result<SqlitePool> {
|
||||
let connection_options = SqliteConnectOptions::new()
|
||||
.filename(&config.db_path)
|
||||
.create_if_missing(true);
|
||||
let pool = SqlitePool::connect_with(connection_options).await?;
|
||||
sqlx::query(SCHEMA).execute(&pool).await?;
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
fn image_formats(_config: &Config) -> HashMap<String, ImageFormatConfig> {
|
||||
let mut formats = HashMap::new();
|
||||
formats.insert(
|
||||
"jpegl".to_string(),
|
||||
ImageFormatConfig {
|
||||
target_width: 800,
|
||||
target_filesize: 0,
|
||||
quality: 70,
|
||||
format: ImageFormat::Jpeg,
|
||||
extension: "jpg".to_string(),
|
||||
},
|
||||
);
|
||||
formats.insert(
|
||||
"jpegh".to_string(),
|
||||
ImageFormatConfig {
|
||||
target_width: 1600,
|
||||
target_filesize: 0,
|
||||
quality: 80,
|
||||
format: ImageFormat::Jpeg,
|
||||
extension: "jpg".to_string(),
|
||||
},
|
||||
);
|
||||
formats.insert(
|
||||
"jpeg256kb".to_string(),
|
||||
ImageFormatConfig {
|
||||
target_width: 500,
|
||||
target_filesize: 256000,
|
||||
quality: 0,
|
||||
format: ImageFormat::Jpeg,
|
||||
extension: "jpg".to_string(),
|
||||
},
|
||||
);
|
||||
formats.insert(
|
||||
"avifh".to_string(),
|
||||
ImageFormatConfig {
|
||||
target_width: 1600,
|
||||
target_filesize: 0,
|
||||
quality: 80,
|
||||
format: ImageFormat::Avif,
|
||||
extension: "avif".to_string(),
|
||||
},
|
||||
);
|
||||
formats.insert(
|
||||
"avifl".to_string(),
|
||||
ImageFormatConfig {
|
||||
target_width: 800,
|
||||
target_filesize: 0,
|
||||
quality: 30,
|
||||
format: ImageFormat::Avif,
|
||||
extension: "avif".to_string(),
|
||||
},
|
||||
);
|
||||
formats
|
||||
}
|
||||
|
||||
async fn resize_for_embed(backend_config: Arc<InferenceServerConfig>, image: Arc<DynamicImage>) -> Result<Vec<u8>> {
|
||||
let resized = tokio::task::spawn_blocking(move || {
|
||||
let new = image.resize(
|
||||
backend_config.image_size.0,
|
||||
backend_config.image_size.1,
|
||||
FilterType::Lanczos3
|
||||
);
|
||||
let mut buf = Vec::new();
|
||||
let mut csr = Cursor::new(&mut buf);
|
||||
new.write_to(&mut csr, ImageFormat::Png)?;
|
||||
Ok::<Vec<u8>, anyhow::Error>(buf)
|
||||
}).await??;
|
||||
Ok(resized)
|
||||
}
|
||||
|
||||
async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>) -> Result<()> {
|
||||
let pool = initialize_database(&config).await?;
|
||||
let client = Client::new();
|
||||
|
||||
let formats = image_formats(&config);
|
||||
|
||||
let (to_process_tx, to_process_rx) = mpsc::channel::<FileRecord>(100);
|
||||
let (to_embed_tx, to_embed_rx) = mpsc::channel(backend.batch as usize);
|
||||
let (to_thumbnail_tx, to_thumbnail_rx) = mpsc::channel(30);
|
||||
let (to_ocr_tx, to_ocr_rx) = mpsc::channel(30);
|
||||
|
||||
let cpus = num_cpus::get();
|
||||
|
||||
// Image loading and preliminary resizing
|
||||
let image_loading: JoinHandle<Result<()>> = tokio::spawn({
|
||||
let config = config.clone();
|
||||
let backend = backend.clone();
|
||||
let stream = ReceiverStream::new(to_process_rx).map(Ok);
|
||||
stream.try_for_each_concurrent(Some(cpus), move |record| {
|
||||
let config = config.clone();
|
||||
let backend = backend.clone();
|
||||
let to_embed_tx = to_embed_tx.clone();
|
||||
let to_thumbnail_tx = to_thumbnail_tx.clone();
|
||||
let to_ocr_tx = to_ocr_tx.clone();
|
||||
async move {
|
||||
let path = Path::new(&config.files).join(&record.filename);
|
||||
let image: Result<Arc<DynamicImage>> = tokio::task::block_in_place(|| Ok(Arc::new(ImageReader::open(&path)?.with_guessed_format()?.decode()?)));
|
||||
let image = match image {
|
||||
Ok(image) => image,
|
||||
Err(e) => {
|
||||
log::error!("Could not read {}: {}", record.filename, e);
|
||||
return Ok(())
|
||||
}
|
||||
};
|
||||
if record.embedding.is_none() {
|
||||
let resized = resize_for_embed(backend.clone(), image.clone()).await?;
|
||||
|
||||
to_embed_tx.send(EmbeddingInput { image: resized, filename: record.filename.clone() }).await?
|
||||
}
|
||||
if record.thumbnails.is_none() && config.enable_thumbs {
|
||||
to_thumbnail_tx
|
||||
.send(LoadedImage {
|
||||
image: image.clone(),
|
||||
filename: record.filename.clone(),
|
||||
original_size: std::fs::metadata(&path)?.len() as usize,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
if record.raw_ocr_segments.is_none() && config.enable_ocr {
|
||||
to_ocr_tx
|
||||
.send(LoadedImage {
|
||||
image,
|
||||
filename: record.filename.clone(),
|
||||
original_size: 0,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
// Thumbnail generation
|
||||
let thumbnail_generation: Option<JoinHandle<Result<()>>> = if config.enable_thumbs {
|
||||
let config = config.clone();
|
||||
let pool = pool.clone();
|
||||
let stream = ReceiverStream::new(to_thumbnail_rx).map(Ok);
|
||||
let formats = Arc::new(formats);
|
||||
Some(tokio::spawn({
|
||||
stream.try_for_each_concurrent(Some(cpus), move |image| {
|
||||
use image::codecs::*;
|
||||
|
||||
let formats = formats.clone();
|
||||
let config = config.clone();
|
||||
let pool = pool.clone();
|
||||
async move {
|
||||
let filename = image.filename.clone();
|
||||
log::debug!("thumbnailing {}", filename);
|
||||
let generated_formats = tokio::task::spawn_blocking(move || {
|
||||
let mut generated_formats = Vec::new();
|
||||
let rgb = DynamicImage::from(image.image.to_rgb8());
|
||||
for (format_name, format_config) in &*formats {
|
||||
let resized = if format_config.target_filesize != 0 {
|
||||
let mut lb = 1;
|
||||
let mut ub = 100;
|
||||
loop {
|
||||
let quality = (lb + ub) / 2;
|
||||
let thumbnail = rgb.resize(
|
||||
format_config.target_width,
|
||||
u32::MAX,
|
||||
FilterType::Lanczos3,
|
||||
);
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
let mut csr = Cursor::new(&mut buf);
|
||||
// this is ugly but I don't actually know how to fix it (cannot factor it out due to issues with dyn Trait)
|
||||
match format_config.format {
|
||||
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, quality)),
|
||||
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, quality)),
|
||||
_ => unimplemented!()
|
||||
}?;
|
||||
if buf.len() > image.original_size {
|
||||
ub = quality;
|
||||
} else {
|
||||
lb = quality + 1;
|
||||
}
|
||||
if lb >= ub {
|
||||
break buf;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let thumbnail = rgb.resize(
|
||||
format_config.target_width,
|
||||
u32::MAX,
|
||||
FilterType::Lanczos3,
|
||||
);
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
let mut csr = Cursor::new(&mut buf);
|
||||
match format_config.format {
|
||||
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, format_config.quality)),
|
||||
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, format_config.quality)),
|
||||
ImageFormat::WebP => thumbnail.write_with_encoder(webp::WebPEncoder::new_lossless(&mut csr)),
|
||||
_ => unimplemented!()
|
||||
}?;
|
||||
buf
|
||||
};
|
||||
if resized.len() < image.original_size {
|
||||
generated_formats.push(format_name.clone());
|
||||
let thumbnail_path = Path::new(&config.thumbs_path).join(
|
||||
generate_thumbnail_filename(
|
||||
&image.filename,
|
||||
format_name,
|
||||
format_config,
|
||||
),
|
||||
);
|
||||
std::fs::write(thumbnail_path, resized)?;
|
||||
}
|
||||
}
|
||||
Ok::<Vec<String>, anyhow::Error>(generated_formats)
|
||||
}).await??;
|
||||
let formats_data = rmp_serde::to_vec(&generated_formats)?;
|
||||
let ts = timestamp();
|
||||
sqlx::query!(
|
||||
"UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?",
|
||||
formats_data,
|
||||
ts,
|
||||
filename
|
||||
)
|
||||
.execute(&pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// OCR
|
||||
let ocr: Option<JoinHandle<Result<()>>> = if config.enable_ocr {
|
||||
let client = client.clone();
|
||||
let pool = pool.clone();
|
||||
let stream = ReceiverStream::new(to_ocr_rx).map(Ok);
|
||||
Some(tokio::spawn({
|
||||
stream.try_for_each_concurrent(Some(config.ocr_concurrency), move |image| {
|
||||
let client = client.clone();
|
||||
let pool = pool.clone();
|
||||
async move {
|
||||
log::debug!("OCRing {}", image.filename);
|
||||
let scan = match scan_image(&client, &image.image).await {
|
||||
Ok(scan) => scan,
|
||||
Err(e) => {
|
||||
log::error!("OCR failure {}: {}", image.filename, e);
|
||||
return Ok(())
|
||||
}
|
||||
};
|
||||
let ocr_text = scan
|
||||
.iter()
|
||||
.map(|segment| segment.text.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let ocr_data = rmp_serde::to_vec(&scan)?;
|
||||
let ts = timestamp();
|
||||
sqlx::query!(
|
||||
"UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?",
|
||||
ocr_text,
|
||||
ocr_data,
|
||||
ts,
|
||||
image.filename
|
||||
)
|
||||
.execute(&pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let embedding_generation: JoinHandle<Result<()>> = tokio::spawn({
|
||||
let stream = ReceiverStream::new(to_embed_rx).chunks(backend.batch);
|
||||
let client = client.clone();
|
||||
let config = config.clone();
|
||||
let pool = pool.clone();
|
||||
// keep multiple embedding requests in flight
|
||||
stream.map(Ok).try_for_each_concurrent(Some(3), move |batch| {
|
||||
let client = client.clone();
|
||||
let config = config.clone();
|
||||
let pool = pool.clone();
|
||||
async move {
|
||||
let result: Vec<serde_bytes::ByteBuf> = query_clip_server(
|
||||
&client,
|
||||
&config,
|
||||
"",
|
||||
EmbeddingRequest::Images {
|
||||
images: batch.iter().map(|input| serde_bytes::ByteBuf::from(input.image.clone())).collect(),
|
||||
},
|
||||
).await.context("querying CLIP server")?;
|
||||
|
||||
let mut tx = pool.begin().await?;
|
||||
let ts = timestamp();
|
||||
for (i, vector) in result.into_iter().enumerate() {
|
||||
let vector = vector.into_vec();
|
||||
log::debug!("embedded {}", batch[i].filename);
|
||||
sqlx::query!(
|
||||
"UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?",
|
||||
ts,
|
||||
vector,
|
||||
batch[i].filename
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
tx.commit().await?;
|
||||
anyhow::Result::Ok(())
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
let mut filenames = HashMap::new();
|
||||
|
||||
// blocking OS calls
|
||||
tokio::task::block_in_place(|| -> anyhow::Result<()> {
|
||||
for entry in WalkDir::new(config.files.as_str()) {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_file() {
|
||||
let filename = path.strip_prefix(&config.files)?.to_str().unwrap().to_string();
|
||||
let modtime = entry.metadata()?.modified()?.duration_since(std::time::UNIX_EPOCH)?;
|
||||
let modtime = modtime.as_micros() as i64;
|
||||
filenames.insert(filename.clone(), (path.to_path_buf(), modtime));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
log::debug!("finished reading filenames");
|
||||
|
||||
for (filename, (_path, modtime)) in filenames.iter() {
|
||||
let modtime = *modtime;
|
||||
let record = sqlx::query_as!(FileRecord, "SELECT * FROM files WHERE filename = ?", filename)
|
||||
.fetch_optional(&pool)
|
||||
.await?;
|
||||
|
||||
let new_record = match record {
|
||||
None => Some(FileRecord {
|
||||
filename: filename.clone(),
|
||||
..Default::default()
|
||||
}),
|
||||
Some(r) if modtime > r.embedding_time.unwrap_or(i64::MIN) || (modtime > r.ocr_time.unwrap_or(i64::MIN) && config.enable_ocr) || (modtime > r.thumbnail_time.unwrap_or(i64::MIN) && config.enable_thumbs) => {
|
||||
Some(r)
|
||||
},
|
||||
_ => None
|
||||
};
|
||||
if let Some(mut record) = new_record {
|
||||
log::debug!("processing {}", record.filename);
|
||||
sqlx::query!("INSERT OR IGNORE INTO files (filename) VALUES (?)", filename)
|
||||
.execute(&pool)
|
||||
.await?;
|
||||
if modtime > record.embedding_time.unwrap_or(i64::MIN) {
|
||||
record.embedding = None;
|
||||
}
|
||||
if modtime > record.ocr_time.unwrap_or(i64::MIN) {
|
||||
record.raw_ocr_segments = None;
|
||||
}
|
||||
if modtime > record.thumbnail_time.unwrap_or(i64::MIN) {
|
||||
record.thumbnails = None;
|
||||
}
|
||||
// we need to exit here to actually capture the error
|
||||
if !to_process_tx.send(record).await.is_ok() {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
drop(to_process_tx);
|
||||
|
||||
embedding_generation.await?.context("generating embeddings")?;
|
||||
|
||||
if let Some(thumbnail_generation) = thumbnail_generation {
|
||||
thumbnail_generation.await?.context("generating thumbnails")?;
|
||||
}
|
||||
|
||||
if let Some(ocr) = ocr {
|
||||
ocr.await?.context("OCRing")?;
|
||||
}
|
||||
|
||||
image_loading.await?.context("loading images")?;
|
||||
|
||||
let stored: Vec<String> = sqlx::query_scalar("SELECT filename FROM files").fetch_all(&pool).await?;
|
||||
let mut tx = pool.begin().await?;
|
||||
for filename in stored {
|
||||
if !filenames.contains_key(&filename) {
|
||||
sqlx::query!("DELETE FROM files WHERE filename = ?", filename)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
tx.commit().await?;
|
||||
|
||||
log::info!("ingest done");
|
||||
|
||||
Result::Ok(())
|
||||
}
|
||||
|
||||
const INDEX_ADD_BATCH: usize = 512;
|
||||
|
||||
async fn build_index(config: Arc<Config>, backend: Arc<InferenceServerConfig>) -> Result<IIndex> {
|
||||
let pool = initialize_database(&config).await?;
|
||||
|
||||
let mut index = IIndex {
|
||||
// Use a suitable vector similarity search library for Rust
|
||||
vectors: faiss::index_factory(backend.embedding_size as u32, "SQfp16", faiss::MetricType::InnerProduct)?,
|
||||
filenames: Vec::new(),
|
||||
format_codes: Vec::new(),
|
||||
format_names: Vec::new(),
|
||||
};
|
||||
|
||||
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM files")
|
||||
.fetch_one(&pool)
|
||||
.await?;
|
||||
|
||||
index.filenames = Vec::with_capacity(count as usize);
|
||||
index.format_codes = Vec::with_capacity(count as usize);
|
||||
let mut buffer = Vec::with_capacity(INDEX_ADD_BATCH * backend.embedding_size as usize);
|
||||
index.format_names = Vec::with_capacity(5);
|
||||
|
||||
let mut rows = sqlx::query_as::<_, FileRecord>("SELECT * FROM files").fetch(&pool);
|
||||
while let Some(record) = rows.try_next().await? {
|
||||
if let Some(emb) = record.embedding {
|
||||
index.filenames.push(record.filename);
|
||||
for i in (0..emb.len()).step_by(2) {
|
||||
buffer.push(
|
||||
half::f16::from_le_bytes([emb[i], emb[i + 1]])
|
||||
.to_f32(),
|
||||
);
|
||||
}
|
||||
if buffer.len() == buffer.capacity() {
|
||||
index.vectors.add(&buffer)?;
|
||||
buffer.clear();
|
||||
}
|
||||
|
||||
let mut formats: Vec<String> = Vec::new();
|
||||
if let Some(t) = record.thumbnails {
|
||||
formats = rmp_serde::from_slice(&t)?;
|
||||
}
|
||||
|
||||
let mut format_code = 0;
|
||||
for format_string in &formats {
|
||||
let mut found = false;
|
||||
for (i, name) in index.format_names.iter().enumerate() {
|
||||
if name == format_string {
|
||||
format_code |= 1 << i;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
let new_index = index.format_names.len();
|
||||
format_code |= 1 << new_index;
|
||||
index.format_names.push(format_string.clone());
|
||||
}
|
||||
}
|
||||
index.format_codes.push(format_code);
|
||||
}
|
||||
}
|
||||
if !buffer.is_empty() {
|
||||
index.vectors.add(&buffer)?;
|
||||
}
|
||||
|
||||
Ok(index)
|
||||
}
|
||||
|
||||
fn decode_fp16_buffer(buf: &[u8]) -> Vec<f32> {
|
||||
buf.chunks_exact(2)
|
||||
.map(|chunk| half::f16::from_le_bytes([chunk[0], chunk[1]]).to_f32())
|
||||
.collect()
|
||||
}
|
||||
|
||||
type EmbeddingVector = Vec<f32>;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct QueryResult {
|
||||
matches: Vec<(f32, String, String, u64)>,
|
||||
formats: Vec<String>,
|
||||
extensions: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct QueryTerm {
|
||||
embedding: Option<EmbeddingVector>,
|
||||
image: Option<String>,
|
||||
text: Option<String>,
|
||||
weight: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct QueryRequest {
|
||||
terms: Vec<QueryTerm>,
|
||||
k: Option<usize>,
|
||||
}
|
||||
|
||||
async fn query_index(index: &mut IIndex, query: EmbeddingVector, k: usize) -> Result<QueryResult> {
|
||||
let result = index.vectors.search(&query, k as usize)?;
|
||||
|
||||
let items = result.distances
|
||||
.into_iter()
|
||||
.zip(result.labels)
|
||||
.filter_map(|(distance, id)| {
|
||||
let id = id.get()? as usize;
|
||||
Some((
|
||||
distance,
|
||||
index.filenames[id].clone(),
|
||||
generate_filename_hash(&index.filenames[id as usize]).clone(),
|
||||
index.format_codes[id]
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(QueryResult {
|
||||
matches: items,
|
||||
formats: index.format_names.clone(),
|
||||
extensions: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_request(
|
||||
config: &Config,
|
||||
backend_config: Arc<InferenceServerConfig>,
|
||||
client: Arc<Client>,
|
||||
index: &mut IIndex,
|
||||
req: Json<QueryRequest>,
|
||||
) -> Result<Response<Body>> {
|
||||
let mut total_embedding = ndarray::Array::from(vec![0.0; backend_config.embedding_size]);
|
||||
|
||||
let mut image_batch = Vec::new();
|
||||
let mut image_weights = Vec::new();
|
||||
let mut text_batch = Vec::new();
|
||||
let mut text_weights = Vec::new();
|
||||
|
||||
for term in &req.terms {
|
||||
if let Some(image) = &term.image {
|
||||
let bytes = BASE64_STANDARD.decode(image)?;
|
||||
let image = Arc::new(tokio::task::block_in_place(|| image::load_from_memory(&bytes))?);
|
||||
image_batch.push(serde_bytes::ByteBuf::from(resize_for_embed(backend_config.clone(), image).await?));
|
||||
image_weights.push(term.weight.unwrap_or(1.0));
|
||||
}
|
||||
if let Some(text) = &term.text {
|
||||
text_batch.push(text.clone());
|
||||
text_weights.push(term.weight.unwrap_or(1.0));
|
||||
}
|
||||
if let Some(embedding) = &term.embedding {
|
||||
let weight = term.weight.unwrap_or(1.0);
|
||||
for (i, value) in embedding.iter().enumerate() {
|
||||
total_embedding[i] += value * weight;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut batches = vec![];
|
||||
|
||||
if !image_batch.is_empty() {
|
||||
batches.push(
|
||||
EmbeddingRequest::Images {
|
||||
images: image_batch
|
||||
}
|
||||
);
|
||||
}
|
||||
if !text_batch.is_empty() {
|
||||
batches.push(
|
||||
EmbeddingRequest::Text {
|
||||
text: text_batch,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
for batch in batches {
|
||||
let embs: Vec<Vec<u8>> = query_clip_server(&client, config, "/", batch).await?;
|
||||
for emb in embs {
|
||||
total_embedding += &ndarray::Array::from_vec(decode_fp16_buffer(&emb));
|
||||
}
|
||||
}
|
||||
|
||||
let k = req.k.unwrap_or(1000);
|
||||
let qres = query_index(index, total_embedding.to_vec(), k).await?;
|
||||
|
||||
let mut extensions = HashMap::new();
|
||||
for (k, v) in image_formats(config) {
|
||||
extensions.insert(k, v.extension);
|
||||
}
|
||||
|
||||
Ok(Json(QueryResult {
|
||||
matches: qres.matches,
|
||||
formats: qres.formats,
|
||||
extensions,
|
||||
}).into_response())
|
||||
}
|
||||
|
||||
async fn get_backend_config(config: &Config) -> Result<InferenceServerConfig> {
|
||||
let res = Client::new().get(&format!("{}/config", config.clip_server)).send().await?;
|
||||
Ok(rmp_serde::from_slice(&res.bytes().await?)?)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
pretty_env_logger::init();
|
||||
|
||||
let config_path = std::env::args().nth(1).expect("Missing config file path");
|
||||
let config: Arc<Config> = Arc::new(serde_json::from_slice(&std::fs::read(config_path)?)?);
|
||||
|
||||
let pool = initialize_database(&config).await?;
|
||||
sqlx::query(SCHEMA).execute(&pool).await?;
|
||||
|
||||
let backend = Arc::new(loop {
|
||||
match get_backend_config(&config).await {
|
||||
Ok(backend) => break backend,
|
||||
Err(e) => {
|
||||
log::error!("Backend failed (fetch): {}", e);
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if config.no_run_server {
|
||||
ingest_files(config.clone(), backend.clone()).await?;
|
||||
return Ok(())
|
||||
}
|
||||
|
||||
let (request_ingest_tx, mut request_ingest_rx) = mpsc::channel(1);
|
||||
|
||||
let index = Arc::new(tokio::sync::Mutex::new(build_index(config.clone(), backend.clone()).await?));
|
||||
|
||||
let (ingest_done_tx, _ingest_done_rx) = broadcast::channel(1);
|
||||
let done_tx = Arc::new(ingest_done_tx.clone());
|
||||
|
||||
let _ingest_task = tokio::spawn({
|
||||
let config = config.clone();
|
||||
let backend = backend.clone();
|
||||
let index = index.clone();
|
||||
async move {
|
||||
loop {
|
||||
log::info!("Ingest running");
|
||||
match ingest_files(config.clone(), backend.clone()).await {
|
||||
Ok(_) => {
|
||||
match build_index(config.clone(), backend.clone()).await {
|
||||
Ok(new_index) => {
|
||||
*index.lock().await = new_index;
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Index build failed: {:?}", e);
|
||||
ingest_done_tx.send((false, format!("{:?}", e))).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Ingest failed: {:?}", e);
|
||||
ingest_done_tx.send((false, format!("{:?}", e))).unwrap();
|
||||
}
|
||||
}
|
||||
ingest_done_tx.send((true, format!("OK"))).unwrap();
|
||||
request_ingest_rx.recv().await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let cors = CorsLayer::permissive();
|
||||
|
||||
let config_ = config.clone();
|
||||
let client = Arc::new(Client::new());
|
||||
let app = Router::new()
|
||||
.route("/", post(|req| async move {
|
||||
let config = config.clone();
|
||||
let backend_config = backend.clone();
|
||||
let mut index = index.lock().await; // TODO: use ConcurrentIndex here
|
||||
let client = client.clone();
|
||||
handle_request(&config, backend_config, client.clone(), &mut index, req).await.map_err(|e| format!("{:?}", e))
|
||||
}))
|
||||
.route("/", get(|_req: axum::http::Request<axum::body::Body>| async move {
|
||||
"OK"
|
||||
}))
|
||||
.route("/reload", post(|_req: axum::http::Request<axum::body::Body>| async move {
|
||||
log::info!("Requesting index reload");
|
||||
let mut done_rx = done_tx.clone().subscribe();
|
||||
let _ = request_ingest_tx.send(()).await; // ignore possible error, which is presumably because the queue is full
|
||||
match done_rx.recv().await {
|
||||
Ok((true, status)) => {
|
||||
let mut res = status.into_response();
|
||||
*res.status_mut() = StatusCode::OK;
|
||||
res
|
||||
},
|
||||
Ok((false, status)) => {
|
||||
let mut res = status.into_response();
|
||||
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
res
|
||||
},
|
||||
Err(_) => {
|
||||
let mut res = "internal error".into_response();
|
||||
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
res
|
||||
}
|
||||
}
|
||||
}))
|
||||
.layer(cors);
|
||||
|
||||
let addr = format!("0.0.0.0:{}", config_.port);
|
||||
log::info!("Starting server on {}", addr);
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
173
src/ocr.rs
Normal file
173
src/ocr.rs
Normal file
@ -0,0 +1,173 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use image::{DynamicImage, GenericImageView, ImageFormat};
|
||||
use regex::Regex;
|
||||
use reqwest::{
|
||||
header::{HeaderMap, HeaderValue},
|
||||
multipart::{Form, Part},
|
||||
Client,
|
||||
};
|
||||
use serde_json::Value;
|
||||
use std::{io::Cursor, time::{SystemTime, UNIX_EPOCH}};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const CALLBACK_REGEX: &str = r">AF_initDataCallback\((\{key: 'ds:1'.*?\})\);</script>";
|
||||
const MAX_DIM: u32 = 1024;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SegmentCoords {
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
pub w: i32,
|
||||
pub h: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct Segment {
|
||||
pub coords: SegmentCoords,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
pub type ScanResult = Vec<Segment>;
|
||||
|
||||
fn rationalize_coords_format1(
|
||||
image_w: f64,
|
||||
image_h: f64,
|
||||
center_x_fraction: f64,
|
||||
center_y_fraction: f64,
|
||||
width_fraction: f64,
|
||||
height_fraction: f64,
|
||||
) -> SegmentCoords {
|
||||
SegmentCoords {
|
||||
x: ((center_x_fraction - width_fraction / 2.0) * image_w).round() as i32,
|
||||
y: ((center_y_fraction - height_fraction / 2.0) * image_h).round() as i32,
|
||||
w: (width_fraction * image_w).round() as i32,
|
||||
h: (height_fraction * image_h).round() as i32,
|
||||
}
|
||||
}
|
||||
|
||||
async fn scan_image_chunk(
|
||||
client: &Client,
|
||||
image: &[u8],
|
||||
image_width: u32,
|
||||
image_height: u32,
|
||||
) -> Result<ScanResult> {
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_micros();
|
||||
|
||||
let part = Part::bytes(image.to_vec())
|
||||
.file_name(format!("ocr{}.png", timestamp))
|
||||
.mime_str("image/png")?;
|
||||
|
||||
let form = Form::new().part("encoded_image", part);
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"User-Agent",
|
||||
HeaderValue::from_static("Mozilla/5.0 (Linux; Android 13; RMX3771) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.6167.144 Mobile Safari/537.36"),
|
||||
);
|
||||
headers.insert("Cookie", HeaderValue::from_str(&format!("SOCS=CAESEwgDEgk0ODE3Nzk3MjQaAmVuIAEaBgiA_LyaBg; stcs={}", timestamp))?);
|
||||
|
||||
let response = client
|
||||
.post(&format!("https://lens.google.com/v3/upload?stcs={}", timestamp))
|
||||
.multipart(form)
|
||||
.headers(headers)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let body = response.text().await?;
|
||||
|
||||
let re = Regex::new(CALLBACK_REGEX)?;
|
||||
let captures = re
|
||||
.captures(&body)
|
||||
.ok_or_else(|| anyhow!("invalid API response"))?;
|
||||
let match_str = captures.get(1).unwrap().as_str();
|
||||
|
||||
let lens_object: Value = json5::from_str(match_str)?;
|
||||
|
||||
if lens_object.get("errorHasStatus").is_some() {
|
||||
return Err(anyhow!("lens failed"));
|
||||
}
|
||||
|
||||
let root = lens_object["data"].as_array().unwrap();
|
||||
|
||||
let mut text_segments = Vec::new();
|
||||
let mut text_regions = Vec::new();
|
||||
|
||||
let text_segments_raw = root[3][4][0][0]
|
||||
.as_array()
|
||||
.ok_or_else(|| anyhow!("invalid text segments"))?;
|
||||
let text_regions_raw = root[2][3][0]
|
||||
.as_array()
|
||||
.ok_or_else(|| anyhow!("invalid text regions"))?;
|
||||
|
||||
for region in text_regions_raw {
|
||||
let region_data = region.as_array().unwrap();
|
||||
if region_data[11].as_str().unwrap().starts_with("text:") {
|
||||
let raw_coords = region_data[1].as_array().unwrap();
|
||||
let coords = rationalize_coords_format1(
|
||||
image_width as f64,
|
||||
image_height as f64,
|
||||
raw_coords[0].as_f64().unwrap(),
|
||||
raw_coords[1].as_f64().unwrap(),
|
||||
raw_coords[2].as_f64().unwrap(),
|
||||
raw_coords[3].as_f64().unwrap(),
|
||||
);
|
||||
text_regions.push(coords);
|
||||
}
|
||||
}
|
||||
|
||||
for segment in text_segments_raw {
|
||||
let text_segment = segment.as_str().unwrap().to_string();
|
||||
text_segments.push(text_segment);
|
||||
}
|
||||
|
||||
Ok(text_segments
|
||||
.into_iter()
|
||||
.zip(text_regions.into_iter())
|
||||
.map(|(text, coords)| Segment { text, coords })
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn scan_image(client: &Client, image: &DynamicImage) -> Result<ScanResult> {
|
||||
let mut result = ScanResult::new();
|
||||
let (width, height) = image.dimensions();
|
||||
|
||||
let (width, height, image) = if width > MAX_DIM {
|
||||
let height = ((height as f64) * (MAX_DIM as f64) / (width as f64)).round() as u32;
|
||||
let new_image = tokio::task::block_in_place(|| image.resize_exact(MAX_DIM, height, image::imageops::FilterType::Lanczos3));
|
||||
(MAX_DIM, height, std::borrow::Cow::Owned(new_image))
|
||||
} else {
|
||||
(width, height, std::borrow::Cow::Borrowed(image))
|
||||
};
|
||||
|
||||
let mut y = 0;
|
||||
while y < height {
|
||||
let chunk_height = (height - y).min(MAX_DIM);
|
||||
let chunk = tokio::task::block_in_place(|| {
|
||||
let chunk = image.view(0, y, width, chunk_height).to_image();
|
||||
let mut buf = Vec::new();
|
||||
let mut csr = Cursor::new(&mut buf);
|
||||
chunk.write_to(&mut csr, ImageFormat::Png)?;
|
||||
Ok::<Vec<u8>, anyhow::Error>(buf)
|
||||
})?;
|
||||
|
||||
let res = scan_image_chunk(client, &chunk, width, chunk_height).await?;
|
||||
for segment in res {
|
||||
result.push(Segment {
|
||||
text: segment.text,
|
||||
coords: SegmentCoords {
|
||||
y: segment.coords.y + y as i32,
|
||||
x: segment.coords.x,
|
||||
w: segment.coords.w,
|
||||
h: segment.coords.h,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
y += chunk_height;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
Loading…
Reference in New Issue
Block a user