From 40eb0e52d527e7edaaabf15e8e03741665ef1fca Mon Sep 17 00:00:00 2001 From: John Finberg Date: Mon, 11 Nov 2024 17:32:56 -0500 Subject: [PATCH] fix: correcting audio resampling in realtime streamer --- .../resource/passive_sound_localization | 0 .../localization.py | 109 +++++++++++++----- .../passive_sound_localization/main.py | 38 +++--- .../realtime_audio_streamer.py | 36 +++--- .../realtime_openai_websocket.py | 97 ++++++++++------ 5 files changed, 184 insertions(+), 96 deletions(-) delete mode 100644 packages/movement_library/resource/passive_sound_localization diff --git a/packages/movement_library/resource/passive_sound_localization b/packages/movement_library/resource/passive_sound_localization deleted file mode 100644 index e69de29..0000000 diff --git a/packages/passive_sound_localization/passive_sound_localization/localization.py b/packages/passive_sound_localization/passive_sound_localization/localization.py index cd65a07..22bbda9 100644 --- a/packages/passive_sound_localization/passive_sound_localization/localization.py +++ b/packages/passive_sound_localization/passive_sound_localization/localization.py @@ -1,6 +1,7 @@ from typing import List, Iterator, Optional, Tuple -# from passive_sound_localization.models.configs.localization import LocalizationConfig -from models.configs.localization import LocalizationConfig # Only needed to run with `realtime_audio.py` +from passive_sound_localization.models.configs.localization import LocalizationConfig + +# from models.configs.localization import LocalizationConfig # Only needed to run with `realtime_audio.py` from dataclasses import dataclass import numpy as np import logging @@ -10,34 +11,54 @@ class NoMicrophonePositionsError(Exception): """Exception raised when there are no microphone positions""" + def __init__(self) -> None: super().__init__("No microphone positions were configured") + class TooFewMicrophonePositionsError(Exception): """Exception raised when there are less than 2 microphone positions""" + def __init__(self, num_mic_positions: int) -> None: - super().__init__(f"There should be at least 2 microphone positions. Currently only {num_mic_positions} microphone position(s) were configured") + super().__init__( + f"There should be at least 2 microphone positions. Currently only {num_mic_positions} microphone position(s) were configured" + ) + class MicrophonePositionShapeError(Exception): """Exception raised when microphone positions don't match (x,y) pairs""" + def __init__(self, mic_position_shape: int) -> None: - super().__init__(f"The microphone positions should be in (x,y) pairs. Currently the microphone positions come in pairs of shape {mic_position_shape}") + super().__init__( + f"The microphone positions should be in (x,y) pairs. Currently the microphone positions come in pairs of shape {mic_position_shape}" + ) + class NoMicrophoneStreamsError(Exception): """Exception raised when there are no microphone streams""" + def __init__(self) -> None: super().__init__("No microphone streams were passed for localization") + pass + class TooFewMicrophoneStreamsError(Exception): """Exception raised when there are less than 2 microphone streams""" + def __init__(self, num_mic_streams: int) -> None: - super().__init__(f"There should be at least 2 microphone streams. Currently there are only {num_mic_streams} microphone streams") + super().__init__( + f"There should be at least 2 microphone streams. Currently there are only {num_mic_streams} microphone streams" + ) + class MicrophoneStreamSizeMismatchError(Exception): """Exception raised when the number of microphone streams doesn't match the number of microphone positions""" - def __init__(self, num_mics:int, num_mic_streams) -> None: - super().__init__(f"The number of microphone streams should match the number of microphone positions. Currently there are {num_mic_streams} microphone streams and {num_mics} microphone positions") + + def __init__(self, num_mics: int, num_mic_streams) -> None: + super().__init__( + f"The number of microphone streams should match the number of microphone positions. Currently there are {num_mic_streams} microphone streams and {num_mics} microphone positions" + ) @dataclass(frozen=True) @@ -45,6 +66,7 @@ class LocalizationResult: distance: float # Estimated distance to the sound source in meters angle: float # Estimated angle to the sound source in degrees + class SoundLocalizer: def __init__(self, config: LocalizationConfig): # TODO: Make sure that mic position ordering matches the ordering of the microphone streams/device indices @@ -53,18 +75,23 @@ def __init__(self, config: LocalizationConfig): ) # Get mic positions from config if self.mic_positions.shape[0] == 0: raise NoMicrophonePositionsError() - + if self.mic_positions.shape[0] < 2: - raise TooFewMicrophonePositionsError(num_mic_positions=self.mic_positions.size) - + raise TooFewMicrophonePositionsError( + num_mic_positions=self.mic_positions.size + ) + if self.mic_positions.shape[1] < 2: - raise MicrophonePositionShapeError(mic_position_shape=self.mic_positions.shape[1]) + raise MicrophonePositionShapeError( + mic_position_shape=self.mic_positions.shape[1] + ) - - self.speed_of_sound:float = config.speed_of_sound - self.sample_rate:int = config.sample_rate - self.fft_size:int = config.fft_size - self.num_mics:int = self.mic_positions.shape[0] # To be set when data is received + self.speed_of_sound: float = config.speed_of_sound + self.sample_rate: int = config.sample_rate + self.fft_size: int = config.fft_size + self.num_mics: int = self.mic_positions.shape[ + 0 + ] # To be set when data is received # Generate circular plane of grid points for direction searching self.grid_points = self._generate_circular_grid() @@ -75,7 +102,7 @@ def __init__(self, config: LocalizationConfig): self.phase_shifts = self._compute_all_phase_shifts(self.freqs) # Initialize buffer for streaming - self.buffer:Optional[np.ndarray[np.float32]] = None + self.buffer: Optional[np.ndarray[np.float32]] = None def localize_stream( self, multi_channel_stream: List[bytes], num_sources: int = 1 @@ -99,10 +126,12 @@ def localize_stream( if num_mic_streams < 2: logger.error("At least two microphones are required for localization.") raise TooFewMicrophoneStreamsError() - + if self.num_mics != num_mic_streams: - raise MicrophoneStreamSizeMismatchError(num_mics=self.num_mics, num_mic_streams=num_mic_streams) - + raise MicrophoneStreamSizeMismatchError( + num_mics=self.num_mics, num_mic_streams=num_mic_streams + ) + # Convert buffers into numpy arrays multi_channel_data = [ np.frombuffer(data, dtype=np.float32) for data in multi_channel_stream @@ -157,7 +186,9 @@ def localize_stream( logger.info("Real-time sound source localization completed.") - def _compute_cross_spectrum(self, mic_signals:np.ndarray[np.float32], fft_size:int=1024) -> np.ndarray[np.complex128]: + def _compute_cross_spectrum( + self, mic_signals: np.ndarray[np.float32], fft_size: int = 1024 + ) -> np.ndarray[np.complex128]: """Compute the cross-power spectrum between microphone pairs.""" # Correct shape: (num_mics, num_mics, fft_size // 2 + 1) for the rfft result @@ -170,10 +201,14 @@ def _compute_cross_spectrum(self, mic_signals:np.ndarray[np.float32], fft_size:i cross_spectrum = mic_fft[:, np.newaxis, :] * np.conj(mic_fft[np.newaxis, :, :]) return cross_spectrum - + def _generate_circular_grid( - self, offset:float=0.45, radius:float=1.0, num_points_radial:int=50, num_points_angular:int=360 - )-> np.ndarray[np.float32]: + self, + offset: float = 0.45, + radius: float = 1.0, + num_points_radial: int = 50, + num_points_angular: int = 360, + ) -> np.ndarray[np.float32]: """Generate a grid of points on a circular plane, optimized for speed.""" # Create radial distances from 0 to the specified radius r = np.linspace(offset, radius + offset, num_points_radial, dtype=np.float32) @@ -189,16 +224,22 @@ def _generate_circular_grid( # Return the points stacked as (x, y) pairs return np.column_stack((x.ravel(), y.ravel())) - def _search_best_direction(self, cross_spectrum: np.ndarray[np.complex128]) -> Tuple[np.ndarray, np.float32, int]: + def _search_best_direction( + self, cross_spectrum: np.ndarray[np.complex128] + ) -> Tuple[np.ndarray, np.float32, int]: """Search the circular grid for the direction with maximum beamformer output.""" energies = self._compute_beamformer_energies(cross_spectrum) best_direction_idx = np.argmax(energies) best_direction = self.grid_points[best_direction_idx] estimated_distance = np.min(self.distances_to_mics[best_direction_idx]) - print(f"Position of closest mic: {self.mic_positions[np.argmin(self.distances_to_mics[best_direction_idx])]}") + print( + f"Position of closest mic: {self.mic_positions[np.argmin(self.distances_to_mics[best_direction_idx])]}" + ) return best_direction, estimated_distance, best_direction_idx - def _compute_all_phase_shifts(self, freqs: np.ndarray[np.float32]) -> np.ndarray[np.complex128]: + def _compute_all_phase_shifts( + self, freqs: np.ndarray[np.float32] + ) -> np.ndarray[np.complex128]: """ Precompute phase shifts for all grid points, microphone pairs, and frequency bins. """ @@ -231,11 +272,15 @@ def _compute_all_delays(self) -> Tuple[np.ndarray, np.ndarray]: ) # Shape: (num_grid_points, num_mics) # Compute delays: distances divided by speed of sound - delays = distances_to_mics / self.speed_of_sound # Shape: (num_grid_points, num_mics) + delays = ( + distances_to_mics / self.speed_of_sound + ) # Shape: (num_grid_points, num_mics) return distances_to_mics, delays - def _compute_beamformer_energies(self, cross_spectrum: np.ndarray[np.complex128]) -> np.ndarray: + def _compute_beamformer_energies( + self, cross_spectrum: np.ndarray[np.complex128] + ) -> np.ndarray: """Compute the beamformer energy given the cross-spectrum and delays.""" cross_spectrum_expanded = cross_spectrum[np.newaxis, :, :, :] # Multiply and sum over mics and frequency bins @@ -245,7 +290,9 @@ def _compute_beamformer_energies(self, cross_spectrum: np.ndarray[np.complex128] energies = np.abs(np.sum(product, axis=(1, 2, 3))) # Shape: (num_grid_points,) return energies - def _remove_source_contribution(self, cross_spectrum: np.ndarray[np.complex128], source_idx: int) -> np.ndarray[np.complex128]: + def _remove_source_contribution( + self, cross_spectrum: np.ndarray[np.complex128], source_idx: int + ) -> np.ndarray[np.complex128]: """ Remove the contribution of a localized source using vectorized operations. """ @@ -264,4 +311,4 @@ def compute_cartesian_coordinates(self, distance, angle): """ x = distance * np.cos(np.radians(angle)) y = distance * np.sin(np.radians(angle)) - return x, y \ No newline at end of file + return x, y diff --git a/packages/passive_sound_localization/passive_sound_localization/main.py b/packages/passive_sound_localization/passive_sound_localization/main.py index f0d267f..8e12b85 100644 --- a/packages/passive_sound_localization/passive_sound_localization/main.py +++ b/packages/passive_sound_localization/passive_sound_localization/main.py @@ -17,8 +17,9 @@ commands = Queue(maxsize=10) locations = Queue(maxsize=10) -def send_audio_continuously(client, single_channel_generator): - print("Threading...") + +def send_audio_continuously(client, single_channel_generator, logger): + logger.info("Sending audio to OpenAI") for single_channel_audio in single_channel_generator: client.send_audio(single_channel_audio) @@ -29,16 +30,15 @@ def receive_text_messages(client, logger): try: command = client.receive_text_response() if command and command.strip() == "MOVE_TO": - print(command) logger.info(f"Received command: {command}") - + if commands.full(): commands.get() commands.task_done() commands.put(command) except Exception as e: - print(f"Error receiving response: {e}") - break # Exit loop if server disconnects + logger.error(f"Error receiving response: {e}") + def realtime_localization(multi_channel_stream, localizer, logger): logger.info("Localization: Listening to audio stream") @@ -53,16 +53,17 @@ def realtime_localization(multi_channel_stream, localizer, logger): for localization_results in localization_stream: if locations.full(): locations.get() - + locations.put(localization_results) - + if did_get: locations.task_done() did_get = False - + except Exception as e: print(f"Realtime Localization error: {e}") + def command_executor(publisher, logger): logger.info("Executor: listening for command") while True: @@ -73,7 +74,7 @@ def command_executor(publisher, logger): commands.task_done() if locations.qsize() > 0: publisher(locations.get()) - + locations.task_done() except Exception as e: print(f"Command executor error: {e}") @@ -100,7 +101,7 @@ def __init__(self): ("realtime_streamer.sample_rate", 24000), ("realtime_streamer.channels", 1), ("realtime_streamer.chunk", 1024), - ("realtime_streamer.device_indices", [2, 3, 4, 5]) + ("realtime_streamer.device_indices", [2, 3, 4, 5]), ], ) @@ -133,10 +134,19 @@ def process_audio(self): with ThreadPoolExecutor(max_workers=4) as executor: self.logger.info("Threading log") - executor.submit(send_audio_continuously, client, single_channel_stream) + executor.submit( + send_audio_continuously, client, single_channel_stream, self.logger + ) executor.submit(receive_text_messages, client, self.logger) - executor.submit(realtime_localization, multi_channel_stream, self.localizer, self.logger) - executor.submit(command_executor, lambda x: self.publish_results(x), self.logger) + executor.submit( + realtime_localization, + multi_channel_stream, + self.localizer, + self.logger, + ) + executor.submit( + command_executor, lambda x: self.publish_results(x), self.logger + ) def publish_results(self, localization_results): # Publish results to ROS topic diff --git a/packages/passive_sound_localization/passive_sound_localization/realtime_audio_streamer.py b/packages/passive_sound_localization/passive_sound_localization/realtime_audio_streamer.py index 33d973e..d81fef1 100644 --- a/packages/passive_sound_localization/passive_sound_localization/realtime_audio_streamer.py +++ b/packages/passive_sound_localization/passive_sound_localization/realtime_audio_streamer.py @@ -7,13 +7,13 @@ from io import BytesIO -# from passive_sound_localization.models.configs.realtime_streamer import ( -# RealtimeAudioStreamerConfig, -# ) - -from models.configs.realtime_streamer import ( +from passive_sound_localization.models.configs.realtime_streamer import ( RealtimeAudioStreamerConfig, -) # Only needed to run with `realtime_audio.py` +) + +# from models.configs.realtime_streamer import ( +# RealtimeAudioStreamerConfig, +# ) # Only needed to run with `realtime_audio.py` logger = logging.getLogger(__name__) @@ -87,19 +87,27 @@ def multi_channel_gen(self) -> Generator[Optional[Dict[int, bytes]], None, None] def merge_streams(self, streams: List[np.ndarray]) -> np.ndarray: return np.sum(streams, axis=0) / len(streams) - - def resample_stream(self, stream: bytes, target_sample_rate: int = 24000, sample_width: int=2) -> bytes: - try: - audio = AudioSegment.from_file(BytesIO(stream)) - # Resample to 24kHz mono pcm16 - return audio.set_frame_rate(target_sample_rate).set_channels(self.channels).set_sample_width(sample_width).raw_data + def resample_stream( + self, stream: bytes, target_sample_rate: int = 24000, sample_width: int = 2 + ) -> bytes: + try: + audio_data_int16 = (stream * 32767).astype(np.int16) + audio_segment = AudioSegment( + audio_data_int16.tobytes(), + frame_rate=self.sample_rate, + sample_width=audio_data_int16.dtype.itemsize, + channels=self.channels, + ) + + # Resample the audio to 24000 Hz + audio_segment_resampled = audio_segment.set_frame_rate(target_sample_rate) + return audio_segment_resampled.get_array_of_samples().tobytes() except Exception as e: print(f"Error in resample_stream: {e}") return b"" - def single_channel_gen(self) -> Generator[Optional[bytes], None, None]: try: while self.streaming: @@ -114,4 +122,4 @@ def single_channel_gen(self) -> Generator[Optional[bytes], None, None]: else: yield None except Exception as e: - print(f"Error in single_channel_gen: {e}") \ No newline at end of file + print(f"Error in single_channel_gen: {e}") diff --git a/packages/passive_sound_localization/passive_sound_localization/realtime_openai_websocket.py b/packages/passive_sound_localization/passive_sound_localization/realtime_openai_websocket.py index aba6d21..13713ae 100644 --- a/packages/passive_sound_localization/passive_sound_localization/realtime_openai_websocket.py +++ b/packages/passive_sound_localization/passive_sound_localization/realtime_openai_websocket.py @@ -5,34 +5,49 @@ import logging from typing import Optional -# from passive_sound_localization.models.configs.openai_websocket import OpenAIWebsocketConfig -from models.configs.openai_websocket import OpenAIWebsocketConfig # Only needed to run with `realtime_audio.py` +from passive_sound_localization.models.configs.openai_websocket import ( + OpenAIWebsocketConfig, +) + +# from models.configs.openai_websocket import ( +# OpenAIWebsocketConfig, +# ) # Only needed to run with `realtime_audio.py` logger = logging.getLogger(__name__) + class InvalidWebsocketURIError(Exception): def __init__(self, websocket_url: str) -> None: super().__init__(f"Invalid Websocker URI was passed: {websocket_url}") + class SessionNotCreatedError(Exception): def __init__(self) -> None: super().__init__("Session was not created") + class SessionNotUpdatedError(Exception): def __init__(self) -> None: super().__init__("Session was not updated") + class OpenAIWebsocketError(Exception): def __init__(self, error_code: str, error_message: str) -> None: - super().__init__(f"OpenAI websocket erred with error type `{error_code}`: {error_message}") + super().__init__( + f"OpenAI websocket erred with error type `{error_code}`: {error_message}" + ) + class OpenAIRateLimitError(Exception): def __init__(self) -> None: super().__init__("Hit OpenAI Realtime API rate limit") + class OpenAITimeoutError(Exception): def __init__(self, timeout: float) -> None: - super().__init__(f"OpenAI websocket timed out because it did not receive a message in {timeout} seconds") + super().__init__( + f"OpenAI websocket timed out because it did not receive a message in {timeout} seconds" + ) INSTRUCTIONS = """ @@ -48,7 +63,7 @@ def __init__(self, timeout: float) -> None: # TODO: Make it take in Hydra config class OpenAIWebsocketClient: - def __init__(self, config: OpenAIWebsocketConfig): + def __init__(self, config: OpenAIWebsocketConfig): self.api_key: str = config.api_key self.websocket_url: str = config.websocket_url self.session_id: Optional[str] = None @@ -60,7 +75,7 @@ def __enter__(self): self._configure_session() print("Connected websocket...") return self - + def __exit__(self): self._close() @@ -69,8 +84,8 @@ def _connect(self) -> None: uri=self.websocket_url, additional_headers={ "Authorization": f"Bearer {self.api_key}", - "OpenAI-Beta": "realtime=v1" - } + "OpenAI-Beta": "realtime=v1", + }, ) message = json.loads(self.ws.recv()) @@ -79,57 +94,65 @@ def _connect(self) -> None: raise SessionNotCreatedError() def _configure_session(self) -> None: - self.ws.send(json.dumps({ - "type": "session.update", - "session": { - "modalities": ["text"], - "instructions": self.instructions, - "input_audio_format": "pcm16", - "turn_detection": { - "type": "server_vad", - "threshold": 0.5, - "prefix_padding_ms": 300, - "silence_duration_ms": 500 - }, - "temperature": 0.8, - "max_response_output_tokens": 4096 - } - })) + self.ws.send( + json.dumps( + { + "type": "session.update", + "session": { + "modalities": ["text"], + "instructions": self.instructions, + "input_audio_format": "pcm16", + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "prefix_padding_ms": 300, + "silence_duration_ms": 500, + }, + "temperature": 0.8, + "max_response_output_tokens": 4096, + }, + } + ) + ) message = json.loads(self.ws.recv()) if message["type"] != "session.updated": raise SessionNotUpdatedError() - - def send_audio(self, audio_chunk: bytes) -> None: # Audio needs to be encoded in Base64 before being sent to the OpenAI Realtime API audio_b64 = base64.b64encode(audio_chunk).decode() - self.ws.send(json.dumps({ - "type": "input_audio_buffer.append", - "audio": audio_b64 - })) + self.ws.send( + json.dumps({"type": "input_audio_buffer.append", "audio": audio_b64}) + ) - def receive_text_response(self, timeout:float=5.0) -> str: + def receive_text_response(self, timeout: Optional[float] = None) -> str: try: # Tries to receive the next message (in a blocking manner) from the OpenAI websocket # If the message doesn't arrive in 300ms, then it raises a TimeoutError message = json.loads(self.ws.recv(timeout=timeout)) except TimeoutError: raise OpenAITimeoutError(timeout=timeout) - + # Print message just to see what we're receiving # print(message) # Checks to see any general errors if message["type"] == "error": - raise OpenAIWebsocketError(error_code=message["error"]["code"], error_message=["error"]["message"]) - + raise OpenAIWebsocketError( + error_code=message["error"]["code"], + error_message=message["error"]["message"], + ) + # Checks to see whether OpenAI is specifically rate limiting our responses - if message["type"] == "response.done" and message["response"]["status_details"]["error"]["code"] == "rate_limit_exceeded": + if ( + message["type"] == "response.done" + and message["response"]["status_details"]["error"]["code"] + == "rate_limit_exceeded" + ): raise OpenAIRateLimitError() - + # Checks to see if an actual text response was sent, and returns the text if message["type"] == "response.text.done": return message["text"] @@ -137,4 +160,4 @@ def receive_text_response(self, timeout:float=5.0) -> str: def _close(self) -> None: if self.ws: self.ws.close() - self.ws = None \ No newline at end of file + self.ws = None