-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] add intervention methods for multi-target learning #49
base: master
Are you sure you want to change the base?
Conversation
import torch | ||
from torch import nn | ||
|
||
class RotateLayer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add unit test for these module in intervention_test.py
|
||
|
||
|
||
class DC2VR(MultiTaskRank): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add model test in dc2vr_test.py
class Intervention(nn.Module): | ||
def __init__(self, | ||
base_dim,source_dim, | ||
low_rank_dim,orth=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the orth
param has no effect?
base_dim, | ||
low_rank_dim): | ||
super().__init__() | ||
# n > m |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
base_dim must greater than low_rank_dim?
assert base_dim > low_rank_dim
// intervention tower names | ||
repeated string intervention_tower_names = 8; | ||
// low_rank_dim | ||
required uint32 low_rank_dim = 9; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible for us to assign a distinct low_rank_dim to each tower?
self.base_rotate_layer = torch.nn.utils.parametrizations.orthogonal(base_rotate_layer) | ||
source_rotate_layer = RotateLayer(source_dim, low_rank_dim) | ||
self.source_rotate_layer = torch.nn.utils.parametrizations.orthogonal(source_rotate_layer) | ||
self.dropout = torch.nn.Dropout(0.1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dropout rate can be configurable?
No description provided.