Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't merge: tips for HV #157

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions src/endtoend_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(self, sparse_mode=False):

self.num_jet_features = 8
self.num_pf_features = 36
self.num_decaymodes = 16

self.nn_pf_initialembedding = ffn(self.num_pf_features, self.embedding_dim, self.width, self.act, self.dropout)

Expand All @@ -103,6 +104,7 @@ def __init__(self, sparse_mode=False):

self.nn_pred_istau = ffn(self.num_jet_features + 3 * self.embedding_dim, 2, self.width, self.act, self.dropout)
self.nn_pred_p4 = ffn(self.num_jet_features + 3 * self.embedding_dim, 4, self.width, self.act, self.dropout)
self.nn_pred_dm = ffn(self.num_jet_features + 3 * self.embedding_dim, self.num_decaymodes, self.width, self.act, self.dropout)

# forward function for training with pytorch geometric
def forward_sparse(self, inputs):
Expand Down Expand Up @@ -130,8 +132,10 @@ def forward_sparse(self, inputs):
# run a per-jet NN for visible energy prediction
jet_p4 = jet_features[:, :4]
pred_p4 = jet_p4 * self.nn_pred_p4(jet_feats)

pred_dm = self.nn_pred_dm(jet_feats)

return pred_istau, pred_p4
return pred_istau, pred_p4, pred_dm

# custom forward function for HLS4ML export, assuming a single 3D input
def forward_3d(self, inputs):
Expand Down Expand Up @@ -211,16 +215,29 @@ def model_loop(model, ds_loader, optimizer, scheduler, is_train, dev, tensorboar
for ibatch, batch in enumerate(tqdm.tqdm(ds_loader, total=len(ds_loader))):
optimizer.zero_grad()
batch = batch.to(device=dev)
pred_istau, pred_p4 = model((batch.jet_features, batch.jet_pf_features, batch.jet_pf_features_batch))
pred_istau, pred_p4, pred_dm = model((batch.jet_features, batch.jet_pf_features, batch.jet_pf_features_batch))
true_p4 = batch.gen_tau_p4
true_istau = (batch.gen_tau_decaymode != -1).to(dtype=torch.float32)
true_istau_mask = batch.gen_tau_decaymode != -1
true_istau = true_istau_mask.to(dtype=torch.float32)
pred_p4 = pred_p4 * true_istau.unsqueeze(-1)
weights = batch.weight

loss_p4 = 1e5 * weighted_huber_loss(pred_p4, true_p4, weights)
loss_cls = 1e7 * weighted_bce_with_logits(pred_istau, true_istau, weights)
loss_p4 = weighted_huber_loss(pred_p4, true_p4, weights)
loss_cls = weighted_bce_with_logits(pred_istau, true_istau, weights)

#get the numerical decay mode values for the jets that actually were from a tau
true_dm_vals = batch.gen_tau_decaymode[true_istau_mask].to(torch.int64)

#convert decay mode values from 0...16 ints > to one-hot encoded vectors [0,0,0...,1,...,0] where for value N, the N-th bit is set.
true_dm_onehot = torch.nn.functional.one_hot(true_dm_vals, model.num_decaymodes).to(torch.float32)

#compute the loss between the predicted decay mode values and the true values, for the cases where the jet was really from tau
loss_dm = torch.nn.functional.cross_entropy(pred_dm[true_istau_mask], true_dm_onehot)
print(loss_p4, loss_cls, loss_dm)

#sum all loss components from binary classification, momentum regression and decay mode prediction
loss = loss_cls + loss_p4 + loss_dm

loss = loss_cls + loss_p4
if is_train:
loss.backward()
optimizer.step()
Expand Down
Loading