Skip to content

Commit

Permalink
Merge pull request #23 from nicolasperez19/localization-cleanup
Browse files Browse the repository at this point in the history
chore: Just cleaned up localization code
  • Loading branch information
nicolasperez19 authored Nov 10, 2024
2 parents a6d5606 + 4b1e3e8 commit 9a18c81
Showing 1 changed file with 18 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,34 @@

class NoMicrophonePositionsError(Exception):
"""Exception raised when there are no microphone positions"""
pass
def __init__(self) -> None:
super().__init__("No microphone positions were configured")

class TooFewMicrophonePositionsError(Exception):
"""Exception raised when there are less than 2 microphone positions"""
pass
def __init__(self, num_mic_positions: int) -> None:
super().__init__(f"There should be at least 2 microphone positions. Currently only {num_mic_positions} microphone position(s) were configured")

class MicrophonePositionShapeError(Exception):
"""Exception raised when microphone positions don't match (x,y) pairs"""
pass
def __init__(self, mic_position_shape: int) -> None:
super().__init__(f"The microphone positions should be in (x,y) pairs. Currently the microphone positions come in pairs of shape {mic_position_shape}")

class NoMicrophoneStreamsError(Exception):
"""Exception raised when there are no microphone streams"""
def __init__(self) -> None:
super().__init__("No microphone streams were passed for localization")
pass

class TooFewMicrophoneStreamsError(Exception):
"""Exception raised when there are less than 2 microphone streams"""
pass
def __init__(self, num_mic_streams: int) -> None:
super().__init__(f"There should be at least 2 microphone streams. Currently there are only {num_mic_streams} microphone streams")

class MicrophoneStreamSizeMismatchError(Exception):
"""Exception raised when the number of microphone streams doesn't match the number of microphone positions"""
pass
def __init__(self, num_mics:int, num_mic_streams) -> None:
super().__init__(f"The number of microphone streams should match the number of microphone positions. Currently there are {num_mic_streams} microphone streams and {num_mics} microphone positions")


@dataclass(frozen=True)
Expand All @@ -40,19 +47,18 @@ class LocalizationResult:

class SoundLocalizer:
def __init__(self, config: LocalizationConfig):
# TODO: Do more granular error checking in init
# TODO: Make sure that mic position ordering matches the ordering of the microphone streams/device indices
self.mic_positions = np.array(
config.mic_positions, dtype=np.float32
) # Get mic positions from config
if self.mic_positions.shape[0] == 0:
raise NoMicrophonePositionsError("No microphone positions were configured")
raise NoMicrophonePositionsError()

if self.mic_positions.shape[0] < 2:
raise TooFewMicrophonePositionsError(f"There should be at least 2 microphone positions. Currently only {self.mic_positions.size} microphone position(s) were configured")
raise TooFewMicrophonePositionsError(num_mic_positions=self.mic_positions.size)

if self.mic_positions.shape[1] < 2:
raise MicrophonePositionShapeError(f"The microphone positions should be in (x,y) pairs. Currently the microphone positions come in pairs of shape {self.mic_positions.shape[1]}")
raise MicrophonePositionShapeError(mic_position_shape=self.mic_positions.shape[1])


self.speed_of_sound:float = config.speed_of_sound
Expand Down Expand Up @@ -89,23 +95,20 @@ def localize_stream(
num_mic_streams = len(multi_channel_stream)

if num_mic_streams == 0:
raise NoMicrophoneStreamsError("No microphone streams were passed for localization")
raise NoMicrophoneStreamsError()
if num_mic_streams < 2:
logger.error("At least two microphones are required for localization.")
raise TooFewMicrophoneStreamsError(f"There should be at least 2 microphone streams. Currently there are only {num_mic_streams} microphone streams")
raise TooFewMicrophoneStreamsError()

if self.num_mics != num_mic_streams:
raise MicrophoneStreamSizeMismatchError(f"The number of microphone streams should match the number of microphone positions. Currently there are {num_mic_streams} microphone streams and {self.num_mics} microphone positions")
raise MicrophoneStreamSizeMismatchError(num_mics=self.num_mics, num_mic_streams=num_mic_streams)

# Convert buffers into numpy arrays
multi_channel_data = [
np.frombuffer(data, dtype=np.float32) for data in multi_channel_stream
]

# TODO: Refactor buffer processing into its own method for testing

# Stack the multi-channel data into a 2D array (num_mics x num_samples) and replace any na values with zeroes
# data = np.vstack(multi_channel_data)
data = np.nan_to_num(np.vstack(multi_channel_data))

# Initialize buffer if it's the first chunk
Expand Down Expand Up @@ -157,23 +160,12 @@ def localize_stream(
def _compute_cross_spectrum(self, mic_signals:np.ndarray[np.float32], fft_size:int=1024) -> np.ndarray[np.complex128]:
"""Compute the cross-power spectrum between microphone pairs."""
# Correct shape: (num_mics, num_mics, fft_size // 2 + 1) for the rfft result
# cross_spectrum = np.zeros(
# (self.num_mics, self.num_mics, fft_size // 2 + 1), dtype=np.complex64
# )

# print(f"Does the mic_signals have any na values or inf values {np.isnan(mic_signals).any() or np.isinf(mic_signals).any()}")

mic_signals = mic_signals.astype(np.float64)

# Compute the FFT of each microphone signal
mic_fft = np.fft.rfft(mic_signals, fft_size)

# Compute the cross-power spectrum for each microphone pair
# for i in range(self.num_mics):
# for j in range(i, self.num_mics):
# cross_spectrum[i, j] = mic_fft[i] * np.conj(mic_fft[j])
# cross_spectrum[j, i] = np.conj(cross_spectrum[i, j])

# Compute the cross-power spectrum for each microphone pair using broadcasting
cross_spectrum = mic_fft[:, np.newaxis, :] * np.conj(mic_fft[np.newaxis, :, :])

Expand Down

0 comments on commit 9a18c81

Please sign in to comment.