Skip to content

Commit

Permalink
Feat integrate node features (#142)
Browse files Browse the repository at this point in the history
* bug: dataset construction

* bug: dataset generation

---------

Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored May 11, 2024
1 parent 6c735b3 commit fbec0ae
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
13 changes: 7 additions & 6 deletions src/deep_neurographs/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def smooth_branch(xyz, s=None):
Returns
-------
xyz : numpy.ndarray
numpy.ndarray
Smoothed points.
"""
Expand All @@ -199,7 +199,7 @@ def smooth_branch(xyz, s=None):
return xyz.astype(np.float32)


def fit_spline(xyz, s=None):
def fit_spline(xyz, k=3, s=None):
"""
Fits a cubic spline to an array containing xyz coordinates.
Expand All @@ -222,9 +222,9 @@ def fit_spline(xyz, s=None):
"""
s = xyz.shape[0] / 10 if not s else xyz.shape[0] / s
t = np.linspace(0, 1, xyz.shape[0])
spline_x = UnivariateSpline(t, xyz[:, 0], s=s, k=3)
spline_y = UnivariateSpline(t, xyz[:, 1], s=s, k=3)
spline_z = UnivariateSpline(t, xyz[:, 2], s=s, k=3)
spline_x = UnivariateSpline(t, xyz[:, 0], k=k, s=s)
spline_y = UnivariateSpline(t, xyz[:, 1], k=k, s=s)
spline_z = UnivariateSpline(t, xyz[:, 2], k=k, s=s)
return spline_x, spline_y, spline_z


Expand All @@ -245,8 +245,9 @@ def sample_curve(xyz_arr, n_pts):
Resampled points along curve.
"""
k = 1 if xyz_arr.shape[0] <= 3 else 3
t = np.linspace(0, 1, n_pts)
spline_x, spline_y, spline_z = fit_spline(xyz_arr, s=0)
spline_x, spline_y, spline_z = fit_spline(xyz_arr, k=k, s=0)
xyz = np.column_stack((spline_x(t), spline_y(t), spline_z(t)))
return xyz.astype(int)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,11 @@ def init_edge_attrs(self, x_nodes):
"""
# Proposal edges
edge_type = ("proposal", "to", "proposal")
attrs = self.set_edge_attrs(x_nodes, edge_type, self.idxs_proposals)
# --> set attr
self.set_edge_attrs(x_nodes, edge_type, self.idxs_proposals)

# Branch edges
edge_type = ("branch", "to", "branch")
self.set_edge_attrs(x_nodes, edge_type, self.idxs_branches)

# Branch-Proposal edges
edge_type = ("branch", "to", "proposal")
Expand Down Expand Up @@ -270,9 +270,9 @@ def set_edge_attrs(self, x_nodes, edge_type, idx_mapping):
e1, e2 = self.data[edge_type][:, i]
v = node_intersection(idx_mapping, e1, e2)
attrs.append(x_nodes[v])
print(v)
print(attrs)
stop
arrs = torch.tensor(np.array(attrs), dtype=DTYPE)
self.data[edge_type].edge_attr = arrs


# -- utils --
def init_idxs(idxs):
Expand Down

0 comments on commit fbec0ae

Please sign in to comment.