Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compose variants as a tuple #2147

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions validphys2/src/validphys/commondataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,16 @@ def check(self):
def apply_variant(self, variant_name):
"""Return a new instance of this class with the variant applied

This class also defines how the variant is applied to the commondata
This class also defines how the variant is applied to the commondata.
If more than a variant is being used, this function will be called recursively
until all variants are applied.
"""
if not isinstance(variant_name, str):
observable = self
for single_variant in variant_name:
observable = observable.apply_variant(single_variant)
return observable

try:
variant = self.variants[variant_name]
except KeyError as e:
Expand All @@ -487,15 +495,20 @@ def apply_variant(self, variant_name):

# This section should only be used for the purposes of reproducibility
# of legacy data, no new data should use these

if variant.experiment is not None:
new_nnpdf_metadata = dict(self._parent.nnpdf_metadata.items())
new_nnpdf_metadata["experiment"] = variant.experiment
setmetadata_copy = dataclasses.replace(self._parent, nnpdf_metadata=new_nnpdf_metadata)
variant_replacement["_parent"] = setmetadata_copy
variant_replacement["plotting"] = dataclasses.replace(self.plotting)

return dataclasses.replace(self, applied_variant=variant_name, **variant_replacement)
# Keep track of applied variants:
if self.applied_variant is None:
varname = variant_name
else:
varname = f"{self.applied_variant}_{variant_name}"

return dataclasses.replace(self, applied_variant=varname, **variant_replacement)

@property
def is_positivity(self):
Expand Down Expand Up @@ -843,7 +856,7 @@ def parse_new_metadata(metadata_file, observable_name, variant=None):
# Select one observable from the entire metadata
metadata = set_metadata.select_observable(observable_name)

# And apply variant if given
# And apply variant or variants if given
if variant is not None:
metadata = metadata.apply_variant(variant)

Expand Down
5 changes: 4 additions & 1 deletion validphys2/src/validphys/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,14 +460,17 @@ def parse_dataset_input(self, dataset: Mapping):
if variant is None or map_variant == "legacy_dw":
variant = map_variant

if sysnum is not None:
log.warning("The key 'sys' is deprecated and will soon be removed")

return DataSetInput(
name=name,
sys=sysnum,
cfac=cfac,
frac=frac,
weight=weight,
custom_group=custom_group,
variant=variant,
sys=sysnum,
)

def parse_use_fitcommondata(self, do_use: bool):
Expand Down
28 changes: 25 additions & 3 deletions validphys2/src/validphys/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,17 +368,39 @@ def plot_kinlabels(self):

class DataSetInput(TupleComp):
"""Represents whatever the user enters in the YAML to specify a
dataset."""
dataset.

name: str
name of the dataset_inputs
cfac: tuple
cfactors to apply to the final predictions (default: ())
frac: float
fraction of the data to be used during training (default: 1.0)
weight: float
extra weight to apply to the dataset (default: 1.0)
variant: str or tuple[str]
variant or variants to apply (default: None)
sysnum: int
deprecated, systematic file to load for the dataset
"""

def __init__(self, *, name, sys, cfac, frac, weight, custom_group, variant):
def __init__(self, *, name, cfac, frac, weight, custom_group, variant, sys=None):
self.name = name
self.sys = sys
self.cfac = cfac
self.frac = frac
self.weight = weight
self.custom_group = custom_group

# Parse the variant if introduced as a string
if isinstance(variant, str):
variant = (variant,)

# Make sure that variant is not a list but, in case, a tuple
if isinstance(variant, list):
variant = tuple(variant)
self.variant = variant
super().__init__(name, sys, cfac, frac, weight, custom_group, variant)
super().__init__(name, cfac, frac, weight, custom_group, variant, sys)

def __str__(self):
return self.name
Expand Down
10 changes: 8 additions & 2 deletions validphys2/src/validphys/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def check_commondata(
self, setname, sysnum=None, use_fitcommondata=False, fit=None, variant=None
):
"""Prepare the commondata files to be loaded.
A commondata is defined by its name (``setname``) and the variant (``variant``)
A commondata is defined by its name (``setname``) and the variant(s) (``variant``)

At the moment both old-format and new-format commondata can be utilized and loaded
however old-format commondata are deprecated and will be removed in future relases.
Expand Down Expand Up @@ -423,7 +423,12 @@ def check_commondata(
)
break
# try new commondata format
old_path = fit.path / "filter" / legacy_name / f"filtered_uncertainties_{legacy_name}.yaml"
old_path = (
fit.path
/ "filter"
/ legacy_name
/ f"filtered_uncertainties_{legacy_name}.yaml"
)
if old_path.exists():
data_path = old_path.with_name(f"filtered_data_{legacy_name}.yaml")
unc_path = old_path.with_name(f"filtered_uncertainties_{legacy_name}.yaml")
Expand Down Expand Up @@ -481,6 +486,7 @@ def get_commondata(self, setname, sysnum):
"""Get a Commondata from the set name and number."""
# TODO: check where this is used
# as this might ignore cfactors or variants
raise Exception("Not used")
cd = self.check_commondata(setname, sysnum)
return cd.load()

Expand Down
14 changes: 4 additions & 10 deletions validphys2/src/validphys/overfit_metric.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
overfit_metric.py

This module contains the functions used to calculate the overfit metric and
This module contains the functions used to calculate the overfit metric and
produce the corresponding tables and figures.
"""

Expand Down Expand Up @@ -59,7 +59,7 @@ def calculate_chi2s_per_replica(
preds : list[pd.core.frame.DataFrame]
List of pandas dataframes, each containing the predictions of the pdf
replicas for a dataset_input
dataset_inputs : list[DatasetInput]
dataset_inputs : list[DataSetInput]
groups_covmat_no_table : pdf.core.frame.DataFrame

Returns
Expand Down Expand Up @@ -112,10 +112,7 @@ def calculate_chi2s_per_replica(


def array_expected_overfitting(
calculate_chi2s_per_replica,
replica_data,
number_of_resamples=1000,
resampling_fraction=0.95,
calculate_chi2s_per_replica, replica_data, number_of_resamples=1000, resampling_fraction=0.95
):
"""Calculates the expected difference in chi2 between:
1. The chi2 of a PDF replica calculated using the corresponding pseudodata
Expand Down Expand Up @@ -181,10 +178,7 @@ def plot_overfitting_histogram(fit, array_expected_overfitting):
ax.hist(array_expected_overfitting, bins=50, density=True)
ax.axvline(x=mean, color="black")
ax.axvline(x=0, color="black", linestyle="--")
xrange = [
array_expected_overfitting.min(),
array_expected_overfitting.max(),
]
xrange = [array_expected_overfitting.min(), array_expected_overfitting.max()]
xgrid = np.linspace(xrange[0], xrange[1], num=100)
ax.plot(xgrid, stats.norm.pdf(xgrid, mean, std))
ax.set_xlabel(r"$\mathcal{R}_O$")
Expand Down
Loading