diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml new file mode 100644 index 0000000..75bd418 --- /dev/null +++ b/.github/workflows/docker.yaml @@ -0,0 +1,95 @@ +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +# GitHub recommends pinning actions to a commit SHA. +# To get a newer version, you will need to update the SHA. +# You can also reference a tag or branch, but the action may change without warning. + +name: Publish Docker image + +on: + + push: + tags: + - v* + + workflow_dispatch: + +env: + image: hadr0n/scraibe + +jobs: + push_to_registry: + name: Push Docker image to Docker Hub + runs-on: ubuntu-latest + permissions: + packages: write + contents: read + security-events: write + steps: + - name: Check out the repo + uses: actions/checkout@v4 + with: + fetch-tags: true + fetch-depth: 0 + + - name: Get Version Tag + id: version + run: | + echo "tag=$(git describe --tags --abbrev=0)" >> $GITHUB_OUTPUT + + - name: Overwrite label tag + run: sed -i 's/LABEL version=".*"/LABEL version="'${{ steps.version.outputs.tag }}'"/' Dockerfile + + - name: Test name and tag + run: | + echo "${{ env.image }}:latest,${{ env.image }}:${{ steps.version.outputs.tag }}" + + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and push Docker image + id: push + uses: docker/build-push-action@v5 + with: + context: . + file: ./Dockerfile + push: true + tags: "${{ env.image }}:latest,${{ env.image }}:${{ steps.version.outputs.tag }}" + + - name: SBOM Generation + uses: anchore/sbom-action@v0 + with: + image: ${{ env.image }}:latest + + - name: Scan image + id: scan + uses: anchore/scan-action@v3 + with: + image: ${{ env.image }}:latest + fail-build: false + + - name: upload Anchore scan SARIF report + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: ${{ steps.scan.outputs.sarif }} + + # - name: Inspect action SARIF report + # run: cat ${{ steps.scan.outputs.sarif }} + + - uses: actions/upload-artifact@v4 + with: + name: SARIF report + path: ${{ steps.scan.outputs.sarif }} + + # - name: Generate artifact attestation + # uses: actions/attest-build-provenance@v1 + # with: + # subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}} + # subject-digest: ${{ steps.push.outputs.digest }} + # push-to-registry: false diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index a2641ee..1c75489 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -1,18 +1,14 @@ name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI on: - pull_request_target: - branches: - - develop - types: - - closed - paths: - - scraibe/** - - pyproject.toml - push: tags: - 'v*.*.*' + branches: + - "develop" + paths: + - "scraibe/**" + - "pyproject.toml" workflow_dispatch: inputs: @@ -27,13 +23,7 @@ on: jobs: Build-and-publish-to-Test-PyPI: - if: | - (github.event_name == 'workflow_dispatch' && - github.event.inputs.test == 'true') || - (github.event_name == 'pull_request_target' && - github.event.pull_request.merged && - contains(github.event.pull_request.labels.*.name, 'release')) || - (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')) + if: github.event_name != 'workflow_dispatch' || github.event.inputs.test == 'true' runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -72,28 +62,16 @@ jobs: needs: Test-PyPi-install runs-on: ubuntu-latest if: | - always() && - (( needs.Build-and-publish-to-Test-PyPI.result != 'failure' && - needs.Test-PyPi-install.result != 'failure' ) && - ((github.event_name == 'workflow_dispatch' && - github.event.inputs.publish_to_pypi == 'true') || - (github.event_name == 'pull_request_target' && - github.event.pull_request.merged && - contains(github.event.pull_request.labels.*.name, 'release')) || - (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')))) + always() && + (( needs.Build-and-publish-to-Test-PyPI.result != 'failure' && + needs.Test-PyPi-install.result != 'failure' ) || + ((github.event_name == 'workflow_dispatch' && + github.event.inputs.publish_to_pypi == 'true'))) steps: - - name: Checkout Repository Tags - uses: actions/checkout@v4 - if: github.ref == 'refs/heads/main' - with: - fetch-depth: '0' - branch: 'main' - name: Checkout Repository (Develop) uses: actions/checkout@v4 - if: github.ref == 'refs/heads/develop' with: fetch-depth: '0' - branch: 'develop' - name: Set up Poetry 📦 uses: JRubics/poetry-publish@v1.16 with: diff --git a/Dockerfile b/Dockerfile index 95db093..a9feb8e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ #pytorch Image -FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime +FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime # Labels @@ -14,33 +14,31 @@ LABEL url="https://github.com/JSchmie/ScrAIbe" # Install dependencies WORKDIR /app -ARG model_name=medium -#Enviorment Dependncies -ENV TRANSFORMERS_CACHE /app/models -ENV HF_HOME /app/models -ENV AUTOT_CACHE /app/models -ENV PYANNOTE_CACHE /app/models/pyannote +#Enviorment dependencies +ENV TRANSFORMERS_CACHE=/app/models +ENV HF_HOME=/app/models +ENV AUTOT_CACHE=/app/models +ENV PYANNOTE_CACHE=/app/models/pyannote #Copy all necessary files COPY requirements.txt /app/requirements.txt COPY README.md /app/README.md -COPY models /app/models COPY scraibe /app/scraibe -COPY setup.py /app/setup.py -#Installing all necessary Dependencies and Running the Application with a personalised Hugging-Face-Token -RUN apt update && apt-get install -y libsm6 libxrender1 libfontconfig1 -RUN conda update --all +#Installing all necessary dependencies and running the application with a personalised Hugging-Face-Token +RUN apt update -y && apt upgrade -y && \ + apt install -y libsm6 libxrender1 libfontconfig1 && \ + apt clean && \ + rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* -RUN conda install pip -RUN conda install -y ffmpeg -RUN conda install -c conda-forge libsndfile -RUN pip install torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html -RUN pip install -r requirements.txt -RUN pip install markupsafe==2.0.1 --force-reinstall +RUN conda update --all && \ + # conda install -y pip ffmpeg && \ + conda install -c conda-forge libsndfile && \ + conda clean --all -y +# RUN pip install torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html +RUN pip install --no-cache-dir -r requirements.txt -RUN python3 -m 'scraibe.cli' --whisper-model-name $model_name # Expose port EXPOSE 7860 # Run the application -ENTRYPOINT ["python3", "-m", "scraibe.cli" ,"--whisper-model-name", "$model_name"] \ No newline at end of file +ENTRYPOINT ["python3", "-m", "scraibe.cli"] \ No newline at end of file diff --git a/environment.yml b/environment.yml deleted file mode 100644 index 7913480..0000000 --- a/environment.yml +++ /dev/null @@ -1,256 +0,0 @@ -channels: - - pytorch - - defaults -dependencies: - - _libgcc_mutex=0.1=main - - _openmp_mutex=5.1=1_gnu - - blas=1.0=mkl - - brotlipy=0.7.0=py39h27cfd23_1003 - - bzip2=1.0.8=h7b6447c_0 - - ca-certificates=2023.05.30=h06a4308_0 - - certifi=2023.5.7=py39h06a4308_0 - - cffi=1.15.1=py39h5eee18b_3 - - cryptography=39.0.1=py39h9ce1e76_2 - - cudatoolkit=11.3.1=h2bc3f7f_2 - - ffmpeg=4.2.2=h20bf706_0 - - flit-core=3.8.0=py39h06a4308_0 - - freetype=2.12.1=h4a9f257_0 - - giflib=5.2.1=h5eee18b_3 - - gmp=6.2.1=h295c915_3 - - gnutls=3.6.15=he1e5248_0 - - idna=3.4=py39h06a4308_0 - - intel-openmp=2021.4.0=h06a4308_3561 - - jpeg=9e=h5eee18b_1 - - lame=3.100=h7b6447c_0 - - lcms2=2.12=h3be6417_0 - - ld_impl_linux-64=2.38=h1181459_1 - - lerc=3.0=h295c915_0 - - libdeflate=1.17=h5eee18b_0 - - libffi=3.4.2=h6a678d5_6 - - libgcc-ng=11.2.0=h1234567_1 - - libgomp=11.2.0=h1234567_1 - - libidn2=2.3.2=h7f8727e_0 - - libopus=1.3.1=h7b6447c_0 - - libpng=1.6.39=h5eee18b_0 - - libstdcxx-ng=11.2.0=h1234567_1 - - libtasn1=4.16.0=h27cfd23_0 - - libtiff=4.5.0=h6a678d5_2 - - libunistring=0.9.10=h27cfd23_0 - - libuv=1.44.2=h5eee18b_0 - - libvpx=1.7.0=h439df22_0 - - libwebp=1.2.4=h11a3e52_1 - - libwebp-base=1.2.4=h5eee18b_1 - - lz4-c=1.9.4=h6a678d5_0 - - mkl=2021.4.0=h06a4308_640 - - mkl-service=2.4.0=py39h7f8727e_0 - - mkl_fft=1.3.1=py39hd3c417c_0 - - mkl_random=1.2.2=py39h51133e4_0 - - ncurses=6.4=h6a678d5_0 - - nettle=3.7.3=hbbd107a_1 - - numpy=1.23.5=py39h14f4228_0 - - numpy-base=1.23.5=py39h31eccc5_0 - - openh264=2.1.1=h4ff587b_0 - - openssl=3.0.9=h7f8727e_0 - - pillow=9.4.0=py39h6a678d5_0 - - pip=23.0.1=py39h06a4308_0 - - pycparser=2.21=pyhd3eb1b0_0 - - pyopenssl=23.0.0=py39h06a4308_0 - - pysocks=1.7.1=py39h06a4308_0 - - python=3.9.16=h955ad1f_3 - - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0 - - pytorch-mutex=1.0=cuda - - readline=8.2=h5eee18b_0 - - requests=2.28.1=py39h06a4308_1 - - setuptools=65.6.3=py39h06a4308_0 - - six=1.16.0=pyhd3eb1b0_1 - - sqlite=3.41.2=h5eee18b_0 - - tk=8.6.12=h1ccaba5_0 - - torchaudio=0.11.0=py39_cu113 - - torchvision=0.12.0=py39_cu113 - - tzdata=2023c=h04d1e81_0 - - wheel=0.38.4=py39h06a4308_0 - - x264=1!157.20191217=h7b6447c_0 - - xz=5.4.2=h5eee18b_0 - - zlib=1.2.13=h5eee18b_0 - - zstd=1.5.4=hc292b87_0 - - pip: - - absl-py==1.3.0 - - aiofiles==23.1.0 - - aiohttp==3.8.3 - - aiosignal==1.3.1 - - alembic==1.9.1 - - altair==5.0.1 - - annotated-types==0.5.0 - - ansi2html==1.8.0 - - antlr4-python3-runtime==4.9.3 - - anyio==3.7.1 - - appdirs==1.4.4 - - asteroid-filterbanks==0.4.0 - - async-timeout==4.0.2 - - attrs==22.2.0 - - audioread==3.0.0 - - autopage==0.5.1 - - backports-cached-property==1.0.2 - - cachetools==5.2.0 - - charset-normalizer==2.1.1 - - click==8.1.3 - - cliff==4.1.0 - - cmaes==0.9.0 - - cmake==3.26.4 - - cmd2==2.4.2 - - colorama==0.4.6 - - colorlog==6.7.0 - - commonmark==0.9.1 - - contourpy==1.0.6 - - cycler==0.11.0 - - dash==2.12.1 - - dash-core-components==2.0.0 - - dash-html-components==2.0.0 - - dash-table==5.0.0 - - decorator==4.4.2 - - docopt==0.6.2 - - einops==0.3.2 - - exceptiongroup==1.1.1 - - fastapi==0.100.0 - - ffmpeg-python==0.2.0 - - ffmpy==0.3.0 - - filelock==3.8.0 - - flask==2.2.5 - - fonttools==4.38.0 - - frozenlist==1.3.3 - - fsspec==2022.11.0 - - future==0.18.2 - - google-auth==2.15.0 - - google-auth-oauthlib==0.4.6 - - gradio==3.36.1 - - gradio-client==0.2.7 - - greenlet==2.0.1 - - grpcio==1.51.1 - - h11==0.14.0 - - hmmlearn==0.2.8 - - httpcore==0.17.3 - - httpx==0.24.1 - - huggingface-hub==0.16.4 - - humanize==4.7.0 - - hyperpyyaml==1.1.0 - - imageio==2.23.0 - - imageio-ffmpeg==0.4.7 - - importlib-metadata==4.13.0 - - importlib-resources==5.12.0 - - iniconfig==2.0.0 - - itsdangerous==2.1.2 - - jinja2==3.1.2 - - joblib==1.2.0 - - jsonschema==4.18.0 - - jsonschema-specifications==2023.6.1 - - julius==0.2.7 - - kiwisolver==1.4.4 - - librosa==0.9.2 - - linkify-it-py==2.0.2 - - lit==16.0.5.post0 - - llvmlite==0.39.1 - - mako==1.2.4 - - markdown==3.4.1 - - markdown-it-py==2.2.0 - - markupsafe==2.1.1 - - matplotlib==3.7.1 - - mdit-py-plugins==0.3.3 - - mdurl==0.1.2 - - more-itertools==9.0.0 - - moviepy==1.0.3 - - mpmath==1.2.1 - - multidict==6.0.4 - - nest-asyncio==1.5.7 - - networkx==2.8.8 - - numba==0.56.4 - - oauthlib==3.2.2 - - omegaconf==2.3.0 - - openai-whisper==20230314 - - optuna==3.0.5 - - orjson==3.9.2 - - packaging==21.3 - - pandas==1.5.2 - - pbr==5.11.0 - - plotly==5.15.0 - - pluggy==1.0.0 - - pooch==1.6.0 - - prettytable==3.5.0 - - primepy==1.3 - - proglog==0.1.10 - - protobuf==3.20.1 - - pyannote-audio==2.1.1 - - pyannote-core==4.5 - - pyannote-database==4.1.3 - - pyannote-metrics==3.2.1 - - pyannote-pipeline==2.3 - - pyasn1==0.4.8 - - pyasn1-modules==0.2.8 - - pydantic==2.0.2 - - pydantic-core==2.1.2 - - pydeprecate==0.3.2 - - pydub==0.25.1 - - pygments==2.13.0 - - pyparsing==3.0.9 - - pyperclip==1.8.2 - - pytest==7.3.1 - - python-dateutil==2.8.2 - - python-multipart==0.0.6 - - pytorch-lightning==1.6.5 - - pytorch-metric-learning==1.6.3 - - pytz==2022.7 - - pyyaml==6.0 - - qtfaststart==1.8 - - referencing==0.29.1 - - regex==2022.10.31 - - requests-oauthlib==1.3.1 - - resampy==0.4.2 - - retrying==1.3.4 - - rich==12.6.0 - - rpds-py==0.8.10 - - rsa==4.9 - - ruamel-yaml==0.17.21 - - ruamel-yaml-clib==0.2.7 - - ruff==0.0.272 - - scikit-learn==1.2.0 - - scipy==1.8.1 - - semantic-version==2.10.0 - - semver==2.13.0 - - sentencepiece==0.1.97 - - setuptools-rust==1.5.2 - - shellingham==1.5.0 - - simplejson==3.18.0 - - singledispatchmethod==1.0 - - sniffio==1.3.0 - - sortedcontainers==2.4.0 - - soundfile==0.10.3.post1 - - speechbrain==0.5.14 - - sqlalchemy==1.4.45 - - starlette==0.27.0 - - stevedore==4.1.1 - - sympy==1.11.1 - - tabulate==0.9.0 - - tenacity==8.2.2 - - tensorboard==2.11.0 - - tensorboard-data-server==0.6.1 - - tensorboard-plugin-wit==1.8.1 - - threadpoolctl==3.1.0 - - tiktoken==0.3.1 - - tokenizers==0.13.2 - - tomli==2.0.1 - - toolz==0.12.0 - - torch-audiomentations==0.11.0 - - torch-pitch-shift==1.2.2 - - torchmetrics==0.11.0 - - tqdm==4.64.1 - - transformers==4.24.0 - - triton==2.0.0 - - typer==0.7.0 - - typing-extensions==4.7.1 - - uc-micro-py==1.0.2 - - urllib3==1.26.12 - - uvicorn==0.22.0 - - wcwidth==0.2.5 - - websockets==11.0.3 - - werkzeug==2.2.2 - - yarl==1.8.2 - - zipp==3.11.0 diff --git a/pyproject.toml b/pyproject.toml index 8c46bdb..c113502 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,12 +31,12 @@ exclude =[ ] [tool.poetry.dependencies] python = "^3.9" -tqdm = "^4.66.4" +tqdm = "^4.66.5" numpy = "^1.26.4" -openai-whisper = "^20231117" -whisperx = "^3.1.3" -"pyannote.audio" = "^3.1.1" -torch = "^2.3.0" +openai-whisper = ">=20231117,<20240931" +faster-whisper = "^1.0.3" +"pyannote.audio" = "^3.3.1" +torch = "^2.1.2" [tool.poetry.group.dev.dependencies] pytest = "^8.1.1" @@ -57,7 +57,7 @@ format-jinja = """ [tool.poetry.group.docs.dependencies] sphinx = "^7.3.7" -sphinx-rtd-theme = "^2.0.0" +sphinx-rtd-theme = ">=2,<4" markdown-it-py = {version = "~3.0.0", extras = ["plugins"]} myst-parser = "^3.0.1" mdit-py-plugins = "^0.4.1" diff --git a/requirements.txt b/requirements.txt index f08e2e6..8786d84 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,14 @@ -tqdm>=4.65.0 +tqdm>=4.66.5 numpy>=1.26.4 openai-whisper==20231117 -whisperx~=3.1.3 +faster-whisper~=1.0.3 -pyannote.audio~=3.1.1 +pyannote.audio~=3.3.1 pyannote.core~=5.0.0 pyannote.database~=5.0.1 pyannote.metrics~=3.2.1 pyannote.pipeline~=3.0.1 -torch>=2.0.0 +torchaudio>=2.1.2 diff --git a/scraibe/audio.py b/scraibe/audio.py index 7fbc6fb..4e5dd0f 100644 --- a/scraibe/audio.py +++ b/scraibe/audio.py @@ -41,26 +41,20 @@ class AudioProcessor: The sample rate of the audio. """ - def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE, - *args, **kwargs) -> None: + def __init__(self, waveform: torch.Tensor, + sr: int = SAMPLE_RATE) -> None: """ Initialize the AudioProcessor object. Args: waveform (torch.Tensor): The audio waveform tensor. sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE. - args: Additional arguments. - kwargs: Additional keyword arguments, e.g., device to use for processing. - If CUDA is available, it defaults to CUDA. Raises: ValueError: If the provided sample rate is not of type int. """ - device = kwargs.get( - "device", "cuda" if torch.cuda.is_available() else "cpu") - - self.waveform = waveform.to(device) + self.waveform = waveform self.sr = sr if not isinstance(self.sr, int): @@ -147,6 +141,6 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): np.float32) / NORMALIZATION_FACTOR return out, sr - + def __repr__(self) -> str: return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' diff --git a/scraibe/autotranscript.py b/scraibe/autotranscript.py index 7391f1a..9023107 100644 --- a/scraibe/autotranscript.py +++ b/scraibe/autotranscript.py @@ -40,6 +40,7 @@ from .diarisation import Diariser from .transcriber import Transcriber, load_transcriber, whisper from .transcript_exporter import Transcript +from .misc import SCRAIBE_TORCH_DEVICE DiarisationType = TypeVar('DiarisationType') @@ -74,7 +75,7 @@ def __init__(self, whisper_model (Union[bool, str, whisper], optional): Path to whisper model or whisper model itself. whisper_type (str): - Type of whisper model to load. "whisper" or "whisperx". + Type of whisper model to load. "whisper" or "faster-whisper". diarisation_model (Union[bool, str, DiarisationType], optional): Path to pyannote diarization model or model itself. **kwargs: Additional keyword arguments for whisper @@ -115,6 +116,9 @@ def __init__(self, **kwargs) else: self.params = {} + + self.device = kwargs.get( + "device", SCRAIBE_TORCH_DEVICE) def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray], remove_original: bool = False, @@ -141,10 +145,10 @@ def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray], # Prepare waveform and sample rate for diarization dia_audio = { - "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), + "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device), "sample_rate": audio_file.sr } - + if self.verbose: print("Starting diarisation.") @@ -165,8 +169,6 @@ def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray], if self.verbose: print("Diarisation finished. Starting transcription.") - audio_file.sr = torch.Tensor([audio_file.sr]).to( - audio_file.waveform.device) # Transcribe each segment and store the results final_transcript = dict() @@ -213,7 +215,7 @@ def diarization(self, audio_file: Union[str, torch.Tensor, ndarray], # Prepare waveform and sample rate for diarization dia_audio = { - "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), + "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device), "sample_rate": audio_file.sr } @@ -323,8 +325,7 @@ def remove_audio_file(audio_file: str, print(f"Audiofile {audio_file} removed.") @staticmethod - def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray], - *args, **kwargs) -> AudioProcessor: + def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray]) -> AudioProcessor: """Gets an audio file as TorchAudioProcessor. Args: diff --git a/scraibe/cli.py b/scraibe/cli.py index ee40c8b..df73d1b 100644 --- a/scraibe/cli.py +++ b/scraibe/cli.py @@ -36,8 +36,8 @@ def str2bool(string): help="List of audio files to transcribe.") parser.add_argument("--whisper-type", type=str, default="whisper", - choices=["whisper", "whisperx"], - help="Type of Whisper model to use ('whisper' or 'whisperx').") + choices=["whisper", "faster-whisper"], + help="Type of Whisper model to use ('whisper' or 'faster-whisper').") parser.add_argument("--whisper-model-name", default="medium", help="Name of the Whisper model to use.") @@ -79,6 +79,8 @@ def str2bool(string): choices=sorted( LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="Language spoken in the audio. Specify None to perform language detection.") + parser.add_argument("--num-speakers", type=int, default=2, + help="Number of speakers in the audio.") args = parser.parse_args() @@ -117,8 +119,13 @@ def str2bool(string): else: task = "transcribe" - out = model.autotranscribe(audio, task=task, language=arg_dict.pop( - "language"), verbose=arg_dict.pop("verbose_output")) + out = model.autotranscribe( + audio, + task=task, + language=arg_dict.pop("language"), + verbose=arg_dict.pop("verbose_output"), + num_speakers=arg_dict.pop("num_speakers") + ) basename = audio.split("/")[-1].split(".")[0] print(f'Saving {basename}.{out_format} to {out_folder}') out.save(os.path.join( diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index d70df99..eeef135 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -37,11 +37,11 @@ from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from torch import Tensor from torch import device as torch_device -from torch.cuda import is_available + from huggingface_hub import HfApi from huggingface_hub.utils import RepositoryNotFoundError -from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG +from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, SCRAIBE_TORCH_DEVICE Annotation = TypeVar('Annotation') TOKEN_PATH = os.path.join(os.path.dirname( @@ -190,8 +190,7 @@ def load_model(cls, cache_token: bool = False, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, hparams_file: Union[str, Path] = None, - device: str = None, - *args, **kwargs + device: str = SCRAIBE_TORCH_DEVICE, ) -> Pipeline: """ Loads a pretrained model from pyannote.audio, @@ -283,10 +282,6 @@ def load_model(cls, 'or from huggingface.co models. Please check your token' 'or your local model path') - # try to move the model to the device - if device is None: - device = "cuda" if is_available() else "cpu" - # torch_device is renamed from torch.device to avoid name conflict _model = _model.to(torch_device(device)) diff --git a/scraibe/misc.py b/scraibe/misc.py index f12335f..4f5ab1a 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -1,23 +1,25 @@ import os import yaml -from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR from argparse import Action from ast import literal_eval +from torch.cuda import is_available CACHE_DIR = os.getenv( "AUTOT_CACHE", os.path.expanduser("~/.cache/torch/models"), ) - -if CACHE_DIR != PYANNOTE_CACHE_DIR: - os.environ["PYANNOTE_CACHE"] = os.path.join(CACHE_DIR, "pyannote") +os.environ["PYANNOTE_CACHE"] = os.getenv( + "PYANNOTE_CACHE", + os.path.join(CACHE_DIR, "pyannote"), +) WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ - else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1') + else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1') +SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu") def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: """Configure diarization pipeline from a YAML file. diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index 0301955..040b79d 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -26,17 +26,17 @@ from whisper import Whisper from whisper import load_model as whisper_load_model -from whisperx.asr import WhisperModel -from whisperx import load_model as whisperx_load_model +from whisper.tokenizer import TO_LANGUAGE_CODE +from faster_whisper import WhisperModel as FasterWhisperModel +from faster_whisper.tokenizer import _LANGUAGE_CODES as FASTER_WHISPER_LANGUAGE_CODES from typing import TypeVar, Union, Optional from torch import Tensor, device -from torch.cuda import is_available as cuda_is_available from numpy import ndarray from inspect import signature from abc import abstractmethod import warnings -from .misc import WHISPER_DEFAULT_PATH +from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE whisper = TypeVar('whisper') @@ -123,7 +123,7 @@ def load_model(cls, model: str = "medium", whisper_type: str = 'whisper', download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> None: @@ -145,7 +145,7 @@ def load_model(cls, - 'large-v3' - 'large' whisper_type (str): - Type of whisper model to load. "whisper" or "whisperx". + Type of whisper model to load. "whisper" or "faster-whisper". download_root (str, optional): Path to download the model. Defaults to WHISPER_DEFAULT_PATH. device (Optional[Union[str, torch.device]], optional): @@ -205,7 +205,7 @@ def transcribe(self, audio: Union[str, Tensor, ndarray], def load_model(cls, model: str = "medium", download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> 'WhisperTranscriber': @@ -272,7 +272,7 @@ def __repr__(self) -> str: return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})" -class WhisperXTranscriber(Transcriber): +class FasterWhisperTranscriber(Transcriber): def __init__(self, model: whisper, model_name: str) -> None: super().__init__(model, model_name) @@ -294,19 +294,19 @@ def transcribe(self, audio: Union[str, Tensor, ndarray], if isinstance(audio, Tensor): audio = audio.cpu().numpy() - result = self.model.transcribe(audio, *args, **kwargs) + result, _ = self.model.transcribe(audio, *args, **kwargs) text = "" - for seg in result['segments']: - text += seg['text'] + for seg in result: + text += seg.text return text @classmethod def load_model(cls, model: str = "medium", download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, *args, **kwargs - ) -> 'WhisperXTranscriber': + ) -> 'FasterWhisperModel': """ Load whisper model. @@ -329,7 +329,7 @@ def load_model(cls, Defaults to WHISPER_DEFAULT_PATH. device (Optional[Union[str, torch.device]], optional): - Device to load model on. Defaults to None. + Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE. in_memory (bool, optional): Whether to load model in memory. Defaults to False. args: Additional arguments only to avoid errors. @@ -338,17 +338,17 @@ def load_model(cls, Returns: Transcriber: A Transcriber object initialized with the specified model. """ - if device is None: - device = "cuda" if cuda_is_available() else "cpu" + if not isinstance(device, str): device = str(device) + compute_type = kwargs.get('compute_type', 'float16') if device == 'cpu' and compute_type == 'float16': warnings.warn(f'Compute type {compute_type} not compatible with ' f'device {device}! Changing compute type to int8.') compute_type = 'int8' - _model = whisperx_load_model(model, download_root=download_root, - device=device, compute_type=compute_type) + _model = FasterWhisperModel(model, download_root=download_root, + device=device, compute_type=compute_type) return cls(_model, model_name=model) @@ -361,7 +361,7 @@ def _get_whisper_kwargs(**kwargs) -> dict: dict: Keyword arguments for whisper model. """ # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames - _possible_kwargs = signature(WhisperModel.transcribe).parameters.keys() + _possible_kwargs = signature(FasterWhisperModel.transcribe).parameters.keys() whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} @@ -370,21 +370,51 @@ def _get_whisper_kwargs(**kwargs) -> dict: whisper_kwargs["task"] = task if (language := kwargs.get("language")): + language = FasterWhisperTranscriber.convert_to_language_code(language) whisper_kwargs["language"] = language return whisper_kwargs + @staticmethod + def convert_to_language_code(lang : str) -> str: + """ + Load whisper model. + + Args: + lang (str): language as code or language name + + Returns: + language (str) code of language + """ + + # If the input is already in FASTER_WHISPER_LANGUAGE_CODES, return it directly + if lang in FASTER_WHISPER_LANGUAGE_CODES: + return lang + + # Normalize the input to lowercase + lang = lang.lower() + + # Check if the language name is in the TO_LANGUAGE_CODE mapping + if lang in TO_LANGUAGE_CODE: + return TO_LANGUAGE_CODE[lang] + + # If the language is not recognized, raise a ValueError with the available options + available_codes = ', '.join(FASTER_WHISPER_LANGUAGE_CODES) + raise ValueError(f"Language '{lang}' is not a valid language code or name. " + f"Available language codes are: {available_codes}.") + def __repr__(self) -> str: - return f"WhisperXTranscriber(model_name={self.model_name}, model={self.model})" + return f"FasterWhisperTranscriber(model_name={self.model_name}, model={self.model})" + def load_transcriber(model: str = "medium", whisper_type: str = 'whisper', download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs - ) -> Union[WhisperTranscriber, WhisperXTranscriber]: + ) -> Union[WhisperTranscriber, FasterWhisperTranscriber]: """ Load whisper model. @@ -403,28 +433,28 @@ def load_transcriber(model: str = "medium", - 'large-v3' - 'large' whisper_type (str): - Type of whisper model to load. "whisper" or "whisperx". + Type of whisper model to load. "whisper" or "faster-whisper". download_root (str, optional): Path to download the model. Defaults to WHISPER_DEFAULT_PATH. - device (Optional[Union[str, torch.device]], optional): - Device to load model on. Defaults to None. - in_memory (bool, optional): Whether to load model in memory. + device (Optional[Union[str, torch.device]], optional): + Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE. + in_memory (bool, optional): Whether to load model in memory. Defaults to False. args: Additional arguments only to avoid errors. kwargs: Additional keyword arguments only to avoid errors. Returns: - Union[WhisperTranscriber, WhisperXTranscriber]: + Union[WhisperTranscriber, FasterWhisperTranscriber]: One of the Whisper variants as Transcrbier object initialized with the specified model. """ if whisper_type.lower() == 'whisper': _model = WhisperTranscriber.load_model( model, download_root, device, in_memory, *args, **kwargs) return _model - elif whisper_type.lower() == 'whisperx': - _model = WhisperXTranscriber.load_model( + elif whisper_type.lower() == 'faster-whisper': + _model = FasterWhisperTranscriber.load_model( model, download_root, device, *args, **kwargs) return _model else: raise ValueError(f'Model type not recognized, exptected "whisper" ' - f'or "whisperx", got {whisper_type}.') + f'or "faster-whisper", got {whisper_type}.') diff --git a/test/audio_test_1.mp4 b/tests/audio_test_1.mp4 similarity index 100% rename from test/audio_test_1.mp4 rename to tests/audio_test_1.mp4 diff --git a/test/audio_test_2.mp4 b/tests/audio_test_2.mp4 similarity index 100% rename from test/audio_test_2.mp4 rename to tests/audio_test_2.mp4 diff --git a/test/test_audio.py b/tests/test_audio.py similarity index 100% rename from test/test_audio.py rename to tests/test_audio.py diff --git a/test/test_autotranscript.py b/tests/test_autotranscript.py similarity index 83% rename from test/test_autotranscript.py rename to tests/test_autotranscript.py index 78442b3..fbf18ab 100644 --- a/test/test_autotranscript.py +++ b/tests/test_autotranscript.py @@ -6,7 +6,7 @@ @pytest.fixture def create_scraibe_instance(): if "HF_TOKEN" in os.environ: - return Scraibe(use_auth_token=os.environ["HF_TOKEN"]) + return Scraibe(use_auth_token=os.environ["HF_TOKEN"], whisper_model= "tiny") else: return Scraibe() @@ -19,19 +19,19 @@ def test_scraibe_init(create_scraibe_instance): def test_scraibe_autotranscribe(create_scraibe_instance): model = create_scraibe_instance - transcript = model.autotranscribe('test/audio_test_2.mp4') + transcript = model.autotranscribe('tests/audio_test_2.mp4') assert isinstance(transcript, Transcript) def test_scraibe_diarization(create_scraibe_instance): model = create_scraibe_instance - diarisation_result = model.diarization('test/audio_test_2.mp4') + diarisation_result = model.diarization('tests/audio_test_2.mp4') assert isinstance(diarisation_result, dict) def test_scraibe_transcribe(create_scraibe_instance): model = create_scraibe_instance - transcription_result = model.transcribe('test/audio_test_2.mp4') + transcription_result = model.transcribe('tests/audio_test_2.mp4') assert isinstance(transcription_result, str) diff --git a/test/test_diarisation.py b/tests/test_diarisation.py similarity index 100% rename from test/test_diarisation.py rename to tests/test_diarisation.py diff --git a/test/test_transcriber.py b/tests/test_transcriber.py similarity index 69% rename from test/test_transcriber.py rename to tests/test_transcriber.py index 31765f6..ba9d99a 100644 --- a/test/test_transcriber.py +++ b/tests/test_transcriber.py @@ -1,6 +1,6 @@ import pytest from scraibe import (Transcriber, WhisperTranscriber, - WhisperXTranscriber, load_transcriber) + FasterWhisperTranscriber, load_transcriber) import torch @@ -31,33 +31,33 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription): @pytest.fixture def whisper_instance(): - return load_transcriber('medium', whisper_type='whisper') + return load_transcriber('tiny', whisper_type='whisper') @pytest.fixture -def whisperx_instance(): - return load_transcriber('medium', whisper_type='whisperx') +def faster_whisper_instance(): + return load_transcriber('tiny', whisper_type='faster-whisper') def test_whisper_base_initialization(whisper_instance): assert isinstance(whisper_instance, Transcriber) -def test_whisperx_base_initialization(whisperx_instance): - assert isinstance(whisperx_instance, Transcriber) +def test_faster_whisper_base_initialization(faster_whisper_instance): + assert isinstance(faster_whisper_instance, Transcriber) def test_whisper_transcriber_initialization(whisper_instance): assert isinstance(whisper_instance, WhisperTranscriber) -def test_whisperx_transcriber_initialization(whisperx_instance): - assert isinstance(whisperx_instance, WhisperXTranscriber) +def test_faster_whisper_transcriber_initialization(faster_whisper_instance): + assert isinstance(faster_whisper_instance, FasterWhisperTranscriber) def test_wrong_transcriber_initialization(): with pytest.raises(ValueError): - load_transcriber('medium', whisper_type='wrong_whisper') + load_transcriber('tiny', whisper_type='wrong_whisper') def test_get_whisper_kwargs(): @@ -69,12 +69,12 @@ def test_get_whisper_kwargs(): def test_whisper_transcribe(whisper_instance): model = whisper_instance # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) - transcript = model.transcribe('test/audio_test_2.mp4') + transcript = model.transcribe('tests/audio_test_2.mp4') assert isinstance(transcript, str) -def test_whisperx_transcribe(whisperx_instance): - model = whisperx_instance +def test_faster_whisper_transcribe(faster_whisper_instance): + model = faster_whisper_instance # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) - transcript = model.transcribe('test/audio_test_2.mp4') + transcript = model.transcribe('tests/audio_test_2.mp4') assert isinstance(transcript, str)