1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-04-06 18:56:57 +00:00

model path argument

This commit is contained in:
osmarks 2025-03-26 11:02:40 +00:00
parent ec27deddbf
commit 435a9812dc

View File

@ -20,7 +20,7 @@ with open(sys.argv[1], "r") as config_file:
CONFIG = json.load(config_file)
device = torch.device(CONFIG["device"])
model, _, preprocess = open_clip.create_model_and_transforms(CONFIG["model"], device=device, pretrained=dict(open_clip.list_pretrained())[CONFIG["model"]], precision="fp16")
model, _, preprocess = open_clip.create_model_and_transforms(CONFIG["model"], device=device, pretrained=CONFIG.get("model_path", dict(open_clip.list_pretrained())[CONFIG["model"]], precision="fp16"))
model.eval()
tokenizer = open_clip.get_tokenizer(CONFIG["model"])
print("Model loaded")