From ba058c3e021f08d22cf4c1ed0701bcca3d6c9071 Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Tue, 18 Jun 2024 17:21:34 +0200 Subject: [PATCH 01/30] Implemented faster-whisper, removed WhisperX --- pyproject.toml | 2 +- requirements.txt | 2 +- scraibe/autotranscript.py | 2 +- scraibe/cli.py | 4 ++-- scraibe/misc.py | 2 +- scraibe/transcriber.py | 39 +++++++++++++++++++-------------------- 6 files changed, 25 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8c46bdb..caf02a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ python = "^3.9" tqdm = "^4.66.4" numpy = "^1.26.4" openai-whisper = "^20231117" -whisperx = "^3.1.3" +faster-whisper = "^1.0.1" "pyannote.audio" = "^3.1.1" torch = "^2.3.0" diff --git a/requirements.txt b/requirements.txt index f08e2e6..94ee85a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ tqdm>=4.65.0 numpy>=1.26.4 openai-whisper==20231117 -whisperx~=3.1.3 +faster-whisper~=1.0.1 pyannote.audio~=3.1.1 pyannote.core~=5.0.0 diff --git a/scraibe/autotranscript.py b/scraibe/autotranscript.py index 7391f1a..43dedc2 100644 --- a/scraibe/autotranscript.py +++ b/scraibe/autotranscript.py @@ -74,7 +74,7 @@ def __init__(self, whisper_model (Union[bool, str, whisper], optional): Path to whisper model or whisper model itself. whisper_type (str): - Type of whisper model to load. "whisper" or "whisperx". + Type of whisper model to load. "whisper" or "faster-whisper". diarisation_model (Union[bool, str, DiarisationType], optional): Path to pyannote diarization model or model itself. **kwargs: Additional keyword arguments for whisper diff --git a/scraibe/cli.py b/scraibe/cli.py index ee40c8b..a234132 100644 --- a/scraibe/cli.py +++ b/scraibe/cli.py @@ -36,8 +36,8 @@ def str2bool(string): help="List of audio files to transcribe.") parser.add_argument("--whisper-type", type=str, default="whisper", - choices=["whisper", "whisperx"], - help="Type of Whisper model to use ('whisper' or 'whisperx').") + choices=["whisper", "faster-whisper"], + help="Type of Whisper model to use ('whisper' or 'faster-whisper').") parser.add_argument("--whisper-model-name", default="medium", help="Name of the Whisper model to use.") diff --git a/scraibe/misc.py b/scraibe/misc.py index f12335f..56e9f3a 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -16,7 +16,7 @@ PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ - else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1') + else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1') def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index 0301955..cea7274 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -26,8 +26,7 @@ from whisper import Whisper from whisper import load_model as whisper_load_model -from whisperx.asr import WhisperModel -from whisperx import load_model as whisperx_load_model +from faster_whisper import WhisperModel as FasterWhisperModel from typing import TypeVar, Union, Optional from torch import Tensor, device from torch.cuda import is_available as cuda_is_available @@ -145,7 +144,7 @@ def load_model(cls, - 'large-v3' - 'large' whisper_type (str): - Type of whisper model to load. "whisper" or "whisperx". + Type of whisper model to load. "whisper" or "faster-whisper". download_root (str, optional): Path to download the model. Defaults to WHISPER_DEFAULT_PATH. device (Optional[Union[str, torch.device]], optional): @@ -272,7 +271,7 @@ def __repr__(self) -> str: return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})" -class WhisperXTranscriber(Transcriber): +class FasterWhisperTranscriber(Transcriber): def __init__(self, model: whisper, model_name: str) -> None: super().__init__(model, model_name) @@ -294,10 +293,10 @@ def transcribe(self, audio: Union[str, Tensor, ndarray], if isinstance(audio, Tensor): audio = audio.cpu().numpy() - result = self.model.transcribe(audio, *args, **kwargs) + result, _ = self.model.transcribe(audio, *args, **kwargs) text = "" - for seg in result['segments']: - text += seg['text'] + for seg in result: + text += seg.text return text @classmethod @@ -306,7 +305,7 @@ def load_model(cls, download_root: str = WHISPER_DEFAULT_PATH, device: Optional[Union[str, device]] = None, *args, **kwargs - ) -> 'WhisperXTranscriber': + ) -> 'FasterWhisperModel': """ Load whisper model. @@ -347,8 +346,8 @@ def load_model(cls, warnings.warn(f'Compute type {compute_type} not compatible with ' f'device {device}! Changing compute type to int8.') compute_type = 'int8' - _model = whisperx_load_model(model, download_root=download_root, - device=device, compute_type=compute_type) + _model = FasterWhisperModel(model, download_root=download_root, + device=device, compute_type=compute_type) return cls(_model, model_name=model) @@ -361,7 +360,7 @@ def _get_whisper_kwargs(**kwargs) -> dict: dict: Keyword arguments for whisper model. """ # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames - _possible_kwargs = signature(WhisperModel.transcribe).parameters.keys() + _possible_kwargs = signature(FasterWhisperModel.transcribe).parameters.keys() whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} @@ -375,7 +374,7 @@ def _get_whisper_kwargs(**kwargs) -> dict: return whisper_kwargs def __repr__(self) -> str: - return f"WhisperXTranscriber(model_name={self.model_name}, model={self.model})" + return f"FasterWhisperTranscriber(model_name={self.model_name}, model={self.model})" def load_transcriber(model: str = "medium", @@ -384,7 +383,7 @@ def load_transcriber(model: str = "medium", device: Optional[Union[str, device]] = None, in_memory: bool = False, *args, **kwargs - ) -> Union[WhisperTranscriber, WhisperXTranscriber]: + ) -> Union[WhisperTranscriber, FasterWhisperTranscriber]: """ Load whisper model. @@ -403,28 +402,28 @@ def load_transcriber(model: str = "medium", - 'large-v3' - 'large' whisper_type (str): - Type of whisper model to load. "whisper" or "whisperx". + Type of whisper model to load. "whisper" or "faster-whisper". download_root (str, optional): Path to download the model. Defaults to WHISPER_DEFAULT_PATH. - device (Optional[Union[str, torch.device]], optional): + device (Optional[Union[str, torch.device]], optional): Device to load model on. Defaults to None. - in_memory (bool, optional): Whether to load model in memory. + in_memory (bool, optional): Whether to load model in memory. Defaults to False. args: Additional arguments only to avoid errors. kwargs: Additional keyword arguments only to avoid errors. Returns: - Union[WhisperTranscriber, WhisperXTranscriber]: + Union[WhisperTranscriber, FasterWhisperTranscriber]: One of the Whisper variants as Transcrbier object initialized with the specified model. """ if whisper_type.lower() == 'whisper': _model = WhisperTranscriber.load_model( model, download_root, device, in_memory, *args, **kwargs) return _model - elif whisper_type.lower() == 'whisperx': - _model = WhisperXTranscriber.load_model( + elif whisper_type.lower() == 'faster-whisper': + _model = FasterWhisperTranscriber.load_model( model, download_root, device, *args, **kwargs) return _model else: raise ValueError(f'Model type not recognized, exptected "whisper" ' - f'or "whisperx", got {whisper_type}.') + f'or "faster-whisper", got {whisper_type}.') From f5ef26432bc50b70c159d081421d8fd8c7f706f8 Mon Sep 17 00:00:00 2001 From: Till Hanke Date: Mon, 1 Jul 2024 15:20:17 +0200 Subject: [PATCH 02/30] add num-speakers as cmdline option to scraibe --- scraibe/cli.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/scraibe/cli.py b/scraibe/cli.py index ee40c8b..c85e985 100644 --- a/scraibe/cli.py +++ b/scraibe/cli.py @@ -79,6 +79,8 @@ def str2bool(string): choices=sorted( LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="Language spoken in the audio. Specify None to perform language detection.") + parser.add_argument("--num-speakers", type=int, default=2, + help="Number of speakers in the audio.") args = parser.parse_args() @@ -117,8 +119,13 @@ def str2bool(string): else: task = "transcribe" - out = model.autotranscribe(audio, task=task, language=arg_dict.pop( - "language"), verbose=arg_dict.pop("verbose_output")) + out = model.autotranscribe( + audio, + task=task, + language=arg_dict.pop("language"), + verbose=arg_dict.pop("verbose_output"), + num_speakers=arg_dict.pop("num_speakers") + ) basename = audio.split("/")[-1].split(".")[0] print(f'Saving {basename}.{out_format} to {out_folder}') out.save(os.path.join( From cf63ac8e2e33ea8f9d80f6825716b0d0d0cd13b6 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Mon, 9 Sep 2024 09:50:43 +0000 Subject: [PATCH 03/30] update dependencies --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8c46bdb..86867a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,11 +31,11 @@ exclude =[ ] [tool.poetry.dependencies] python = "^3.9" -tqdm = "^4.66.4" +tqdm = "^4.66.5" numpy = "^1.26.4" openai-whisper = "^20231117" -whisperx = "^3.1.3" -"pyannote.audio" = "^3.1.1" +whisperx = "^3.1.5" +"pyannote.audio" = "^3.3.1" torch = "^2.3.0" [tool.poetry.group.dev.dependencies] From 533b199f4c884de87e9bb8320dd967c60ffb77c3 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Mon, 9 Sep 2024 09:57:45 +0000 Subject: [PATCH 04/30] downgraded to pyannote.audio==3.1.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 86867a0..9309239 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ python = "^3.9" tqdm = "^4.66.5" numpy = "^1.26.4" openai-whisper = "^20231117" -whisperx = "^3.1.5" +whisperx = "^3.1.1" "pyannote.audio" = "^3.3.1" torch = "^2.3.0" From d25fda58020402ee55bdf793efb14103c7336b76 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Mon, 9 Sep 2024 10:05:16 +0000 Subject: [PATCH 05/30] downgraded to 3.1.1 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9309239..098b025 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,8 +34,8 @@ python = "^3.9" tqdm = "^4.66.5" numpy = "^1.26.4" openai-whisper = "^20231117" -whisperx = "^3.1.1" -"pyannote.audio" = "^3.3.1" +whisperx = "^3.1.5" +"pyannote.audio" = "^3.1.1" torch = "^2.3.0" [tool.poetry.group.dev.dependencies] From 129f0ce39090ed7b567b4750085fac4adefcb37c Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Mon, 9 Sep 2024 10:06:13 +0000 Subject: [PATCH 06/30] updated versions --- requirements.txt | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/requirements.txt b/requirements.txt index f08e2e6..f43514f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,10 @@ -tqdm>=4.65.0 +tqdm>=4.66.5 numpy>=1.26.4 openai-whisper==20231117 -whisperx~=3.1.3 +whisperx~=3.1.5 pyannote.audio~=3.1.1 -pyannote.core~=5.0.0 -pyannote.database~=5.0.1 -pyannote.metrics~=3.2.1 -pyannote.pipeline~=3.0.1 torch>=2.0.0 From 53e57a06d70263a08467ecea9063d44738b9c0c7 Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Mon, 9 Sep 2024 12:25:14 +0200 Subject: [PATCH 07/30] Added tests for faster-whisper --- test/test_transcriber.py | 18 +++++++++--------- tests/test_diarization.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 9 deletions(-) create mode 100644 tests/test_diarization.py diff --git a/test/test_transcriber.py b/test/test_transcriber.py index 31765f6..bd1e9f5 100644 --- a/test/test_transcriber.py +++ b/test/test_transcriber.py @@ -1,6 +1,6 @@ import pytest from scraibe import (Transcriber, WhisperTranscriber, - WhisperXTranscriber, load_transcriber) + FasterWhisperTranscriber, load_transcriber) import torch @@ -35,24 +35,24 @@ def whisper_instance(): @pytest.fixture -def whisperx_instance(): - return load_transcriber('medium', whisper_type='whisperx') +def faster_whisper_instance(): + return load_transcriber('medium', whisper_type='faster-whisper') def test_whisper_base_initialization(whisper_instance): assert isinstance(whisper_instance, Transcriber) -def test_whisperx_base_initialization(whisperx_instance): - assert isinstance(whisperx_instance, Transcriber) +def test_faster_whisper_base_initialization(faster_whisper_instance): + assert isinstance(faster_whisper_instance, Transcriber) def test_whisper_transcriber_initialization(whisper_instance): assert isinstance(whisper_instance, WhisperTranscriber) -def test_whisperx_transcriber_initialization(whisperx_instance): - assert isinstance(whisperx_instance, WhisperXTranscriber) +def test_faster_whisper_transcriber_initialization(faster_whisper_instance): + assert isinstance(faster_whisper_instance, FasterWhisperTranscriber) def test_wrong_transcriber_initialization(): @@ -73,8 +73,8 @@ def test_whisper_transcribe(whisper_instance): assert isinstance(transcript, str) -def test_whisperx_transcribe(whisperx_instance): - model = whisperx_instance +def test_faster_whisper_transcribe(faster_whisper_instance): + model = faster_whisper_instance # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) transcript = model.transcribe('test/audio_test_2.mp4') assert isinstance(transcript, str) diff --git a/tests/test_diarization.py b/tests/test_diarization.py new file mode 100644 index 0000000..f9e81a5 --- /dev/null +++ b/tests/test_diarization.py @@ -0,0 +1,10 @@ +from os import environ + +environ["AUTOT_CACHE"] = "/mnt/disk1/Projekte/ScrAIbe/tests" +# environ["PYANNOTE_CACHE"] = "/mnt/disk1/Projekte/ScrAIbe/tests/pyannote" +# environ["TORCH_HOME"] = "/mnt/disk1/Projekte/ScrAIbe/tests/torch" + +from scraibe import Scraibe + +scraibe = Scraibe(whisper_type = "faster-whisper", whisper_model = "tiny") +print(scraibe.autotranscribe('/mnt/disk1/Projekte/ScrAIbe/test/audio_test_1.mp4')) \ No newline at end of file From 5b56b54da26d03b1bf3578cdbaaee4acfe5e9e67 Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Tue, 10 Sep 2024 09:37:46 +0200 Subject: [PATCH 08/30] Fixed pyannote env var import, now can use `PYANNOTE_CACHE` --- scraibe/misc.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/scraibe/misc.py b/scraibe/misc.py index f12335f..7b7cbf4 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -1,6 +1,5 @@ import os import yaml -from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR from argparse import Action from ast import literal_eval @@ -8,9 +7,7 @@ "AUTOT_CACHE", os.path.expanduser("~/.cache/torch/models"), ) - -if CACHE_DIR != PYANNOTE_CACHE_DIR: - os.environ["PYANNOTE_CACHE"] = os.path.join(CACHE_DIR, "pyannote") +os.environ["PYANNOTE_CACHE"] = os.path.join(CACHE_DIR, "pyannote") WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") From de9c81b3136652012cf92b345fe6b9621a670798 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Tue, 10 Sep 2024 09:01:59 +0000 Subject: [PATCH 09/30] added language to code support for faster whisper --- scraibe/transcriber.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index cea7274..abf1ace 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -26,7 +26,9 @@ from whisper import Whisper from whisper import load_model as whisper_load_model +from whisper.tokenizer import TO_LANGUAGE_CODE from faster_whisper import WhisperModel as FasterWhisperModel +from faster_whisper.tokenizer import _LANGUAGE_CODES as FASTER_WHISPER_LANGUAGE_CODES from typing import TypeVar, Union, Optional from torch import Tensor, device from torch.cuda import is_available as cuda_is_available @@ -369,14 +371,44 @@ def _get_whisper_kwargs(**kwargs) -> dict: whisper_kwargs["task"] = task if (language := kwargs.get("language")): + language = FasterWhisperTranscriber.convert_to_language_code(language) whisper_kwargs["language"] = language return whisper_kwargs + @staticmethod + def convert_to_language_code(lang : str) -> str: + """ + Load whisper model. + + Args: + lang (str): language as code or language name + + Returns: + language (str) code of language + """ + + # If the input is already in FASTER_WHISPER_LANGUAGE_CODES, return it directly + if lang in FASTER_WHISPER_LANGUAGE_CODES: + return lang + + # Normalize the input to lowercase + lang = lang.lower() + + # Check if the language name is in the TO_LANGUAGE_CODE mapping + if lang in TO_LANGUAGE_CODE: + return TO_LANGUAGE_CODE[lang] + + # If the language is not recognized, raise a ValueError with the available options + available_codes = ', '.join(FASTER_WHISPER_LANGUAGE_CODES) + raise ValueError(f"Language '{lang}' is not a valid language code or name. " + f"Available language codes are: {available_codes}.") + def __repr__(self) -> str: return f"FasterWhisperTranscriber(model_name={self.model_name}, model={self.model})" + def load_transcriber(model: str = "medium", whisper_type: str = 'whisper', download_root: str = WHISPER_DEFAULT_PATH, From 885d0c864ebd33f75416dd3300eca88f79c527c7 Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Tue, 10 Sep 2024 13:50:28 +0200 Subject: [PATCH 10/30] Fixed pyannote_cache, now looks it up before overwrite. --- scraibe/misc.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scraibe/misc.py b/scraibe/misc.py index 7b7cbf4..21099fb 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -7,7 +7,10 @@ "AUTOT_CACHE", os.path.expanduser("~/.cache/torch/models"), ) -os.environ["PYANNOTE_CACHE"] = os.path.join(CACHE_DIR, "pyannote") +os.getenv( + "PYANNOTE_CACHE", + os.path.join(CACHE_DIR, "pyannote"), +) WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") From ae1bae750fad871bb853dbcefcc9a332f3ae1b2d Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Tue, 10 Sep 2024 12:08:26 +0000 Subject: [PATCH 11/30] fixed pypi.yaml to run on push on main --- .github/workflows/pypi.yml | 44 ++++++++++---------------------------- 1 file changed, 11 insertions(+), 33 deletions(-) diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index a2641ee..1c75489 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -1,18 +1,14 @@ name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI on: - pull_request_target: - branches: - - develop - types: - - closed - paths: - - scraibe/** - - pyproject.toml - push: tags: - 'v*.*.*' + branches: + - "develop" + paths: + - "scraibe/**" + - "pyproject.toml" workflow_dispatch: inputs: @@ -27,13 +23,7 @@ on: jobs: Build-and-publish-to-Test-PyPI: - if: | - (github.event_name == 'workflow_dispatch' && - github.event.inputs.test == 'true') || - (github.event_name == 'pull_request_target' && - github.event.pull_request.merged && - contains(github.event.pull_request.labels.*.name, 'release')) || - (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')) + if: github.event_name != 'workflow_dispatch' || github.event.inputs.test == 'true' runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -72,28 +62,16 @@ jobs: needs: Test-PyPi-install runs-on: ubuntu-latest if: | - always() && - (( needs.Build-and-publish-to-Test-PyPI.result != 'failure' && - needs.Test-PyPi-install.result != 'failure' ) && - ((github.event_name == 'workflow_dispatch' && - github.event.inputs.publish_to_pypi == 'true') || - (github.event_name == 'pull_request_target' && - github.event.pull_request.merged && - contains(github.event.pull_request.labels.*.name, 'release')) || - (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')))) + always() && + (( needs.Build-and-publish-to-Test-PyPI.result != 'failure' && + needs.Test-PyPi-install.result != 'failure' ) || + ((github.event_name == 'workflow_dispatch' && + github.event.inputs.publish_to_pypi == 'true'))) steps: - - name: Checkout Repository Tags - uses: actions/checkout@v4 - if: github.ref == 'refs/heads/main' - with: - fetch-depth: '0' - branch: 'main' - name: Checkout Repository (Develop) uses: actions/checkout@v4 - if: github.ref == 'refs/heads/develop' with: fetch-depth: '0' - branch: 'develop' - name: Set up Poetry 📦 uses: JRubics/poetry-publish@v1.16 with: From 5c0386edaca8e22d141b7eb9b94f60498cb7f7fe Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Tue, 10 Sep 2024 15:07:57 +0000 Subject: [PATCH 12/30] define new Versions of pyannote and faster-whisper --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index caf02a2..2c346a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,8 +34,8 @@ python = "^3.9" tqdm = "^4.66.4" numpy = "^1.26.4" openai-whisper = "^20231117" -faster-whisper = "^1.0.1" -"pyannote.audio" = "^3.1.1" +faster-whisper = "^1.0.3" +"pyannote.audio" = "^3.3.1" torch = "^2.3.0" [tool.poetry.group.dev.dependencies] From 51bf211d27469735aef86fd8c5ff78ec492042d8 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Tue, 10 Sep 2024 15:09:35 +0000 Subject: [PATCH 13/30] updated deps --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 94ee85a..66d7857 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,9 +2,9 @@ tqdm>=4.65.0 numpy>=1.26.4 openai-whisper==20231117 -faster-whisper~=1.0.1 +faster-whisper~=1.0.3 -pyannote.audio~=3.1.1 +pyannote.audio~=3.3.1 pyannote.core~=5.0.0 pyannote.database~=5.0.1 pyannote.metrics~=3.2.1 From ab7b43ac489cef8967137b05162c7382d7247169 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Tue, 10 Sep 2024 15:22:18 +0000 Subject: [PATCH 14/30] set test whisper model to tiny --- test/test_transcriber.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_transcriber.py b/test/test_transcriber.py index bd1e9f5..5bfe3cf 100644 --- a/test/test_transcriber.py +++ b/test/test_transcriber.py @@ -31,12 +31,12 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription): @pytest.fixture def whisper_instance(): - return load_transcriber('medium', whisper_type='whisper') + return load_transcriber('tiny', whisper_type='whisper') @pytest.fixture def faster_whisper_instance(): - return load_transcriber('medium', whisper_type='faster-whisper') + return load_transcriber('tiny', whisper_type='faster-whisper') def test_whisper_base_initialization(whisper_instance): @@ -57,7 +57,7 @@ def test_faster_whisper_transcriber_initialization(faster_whisper_instance): def test_wrong_transcriber_initialization(): with pytest.raises(ValueError): - load_transcriber('medium', whisper_type='wrong_whisper') + load_transcriber('tiny', whisper_type='wrong_whisper') def test_get_whisper_kwargs(): From 9df05033dadffbd1642e7362751f604f53be3cd4 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Fri, 27 Sep 2024 13:22:12 +0000 Subject: [PATCH 15/30] align minimal torch version --- environment.yml | 256 ----------------------------------------------- pyproject.toml | 2 +- requirements.txt | 2 +- 3 files changed, 2 insertions(+), 258 deletions(-) delete mode 100644 environment.yml diff --git a/environment.yml b/environment.yml deleted file mode 100644 index 7913480..0000000 --- a/environment.yml +++ /dev/null @@ -1,256 +0,0 @@ -channels: - - pytorch - - defaults -dependencies: - - _libgcc_mutex=0.1=main - - _openmp_mutex=5.1=1_gnu - - blas=1.0=mkl - - brotlipy=0.7.0=py39h27cfd23_1003 - - bzip2=1.0.8=h7b6447c_0 - - ca-certificates=2023.05.30=h06a4308_0 - - certifi=2023.5.7=py39h06a4308_0 - - cffi=1.15.1=py39h5eee18b_3 - - cryptography=39.0.1=py39h9ce1e76_2 - - cudatoolkit=11.3.1=h2bc3f7f_2 - - ffmpeg=4.2.2=h20bf706_0 - - flit-core=3.8.0=py39h06a4308_0 - - freetype=2.12.1=h4a9f257_0 - - giflib=5.2.1=h5eee18b_3 - - gmp=6.2.1=h295c915_3 - - gnutls=3.6.15=he1e5248_0 - - idna=3.4=py39h06a4308_0 - - intel-openmp=2021.4.0=h06a4308_3561 - - jpeg=9e=h5eee18b_1 - - lame=3.100=h7b6447c_0 - - lcms2=2.12=h3be6417_0 - - ld_impl_linux-64=2.38=h1181459_1 - - lerc=3.0=h295c915_0 - - libdeflate=1.17=h5eee18b_0 - - libffi=3.4.2=h6a678d5_6 - - libgcc-ng=11.2.0=h1234567_1 - - libgomp=11.2.0=h1234567_1 - - libidn2=2.3.2=h7f8727e_0 - - libopus=1.3.1=h7b6447c_0 - - libpng=1.6.39=h5eee18b_0 - - libstdcxx-ng=11.2.0=h1234567_1 - - libtasn1=4.16.0=h27cfd23_0 - - libtiff=4.5.0=h6a678d5_2 - - libunistring=0.9.10=h27cfd23_0 - - libuv=1.44.2=h5eee18b_0 - - libvpx=1.7.0=h439df22_0 - - libwebp=1.2.4=h11a3e52_1 - - libwebp-base=1.2.4=h5eee18b_1 - - lz4-c=1.9.4=h6a678d5_0 - - mkl=2021.4.0=h06a4308_640 - - mkl-service=2.4.0=py39h7f8727e_0 - - mkl_fft=1.3.1=py39hd3c417c_0 - - mkl_random=1.2.2=py39h51133e4_0 - - ncurses=6.4=h6a678d5_0 - - nettle=3.7.3=hbbd107a_1 - - numpy=1.23.5=py39h14f4228_0 - - numpy-base=1.23.5=py39h31eccc5_0 - - openh264=2.1.1=h4ff587b_0 - - openssl=3.0.9=h7f8727e_0 - - pillow=9.4.0=py39h6a678d5_0 - - pip=23.0.1=py39h06a4308_0 - - pycparser=2.21=pyhd3eb1b0_0 - - pyopenssl=23.0.0=py39h06a4308_0 - - pysocks=1.7.1=py39h06a4308_0 - - python=3.9.16=h955ad1f_3 - - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0 - - pytorch-mutex=1.0=cuda - - readline=8.2=h5eee18b_0 - - requests=2.28.1=py39h06a4308_1 - - setuptools=65.6.3=py39h06a4308_0 - - six=1.16.0=pyhd3eb1b0_1 - - sqlite=3.41.2=h5eee18b_0 - - tk=8.6.12=h1ccaba5_0 - - torchaudio=0.11.0=py39_cu113 - - torchvision=0.12.0=py39_cu113 - - tzdata=2023c=h04d1e81_0 - - wheel=0.38.4=py39h06a4308_0 - - x264=1!157.20191217=h7b6447c_0 - - xz=5.4.2=h5eee18b_0 - - zlib=1.2.13=h5eee18b_0 - - zstd=1.5.4=hc292b87_0 - - pip: - - absl-py==1.3.0 - - aiofiles==23.1.0 - - aiohttp==3.8.3 - - aiosignal==1.3.1 - - alembic==1.9.1 - - altair==5.0.1 - - annotated-types==0.5.0 - - ansi2html==1.8.0 - - antlr4-python3-runtime==4.9.3 - - anyio==3.7.1 - - appdirs==1.4.4 - - asteroid-filterbanks==0.4.0 - - async-timeout==4.0.2 - - attrs==22.2.0 - - audioread==3.0.0 - - autopage==0.5.1 - - backports-cached-property==1.0.2 - - cachetools==5.2.0 - - charset-normalizer==2.1.1 - - click==8.1.3 - - cliff==4.1.0 - - cmaes==0.9.0 - - cmake==3.26.4 - - cmd2==2.4.2 - - colorama==0.4.6 - - colorlog==6.7.0 - - commonmark==0.9.1 - - contourpy==1.0.6 - - cycler==0.11.0 - - dash==2.12.1 - - dash-core-components==2.0.0 - - dash-html-components==2.0.0 - - dash-table==5.0.0 - - decorator==4.4.2 - - docopt==0.6.2 - - einops==0.3.2 - - exceptiongroup==1.1.1 - - fastapi==0.100.0 - - ffmpeg-python==0.2.0 - - ffmpy==0.3.0 - - filelock==3.8.0 - - flask==2.2.5 - - fonttools==4.38.0 - - frozenlist==1.3.3 - - fsspec==2022.11.0 - - future==0.18.2 - - google-auth==2.15.0 - - google-auth-oauthlib==0.4.6 - - gradio==3.36.1 - - gradio-client==0.2.7 - - greenlet==2.0.1 - - grpcio==1.51.1 - - h11==0.14.0 - - hmmlearn==0.2.8 - - httpcore==0.17.3 - - httpx==0.24.1 - - huggingface-hub==0.16.4 - - humanize==4.7.0 - - hyperpyyaml==1.1.0 - - imageio==2.23.0 - - imageio-ffmpeg==0.4.7 - - importlib-metadata==4.13.0 - - importlib-resources==5.12.0 - - iniconfig==2.0.0 - - itsdangerous==2.1.2 - - jinja2==3.1.2 - - joblib==1.2.0 - - jsonschema==4.18.0 - - jsonschema-specifications==2023.6.1 - - julius==0.2.7 - - kiwisolver==1.4.4 - - librosa==0.9.2 - - linkify-it-py==2.0.2 - - lit==16.0.5.post0 - - llvmlite==0.39.1 - - mako==1.2.4 - - markdown==3.4.1 - - markdown-it-py==2.2.0 - - markupsafe==2.1.1 - - matplotlib==3.7.1 - - mdit-py-plugins==0.3.3 - - mdurl==0.1.2 - - more-itertools==9.0.0 - - moviepy==1.0.3 - - mpmath==1.2.1 - - multidict==6.0.4 - - nest-asyncio==1.5.7 - - networkx==2.8.8 - - numba==0.56.4 - - oauthlib==3.2.2 - - omegaconf==2.3.0 - - openai-whisper==20230314 - - optuna==3.0.5 - - orjson==3.9.2 - - packaging==21.3 - - pandas==1.5.2 - - pbr==5.11.0 - - plotly==5.15.0 - - pluggy==1.0.0 - - pooch==1.6.0 - - prettytable==3.5.0 - - primepy==1.3 - - proglog==0.1.10 - - protobuf==3.20.1 - - pyannote-audio==2.1.1 - - pyannote-core==4.5 - - pyannote-database==4.1.3 - - pyannote-metrics==3.2.1 - - pyannote-pipeline==2.3 - - pyasn1==0.4.8 - - pyasn1-modules==0.2.8 - - pydantic==2.0.2 - - pydantic-core==2.1.2 - - pydeprecate==0.3.2 - - pydub==0.25.1 - - pygments==2.13.0 - - pyparsing==3.0.9 - - pyperclip==1.8.2 - - pytest==7.3.1 - - python-dateutil==2.8.2 - - python-multipart==0.0.6 - - pytorch-lightning==1.6.5 - - pytorch-metric-learning==1.6.3 - - pytz==2022.7 - - pyyaml==6.0 - - qtfaststart==1.8 - - referencing==0.29.1 - - regex==2022.10.31 - - requests-oauthlib==1.3.1 - - resampy==0.4.2 - - retrying==1.3.4 - - rich==12.6.0 - - rpds-py==0.8.10 - - rsa==4.9 - - ruamel-yaml==0.17.21 - - ruamel-yaml-clib==0.2.7 - - ruff==0.0.272 - - scikit-learn==1.2.0 - - scipy==1.8.1 - - semantic-version==2.10.0 - - semver==2.13.0 - - sentencepiece==0.1.97 - - setuptools-rust==1.5.2 - - shellingham==1.5.0 - - simplejson==3.18.0 - - singledispatchmethod==1.0 - - sniffio==1.3.0 - - sortedcontainers==2.4.0 - - soundfile==0.10.3.post1 - - speechbrain==0.5.14 - - sqlalchemy==1.4.45 - - starlette==0.27.0 - - stevedore==4.1.1 - - sympy==1.11.1 - - tabulate==0.9.0 - - tenacity==8.2.2 - - tensorboard==2.11.0 - - tensorboard-data-server==0.6.1 - - tensorboard-plugin-wit==1.8.1 - - threadpoolctl==3.1.0 - - tiktoken==0.3.1 - - tokenizers==0.13.2 - - tomli==2.0.1 - - toolz==0.12.0 - - torch-audiomentations==0.11.0 - - torch-pitch-shift==1.2.2 - - torchmetrics==0.11.0 - - tqdm==4.64.1 - - transformers==4.24.0 - - triton==2.0.0 - - typer==0.7.0 - - typing-extensions==4.7.1 - - uc-micro-py==1.0.2 - - urllib3==1.26.12 - - uvicorn==0.22.0 - - wcwidth==0.2.5 - - websockets==11.0.3 - - werkzeug==2.2.2 - - yarl==1.8.2 - - zipp==3.11.0 diff --git a/pyproject.toml b/pyproject.toml index e82881d..5d7b584 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ numpy = "^1.26.4" openai-whisper = "^20231117" faster-whisper = "^1.0.3" "pyannote.audio" = "^3.3.1" -torch = "^2.3.0" +torch = "^2.1.2" [tool.poetry.group.dev.dependencies] pytest = "^8.1.1" diff --git a/requirements.txt b/requirements.txt index 6e95c81..8786d84 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,5 +10,5 @@ pyannote.database~=5.0.1 pyannote.metrics~=3.2.1 pyannote.pipeline~=3.0.1 -torch>=2.0.0 +torchaudio>=2.1.2 From 2adbfaef515caaa1ecb870665171849859b89f0c Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Mon, 30 Sep 2024 12:27:41 +0000 Subject: [PATCH 16/30] added relative Path --- test/test_autotranscript.py | 8 ++++---- test/test_transcriber.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_autotranscript.py b/test/test_autotranscript.py index 78442b3..9d04c9b 100644 --- a/test/test_autotranscript.py +++ b/test/test_autotranscript.py @@ -6,7 +6,7 @@ @pytest.fixture def create_scraibe_instance(): if "HF_TOKEN" in os.environ: - return Scraibe(use_auth_token=os.environ["HF_TOKEN"]) + return Scraibe(use_auth_token=os.environ["HF_TOKEN"], whisper_model= "tiny") else: return Scraibe() @@ -19,19 +19,19 @@ def test_scraibe_init(create_scraibe_instance): def test_scraibe_autotranscribe(create_scraibe_instance): model = create_scraibe_instance - transcript = model.autotranscribe('test/audio_test_2.mp4') + transcript = model.autotranscribe('./test/audio_test_2.mp4') assert isinstance(transcript, Transcript) def test_scraibe_diarization(create_scraibe_instance): model = create_scraibe_instance - diarisation_result = model.diarization('test/audio_test_2.mp4') + diarisation_result = model.diarization('./test/audio_test_2.mp4') assert isinstance(diarisation_result, dict) def test_scraibe_transcribe(create_scraibe_instance): model = create_scraibe_instance - transcription_result = model.transcribe('test/audio_test_2.mp4') + transcription_result = model.transcribe('./test/audio_test_2.mp4') assert isinstance(transcription_result, str) diff --git a/test/test_transcriber.py b/test/test_transcriber.py index 5bfe3cf..80f79d2 100644 --- a/test/test_transcriber.py +++ b/test/test_transcriber.py @@ -69,12 +69,12 @@ def test_get_whisper_kwargs(): def test_whisper_transcribe(whisper_instance): model = whisper_instance # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) - transcript = model.transcribe('test/audio_test_2.mp4') + transcript = model.transcribe('./test/audio_test_2.mp4') assert isinstance(transcript, str) def test_faster_whisper_transcribe(faster_whisper_instance): model = faster_whisper_instance # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) - transcript = model.transcribe('test/audio_test_2.mp4') + transcript = model.transcribe('./test/audio_test_2.mp4') assert isinstance(transcript, str) From 5f6f681edfbc495a475ba48641391dc7e76562b2 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Mon, 30 Sep 2024 12:39:37 +0000 Subject: [PATCH 17/30] fix paths --- test/test_autotranscript.py | 6 +++--- test/test_transcriber.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_autotranscript.py b/test/test_autotranscript.py index 9d04c9b..865f507 100644 --- a/test/test_autotranscript.py +++ b/test/test_autotranscript.py @@ -19,19 +19,19 @@ def test_scraibe_init(create_scraibe_instance): def test_scraibe_autotranscribe(create_scraibe_instance): model = create_scraibe_instance - transcript = model.autotranscribe('./test/audio_test_2.mp4') + transcript = model.autotranscribe('./audio_test_2.mp4') assert isinstance(transcript, Transcript) def test_scraibe_diarization(create_scraibe_instance): model = create_scraibe_instance - diarisation_result = model.diarization('./test/audio_test_2.mp4') + diarisation_result = model.diarization('./audio_test_2.mp4') assert isinstance(diarisation_result, dict) def test_scraibe_transcribe(create_scraibe_instance): model = create_scraibe_instance - transcription_result = model.transcribe('./test/audio_test_2.mp4') + transcription_result = model.transcribe('./audio_test_2.mp4') assert isinstance(transcription_result, str) diff --git a/test/test_transcriber.py b/test/test_transcriber.py index 80f79d2..a805868 100644 --- a/test/test_transcriber.py +++ b/test/test_transcriber.py @@ -69,12 +69,12 @@ def test_get_whisper_kwargs(): def test_whisper_transcribe(whisper_instance): model = whisper_instance # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) - transcript = model.transcribe('./test/audio_test_2.mp4') + transcript = model.transcribe('./audio_test_2.mp4') assert isinstance(transcript, str) def test_faster_whisper_transcribe(faster_whisper_instance): model = faster_whisper_instance # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) - transcript = model.transcribe('./test/audio_test_2.mp4') + transcript = model.transcribe('./audio_test_2.mp4') assert isinstance(transcript, str) From 6326d0f15677e00909cfb76a737cec8efeec9e22 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Mon, 30 Sep 2024 16:03:27 +0000 Subject: [PATCH 18/30] added tests folder --- {test => tests}/audio_test_1.mp4 | Bin {test => tests}/audio_test_2.mp4 | Bin tests/test_audio.py | 96 +++++++++++++++++++++++++ {test => tests}/test_autotranscript.py | 6 +- tests/test_diarisation.py | 32 +++++++++ tests/test_transcriber.py | 80 +++++++++++++++++++++ 6 files changed, 211 insertions(+), 3 deletions(-) rename {test => tests}/audio_test_1.mp4 (100%) rename {test => tests}/audio_test_2.mp4 (100%) create mode 100644 tests/test_audio.py rename {test => tests}/test_autotranscript.py (88%) create mode 100644 tests/test_diarisation.py create mode 100644 tests/test_transcriber.py diff --git a/test/audio_test_1.mp4 b/tests/audio_test_1.mp4 similarity index 100% rename from test/audio_test_1.mp4 rename to tests/audio_test_1.mp4 diff --git a/test/audio_test_2.mp4 b/tests/audio_test_2.mp4 similarity index 100% rename from test/audio_test_2.mp4 rename to tests/audio_test_2.mp4 diff --git a/tests/test_audio.py b/tests/test_audio.py new file mode 100644 index 0000000..aee6cb3 --- /dev/null +++ b/tests/test_audio.py @@ -0,0 +1,96 @@ +import pytest +from scraibe.audio import AudioProcessor +import torch + + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE) +TEST_SR = 16000 +SAMPLE_RATE = 16000 +NORMALIZATION_FACTOR = 32768 + + +@pytest.fixture +def probe_audio_processor(): + """Fixture for creating an instance of the AudioProcessor class with test waveform and sample rate. + + This fixture is used to create an instance of the AudioProcessor class with a predfined test waveform and sample rate (TEST_SR). It returns the instantiated AudioProcessor , which can bes used as a + dependency in other test functions. + + + Returns: + AudioProcessor (obj): An instance of the AudioProcessor class with the test waveform and sample rate. + """ + return AudioProcessor(TEST_WAVEFORM, TEST_SR) + + +def test_AudioProcessor_init(probe_audio_processor): + """ + Test the initialization of the AudioProcessor class. + + This test verifies that the AUdioProcessor class is correctly initialized with the provided waveform and sample rate. It checks whether the instantiated AhdioProcessor object has the correct attributes + and whether the waveform and sample rate match the expected values. + + Args: + probe_audio_processor (obj): An instance of the AudioProcessor class to be tested. + + + Returns: + None + + + + + """ + assert isinstance(probe_audio_processor, AudioProcessor) + assert probe_audio_processor.waveform.device == TEST_WAVEFORM.device + assert torch.equal(probe_audio_processor.waveform, TEST_WAVEFORM) + assert probe_audio_processor.sr == TEST_SR + + +def test_cut(probe_audio_processor): + """Test the cut function of the AudioProcessor class. + + This test verifies that the cut function correctly extracts a segment of audio data from + the waveform, given start and end indices. It checks whether the size of the extracted segment matches + the expected size based on the provided start and end indices and the sample rate. + + Returns: + None + + + """ + + start = 4 + end = 7 + trimmed_waveform = probe_audio_processor.cut(start, end) + expected_size = int((end - start) * TEST_SR) + real_size = trimmed_waveform.size(0) + assert real_size == expected_size + # assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR) + + +def test_audio_processor_invalid_sr(): + """Test the behavior of AudioProcessor when an invalid smaple rate is provided. + + This test verifies that the AudioProcessor constructor raises a ValueError when an invalid sample rate is provided. It uses the pytest.raises context manager to check if the ValueError is raised when initializing an + AudioProcessor object with an invalid sample rate. + + Returns: + None + """ + with pytest.raises(ValueError): + AudioProcessor(TEST_WAVEFORM, [44100, 48000]) + + +def test_audio_processor_SAMPLE_RATE(): + """Test the default sample rate of the AudioProcessor class. + + This test verifies that the default sample rate of the AudioProcessor class matches the expected value defined by the constant SAMPLE_RATE. It instantiates an AudioProcessor object with a test waveform + and checks whether the sample rate attribute (sr) of the AudioProcessor object equals the predefined constant SAMPLE_RATE. + + Returns: + None + """ + probe_audio_processor = AudioProcessor(TEST_WAVEFORM) + assert probe_audio_processor.sr == SAMPLE_RATE diff --git a/test/test_autotranscript.py b/tests/test_autotranscript.py similarity index 88% rename from test/test_autotranscript.py rename to tests/test_autotranscript.py index 865f507..fbf18ab 100644 --- a/test/test_autotranscript.py +++ b/tests/test_autotranscript.py @@ -19,19 +19,19 @@ def test_scraibe_init(create_scraibe_instance): def test_scraibe_autotranscribe(create_scraibe_instance): model = create_scraibe_instance - transcript = model.autotranscribe('./audio_test_2.mp4') + transcript = model.autotranscribe('tests/audio_test_2.mp4') assert isinstance(transcript, Transcript) def test_scraibe_diarization(create_scraibe_instance): model = create_scraibe_instance - diarisation_result = model.diarization('./audio_test_2.mp4') + diarisation_result = model.diarization('tests/audio_test_2.mp4') assert isinstance(diarisation_result, dict) def test_scraibe_transcribe(create_scraibe_instance): model = create_scraibe_instance - transcription_result = model.transcribe('./audio_test_2.mp4') + transcription_result = model.transcribe('tests/audio_test_2.mp4') assert isinstance(transcription_result, str) diff --git a/tests/test_diarisation.py b/tests/test_diarisation.py new file mode 100644 index 0000000..01431be --- /dev/null +++ b/tests/test_diarisation.py @@ -0,0 +1,32 @@ +import pytest +from scraibe import Diariser + + +@pytest.fixture +def diariser_instance(): + """Fixture for creating an instance of the Diariser class with mocked token. + + This fixture is used to create an instance of the the Diariser class with a mocked token returned by the _get_token method. It patches the _get_token method of the Diariser class + using unit.test.mock.patch.object, ensuring that it returns a predetrmined value ('personal Hugging-Face token'). The mocked Diariser object is retunrned and can be used as a dependency in otehr tests. + + Returns: + Diariser(Obj): An instance of the Diariser class with a mocked token. + """ + # with mock.patch.object(Diariser, '_get_token', return_value = 'HF_TOKEN' ): + return Diariser('pyannote') + + +def test_Diariser_init(diariser_instance): + """Test the initialization of the Diariser class. + + This test verifies that the Diariser class is correctly initialized with the specified model. + It checks whether the 'model' attribute of the instantiated Diariser object equals 'pyannote'. + + + Args: + diariser_instance (obj): instance of the Diariser class + + Returns: + None + """ + assert diariser_instance.model == 'pyannote' diff --git a/tests/test_transcriber.py b/tests/test_transcriber.py new file mode 100644 index 0000000..ba9d99a --- /dev/null +++ b/tests/test_transcriber.py @@ -0,0 +1,80 @@ +import pytest +from scraibe import (Transcriber, WhisperTranscriber, + FasterWhisperTranscriber, load_transcriber) +import torch + + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +TEST_WAVEFORM = "Hello World" + +""" +@pytest.mark.parametrize("audio_file, expected_transcription",[("path_to_test_audiofile", "test_transcription")] ) +@patch("scraibe.Transcriber.load_model") + +def test_transcriber(mock_load_model, audio_file, expected_transcription): + + + Args: + mock_load_model (_type_): _description_ + audio_file (_type_): _description_ + expected_transcription (_type_): _description_ + + mock_model = mock_load_model.return_value + mock_model.transcribe.return_value ={"text": expected_transcription} + + transcriber = Transcriber.load_model(model="medium") + + transcription_result = transcriber.transcribe(audio=audio_file) + + assert transcription_result == expected_transcription """ + + +@pytest.fixture +def whisper_instance(): + return load_transcriber('tiny', whisper_type='whisper') + + +@pytest.fixture +def faster_whisper_instance(): + return load_transcriber('tiny', whisper_type='faster-whisper') + + +def test_whisper_base_initialization(whisper_instance): + assert isinstance(whisper_instance, Transcriber) + + +def test_faster_whisper_base_initialization(faster_whisper_instance): + assert isinstance(faster_whisper_instance, Transcriber) + + +def test_whisper_transcriber_initialization(whisper_instance): + assert isinstance(whisper_instance, WhisperTranscriber) + + +def test_faster_whisper_transcriber_initialization(faster_whisper_instance): + assert isinstance(faster_whisper_instance, FasterWhisperTranscriber) + + +def test_wrong_transcriber_initialization(): + with pytest.raises(ValueError): + load_transcriber('tiny', whisper_type='wrong_whisper') + + +def test_get_whisper_kwargs(): + kwargs = {"arg1": 1, "arg3": 3} + valid_kwargs = Transcriber._get_whisper_kwargs(**kwargs) + assert not valid_kwargs == {"arg1": 1, "arg3": 3} + + +def test_whisper_transcribe(whisper_instance): + model = whisper_instance + # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) + transcript = model.transcribe('tests/audio_test_2.mp4') + assert isinstance(transcript, str) + + +def test_faster_whisper_transcribe(faster_whisper_instance): + model = faster_whisper_instance + # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) + transcript = model.transcribe('tests/audio_test_2.mp4') + assert isinstance(transcript, str) From 81fb9af4618e06e417d2ba8d87c0705104fcefbd Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Mon, 30 Sep 2024 16:04:45 +0000 Subject: [PATCH 19/30] removed old test files --- test/test_audio.py | 96 --------------------------------------- test/test_diarisation.py | 32 ------------- test/test_transcriber.py | 80 -------------------------------- tests/test_diarization.py | 10 ---- 4 files changed, 218 deletions(-) delete mode 100644 test/test_audio.py delete mode 100644 test/test_diarisation.py delete mode 100644 test/test_transcriber.py delete mode 100644 tests/test_diarization.py diff --git a/test/test_audio.py b/test/test_audio.py deleted file mode 100644 index aee6cb3..0000000 --- a/test/test_audio.py +++ /dev/null @@ -1,96 +0,0 @@ -import pytest -from scraibe.audio import AudioProcessor -import torch - - -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") -TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE) -TEST_SR = 16000 -SAMPLE_RATE = 16000 -NORMALIZATION_FACTOR = 32768 - - -@pytest.fixture -def probe_audio_processor(): - """Fixture for creating an instance of the AudioProcessor class with test waveform and sample rate. - - This fixture is used to create an instance of the AudioProcessor class with a predfined test waveform and sample rate (TEST_SR). It returns the instantiated AudioProcessor , which can bes used as a - dependency in other test functions. - - - Returns: - AudioProcessor (obj): An instance of the AudioProcessor class with the test waveform and sample rate. - """ - return AudioProcessor(TEST_WAVEFORM, TEST_SR) - - -def test_AudioProcessor_init(probe_audio_processor): - """ - Test the initialization of the AudioProcessor class. - - This test verifies that the AUdioProcessor class is correctly initialized with the provided waveform and sample rate. It checks whether the instantiated AhdioProcessor object has the correct attributes - and whether the waveform and sample rate match the expected values. - - Args: - probe_audio_processor (obj): An instance of the AudioProcessor class to be tested. - - - Returns: - None - - - - - """ - assert isinstance(probe_audio_processor, AudioProcessor) - assert probe_audio_processor.waveform.device == TEST_WAVEFORM.device - assert torch.equal(probe_audio_processor.waveform, TEST_WAVEFORM) - assert probe_audio_processor.sr == TEST_SR - - -def test_cut(probe_audio_processor): - """Test the cut function of the AudioProcessor class. - - This test verifies that the cut function correctly extracts a segment of audio data from - the waveform, given start and end indices. It checks whether the size of the extracted segment matches - the expected size based on the provided start and end indices and the sample rate. - - Returns: - None - - - """ - - start = 4 - end = 7 - trimmed_waveform = probe_audio_processor.cut(start, end) - expected_size = int((end - start) * TEST_SR) - real_size = trimmed_waveform.size(0) - assert real_size == expected_size - # assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR) - - -def test_audio_processor_invalid_sr(): - """Test the behavior of AudioProcessor when an invalid smaple rate is provided. - - This test verifies that the AudioProcessor constructor raises a ValueError when an invalid sample rate is provided. It uses the pytest.raises context manager to check if the ValueError is raised when initializing an - AudioProcessor object with an invalid sample rate. - - Returns: - None - """ - with pytest.raises(ValueError): - AudioProcessor(TEST_WAVEFORM, [44100, 48000]) - - -def test_audio_processor_SAMPLE_RATE(): - """Test the default sample rate of the AudioProcessor class. - - This test verifies that the default sample rate of the AudioProcessor class matches the expected value defined by the constant SAMPLE_RATE. It instantiates an AudioProcessor object with a test waveform - and checks whether the sample rate attribute (sr) of the AudioProcessor object equals the predefined constant SAMPLE_RATE. - - Returns: - None - """ - probe_audio_processor = AudioProcessor(TEST_WAVEFORM) - assert probe_audio_processor.sr == SAMPLE_RATE diff --git a/test/test_diarisation.py b/test/test_diarisation.py deleted file mode 100644 index 01431be..0000000 --- a/test/test_diarisation.py +++ /dev/null @@ -1,32 +0,0 @@ -import pytest -from scraibe import Diariser - - -@pytest.fixture -def diariser_instance(): - """Fixture for creating an instance of the Diariser class with mocked token. - - This fixture is used to create an instance of the the Diariser class with a mocked token returned by the _get_token method. It patches the _get_token method of the Diariser class - using unit.test.mock.patch.object, ensuring that it returns a predetrmined value ('personal Hugging-Face token'). The mocked Diariser object is retunrned and can be used as a dependency in otehr tests. - - Returns: - Diariser(Obj): An instance of the Diariser class with a mocked token. - """ - # with mock.patch.object(Diariser, '_get_token', return_value = 'HF_TOKEN' ): - return Diariser('pyannote') - - -def test_Diariser_init(diariser_instance): - """Test the initialization of the Diariser class. - - This test verifies that the Diariser class is correctly initialized with the specified model. - It checks whether the 'model' attribute of the instantiated Diariser object equals 'pyannote'. - - - Args: - diariser_instance (obj): instance of the Diariser class - - Returns: - None - """ - assert diariser_instance.model == 'pyannote' diff --git a/test/test_transcriber.py b/test/test_transcriber.py deleted file mode 100644 index a805868..0000000 --- a/test/test_transcriber.py +++ /dev/null @@ -1,80 +0,0 @@ -import pytest -from scraibe import (Transcriber, WhisperTranscriber, - FasterWhisperTranscriber, load_transcriber) -import torch - - -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") -TEST_WAVEFORM = "Hello World" - -""" -@pytest.mark.parametrize("audio_file, expected_transcription",[("path_to_test_audiofile", "test_transcription")] ) -@patch("scraibe.Transcriber.load_model") - -def test_transcriber(mock_load_model, audio_file, expected_transcription): - - - Args: - mock_load_model (_type_): _description_ - audio_file (_type_): _description_ - expected_transcription (_type_): _description_ - - mock_model = mock_load_model.return_value - mock_model.transcribe.return_value ={"text": expected_transcription} - - transcriber = Transcriber.load_model(model="medium") - - transcription_result = transcriber.transcribe(audio=audio_file) - - assert transcription_result == expected_transcription """ - - -@pytest.fixture -def whisper_instance(): - return load_transcriber('tiny', whisper_type='whisper') - - -@pytest.fixture -def faster_whisper_instance(): - return load_transcriber('tiny', whisper_type='faster-whisper') - - -def test_whisper_base_initialization(whisper_instance): - assert isinstance(whisper_instance, Transcriber) - - -def test_faster_whisper_base_initialization(faster_whisper_instance): - assert isinstance(faster_whisper_instance, Transcriber) - - -def test_whisper_transcriber_initialization(whisper_instance): - assert isinstance(whisper_instance, WhisperTranscriber) - - -def test_faster_whisper_transcriber_initialization(faster_whisper_instance): - assert isinstance(faster_whisper_instance, FasterWhisperTranscriber) - - -def test_wrong_transcriber_initialization(): - with pytest.raises(ValueError): - load_transcriber('tiny', whisper_type='wrong_whisper') - - -def test_get_whisper_kwargs(): - kwargs = {"arg1": 1, "arg3": 3} - valid_kwargs = Transcriber._get_whisper_kwargs(**kwargs) - assert not valid_kwargs == {"arg1": 1, "arg3": 3} - - -def test_whisper_transcribe(whisper_instance): - model = whisper_instance - # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) - transcript = model.transcribe('./audio_test_2.mp4') - assert isinstance(transcript, str) - - -def test_faster_whisper_transcribe(faster_whisper_instance): - model = faster_whisper_instance - # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) - transcript = model.transcribe('./audio_test_2.mp4') - assert isinstance(transcript, str) diff --git a/tests/test_diarization.py b/tests/test_diarization.py deleted file mode 100644 index f9e81a5..0000000 --- a/tests/test_diarization.py +++ /dev/null @@ -1,10 +0,0 @@ -from os import environ - -environ["AUTOT_CACHE"] = "/mnt/disk1/Projekte/ScrAIbe/tests" -# environ["PYANNOTE_CACHE"] = "/mnt/disk1/Projekte/ScrAIbe/tests/pyannote" -# environ["TORCH_HOME"] = "/mnt/disk1/Projekte/ScrAIbe/tests/torch" - -from scraibe import Scraibe - -scraibe = Scraibe(whisper_type = "faster-whisper", whisper_model = "tiny") -print(scraibe.autotranscribe('/mnt/disk1/Projekte/ScrAIbe/test/audio_test_1.mp4')) \ No newline at end of file From 575a8de48d39b6756ce671842abc3f066a79d5c5 Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Tue, 1 Oct 2024 17:47:55 +0200 Subject: [PATCH 20/30] Re-fixed PYANNOTE_CACHE, now writes the variable. --- scraibe/misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scraibe/misc.py b/scraibe/misc.py index 106b9e1..e865f52 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -7,7 +7,7 @@ "AUTOT_CACHE", os.path.expanduser("~/.cache/torch/models"), ) -os.getenv( +os.environ["PYANNOTE_CACHE"] = os.getenv( "PYANNOTE_CACHE", os.path.join(CACHE_DIR, "pyannote"), ) From b07f593fab7ee7ef54d349640563b3005d051133 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Oct 2024 14:58:32 +0000 Subject: [PATCH 21/30] Bump openai-whisper from 20231117 to 20240930 Bumps [openai-whisper](https://github.com/openai/whisper) from 20231117 to 20240930. - [Release notes](https://github.com/openai/whisper/releases) - [Changelog](https://github.com/openai/whisper/blob/main/CHANGELOG.md) - [Commits](https://github.com/openai/whisper/compare/v20231117...v20240930) --- updated-dependencies: - dependency-name: openai-whisper dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8c46bdb..6aabcd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ exclude =[ python = "^3.9" tqdm = "^4.66.4" numpy = "^1.26.4" -openai-whisper = "^20231117" +openai-whisper = ">=20231117,<20240931" whisperx = "^3.1.3" "pyannote.audio" = "^3.1.1" torch = "^2.3.0" From fa1dad69d1427e5d4389bfab2a9aac6ee073a909 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Oct 2024 15:00:49 +0000 Subject: [PATCH 22/30] Update sphinx-rtd-theme requirement from ^2.0.0 to >=2,<4 Updates the requirements on [sphinx-rtd-theme](https://github.com/readthedocs/sphinx_rtd_theme) to permit the latest version. - [Changelog](https://github.com/readthedocs/sphinx_rtd_theme/blob/master/docs/changelog.rst) - [Commits](https://github.com/readthedocs/sphinx_rtd_theme/compare/2.0.0...3.0.0) --- updated-dependencies: - dependency-name: sphinx-rtd-theme dependency-type: direct:development ... Signed-off-by: dependabot[bot] --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8c46bdb..6a667d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ format-jinja = """ [tool.poetry.group.docs.dependencies] sphinx = "^7.3.7" -sphinx-rtd-theme = "^2.0.0" +sphinx-rtd-theme = ">=2,<4" markdown-it-py = {version = "~3.0.0", extras = ["plugins"]} myst-parser = "^3.0.1" mdit-py-plugins = "^0.4.1" From 6fadf3d851c06ffc130bfd4d6e758d7da5850830 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Tue, 8 Oct 2024 12:01:36 +0000 Subject: [PATCH 23/30] removed torch device from AudioProcessor class --- scraibe/audio.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/scraibe/audio.py b/scraibe/audio.py index 7fbc6fb..4e5dd0f 100644 --- a/scraibe/audio.py +++ b/scraibe/audio.py @@ -41,26 +41,20 @@ class AudioProcessor: The sample rate of the audio. """ - def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE, - *args, **kwargs) -> None: + def __init__(self, waveform: torch.Tensor, + sr: int = SAMPLE_RATE) -> None: """ Initialize the AudioProcessor object. Args: waveform (torch.Tensor): The audio waveform tensor. sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE. - args: Additional arguments. - kwargs: Additional keyword arguments, e.g., device to use for processing. - If CUDA is available, it defaults to CUDA. Raises: ValueError: If the provided sample rate is not of type int. """ - device = kwargs.get( - "device", "cuda" if torch.cuda.is_available() else "cpu") - - self.waveform = waveform.to(device) + self.waveform = waveform self.sr = sr if not isinstance(self.sr, int): @@ -147,6 +141,6 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): np.float32) / NORMALIZATION_FACTOR return out, sr - + def __repr__(self) -> str: return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' From 8813662d4df9cbea51940a82530c2782c8f22f28 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Tue, 8 Oct 2024 12:02:08 +0000 Subject: [PATCH 24/30] added SCRAIBE_TORCH_DEVICE to Scraibe Class to handle torch device setting --- scraibe/autotranscript.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/scraibe/autotranscript.py b/scraibe/autotranscript.py index 43dedc2..9023107 100644 --- a/scraibe/autotranscript.py +++ b/scraibe/autotranscript.py @@ -40,6 +40,7 @@ from .diarisation import Diariser from .transcriber import Transcriber, load_transcriber, whisper from .transcript_exporter import Transcript +from .misc import SCRAIBE_TORCH_DEVICE DiarisationType = TypeVar('DiarisationType') @@ -115,6 +116,9 @@ def __init__(self, **kwargs) else: self.params = {} + + self.device = kwargs.get( + "device", SCRAIBE_TORCH_DEVICE) def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray], remove_original: bool = False, @@ -141,10 +145,10 @@ def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray], # Prepare waveform and sample rate for diarization dia_audio = { - "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), + "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device), "sample_rate": audio_file.sr } - + if self.verbose: print("Starting diarisation.") @@ -165,8 +169,6 @@ def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray], if self.verbose: print("Diarisation finished. Starting transcription.") - audio_file.sr = torch.Tensor([audio_file.sr]).to( - audio_file.waveform.device) # Transcribe each segment and store the results final_transcript = dict() @@ -213,7 +215,7 @@ def diarization(self, audio_file: Union[str, torch.Tensor, ndarray], # Prepare waveform and sample rate for diarization dia_audio = { - "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), + "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device), "sample_rate": audio_file.sr } @@ -323,8 +325,7 @@ def remove_audio_file(audio_file: str, print(f"Audiofile {audio_file} removed.") @staticmethod - def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray], - *args, **kwargs) -> AudioProcessor: + def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray]) -> AudioProcessor: """Gets an audio file as TorchAudioProcessor. Args: From 44ff678e06aa99b0fdced7dd2b5675ec2165e495 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Tue, 8 Oct 2024 12:02:30 +0000 Subject: [PATCH 25/30] added SCRAIBE_TORCH_DEVICE Variable --- scraibe/misc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scraibe/misc.py b/scraibe/misc.py index 106b9e1..4a3de57 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -2,6 +2,7 @@ import yaml from argparse import Action from ast import literal_eval +from torch.cuda import is_available CACHE_DIR = os.getenv( "AUTOT_CACHE", @@ -18,6 +19,7 @@ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1') +SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu") def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: """Configure diarization pipeline from a YAML file. From af99a655e593093494600eb25353b82f4a44dcd6 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Thu, 10 Oct 2024 09:22:34 +0000 Subject: [PATCH 26/30] added SCRAIBE_TORCH_DEVICE to Diariser class --- scraibe/diarisation.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index d70df99..6e6d6b9 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -41,7 +41,7 @@ from huggingface_hub import HfApi from huggingface_hub.utils import RepositoryNotFoundError -from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG +from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, SCRAIBE_TORCH_DEVICE Annotation = TypeVar('Annotation') TOKEN_PATH = os.path.join(os.path.dirname( @@ -190,8 +190,7 @@ def load_model(cls, cache_token: bool = False, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, hparams_file: Union[str, Path] = None, - device: str = None, - *args, **kwargs + device: str = SCRAIBE_TORCH_DEVICE, ) -> Pipeline: """ Loads a pretrained model from pyannote.audio, @@ -283,10 +282,6 @@ def load_model(cls, 'or from huggingface.co models. Please check your token' 'or your local model path') - # try to move the model to the device - if device is None: - device = "cuda" if is_available() else "cpu" - # torch_device is renamed from torch.device to avoid name conflict _model = _model.to(torch_device(device)) From e7c1a5a2b01263acddb80c781d1c26292fc6210a Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Thu, 10 Oct 2024 09:25:49 +0000 Subject: [PATCH 27/30] added SCRAIBE_TORCH_DEVICE to transcriber class --- scraibe/transcriber.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index abf1ace..9c891f6 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -37,7 +37,7 @@ from abc import abstractmethod import warnings -from .misc import WHISPER_DEFAULT_PATH +from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE whisper = TypeVar('whisper') @@ -124,7 +124,7 @@ def load_model(cls, model: str = "medium", whisper_type: str = 'whisper', download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> None: @@ -206,7 +206,7 @@ def transcribe(self, audio: Union[str, Tensor, ndarray], def load_model(cls, model: str = "medium", download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> 'WhisperTranscriber': @@ -305,7 +305,7 @@ def transcribe(self, audio: Union[str, Tensor, ndarray], def load_model(cls, model: str = "medium", download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, *args, **kwargs ) -> 'FasterWhisperModel': """ @@ -330,7 +330,7 @@ def load_model(cls, Defaults to WHISPER_DEFAULT_PATH. device (Optional[Union[str, torch.device]], optional): - Device to load model on. Defaults to None. + Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE. in_memory (bool, optional): Whether to load model in memory. Defaults to False. args: Additional arguments only to avoid errors. @@ -339,10 +339,10 @@ def load_model(cls, Returns: Transcriber: A Transcriber object initialized with the specified model. """ - if device is None: - device = "cuda" if cuda_is_available() else "cpu" + if not isinstance(device, str): device = str(device) + compute_type = kwargs.get('compute_type', 'float16') if device == 'cpu' and compute_type == 'float16': warnings.warn(f'Compute type {compute_type} not compatible with ' @@ -412,7 +412,7 @@ def __repr__(self) -> str: def load_transcriber(model: str = "medium", whisper_type: str = 'whisper', download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> Union[WhisperTranscriber, FasterWhisperTranscriber]: @@ -438,7 +438,7 @@ def load_transcriber(model: str = "medium", download_root (str, optional): Path to download the model. Defaults to WHISPER_DEFAULT_PATH. device (Optional[Union[str, torch.device]], optional): - Device to load model on. Defaults to None. + Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE. in_memory (bool, optional): Whether to load model in memory. Defaults to False. args: Additional arguments only to avoid errors. From 101e913f849ce450dea400a4681095d6c39d455f Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Thu, 10 Oct 2024 09:29:48 +0000 Subject: [PATCH 28/30] make ruff happy --- scraibe/diarisation.py | 2 +- scraibe/transcriber.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index 6e6d6b9..eeef135 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -37,7 +37,7 @@ from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from torch import Tensor from torch import device as torch_device -from torch.cuda import is_available + from huggingface_hub import HfApi from huggingface_hub.utils import RepositoryNotFoundError diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index 9c891f6..040b79d 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -31,7 +31,6 @@ from faster_whisper.tokenizer import _LANGUAGE_CODES as FASTER_WHISPER_LANGUAGE_CODES from typing import TypeVar, Union, Optional from torch import Tensor, device -from torch.cuda import is_available as cuda_is_available from numpy import ndarray from inspect import signature from abc import abstractmethod From 08f14883e25391c03e29c61e659bdcab547d7a81 Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Thu, 10 Oct 2024 14:07:36 +0200 Subject: [PATCH 29/30] Bring dockerfile up to date --- Dockerfile | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/Dockerfile b/Dockerfile index 95db093..a9feb8e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ #pytorch Image -FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime +FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime # Labels @@ -14,33 +14,31 @@ LABEL url="https://github.com/JSchmie/ScrAIbe" # Install dependencies WORKDIR /app -ARG model_name=medium -#Enviorment Dependncies -ENV TRANSFORMERS_CACHE /app/models -ENV HF_HOME /app/models -ENV AUTOT_CACHE /app/models -ENV PYANNOTE_CACHE /app/models/pyannote +#Enviorment dependencies +ENV TRANSFORMERS_CACHE=/app/models +ENV HF_HOME=/app/models +ENV AUTOT_CACHE=/app/models +ENV PYANNOTE_CACHE=/app/models/pyannote #Copy all necessary files COPY requirements.txt /app/requirements.txt COPY README.md /app/README.md -COPY models /app/models COPY scraibe /app/scraibe -COPY setup.py /app/setup.py -#Installing all necessary Dependencies and Running the Application with a personalised Hugging-Face-Token -RUN apt update && apt-get install -y libsm6 libxrender1 libfontconfig1 -RUN conda update --all +#Installing all necessary dependencies and running the application with a personalised Hugging-Face-Token +RUN apt update -y && apt upgrade -y && \ + apt install -y libsm6 libxrender1 libfontconfig1 && \ + apt clean && \ + rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* -RUN conda install pip -RUN conda install -y ffmpeg -RUN conda install -c conda-forge libsndfile -RUN pip install torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html -RUN pip install -r requirements.txt -RUN pip install markupsafe==2.0.1 --force-reinstall +RUN conda update --all && \ + # conda install -y pip ffmpeg && \ + conda install -c conda-forge libsndfile && \ + conda clean --all -y +# RUN pip install torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html +RUN pip install --no-cache-dir -r requirements.txt -RUN python3 -m 'scraibe.cli' --whisper-model-name $model_name # Expose port EXPOSE 7860 # Run the application -ENTRYPOINT ["python3", "-m", "scraibe.cli" ,"--whisper-model-name", "$model_name"] \ No newline at end of file +ENTRYPOINT ["python3", "-m", "scraibe.cli"] \ No newline at end of file From de9071762e69db7eafacb819c2926d94b1a1f163 Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Thu, 10 Oct 2024 14:33:14 +0200 Subject: [PATCH 30/30] Docker image generation workflow from webui --- .github/workflows/docker.yaml | 95 +++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 .github/workflows/docker.yaml diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml new file mode 100644 index 0000000..75bd418 --- /dev/null +++ b/.github/workflows/docker.yaml @@ -0,0 +1,95 @@ +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +# GitHub recommends pinning actions to a commit SHA. +# To get a newer version, you will need to update the SHA. +# You can also reference a tag or branch, but the action may change without warning. + +name: Publish Docker image + +on: + + push: + tags: + - v* + + workflow_dispatch: + +env: + image: hadr0n/scraibe + +jobs: + push_to_registry: + name: Push Docker image to Docker Hub + runs-on: ubuntu-latest + permissions: + packages: write + contents: read + security-events: write + steps: + - name: Check out the repo + uses: actions/checkout@v4 + with: + fetch-tags: true + fetch-depth: 0 + + - name: Get Version Tag + id: version + run: | + echo "tag=$(git describe --tags --abbrev=0)" >> $GITHUB_OUTPUT + + - name: Overwrite label tag + run: sed -i 's/LABEL version=".*"/LABEL version="'${{ steps.version.outputs.tag }}'"/' Dockerfile + + - name: Test name and tag + run: | + echo "${{ env.image }}:latest,${{ env.image }}:${{ steps.version.outputs.tag }}" + + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and push Docker image + id: push + uses: docker/build-push-action@v5 + with: + context: . + file: ./Dockerfile + push: true + tags: "${{ env.image }}:latest,${{ env.image }}:${{ steps.version.outputs.tag }}" + + - name: SBOM Generation + uses: anchore/sbom-action@v0 + with: + image: ${{ env.image }}:latest + + - name: Scan image + id: scan + uses: anchore/scan-action@v3 + with: + image: ${{ env.image }}:latest + fail-build: false + + - name: upload Anchore scan SARIF report + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: ${{ steps.scan.outputs.sarif }} + + # - name: Inspect action SARIF report + # run: cat ${{ steps.scan.outputs.sarif }} + + - uses: actions/upload-artifact@v4 + with: + name: SARIF report + path: ${{ steps.scan.outputs.sarif }} + + # - name: Generate artifact attestation + # uses: actions/attest-build-provenance@v1 + # with: + # subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}} + # subject-digest: ${{ steps.push.outputs.digest }} + # push-to-registry: false