From daf6c4faee4e4272b94faa9188dbdc77eb493952 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E9=AB=98=E9=A3=9E?= Date: Thu, 28 Nov 2024 15:42:10 +0800 Subject: [PATCH] intervention method --- tzrec/models/dc2vr.py | 162 ++++++++++++++++++++++ tzrec/modules/intervention.py | 36 +++++ tzrec/protos/model.proto | 1 + tzrec/protos/models/multi_task_rank.proto | 15 ++ tzrec/protos/tower.proto | 23 +++ 5 files changed, 237 insertions(+) create mode 100644 tzrec/models/dc2vr.py create mode 100644 tzrec/modules/intervention.py diff --git a/tzrec/models/dc2vr.py b/tzrec/models/dc2vr.py new file mode 100644 index 0000000..79a97b6 --- /dev/null +++ b/tzrec/models/dc2vr.py @@ -0,0 +1,162 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List + +import torch +from torch import nn + +from tzrec.datasets.utils import Batch +from tzrec.features.feature import BaseFeature +from tzrec.models.multi_task_rank import MultiTaskRank +from tzrec.modules.mlp import MLP +from tzrec.modules.mmoe import MMoE as MMoEModule +from tzrec.modules.intervention import Intervention +from tzrec.protos.model_pb2 import ModelConfig +from tzrec.protos.models import multi_task_rank_pb2 +from tzrec.utils.config_util import config_to_kwargs + + + +class DC2VR(MultiTaskRank): + """ DeCoudounding Conversion Rate. + + Args: + model_config (ModelConfig): an instance of ModelConfig. + features (list): list of features. + labels (list): list of label names. + """ + + def __init__( + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + ) -> None: + super().__init__(model_config, features, labels) + assert model_config.WhichOneof("model") == "dc2vr", ( + "invalid model config: %s" % self._model_config.WhichOneof("model") + ) + assert isinstance(self._model_config, multi_task_rank_pb2.DC2VR) + + self._task_tower_cfgs = self._model_config.task_towers + self.init_input() + self.group_name = self.embedding_group.group_names()[0] + feature_in = self.embedding_group.group_total_dim(self.group_name) + + self.bottom_mlp = None + if self._model_config.HasField("bottom_mlp"): + self.bottom_mlp = MLP( + feature_in, **config_to_kwargs(self._model_config.bottom_mlp) + ) + feature_in = self.bottom_mlp.output_dim() + + self.mmoe = None + if self._model_config.HasField("expert_mlp"): + self.mmoe = MMoEModule( + in_features=feature_in, + expert_mlp=config_to_kwargs(self._model_config.expert_mlp), + num_expert=self._model_config.num_expert, + num_task=len(self._task_tower_cfgs), + gate_mlp=config_to_kwargs(self._model_config.gate_mlp) + if self._model_config.HasField("gate_mlp") + else None, + ) + feature_in = self.mmoe.output_dim() + + self.task_mlps = nn.ModuleDict() + for task_tower_cfg in self._task_tower_cfgs: + if task_tower_cfg.HasField("mlp"): + tower_mlp = MLP(feature_in, **config_to_kwargs(task_tower_cfg.mlp)) + self.task_mlps[task_tower_cfg.tower_name] = tower_mlp + + self.intervention = nn.ModuleDict() + for task_tower_cfg in self._task_tower_cfgs: + tower_name = task_tower_cfg.tower_name + if task_tower_cfg.HasField("low_rank_dim"): + if tower_name in self.task_mlps: + base_intervention_dim = self.task_mlps[tower_name].output_dim() + else: + base_intervention_dim = feature_in + for intervention_tower_name in task_tower_cfg.intervention_tower_names: + if intervention_tower_name in self.intervention: + source_intervention_dim = self.intervention[ + intervention_tower_name + ].output_dim() + elif intervention_tower_name in self.task_mlps: + source_intervention_dim = self.task_mlps[ + intervention_tower_name + ].output_dim() + else: + source_intervention_dim = feature_in + intervention = Intervention( + base_intervention_dim, source_intervention_dim, task_tower_cfg.low_rank_dim + ) + self.intervention[tower_name] = intervention + + self.task_outputs = nn.ModuleList() + for task_tower_cfg in self._task_tower_cfgs: + tower_name = task_tower_cfg.tower_name + if tower_name in self.intervention: + input_dim = self.intervention[tower_name].output_dim() + elif tower_name in self.task_mlps: + input_dim = self.task_mlps[tower_name].output_dim() + else: + input_dim = feature_in + self.task_outputs.append(nn.Linear(input_dim, task_tower_cfg.num_class)) + + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: + """Forward the model. + + Args: + batch (Batch): input batch data. + + Return: + predictions (dict): a dict of predicted result. + """ + grouped_features = self.build_input(batch) + + net = grouped_features[self.group_name] + if self.bottom_mlp is not None: + net = self.bottom_mlp(net) + + if self.mmoe is not None: + task_input_list = self.mmoe(net) + else: + task_input_list = [net] * len(self._task_tower_cfgs) + + task_net = {} + for i, task_tower_cfg in enumerate(self._task_tower_cfgs): + tower_name = task_tower_cfg.tower_name + if tower_name in self.task_mlps.keys(): + task_net[tower_name] = self.task_mlps[tower_name](task_input_list[i]) + else: + task_net[tower_name] = task_input_list[i] + + intervention = {} + for task_tower_cfg in self._task_tower_cfgs: + tower_name = task_tower_cfg.tower_name + if task_tower_cfg.HasField("low_rank_dim"): + intervention_base = task_net[tower_name] + intervention_source = [] + for intervention_tower_name in task_tower_cfg.intervention_tower_names: + intervention_source.append(intervention[intervention_tower_name]) + intervention_source = torch.stack(intervention_source, dim=0).mean(0) + intervention[tower_name] = self.intervention[tower_name]( + intervention_base,intervention_source + ) + else: + intervention[tower_name] = task_net[tower_name] + + tower_outputs = {} + for i, task_tower_cfg in enumerate(self._task_tower_cfgs): + tower_name = task_tower_cfg.tower_name + tower_output = self.task_outputs[i](intervention[tower_name]) + tower_outputs[tower_name] = tower_output + + return self._multi_task_output_to_prediction(tower_outputs) diff --git a/tzrec/modules/intervention.py b/tzrec/modules/intervention.py new file mode 100644 index 0000000..f563690 --- /dev/null +++ b/tzrec/modules/intervention.py @@ -0,0 +1,36 @@ +import torch +from torch import nn + +class RotateLayer(nn.Module): + def __init__(self, + base_dim, + low_rank_dim): + super().__init__() + # n > m + self.weight = torch.nn.Parameter(torch.empty(base_dim, low_rank_dim), requires_grad=True) + torch.nn.init.orthogonal_(self.weight) + + def forward(self, base): + return torch.matmul(base.to(self.weight.dtype), self.weight) + +class Intervention(nn.Module): + def __init__(self, + base_dim,source_dim, + low_rank_dim,orth=True): + super().__init__() + self.base_dim = base_dim + base_rotate_layer = RotateLayer(base_dim, low_rank_dim) + self.base_rotate_layer = torch.nn.utils.parametrizations.orthogonal(base_rotate_layer) + source_rotate_layer = RotateLayer(source_dim, low_rank_dim) + self.source_rotate_layer = torch.nn.utils.parametrizations.orthogonal(source_rotate_layer) + self.dropout = torch.nn.Dropout(0.1) + + def forward(self, base, source): + rotated_base = self.base_rotate_layer(base) + rotated_source = self.source_rotate_layer(source.detach()) + output = torch.matmul(rotated_base-rotated_source, self.base_rotate_layer.weight.T) + base + return self.dropout(output.to(base.dtype)) + + def output_dim(self) -> int: + """Output dimension of the module.""" + return self.base_dim \ No newline at end of file diff --git a/tzrec/protos/model.proto b/tzrec/protos/model.proto index f3603cd..9ede437 100644 --- a/tzrec/protos/model.proto +++ b/tzrec/protos/model.proto @@ -43,6 +43,7 @@ message ModelConfig { MMoE mmoe = 201; DBMTL dbmtl = 202; PLE ple = 203; + DC2VR dc2vr = 204; DSSM dssm = 301; DSSMV2 dssm_v2 = 302; diff --git a/tzrec/protos/models/multi_task_rank.proto b/tzrec/protos/models/multi_task_rank.proto index 8db2422..86d8cd0 100644 --- a/tzrec/protos/models/multi_task_rank.proto +++ b/tzrec/protos/models/multi_task_rank.proto @@ -34,6 +34,21 @@ message DBMTL { repeated BayesTaskTower task_towers = 5; } +message DC2VR { + // shared bottom mlp layer + optional MLP bottom_mlp = 1; + // mmoe expert mlp layer definition + optional MLP expert_mlp = 2; + // mmoe gate module definition + optional MLP gate_mlp = 3; + // number of mmoe experts + optional uint32 num_expert = 4 [default=3]; + // task tower + repeated InterventionTaskTower task_towers = 5; +} + + + message PLE { // extraction network repeated ExtractionNetwork extraction_networks = 1; diff --git a/tzrec/protos/tower.proto b/tzrec/protos/tower.proto index 629ab6b..422b702 100644 --- a/tzrec/protos/tower.proto +++ b/tzrec/protos/tower.proto @@ -58,6 +58,29 @@ message BayesTaskTower { optional MLP relation_mlp = 9; }; +message InterventionTaskTower { + // task name for the task tower + required string tower_name = 1; + // label for the task, default is label_fields by order + optional string label_name = 2; + // metrics for the task + repeated MetricConfig metrics = 3; + // loss for the task + repeated LossConfig losses = 4; + // num_class for multi-class classification loss + optional uint32 num_class = 5 [default = 1]; + // task specific mlp + optional MLP mlp = 6; + // training loss weights + optional float weight = 7 [default = 1.0]; + + // intervention tower names + repeated string intervention_tower_names = 8; + // low_rank_dim + required uint32 low_rank_dim = 9; +}; + + message MultiWindowDINTower { // time windows len repeated uint32 windows_len = 1;