diff --git a/clip_server.py b/clip_server.py index a228189..02b1ac2 100644 --- a/clip_server.py +++ b/clip_server.py @@ -20,7 +20,7 @@ with open(sys.argv[1], "r") as config_file: CONFIG = json.load(config_file) device = torch.device(CONFIG["device"]) -model, _, preprocess = open_clip.create_model_and_transforms(CONFIG["model"], device=device, pretrained=dict(open_clip.list_pretrained())[CONFIG["model"]], precision="fp16") +model, _, preprocess = open_clip.create_model_and_transforms(CONFIG["model"], device=device, pretrained=CONFIG.get("model_path", dict(open_clip.list_pretrained())[CONFIG["model"]], precision="fp16")) model.eval() tokenizer = open_clip.get_tokenizer(CONFIG["model"]) print("Model loaded")