Skip to content

Commit

Permalink
almost have multiple coherence - need to do one more vectorized 2x2 @…
Browse files Browse the repository at this point in the history
… 2x1 (H.H@E)
  • Loading branch information
kkappler committed Jan 31, 2024
1 parent 432e0c4 commit 6ad7d70
Showing 1 changed file with 195 additions and 70 deletions.
265 changes: 195 additions & 70 deletions aurora/transfer_function/weights/coherence_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,99 @@ def solve_single_time_window(Y, X, R=None):
z = np.linalg.solve(a, b)
return z

def estimate_time_series_of_impedances(band, output_ch="ex", use_remote=True):
"""
solve:
[<rx,ex>] = [<rx,hx>, <rx,hy>] [Zxx]
[<ry,ex>] [<ry,hx>, <ry,hy>] [Zyx]
[a, b]
[c, d]
determinant where det(A) = 1/(ad-bc)
:param band: band dataset: xarray, will be spectrogram in future
:return:
Requires some nomenclature setup... for now just hard code:/
TODO: Note that cross powers can be computed once only by using Spectrogram class
and spectrogram.cross_power("CH1", "CH2")
which returns self._ch1_ch2
which initialized to None and is written when requested
:param band_dataset:
:return:
"""

def cross_power_series(ch1, ch2):
"""<ch1.H ch2> summed along frequnecy"""
return (ch1.conjugate().transpose() * ch2).sum(dim="frequency")

# Start by computing relevant cross powers
if use_remote:
rx = band["rx"]
ry = band["ry"]
else:
rx = band["hx"]
ry = band["hy"]
rxex = cross_power_series(rx, band["ex"])
ryex = cross_power_series(ry, band["ex"])
rxhx = cross_power_series(rx, band["hx"])
ryhx = cross_power_series(ry, band["hx"])
rxhy = cross_power_series(rx, band["hy"])
ryhy = cross_power_series(ry, band["hy"])

N = len(rxex)
# Compute determiniants (one per time window)
det = rxhx * ryhy - rxhy * ryhx
# determinanjt is the sum of the autopowers minus the sum of the cross powers
det = np.real(det)
# det_inv = 1.0 / det # might get nan here ...

# Inverse matrix (2 x 2 x nTime)
inverse_matrices = np.zeros((2, 2, N), dtype=np.complex128)
inverse_matrices[0, 0, :] = ryhy / det
inverse_matrices[1, 1, :] = rxhx / det
inverse_matrices[0, 1, :] = -rxhy / det
inverse_matrices[1, 0, :] = -ryhx / det

# multiply inverse (on the left) against LHS
# bb = HH @ E
# bb1 = inverse_matrices[0, 0, :] * rxex + inverse_matrices[1, 0, :] * ryex
# Zxy = inverse_matrices[1, 0, :] * rxex + inverse_matrices[0, 1, :] * ryex

print(
"THE PROBLEM IS RIGHT HERE -- need to multiply the inverse by b=HH@E, NOT E alone"
)
Zxx = inverse_matrices[0, 0, :] * rxex + inverse_matrices[1, 0, :] * ryex
Zxy = inverse_matrices[1, 0, :] * rxex + inverse_matrices[0, 1, :] * ryex

# set up simple system and check Z1, z2 OK
# ex1 = [hx1 hy1] [z1
# ex2 = [hx2 hy2] z2]
# ex3 = [hx2 hy2]
idx = 0
E = band["ex"][idx, :]
H = band[["hx", "hy"]].to_array()[:, idx].T
HH = H.conj().transpose()
a = HH.data @ H.data
b = HH.data @ E.data
inv_a = np.linalg.inv(a)
zz0 = inv_a @ b
print(zz0)
zz = solve_single_time_window(b, a, None)
direct = (np.abs(E - H[:, 0] * zz[0] - H[:, 1] * zz[1]) ** 2).sum() / (
np.abs(E) ** 2
).sum()
print("direct", direct)
tricky = (np.abs(E - H[:, 0] * Zxx[0] - H[:, 1] * Zxy[1]) ** 2).sum() / (
np.abs(E) ** 2
).sum()
print("tricky", tricky)
return Zxx, Zxy

# cutoff_type = "threshold"
cutoffs = {}
cutoffs["local"] = {}
Expand All @@ -253,11 +346,35 @@ def solve_single_time_window(Y, X, R=None):
# band = frequency_band

# Extract the FCs for band
band_datasets = {}
band_datasets["local"] = local_stft.extract_band(band)
band_datasets["remote"] = remote_stft.extract_band(band)
n_obs = band_datasets["local"].time.shape[0]
local_dataset = local_stft.extract_band(band)
remote_dataset = remote_stft.extract_band(band, channels=["hx", "hy"])
remote_dataset = remote_dataset.rename({"hx": "rx", "hy": "ry"})
band_dataset = local_dataset.merge(remote_dataset)
n_obs = band_dataset.time.shape[0]
import time

