Skip to content

Commit

Permalink
Fix the import of global variables, make them non-static
Browse files Browse the repository at this point in the history
  • Loading branch information
mahenning committed Nov 27, 2024
1 parent 2238150 commit 1e40e4c
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions scraibe/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from abc import abstractmethod
import warnings

from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE, SCRAIBE_NUM_THREADS
import scraibe.misc
whisper = TypeVar('whisper')


Expand Down Expand Up @@ -122,8 +122,8 @@ def save_transcript(transcript: str, save_path: str) -> None:
def load_model(cls,
model: str = "medium",
whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
download_root: str = scraibe.misc.WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = scraibe.misc.SCRAIBE_TORCH_DEVICE,
in_memory: bool = False,
*args, **kwargs
) -> None:
Expand Down Expand Up @@ -204,8 +204,8 @@ def transcribe(self, audio: Union[str, Tensor, ndarray],
@classmethod
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
download_root: str = scraibe.misc.WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = scraibe.misc.SCRAIBE_TORCH_DEVICE,
in_memory: bool = False,
*args, **kwargs
) -> 'WhisperTranscriber':
Expand Down Expand Up @@ -303,8 +303,8 @@ def transcribe(self, audio: Union[str, Tensor, ndarray],
@classmethod
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
download_root: str = scraibe.misc.WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = scraibe.misc.SCRAIBE_TORCH_DEVICE,
*args, **kwargs
) -> 'FasterWhisperModel':
"""
Expand Down Expand Up @@ -349,7 +349,7 @@ def load_model(cls,
compute_type = 'int8'
_model = FasterWhisperModel(model, download_root=download_root,
device=device, compute_type=compute_type,
cpu_threads=SCRAIBE_NUM_THREADS)
cpu_threads=scraibe.misc.SCRAIBE_NUM_THREADS)

return cls(_model, model_name=model)

Expand Down Expand Up @@ -411,8 +411,8 @@ def __repr__(self) -> str:

def load_transcriber(model: str = "medium",
whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
download_root: str = scraibe.misc.WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = scraibe.misc.SCRAIBE_TORCH_DEVICE,
in_memory: bool = False,
*args, **kwargs
) -> Union[WhisperTranscriber, FasterWhisperTranscriber]:
Expand Down

0 comments on commit 1e40e4c

Please sign in to comment.