Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added smooth SRP and initial generalized sidelobe canceller #3

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
171 changes: 171 additions & 0 deletions new_localization/example_svrp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import numpy as np
import torch
import torchaudio
from abc import ABC, abstractmethod

from xsrp_project.xsrp.grids import Grid


class XSrp(ABC):
def __init__(self, fs: float, mic_positions=None, room_dims=None, c=343.0):
self.fs = fs
self.mic_positions = mic_positions
self.room_dims = room_dims
self.c = c

self.n_mics = len(mic_positions)

# 0. Create the initial grid of candidate positions
self.candidate_grid = self.create_initial_candidate_grid(room_dims)

# ---- Load Silero VAD model here ----
self.vad_model, self.vad_utils = torch.hub.load(
repo_or_dir="snakers4/silero-vad", model="silero_vad", source="github"
)
(self.get_speech_timestamps, _, _, _) = self.vad_utils

def smooth_srp_map(self, srp_map, window_size: int = 5):
"""
Apply a moving average to the SRP map to smooth it.
Since srp_map might be 1D or a flattened 2D, adjust as needed.
"""
window = np.ones(window_size) / window_size
return np.convolve(srp_map, window, "valid")

def apply_gsc(self, data, peaks, sidelobe_reduction=0.5):
"""
Apply Generalized Sidelobe Canceller to boost signal at peaks and reduce elsewhere.
"""
mask = np.zeros_like(data)
for peak in peaks:
mask[peak] = 1

mask = self.smooth_srp_map(
mask, window_size=10
) # Smooth the mask to create transition regions
return data * (1 + mask * sidelobe_reduction)

def classify_candidates_with_vad(
self, mic_signals, top_candidates, segment_duration=0.5
):
"""
Classify each candidate as speech or non-speech using silero-vad.

This is a placeholder method. It assumes:
- You can extract a segment of audio from mic_signals for each candidate.
- Here, we dynamically analyze each candidate individually.
In a real scenario, you'd beamform towards each candidate and pick a segment of the beamformed output.
"""
vad_labels = []
for candidate in top_candidates:
# Use the signal from the first microphone for simplicity
audio_data = torch.from_numpy(mic_signals[0].astype(np.float32))

