From 663675c7b2d8be21ec3f86d452f18dd8e4c40637 Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Tue, 19 Nov 2024 17:19:02 +0100 Subject: [PATCH 1/2] Adding support for setting number of threads to faster-whisper cpu, reading from cli, yaml or env var. --- scraibe/cli.py | 7 +++---- scraibe/misc.py | 25 +++++++++++++++++++++++++ scraibe/transcriber.py | 5 +++-- 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/scraibe/cli.py b/scraibe/cli.py index df73d1b..e4eeaad 100644 --- a/scraibe/cli.py +++ b/scraibe/cli.py @@ -9,8 +9,8 @@ from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE from torch.cuda import is_available -from torch import set_num_threads from .autotranscript import Scraibe +from .misc import set_threads def cli(): """ @@ -55,7 +55,7 @@ def str2bool(string): default="cuda" if is_available() else "cpu", help="Device to use for PyTorch inference.") - parser.add_argument("--num-threads", type=int, default=0, + parser.add_argument("--num-threads", type=int, default=None, help="Number of threads used by torch for CPU inference; '\ 'overrides MKL_NUM_THREADS/OMP_NUM_THREADS.") @@ -94,8 +94,7 @@ def str2bool(string): task = arg_dict.pop("task") - if args.num_threads > 0: - set_num_threads(arg_dict.pop("num_threads")) + set_threads(arg_dict.pop("num_threads")) class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"), 'whisper_type':arg_dict.pop("whisper_type"), diff --git a/scraibe/misc.py b/scraibe/misc.py index 4f5ab1a..c8e89a1 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -3,6 +3,7 @@ from argparse import Action from ast import literal_eval from torch.cuda import is_available +from torch import get_num_threads, set_num_threads CACHE_DIR = os.getenv( "AUTOT_CACHE", @@ -21,6 +22,8 @@ SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu") +SCRAIBE_NUM_THREADS = os.getenv("SCRAIBE_NUM_THREADS", min(8, get_num_threads())) + def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: """Configure diarization pipeline from a YAML file. @@ -49,6 +52,28 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> yaml.dump(yml, stream) +def set_threads(parse_threads=None, + yaml_threads=None, + env_var_threads=None): + global SCRAIBE_NUM_THREADS + if parse_threads is not None: + if not isinstance(parse_threads, int): + # probably covered with int type of parser arg + raise ValueError(f"Type of --num-threads must be int, but the type is {type(parse_threads)}") + elif parse_threads < 1: + raise ValueError(f"Number of threads must be a positive integer, {parse_threads} was given") + else: + set_num_threads(parse_threads) + SCRAIBE_NUM_THREADS = parse_threads + elif yaml_threads is not None: + if not isinstance(yaml_threads, int): + raise ValueError(f"Type of num_threads must be int, but the type is {type(yaml_threads)}") + elif yaml_threads < 1: + raise ValueError(f"Number of threads must be a positive integer, {yaml_threads} was given") + else: + set_num_threads(yaml_threads) + SCRAIBE_NUM_THREADS = yaml_threads + class ParseKwargs(Action): """ Custom argparse action to parse keyword arguments. diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index 040b79d..bc341dc 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -36,7 +36,7 @@ from abc import abstractmethod import warnings -from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE +from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE, SCRAIBE_NUM_THREADS whisper = TypeVar('whisper') @@ -348,7 +348,8 @@ def load_model(cls, f'device {device}! Changing compute type to int8.') compute_type = 'int8' _model = FasterWhisperModel(model, download_root=download_root, - device=device, compute_type=compute_type) + device=device, compute_type=compute_type, + cpu_threads=SCRAIBE_NUM_THREADS) return cls(_model, model_name=model) From de883bc06246d14abc2a706e4acfdd63aaf024ad Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Mon, 25 Nov 2024 13:39:18 +0100 Subject: [PATCH 2/2] removed unused parameter --- scraibe/misc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scraibe/misc.py b/scraibe/misc.py index c8e89a1..f5d2bfe 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -53,8 +53,7 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> def set_threads(parse_threads=None, - yaml_threads=None, - env_var_threads=None): + yaml_threads=None): global SCRAIBE_NUM_THREADS if parse_threads is not None: if not isinstance(parse_threads, int):