Skip to content

Commit

Permalink
fix: correcting audio resampling in realtime streamer
Browse files Browse the repository at this point in the history
  • Loading branch information
John2360 committed Nov 11, 2024
1 parent fc51c1d commit 40eb0e5
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 96 deletions.
Empty file.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Iterator, Optional, Tuple
# from passive_sound_localization.models.configs.localization import LocalizationConfig
from models.configs.localization import LocalizationConfig # Only needed to run with `realtime_audio.py`
from passive_sound_localization.models.configs.localization import LocalizationConfig

# from models.configs.localization import LocalizationConfig # Only needed to run with `realtime_audio.py`
from dataclasses import dataclass
import numpy as np
import logging
Expand All @@ -10,41 +11,62 @@

class NoMicrophonePositionsError(Exception):
"""Exception raised when there are no microphone positions"""

def __init__(self) -> None:
super().__init__("No microphone positions were configured")


class TooFewMicrophonePositionsError(Exception):
"""Exception raised when there are less than 2 microphone positions"""

def __init__(self, num_mic_positions: int) -> None:
super().__init__(f"There should be at least 2 microphone positions. Currently only {num_mic_positions} microphone position(s) were configured")
super().__init__(
f"There should be at least 2 microphone positions. Currently only {num_mic_positions} microphone position(s) were configured"
)


class MicrophonePositionShapeError(Exception):
"""Exception raised when microphone positions don't match (x,y) pairs"""

def __init__(self, mic_position_shape: int) -> None:
super().__init__(f"The microphone positions should be in (x,y) pairs. Currently the microphone positions come in pairs of shape {mic_position_shape}")
super().__init__(
f"The microphone positions should be in (x,y) pairs. Currently the microphone positions come in pairs of shape {mic_position_shape}"
)


class NoMicrophoneStreamsError(Exception):
"""Exception raised when there are no microphone streams"""

def __init__(self) -> None:
super().__init__("No microphone streams were passed for localization")

pass


class TooFewMicrophoneStreamsError(Exception):
"""Exception raised when there are less than 2 microphone streams"""

def __init__(self, num_mic_streams: int) -> None:
super().__init__(f"There should be at least 2 microphone streams. Currently there are only {num_mic_streams} microphone streams")
super().__init__(
f"There should be at least 2 microphone streams. Currently there are only {num_mic_streams} microphone streams"
)


class MicrophoneStreamSizeMismatchError(Exception):
"""Exception raised when the number of microphone streams doesn't match the number of microphone positions"""
def __init__(self, num_mics:int, num_mic_streams) -> None:
super().__init__(f"The number of microphone streams should match the number of microphone positions. Currently there are {num_mic_streams} microphone streams and {num_mics} microphone positions")

def __init__(self, num_mics: int, num_mic_streams) -> None:
super().__init__(
f"The number of microphone streams should match the number of microphone positions. Currently there are {num_mic_streams} microphone streams and {num_mics} microphone positions"
)


@dataclass(frozen=True)
class LocalizationResult:
distance: float # Estimated distance to the sound source in meters
angle: float # Estimated angle to the sound source in degrees


class SoundLocalizer:
def __init__(self, config: LocalizationConfig):
# TODO: Make sure that mic position ordering matches the ordering of the microphone streams/device indices
Expand All @@ -53,18 +75,23 @@ def __init__(self, config: LocalizationConfig):
) # Get mic positions from config
if self.mic_positions.shape[0] == 0:
raise NoMicrophonePositionsError()

if self.mic_positions.shape[0] < 2:
raise TooFewMicrophonePositionsError(num_mic_positions=self.mic_positions.size)

raise TooFewMicrophonePositionsError(
num_mic_positions=self.mic_positions.size
)

if self.mic_positions.shape[1] < 2:
raise MicrophonePositionShapeError(mic_position_shape=self.mic_positions.shape[1])
raise MicrophonePositionShapeError(
mic_position_shape=self.mic_positions.shape[1]
)


self.speed_of_sound:float = config.speed_of_sound
self.sample_rate:int = config.sample_rate
self.fft_size:int = config.fft_size
self.num_mics:int = self.mic_positions.shape[0] # To be set when data is received
self.speed_of_sound: float = config.speed_of_sound
self.sample_rate: int = config.sample_rate
self.fft_size: int = config.fft_size
self.num_mics: int = self.mic_positions.shape[
0
] # To be set when data is received

# Generate circular plane of grid points for direction searching
self.grid_points = self._generate_circular_grid()
Expand All @@ -75,7 +102,7 @@ def __init__(self, config: LocalizationConfig):
self.phase_shifts = self._compute_all_phase_shifts(self.freqs)

# Initialize buffer for streaming
self.buffer:Optional[np.ndarray[np.float32]] = None
self.buffer: Optional[np.ndarray[np.float32]] = None

