-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] add intervention methods for multi-target learning #49
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# Copyright (c) 2024, Alibaba Group; | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Dict, List | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from tzrec.datasets.utils import Batch | ||
from tzrec.features.feature import BaseFeature | ||
from tzrec.models.multi_task_rank import MultiTaskRank | ||
from tzrec.modules.mlp import MLP | ||
from tzrec.modules.mmoe import MMoE as MMoEModule | ||
from tzrec.modules.intervention import Intervention | ||
from tzrec.protos.model_pb2 import ModelConfig | ||
from tzrec.protos.models import multi_task_rank_pb2 | ||
from tzrec.utils.config_util import config_to_kwargs | ||
|
||
|
||
|
||
class DC2VR(MultiTaskRank): | ||
""" DeCoudounding Conversion Rate. | ||
|
||
Args: | ||
model_config (ModelConfig): an instance of ModelConfig. | ||
features (list): list of features. | ||
labels (list): list of label names. | ||
""" | ||
|
||
def __init__( | ||
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] | ||
) -> None: | ||
super().__init__(model_config, features, labels) | ||
assert model_config.WhichOneof("model") == "dc2vr", ( | ||
"invalid model config: %s" % self._model_config.WhichOneof("model") | ||
) | ||
assert isinstance(self._model_config, multi_task_rank_pb2.DC2VR) | ||
|
||
self._task_tower_cfgs = self._model_config.task_towers | ||
self.init_input() | ||
self.group_name = self.embedding_group.group_names()[0] | ||
feature_in = self.embedding_group.group_total_dim(self.group_name) | ||
|
||
self.bottom_mlp = None | ||
if self._model_config.HasField("bottom_mlp"): | ||
self.bottom_mlp = MLP( | ||
feature_in, **config_to_kwargs(self._model_config.bottom_mlp) | ||
) | ||
feature_in = self.bottom_mlp.output_dim() | ||
|
||
self.mmoe = None | ||
if self._model_config.HasField("expert_mlp"): | ||
self.mmoe = MMoEModule( | ||
in_features=feature_in, | ||
expert_mlp=config_to_kwargs(self._model_config.expert_mlp), | ||
num_expert=self._model_config.num_expert, | ||
num_task=len(self._task_tower_cfgs), | ||
gate_mlp=config_to_kwargs(self._model_config.gate_mlp) | ||
if self._model_config.HasField("gate_mlp") | ||
else None, | ||
) | ||
feature_in = self.mmoe.output_dim() | ||
|
||
self.task_mlps = nn.ModuleDict() | ||
for task_tower_cfg in self._task_tower_cfgs: | ||
if task_tower_cfg.HasField("mlp"): | ||
tower_mlp = MLP(feature_in, **config_to_kwargs(task_tower_cfg.mlp)) | ||
self.task_mlps[task_tower_cfg.tower_name] = tower_mlp | ||
|
||
self.intervention = nn.ModuleDict() | ||
for task_tower_cfg in self._task_tower_cfgs: | ||
tower_name = task_tower_cfg.tower_name | ||
if task_tower_cfg.HasField("low_rank_dim"): | ||
if tower_name in self.task_mlps: | ||
base_intervention_dim = self.task_mlps[tower_name].output_dim() | ||
else: | ||
base_intervention_dim = feature_in | ||
for intervention_tower_name in task_tower_cfg.intervention_tower_names: | ||
if intervention_tower_name in self.intervention: | ||
source_intervention_dim = self.intervention[ | ||
intervention_tower_name | ||
].output_dim() | ||
elif intervention_tower_name in self.task_mlps: | ||
source_intervention_dim = self.task_mlps[ | ||
intervention_tower_name | ||
].output_dim() | ||
else: | ||
source_intervention_dim = feature_in | ||
intervention = Intervention( | ||
base_intervention_dim, source_intervention_dim, task_tower_cfg.low_rank_dim | ||
) | ||
self.intervention[tower_name] = intervention | ||
|
||
self.task_outputs = nn.ModuleList() | ||
for task_tower_cfg in self._task_tower_cfgs: | ||
tower_name = task_tower_cfg.tower_name | ||
if tower_name in self.intervention: | ||
input_dim = self.intervention[tower_name].output_dim() | ||
elif tower_name in self.task_mlps: | ||
input_dim = self.task_mlps[tower_name].output_dim() | ||
else: | ||
input_dim = feature_in | ||
self.task_outputs.append(nn.Linear(input_dim, task_tower_cfg.num_class)) | ||
|
||
def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: | ||
"""Forward the model. | ||
|
||
Args: | ||
batch (Batch): input batch data. | ||
|
||
Return: | ||
predictions (dict): a dict of predicted result. | ||
""" | ||
grouped_features = self.build_input(batch) | ||
|
||
net = grouped_features[self.group_name] | ||
if self.bottom_mlp is not None: | ||
net = self.bottom_mlp(net) | ||
|
||
if self.mmoe is not None: | ||
task_input_list = self.mmoe(net) | ||
else: | ||
task_input_list = [net] * len(self._task_tower_cfgs) | ||
|
||
task_net = {} | ||
for i, task_tower_cfg in enumerate(self._task_tower_cfgs): | ||
tower_name = task_tower_cfg.tower_name | ||
if tower_name in self.task_mlps.keys(): | ||
task_net[tower_name] = self.task_mlps[tower_name](task_input_list[i]) | ||
else: | ||
task_net[tower_name] = task_input_list[i] | ||
|
||
intervention = {} | ||
for task_tower_cfg in self._task_tower_cfgs: | ||
tower_name = task_tower_cfg.tower_name | ||
if task_tower_cfg.HasField("low_rank_dim"): | ||
intervention_base = task_net[tower_name] | ||
intervention_source = [] | ||
for intervention_tower_name in task_tower_cfg.intervention_tower_names: | ||
intervention_source.append(intervention[intervention_tower_name]) | ||
intervention_source = torch.stack(intervention_source, dim=0).mean(0) | ||
intervention[tower_name] = self.intervention[tower_name]( | ||
intervention_base,intervention_source | ||
) | ||
else: | ||
intervention[tower_name] = task_net[tower_name] | ||
|
||
tower_outputs = {} | ||
for i, task_tower_cfg in enumerate(self._task_tower_cfgs): | ||
tower_name = task_tower_cfg.tower_name | ||
tower_output = self.task_outputs[i](intervention[tower_name]) | ||
tower_outputs[tower_name] = tower_output | ||
|
||
return self._multi_task_output_to_prediction(tower_outputs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
from torch import nn | ||
|
||
class RotateLayer(nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add unit test for these module in intervention_test.py |
||
def __init__(self, | ||
base_dim, | ||
low_rank_dim): | ||
super().__init__() | ||
# n > m | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. base_dim must greater than low_rank_dim?
|
||
self.weight = torch.nn.Parameter(torch.empty(base_dim, low_rank_dim), requires_grad=True) | ||
torch.nn.init.orthogonal_(self.weight) | ||
|
||
def forward(self, base): | ||
return torch.matmul(base.to(self.weight.dtype), self.weight) | ||
|
||
class Intervention(nn.Module): | ||
def __init__(self, | ||
base_dim,source_dim, | ||
low_rank_dim,orth=True): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the |
||
super().__init__() | ||
self.base_dim = base_dim | ||
base_rotate_layer = RotateLayer(base_dim, low_rank_dim) | ||
self.base_rotate_layer = torch.nn.utils.parametrizations.orthogonal(base_rotate_layer) | ||
source_rotate_layer = RotateLayer(source_dim, low_rank_dim) | ||
self.source_rotate_layer = torch.nn.utils.parametrizations.orthogonal(source_rotate_layer) | ||
self.dropout = torch.nn.Dropout(0.1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dropout rate can be configurable? |
||
|
||
def forward(self, base, source): | ||
rotated_base = self.base_rotate_layer(base) | ||
rotated_source = self.source_rotate_layer(source.detach()) | ||
output = torch.matmul(rotated_base-rotated_source, self.base_rotate_layer.weight.T) + base | ||
return self.dropout(output.to(base.dtype)) | ||
|
||
def output_dim(self) -> int: | ||
"""Output dimension of the module.""" | ||
return self.base_dim |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,6 +58,29 @@ message BayesTaskTower { | |
optional MLP relation_mlp = 9; | ||
}; | ||
|
||
message InterventionTaskTower { | ||
// task name for the task tower | ||
required string tower_name = 1; | ||
// label for the task, default is label_fields by order | ||
optional string label_name = 2; | ||
// metrics for the task | ||
repeated MetricConfig metrics = 3; | ||
// loss for the task | ||
repeated LossConfig losses = 4; | ||
// num_class for multi-class classification loss | ||
optional uint32 num_class = 5 [default = 1]; | ||
// task specific mlp | ||
optional MLP mlp = 6; | ||
// training loss weights | ||
optional float weight = 7 [default = 1.0]; | ||
|
||
// intervention tower names | ||
repeated string intervention_tower_names = 8; | ||
// low_rank_dim | ||
required uint32 low_rank_dim = 9; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible for us to assign a distinct low_rank_dim to each tower? |
||
}; | ||
|
||
|
||
message MultiWindowDINTower { | ||
// time windows len | ||
repeated uint32 windows_len = 1; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add model test in dc2vr_test.py