1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-06-19 23:04:07 +00:00
This commit is contained in:
osmarks 2025-01-18 07:25:21 +00:00
parent 0a542ef579
commit d3fcedda09
9 changed files with 3 additions and 2654 deletions

3
.gitignore vendored
View File

@ -17,3 +17,6 @@ diskann/target
*/flamegraph.svg
*.hdf5
*.v
shards
index
queries.txt

View File

@ -1,26 +0,0 @@
module meme-search
go 1.22.2
require (
github.com/DataIntelligenceCrew/go-faiss v0.2.0
github.com/jmoiron/sqlx v1.4.0
github.com/mattn/go-sqlite3 v1.14.22
github.com/samber/lo v1.39.0
github.com/titanous/json5 v1.0.0
github.com/vmihailenco/msgpack v4.0.4+incompatible
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.7.0
)
require (
github.com/davidbyttow/govips/v2 v2.14.0 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/h2non/bimg v1.1.9 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
golang.org/x/image v0.16.0 // indirect
golang.org/x/net v0.25.0 // indirect
golang.org/x/text v0.15.0 // indirect
google.golang.org/appengine v1.6.8 // indirect
google.golang.org/protobuf v1.26.0 // indirect
)

View File

@ -1,100 +0,0 @@
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/DataIntelligenceCrew/go-faiss v0.2.0 h1:c0pxAr0vldXIuE4DZnqpl6FuuH1uZd45d+NiQHKg1uU=
github.com/DataIntelligenceCrew/go-faiss v0.2.0/go.mod h1:4Gi7G3PF78IwZigTL2M1AJXOaAgxyL66vCqUYVaNgwk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davidbyttow/govips/v2 v2.14.0 h1:il3pX0XMZ5nlwipkFJHRZ3vGzcdXWApARalJxNpRHJU=
github.com/davidbyttow/govips/v2 v2.14.0/go.mod h1:eglyvgm65eImDiJJk4wpj9LSz4pWivPzWgDqkxWJn5k=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/h2non/bimg v1.1.9 h1:WH20Nxko9l/HFm4kZCA3Phbgu2cbHvYzxwxn9YROEGg=
github.com/h2non/bimg v1.1.9/go.mod h1:R3+UiYwkK4rQl6KVFTOFJHitgLbZXBZNFh2cv3AEbp8=
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
github.com/json5/json5-go v0.0.0-20160331055859-40c2958e3bf8 h1:BQuwfXQRDQMI8YNqINKNlFV23P0h07ZvOQAtezAEsP8=
github.com/json5/json5-go v0.0.0-20160331055859-40c2958e3bf8/go.mod h1:7n1PdYNh4RIHTvILru80IEstTADqQz/wmjeNXTcC9rA=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/thoas/go-funk v0.9.3 h1:7+nAEx3kn5ZJcnDm2Bh23N2yOtweO14bi//dvRtgLpw=
github.com/thoas/go-funk v0.9.3/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q=
github.com/titanous/json5 v1.0.0 h1:hJf8Su1d9NuI/ffpxgxQfxh/UiBFZX7bMPid0rIL/7s=
github.com/titanous/json5 v1.0.0/go.mod h1:7JH1M8/LHKc6cyP5o5g3CSaRj+mBrIimTxzpvmckH8c=
github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI=
github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/image v0.10.0/go.mod h1:jtrku+n79PfroUbvDdeUWMAI+heR786BofxrbiSF+J0=
golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw=
golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM=
google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,877 +0,0 @@
package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"hash/fnv"
"io"
"log"
"net/http"
"os"
"path/filepath"
"runtime"
"runtime/pprof"
"strings"
"sync"
"time"
"github.com/DataIntelligenceCrew/go-faiss"
"github.com/h2non/bimg"
"github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3"
"github.com/samber/lo"
"github.com/vmihailenco/msgpack"
"github.com/x448/float16"
"golang.org/x/sync/errgroup"
)
type Config struct {
ClipServer string `json:"clip_server"`
DbPath string `json:"db_path"`
Port int16 `json:"port"`
Files string `json:"files"`
EnableOCR bool `json:"enable_ocr"`
ThumbsPath string `json:"thumbs_path"`
EnableThumbnails bool `json:"enable_thumbs"`
}
type Index struct {
vectors *faiss.IndexImpl
filenames []string
formatCodes []int64
formatNames []string
}
var schema = `
CREATE TABLE IF NOT EXISTS files (
filename TEXT PRIMARY KEY,
embedding_time INTEGER,
ocr_time INTEGER,
thumbnail_time INTEGER,
embedding BLOB,
ocr TEXT,
raw_ocr_segments BLOB,
thumbnails BLOB
);
CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
filename,
ocr,
tokenize='unicode61 remove_diacritics 2',
content='ocr'
);
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN
INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
END;
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER DELETE ON files BEGIN
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
END;
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER UPDATE ON files BEGIN
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
END;
`
type FileRecord struct {
Filename string `db:"filename"`
EmbedTime int64 `db:"embedding_time"`
OcrTime int64 `db:"ocr_time"`
ThumbnailTime int64 `db:"thumbnail_time"`
Embedding []byte `db:"embedding"`
Ocr string `db:"ocr"`
RawOcrSegments []byte `db:"raw_ocr_segments"`
Thumbnails []byte `db:"thumbnails"`
}
type InferenceServerConfig struct {
BatchSize uint `msgpack:"batch"`
ImageSize []uint `msgpack:"image_size"`
EmbeddingSize uint `msgpack:"embedding_size"`
}
func decodeMsgpackFrom[O interface{}](resp *http.Response) (O, error) {
var result O
respData, err := io.ReadAll(resp.Body)
if err != nil {
return result, err
}
err = msgpack.Unmarshal(respData, &result)
return result, err
}
func queryClipServer[I interface{}, O interface{}](config Config, path string, data I) (O, error) {
var result O
b, err := msgpack.Marshal(data)
if err != nil {
return result, err
}
resp, err := http.Post(config.ClipServer+path, "application/msgpack", bytes.NewReader(b))
if err != nil {
return result, err
}
defer resp.Body.Close()
return decodeMsgpackFrom[O](resp)
}
type LoadedImage struct {
image *bimg.Image
filename string
originalSize int
}
type EmbeddingInput struct {
image []byte
filename string
}
type EmbeddingRequest struct {
Images [][]byte `msgpack:"images"`
Text []string `msgpack:"text"`
}
type EmbeddingResponse = [][]byte
func timestamp() int64 {
return time.Now().UnixMicro()
}
type ImageFormatConfig struct {
targetWidth int
targetFilesize int
quality int
format bimg.ImageType
extension string
}
func generateFilenameHash(filename string) string {
hasher := fnv.New128()
hasher.Write([]byte(filename))
hash := hasher.Sum(make([]byte, 0))
return base64.RawURLEncoding.EncodeToString(hash)
}
func generateThumbnailFilename(filename string, formatName string, formatConfig ImageFormatConfig) string {
return fmt.Sprintf("%s%s.%s", generateFilenameHash(filename), formatName, formatConfig.extension)
}
func initializeDatabase(config Config) (*sqlx.DB, error) {
db, err := sqlx.Connect("sqlite3", config.DbPath)
if err != nil {
return nil, err
}
_, err = db.Exec("PRAGMA busy_timeout = 2000; PRAGMA journal_mode = WAL")
if err != nil {
return nil, err
}
return db, nil
}
func imageFormats(config Config) map[string]ImageFormatConfig {
return map[string]ImageFormatConfig{
"jpegl": {
targetWidth: 800,
quality: 70,
format: bimg.JPEG,
extension: "jpg",
},
"jpegh": {
targetWidth: 1600,
quality: 80,
format: bimg.JPEG,
extension: "jpg",
},
"jpeg256kb": {
targetWidth: 500,
targetFilesize: 256000,
format: bimg.JPEG,
extension: "jpg",
},
"avifh": {
targetWidth: 1600,
quality: 80,
format: bimg.AVIF,
extension: "avif",
},
"avifl": {
targetWidth: 800,
quality: 30,
format: bimg.AVIF,
extension: "avif",
},
}
}
func ingestFiles(config Config, backend InferenceServerConfig) error {
var wg errgroup.Group
var iwg errgroup.Group
// We assume everything is either a modern browser (low-DPI or high-DPI), an ancient browser or a ComputerCraft machine abusing Extra Utilities 2 screens.
var formats = imageFormats(config)
db, err := initializeDatabase(config)
if err != nil {
return err
}
defer db.Close()
toProcess := make(chan FileRecord, 100)
toEmbed := make(chan EmbeddingInput, backend.BatchSize)
toThumbnail := make(chan LoadedImage, 30)
toOCR := make(chan LoadedImage, 30)
embedBatches := make(chan []EmbeddingInput, 1)
// image loading and preliminary resizing
for range runtime.NumCPU() {
iwg.Go(func() error {
for record := range toProcess {
path := filepath.Join(config.Files, record.Filename)
buffer, err := bimg.Read(path)
if err != nil {
log.Println("could not read ", record.Filename)
}
img := bimg.NewImage(buffer)
if record.Embedding == nil {
resized, err := img.Process(bimg.Options{
Width: int(backend.ImageSize[0]),
Height: int(backend.ImageSize[1]),
Force: true,
Type: bimg.PNG,
Interpretation: bimg.InterpretationSRGB,
})
if err != nil {
log.Println("resize failure", record.Filename, err)
} else {
toEmbed <- EmbeddingInput{
image: resized,
filename: record.Filename,
}
}
}
if record.Thumbnails == nil && config.EnableThumbnails {
toThumbnail <- LoadedImage{
image: img,
filename: record.Filename,
originalSize: len(buffer),
}
}
if record.RawOcrSegments == nil && config.EnableOCR {
toOCR <- LoadedImage{
image: img,
filename: record.Filename,
}
}
}
return nil
})
}
if config.EnableThumbnails {
for range runtime.NumCPU() {
wg.Go(func() error {
for image := range toThumbnail {
generatedFormats := make([]string, 0)
for formatName, formatConfig := range formats {
var err error
var resized []byte
if formatConfig.targetFilesize != 0 {
lb := 1
ub := 100
for {
quality := (lb + ub) / 2
resized, err = image.image.Process(bimg.Options{
Width: formatConfig.targetWidth,
Type: formatConfig.format,
Speed: 4,
Quality: quality,
StripMetadata: true,
Enlarge: false,
})
if len(resized) > image.originalSize {
ub = quality
} else {
lb = quality + 1
}
if lb >= ub {
break
}
}
} else {
resized, err = image.image.Process(bimg.Options{
Width: formatConfig.targetWidth,
Type: formatConfig.format,
Speed: 4,
Quality: formatConfig.quality,
StripMetadata: true,
Enlarge: false,
})
}
if err != nil {
log.Println("thumbnailing failure", image.filename, err)
continue
}
if len(resized) < image.originalSize {
generatedFormats = append(generatedFormats, formatName)
err = bimg.Write(filepath.Join(config.ThumbsPath, generateThumbnailFilename(image.filename, formatName, formatConfig)), resized)
if err != nil {
return err
}
}
}
formatsData, err := msgpack.Marshal(generatedFormats)
if err != nil {
return err
}
_, err = db.Exec("UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?", formatsData, timestamp(), image.filename)
if err != nil {
return err
}
}
return nil
})
}
}
if config.EnableOCR {
for range 100 {
wg.Go(func() error {
for image := range toOCR {
scan, err := scanImage(image.image)
if err != nil {
log.Println("OCR failure", image.filename, err)
continue
}
ocrText := ""
for _, segment := range scan {
ocrText += segment.text
ocrText += "\n"
}
ocrData, err := msgpack.Marshal(scan)
if err != nil {
return err
}
_, err = db.Exec("UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?", ocrText, ocrData, timestamp(), image.filename)
if err != nil {
return err
}
}
return nil
})
}
}
wg.Go(func() error {
buffer := make([]EmbeddingInput, 0, backend.BatchSize)
for input := range toEmbed {
buffer = append(buffer, input)
if len(buffer) == int(backend.BatchSize) {
embedBatches <- buffer
buffer = make([]EmbeddingInput, 0, backend.BatchSize)
}
}
if len(buffer) > 0 {
embedBatches <- buffer
}
close(embedBatches)
return nil
})
for range 3 {
wg.Go(func() error {
for batch := range embedBatches {
result, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "", EmbeddingRequest{
Images: lo.Map(batch, func(item EmbeddingInput, _ int) []byte { return item.image }),
})
if err != nil {
return err
}
tx, err := db.Begin()
if err != nil {
return err
}
for i, vector := range result {
_, err = tx.Exec("UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?", timestamp(), vector, batch[i].filename)
if err != nil {
return err
}
}
err = tx.Commit()
if err != nil {
return err
}
}
return nil
})
}
filenamesOnDisk := make(map[string]struct{})
err = filepath.WalkDir(config.Files, func(path string, d os.DirEntry, err error) error {
filename := strings.TrimPrefix(path, config.Files)
if err != nil {
return err
}
if d.IsDir() {
return nil
}
filenamesOnDisk[filename] = struct{}{}
records := []FileRecord{}
err = db.Select(&records, "SELECT * FROM files WHERE filename = ?", filename)
if err != nil {
return err
}
stat, err := d.Info()
if err != nil {
return err
}
modtime := stat.ModTime().UnixMicro()
if len(records) == 0 || modtime > records[0].EmbedTime || modtime > records[0].OcrTime || modtime > records[0].ThumbnailTime {
_, err = db.Exec("INSERT OR IGNORE INTO files VALUES (?, 0, 0, 0, '', '', '', '')", filename)
if err != nil {
return err
}
record := FileRecord{
Filename: filename,
}
if len(records) > 0 {
record = records[0]
}
if modtime > record.EmbedTime || len(record.Embedding) == 0 {
record.Embedding = nil
}
if modtime > record.OcrTime || len(record.RawOcrSegments) == 0 {
record.RawOcrSegments = nil
}
if modtime > record.ThumbnailTime || len(record.Thumbnails) == 0 {
record.Thumbnails = nil
}
toProcess <- record
}
return nil
})
if err != nil {
return err
}
close(toProcess)
err = iwg.Wait()
close(toEmbed)
close(toThumbnail)
if err != nil {
return err
}
err = wg.Wait()
if err != nil {
return err
}
rows, err := db.Queryx("SELECT filename FROM files")
if err != nil {
return err
}
tx, err := db.Begin()
if err != nil {
return err
}
for rows.Next() {
var filename string
err := rows.Scan(&filename)
if err != nil {
return err
}
if _, ok := filenamesOnDisk[filename]; !ok {
_, err = tx.Exec("DELETE FROM files WHERE filename = ?", filename)
if err != nil {
return err
}
}
}
if err = tx.Commit(); err != nil {
return err
}
return nil
}
const INDEX_ADD_BATCH = 512
func buildIndex(config Config, backend InferenceServerConfig) (Index, error) {
var index Index
db, err := initializeDatabase(config)
if err != nil {
return index, err
}
defer db.Close()
newFAISSIndex, err := faiss.IndexFactory(int(backend.EmbeddingSize), "SQfp16", faiss.MetricInnerProduct)
if err != nil {
return index, err
}
index.vectors = newFAISSIndex
var count int
err = db.Get(&count, "SELECT COUNT(*) FROM files")
if err != nil {
return index, err
}
index.filenames = make([]string, 0, count)
index.formatCodes = make([]int64, 0, count)
buffer := make([]float32, 0, INDEX_ADD_BATCH*backend.EmbeddingSize)
index.formatNames = make([]string, 0, 5)
record := FileRecord{}
rows, err := db.Queryx("SELECT * FROM files")
if err != nil {
return index, err
}
for rows.Next() {
err := rows.StructScan(&record)
if err != nil {
return index, err
}
if len(record.Embedding) > 0 {
index.filenames = append(index.filenames, record.Filename)
for i := 0; i < len(record.Embedding); i += 2 {
buffer = append(buffer, float16.Frombits(uint16(record.Embedding[i])+uint16(record.Embedding[i+1])<<8).Float32())
}
if len(buffer) == cap(buffer) {
index.vectors.Add(buffer)
buffer = make([]float32, 0, INDEX_ADD_BATCH*backend.EmbeddingSize)
}
formats := make([]string, 0, 5)
if len(record.Thumbnails) > 0 {
err := msgpack.Unmarshal(record.Thumbnails, &formats)
if err != nil {
return index, err
}
}
formatCode := int64(0)
for _, formatString := range formats {
found := false
for i, name := range index.formatNames {
if name == formatString {
formatCode |= 1 << i
found = true
break
}
}
if !found {
newIndex := len(index.formatNames)
formatCode |= 1 << newIndex
index.formatNames = append(index.formatNames, formatString)
}
}
index.formatCodes = append(index.formatCodes, formatCode)
}
}
if len(buffer) > 0 {
index.vectors.Add(buffer)
}
return index, nil
}
func decodeFP16Buffer(buf []byte) []float32 {
out := make([]float32, 0, len(buf)/2)
for i := 0; i < len(buf); i += 2 {
out = append(out, float16.Frombits(uint16(buf[i])+uint16(buf[i+1])<<8).Float32())
}
return out
}
type EmbeddingVector []float32
type QueryResult struct {
Matches [][]interface{} `json:"matches"`
Formats []string `json:"formats"`
Extensions map[string]string `json:"extensions"`
}
// this terrible language cannot express tagged unions
type QueryTerm struct {
Embedding *EmbeddingVector `json:"embedding"`
Image *string `json:"image"` // base64
Text *string `json:"text"`
Weight *float32 `json:"weight"`
}
type QueryRequest struct {
Terms []QueryTerm `json:"terms"`
K *int `json:"k"`
}
func queryIndex(index *Index, query EmbeddingVector, k int) (QueryResult, error) {
var qr QueryResult
distances, ids, err := index.vectors.Search(query, int64(k))
if err != nil {
return qr, err
}
items := lo.Map(lo.Zip2(distances, ids), func(x lo.Tuple2[float32, int64], i int) []interface{} {
return []interface{}{
x.A,
index.filenames[x.B],
generateFilenameHash(index.filenames[x.B]),
index.formatCodes[x.B],
}
})
return QueryResult{
Matches: items,
Formats: index.formatNames,
}, nil
}
func handleRequest(config Config, backendConfig InferenceServerConfig, index *Index, w http.ResponseWriter, req *http.Request) error {
if req.Body == nil {
io.WriteString(w, "OK") // health check
return nil
}
dec := json.NewDecoder(req.Body)
var qreq QueryRequest
err := dec.Decode(&qreq)
if err != nil {
return err
}
totalEmbedding := make(EmbeddingVector, backendConfig.EmbeddingSize)
imageBatch := make([][]byte, 0)
imageWeights := make([]float32, 0)
textBatch := make([]string, 0)
textWeights := make([]float32, 0)
for _, term := range qreq.Terms {
if term.Image != nil {
bytes, err := base64.StdEncoding.DecodeString(*term.Image)
if err != nil {
return err
}
loaded := bimg.NewImage(bytes)
resized, err := loaded.Process(bimg.Options{
Width: int(backendConfig.ImageSize[0]),
Height: int(backendConfig.ImageSize[1]),
Force: true,
Type: bimg.PNG,
Interpretation: bimg.InterpretationSRGB,
})
if err != nil {
return err
}
imageBatch = append(imageBatch, resized)
if term.Weight != nil {
imageWeights = append(imageWeights, *term.Weight)
} else {
imageWeights = append(imageWeights, 1)
}
}
if term.Text != nil {
textBatch = append(textBatch, *term.Text)
if term.Weight != nil {
textWeights = append(textWeights, *term.Weight)
} else {
textWeights = append(textWeights, 1)
}
}
if term.Embedding != nil {
weight := float32(1.0)
if term.Weight != nil {
weight = *term.Weight
}
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
totalEmbedding[i] += (*term.Embedding)[i] * weight
}
}
}
if len(imageBatch) > 0 {
embs, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "/", EmbeddingRequest{
Images: imageBatch,
})
if err != nil {
return err
}
for j, emb := range embs {
embd := decodeFP16Buffer(emb)
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
totalEmbedding[i] += embd[i] * imageWeights[j]
}
}
}
if len(textBatch) > 0 {
embs, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "/", EmbeddingRequest{
Text: textBatch,
})
if err != nil {
return err
}
for j, emb := range embs {
embd := decodeFP16Buffer(emb)
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
totalEmbedding[i] += embd[i] * textWeights[j]
}
}
}
k := 1000
if qreq.K != nil {
k = *qreq.K
}
w.Header().Add("Content-Type", "application/json")
enc := json.NewEncoder(w)
qres, err := queryIndex(index, totalEmbedding, k)
qres.Extensions = make(map[string]string)
for k, v := range imageFormats(config) {
qres.Extensions[k] = v.extension
}
if err != nil {
return err
}
err = enc.Encode(qres)
if err != nil {
return err
}
return nil
}
func init() {
os.Setenv("VIPS_WARNING", "FALSE") // this does not actually work
bimg.VipsCacheSetMax(0)
bimg.VipsCacheSetMaxMem(0)
}
func main() {
content, err := os.ReadFile(os.Args[1])
if err != nil {
log.Fatal("config file unreadable ", err)
}
var config Config
err = json.Unmarshal(content, &config)
if err != nil {
log.Fatal("config file wrong ", err)
}
fmt.Println(config)
db, err := sqlx.Connect("sqlite3", config.DbPath)
if err != nil {
log.Fatal("DB connection failure ", db)
}
db.MustExec(schema)
var backend InferenceServerConfig
for {
resp, err := http.Get(config.ClipServer + "/config")
if err != nil {
log.Println("backend failed (fetch) ", err)
}
backend, err = decodeMsgpackFrom[InferenceServerConfig](resp)
resp.Body.Close()
if err != nil {
log.Println("backend failed (parse) ", err)
} else {
break
}
time.Sleep(time.Second)
}
requestIngest := make(chan struct{}, 1)
var index *Index
// maybe this ought to be mutexed?
var lastError *error
// there's not a neat way to reusably broadcast to multiple channels, but I *can* abuse WaitGroups probably
// this might cause horrible concurrency issues, but you brought me to this point, Go designers
var wg sync.WaitGroup
go func() {
for {
wg.Add(1)
log.Println("ingest running")
err := ingestFiles(config, backend)
if err != nil {
log.Println("ingest failed ", err)
lastError = &err
} else {
newIndex, err := buildIndex(config, backend)
if err != nil {
log.Println("index build failed ", err)
lastError = &err
} else {
lastError = nil
index = &newIndex
}
}
wg.Done()
<-requestIngest
}
}()
newIndex, err := buildIndex(config, backend)
index = &newIndex
if err != nil {
log.Fatal("index build failed ", err)
}
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Headers", "Content-Type")
if req.Method == "OPTIONS" {
w.WriteHeader(204)
return
}
err := handleRequest(config, backend, index, w, req)
if err != nil {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{
"error": err.Error(),
})
}
})
http.HandleFunc("/reload", func(w http.ResponseWriter, req *http.Request) {
if req.Method == "POST" {
log.Println("requesting index reload")
select {
case requestIngest <- struct{}{}:
default:
}
wg.Wait()
if lastError == nil {
w.Write([]byte("OK"))
} else {
w.WriteHeader(500)
w.Write([]byte((*lastError).Error()))
}
}
})
http.HandleFunc("/profile", func(w http.ResponseWriter, req *http.Request) {
f, err := os.Create("mem.pprof")
if err != nil {
log.Fatal("could not create memory profile: ", err)
}
defer f.Close()
var m runtime.MemStats
runtime.ReadMemStats(&m)
log.Printf("Memory usage: Alloc=%v, TotalAlloc=%v, Sys=%v", m.Alloc, m.TotalAlloc, m.Sys)
log.Println(bimg.VipsMemory())
bimg.VipsDebugInfo()
runtime.GC() // Trigger garbage collection
if err := pprof.WriteHeapProfile(f); err != nil {
log.Fatal("could not write memory profile: ", err)
}
})
log.Println("starting server")
http.ListenAndServe(fmt.Sprintf(":%d", config.Port), nil)
}

