mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-04-06 18:56:57 +00:00
model path argument
This commit is contained in:
parent
ec27deddbf
commit
435a9812dc
@ -20,7 +20,7 @@ with open(sys.argv[1], "r") as config_file:
|
||||
CONFIG = json.load(config_file)
|
||||
|
||||
device = torch.device(CONFIG["device"])
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(CONFIG["model"], device=device, pretrained=dict(open_clip.list_pretrained())[CONFIG["model"]], precision="fp16")
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(CONFIG["model"], device=device, pretrained=CONFIG.get("model_path", dict(open_clip.list_pretrained())[CONFIG["model"]], precision="fp16"))
|
||||
model.eval()
|
||||
tokenizer = open_clip.get_tokenizer(CONFIG["model"])
|
||||
print("Model loaded")
|
||||
|
Loading…
x
Reference in New Issue
Block a user