Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] add intervention methods for multi-target learning #49

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions tzrec/models/dc2vr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List

import torch
from torch import nn

from tzrec.datasets.utils import Batch
from tzrec.features.feature import BaseFeature
from tzrec.models.multi_task_rank import MultiTaskRank
from tzrec.modules.mlp import MLP
from tzrec.modules.mmoe import MMoE as MMoEModule
from tzrec.modules.intervention import Intervention
from tzrec.protos.model_pb2 import ModelConfig
from tzrec.protos.models import multi_task_rank_pb2
from tzrec.utils.config_util import config_to_kwargs



class DC2VR(MultiTaskRank):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add model test in dc2vr_test.py

""" DeCoudounding Conversion Rate.

Args:
model_config (ModelConfig): an instance of ModelConfig.
features (list): list of features.
labels (list): list of label names.
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str]
) -> None:
super().__init__(model_config, features, labels)
assert model_config.WhichOneof("model") == "dc2vr", (
"invalid model config: %s" % self._model_config.WhichOneof("model")
)
assert isinstance(self._model_config, multi_task_rank_pb2.DC2VR)

self._task_tower_cfgs = self._model_config.task_towers
self.init_input()
self.group_name = self.embedding_group.group_names()[0]
feature_in = self.embedding_group.group_total_dim(self.group_name)

self.bottom_mlp = None
if self._model_config.HasField("bottom_mlp"):
self.bottom_mlp = MLP(
feature_in, **config_to_kwargs(self._model_config.bottom_mlp)
)
feature_in = self.bottom_mlp.output_dim()

self.mmoe = None
if self._model_config.HasField("expert_mlp"):
self.mmoe = MMoEModule(
in_features=feature_in,
expert_mlp=config_to_kwargs(self._model_config.expert_mlp),
num_expert=self._model_config.num_expert,
num_task=len(self._task_tower_cfgs),
gate_mlp=config_to_kwargs(self._model_config.gate_mlp)
if self._model_config.HasField("gate_mlp")
else None,
)
feature_in = self.mmoe.output_dim()

self.task_mlps = nn.ModuleDict()
for task_tower_cfg in self._task_tower_cfgs:
if task_tower_cfg.HasField("mlp"):
tower_mlp = MLP(feature_in, **config_to_kwargs(task_tower_cfg.mlp))
self.task_mlps[task_tower_cfg.tower_name] = tower_mlp

self.intervention = nn.ModuleDict()
for task_tower_cfg in self._task_tower_cfgs:
tower_name = task_tower_cfg.tower_name
if task_tower_cfg.HasField("low_rank_dim"):
if tower_name in self.task_mlps:
base_intervention_dim = self.task_mlps[tower_name].output_dim()
else:
base_intervention_dim = feature_in
for intervention_tower_name in task_tower_cfg.intervention_tower_names:
if intervention_tower_name in self.intervention:
source_intervention_dim = self.intervention[
intervention_tower_name
].output_dim()
elif intervention_tower_name in self.task_mlps:
source_intervention_dim = self.task_mlps[
intervention_tower_name
].output_dim()
else:
source_intervention_dim = feature_in
intervention = Intervention(
base_intervention_dim, source_intervention_dim, task_tower_cfg.low_rank_dim
)
self.intervention[tower_name] = intervention

self.task_outputs = nn.ModuleList()
for task_tower_cfg in self._task_tower_cfgs:
tower_name = task_tower_cfg.tower_name
if tower_name in self.intervention:
input_dim = self.intervention[tower_name].output_dim()
elif tower_name in self.task_mlps:
input_dim = self.task_mlps[tower_name].output_dim()
else:
input_dim = feature_in
self.task_outputs.append(nn.Linear(input_dim, task_tower_cfg.num_class))

def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
"""Forward the model.

Args:
batch (Batch): input batch data.

