Skip to content

Commit

Permalink
accept the possibility of composing variants ; this will allow the po…
Browse files Browse the repository at this point in the history
…ssibility of studying the effect of new theories -which might be differently architectured in some ways- in legacy datasets
  • Loading branch information
scarlehoff committed Nov 23, 2024
1 parent a329696 commit deb8c60
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 20 deletions.
21 changes: 17 additions & 4 deletions validphys2/src/validphys/commondataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,16 @@ def check(self):
def apply_variant(self, variant_name):
"""Return a new instance of this class with the variant applied
This class also defines how the variant is applied to the commondata
This class also defines how the variant is applied to the commondata.
If more than a variant is being used, this function will be called recursively
until all variants are applied.
"""
if not isinstance(variant_name, str):
observable = self
for single_variant in variant_name:
observable = observable.apply_variant(single_variant)
return observable

try:
variant = self.variants[variant_name]
except KeyError as e:
Expand All @@ -484,15 +492,20 @@ def apply_variant(self, variant_name):

# This section should only be used for the purposes of reproducibility
# of legacy data, no new data should use these

if variant.experiment is not None:
new_nnpdf_metadata = dict(self._parent.nnpdf_metadata.items())
new_nnpdf_metadata["experiment"] = variant.experiment
setmetadata_copy = dataclasses.replace(self._parent, nnpdf_metadata=new_nnpdf_metadata)
variant_replacement["_parent"] = setmetadata_copy
variant_replacement["plotting"] = dataclasses.replace(self.plotting)

return dataclasses.replace(self, applied_variant=variant_name, **variant_replacement)
# Keep track of applied variants:
if self.applied_variant is None:
varname = variant_name
else:
varname = f"{self.applied_variant}_{variant_name}"

return dataclasses.replace(self, applied_variant=varname, **variant_replacement)

@property
def is_positivity(self):
Expand Down Expand Up @@ -840,7 +853,7 @@ def parse_new_metadata(metadata_file, observable_name, variant=None):
# Select one observable from the entire metadata
metadata = set_metadata.select_observable(observable_name)

# And apply variant if given
# And apply variant or variants if given
if variant is not None:
metadata = metadata.apply_variant(variant)

Expand Down
5 changes: 4 additions & 1 deletion validphys2/src/validphys/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,14 +460,17 @@ def parse_dataset_input(self, dataset: Mapping):
if variant is None or map_variant == "legacy_dw":
variant = map_variant

if sysnum is not None:
log.warning("The key 'sys' is deprecated and will soon be removed")

return DataSetInput(
name=name,
sys=sysnum,
cfac=cfac,
frac=frac,
weight=weight,
custom_group=custom_group,
variant=variant,
sys=sysnum,
)

def parse_use_fitcommondata(self, do_use: bool):
Expand Down
28 changes: 25 additions & 3 deletions validphys2/src/validphys/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,17 +368,39 @@ def plot_kinlabels(self):

class DataSetInput(TupleComp):
"""Represents whatever the user enters in the YAML to specify a
dataset."""
dataset.
name: str
name of the dataset_inputs
cfac: tuple
cfactors to apply to the final predictions (default: ())
frac: float
fraction of the data to be used during training (default: 1.0)
weight: float
extra weight to apply to the dataset (default: 1.0)
variant: str or tuple[str]
variant or variants to apply (default: None)
sysnum: int
deprecated, systematic file to load for the dataset
"""

def __init__(self, *, name, sys, cfac, frac, weight, custom_group, variant):
def __init__(self, *, name, cfac, frac, weight, custom_group, variant, sys=None):
self.name = name
self.sys = sys
self.cfac = cfac
self.frac = frac
self.weight = weight
self.custom_group = custom_group

# Parse the variant if introduced as a string
if isinstance(variant, str):
variant = (variant,)

# Make sure that variant is not a list but, in case, a tuple
if isinstance(variant, list):
variant = tuple(variant)
self.variant = variant
super().__init__(name, sys, cfac, frac, weight, custom_group, variant)
super().__init__(name, cfac, frac, weight, custom_group, variant, sys)

def __str__(self):
return self.name
Expand Down
10 changes: 8 additions & 2 deletions validphys2/src/validphys/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def check_commondata(
self, setname, sysnum=None, use_fitcommondata=False, fit=None, variant=None
):
"""Prepare the commondata files to be loaded.
A commondata is defined by its name (``setname``) and the variant (``variant``)
A commondata is defined by its name (``setname``) and the variant(s) (``variant``)
At the moment both old-format and new-format commondata can be utilized and loaded
however old-format commondata are deprecated and will be removed in future relases.
Expand Down Expand Up @@ -423,7 +423,12 @@ def check_commondata(
)
break
# try new commondata format
old_path = fit.path / "filter" / legacy_name / f"filtered_uncertainties_{legacy_name}.yaml"
old_path = (
fit.path
/ "filter"
/ legacy_name
/ f"filtered_uncertainties_{legacy_name}.yaml"
)
if old_path.exists():
data_path = old_path.with_name(f"filtered_data_{legacy_name}.yaml")
unc_path = old_path.with_name(f"filtered_uncertainties_{legacy_name}.yaml")
Expand Down Expand Up @@ -481,6 +486,7 @@ def get_commondata(self, setname, sysnum):
"""Get a Commondata from the set name and number."""
# TODO: check where this is used
# as this might ignore cfactors or variants
raise Exception("Not used")
cd = self.check_commondata(setname, sysnum)
return cd.load()

Expand Down
14 changes: 4 additions & 10 deletions validphys2/src/validphys/overfit_metric.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
overfit_metric.py
This module contains the functions used to calculate the overfit metric and
This module contains the functions used to calculate the overfit metric and
produce the corresponding tables and figures.
"""

Expand Down Expand Up @@ -59,7 +59,7 @@ def calculate_chi2s_per_replica(
preds : list[pd.core.frame.DataFrame]
List of pandas dataframes, each containing the predictions of the pdf
replicas for a dataset_input
dataset_inputs : list[DatasetInput]
dataset_inputs : list[DataSetInput]
groups_covmat_no_table : pdf.core.frame.DataFrame
Returns
Expand Down Expand Up @@ -112,10 +112,7 @@ def calculate_chi2s_per_replica(


def array_expected_overfitting(
calculate_chi2s_per_replica,
replica_data,
number_of_resamples=1000,
resampling_fraction=0.95,
calculate_chi2s_per_replica, replica_data, number_of_resamples=1000, resampling_fraction=0.95
):
"""Calculates the expected difference in chi2 between:
1. The chi2 of a PDF replica calculated using the corresponding pseudodata
Expand Down Expand Up @@ -181,10 +178,7 @@ def plot_overfitting_histogram(fit, array_expected_overfitting):
ax.hist(array_expected_overfitting, bins=50, density=True)
ax.axvline(x=mean, color="black")
ax.axvline(x=0, color="black", linestyle="--")
xrange = [
array_expected_overfitting.min(),
array_expected_overfitting.max(),
]
xrange = [array_expected_overfitting.min(), array_expected_overfitting.max()]
xgrid = np.linspace(xrange[0], xrange[1], num=100)
ax.plot(xgrid, stats.norm.pdf(xgrid, mean, std))
ax.set_xlabel(r"$\mathcal{R}_O$")
Expand Down

0 comments on commit deb8c60

Please sign in to comment.