diff --git a/clip_server.py b/clip_server.py index c937a54..6a2ea7e 100644 --- a/clip_server.py +++ b/clip_server.py @@ -23,7 +23,7 @@ import numpy with open(sys.argv[1], "r") as config_file: CONFIG = json.load(config_file) -DEVICE = "cuda:0" +DEVICE = CONFIG["device"] # So400m/14@384 with init_empty_weights():