Skip to content

Commit

Permalink
added update functions for transcriber and diariser + adding some typ…
Browse files Browse the repository at this point in the history
…e hints
  • Loading branch information
JSchmie committed Apr 24, 2024
1 parent 8f58aeb commit 0666548
Showing 1 changed file with 52 additions and 5 deletions.
57 changes: 52 additions & 5 deletions scraibe/autotranscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self,
elif isinstance(dia_model, str):
self.diariser = Diariser.load_model(dia_model, **kwargs)
else:
self.diariser = dia_model
self.diariser : Diariser = dia_model

if kwargs.get("verbose"):
print("Scraibe initialized all models successfully loaded.")
Expand Down Expand Up @@ -133,7 +133,7 @@ def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray],
if kwargs.get("verbose"):
self.verbose = kwargs.get("verbose")
# Get audio file as an AudioProcessor object
audio_file = self.get_audio_file(audio_file)
audio_file : AudioProcessor = self.get_audio_file(audio_file)

# Prepare waveform and sample rate for diarization
dia_audio = {
Expand Down Expand Up @@ -203,7 +203,7 @@ def diarization(self, audio_file : Union[str, torch.Tensor, ndarray],
"""

# Get audio file as an AudioProcessor object
audio_file = self.get_audio_file(audio_file)
audio_file : AudioProcessor = self.get_audio_file(audio_file)

# Prepare waveform and sample rate for diarization
dia_audio = {
Expand Down Expand Up @@ -232,9 +232,56 @@ def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray],
str:
The transcribed text from the audio source.
"""
audio_file = self.get_audio_file(audio_file)
audio_file : AudioProcessor = self.get_audio_file(audio_file)

return self.transcriber.transcribe(audio_file.waveform, **kwargs)

def update_transcriber(self, whisper_model : Union[str, whisper], **kwargs) -> None:
"""
Update the transcriber model.
Args:
whisper_model (Union[str, whisper]):
The new whisper model to use for transcription.
**kwargs:
Additional keyword arguments for the transcriber model.
Returns:
None
"""
_old_model = self.transcriber.model_name

if isinstance(whisper_model, str):
self.transcriber = Transcriber.load_model(whisper_model, **kwargs)
elif isinstance(whisper_model, Transcriber):
self.transcriber = whisper_model
else:
warn(f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning)

return None

def update_diariser(self, dia_model : Union[str, DiarisationType], **kwargs) -> None:
"""
Update the diariser model.
Args:
dia_model (Union[str, DiarisationType]):
The new diariser model to use for diarization.
**kwargs:
Additional keyword arguments for the diariser model.
Returns:
None
"""
if isinstance(dia_model, str):
self.diariser = Diariser.load_model(dia_model, **kwargs)
elif isinstance(dia_model, Diariser):
self.diariser = dia_model
else:
warn(f"Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning)

return None

@staticmethod
def remove_audio_file(audio_file : str,
shred : bool = False) -> None:
Expand Down Expand Up @@ -269,7 +316,6 @@ def remove_audio_file(audio_file : str,
print(f"Audiofile {audio_file} removed.")



@staticmethod
def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray],
*args, **kwargs) -> AudioProcessor:
Expand Down Expand Up @@ -298,6 +344,7 @@ def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray],
if not isinstance(audio_file, AudioProcessor):
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
f'not {type(audio_file)}')

return audio_file

def __repr__(self):
Expand Down

0 comments on commit 0666548

Please sign in to comment.