Skip to content

Commit

Permalink
test_fft_gtgram.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jun 12, 2024
1 parent 09fd0de commit 2be3286
Showing 1 changed file with 31 additions and 77 deletions.
108 changes: 31 additions & 77 deletions tests/test_fft_gtgram.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,102 +5,56 @@
# BSD license: https://github.com/detly/gammatone/blob/master/COPYING

import numpy as np
import pytest
import scipy.io
from mock import patch
from pkg_resources import resource_stream

import gammatone.fftweight

REF_DATA_FILENAME = "data/test_fft_gammatonegram_data.mat"

INPUT_KEY = "fft_gammatonegram_inputs"
MOCK_KEY = "fft_gammatonegram_mocks"
RESULT_KEY = "fft_gammatonegram_results"

INPUT_COLS = ("name", "wave", "fs", "twin", "thop", "channels", "fmin")
MOCK_COLS = ("wts",)
RESULT_COLS = ("res", "window", "nfft", "nwin", "nhop")

with resource_stream(__name__, REF_DATA_FILENAME) as test_data:
DATA = scipy.io.loadmat(test_data, squeeze_me=False)
INPUTS_MOCKS_REFS_DICTS = [
(dict(zip(INPUT_COLS, inputs)), dict(zip(MOCK_COLS, mocks)), dict(zip(RESULT_COLS, refs)))
for inputs, mocks, refs in zip(DATA[INPUT_KEY], DATA[MOCK_KEY], DATA[RESULT_KEY])
]

def load_reference_data():
"""Load test data generated from the reference code"""
# Load test data
with resource_stream(__name__, REF_DATA_FILENAME) as test_data:
data = scipy.io.loadmat(test_data, squeeze_me=False)

zipped_data = zip(data[INPUT_KEY], data[MOCK_KEY], data[RESULT_KEY])
for inputs, mocks, refs in zipped_data:
input_dict = dict(zip(INPUT_COLS, inputs))
mock_dict = dict(zip(MOCK_COLS, mocks))
ref_dict = dict(zip(RESULT_COLS, refs))

yield (input_dict, mock_dict, ref_dict)


def test_fft_specgram_window():
for inputs, mocks, refs in load_reference_data():
args = (refs["nfft"], refs["nwin"])

expected = (refs["window"],)

yield FFTGtgramWindowTester(inputs["name"], args, expected)


class FFTGtgramWindowTester:
def __init__(self, name, args, expected):
self.nfft = args[0].squeeze()
self.nwin = args[1].squeeze()
self.expected = expected[0].squeeze()

self.description = "FFT gammatonegram window for nfft = {:f}, nwin = {:f}".format(
float(self.nfft), float(self.nwin)
)

def __call__(self):
result = gammatone.fftweight.specgram_window(self.nfft, self.nwin)
max_diff = np.max(np.abs(result - self.expected))
assert np.allclose(result, self.expected, rtol=1e-6, atol=2e-3), "Maximum difference: {:6e}".format(max_diff)


def test_fft_gtgram():
for inputs, mocks, refs in load_reference_data():
args = (
inputs["fs"],
inputs["twin"],
inputs["thop"],
inputs["channels"],
inputs["fmin"],
)
yield FFTGammatonegramTester(
inputs["name"][0],
args,
inputs["wave"],
mocks["wts"],
refs["window"],
refs["res"],
)

@pytest.mark.parametrize("inputs,mocks,refs", INPUTS_MOCKS_REFS_DICTS)
def test_fft_specgram_window(inputs, mocks, refs):
args = (refs["nfft"], refs["nwin"])
nfft = args[0].squeeze()
nwin = args[1].squeeze()
expected = refs["window"].squeeze()

class FFTGammatonegramTester:
"""Testing class for gammatonegram calculation"""
result = gammatone.fftweight.specgram_window(nfft, nwin)
max_diff = np.max(np.abs(result - expected))
assert np.allclose(result, expected, rtol=1e-6, atol=2e-3), "Maximum difference: {:6e}".format(max_diff)

def __init__(self, name, args, sig, fft_weights, window, expected):
self.signal = np.asarray(sig).squeeze()
self.expected = np.asarray(expected).squeeze()
self.fft_weights = np.asarray(fft_weights)
self.args = args
self.window = window.squeeze()

self.description = "FFT gammatonegram for {:s}".format(name)
@pytest.mark.parametrize("inputs,mocks,refs", INPUTS_MOCKS_REFS_DICTS)
def test_fft_gtgram(inputs, mocks, refs):
args = (inputs["fs"], inputs["twin"], inputs["thop"], inputs["channels"], inputs["fmin"])
signal = np.asarray(inputs["wave"]).squeeze()
expected = np.asarray(refs["res"]).squeeze()
fft_weights = np.asarray(mocks["wts"])
window = refs["window"].squeeze()

def __call__(self):
# Note that the second return value from fft_weights isn't actually used
with patch("gammatone.fftweight.fft_weights", return_value=(self.fft_weights, None)), patch(
"gammatone.fftweight.specgram_window", return_value=self.window
):
result = gammatone.fftweight.fft_gtgram(self.signal, *self.args)
# Note that the second return value from fft_weights isn't actually used
with patch("gammatone.fftweight.fft_weights", return_value=(fft_weights, None)), patch(
"gammatone.fftweight.specgram_window", return_value=window
):
result = gammatone.fftweight.fft_gtgram(signal, *args)

max_diff = np.max(np.abs(result - self.expected))
diagnostic = "Maximum difference: {:6e}".format(max_diff)
max_diff = np.max(np.abs(result - expected))
diagnostic = "Maximum difference: {:6e}".format(max_diff)

assert np.allclose(result, self.expected, rtol=1e-6, atol=1e-12), diagnostic
assert np.allclose(result, expected, rtol=1e-6, atol=1e-12), diagnostic

0 comments on commit 2be3286

Please sign in to comment.