Skip to content

Commit

Permalink
PEP 8 compliance for evaluate.py.
Browse files Browse the repository at this point in the history
Changes:
-Made evaluate.py PEP 8 compliant
-Fixed a type in get_stats
  • Loading branch information
Marlin Schäfer committed May 26, 2022
1 parent 00d3a3b commit bb36080
Showing 1 changed file with 39 additions and 30 deletions.
69 changes: 39 additions & 30 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import h5py
import os
import logging
from pycbc.conversions import distance_from_chirp_distance_mchirp


def find_injection_times(fgfiles, injfile, padding_start=0, padding_end=0):
"""Determine injections which are contained in the file.
Expand Down Expand Up @@ -61,6 +61,7 @@ def find_injection_times(fgfiles, injfile, padding_start=0, padding_end=0):

return duration, np.any(ret, axis=0)


def find_closest_index(array, value, assume_sorted=False):
"""Find the index of the closest element in the array for the given
value(s).
Expand Down Expand Up @@ -89,15 +90,18 @@ def find_closest_index(array, value, assume_sorted=False):
ridxs = np.searchsorted(array, value, side='right')
lidxs = np.maximum(ridxs - 1, 0)
comp = np.fabs(array[lidxs] - value) < \
np.fabs(array[np.minimum(ridxs, len(array) - 1)] - value)
np.fabs(array[np.minimum(ridxs, len(array) - 1)] - value) # noqa: E127, E501
lisbetter = np.logical_or((ridxs == len(array)), comp)
ridxs[lisbetter] -= 1
return ridxs


def mchirp(mass1, mass2):
return (mass1 * mass2) ** (3. / 5.) / (mass1 + mass2) ** (1. / 5.)

def get_stats(fgevents, bgevents, injparams, duration=None, chirp_distance=False):

