Skip to content

Commit

Permalink
Merge pull request #138 from JSchmie/fix_torch_threads
Browse files Browse the repository at this point in the history
Adding support for setting number of threads to faster-whisper
  • Loading branch information
mahenning authored Nov 25, 2024
2 parents 9528468 + de883bc commit d00ec2d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 6 deletions.
7 changes: 3 additions & 4 deletions scraibe/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from torch.cuda import is_available
from torch import set_num_threads
from .autotranscript import Scraibe
from .misc import set_threads

def cli():
"""
Expand Down Expand Up @@ -55,7 +55,7 @@ def str2bool(string):
default="cuda" if is_available() else "cpu",
help="Device to use for PyTorch inference.")

parser.add_argument("--num-threads", type=int, default=0,
parser.add_argument("--num-threads", type=int, default=None,
help="Number of threads used by torch for CPU inference; '\
'overrides MKL_NUM_THREADS/OMP_NUM_THREADS.")

Expand Down Expand Up @@ -94,8 +94,7 @@ def str2bool(string):

task = arg_dict.pop("task")

if args.num_threads > 0:
set_num_threads(arg_dict.pop("num_threads"))
set_threads(arg_dict.pop("num_threads"))

class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"),
'whisper_type':arg_dict.pop("whisper_type"),
Expand Down
24 changes: 24 additions & 0 deletions scraibe/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from argparse import Action
from ast import literal_eval
from torch.cuda import is_available
from torch import get_num_threads, set_num_threads

CACHE_DIR = os.getenv(
"AUTOT_CACHE",
Expand All @@ -21,6 +22,8 @@

SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu")

SCRAIBE_NUM_THREADS = os.getenv("SCRAIBE_NUM_THREADS", min(8, get_num_threads()))

def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
"""Configure diarization pipeline from a YAML file.
Expand Down Expand Up @@ -49,6 +52,27 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) ->
yaml.dump(yml, stream)


def set_threads(parse_threads=None,
yaml_threads=None):
global SCRAIBE_NUM_THREADS
if parse_threads is not None:
if not isinstance(parse_threads, int):
# probably covered with int type of parser arg
raise ValueError(f"Type of --num-threads must be int, but the type is {type(parse_threads)}")
elif parse_threads < 1:
raise ValueError(f"Number of threads must be a positive integer, {parse_threads} was given")
else:
set_num_threads(parse_threads)
SCRAIBE_NUM_THREADS = parse_threads
elif yaml_threads is not None:
if not isinstance(yaml_threads, int):
raise ValueError(f"Type of num_threads must be int, but the type is {type(yaml_threads)}")
elif yaml_threads < 1:
raise ValueError(f"Number of threads must be a positive integer, {yaml_threads} was given")
else:
set_num_threads(yaml_threads)
SCRAIBE_NUM_THREADS = yaml_threads

class ParseKwargs(Action):
"""
Custom argparse action to parse keyword arguments.
Expand Down
5 changes: 3 additions & 2 deletions scraibe/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from abc import abstractmethod
import warnings

from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE
from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE, SCRAIBE_NUM_THREADS
whisper = TypeVar('whisper')


Expand Down Expand Up @@ -348,7 +348,8 @@ def load_model(cls,
f'device {device}! Changing compute type to int8.')
compute_type = 'int8'
_model = FasterWhisperModel(model, download_root=download_root,
device=device, compute_type=compute_type)
device=device, compute_type=compute_type,
cpu_threads=SCRAIBE_NUM_THREADS)

return cls(_model, model_name=model)

Expand Down

0 comments on commit d00ec2d

Please sign in to comment.