diff --git a/src/bioclip/__init__.py b/src/bioclip/__init__.py index 2c50324..1f54038 100644 --- a/src/bioclip/__init__.py +++ b/src/bioclip/__init__.py @@ -1,6 +1,16 @@ # SPDX-FileCopyrightText: 2024-present John Bradley # # SPDX-License-Identifier: MIT -from bioclip.predict import TreeOfLifeClassifier, Rank, CustomLabelsClassifier, CustomLabelsBinningClassifier __all__ = ["TreeOfLifeClassifier", "Rank", "CustomLabelsClassifier", "CustomLabelsBinningClassifier"] + +def __getattr__(name): + if name in __all__: + from bioclip.predict import ( + TreeOfLifeClassifier, + Rank, + CustomLabelsClassifier, + CustomLabelsBinningClassifier, + ) + return locals()[name] + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/bioclip/__main__.py b/src/bioclip/__main__.py index 899fef2..cde1c60 100644 --- a/src/bioclip/__main__.py +++ b/src/bioclip/__main__.py @@ -1,6 +1,3 @@ -from bioclip import TreeOfLifeClassifier, Rank, CustomLabelsClassifier, CustomLabelsBinningClassifier -from .predict import BIOCLIP_MODEL_STR -import open_clip as oc import os import json import sys @@ -8,6 +5,7 @@ import pandas as pd import argparse +DEFAULT_MODEL_STR = "hf-hub:imageomics/bioclip" def write_results(data, format, output): df = pd.DataFrame(data) @@ -45,10 +43,16 @@ def predict(image_file: list[str], format: str, output: str, cls_str: str, - rank: Rank, + rank: "Rank", bins_path: str, k: int, **kwargs): + from bioclip.predict import ( + TreeOfLifeClassifier, + CustomLabelsClassifier, + CustomLabelsBinningClassifier, + Rank + ) if cls_str: classifier = CustomLabelsClassifier(cls_ary=cls_str.split(','), **kwargs) predictions = classifier.predict(images=image_file, k=k) @@ -65,6 +69,7 @@ def predict(image_file: list[str], def embed(image_file: list[str], output: str, **kwargs): + from bioclip.predict import TreeOfLifeClassifier classifier = TreeOfLifeClassifier(**kwargs) images_dict = {} data = { @@ -87,7 +92,7 @@ def create_parser(): device_arg = {'default':'cpu', 'help': 'device to use (cpu or cuda or mps), default: cpu'} output_arg = {'default': 'stdout', 'help': 'print output to file, default: stdout'} - model_arg = {'help': f'model identifier (see command list-models); default: {BIOCLIP_MODEL_STR}'} + model_arg = {'help': f'model identifier (see command list-models); default: {DEFAULT_MODEL_STR}'} pretrained_arg = {'help': 'pretrained model checkpoint as tag or file, depends on model; ' 'needed only if more than one is available (see command list-models)'} @@ -98,7 +103,7 @@ def create_parser(): predict_parser.add_argument('--output', **output_arg) cls_group = predict_parser.add_mutually_exclusive_group(required=False) cls_group.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'], - help='rank of the classification, default: species (when)') + help='rank of the classification, default: species') cls_help = "classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank and --bins arguments are not allowed." cls_group.add_argument('--cls', help=cls_help) cls_group.add_argument('--bins', help='path to CSV file with two columns with the first being classes and second being bin names, when specified the --cls argument is not allowed.') @@ -123,7 +128,7 @@ def create_parser(): 'Note that this will only list models known to open_clip; ' 'any model identifier loadable by open_clip, such as from hf-hub, file, etc ' 'should also be usable for --model in the embed and predict commands. ' - f'(The default model {BIOCLIP_MODEL_STR} is one example.)') + f'(The default model {DEFAULT_MODEL_STR} is one example.)') list_parser.add_argument('--model', help='list available tags for pretrained model checkpoint(s) for specified model') return parser @@ -136,8 +141,12 @@ def parse_args(input_args=None): # tree of life class list mode if args.model or args.pretrained: raise ValueError("Custom model or checkpoints currently not supported for Tree-of-Life prediction") - if not args.rank: - args.rank = 'species' + + # Set default rank if not provided + args.rank = args.rank or 'species' + + from bioclip.predict import Rank + args.rank = Rank[args.rank.upper()] if not args.k: args.k = 5 @@ -174,6 +183,7 @@ def main(): model_str=args.model, pretrained_str=args.pretrained) elif args.command == 'list-models': + import open_clip as oc if args.model: for tag in oc.list_pretrained_tags_by_model(args.model): print(tag) diff --git a/tests/test_main.py b/tests/test_main.py index 5d8ac24..af55c59 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,11 +1,164 @@ import unittest -from unittest.mock import mock_open, patch +from unittest.mock import mock_open, patch, MagicMock import argparse import pandas as pd -from bioclip.__main__ import parse_args, Rank, create_classes_str, main, parse_bins_csv +from enum import Enum +from bioclip.predict import Rank +from bioclip.__main__ import parse_args, create_classes_str, main, parse_bins_csv class TestParser(unittest.TestCase): + + def test_parse_args_lazy_import(self): + """Test that Rank is only imported when needed""" + # Should not import Rank + with patch('bioclip.predict.Rank') as mock_rank: + args = parse_args(['embed', 'image.jpg']) + mock_rank.assert_not_called() + + # Should not import Rank when using --cls + with patch('bioclip.predict.Rank') as mock_rank: + args = parse_args(['predict', 'image.jpg', '--cls', 'class1,class2']) + mock_rank.assert_not_called() + + # Should not import Rank when using --bins + with patch('bioclip.predict.Rank') as mock_rank: + args = parse_args(['predict', 'image.jpg', '--bins', 'bins.csv']) + mock_rank.assert_not_called() + + # Should import Rank for tree-of-life prediction + with patch('bioclip.predict.Rank', Rank) as mock_rank: + args = parse_args(['predict', 'image.jpg']) + self.assertEqual(args.rank, Rank.SPECIES) + +def test_list_models_lazy_import(self): + """Test that open_clip is only imported when the list-models command is used""" + # Should import open_clip for list-models + with patch('bioclip.__main__.open_clip', create=True) as mock_oc: + mock_parse_args = MagicMock(return_value=argparse.Namespace( + command='list-models', + model=None + )) + with patch('bioclip.__main__.parse_args', mock_parse_args), \ + patch('builtins.print'): # prevent actual printing + main() + mock_oc.list_models.assert_called_once() + + # Should call list_pretrained_tags when model specified + with patch('bioclip.__main__.open_clip', create=True) as mock_oc: + mock_parse_args = MagicMock(return_value=argparse.Namespace( + command='list-models', + model='somemodel' + )) + with patch('bioclip.__main__.parse_args', mock_parse_args), \ + patch('builtins.print'): # prevent actual printing + main() + mock_oc.list_pretrained_tags_by_model.assert_called_once_with('somemodel') + + def test_predict_lazy_imports(self): + """Test that classifier classes are only imported when needed""" + # For cls_str path + with patch('bioclip.predict.TreeOfLifeClassifier') as mock_tree, \ + patch('bioclip.predict.CustomLabelsClassifier') as mock_custom, \ + patch('bioclip.predict.CustomLabelsBinningClassifier') as mock_binning: + mock_parse_args = MagicMock(return_value=argparse.Namespace( + command='predict', + image_file=['image.jpg'], + format='csv', + output='stdout', + cls='cat,dog', + bins=None, + device='cpu', + model=None, + pretrained=None, + k=5, + rank=None + )) + with patch('bioclip.__main__.parse_args', mock_parse_args): + with patch('bioclip.__main__.write_results'): # Prevent actual write + main() + mock_custom.assert_called() + mock_tree.assert_not_called() + mock_binning.assert_not_called() + + # For bins path + with patch('bioclip.predict.TreeOfLifeClassifier') as mock_tree, \ + patch('bioclip.predict.CustomLabelsClassifier') as mock_custom, \ + patch('bioclip.predict.CustomLabelsBinningClassifier') as mock_binning: + mock_parse_args = MagicMock(return_value=argparse.Namespace( + command='predict', + image_file=['image.jpg'], + format='csv', + output='stdout', + cls=None, + bins='bins.csv', + device='cpu', + model=None, + pretrained=None, + k=5, + rank=None + )) + with patch('bioclip.__main__.parse_args', mock_parse_args), \ + patch('bioclip.__main__.parse_bins_csv', return_value={}), \ + patch('bioclip.__main__.write_results'): + main() + mock_binning.assert_called() + mock_tree.assert_not_called() + mock_custom.assert_not_called() + + # For default (TreeOfLifeClassifier) path + with patch('bioclip.predict.TreeOfLifeClassifier') as mock_tree, \ + patch('bioclip.predict.CustomLabelsClassifier') as mock_custom, \ + patch('bioclip.predict.CustomLabelsBinningClassifier') as mock_binning: + mock_parse_args = MagicMock(return_value=argparse.Namespace( + command='predict', + image_file=['image.jpg'], + format='csv', + output='stdout', + cls=None, + bins=None, + device='cpu', + model=None, + pretrained=None, + k=5, + rank=Rank.SPECIES + )) + with patch('bioclip.__main__.parse_args', mock_parse_args), \ + patch('bioclip.__main__.write_results'): + main() + mock_tree.assert_called() + mock_custom.assert_not_called() + mock_binning.assert_not_called() + + def test_embed_lazy_imports(self): + """Test that TreeOfLifeClassifier is only imported for embed command""" + class MockTensor: + def tolist(self): + return [1.0, 2.0, 3.0] + + with patch('bioclip.predict.TreeOfLifeClassifier') as mock_clf: + # Mock the classifier instance + mock_clf_instance = MagicMock() + mock_clf.return_value = mock_clf_instance + + # Make create_image_features_for_image return our mock tensor + mock_clf_instance.create_image_features_for_image.return_value = MockTensor() + mock_clf_instance.model_str = "test-model" + + mock_parse_args = MagicMock(return_value=argparse.Namespace( + command='embed', + image_file=['image.jpg'], + output='stdout', + device='cpu', + model=None, + pretrained=None + )) + with patch('bioclip.__main__.parse_args', mock_parse_args), \ + patch('builtins.print'): # prevent actual printing to stdout + main() + mock_clf.assert_called_once() + mock_clf_instance.create_image_features_for_image.assert_called_once() + def test_parse_args(self): args = parse_args(['predict', 'image.jpg'])