From 4207e8a0cb5b4685ecf334fb639e79f990acf66e Mon Sep 17 00:00:00 2001 From: Laurits Date: Fri, 5 Jul 2024 11:43:21 +0300 Subject: [PATCH 1/2] minor changes for figures in the paper --- enreg/config/metrics/regression.yaml | 2 +- enreg/tools/visualization/plot_tau_decay.py | 274 ++++++++++++++++++++ notebooks/DM_CM.ipynb | 38 ++- 3 files changed, 300 insertions(+), 14 deletions(-) create mode 100644 enreg/tools/visualization/plot_tau_decay.py diff --git a/enreg/config/metrics/regression.yaml b/enreg/config/metrics/regression.yaml index dde7e4a..9228983 100644 --- a/enreg/config/metrics/regression.yaml +++ b/enreg/config/metrics/regression.yaml @@ -25,7 +25,7 @@ regression: marker: "v" hatch: "//" color: "tab:green" - SimpleDNN: + DeepSet: ntuples_dir: /home/laurits/ml-tau-en-reg/training-outputs/20240701_lowered_ptcut_merged/v1/jet_regression/SimpleDNN/ json_metrics_path: plotting_data.json load_from_json: False diff --git a/enreg/tools/visualization/plot_tau_decay.py b/enreg/tools/visualization/plot_tau_decay.py new file mode 100644 index 0000000..5282ced --- /dev/null +++ b/enreg/tools/visualization/plot_tau_decay.py @@ -0,0 +1,274 @@ +import matplotlib.pyplot as plt +from matplotlib import patches +from matplotlib import text as mtext +import numpy as np +import math + + +class CurvedText(mtext.Text): + """ + A text object that follows an arbitrary curve. + """ + def __init__(self, x, y, text, axes, **kwargs): + super(CurvedText, self).__init__(x[0],y[0],' ', **kwargs) + + axes.add_artist(self) + + ##saving the curve: + self.__x = x + self.__y = y + self.__zorder = self.get_zorder() + + ##creating the text objects + self.__Characters = [] + for c in text: + if c == ' ': + ##make this an invisible 'a': + t = mtext.Text(0,0,'a') + t.set_alpha(0.0) + else: + t = mtext.Text(0,0,c, fontsize=13, **kwargs) + + #resetting unnecessary arguments + t.set_ha('center') + t.set_rotation(0) + t.set_zorder(self.__zorder +1) + + self.__Characters.append((c,t)) + axes.add_artist(t) + + + ##overloading some member functions, to assure correct functionality + ##on update + def set_zorder(self, zorder): + super(CurvedText, self).set_zorder(zorder) + self.__zorder = self.get_zorder() + for c,t in self.__Characters: + t.set_zorder(self.__zorder+1) + + def draw(self, renderer, *args, **kwargs): + """ + Overload of the Text.draw() function. Do not do + do any drawing, but update the positions and rotation + angles of self.__Characters. + """ + self.update_positions(renderer) + + def update_positions(self,renderer): + """ + Update positions and rotations of the individual text elements. + """ + + #preparations + + ##determining the aspect ratio: + ##from https://stackoverflow.com/a/42014041/2454357 + + ##data limits + xlim = self.axes.get_xlim() + ylim = self.axes.get_ylim() + ## Axis size on figure + figW, figH = self.axes.get_figure().get_size_inches() + ## Ratio of display units + _, _, w, h = self.axes.get_position().bounds + ##final aspect ratio + aspect = ((figW * w)/(figH * h))*(ylim[1]-ylim[0])/(xlim[1]-xlim[0]) + + #points of the curve in figure coordinates: + x_fig,y_fig = ( + np.array(l) for l in zip(*self.axes.transData.transform([ + (i,j) for i,j in zip(self.__x,self.__y) + ])) + ) + + #point distances in figure coordinates + x_fig_dist = (x_fig[1:]-x_fig[:-1]) + y_fig_dist = (y_fig[1:]-y_fig[:-1]) + r_fig_dist = np.sqrt(x_fig_dist**2+y_fig_dist**2) + + #arc length in figure coordinates + l_fig = np.insert(np.cumsum(r_fig_dist),0,0) + + #angles in figure coordinates + rads = np.arctan2((y_fig[1:] - y_fig[:-1]),(x_fig[1:] - x_fig[:-1])) + degs = np.rad2deg(rads) + + + rel_pos = 10 + for c,t in self.__Characters: + #finding the width of c: + t.set_rotation(0) + t.set_va('center') + bbox1 = t.get_window_extent(renderer=renderer) + w = bbox1.width + h = bbox1.height + + #ignore all letters that don't fit: + if rel_pos+w/2 > l_fig[-1]: + t.set_alpha(0.0) + rel_pos += w + continue + + elif c != ' ': + t.set_alpha(1.0) + + #finding the two data points between which the horizontal + #center point of the character will be situated + #left and right indices: + il = np.where(rel_pos+w/2 >= l_fig)[0][-1] + ir = np.where(rel_pos+w/2 <= l_fig)[0][0] + + #if we exactly hit a data point: + if ir == il: + ir += 1 + + #how much of the letter width was needed to find il: + used = l_fig[il]-rel_pos + rel_pos = l_fig[il] + + #relative distance between il and ir where the center + #of the character will be + fraction = (w/2-used)/r_fig_dist[il] + + ##setting the character position in data coordinates: + ##interpolate between the two points: + x = self.__x[il]+fraction*(self.__x[ir]-self.__x[il]) + y = self.__y[il]+fraction*(self.__y[ir]-self.__y[il]) + + #getting the offset when setting correct vertical alignment + #in data coordinates + t.set_va(self.get_va()) + bbox2 = t.get_window_extent(renderer=renderer) + + bbox1d = self.axes.transData.inverted().transform(bbox1) + bbox2d = self.axes.transData.inverted().transform(bbox2) + dr = np.array(bbox2d[0]-bbox1d[0]) + + #the rotation/stretch matrix + rad = rads[il] + rot_mat = np.array([ + [math.cos(rad), math.sin(rad)*aspect], + [-math.sin(rad)/aspect, math.cos(rad)] + ]) + + ##computing the offset vector of the rotated character + drp = np.dot(dr,rot_mat) + + #setting final position and rotation: + t.set_position(np.array([x,y])+drp) + t.set_rotation(degs[il]) + + t.set_va('center') + t.set_ha('center') + + #updating rel_pos to right edge of character + rel_pos += w-used + + + +# def curveText(text, height, minTheta, maxTheta, ax): +# interval = np.arange(minTheta, maxTheta, .022) +# if( maxTheta <= np.pi): +# progression = interval[::-1] +# rotation = interval[::-1] - np.arctan(np.tan(np.pi/2)) +# else: +# progression = interval +# rotation = interval - np.arctan(np.tan(np.pi/2)) - np.pi + +# ## Render each letter individually +# for i, rot, t in zip(progression, rotation, text): +# ax.text(i, height, t, fontsize=11,rotation=np.degrees(rot), ha='center', va='center') + + +hadronic_decays = [1.463, 11.51, 25.93, 10.81, 9.80, 4.76 + 0.517] +hadronic_explode = [0.1] * len(hadronic_decays) +hadronic_labels = [ + "Rare", + r"$h^{\pm} \nu_{\tau}$", # h + r"$h^{\pm} \pi_{0} \nu_{\tau}$", # h + pi0 + r"$h^{\pm} \geq 2 \pi_{0} + \nu_{\tau}$", # h + 2 pi0 -> Should replace with h + >= 2pi0 + r"$h^{\pm} h^{\mp} h^{\pm} \nu_{\tau}$", # h h h + r"$h^{\pm} h^{\mp} h^{\pm} \pi_{0} \nu_{\tau}$", # hhh + pi0 + # r"$h^{\pm} h^{\mp} h^{\pm} \geq 2\pi_{0} \nu_{\tau}$", # hhh + >= 2pi0 +] +hadronic_colors = [ + plt.cm.Greens(0.5), + plt.cm.Blues(0.4), plt.cm.Blues(0.5), plt.cm.Blues(0.6), + plt.cm.Reds(0.4), plt.cm.Reds(0.6), +] + + +leptonic_decays = [17.82, 17.39] +leptonic_explode = [0.0] * len(leptonic_decays) +leptonic_labels = [ + r"$e^{-} \bar{\nu_{e}} \nu_{\tau}$", + r"$\mu^{-} \bar{\nu_{\mu}} \nu_{\tau}$", +] +leptonic_colors = [plt.cm.Greys(0.5), plt.cm.Greys(0.6)] + + + +subgroup_names = hadronic_labels + leptonic_labels +subgroup_size = hadronic_decays + leptonic_decays +subgroup_explode = hadronic_explode + leptonic_explode +subgroup_colors = hadronic_colors + leptonic_colors + +group_names = ["Hadronic decays", "Leptonic decays"] +group_size = [sum(hadronic_decays), sum(leptonic_decays)] +group_explode = [0.0, 0.0] +group_colors = [plt.cm.Blues(0.8), plt.cm.Greys(0.8)] + +fig, ax = plt.subplots() +# ax.axis('equal') + +inner_circle = ax.pie( + group_size, + autopct='%1.1f%%', + pctdistance=0.8, + radius=1.0, + # labels=group_names, + explode=group_explode, + labeldistance=0.7, + colors=group_colors, + textprops={'fontsize': 13, "color":'black'} +)[0] +plt.setp( + inner_circle, + width=0.4, + edgecolor='black' +) + +outer_circle = ax.pie( + subgroup_size, + autopct='%1.1f%%', + pctdistance=0.87, + radius=1.4, + labels=subgroup_names, + explode=subgroup_explode, + colors=subgroup_colors, + textprops={'fontsize': 13, "color":'black'} +)[0] + +plt.setp( + outer_circle, + width=0.4, + edgecolor='black' +) + +plt.margins(5,5) +N = 100 + +x_h = 0.4 * np.cos(np.linspace(1.2 *np.pi, 0, N)) +y_h = 0.4 * np.sin(np.linspace(1.2 *np.pi, 0, N)) + + +x_l = 0.5 * np.cos(np.linspace(1.33 * np.pi, 2*np.pi, N)) +y_l = 0.5 * np.sin(np.linspace(1.33 * np.pi, 2*np.pi, N)) + + +CurvedText(x_h, y_h, text="Hadronic decays", va='bottom', axes=ax, color="black") +CurvedText(x_l, y_l, text="Leptonic decays", va='bottom', axes=ax, color="black") +plt.text(-0.08, -0.08, r"$\tau$", fontsize=30) + + +plt.savefig("tau_decays.pdf", format="pdf", bbox_inches="tight", transparent=True) diff --git a/notebooks/DM_CM.ipynb b/notebooks/DM_CM.ipynb index 498796f..2161db8 100644 --- a/notebooks/DM_CM.ipynb +++ b/notebooks/DM_CM.ipynb @@ -11,6 +11,7 @@ "import json\n", "import hydra\n", "import numpy as np\n", + "from math import isclose\n", "import enreg.tools.general as g\n", "import mplhep as hep\n", "import awkward as ak\n", @@ -126,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 23, "id": "7445c1ea-7e5d-42c8-b8dd-bdbc4a75e3d3", "metadata": {}, "outputs": [], @@ -280,18 +281,29 @@ " pr = get_precision(\"Z\")\n", "\n", " pr_sdnn, pr_ln, pr_pt = pr[0], pr[1], pr[2]\n", + " pr_sdnn = list(pr_sdnn)\n", + " pr_ln = list(pr_ln)\n", + " pr_pt = list(pr_pt)\n", + " labels = [r'$h^\\pm$', r'$h^\\pm+\\pi^0$', r'$h^\\pm+\\geq2\\pi^0$', r'$h^\\pm h^\\mp h^\\pm$', r'$h^\\pm h^\\mp h^\\pm$' '\\n' r'$+\\geq\\pi^0$', 'Rare', 'Overall']\n", + " PDG_ratios = np.array([0.1777, 0.4002, 0.1668, 0.1513, 0.0816, 0.0224])\n", " \n", - " labels = [r'$h^\\pm$', r'$h^\\pm+\\pi^0$', r'$h^\\pm+\\geq2\\pi^0$', r'$h^\\pm h^\\mp h^\\pm$', r'$h^\\pm h^\\mp h^\\pm$' '\\n' r'$+\\geq\\pi^0$', 'Rare']\n", - " \n", + " pr_sdnn += [np.sum(np.array(pr_sdnn) * PDG_ratios)]\n", + " pr_ln += [np.sum(np.array(pr_ln) * PDG_ratios)]\n", + " pr_pt += [np.sum(np.array(pr_pt) * PDG_ratios)]\n", + "\n", + "\n", " # Create a mapping from labels to their positions on the x-axis\n", " x = range(len(labels))\n", " \n", " fig, ax = plt.subplots()\n", + "\n", + " # Define small offsets for each dataset\n", + " offsets = [-0.2, 0, 0.2]\n", " \n", " # Plotting the data as points\n", - " ax.scatter(x, pr_sdnn, label='DeepSet', color='#ff5b5b', marker='o', s=100)\n", - " ax.scatter(x, pr_ln, label='LorentzNet', color='#ffc140', marker='o', s=100)\n", - " ax.scatter(x, pr_pt, label='ParticleTransformer', color='#89cded', marker='o', s=100)\n", + " ax.scatter(np.array(range(len(labels))) + offsets[0], pr_sdnn, label='DeepSet', color='#ff5b5b', marker='o', s=100)\n", + " ax.scatter(np.array(range(len(labels))) + offsets[1], pr_ln, label='LorentzNet', color='#ffc140', marker='o', s=100)\n", + " ax.scatter(np.array(range(len(labels))) + offsets[2], pr_pt, label='ParticleTransformer', color='#89cded', marker='o', s=100)\n", " \n", " ax.set_xlabel('Decay Modes')\n", " ax.set_ylabel('Precision', x=1.05)\n", @@ -309,13 +321,13 @@ "\n", " # Function to add labels to the datapoints\n", " def annotate_points(x, y, offset):\n", - " for (i, j) in zip(x, y):\n", - " ax.annotate(f'{j:.2f}', (i, j), textcoords=\"offset points\", xytext=offset, ha='center', va='bottom', fontsize=9)\n", + " for ii, (i, j) in enumerate(zip(x, y)):\n", + " ax.annotate(f'{j:.3f}', (i, j), textcoords=\"offset points\", xytext=offset, ha='center', va='bottom', fontsize=9)\n", "\n", " # Annotate the datapoints with different alignments\n", - " annotate_points(x, pr_sdnn, (-16, -4))\n", - " annotate_points(x, pr_ln, (16, -4))\n", - " annotate_points(x, pr_pt, (-16, -4))\n", + " annotate_points(np.array(range(len(labels))) + offsets[0], pr_sdnn, (-18, -4))\n", + " annotate_points(np.array(range(len(labels))) + offsets[1], pr_ln, (-18, -4))\n", + " annotate_points(np.array(range(len(labels))) + offsets[2], pr_pt, (-18, -4))\n", "\n", " legend = ax.legend(loc='lower left', shadow=True, fancybox=True, framealpha=1, borderpad=1)\n", " plt.tight_layout()\n", @@ -348,7 +360,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 24, "id": "7c34850b-b7f9-4117-95c3-034c885e3296", "metadata": {}, "outputs": [], @@ -360,7 +372,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ab8fb87c-045d-4f8f-ba8a-2fb0db4c2e24", + "id": "c6eff306-4fec-4b76-88ce-faa22e61add9", "metadata": {}, "outputs": [], "source": [] From e4e6ea0107f14948efb4243fd3bcca94d10b4335 Mon Sep 17 00:00:00 2001 From: Laurits Date: Fri, 5 Jul 2024 11:44:53 +0300 Subject: [PATCH 2/2] reset nb --- notebooks/DM_CM.ipynb | 63 +++++++++---------------------------------- 1 file changed, 13 insertions(+), 50 deletions(-) diff --git a/notebooks/DM_CM.ipynb b/notebooks/DM_CM.ipynb index 2161db8..ce0e4d5 100644 --- a/notebooks/DM_CM.ipynb +++ b/notebooks/DM_CM.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "038cc922-a195-4941-bce5-dbc4aa8e0694", "metadata": {}, "outputs": [], @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "e3580249-c747-4234-8267-c298f826d53d", "metadata": {}, "outputs": [], @@ -36,21 +36,10 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "4bc3194a-8d4e-48d9-b010-eeca4adf6273", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[1/1] Loading from /scratch/persistent/joosep/ml-tau/20240701_lowered_ptcut_merged/zh_test.parquet\n", - "Input data loaded\n", - "[1/1] Loading from /scratch/persistent/joosep/ml-tau/20240701_lowered_ptcut_merged/z_test.parquet\n", - "Input data loaded\n" - ] - } - ], + "outputs": [], "source": [ "data_zh = g.load_all_data([\"/scratch/persistent/joosep/ml-tau/20240701_lowered_ptcut_merged/zh_test.parquet\"])\n", "data_z = g.load_all_data([\"/scratch/persistent/joosep/ml-tau/20240701_lowered_ptcut_merged/z_test.parquet\"])\n" @@ -58,23 +47,10 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "8fb490e4-cf9b-4a74-bb3d-a02a6e0f8012", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[1/1] Loading from /home/laurits/ml-tau-en-reg/training-outputs/20240701_lowered_ptcut_merged/v1/dm_multiclass/ParticleTransformer/zh_test.parquet\n", - "Input data loaded\n", - "[1/1] Loading from /home/laurits/ml-tau-en-reg/training-outputs/20240701_lowered_ptcut_merged/v1/dm_multiclass/LorentzNet/zh_test.parquet\n", - "Input data loaded\n", - "[1/1] Loading from /home/laurits/ml-tau-en-reg/training-outputs/20240701_lowered_ptcut_merged/v1/dm_multiclass/SimpleDNN/zh_test.parquet\n", - "Input data loaded\n" - ] - } - ], + "outputs": [], "source": [ "paths_zh_model = {\n", " \"ParticleTransformer\": \"/home/laurits/ml-tau-en-reg/training-outputs/20240701_lowered_ptcut_merged/v1/dm_multiclass/ParticleTransformer/zh_test.parquet\",\n", @@ -87,23 +63,10 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "8530cf6d-cfea-4e5c-91ff-2d76ef326607", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[1/1] Loading from /home/laurits/ml-tau-en-reg/training-outputs/20240701_lowered_ptcut_merged/v1/dm_multiclass/ParticleTransformer/z_test.parquet\n", - "Input data loaded\n", - "[1/1] Loading from /home/laurits/ml-tau-en-reg/training-outputs/20240701_lowered_ptcut_merged/v1/dm_multiclass/LorentzNet/z_test.parquet\n", - "Input data loaded\n", - "[1/1] Loading from /home/laurits/ml-tau-en-reg/training-outputs/20240701_lowered_ptcut_merged/v1/dm_multiclass/SimpleDNN/z_test.parquet\n", - "Input data loaded\n" - ] - } - ], + "outputs": [], "source": [ "paths_z_model = {\n", " \"ParticleTransformer\": \"/home/laurits/ml-tau-en-reg/training-outputs/20240701_lowered_ptcut_merged/v1/dm_multiclass/ParticleTransformer/z_test.parquet\",\n", @@ -116,7 +79,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "e09ef30a-a66e-4806-a129-605219ae7385", "metadata": {}, "outputs": [], @@ -127,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "id": "7445c1ea-7e5d-42c8-b8dd-bdbc4a75e3d3", "metadata": {}, "outputs": [], @@ -337,7 +300,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "58c2de54-97b7-4d6f-95d8-4c5c2da7c57e", "metadata": {}, "outputs": [], @@ -349,7 +312,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "8a2ee92d-ace3-4dc9-b57a-3a3d5ca38821", "metadata": {}, "outputs": [], @@ -360,7 +323,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "id": "7c34850b-b7f9-4117-95c3-034c885e3296", "metadata": {}, "outputs": [],