def localize_stream(
self, multi_channel_stream: List[bytes], num_sources: int = 1
Expand All @@ -99,10 +126,12 @@ def localize_stream(
if num_mic_streams < 2:
logger.error("At least two microphones are required for localization.")
raise TooFewMicrophoneStreamsError()

if self.num_mics != num_mic_streams:
raise MicrophoneStreamSizeMismatchError(num_mics=self.num_mics, num_mic_streams=num_mic_streams)

raise MicrophoneStreamSizeMismatchError(
num_mics=self.num_mics, num_mic_streams=num_mic_streams
)

# Convert buffers into numpy arrays
multi_channel_data = [
np.frombuffer(data, dtype=np.float32) for data in multi_channel_stream
Expand Down Expand Up @@ -157,7 +186,9 @@ def localize_stream(

logger.info("Real-time sound source localization completed.")

def _compute_cross_spectrum(self, mic_signals:np.ndarray[np.float32], fft_size:int=1024) -> np.ndarray[np.complex128]:
def _compute_cross_spectrum(
self, mic_signals: np.ndarray[np.float32], fft_size: int = 1024
) -> np.ndarray[np.complex128]:
"""Compute the cross-power spectrum between microphone pairs."""
# Correct shape: (num_mics, num_mics, fft_size // 2 + 1) for the rfft result

Expand All @@ -170,10 +201,14 @@ def _compute_cross_spectrum(self, mic_signals:np.ndarray[np.float32], fft_size:i
cross_spectrum = mic_fft[:, np.newaxis, :] * np.conj(mic_fft[np.newaxis, :, :])

return cross_spectrum

def _generate_circular_grid(
self, offset:float=0.45, radius:float=1.0, num_points_radial:int=50, num_points_angular:int=360
)-> np.ndarray[np.float32]:
self,
offset: float = 0.45,
radius: float = 1.0,
num_points_radial: int = 50,
num_points_angular: int = 360,
) -> np.ndarray[np.float32]:
"""Generate a grid of points on a circular plane, optimized for speed."""
# Create radial distances from 0 to the specified radius
r = np.linspace(offset, radius + offset, num_points_radial, dtype=np.float32)
Expand All @@ -189,16 +224,22 @@ def _generate_circular_grid(
# Return the points stacked as (x, y) pairs
return np.column_stack((x.ravel(), y.ravel()))

def _search_best_direction(self, cross_spectrum: np.ndarray[np.complex128]) -> Tuple[np.ndarray, np.float32, int]:
def _search_best_direction(
self, cross_spectrum: np.ndarray[np.complex128]
) -> Tuple[np.ndarray, np.float32, int]:
"""Search the circular grid for the direction with maximum beamformer output."""
energies = self._compute_beamformer_energies(cross_spectrum)
best_direction_idx = np.argmax(energies)
best_direction = self.grid_points[best_direction_idx]
estimated_distance = np.min(self.distances_to_mics[best_direction_idx])
print(f"Position of closest mic: {self.mic_positions[np.argmin(self.distances_to_mics[best_direction_idx])]}")
print(
f"Position of closest mic: {self.mic_positions[np.argmin(self.distances_to_mics[best_direction_idx])]}"
)
return best_direction, estimated_distance, best_direction_idx

def _compute_all_phase_shifts(self, freqs: np.ndarray[np.float32]) -> np.ndarray[np.complex128]:
def _compute_all_phase_shifts(
self, freqs: np.ndarray[np.float32]
) -> np.ndarray[np.complex128]:
"""
Precompute phase shifts for all grid points, microphone pairs, and frequency bins.
"""
Expand Down Expand Up @@ -231,11 +272,15 @@ def _compute_all_delays(self) -> Tuple[np.ndarray, np.ndarray]:
) # Shape: (num_grid_points, num_mics)

# Compute delays: distances divided by speed of sound
delays = distances_to_mics / self.speed_of_sound # Shape: (num_grid_points, num_mics)
delays = (
distances_to_mics / self.speed_of_sound
) # Shape: (num_grid_points, num_mics)

return distances_to_mics, delays

def _compute_beamformer_energies(self, cross_spectrum: np.ndarray[np.complex128]) -> np.ndarray:
def _compute_beamformer_energies(
self, cross_spectrum: np.ndarray[np.complex128]
) -> np.ndarray:
"""Compute the beamformer energy given the cross-spectrum and delays."""
cross_spectrum_expanded = cross_spectrum[np.newaxis, :, :, :]
# Multiply and sum over mics and frequency bins
Expand All @@ -245,7 +290,9 @@ def _compute_beamformer_energies(self, cross_spectrum: np.ndarray[np.complex128]
energies = np.abs(np.sum(product, axis=(1, 2, 3))) # Shape: (num_grid_points,)
return energies

def _remove_source_contribution(self, cross_spectrum: np.ndarray[np.complex128], source_idx: int) -> np.ndarray[np.complex128]:
def _remove_source_contribution(
self, cross_spectrum: np.ndarray[np.complex128], source_idx: int
) -> np.ndarray[np.complex128]:
"""
Remove the contribution of a localized source using vectorized operations.
"""
Expand All @@ -264,4 +311,4 @@ def compute_cartesian_coordinates(self, distance, angle):
"""
x = distance * np.cos(np.radians(angle))
y = distance * np.sin(np.radians(angle))
return x, y
return x, y
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
commands = Queue(maxsize=10)
locations = Queue(maxsize=10)

def send_audio_continuously(client, single_channel_generator):
print("Threading...")

def send_audio_continuously(client, single_channel_generator, logger):
logger.info("Sending audio to OpenAI")
for single_channel_audio in single_channel_generator:
client.send_audio(single_channel_audio)

Expand All @@ -29,16 +30,15 @@ def receive_text_messages(client, logger):
try:
command = client.receive_text_response()
if command and command.strip() == "MOVE_TO":
print(command)
logger.info(f"Received command: {command}")

if commands.full():
commands.get()
commands.task_done()
commands.put(command)
except Exception as e:
print(f"Error receiving response: {e}")
break # Exit loop if server disconnects
logger.error(f"Error receiving response: {e}")


def realtime_localization(multi_channel_stream, localizer, logger):
logger.info("Localization: Listening to audio stream")
Expand All @@ -53,16 +53,17 @@ def realtime_localization(multi_channel_stream, localizer, logger):
for localization_results in localization_stream:
if locations.full():
locations.get()

locations.put(localization_results)

if did_get:
locations.task_done()
did_get = False

except Exception as e:
print(f"Realtime Localization error: {e}")


def command_executor(publisher, logger):
logger.info("Executor: listening for command")
while True:
Expand All @@ -73,7 +74,7 @@ def command_executor(publisher, logger):
commands.task_done()
if locations.qsize() > 0:
publisher(locations.get())

locations.task_done()
except Exception as e:
print(f"Command executor error: {e}")
Expand All @@ -100,7 +101,7 @@ def __init__(self):
("realtime_streamer.sample_rate", 24000),
("realtime_streamer.channels", 1),
("realtime_streamer.chunk", 1024),
("realtime_streamer.device_indices", [2, 3, 4, 5])
("realtime_streamer.device_indices", [2, 3, 4, 5]),
],
)

Expand Down Expand Up @@ -133,10 +134,19 @@ def process_audio(self):

with ThreadPoolExecutor(max_workers=4) as executor:
self.logger.info("Threading log")
executor.submit(send_audio_continuously, client, single_channel_stream)
executor.submit(
send_audio_continuously, client, single_channel_stream, self.logger
)
executor.submit(receive_text_messages, client, self.logger)
executor.submit(realtime_localization, multi_channel_stream, self.localizer, self.logger)
executor.submit(command_executor, lambda x: self.publish_results(x), self.logger)
executor.submit(
realtime_localization,
multi_channel_stream,
self.localizer,
self.logger,
)
executor.submit(
command_executor, lambda x: self.publish_results(x), self.logger
)

def publish_results(self, localization_results):
# Publish results to ROS topic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from io import BytesIO


# from passive_sound_localization.models.configs.realtime_streamer import (
# RealtimeAudioStreamerConfig,
# )

from models.configs.realtime_streamer import (
from passive_sound_localization.models.configs.realtime_streamer import (
RealtimeAudioStreamerConfig,
) # Only needed to run with `realtime_audio.py`
)

# from models.configs.realtime_streamer import (
# RealtimeAudioStreamerConfig,
# ) # Only needed to run with `realtime_audio.py`

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,19 +87,27 @@ def multi_channel_gen(self) -> Generator[Optional[Dict[int, bytes]], None, None]

def merge_streams(self, streams: List[np.ndarray]) -> np.ndarray:
return np.sum(streams, axis=0) / len(streams)

def resample_stream(self, stream: bytes, target_sample_rate: int = 24000, sample_width: int=2) -> bytes:
try:
audio = AudioSegment.from_file(BytesIO(stream))

# Resample to 24kHz mono pcm16
return audio.set_frame_rate(target_sample_rate).set_channels(self.channels).set_sample_width(sample_width).raw_data
def resample_stream(
self, stream: bytes, target_sample_rate: int = 24000, sample_width: int = 2
) -> bytes:
try:
audio_data_int16 = (stream * 32767).astype(np.int16)
audio_segment = AudioSegment(
audio_data_int16.tobytes(),
frame_rate=self.sample_rate,
sample_width=audio_data_int16.dtype.itemsize,
channels=self.channels,
)

# Resample the audio to 24000 Hz
audio_segment_resampled = audio_segment.set_frame_rate(target_sample_rate)
return audio_segment_resampled.get_array_of_samples().tobytes()

except Exception as e:
print(f"Error in resample_stream: {e}")
return b""


def single_channel_gen(self) -> Generator[Optional[bytes], None, None]:
try:
while self.streaming:
Expand All @@ -114,4 +122,4 @@ def single_channel_gen(self) -> Generator[Optional[bytes], None, None]:
else:
yield None
except Exception as e:
print(f"Error in single_channel_gen: {e}")
print(f"Error in single_channel_gen: {e}")
Loading

0 comments on commit 40eb0e5

Please sign in to comment.