Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default path to pyannote model with fallback option. #71

Merged
merged 4 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 37 additions & 23 deletions scraibe/diarisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from torch import Tensor
from torch import device as torch_device
from torch.cuda import is_available, current_device
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError

from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
Annotation = TypeVar('Annotation')
Expand Down Expand Up @@ -183,7 +185,7 @@ def _save_token(token):

@classmethod
def load_model(cls,
model: str = PYANNOTE_DEFAULT_CONFIG,
model: str = PYANNOTE_DEFAULT_CONFIG,
use_auth_token: str = None,
cache_token: bool = True,
JSchmie marked this conversation as resolved.
Show resolved Hide resolved
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
Expand All @@ -210,28 +212,20 @@ def load_model(cls,
Returns:
mahenning marked this conversation as resolved.
Show resolved Hide resolved
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
"""


if cache_token and use_auth_token is not None:
cls._save_token(use_auth_token)

if not os.path.exists(model) and use_auth_token is None:
use_auth_token = cls._get_token()

elif os.path.exists(model) and not use_auth_token:
if isinstance(model, str) and os.path.exists(model):
# check if model can be found locally nearby the config file
with open(model, 'r') as file:
config = yaml.safe_load(file)

path_to_model = config['pipeline']['params']['segmentation']

if not os.path.exists(path_to_model):
mahenning marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(f"Model not found at {path_to_model}. "\
"Trying to find it nearby the config file.")
warnings.warn(f"Model not found at {path_to_model}. "
"Trying to find it nearby the config file.")

pwd = model.split("/")[:-1]
pwd = "/".join(pwd)

path_to_model = os.path.join(pwd, "pytorch_model.bin")

if not os.path.exists(path_to_model):
Expand All @@ -245,18 +239,38 @@ def load_model(cls,
warnings.warn("Found more than one .bin file. "\
"or none. Please specify the path to the model " \
"or setup a huggingface token.")

raise FileNotFoundError

warnings.warn(f"Found model at {path_to_model} overwriting config file.")

config['pipeline']['params']['segmentation'] = path_to_model

with open(model, 'w') as file:
yaml.dump(config, file)

_model = Pipeline.from_pretrained(model,
use_auth_token = use_auth_token,
cache_dir = cache_dir,
hparams_file = hparams_file,)
elif isinstance(model, tuple):
JSchmie marked this conversation as resolved.
Show resolved Hide resolved
try:
_model = model[0]
HfApi().model_info(_model)
model = _model
use_auth_token = None
except RepositoryNotFoundError:
JSchmie marked this conversation as resolved.
Show resolved Hide resolved
print(f'{model[0]} not found on Huggingface, \
trying {model[1]}')
_model = model[1]
HfApi().model_info(_model)
model = _model
if cache_token and use_auth_token is not None:
cls._save_token(use_auth_token)

if not os.path.exists(model) and use_auth_token is None:
use_auth_token = cls._get_token()
else:
raise FileNotFoundError(f'No local model or directory found at {model}.')

_model = Pipeline.from_pretrained(model,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
hparams_file=hparams_file,)

# try to move the model to the device
if device is None:
Expand Down
2 changes: 1 addition & 1 deletion scraibe/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
else 'pyannote/speaker-diarization-3.1'
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')
JSchmie marked this conversation as resolved.
Show resolved Hide resolved

def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
"""Configure diarization pipeline from a YAML file.
Expand Down
Loading