Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to save cached embeddings to disk for re-use #17

Open
thompsonmj opened this issue Jul 9, 2024 · 1 comment
Open

Add option to save cached embeddings to disk for re-use #17

thompsonmj opened this issue Jul 9, 2024 · 1 comment

Comments

@thompsonmj
Copy link
Contributor

To follow-up #16 and the discussion in bioclip, create an option for the text embeddings cache to be saved to disk and an option to load these from a specified filepath rather than recomputing if the category list is reused.

@johnbradley
Copy link
Collaborator

This can be somewhat done with the current code, but it isn't very appealing. The CustomLabelsClassifier class has two properties that contain the embeddings. classes - a list of class names and txt_features. Both of these items would need to be saved and restored.

Rough Code

Embed text an save "classes.json" and "txt_features.npy":

from bioclip import CustomLabelsClassifier
import numpy as np
import json
classifier = CustomLabelsClassifier(cls_ary=["dog","cat","fish"])
with open("classes.json", "w") as outfile:
    json.dump(classifier.classes, outfile)
np.save("txt_features.npy", classifier.txt_features.numpy())

Load "classes.json" and "txt_features.npy" and make a prediction:

from bioclip import CustomLabelsClassifier
import numpy as np
import json
classifier = CustomLabelsClassifier(cls_ary=[""])
with open("classes.json", "r") as infile:
    classifier.classes = json.load(infile)
classifier.txt_features = np.load("txt_features.npy")

print(classifier.predict("Ursus-arctos.jpeg"))

johnbradley added a commit that referenced this issue Nov 26, 2024
Adds method to CustomLabelsClassifier to save class labels and
embeddings to a npy file. Adds embeddings_path parameter to
CustomLabelsClassifier to load the labels and embeddings.

This change needs to wait for this PR due to a field name change:
#64

Fixes #17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants