diff --git a/clip_server.py b/clip_server.py index bc3bc04..b974b2b 100644 --- a/clip_server.py +++ b/clip_server.py @@ -169,12 +169,16 @@ async def run_inference(request): print(results[1]) return web.Response(body=msgpack.dumps(body_data), status=status, content_type="application/msgpack") +image_size = [ t for t in preprocess.transforms if isinstance(t, transforms.Resize) ][0].size +if not isinstance(image_size, tuple): + image_size = (image_size, image_size) + @routes.get("/config") async def config(request): return web.Response(body=msgpack.dumps({ "model": CONFIG["model"], "batch": BS, - "image_size": [ t for t in preprocess.transforms if isinstance(t, transforms.Resize) ][0].size, + "image_size": image_size, "embedding_size": model.text.text_projection.out_features }), status=200, content_type="application/msgpack")