t0 = time.time()
component = "ex"
Zxx, Zxy = estimate_time_series_of_impedances(
band_dataset, output_ch=component, use_remote=False
)
print(time.time() - t0)

predicted = Zxx * band_dataset["hx"] + Zxy * band_dataset["hy"]
predicted_energy = np.real(
(predicted.conjugate().transpose() * predicted).sum("frequency")
)
# residual = band_dataset[component] - estimate_output
# residual_energy = (residual.conjugate().transpose() * residual).real()
component_energy = np.real(
(band_dataset[component].conjugate().transpose() * band_dataset[component]).sum(
"frequency"
)
)
# component_energy = (band_dataset[component].conjugate().transpose() * band_dataset[component]).real()
# estimate_energy = (predicted_energy * estimate_output).real()
mulitple_coherence = predicted_energy / component_energy
print(mulitple_coherence[0])
# initialize a dict to hold the weights
# in this case there will be only two sets of weights, one set
# from the ex equation and another from the ey
Expand All @@ -267,11 +384,6 @@ def solve_single_time_window(Y, X, R=None):
weights[component] = np.ones(n_obs)
# cumulative_weights = np.ones(n_obs)

# Estimate Time Series of Impedance Tensors:
H = band_datasets["local"][["hx", "hy"]]
HH = band_datasets[local_or_remote]
HH = HH[["hx", "hy"]].conj()

# The notation in the following could be cleaned up, but hopefully should make not too murky what
# is happening here. We wish to estimate the multiple coherence for each time window independently
# For each time index in the spectrograms we will solve the following three equations
Expand Down Expand Up @@ -318,8 +430,8 @@ def solve_single_time_window(Y, X, R=None):
# We could iterate over the time intervals, forming the equation:
# R E = R H Z
# =
# [<rx,ex>] = [<rx,hx>, <ry,hx>] [Zxx]
# [<ry,ex>] [<rx,hy>, <ry,hy>] [Zyx]
# [<rx,ex>] = [<rx,hx>, <rx,hy] [Zxx]
# [<ry,ex>] [<ry,hx>, <ry,hy>] [Zyx]
#
# and then simply inverting the 2x2, to get our esimates.
# This is OK, but will be slow in python.
Expand Down Expand Up @@ -366,71 +478,84 @@ def solve_single_time_window(Y, X, R=None):
# - Compute Auto Power

# So treating this as a linalg .solve is not going to cut it

# 20230130: Inverting the large matrix is impractical, it turns out to be a sparse
# matrix with little nFC x 2 blocks along its diagonal, and can easily get to be 100k elt's on a side
# i.e. simple vctorization is not an option.
# one could call C (or maybe even Julia) to speed it up, but another approach would be to use
# a cross-power formulation for the Z estimates.
#
# Above it was noted that we need to solve, many times (onve per time window)
# An equation like this: where the rx, ry, are in gereral the two chanels in the
# Hermitian transpose matrix used in the classic solution, i.e. its G* in: d = Gm <--> G*d = G*Gm
#
# R E = R H Z
# =
# [<rx,ex>] = [<rx,hx>, <ry,hx>] [Zxx]
# [<ry,ex>] [<rx,hy>, <ry,hy>] [Zyx]
#
# But the crosspowers can be computed by just taking (rx*ex).sum(axis=freq) so we can easily set up
# all terms in the cross powers. Now instead of np.solving the 2x2, we can instead use the explicit solution:
# A= [a b] ^-1 = # 1/det(A) [ d -b] where det(A) = 1/(ad-bc)
# [c d] [-c a]
#
# it will take a bit of wrangling to get a clean way to compute the dets of the matrices vectorially, but
# it should be fast
# We know that
# for computing Zxx, Zxy we need:
# local Ex, Hx, Hy, and then either local Hx, Hy, for the hermitian conjugate, or remote Hx, Hy
# it actually can be any two linear independent channels but normally those are the pairings.
import time

