Skip to content

Commit

Permalink
Fixed cache default value, moved ValuError t othe right place, added …
Browse files Browse the repository at this point in the history
…to docstring.
  • Loading branch information
mahenning committed Apr 23, 2024
1 parent 7d8da3b commit 55a77b8
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions scraibe/diarisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _save_token(token):
def load_model(cls,
model: str = PYANNOTE_DEFAULT_CONFIG,
use_auth_token: str = None,
cache_token: bool = True,
cache_token: bool = False,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None,
device: str = None,
Expand All @@ -196,11 +196,12 @@ def load_model(cls,

"""
Loads a pretrained model from pyannote.audio,
either from a local cache or online repository.
either from a local cache or some online repository.
Args:
model: Path or identifier for the pyannote model.
default: /models/pyannote/speaker_diarization/config.yaml
default: '/home/[user]/.cache/torch/models/pyannote/config.yaml'
or one of 'jaikinator/scraibe', 'pyannote/speaker-diarization-3.1'
token: Optional HUGGINGFACE_TOKEN for authenticated access.
cache_token: Whether to cache the token locally for future use.
cache_dir: Directory for caching models.
Expand Down Expand Up @@ -261,8 +262,8 @@ def load_model(cls,
model = _model
if cache_token and use_auth_token is not None:
cls._save_token(use_auth_token)
if not os.path.exists(model) and use_auth_token is None:

if use_auth_token is None:
use_auth_token = cls._get_token()
else:
raise FileNotFoundError(f'No local model or directory found at {model}.')
Expand All @@ -271,18 +272,17 @@ def load_model(cls,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
hparams_file=hparams_file,)

# try to move the model to the device
if device is None:
device = "cuda" if is_available() else "cpu"

_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict

if _model is None:
raise ValueError('Unable to load model either from local cache' \
'or from huggingface.co models. Please check your token' \
'or your local model path')


# try to move the model to the device
if device is None:
device = "cuda" if is_available() else "cpu"

_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict

return cls(_model)

@staticmethod
Expand Down

0 comments on commit 55a77b8

Please sign in to comment.