Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default path to pyannote model with fallback option. #71

Merged
merged 4 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 52 additions & 34 deletions scraibe/diarisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from torch import Tensor
from torch import device as torch_device
from torch.cuda import is_available, current_device
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError

from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
Annotation = TypeVar('Annotation')
Expand Down Expand Up @@ -183,9 +185,9 @@ def _save_token(token):

@classmethod
def load_model(cls,
model: str = PYANNOTE_DEFAULT_CONFIG,
model: str = PYANNOTE_DEFAULT_CONFIG,
use_auth_token: str = None,
cache_token: bool = True,
cache_token: bool = False,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None,
device: str = None,
Expand All @@ -194,11 +196,12 @@ def load_model(cls,

"""
Loads a pretrained model from pyannote.audio,
either from a local cache or online repository.
either from a local cache or some online repository.

Args:
model: Path or identifier for the pyannote model.
default: /models/pyannote/speaker_diarization/config.yaml
default: '/home/[user]/.cache/torch/models/pyannote/config.yaml'
or one of 'jaikinator/scraibe', 'pyannote/speaker-diarization-3.1'
token: Optional HUGGINGFACE_TOKEN for authenticated access.
cache_token: Whether to cache the token locally for future use.
cache_dir: Directory for caching models.
Expand All @@ -210,33 +213,29 @@ def load_model(cls,
Returns:
mahenning marked this conversation as resolved.
Show resolved Hide resolved
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
"""


if cache_token and use_auth_token is not None:
cls._save_token(use_auth_token)

if not os.path.exists(model) and use_auth_token is None:
use_auth_token = cls._get_token()

elif os.path.exists(model) and not use_auth_token:
if isinstance(model, str) and os.path.exists(model):
# check if model can be found locally nearby the config file
with open(model, 'r') as file:
config = yaml.safe_load(file)

path_to_model = config['pipeline']['params']['segmentation']

if not os.path.exists(path_to_model):
mahenning marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(f"Model not found at {path_to_model}. "\
"Trying to find it nearby the config file.")
warnings.warn(f"Model not found at {path_to_model}. "
"Trying to find it nearby the config file.")

pwd = model.split("/")[:-1]
pwd = "/".join(pwd)

path_to_model = os.path.join(pwd, "pytorch_model.bin")

if not os.path.exists(path_to_model):
warnings.warn(f"Model not found at {path_to_model}. \
'Trying to find it nearby .bin files instead.")
warnings.warn(
'Searching for nearby files in a folder path is '
'deprecated and will be removed in future versions.',
category=DeprecationWarning)
# list elementes with the ending .bin
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
if len(bin_files) == 1:
Expand All @@ -245,30 +244,49 @@ def load_model(cls,
warnings.warn("Found more than one .bin file. "\
"or none. Please specify the path to the model " \
"or setup a huggingface token.")

raise FileNotFoundError

warnings.warn(f"Found model at {path_to_model} overwriting config file.")

config['pipeline']['params']['segmentation'] = path_to_model

with open(model, 'w') as file:
yaml.dump(config, file)

_model = Pipeline.from_pretrained(model,
use_auth_token = use_auth_token,
cache_dir = cache_dir,
hparams_file = hparams_file,)

# try to move the model to the device
if device is None:
device = "cuda" if is_available() else "cpu"

_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict

elif isinstance(model, tuple):
JSchmie marked this conversation as resolved.
Show resolved Hide resolved
try:
_model = model[0]
HfApi().model_info(_model)
model = _model
use_auth_token = None
except RepositoryNotFoundError:
JSchmie marked this conversation as resolved.
Show resolved Hide resolved
print(f'{model[0]} not found on Huggingface, \
trying {model[1]}')
_model = model[1]
HfApi().model_info(_model)
model = _model
if cache_token and use_auth_token is not None:
cls._save_token(use_auth_token)

if use_auth_token is None:
use_auth_token = cls._get_token()
else:
raise FileNotFoundError(f'No local model or directory found at {model}.')

_model = Pipeline.from_pretrained(model,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
hparams_file=hparams_file,)
if _model is None:
raise ValueError('Unable to load model either from local cache' \
'or from huggingface.co models. Please check your token' \
'or your local model path')


# try to move the model to the device
if device is None:
device = "cuda" if is_available() else "cpu"

_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict

return cls(_model)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion scraibe/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
else 'pyannote/speaker-diarization-3.1'
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')
JSchmie marked this conversation as resolved.
Show resolved Hide resolved

def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
"""Configure diarization pipeline from a YAML file.
Expand Down
Loading