t0 = time.time()
for component in components:
E = band_datasets["local"][component]
H = band_datasets["local"][["hx", "hy"]]
# R = band_datasets["remote"][["hx", "hy"]]
logger.info(f"Looping {component} over time")
# 30000 Looper: 150s (3ch)
# for i in range(E.time.shape[0]):
# Y = E.data[i:i+1, :].T # (1,3)
# X = H[["hx", "hy"]].to_array()[:, i:i+1, :].data.squeeze().T
# XR = R[["hx", "hy"]].to_array()[:, i:i+1, :].data.squeeze().T
# z0 = solve_single_time_window(Y,X,XR)
# 15000 Looper (71s 3ch)
for i in range(int(E.time.shape[0] / 2)):
Y = np.reshape(E.data[2 * i : 2 * i + 2, :], (6, 1))
X = np.zeros((6, 4), dtype=np.complex128)
X[0:3, 0:2] = (
H[["hx", "hy"]].to_array()[:, 2 * i : 2 * i + 1, :].data.squeeze().T
)
X[3:6, 2:4] = (
H[["hx", "hy"]].to_array()[:, 2 * i + 1 : 2 * i + 2, :].data.squeeze().T
)
# X = H[["hx", "hy"]].to_array()[:, i:i + 1, :].data.squeeze().T
# XR = R[["hx", "hy"]].to_array()[:, i:i + 1, :].data.squeeze().T
z0 = solve_single_time_window(Y, X, None)
print(f"z0 {z0}")
# 15000 Looper
# for i in range(int(E.time.shape[0] / 2)):
# Y = np.reshape(E.data[2 * i:2 * i + 2, :], (6, 1))
# X = np.zeros((6, 4), dtype=np.complex128)
# X[0:3, 0:2] = H[["hx", "hy"]].to_array()[:, 2 * i:2 * i + 1, :].data.squeeze().T
# X[3:6, 2:4] = H[["hx", "hy"]].to_array()[:, 2 * i + 1:2 * i + 2, :].data.squeeze().T
# # X = H[["hx", "hy"]].to_array()[:, i:i + 1, :].data.squeeze().T
# # XR = R[["hx", "hy"]].to_array()[:, i:i + 1, :].data.squeeze().T
# z0 = solve_single_time_window(Y, X, None)
# print(f"z0 {z0}")
# print("Now loop over time")
# i_time = 0;
# Y = E.data[0:1, :].T #(1,3)
# X = H[["hx","hy"]].to_array()[:,0:1,:].data.squeeze().T
# R = R[["hx", "hy"]].to_array()[:, 0:1, :].data.squeeze().T
# xH = X.conjugate().transpose()
#
# a = xH @ X
# b = xH @ Y
# z0 = np.linalg.solve(a, b)
# print(f"b.shape {b.shape}")
# print(f"a.shape {a.shape}")
# z0 = np.linalg.solve(a,b)
# print(z0)
msg = f"{time.time()-t0}"
print(msg)
print(msg)
print(msg)
print(msg)
print(msg)
print(msg)
print(msg)
# raise(ValueError(msg))
# for component in components:
# E = band_datasets["local"][component]
# H = band_datasets["local"][["hx", "hy"]]
# R = band_datasets["remote"][["hx", "hy"]]
# logger.info(f"Looping {component} over time")
# 30000 Looper: 150s (3ch)
# for i in range(E.time.shape[0]):
# Y = E.data[i:i+1, :].T # (1,3)
# X = H[["hx", "hy"]].to_array()[:, i:i+1, :].data.squeeze().T
# XR = R[["hx", "hy"]].to_array()[:, i:i+1, :].data.squeeze().T
# z0 = solve_single_time_window(Y,X,XR)
# 15000 Looper (71s 3ch)
# for i in range(int(E.time.shape[0] / 2)):
# Y = np.reshape(E.data[2 * i : 2 * i + 2, :], (6, 1))
# X = np.zeros((6, 4), dtype=np.complex128)
# X[0:3, 0:2] = (
# H[["hx", "hy"]].to_array()[:, 2 * i : 2 * i + 1, :].data.squeeze().T
# )
# X[3:6, 2:4] = (
# H[["hx", "hy"]].to_array()[:, 2 * i + 1 : 2 * i + 2, :].data.squeeze().T
# )
# # X = H[["hx", "hy"]].to_array()[:, i:i + 1, :].data.squeeze().T
# # XR = R[["hx", "hy"]].to_array()[:, i:i + 1, :].data.squeeze().T
# z0 = solve_single_time_window(Y, X, None)
# print(f"z0 {z0}")
# # 15000 Looper
# for i in range(int(E.time.shape[0] / 2)):
# Y = np.reshape(E.data[2 * i:2 * i + 2, :], (6, 1))
# X = np.zeros((6, 4), dtype=np.complex128)
# X[0:3, 0:2] = H[["hx", "hy"]].to_array()[:, 2 * i:2 * i + 1, :].data.squeeze().T
# X[3:6, 2:4] = H[["hx", "hy"]].to_array()[:, 2 * i + 1:2 * i + 2, :].data.squeeze().T
# # X = H[["hx", "hy"]].to_array()[:, i:i + 1, :].data.squeeze().T
# # XR = R[["hx", "hy"]].to_array()[:, i:i + 1, :].data.squeeze().T
# z0 = solve_single_time_window(Y, X, None)
# print(f"z0 {z0}")
# print("Now loop over time")
# i_time = 0;
# Y = E.data[0:1, :].T #(1,3)
# X = H[["hx","hy"]].to_array()[:,0:1,:].data.squeeze().T
# R = R[["hx", "hy"]].to_array()[:, 0:1, :].data.squeeze().T
# xH = X.conjugate().transpose()
#
# a = xH @ X
# b = xH @ Y
# z0 = np.linalg.solve(a, b)
# print(f"b.shape {b.shape}")
# print(f"a.shape {a.shape}")
# z0 = np.linalg.solve(a,b)
# print(z0)


def estimate_simple_coherence(X, Y):
Expand Down

0 comments on commit 6ad7d70

Please sign in to comment.