View File

@ -1,264 +0,0 @@
package main
import (
"bytes"
"errors"
"fmt"
"io"
"math"
"mime/multipart"
"net/http"
"net/textproto"
"regexp"
"strings"
"time"
"github.com/h2non/bimg"
"github.com/samber/lo"
"github.com/titanous/json5"
)
const CALLBACK_REGEX string = ">AF_initDataCallback\\(({key: 'ds:1'.*?)\\);</script>"
type SegmentCoords struct {
x int
y int
w int
h int
}
type Segment struct {
coords SegmentCoords
text string
}
type ScanResult []Segment
// TODO coordinates are negative sometimes and I think they shouldn't be
func rationalizeCoordsFormat1(imageW float64, imageH float64, centerXFraction float64, centerYFraction float64, widthFraction float64, heightFraction float64) SegmentCoords {
return SegmentCoords{
x: int(math.Round((centerXFraction - widthFraction/2) * imageW)),
y: int(math.Round((centerYFraction - heightFraction/2) * imageH)),
w: int(math.Round(widthFraction * imageW)),
h: int(math.Round(heightFraction * imageH)),
}
}
func scanImageChunk(image []byte, imageWidth int, imageHeight int) (ScanResult, error) {
var result ScanResult
timestamp := time.Now().UnixMicro()
var b bytes.Buffer
w := multipart.NewWriter(&b)
defer w.Close()
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="encoded_image"; filename="ocr%d.png"`, timestamp))
h.Set("Content-Type", "image/png")
fw, err := w.CreatePart(h)
if err != nil {
return result, err
}
fw.Write(image)
w.Close()
req, err := http.NewRequest("POST", fmt.Sprintf("https://lens.google.com/v3/upload?stcs=%d", timestamp), &b)
if err != nil {
return result, err
}
req.Header.Add("User-Agent", "Mozilla/5.0 (Linux; Android 13; RMX3771) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.6167.144 Mobile Safari/537.36")
req.AddCookie(&http.Cookie{
Name: "SOCS",
Value: "CAESEwgDEgk0ODE3Nzk3MjQaAmVuIAEaBgiA_LyaBg",
})
req.Header.Set("Content-Type", w.FormDataContentType())
client := http.Client{}
res, err := client.Do(req)
if err != nil {
return result, err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return result, err
}
re, _ := regexp.Compile(CALLBACK_REGEX)
matches := re.FindStringSubmatch(string(body[:]))
if len(matches) == 0 {
return result, fmt.Errorf("invalid API response")
}
match := matches[1]
var lensObject map[string]interface{}
err = json5.Unmarshal([]byte(match), &lensObject)
if err != nil {
return result, err
}
if _, ok := lensObject["errorHasStatus"]; ok {
return result, errors.New("lens failed")
}
root := lensObject["data"].([]interface{})
var textSegments []string
var textRegions []SegmentCoords
// I don't know why Google did this.
// Text segments are in one place and their locations are in another, using a very strange coordinate system.
// At least I don't need whatever is contained in the base64 parts (which I assume are protobufs).
// TODO: on a few images, this seems to not work for some reason.
defer func() {
if r := recover(); r != nil {
// https://github.com/dimdenGD/chrome-lens-ocr/blob/main/src/core.js#L316 has code for a fallback text segment read mode.
// In testing, this proved unnecessary (quirks of the HTTP request? I don't know), and this only happens on textless images.
textSegments = []string{}
textRegions = []SegmentCoords{}
}
}()
textSegmentsRaw := root[3].([]interface{})[4].([]interface{})[0].([]interface{})[0].([]interface{})
textRegionsRaw := root[2].([]interface{})[3].([]interface{})[0].([]interface{})
for _, x := range textRegionsRaw {
if strings.HasPrefix(x.([]interface{})[11].(string), "text:") {
rawCoords := x.([]interface{})[1].([]interface{})
coords := rationalizeCoordsFormat1(float64(imageWidth), float64(imageHeight), rawCoords[0].(float64), rawCoords[1].(float64), rawCoords[2].(float64), rawCoords[3].(float64))
textRegions = append(textRegions, coords)
}
}
for _, x := range textSegmentsRaw {
textSegment := x.(string)
textSegments = append(textSegments, textSegment)
}
return lo.Map(lo.Zip2(textSegments, textRegions), func(x lo.Tuple2[string, SegmentCoords], _ int) Segment {
return Segment{
text: x.A,
coords: x.B,
}
}), nil
}
const MAX_DIM int = 1024
func scanImage(image *bimg.Image) (ScanResult, error) {
result := ScanResult{}
metadata, err := image.Metadata()
if err != nil {
return result, err
}
width := metadata.Size.Width
height := metadata.Size.Height
if width > MAX_DIM {
width = MAX_DIM
height = int(math.Round(float64(height) * (float64(width) / float64(metadata.Size.Width))))
}
for y := 0; y < height; y += MAX_DIM {
chunkHeight := MAX_DIM
if y+chunkHeight > height {
chunkHeight = height - y
}
chunk, err := image.Process(bimg.Options{
Height: height, // these are for overall image dimensions (resize then crop)
Width: width,
Top: y,
AreaHeight: chunkHeight,
AreaWidth: width,
Crop: true,
Type: bimg.PNG,
})
if err != nil {
return result, err
}
res, err := scanImageChunk(chunk, width, chunkHeight)
if err != nil {
return result, err
}
for _, segment := range res {
result = append(result, Segment{
text: segment.text,
coords: SegmentCoords{
y: segment.coords.y + y,
x: segment.coords.x,
w: segment.coords.w,
h: segment.coords.h,
},
})
}
}
return result, nil
}
/*
async def scan_image_chunk(sess, image):
# send data to inscrutable undocumented Google service
# https://github.com/AuroraWright/owocr/blob/master/owocr/ocr.py#L193
async with aiohttp.ClientSession() as sess:
data = aiohttp.FormData()
data.add_field(
"encoded_image",
encode_img(image),
filename="ocr" + str(timestamp) + ".png",
content_type="image/png"
)
async with sess.post(url, headers=headers, cookies=cookies, data=data, timeout=10) as res:
body = await res.text()
# I really worry about Google sometimes. This is not a sensible format.
match = CALLBACK_REGEX.search(body)
if match == None:
raise ValueError("Invalid callback")
lens_object = pyjson5.loads(match.group(1))
if "errorHasStatus" in lens_object:
raise RuntimeError("Lens failed")
text_segments = []
text_regions = []
root = lens_object["data"]
# I don't know why Google did this.
# Text segments are in one place and their locations are in another, using a very strange coordinate system.
# At least I don't need whatever is contained in the base64 partss (which I assume are protobufs).
# TODO: on a few images, this seems to not work for some reason.
try:
text_segments = root[3][4][0][0]
text_regions = [ rationalize_coords_format1(image.width, image.height, *x[1]) for x in root[2][3][0] if x[11].startswith("text:") ]
except (KeyError, IndexError):
# https://github.com/dimdenGD/chrome-lens-ocr/blob/main/src/core.js#L316 has code for a fallback text segment read mode.
# In testing, this proved unnecessary (quirks of the HTTP request? I don't know), and this only happens on textless images.
return [], []
return text_segments, text_regions
MAX_SCAN_DIM = 1000 # not actually true but close enough
def chunk_image(image: Image):
chunks = []
# Cut image down in X axis (I'm assuming images aren't too wide to scan in downscaled form because merging text horizontally would be annoying)
if image.width > MAX_SCAN_DIM:
image = image.resize((MAX_SCAN_DIM, round(image.height * (image.width / MAX_SCAN_DIM))), Image.LANCZOS)
for y in range(0, image.height, MAX_SCAN_DIM):
chunks.append(image.crop((0, y, image.width, min(y + MAX_SCAN_DIM, image.height))))
return chunks
async def scan_chunks(sess: aiohttp.ClientSession, chunks: [Image]):
# If text happens to be split across the cut line it won't get read.
# This is because doing overlap read areas would be really annoying.
text = ""
regions = []
for chunk in chunks:
new_segments, new_regions = await scan_image_chunk(sess, chunk)
for segment in new_segments:
text += segment + "\n"
for i, (segment, region) in enumerate(zip(new_segments, new_regions)):
regions.append({ **region, "y": region["y"] + (MAX_SCAN_DIM * i), "text": segment })
return text, regions
async def scan_image(sess: aiohttp.ClientSession, image: Image):
return await scan_chunks(sess, chunk_image(image))
if __name__ == "__main__":
async def main():
async with aiohttp.ClientSession() as sess:
print(await scan_image(sess, Image.open("/data/public/memes-or-something/linear-algebra-chess.png")))
asyncio.run(main())
*/

View File

@ -1,891 +0,0 @@
package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"hash/fnv"
"io"
"log"
"net/http"
"os"
"path/filepath"
"runtime"
"runtime/pprof"
"strings"
"sync"
"time"
"github.com/DataIntelligenceCrew/go-faiss"
"github.com/davidbyttow/govips/v2/vips"
"github.com/h2non/bimg"
"github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3"
"github.com/samber/lo"
"github.com/vmihailenco/msgpack"
"github.com/x448/float16"
"golang.org/x/sync/errgroup"
)
type Config struct {
ClipServer string `json:"clip_server"`
DbPath string `json:"db_path"`
Port int16 `json:"port"`
Files string `json:"files"`
EnableOCR bool `json:"enable_ocr"`
ThumbsPath string `json:"thumbs_path"`
EnableThumbnails bool `json:"enable_thumbs"`
}
type Index struct {
vectors *faiss.IndexImpl
filenames []string
formatCodes []int64
formatNames []string
}
var schema = `
CREATE TABLE IF NOT EXISTS files (
filename TEXT PRIMARY KEY,
embedding_time INTEGER,
ocr_time INTEGER,
thumbnail_time INTEGER,
embedding BLOB,
ocr TEXT,
raw_ocr_segments BLOB,
thumbnails BLOB
);
CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
filename,
ocr,
tokenize='unicode61 remove_diacritics 2',
content='ocr'
);
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN
INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
END;
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER DELETE ON files BEGIN
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
END;
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER UPDATE ON files BEGIN
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
END;
`
type FileRecord struct {
Filename string `db:"filename"`
EmbedTime int64 `db:"embedding_time"`
OcrTime int64 `db:"ocr_time"`
ThumbnailTime int64 `db:"thumbnail_time"`
Embedding []byte `db:"embedding"`
Ocr string `db:"ocr"`
RawOcrSegments []byte `db:"raw_ocr_segments"`
Thumbnails []byte `db:"thumbnails"`
filesize int64
}
type InferenceServerConfig struct {
BatchSize uint `msgpack:"batch"`
ImageSize []uint `msgpack:"image_size"`
EmbeddingSize uint `msgpack:"embedding_size"`
}
func decodeMsgpackFrom[O interface{}](resp *http.Response) (O, error) {
var result O
respData, err := io.ReadAll(resp.Body)
if err != nil {
return result, err
}
err = msgpack.Unmarshal(respData, &result)
return result, err
}
func queryClipServer[I interface{}, O interface{}](config Config, path string, data I) (O, error) {
var result O
b, err := msgpack.Marshal(data)
if err != nil {
return result, err
}
resp, err := http.Post(config.ClipServer+path, "application/msgpack", bytes.NewReader(b))
if err != nil {
return result, err
}
defer resp.Body.Close()
return decodeMsgpackFrom[O](resp)
}
type LoadedImage struct {
image *vips.ImageRef
filename string
originalSize int
}
type EmbeddingInput struct {
image []byte
filename string
}
type EmbeddingRequest struct {
Images [][]byte `msgpack:"images"`
Text []string `msgpack:"text"`
}
type EmbeddingResponse = [][]byte
func timestamp() int64 {
return time.Now().UnixMicro()
}
type ImageFormatConfig struct {
targetWidth int
targetFilesize int
quality int
format vips.ImageType
extension string
}
func generateFilenameHash(filename string) string {
hasher := fnv.New128()
hasher.Write([]byte(filename))
hash := hasher.Sum(make([]byte, 0))
return base64.RawURLEncoding.EncodeToString(hash)
}
func generateThumbnailFilename(filename string, formatName string, formatConfig ImageFormatConfig) string {
return fmt.Sprintf("%s%s.%s", generateFilenameHash(filename), formatName, formatConfig.extension)
}
func initializeDatabase(config Config) (*sqlx.DB, error) {
db, err := sqlx.Connect("sqlite3", config.DbPath)
if err != nil {
return nil, err
}
_, err = db.Exec("PRAGMA busy_timeout = 2000; PRAGMA journal_mode = WAL")
if err != nil {
return nil, err
}
return db, nil
}
func imageFormats(config Config) map[string]ImageFormatConfig {
return map[string]ImageFormatConfig{
"jpegl": {
targetWidth: 800,
quality: 70,
format: vips.ImageTypeJPEG,
extension: "jpg",
},
"jpegh": {
targetWidth: 1600,
quality: 80,
format: vips.ImageTypeJPEG,
extension: "jpg",
},
"jpeg256kb": {
targetWidth: 500,
targetFilesize: 256000,
format: vips.ImageTypeJPEG,
extension: "jpg",
},
"avifh": {
targetWidth: 1600,
quality: 80,
format: vips.ImageTypeAVIF,
extension: "avif",
},
"avifl": {
targetWidth: 800,
quality: 30,
format: vips.ImageTypeAVIF,
extension: "avif",
},
}
}
func ingestFiles(config Config, backend InferenceServerConfig) error {
var wg errgroup.Group
var iwg errgroup.Group
// We assume everything is either a modern browser (low-DPI or high-DPI), an ancient browser or a ComputerCraft machine abusing Extra Utilities 2 screens.
var formats = imageFormats(config)
db, err := initializeDatabase(config)
if err != nil {
return err
}
defer db.Close()
toProcess := make(chan FileRecord, 100)
toEmbed := make(chan EmbeddingInput, backend.BatchSize)
toThumbnail := make(chan LoadedImage, 30)
toOCR := make(chan LoadedImage, 30)
embedBatches := make(chan []EmbeddingInput, 1)
// image loading and preliminary resizing
for range runtime.NumCPU() {
iwg.Go(func() error {
for record := range toProcess {
path := filepath.Join(config.Files, record.Filename)
img, err := vips.LoadImageFromFile(path, &vips.ImportParams{})
if err != nil {
log.Println("could not read", record.Filename)
continue
}
if record.Embedding == nil {
i, err := img.Copy() // TODO this is ugly, we should not need to do in-place operations
if err != nil {
return err
}
err = i.ResizeWithVScale(float64(backend.ImageSize[0])/float64(i.Width()), float64(backend.ImageSize[1])/float64(i.Height()), vips.KernelLanczos3)
if err != nil {
return err
}
resized, _, err := i.ExportPng(vips.NewPngExportParams())
if err != nil {
log.Println("resize failure", record.Filename, err)
} else {
toEmbed <- EmbeddingInput{
image: resized,
filename: record.Filename,
}
}
}
if record.Thumbnails == nil && config.EnableThumbnails {
toThumbnail <- LoadedImage{
image: img,
filename: record.Filename,
originalSize: int(record.filesize),
}
}
if record.RawOcrSegments == nil && config.EnableOCR {
toOCR <- LoadedImage{
image: img,
filename: record.Filename,
}
}
}
return nil
})
}
if config.EnableThumbnails {
for range runtime.NumCPU() {
wg.Go(func() error {
for image := range toThumbnail {
generatedFormats := make([]string, 0)
for formatName, formatConfig := range formats {
var err error
var resized []byte
if formatConfig.targetFilesize != 0 {
lb := 1
ub := 100
for {
quality := (lb + ub) / 2
i, err := image.image.Copy()
if err != nil {
return err
}
i.Resize(float64(formatConfig.targetWidth)/float64(i.Width()), vips.KernelLanczos3)
resized, _, err = i.Export(&vips.ExportParams{
Format: formatConfig.format,
Speed: 4,
Quality: quality,
StripMetadata: true,
})
if len(resized) > image.originalSize {
ub = quality
} else {
lb = quality + 1
}
if lb >= ub {
break
}
}
} else {
i, err := image.image.Copy()
if err != nil {
return err
}
i.Resize(float64(formatConfig.targetWidth)/float64(i.Width()), vips.KernelLanczos3)
resized, _, err = i.Export(&vips.ExportParams{
Format: formatConfig.format,
Speed: 4,
Quality: formatConfig.quality,
StripMetadata: true,
})
}
if err != nil {
log.Println("thumbnailing failure", image.filename, err)
continue
}
if len(resized) < image.originalSize {
generatedFormats = append(generatedFormats, formatName)
err = bimg.Write(filepath.Join(config.ThumbsPath, generateThumbnailFilename(image.filename, formatName, formatConfig)), resized)
if err != nil {
return err
}
}
}
formatsData, err := msgpack.Marshal(generatedFormats)
if err != nil {
return err
}
_, err = db.Exec("UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?", formatsData, timestamp(), image.filename)
if err != nil {
return err
}
}
return nil
})
}
}
if config.EnableOCR {
for range 100 {
wg.Go(func() error {
for image := range toOCR {
scan, err := scanImage(image.image)
if err != nil {
log.Println("OCR failure", image.filename, err)
continue
}
ocrText := ""
for _, segment := range scan {
ocrText += segment.text
ocrText += "\n"
}
ocrData, err := msgpack.Marshal(scan)
if err != nil {
return err
}
_, err = db.Exec("UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?", ocrText, ocrData, timestamp(), image.filename)
if err != nil {
return err
}
}
return nil
})
}
}
wg.Go(func() error {
buffer := make([]EmbeddingInput, 0, backend.BatchSize)
for input := range toEmbed {
buffer = append(buffer, input)
if len(buffer) == int(backend.BatchSize) {
embedBatches <- buffer
buffer = make([]EmbeddingInput, 0, backend.BatchSize)
}
}
if len(buffer) > 0 {
embedBatches <- buffer
}
close(embedBatches)
return nil
})
for range 3 {
wg.Go(func() error {
for batch := range embedBatches {
result, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "", EmbeddingRequest{
Images: lo.Map(batch, func(item EmbeddingInput, _ int) []byte { return item.image }),
})
if err != nil {
return err
}
tx, err := db.Begin()
if err != nil {
return err
}
for i, vector := range result {
_, err = tx.Exec("UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?", timestamp(), vector, batch[i].filename)
if err != nil {
return err
}
}
err = tx.Commit()
if err != nil {
return err
}
}
return nil
})
}
filenamesOnDisk := make(map[string]struct{})
err = filepath.WalkDir(config.Files, func(path string, d os.DirEntry, err error) error {
filename := strings.TrimPrefix(path, config.Files)
if err != nil {
return err
}
if d.IsDir() {
return nil
}
filenamesOnDisk[filename] = struct{}{}
records := []FileRecord{}
err = db.Select(&records, "SELECT * FROM files WHERE filename = ?", filename)
if err != nil {
return err
}
stat, err := d.Info()
if err != nil {
return err
}
modtime := stat.ModTime().UnixMicro()
if len(records) == 0 || modtime > records[0].EmbedTime || modtime > records[0].OcrTime || modtime > records[0].ThumbnailTime {
_, err = db.Exec("INSERT OR IGNORE INTO files VALUES (?, 0, 0, 0, '', '', '', '')", filename)
if err != nil {
return err
}
record := FileRecord{
Filename: filename,
filesize: stat.Size(),
}
if len(records) > 0 {
record = records[0]
}
if modtime > record.EmbedTime || len(record.Embedding) == 0 {
record.Embedding = nil
}
if modtime > record.OcrTime || len(record.RawOcrSegments) == 0 {
record.RawOcrSegments = nil
}
if modtime > record.ThumbnailTime || len(record.Thumbnails) == 0 {
record.Thumbnails = nil
}
toProcess <- record
}
return nil
})
if err != nil {
return err
}
close(toProcess)
err = iwg.Wait()
close(toEmbed)
close(toThumbnail)
if err != nil {
return err
}
err = wg.Wait()
if err != nil {
return err
}
rows, err := db.Queryx("SELECT filename FROM files")
if err != nil {
return err
}
tx, err := db.Begin()
if err != nil {
return err
}
for rows.Next() {
var filename string
err := rows.Scan(&filename)
if err != nil {
return err
}
if _, ok := filenamesOnDisk[filename]; !ok {
_, err = tx.Exec("DELETE FROM files WHERE filename = ?", filename)
if err != nil {
return err
}
}
}
if err = tx.Commit(); err != nil {
return err
}
return nil
}
const INDEX_ADD_BATCH = 512
func buildIndex(config Config, backend InferenceServerConfig) (Index, error) {
var index Index
db, err := initializeDatabase(config)
if err != nil {
return index, err
}
defer db.Close()
newFAISSIndex, err := faiss.IndexFactory(int(backend.EmbeddingSize), "SQfp16", faiss.MetricInnerProduct)
if err != nil {
return index, err
}
index.vectors = newFAISSIndex
var count int
err = db.Get(&count, "SELECT COUNT(*) FROM files")
if err != nil {
return index, err
}
index.filenames = make([]string, 0, count)
index.formatCodes = make([]int64, 0, count)
buffer := make([]float32, 0, INDEX_ADD_BATCH*backend.EmbeddingSize)
index.formatNames = make([]string, 0, 5)
record := FileRecord{}
rows, err := db.Queryx("SELECT * FROM files")
if err != nil {
return index, err
}
for rows.Next() {
err := rows.StructScan(&record)
if err != nil {
return index, err
}
if len(record.Embedding) > 0 {
index.filenames = append(index.filenames, record.Filename)
for i := 0; i < len(record.Embedding); i += 2 {
buffer = append(buffer, float16.Frombits(uint16(record.Embedding[i])+uint16(record.Embedding[i+1])<<8).Float32())
}
if len(buffer) == cap(buffer) {
index.vectors.Add(buffer)
buffer = make([]float32, 0, INDEX_ADD_BATCH*backend.EmbeddingSize)
}
formats := make([]string, 0, 5)
if len(record.Thumbnails) > 0 {
err := msgpack.Unmarshal(record.Thumbnails, &formats)
if err != nil {
return index, err
}
}
formatCode := int64(0)
for _, formatString := range formats {
found := false
for i, name := range index.formatNames {
if name == formatString {
formatCode |= 1 << i
found = true
break
}
}
if !found {
newIndex := len(index.formatNames)
formatCode |= 1 << newIndex
index.formatNames = append(index.formatNames, formatString)
}
}
index.formatCodes = append(index.formatCodes, formatCode)
}
}
if len(buffer) > 0 {
index.vectors.Add(buffer)
}
return index, nil
}
func decodeFP16Buffer(buf []byte) []float32 {
out := make([]float32, 0, len(buf)/2)
for i := 0; i < len(buf); i += 2 {
out = append(out, float16.Frombits(uint16(buf[i])+uint16(buf[i+1])<<8).Float32())
}
return out
}
type EmbeddingVector []float32
type QueryResult struct {
Matches [][]interface{} `json:"matches"`
Formats []string `json:"formats"`
Extensions map[string]string `json:"extensions"`
}
// this terrible language cannot express tagged unions
type QueryTerm struct {
Embedding *EmbeddingVector `json:"embedding"`
Image *string `json:"image"` // base64
Text *string `json:"text"`
Weight *float32 `json:"weight"`
}
type QueryRequest struct {
Terms []QueryTerm `json:"terms"`
K *int `json:"k"`
}
func queryIndex(index *Index, query EmbeddingVector, k int) (QueryResult, error) {
var qr QueryResult
distances, ids, err := index.vectors.Search(query, int64(k))
if err != nil {
return qr, err
}
items := lo.Map(lo.Zip2(distances, ids), func(x lo.Tuple2[float32, int64], i int) []interface{} {
return []interface{}{
x.A,
index.filenames[x.B],
generateFilenameHash(index.filenames[x.B]),
index.formatCodes[x.B],
}
})
return QueryResult{
Matches: items,
Formats: index.formatNames,
}, nil
}
func handleRequest(config Config, backendConfig InferenceServerConfig, index *Index, w http.ResponseWriter, req *http.Request) error {
if req.Body == nil {
io.WriteString(w, "OK") // health check
return nil
}
dec := json.NewDecoder(req.Body)
var qreq QueryRequest
err := dec.Decode(&qreq)
if err != nil {
return err
}
totalEmbedding := make(EmbeddingVector, backendConfig.EmbeddingSize)
imageBatch := make([][]byte, 0)
imageWeights := make([]float32, 0)
textBatch := make([]string, 0)
textWeights := make([]float32, 0)
for _, term := range qreq.Terms {
if term.Image != nil {
bytes, err := base64.StdEncoding.DecodeString(*term.Image)
if err != nil {
return err
}
loaded := bimg.NewImage(bytes)
resized, err := loaded.Process(bimg.Options{
Width: int(backendConfig.ImageSize[0]),
Height: int(backendConfig.ImageSize[1]),
Force: true,
Type: bimg.PNG,
Interpretation: bimg.InterpretationSRGB,
})
if err != nil {
return err
}
imageBatch = append(imageBatch, resized)
if term.Weight != nil {
imageWeights = append(imageWeights, *term.Weight)
} else {
imageWeights = append(imageWeights, 1)
}
}
if term.Text != nil {
textBatch = append(textBatch, *term.Text)
if term.Weight != nil {
textWeights = append(textWeights, *term.Weight)
} else {
textWeights = append(textWeights, 1)
}
}
if term.Embedding != nil {
weight := float32(1.0)
if term.Weight != nil {
weight = *term.Weight
}
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
totalEmbedding[i] += (*term.Embedding)[i] * weight
}
}
}
if len(imageBatch) > 0 {
embs, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "/", EmbeddingRequest{
Images: imageBatch,
})
if err != nil {
return err
}
for j, emb := range embs {
embd := decodeFP16Buffer(emb)
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
totalEmbedding[i] += embd[i] * imageWeights[j]
}
}
}
if len(textBatch) > 0 {
embs, err := queryClipServer[EmbeddingRequest, EmbeddingResponse](config, "/", EmbeddingRequest{
Text: textBatch,
})
if err != nil {
return err
}
for j, emb := range embs {
embd := decodeFP16Buffer(emb)
for i := 0; i < int(backendConfig.EmbeddingSize); i += 1 {
totalEmbedding[i] += embd[i] * textWeights[j]
}
}
}
k := 1000
if qreq.K != nil {
k = *qreq.K
}
w.Header().Add("Content-Type", "application/json")
enc := json.NewEncoder(w)
qres, err := queryIndex(index, totalEmbedding, k)
qres.Extensions = make(map[string]string)
for k, v := range imageFormats(config) {
qres.Extensions[k] = v.extension
}
if err != nil {
return err
}
err = enc.Encode(qres)
if err != nil {
return err
}
return nil
}
func init() {
os.Setenv("VIPS_WARNING", "FALSE") // this does not actually work
bimg.VipsCacheSetMax(0)
bimg.VipsCacheSetMaxMem(0)
}
func main() {
vips.Startup(&vips.Config{})
defer vips.Shutdown()
content, err := os.ReadFile(os.Args[1])
if err != nil {
log.Fatal("config file unreadable ", err)
}
var config Config
err = json.Unmarshal(content, &config)
if err != nil {
log.Fatal("config file wrong ", err)
}
fmt.Println(config)
db, err := sqlx.Connect("sqlite3", config.DbPath)
if err != nil {
log.Fatal("DB connection failure ", db)
}
db.MustExec(schema)
var backend InferenceServerConfig
for {
resp, err := http.Get(config.ClipServer + "/config")
if err != nil {
log.Println("backend failed (fetch) ", err)
}
backend, err = decodeMsgpackFrom[InferenceServerConfig](resp)
resp.Body.Close()
if err != nil {
log.Println("backend failed (parse) ", err)
} else {
break
}
time.Sleep(time.Second)
}
requestIngest := make(chan struct{}, 1)
var index *Index
// maybe this ought to be mutexed?
var lastError *error
// there's not a neat way to reusably broadcast to multiple channels, but I *can* abuse WaitGroups probably
// this might cause horrible concurrency issues, but you brought me to this point, Go designers
var wg sync.WaitGroup
go func() {
for {
wg.Add(1)
log.Println("ingest running")
err := ingestFiles(config, backend)
if err != nil {
log.Println("ingest failed ", err)
lastError = &err
} else {
newIndex, err := buildIndex(config, backend)
if err != nil {
log.Println("index build failed ", err)
lastError = &err
} else {
lastError = nil
index = &newIndex
}
}
wg.Done()
<-requestIngest
}
}()
newIndex, err := buildIndex(config, backend)
index = &newIndex
if err != nil {
log.Fatal("index build failed ", err)
}
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Headers", "Content-Type")
if req.Method == "OPTIONS" {
w.WriteHeader(204)
return
}
err := handleRequest(config, backend, index, w, req)
if err != nil {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{
"error": err.Error(),
})
}
})
http.HandleFunc("/reload", func(w http.ResponseWriter, req *http.Request) {
if req.Method == "POST" {
log.Println("requesting index reload")
select {
case requestIngest <- struct{}{}:
default:
}
wg.Wait()
if lastError == nil {
w.Write([]byte("OK"))
} else {
w.WriteHeader(500)
w.Write([]byte((*lastError).Error()))
}
}
})
http.HandleFunc("/profile", func(w http.ResponseWriter, req *http.Request) {
f, err := os.Create("mem.pprof")
if err != nil {
log.Fatal("could not create memory profile: ", err)
}
defer f.Close()
var m runtime.MemStats
runtime.ReadMemStats(&m)
log.Printf("Memory usage: Alloc=%v, TotalAlloc=%v, Sys=%v", m.Alloc, m.TotalAlloc, m.Sys)
log.Println(bimg.VipsMemory())
bimg.VipsDebugInfo()
runtime.GC() // Trigger garbage collection
if err := pprof.WriteHeapProfile(f); err != nil {
log.Fatal("could not write memory profile: ", err)
}
})
log.Println("starting server")
http.ListenAndServe(fmt.Sprintf(":%d", config.Port), nil)
}

View File

@ -1,265 +0,0 @@
package main
import (
"bytes"
"errors"
"fmt"
"io"
"math"
"mime/multipart"
"net/http"
"net/textproto"
"regexp"
"strings"
"time"
"github.com/davidbyttow/govips/v2/vips"
"github.com/samber/lo"
"github.com/titanous/json5"
)
const CALLBACK_REGEX string = ">AF_initDataCallback\\(({key: 'ds:1'.*?)\\);</script>"
type SegmentCoords struct {
x int
y int
w int
h int
}
type Segment struct {
coords SegmentCoords
text string
}
type ScanResult []Segment
// TODO coordinates are negative sometimes and I think they shouldn't be
func rationalizeCoordsFormat1(imageW float64, imageH float64, centerXFraction float64, centerYFraction float64, widthFraction float64, heightFraction float64) SegmentCoords {
return SegmentCoords{
x: int(math.Round((centerXFraction - widthFraction/2) * imageW)),
y: int(math.Round((centerYFraction - heightFraction/2) * imageH)),
w: int(math.Round(widthFraction * imageW)),
h: int(math.Round(heightFraction * imageH)),
}
}
func scanImageChunk(image []byte, imageWidth int, imageHeight int) (ScanResult, error) {
var result ScanResult
timestamp := time.Now().UnixMicro()
var b bytes.Buffer
w := multipart.NewWriter(&b)
defer w.Close()
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="encoded_image"; filename="ocr%d.png"`, timestamp))
h.Set("Content-Type", "image/png")
fw, err := w.CreatePart(h)
if err != nil {
return result, err
}
fw.Write(image)
w.Close()
req, err := http.NewRequest("POST", fmt.Sprintf("https://lens.google.com/v3/upload?stcs=%d", timestamp), &b)
if err != nil {
return result, err
}
req.Header.Add("User-Agent", "Mozilla/5.0 (Linux; Android 13; RMX3771) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.6167.144 Mobile Safari/537.36")
req.AddCookie(&http.Cookie{
Name: "SOCS",
Value: "CAESEwgDEgk0ODE3Nzk3MjQaAmVuIAEaBgiA_LyaBg",
})
req.Header.Set("Content-Type", w.FormDataContentType())
client := http.Client{}
res, err := client.Do(req)
if err != nil {
return result, err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return result, err
}
re, _ := regexp.Compile(CALLBACK_REGEX)
matches := re.FindStringSubmatch(string(body[:]))
if len(matches) == 0 {
return result, fmt.Errorf("invalid API response")
}
match := matches[1]
var lensObject map[string]interface{}
err = json5.Unmarshal([]byte(match), &lensObject)
if err != nil {
return result, err
}
if _, ok := lensObject["errorHasStatus"]; ok {
return result, errors.New("lens failed")
}
root := lensObject["data"].([]interface{})
var textSegments []string
var textRegions []SegmentCoords
// I don't know why Google did this.
// Text segments are in one place and their locations are in another, using a very strange coordinate system.
// At least I don't need whatever is contained in the base64 parts (which I assume are protobufs).
// TODO: on a few images, this seems to not work for some reason.
defer func() {
if r := recover(); r != nil {
// https://github.com/dimdenGD/chrome-lens-ocr/blob/main/src/core.js#L316 has code for a fallback text segment read mode.
// In testing, this proved unnecessary (quirks of the HTTP request? I don't know), and this only happens on textless images.
textSegments = []string{}
textRegions = []SegmentCoords{}
}
}()
textSegmentsRaw := root[3].([]interface{})[4].([]interface{})[0].([]interface{})[0].([]interface{})
textRegionsRaw := root[2].([]interface{})[3].([]interface{})[0].([]interface{})
for _, x := range textRegionsRaw {
if strings.HasPrefix(x.([]interface{})[11].(string), "text:") {
rawCoords := x.([]interface{})[1].([]interface{})
coords := rationalizeCoordsFormat1(float64(imageWidth), float64(imageHeight), rawCoords[0].(float64), rawCoords[1].(float64), rawCoords[2].(float64), rawCoords[3].(float64))
textRegions = append(textRegions, coords)
}
}
for _, x := range textSegmentsRaw {
textSegment := x.(string)
textSegments = append(textSegments, textSegment)
}
return lo.Map(lo.Zip2(textSegments, textRegions), func(x lo.Tuple2[string, SegmentCoords], _ int) Segment {
return Segment{
text: x.A,
coords: x.B,
}
}), nil
}
const MAX_DIM int = 1024
func scanImage(image *vips.ImageRef) (ScanResult, error) {
result := ScanResult{}
width := image.Width()
height := image.Height()
if width > MAX_DIM {
width = MAX_DIM
height = int(math.Round(float64(height) * (float64(width) / float64(image.Width()))))
}
downscaled, err := image.Copy()
if err != nil {
return result, err
}
downscaled.Resize(float64(width)/float64(image.Width()), vips.KernelLanczos3)
for y := 0; y < height; y += MAX_DIM {
chunkHeight := MAX_DIM
if y+chunkHeight > height {
chunkHeight = height - y
}
chunk, err := image.Copy() // TODO this really really should not be in-place
if err != nil {
return result, err
}
err = chunk.ExtractArea(0, y, width, height)
if err != nil {
return result, err
}
buf, _, err := chunk.ExportPng(&vips.PngExportParams{})
if err != nil {
return result, err
}
res, err := scanImageChunk(buf, width, chunkHeight)
if err != nil {
return result, err
}
for _, segment := range res {
result = append(result, Segment{
text: segment.text,
coords: SegmentCoords{
y: segment.coords.y + y,
x: segment.coords.x,
w: segment.coords.w,
h: segment.coords.h,
},
})
}
}
return result, nil
}
/*
async def scan_image_chunk(sess, image):
# send data to inscrutable undocumented Google service
# https://github.com/AuroraWright/owocr/blob/master/owocr/ocr.py#L193
async with aiohttp.ClientSession() as sess:
data = aiohttp.FormData()
data.add_field(
"encoded_image",
encode_img(image),
filename="ocr" + str(timestamp) + ".png",
content_type="image/png"
)
async with sess.post(url, headers=headers, cookies=cookies, data=data, timeout=10) as res:
body = await res.text()
# I really worry about Google sometimes. This is not a sensible format.
match = CALLBACK_REGEX.search(body)
if match == None:
raise ValueError("Invalid callback")
lens_object = pyjson5.loads(match.group(1))
if "errorHasStatus" in lens_object:
raise RuntimeError("Lens failed")
text_segments = []
text_regions = []
root = lens_object["data"]
# I don't know why Google did this.
# Text segments are in one place and their locations are in another, using a very strange coordinate system.
# At least I don't need whatever is contained in the base64 partss (which I assume are protobufs).
# TODO: on a few images, this seems to not work for some reason.
try:
text_segments = root[3][4][0][0]
text_regions = [ rationalize_coords_format1(image.width, image.height, *x[1]) for x in root[2][3][0] if x[11].startswith("text:") ]
except (KeyError, IndexError):
# https://github.com/dimdenGD/chrome-lens-ocr/blob/main/src/core.js#L316 has code for a fallback text segment read mode.
# In testing, this proved unnecessary (quirks of the HTTP request? I don't know), and this only happens on textless images.
return [], []
return text_segments, text_regions
MAX_SCAN_DIM = 1000 # not actually true but close enough
def chunk_image(image: Image):
chunks = []
# Cut image down in X axis (I'm assuming images aren't too wide to scan in downscaled form because merging text horizontally would be annoying)
if image.width > MAX_SCAN_DIM:
image = image.resize((MAX_SCAN_DIM, round(image.height * (image.width / MAX_SCAN_DIM))), Image.LANCZOS)
for y in range(0, image.height, MAX_SCAN_DIM):
chunks.append(image.crop((0, y, image.width, min(y + MAX_SCAN_DIM, image.height))))
return chunks
async def scan_chunks(sess: aiohttp.ClientSession, chunks: [Image]):
# If text happens to be split across the cut line it won't get read.
# This is because doing overlap read areas would be really annoying.
text = ""
regions = []
for chunk in chunks:
new_segments, new_regions = await scan_image_chunk(sess, chunk)
for segment in new_segments:
text += segment + "\n"
for i, (segment, region) in enumerate(zip(new_segments, new_regions)):
regions.append({ **region, "y": region["y"] + (MAX_SCAN_DIM * i), "text": segment })
return text, regions
async def scan_image(sess: aiohttp.ClientSession, image: Image):
return await scan_chunks(sess, chunk_image(image))
if __name__ == "__main__":
async def main():
async with aiohttp.ClientSession() as sess:
print(await scan_image(sess, Image.open("/data/public/memes-or-something/linear-algebra-chess.png")))
asyncio.run(main())
*/

View File

@ -1,212 +0,0 @@
from aiohttp import web
import aiohttp
import asyncio
import traceback
import umsgpack
from PIL import Image
import base64
import aiosqlite
import faiss
import numpy
import os
import aiohttp_cors
import json
import io
import sys
from concurrent.futures import ProcessPoolExecutor
with open(sys.argv[1], "r") as config_file:
CONFIG = json.load(config_file)
app = web.Application(client_max_size=32*1024**2)
routes = web.RouteTableDef()
async def clip_server(query, unpack_buffer=True):
async with aiohttp.ClientSession() as sess:
async with sess.post(CONFIG["clip_server"], data=umsgpack.dumps(query)) as res:
response = umsgpack.loads(await res.read())
if res.status == 200:
if unpack_buffer:
response = [ numpy.frombuffer(x, dtype="float16") for x in response ]
return response
else:
raise Exception(response if res.headers.get("content-type") == "application/msgpack" else (await res.text()))
@routes.post("/")
async def run_query(request):
data = await request.json()
embeddings = []
if images := data.get("images", []):
embeddings.extend(await clip_server({ "images": [ base64.b64decode(x) for x, w in images ] }))
if text := data.get("text", []):
embeddings.extend(await clip_server({ "text": [ x for x, w in text ] }))
weights = [ w for x, w in images ] + [ w for x, w in text ]
embeddings = [ e * w for e, w in zip(embeddings, weights) ]
if not embeddings:
return web.json_response([])
return web.json_response(app["index"].search(sum(embeddings)))
@routes.get("/")
async def health_check(request):
return web.Response(text="OK")
@routes.post("/reload_index")
async def reload_index_route(request):
await request.app["index"].reload()
return web.json_response(True)
def load_image(path, image_size):
im = Image.open(path)
im.draft("RGB", image_size)
buf = io.BytesIO()
im.resize(image_size).convert("RGB").save(buf, format="BMP")
return buf.getvalue(), path
class Index:
def __init__(self, inference_server_config):
self.faiss_index = faiss.IndexFlatIP(inference_server_config["embedding_size"])
self.associated_filenames = []
self.inference_server_config = inference_server_config
self.lock = asyncio.Lock()
def search(self, query):
distances, indices = self.faiss_index.search(numpy.array([query]), 4000)
distances = distances[0]
indices = indices[0]
try:
indices = indices[:numpy.where(indices==-1)[0][0]]
except IndexError: pass
return [ { "score": float(distance), "file": self.associated_filenames[index] } for index, distance in zip(indices, distances) ]
async def reload(self):
async with self.lock:
with ProcessPoolExecutor(max_workers=12) as executor:
print("Indexing")
conn = await aiosqlite.connect(CONFIG["db_path"], parent_loop=asyncio.get_running_loop())
conn.row_factory = aiosqlite.Row
await conn.executescript("""
CREATE TABLE IF NOT EXISTS files (
filename TEXT PRIMARY KEY,
modtime REAL NOT NULL,
embedding_vector BLOB NOT NULL
);
""")
try:
async with asyncio.TaskGroup() as tg:
batch_sem = asyncio.Semaphore(32)
modified = set()
async def do_batch(batch):
try:
query = { "images": [ arg[2] for arg in batch ] }
embeddings = await clip_server(query, False)
await conn.executemany("INSERT OR REPLACE INTO files VALUES (?, ?, ?)", [
(filename, modtime, embedding) for (filename, modtime, _), embedding in zip(batch, embeddings)
])
await conn.commit()
for filename, _, _ in batch:
modified.add(filename)
sys.stdout.write(".")
finally:
batch_sem.release()
async def dispatch_batch(batch):
await batch_sem.acquire()
tg.create_task(do_batch(batch))
files = {}
for filename, modtime in await conn.execute_fetchall("SELECT filename, modtime FROM files"):
files[filename] = modtime
await conn.commit()
batch = []
for dirpath, _, filenames in os.walk(CONFIG["files"]):
paths = []
for file in filenames:
path = os.path.join(dirpath, file)
file = os.path.relpath(path, CONFIG["files"])
st = os.stat(path)
if st.st_mtime != files.get(file):
paths.append(path)
for task in asyncio.as_completed([ asyncio.get_running_loop().run_in_executor(executor, load_image, path, self.inference_server_config["image_size"]) for path in paths ]):
try:
b, path = await task
st = os.stat(path)
file = os.path.relpath(path, CONFIG["files"])
except Exception as e:
print(file, "failed", e)
continue
batch.append((file, st.st_mtime, b))
if len(batch) == self.inference_server_config["batch"]:
await dispatch_batch(batch)
batch = []
if batch:
await dispatch_batch(batch)
remove_indices = []
for index, filename in enumerate(self.associated_filenames):
if filename not in files or filename in modified:
remove_indices.append(index)
self.associated_filenames[index] = None
if filename not in files:
await conn.execute("DELETE FROM files WHERE filename = ?", (filename,))
await conn.commit()
# TODO concurrency
# TODO understand what that comment meant
if remove_indices:
self.faiss_index.remove_ids(numpy.array(remove_indices))
self.associated_filenames = [ x for x in self.associated_filenames if x is not None ]
filenames_set = set(self.associated_filenames)
new_data = []
new_filenames = []
async with conn.execute("SELECT * FROM files") as csr:
while row := await csr.fetchone():
filename, modtime, embedding_vector = row
if filename not in filenames_set:
new_data.append(numpy.frombuffer(embedding_vector, dtype="float16"))
new_filenames.append(filename)
new_data = numpy.array(new_data)
self.associated_filenames.extend(new_filenames)
self.faiss_index.add(new_data)
finally:
await conn.close()
app.router.add_routes(routes)
cors = aiohttp_cors.setup(app, defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=False,
expose_headers="*",
allow_headers="*",
)
})
for route in list(app.router.routes()):
cors.add(route)
async def main():
while True:
async with aiohttp.ClientSession() as sess:
try:
async with await sess.get(CONFIG["clip_server"] + "config") as res:
inference_server_config = umsgpack.unpackb(await res.read())
print("Backend config:", inference_server_config)
break
except:
traceback.print_exc()
await asyncio.sleep(1)
index = Index(inference_server_config)
app["index"] = index
await index.reload()
print("Ready")
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "", CONFIG["port"])
await site.start()
if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(main())
loop.run_forever()

View File

@ -1,19 +0,0 @@
import numpy
import xgboost as xgb
import shared
trains, validations = shared.fetch_ratings()
ranker = xgb.XGBRanker(
tree_method="hist",
lambdarank_num_pair_per_sample=8,
objective="rank:ndcg",
lambdarank_pair_method="topk",
device="cuda"
)
flat_samples = [ sample for trainss in trains for sample in trainss ]
X = numpy.concatenate([ numpy.stack((meme1, meme2)) for meme1, meme2, rating in flat_samples ])
Y = numpy.concatenate([ numpy.stack((int(rating), int(1 - rating))) for meme1, meme2, rating in flat_samples ])
qid = numpy.concatenate([ numpy.stack((i, i)) for i in range(len(flat_samples)) ])
ranker.fit(X, Y, qid=qid, verbose=True)