def get_stats(fgevents, bgevents, injparams, duration=None,
chirp_distance=False):
"""Calculate the false-alarm rate and sensitivity of a search
algorithm.
Expand Down Expand Up @@ -142,7 +146,7 @@ def get_stats(fgevents, bgevents, injparams, duration=None, chirp_distance=False
if chirp_distance:
massc = mchirp(injparams['mass1'], injparams['mass2'])
if duration is None:
duration = injtime.max() - injtimes.min()
duration = injtimes.max() - injtimes.min()
logging.info('Sorting foreground event times')
sidxs = fgevents[0].argsort()
fgevents = fgevents.T[sidxs].T
Expand Down Expand Up @@ -170,25 +174,24 @@ def get_stats(fgevents, bgevents, injparams, duration=None, chirp_distance=False
ret['true-positives'] = tpevents
ret['false-positives'] = fpevents

#Calculate foreground FAR
# Calculate foreground FAR
logging.info('Calculating foreground FAR')
noise_stats = fpevents[1].copy()
noise_stats.sort()
fgfar = len(noise_stats) - np.arange(len(noise_stats)) - 1
fgfar = fgfar / duration
ret['fg-far'] = fgfar
sfaridxs = fgfar.argsort()

#Calculate background FAR
# Calculate background FAR
logging.info('Calculating background FAR')
noise_stats = bgevents[1].copy()
noise_stats.sort()
far = len(noise_stats) - np.arange(len(noise_stats)) - 1
far = far / duration
ret['far'] = far

#Calculate sensitivity
#CARE! THIS APPLIES ONLY WHEN THE DISTRIBUTION IS CHOSEN CORRECTLY
# Calculate sensitivity
# CARE! THIS APPLIES ONLY WHEN THE DISTRIBUTION IS CHOSEN CORRECTLY
logging.info('Calculating sensitivity')
sidxs = tpevents[1].argsort()
tp_sort = tpevents[1][sidxs]
Expand All @@ -208,14 +211,14 @@ def get_stats(fgevents, bgevents, injparams, duration=None, chirp_distance=False
nfound = len(tp_sort) - np.searchsorted(tp_sort, noise_stats,
side='right')
if chirp_distance:
#Get found chirp-mass indices for given threshold
# Get found chirp-mass indices for given threshold
fidxs = np.searchsorted(tp_sort, noise_stats, side='right')
found_mchirp_total = np.flip(found_mchirp_total)

#Calculate sum(found_mchirp ** (5/2))
#with found_mchirp = found_mchirp_total[i:]
#and i looped over fidxs
#Code below is a vectorized form of that
# Calculate sum(found_mchirp ** (5/2))
# with found_mchirp = found_mchirp_total[i:]
# and i looped over fidxs
# Code below is a vectorized form of that
cumsum = np.flip(np.cumsum(found_mchirp_total ** (5./2.)))
cumsum = np.concatenate([cumsum, np.zeros(1)])
mc_sum = cumsum[fidxs]
Expand All @@ -224,7 +227,8 @@ def get_stats(fgevents, bgevents, injparams, duration=None, chirp_distance=False
cumsumsq = np.flip(np.cumsum(found_mchirp_total ** 5))
cumsumsq = np.concatenate([cumsumsq, np.zeros(1)])
sample_variance_prefactor = cumsumsq[fidxs]
sample_variance = sample_variance_prefactor / Ninj - (mc_sum / Ninj) ** 2
sample_variance = sample_variance_prefactor / Ninj\
- (mc_sum / Ninj) ** 2 # noqa: E127
else:
mc_sum = nfound
sample_variance = nfound / Ninj - (nfound / Ninj) ** 2
Expand All @@ -239,23 +243,27 @@ def get_stats(fgevents, bgevents, injparams, duration=None, chirp_distance=False

return ret


def main(doc):
parser = argparse.ArgumentParser(description=doc)

parser.add_argument('--injection-file', type=str, required=True,
help=("Path to the file containing information "
"on the injections. (The file returned by"
"`generate_data.py --output-injection-file`"))
parser.add_argument('--foreground-events', type=str, nargs='+', required=True,
parser.add_argument('--foreground-events', type=str, nargs='+',
required=True,
help=("Path to the file containing the events "
"returned by the search on the foreground "
"data set as returned by "
"`generate_data.py --output-foreground-file`."))
parser.add_argument('--foreground-files', type=str, nargs='+', required=True,
parser.add_argument('--foreground-files', type=str, nargs='+',
required=True,
help=("Path to the file containing the analyzed "
"foreground data output by"
"`generate_data.py --output-foreground-file`."))
parser.add_argument('--background-events', type=str, required=True, nargs='+',
parser.add_argument('--background-events', type=str, nargs='+',
required=True,
help=("Path to the file containing the events "
"returned by the search on the background"
"data set as returned by "
Expand All @@ -271,34 +279,35 @@ def main(doc):

args = parser.parse_args()

#Setup logging
# Setup logging
log_level = logging.INFO if args.verbose else logging.WARN
logging.basicConfig(format='%(levelname)s | %(asctime)s: %(message)s',
level=log_level, datefmt='%d-%m-%Y %H:%M:%S')

#Sanity check arguments here
# Sanity check arguments here
if os.path.splitext(args.output_file)[1] != '.hdf':
raise ValueError(f'The output file must have the extension `.hdf`.')
raise ValueError('The output file must have the extension `.hdf`.')

if os.path.isfile(args.output_file) and not args.force:
raise IOError(f'The file {args.output_file} already exists. Set the flag `force` to overwrite it.')
raise IOError(f'The file {args.output_file} already exists. '
'Set the flag `force` to overwrite it.')

#Find indices contained in foreground file
logging.info(f'Finding injections contained in data')
# Find indices contained in foreground file
logging.info('Finding injections contained in data')
padding_start, padding_end = 30, 30
dur, idxs = find_injection_times(args.foreground_files,
args.injection_file,
padding_start=padding_start,
padding_end=padding_end)
if np.sum(idxs) == 0:
msg = 'The foreground data contains no injections! '
msg = 'The foreground data contains no injections! '
msg += 'Probably a too small section of data was generated. '
msg += 'Please make sure to generate at least {} seconds of data. '
msg += 'Otherwise a sensitive distance cannot be calculated.'
msg = msg.format(padding_start + padding_end + 24)
raise RuntimeError(msg)

#Read injection parameters
# Read injection parameters
logging.info(f'Reading injections from {args.injection_file}')
injparams = {}
with h5py.File(args.injection_file, 'r') as fp:
Expand All @@ -308,7 +317,7 @@ def main(doc):
injparams['mass2'] = fp['mass2'][()][idxs]
use_chirp_distance = 'chirp_distance' in fp.keys()

#Read foreground events
# Read foreground events
logging.info(f'Reading foreground events from {args.foreground_events}')
fg_events = []
for fpath in args.foreground_events:
Expand All @@ -318,7 +327,7 @@ def main(doc):
fp['var']]))
fg_events = np.concatenate(fg_events, axis=-1)

#Read background events
# Read background events
logging.info(f'Reading background events from {args.background_events}')
bg_events = []
for fpath in args.background_events:
Expand All @@ -332,14 +341,14 @@ def main(doc):
duration=dur,
chirp_distance=use_chirp_distance)


#Store results
# Store results
logging.info(f'Writing output to {args.output_file}')
mode = 'w' if args.force else 'x'
with h5py.File(args.output_file, mode) as fp:
for key, val in stats.items():
fp.create_dataset(key, data=np.array(val))
return


if __name__ == "__main__":
main(__doc__)

0 comments on commit bb36080

Please sign in to comment.