# Dynamically determine the segment based on the presence of strong signals
signal_threshold = 0.1 * np.max(audio_data.numpy()) # Example threshold
significant_indices = np.where(audio_data.numpy() > signal_threshold)[0]
if len(significant_indices) > 0:
center_index = significant_indices[len(significant_indices) // 2]
else:
center_index = len(audio_data) // 2 # Fallback to center of the signal

# Determine the number of samples for the segment
half_segment_samples = int((segment_duration * self.fs) / 2)
start = max(0, center_index - half_segment_samples)
end = min(len(audio_data), center_index + half_segment_samples)
segment = audio_data[start:end]

# Apply VAD
speech_timestamps = self.get_speech_timestamps(
segment, self.vad_model, sampling_rate=self.fs
)

# Label candidate as speech or non-speech based on VAD
is_speech = 1 if len(speech_timestamps) > 0 else 0
vad_labels.append(is_speech)

return vad_labels

def forward(
self, mic_signals, mic_positions=None, room_dims=None, n_best: int = 4
) -> tuple[np.array, np.array, Grid]:
if mic_positions is None:
mic_positions = self.mic_positions
if room_dims is None:
room_dims = self.room_dims

if mic_positions is None:
raise ValueError(
"mic_positions and room_dims must be specified either in the constructor or in the forward method"
)

candidate_grid = self.candidate_grid

estimated_positions = np.array([])

# 1. Compute the signal features (e.g., GCC-PHAT)
signal_features = self.compute_signal_features(mic_signals)

while True:
# 2. Create the SRP map
srp_map = self.create_srp_map(
mic_positions, candidate_grid, signal_features
)

# Find top n candidates
top_indices = np.argsort(srp_map)[-n_best:]
top_candidates = candidate_grid[top_indices]

# Smooth SRP map
srp_map = self.smooth_srp_map(srp_map=srp_map)

# Apply GSC
srp_map = self.apply_gsc(srp_map, top_indices)

# ---- Integrate VAD Classification ----
# Classify the top candidates as speech or non-speech (1 as speech, 0 as non speech)
vad_labels = self.classify_candidates_with_vad(
mic_signals, top_candidates, segment_duration=0.5
)

# Here you can use vad_labels to filter non-speech candidates
# For example, remove candidates that are non-speech or reduce their SRP scores
# This is just an example:
for idx, label in enumerate(vad_labels):
if label == 0:
# Reduce SRP score for non-speech candidates
srp_map[top_indices[idx]] *= 0.1

# 3. Grid search step (refine candidate grid)
estimated_positions, new_candidate_grid, signal_features = self.grid_search(
candidate_grid, srp_map, estimated_positions, signal_features
)

# 4. Update candidate grid
if len(new_candidate_grid) == 0:
# If no new candidates, we're done
break
else:
candidate_grid = new_candidate_grid

return estimated_positions, srp_map, candidate_grid

@abstractmethod
def compute_signal_features(self, mic_signals):
pass

@abstractmethod
def create_initial_candidate_grid(self, room_dims):
pass

@abstractmethod
def create_srp_map(
self, mic_positions: np.array, candidate_grid: Grid, signal_features: np.array
):
pass

@abstractmethod
def grid_search(
self, candidate_grid, srp_map, estimated_positions, signal_features
) -> tuple[np.array, np.array]:
pass
132 changes: 110 additions & 22 deletions new_localization/main.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,107 @@
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sounddevice as sd
import noisereduce as nr
from xsrp_project.xsrp.conventional_srp import ConventionalSrp
from scipy.signal import butter, filtfilt


def butter_bandpass(lowcut, highcut, fs, order=5):
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
b, a = butter(order, [low, high], btype="band")
return b, a


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
b, a = butter_bandpass(lowcut, highcut, fs, order=order)
y = filtfilt(b, a, data)
return y


# Define the microphone positions
# Ensure channels are in the correct order
print(
"Please enter the microphone positions in the format [x1, y1], [x2, y2], [x3, y3], [x4, y4]"
)
# mic_positions = np.array(
# [
# [0.0000, 0.4500, 0.41],
# [-0.4500, 0.0000, 0.41],
# [0.0000, -0.4500, 0.41],
# [0.4500, 0.0000, 0.41],
# ]
# )

mic_positions = np.array(
[
[0.0000, 0.4500],
[0.4500, 0.0000],
[0.0000, -0.4500],
[-0.4500, 0.0000],
[0.0000, -0.4500],
[0.4500, 0.0000],
]
)

print(f"Microphone positions shape: {mic_positions.shape}")

fs = 44100 # Sampling rate
frame_size = 1024
# frame_size = 1024
frame_size = 512
hop_size = 512 # Overlap may be used, depending on your STFT
channels = 4 # Number of mics

# Initialize the ConventionalSrp object
srp_func = ConventionalSrp(
fs,
grid_type="doa_2D",
n_grid_cells=200,
grid_type="2D",
n_grid_cells=50,
room_dims=[10, 10],
mic_positions=mic_positions,
interpolation=False,
mode="gcc_phat_freq",
n_average_samples=5,
freq_cutoff_in_hz=None,
freq_cutoff_low_in_hz=None,
freq_cutoff_high_in_hz=None,
)

# Set up a matplotlib figure for live updates
fig, ax = plt.subplots()
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_title("Real-time Sound Localization")
scat = ax.scatter([], [], c="red", s=50, label="Estimated Source")
ax.legend()
fig = plt.figure()

ax1 = fig.add_subplot(1,2,1)
ax1.set_xlim(-10, 10)
ax1.set_ylim(-10, 10)
ax1.set_title("Real-time Sound Localization")
ax1.scatter(
mic_positions[:, 0],
mic_positions[:, 1],
c="blue",
s=50,
label="Microphone Positions",
)
scat = ax1.scatter([], [], c="red", marker="*", s=50, label="Estimated Source")
ax1.legend()

ax2 = fig.add_subplot(1,2,2, projection="3d")
ax2.set_title("SRC Map")
ax2.set_xlabel("x (cm)")
ax2.set_ylabel("y (cm)")
ax2.set_zlabel("power")

# Initialize the ConventionalSrp object
# srp_func = ConventionalSrp(
# fs,
# grid_type="doa_2D",
# n_grid_cells=200,
# mic_positions=mic_positions,
# interpolation=False,
# mode="gcc_phat_freq",
# n_average_samples=5,
# freq_cutoff_in_hz=None,
# )



# A buffer to store audio data
audio_buffer = np.zeros((channels, frame_size))
Expand All @@ -51,38 +112,65 @@ def audio_callback(indata, frames, time, status):
# indata: shape (frames, channels)
global audio_buffer, srp_func, scat

# Move data into buffer (for simplicity, assume frames == frame_size)
audio_buffer = indata.T # shape: (channels, frame_size)
audio_buffer = butter_bandpass_filter(audio_buffer, 90, 200, fs)
audio_buffer = nr.reduce_noise(y=audio_buffer, y_noise=audio_buffer, sr=fs)


def update_plot(frame):
global audio_buffer, srp_func, scat

# Process with ConventionalSrp
# forward expects signals shape: (n_mics, n_samples)
# If you need STFT, do it here. For simplicity, let's assume direct time-domain input.
estimated_positions, srp_map, candidate_grid = srp_func.forward(audio_buffer)
print(f"SRP shape: {srp_map.shape}")
print(f"Candidate grid shape: {candidate_grid.asarray().shape}")
print(f"Candidate grid X: {candidate_grid.asarray()[:, 0].shape}")
print(f"Reshape shape: {srp_map.reshape((50, 50)).shape}")

X = candidate_grid.asarray()[:, 0].reshape((50, 50))
Y = candidate_grid.asarray()[:, 1].reshape((50, 50))
Z = srp_map.reshape((50, 50))

# estimated_positions might have one or more sources; take the first source if present
if estimated_positions is not None and len(estimated_positions) > 0:
est_pos = estimated_positions[0]
# Update the scatter plot with the estimated position
scat.set_offsets([est_pos[0], est_pos[1]])
print(f"Estimated position: {estimated_positions}")
# ax2.plot_surface(candidate_grid.asarray()[:, 0], candidate_grid.asarray()[:, 1], srp_map.reshape((50, 50)), cmap='jet', edgecolor='none')
ax2.clear()
ax2.plot_surface(X, Y, Z, cmap='jet', edgecolor='none')
scat.set_offsets([estimated_positions[0], estimated_positions[1]])

# Force matplotlib to redraw
plt.pause(0.001)
return (scat,)


ani = animation.FuncAnimation(fig, update_plot, frames=range(100), blit=True)


# Print out connected audio devices
print(sd.query_devices(), "\n")

device_index = input("Please enter the device index: ")
# device_index = input("Please enter the device index: ")
device_index = 13

# Configure the input stream
stream = sd.InputStream(
device=int(device_index), # Please specify the device index
channels=channels,
samplerate=fs,
# samplerate=16000,
blocksize=frame_size,
callback=audio_callback,
)

# Get multiple device indices from the user
# device_indices = input("Please enter the device indices (comma-separated): ")

# # Split the input by comma and convert each index to an integer
# device_indices = [int(index) for index in device_indices.split(',')]

# print(f"Device indices: {device_indices}")

# Create a list to store the InputStream instances
streams = []

# Start the stream
with stream:
print("Press Ctrl+C to stop")
Expand Down
Loading