diff --git a/ammico/test/test_utils.py b/ammico/test/test_utils.py index 6af9d5d6..28dae4aa 100644 --- a/ammico/test/test_utils.py +++ b/ammico/test/test_utils.py @@ -1,11 +1,76 @@ import json import pandas as pd import ammico.utils as ut +import pytest +import os def test_find_files(get_path): - result = ut.find_files(path=get_path, pattern="*.png", recursive=True, limit=10) - assert len(result) > 0 + with pytest.raises(FileNotFoundError): + ut.find_files(path=".", pattern="*.png") + + result_jpg = ut.find_files(path=get_path, pattern=".jpg", recursive=True, limit=10) + assert 0 < len(result_jpg) <= 10 + + result_png = ut.find_files(path=get_path, pattern=".png", recursive=True, limit=10) + assert 0 < len(result_png) <= 10 + + result_png_jpg = ut.find_files( + path=get_path, pattern=["png", "jpg"], recursive=True, limit=10 + ) + assert 0 < len(result_png_jpg) <= 10 + + result_png_jpg = ut.find_files( + path=get_path, pattern=["png", "jpg"], recursive=True, limit=4 + ) + assert 0 < len(result_png_jpg) <= 4 + + result_png_jpg = ut.find_files( + path=get_path, pattern=["png", "jpg"], recursive=True, limit=[2, 4] + ) + assert 0 < len(result_png_jpg) <= 2 + + one_dir_up_path = os.path.join(get_path, "..") + with pytest.raises(FileNotFoundError): + ut.find_files( + path=one_dir_up_path, pattern=["png", "jpg"], recursive=False, limit=[2, 4] + ) + + result_png_jpg = ut.find_files( + path=one_dir_up_path, pattern=["png", "jpg"], recursive=True, limit=[2, 4] + ) + assert 0 < len(result_png_jpg) <= 2 + + result_png_jpg = ut.find_files( + path=get_path, pattern=["png", "jpg"], recursive=True, limit=None + ) + assert 0 < len(result_png_jpg) + result_png_jpg = ut.find_files( + path=get_path, pattern=["png", "jpg"], recursive=True, limit=-1 + ) + assert 0 < len(result_png_jpg) + + result_png_jpg_rdm1 = ut.find_files( + path=get_path, pattern=["png", "jpg"], recursive=True, limit=10, random_seed=1 + ) + result_png_jpg_rdm2 = ut.find_files( + path=get_path, pattern=["png", "jpg"], recursive=True, limit=10, random_seed=2 + ) + assert result_png_jpg_rdm1 != result_png_jpg_rdm2 + assert len(result_png_jpg_rdm1) == len(result_png_jpg_rdm2) + + with pytest.raises(ValueError): + ut.find_files(path=get_path, pattern=["png", "jpg"], recursive=True, limit=-2) + with pytest.raises(ValueError): + ut.find_files( + path=get_path, pattern=["png", "jpg"], recursive=True, limit=[2, 4, 5] + ) + with pytest.raises(ValueError): + ut.find_files(path=get_path, pattern=["png", "jpg"], recursive=True, limit=[2]) + with pytest.raises(ValueError): + ut.find_files( + path=get_path, pattern=["png", "jpg"], recursive=True, limit="limit" + ) def test_initialize_dict(get_path): diff --git a/ammico/utils.py b/ammico/utils.py index 3aaf430a..6d0c0145 100644 --- a/ammico/utils.py +++ b/ammico/utils.py @@ -3,11 +3,17 @@ from pandas import DataFrame, read_csv import pooch import importlib_resources +import collections +import random pkg = importlib_resources.files("ammico") +def iterable(arg): + return isinstance(arg, collections.abc.Iterable) and not isinstance(arg, str) + + class DownloadResource: """A remote resource that needs on demand downloading. @@ -48,8 +54,52 @@ def analyse_image(self): raise NotImplementedError() +def _match_pattern(path, pattern, recursive): + # helper function for find_files + # find all matches for a single pattern. + + if pattern.startswith("."): + pattern = pattern[1:] + if recursive: + search_path = f"{path}/**/*.{pattern}" + else: + search_path = f"{path}/*.{pattern}" + return list(glob.glob(search_path, recursive=recursive)) + + +def _limit_results(results, limit): + # helper function for find_files + # use -1 or None to return all images + if limit == -1 or limit is None: + limit = len(results) + + # limit or batch the images + if isinstance(limit, int): + if limit < -1: + raise ValueError("limit must be an integer greater than 0 or equal to -1") + results = results[:limit] + + elif iterable(limit): + if len(limit) == 2: + results = results[limit[0] : limit[1]] + else: + raise ValueError( + f"limit must be an integer or a tuple of length 2, but is {limit}" + ) + else: + raise ValueError( + f"limit must be an integer or a tuple of length 2, but is {limit}" + ) + + return results + + def find_files( - path: str = None, pattern: str = "*.png", recursive: bool = True, limit: int = 20 + path: str = None, + pattern=["png", "jpg", "jpeg", "gif", "webp", "avif", "tiff"], + recursive: bool = True, + limit=20, + random_seed: int = None, ) -> list: """Find image files on the file system. @@ -57,22 +107,37 @@ def find_files( path (str, optional): The base directory where we are looking for the images. Defaults to None, which uses the XDG data directory if set or the current working directory otherwise. - pattern (str, optional): The naming pattern that the filename should match. Defaults to - "*.png". Can be used to allow other patterns or to only include - specific prefixes or suffixes. - recursive (bool, optional): Whether to recurse into subdirectories. Default is set to False. - limit (int, optional): The maximum number of images to be found. - Defaults to 20. To return all images, set to None. - + pattern (str|list, optional): The naming pattern that the filename should match. + Use either '.ext' or just 'ext' + Defaults to ["png", "jpg", "jpeg", "gif", "webp", "avif","tiff"]. Can be used to allow other patterns or to only include + specific prefixes or suffixes. + recursive (bool, optional): Whether to recurse into subdirectories. Default is set to True. + limit (int/list, optional): The maximum number of images to be found. + Provide a list or tuple of length 2 to batch the images. + Defaults to 20. To return all images, set to None or -1. + random_seed (int, optional): The random seed to use for shuffling the images. + If None is provided the data will not be shuffeled. Defaults to None. Returns: list: A list with all filenames including the path. """ + if path is None: path = os.environ.get("XDG_DATA_HOME", ".") - result = list(glob.glob(f"{path}/{pattern}", recursive=recursive)) - if limit is not None: - result = result[:limit] - return result + + if isinstance(pattern, str): + pattern = [pattern] + results = [] + for p in pattern: + results.extend(_match_pattern(path, p, recursive=recursive)) + + if len(results) == 0: + raise FileNotFoundError(f"No files found in {path} with pattern '{pattern}'") + + if random_seed is not None: + random.seed(random_seed) + random.shuffle(results) + + return _limit_results(results, limit) def initialize_dict(filelist: list) -> dict: