Skip to content

Commit

Permalink
Merge branch 'develop' into feat/pyannote-audio-cli
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin authored Dec 9, 2024
2 parents fa81c6e + ffc43ee commit fe4ff34
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 89 deletions.
94 changes: 29 additions & 65 deletions pyannote/audio/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, List, Optional, Text, Tuple, Union
from urllib.parse import urlparse

import pytorch_lightning as pl
import torch
Expand All @@ -39,7 +38,7 @@
from huggingface_hub.utils import RepositoryNotFoundError
from lightning_fabric.utilities.cloud_io import _load as pl_load
from pyannote.core import SlidingWindow
from pytorch_lightning.utilities.model_summary import ModelSummary
from pytorch_lightning.utilities.model_summary.model_summary import ModelSummary
from torch.utils.data import DataLoader

from pyannote.audio import __version__
Expand All @@ -57,8 +56,6 @@
"PYANNOTE_CACHE",
os.path.expanduser("~/.cache/torch/pyannote"),
)
HF_PYTORCH_WEIGHTS_NAME = "pytorch_model.bin"
HF_LIGHTNING_CONFIG_NAME = "config.yaml"


# NOTE: needed to backward compatibility to load models trained before pyannote.audio 3.x
Expand Down Expand Up @@ -86,6 +83,8 @@ class Model(pl.LightningModule):
Task addressed by the model.
"""

MODEL_CHECKPOINT = "pytorch_model.bin"

def __init__(
self,
sample_rate: int = 16000,
Expand Down Expand Up @@ -530,10 +529,9 @@ def from_pretrained(
cls,
checkpoint: Union[Path, Text],
map_location=None,
hparams_file: Union[Path, Text] = None,
strict: bool = True,
subfolder: Optional[str] = None,
use_auth_token: Union[Text, None] = None,
use_auth_token: Union[Text, None] = None, # todo: deprecate in favor of token
cache_dir: Union[Path, Text] = CACHE_DIR,
**kwargs,
) -> "Model":
Expand All @@ -542,21 +540,13 @@ def from_pretrained(
Parameters
----------
checkpoint : Path or str
Path to checkpoint, or a remote URL, or a model identifier from
the hf.co model hub.
Model checkpoint, provided as one of the following:
* path to a local `pytorch_model.bin` model checkpoint
* path to a local directory containing such a file
* identifier of a model on huggingface.co model hub
map_location: optional
Same role as in torch.load().
Defaults to `lambda storage, loc: storage`.
hparams_file : Path or str, optional
Path to a .yaml file with hierarchical structure as in this example:
drop_prob: 0.2
dataloader:
batch_size: 32
You most likely won’t need this since Lightning will always save the
hyperparameters to the checkpoint. However, if your checkpoint weights
do not have the hyperparameters saved, use this method to pass in a .yaml
file with the hparams you would like to use. These will be converted
into a dict and passed into your Model for use.
strict : bool, optional
Whether to strictly enforce that the keys in checkpoint match
the keys returned by this module’s state dict. Defaults to True.
Expand All @@ -583,21 +573,22 @@ def from_pretrained(
torch.load
"""

# pytorch-lightning expects str, not Path.
checkpoint = str(checkpoint)
if hparams_file is not None:
hparams_file = str(hparams_file)

# resolve the checkpoint to
# something that pl will handle
if os.path.isfile(checkpoint):
path_for_pl = checkpoint
elif urlparse(checkpoint).scheme in ("http", "https"):
path_for_pl = checkpoint
# if checkpoint is a directory, look for the model checkpoint
# inside this directory (or inside a subfolder if specified)
if os.path.isdir(checkpoint):
if subfolder:
path_to_model_checkpoint = (
Path(checkpoint) / subfolder / cls.MODEL_CHECKPOINT
)
else:
path_to_model_checkpoint = Path(checkpoint) / cls.MODEL_CHECKPOINT

# if checkpoint is a file, use it as is
elif os.path.isfile(checkpoint):
path_to_model_checkpoint = checkpoint

# otherwise, assume that the checkpoint is hosted on HF model hub
else:
# Finally, let's try to find it on Hugging Face model hub
# e.g. julien-c/voice-activity-detection is a valid model id
# and julien-c/voice-activity-detection@main supports specifying a commit/branch/tag.
if "@" in checkpoint:
model_id = checkpoint.split("@")[0]
revision = checkpoint.split("@")[1]
Expand All @@ -606,16 +597,16 @@ def from_pretrained(
revision = None

try:
path_for_pl = hf_hub_download(
path_to_model_checkpoint = hf_hub_download(
model_id,
HF_PYTORCH_WEIGHTS_NAME,
cls.MODEL_CHECKPOINT,
subfolder=subfolder,
repo_type="model",
revision=revision,
library_name="pyannote",
library_version=__version__,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
token=use_auth_token,
)
except RepositoryNotFoundError:
print(
Expand All @@ -633,31 +624,6 @@ def from_pretrained(
)
return None

# HACK Huggingface download counters rely on config.yaml
# HACK Therefore we download config.yaml even though we
# HACK do not use it. Fails silently in case model does not
# HACK have a config.yaml file.
try:
_ = hf_hub_download(
model_id,
HF_LIGHTNING_CONFIG_NAME,
repo_type="model",
revision=revision,
library_name="pyannote",
library_version=__version__,
cache_dir=cache_dir,
# force_download=False,
# proxies=None,
# etag_timeout=10,
# resume_download=False,
use_auth_token=use_auth_token,
# local_files_only=False,
# legacy_cache_layout=False,
)

except Exception:
pass

if map_location is None:

def default_map_location(storage, loc):
Expand All @@ -666,17 +632,16 @@ def default_map_location(storage, loc):
map_location = default_map_location

# obtain model class from the checkpoint
loaded_checkpoint = pl_load(path_for_pl, map_location=map_location)
loaded_checkpoint = pl_load(path_to_model_checkpoint, map_location=map_location)
module_name: str = loaded_checkpoint["pyannote.audio"]["architecture"]["module"]
module = import_module(module_name)
class_name: str = loaded_checkpoint["pyannote.audio"]["architecture"]["class"]
Klass = getattr(module, class_name)

try:
model = Klass.load_from_checkpoint(
path_for_pl,
path_to_model_checkpoint,
map_location=map_location,
hparams_file=hparams_file,
strict=strict,
**kwargs,
)
Expand All @@ -689,9 +654,8 @@ def default_map_location(storage, loc):
)
warnings.warn(msg)
model = Klass.load_from_checkpoint(
path_for_pl,
path_to_model_checkpoint,
map_location=map_location,
hparams_file=hparams_file,
strict=False,
**kwargs,
)
Expand Down
61 changes: 37 additions & 24 deletions pyannote/audio/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,68 +44,73 @@
from pyannote.audio.utils.reproducibility import fix_reproducibility
from pyannote.audio.utils.version import check_version

PIPELINE_PARAMS_NAME = "config.yaml"


def expand_subfolders(
config, hf_model_id, use_auth_token: Optional[Text] = None
config,
model_id: Optional[Text] = None,
use_auth_token: Optional[Text] = None,
) -> None:
"""Expand $model subfolders in config
Processes `config` dictionary recursively and replaces "$model/{subfolder}"
values with {"checkpoint": hf_model_id,
values with {"checkpoint": model_id,
"subfolder": {subfolder},
"use_auth_token": use_auth_token}
Parameters
----------
config : dict
hf_model_id : str
Parent Huggingface model identifier
model_id : str, optional
Model identifier when loading from the huggingface.co model hub.
use_auth_token : str, optional
Token used for loading from the root folder.
"""

if isinstance(config, dict):
for key, value in config.items():
if isinstance(value, str) and value.startswith("$model/"):
subfolder = "/".join(value.split("/")[1:])
config[key] = {
"checkpoint": hf_model_id,
"checkpoint": model_id,
"subfolder": subfolder,
"use_auth_token": use_auth_token,
}
else:
expand_subfolders(value, hf_model_id, use_auth_token=use_auth_token)
expand_subfolders(value, model_id, use_auth_token=use_auth_token)

elif isinstance(config, list):
for idx, value in enumerate(config):
if isinstance(value, str) and value.startswith("$model/"):
subfolder = "/".join(value.split("/")[1:])
config[idx] = {
"checkpoint": hf_model_id,
"checkpoint": model_id,
"subfolder": subfolder,
"use_auth_token": use_auth_token,
}
else:
expand_subfolders(value, hf_model_id, use_auth_token=use_auth_token)
expand_subfolders(value, model_id, use_auth_token=use_auth_token)


class Pipeline(_Pipeline):
PIPELINE_CHECKPOINT = "config.yaml"

@classmethod
def from_pretrained(
cls,
checkpoint_path: Union[Text, Path],
checkpoint: Union[Text, Path],
hparams_file: Union[Text, Path] = None,
use_auth_token: Union[Text, None] = None,
use_auth_token: Union[Text, None] = None, # todo: deprecate in favor of token
cache_dir: Union[Path, Text] = CACHE_DIR,
) -> "Pipeline":
"""Load pretrained pipeline
Parameters
----------
checkpoint_path : Path or str
Path to pipeline checkpoint, or a remote URL,
or a pipeline identifier from the huggingface.co model hub.
checkpoint : str
Pipeline checkpoint, provided as one of the following:
* path to a local `config.yaml` pipeline checkpoint
* path to a local directory containing such a file
* identifier of a pipeline on huggingface.co model hub
hparams_file: Path or str, optional
use_auth_token : str, optional
When loading a private huggingface.co pipeline, set `use_auth_token`
Expand All @@ -116,29 +121,36 @@ def from_pretrained(
environment variable, or "~/.cache/torch/pyannote" when unset.
"""

checkpoint_path = str(checkpoint_path)
# if checkpoint is a directory, look for the pipeline checkpoint
# inside this directory
if os.path.isdir(checkpoint):
model_id = Path(checkpoint)
config_yml = model_id / cls.PIPELINE_CHECKPOINT

if os.path.isfile(checkpoint_path):
config_yml = checkpoint_path
# if checkpoint is a file, assume it is the pipeline checkpoint
elif os.path.isfile(checkpoint):
model_id = Path(checkpoint).parent
config_yml = checkpoint

# otherwise, assume that the checkpoint is hosted on HF model hub
else:
if "@" in checkpoint_path:
model_id = checkpoint_path.split("@")[0]
revision = checkpoint_path.split("@")[1]
if "@" in checkpoint:
model_id = checkpoint.split("@")[0]
revision = checkpoint.split("@")[1]
else:
model_id = checkpoint_path
model_id = checkpoint
revision = None

try:
config_yml = hf_hub_download(
model_id,
PIPELINE_PARAMS_NAME,
cls.PIPELINE_CHECKPOINT,
repo_type="model",
revision=revision,
library_name="pyannote",
library_version=__version__,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
token=use_auth_token,
)

except RepositoryNotFoundError:
Expand All @@ -159,6 +171,7 @@ def from_pretrained(

with open(config_yml, "r") as fp:
config = yaml.load(fp, Loader=yaml.SafeLoader)

expand_subfolders(config, model_id, use_auth_token=use_auth_token)

if "version" in config:
Expand Down

0 comments on commit fe4ff34

Please sign in to comment.