Return:
predictions (dict): a dict of predicted result.
"""
grouped_features = self.build_input(batch)

net = grouped_features[self.group_name]
if self.bottom_mlp is not None:
net = self.bottom_mlp(net)

if self.mmoe is not None:
task_input_list = self.mmoe(net)
else:
task_input_list = [net] * len(self._task_tower_cfgs)

task_net = {}
for i, task_tower_cfg in enumerate(self._task_tower_cfgs):
tower_name = task_tower_cfg.tower_name
if tower_name in self.task_mlps.keys():
task_net[tower_name] = self.task_mlps[tower_name](task_input_list[i])
else:
task_net[tower_name] = task_input_list[i]

intervention = {}
for task_tower_cfg in self._task_tower_cfgs:
tower_name = task_tower_cfg.tower_name
if task_tower_cfg.HasField("low_rank_dim"):
intervention_base = task_net[tower_name]
intervention_source = []
for intervention_tower_name in task_tower_cfg.intervention_tower_names:
intervention_source.append(intervention[intervention_tower_name])
intervention_source = torch.stack(intervention_source, dim=0).mean(0)
intervention[tower_name] = self.intervention[tower_name](
intervention_base,intervention_source
)
else:
intervention[tower_name] = task_net[tower_name]

tower_outputs = {}
for i, task_tower_cfg in enumerate(self._task_tower_cfgs):
tower_name = task_tower_cfg.tower_name
tower_output = self.task_outputs[i](intervention[tower_name])
tower_outputs[tower_name] = tower_output

return self._multi_task_output_to_prediction(tower_outputs)
36 changes: 36 additions & 0 deletions tzrec/modules/intervention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from torch import nn

class RotateLayer(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add unit test for these module in intervention_test.py

def __init__(self,
base_dim,
low_rank_dim):
super().__init__()
# n > m
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base_dim must greater than low_rank_dim?

assert base_dim > low_rank_dim

self.weight = torch.nn.Parameter(torch.empty(base_dim, low_rank_dim), requires_grad=True)
torch.nn.init.orthogonal_(self.weight)

def forward(self, base):
return torch.matmul(base.to(self.weight.dtype), self.weight)

class Intervention(nn.Module):
def __init__(self,
base_dim,source_dim,
low_rank_dim,orth=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the orth param has no effect?

super().__init__()
self.base_dim = base_dim
base_rotate_layer = RotateLayer(base_dim, low_rank_dim)
self.base_rotate_layer = torch.nn.utils.parametrizations.orthogonal(base_rotate_layer)
source_rotate_layer = RotateLayer(source_dim, low_rank_dim)
self.source_rotate_layer = torch.nn.utils.parametrizations.orthogonal(source_rotate_layer)
self.dropout = torch.nn.Dropout(0.1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropout rate can be configurable?


def forward(self, base, source):
rotated_base = self.base_rotate_layer(base)
rotated_source = self.source_rotate_layer(source.detach())
output = torch.matmul(rotated_base-rotated_source, self.base_rotate_layer.weight.T) + base
return self.dropout(output.to(base.dtype))

def output_dim(self) -> int:
"""Output dimension of the module."""
return self.base_dim
1 change: 1 addition & 0 deletions tzrec/protos/model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ message ModelConfig {
MMoE mmoe = 201;
DBMTL dbmtl = 202;
PLE ple = 203;
DC2VR dc2vr = 204;

DSSM dssm = 301;
DSSMV2 dssm_v2 = 302;
Expand Down
15 changes: 15 additions & 0 deletions tzrec/protos/models/multi_task_rank.proto
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ message DBMTL {
repeated BayesTaskTower task_towers = 5;
}

message DC2VR {
// shared bottom mlp layer
optional MLP bottom_mlp = 1;
// mmoe expert mlp layer definition
optional MLP expert_mlp = 2;
// mmoe gate module definition
optional MLP gate_mlp = 3;
// number of mmoe experts
optional uint32 num_expert = 4 [default=3];
// task tower
repeated InterventionTaskTower task_towers = 5;
}



message PLE {
// extraction network
repeated ExtractionNetwork extraction_networks = 1;
Expand Down
23 changes: 23 additions & 0 deletions tzrec/protos/tower.proto
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,29 @@ message BayesTaskTower {
optional MLP relation_mlp = 9;
};

message InterventionTaskTower {
// task name for the task tower
required string tower_name = 1;
// label for the task, default is label_fields by order
optional string label_name = 2;
// metrics for the task
repeated MetricConfig metrics = 3;
// loss for the task
repeated LossConfig losses = 4;
// num_class for multi-class classification loss
optional uint32 num_class = 5 [default = 1];
// task specific mlp
optional MLP mlp = 6;
// training loss weights
optional float weight = 7 [default = 1.0];

// intervention tower names
repeated string intervention_tower_names = 8;
// low_rank_dim
required uint32 low_rank_dim = 9;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for us to assign a distinct low_rank_dim to each tower?

};


message MultiWindowDINTower {
// time windows len
repeated uint32 windows_len = 1;
Expand Down
Loading