Skip to content

Commit

Permalink
W.A. for heavy memory consumption eltwise nodes fusion for Didi vecto…
Browse files Browse the repository at this point in the history
…rnet model
  • Loading branch information
liubo-intel committed Jun 6, 2024
1 parent 63e043b commit c39fe9f
Showing 1 changed file with 35 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
#include "openvino/op/reshape.hpp"

template <class T>
std::function<bool(ov::Output<ov::Node>)> value_is_equal_to(const std::vector<T>& ref_values) {
Expand Down Expand Up @@ -172,7 +173,38 @@ ov::pass::MVNFusionWithoutConstants::MVNFusionWithoutConstants() {
} else {
return false;
}
auto mvn = std::make_shared<ov::op::v6::MVN>(exp_input, axes_1_node, true, eps_value, mode);

auto org_exp_shape = exp_input.get_partial_shape();

std::shared_ptr<ov::Node> mvn, final_output;
if (axes_1_value.size() == 1 && org_exp_shape.rank().is_static() && org_exp_shape.size() > 3 &&
(static_cast<size_t>(axes_1_value[0]) == (org_exp_shape.size() - 1) || axes_1_value[0] == -1)) {

auto temp_shape = Shape(2, -1);
temp_shape[1] = org_exp_shape[1].get_length();

auto reshape = std::make_shared<ov::op::v1::Reshape>(
exp_input,
ov::op::v0::Constant::create(element::i64, Shape{2}, temp_shape), true);

mvn = std::make_shared<ov::op::v6::MVN>(reshape, axes_1_node, true, eps_value, mode);

auto org_shape = Shape(org_exp_shape.size(), -1);
for (size_t i = 0; i < org_exp_shape.size(); i++) {
org_shape[i] = org_exp_shape[i].is_dynamic() ? -1 : org_exp_shape[i].get_length();
}

auto reshape_2 = std::make_shared<ov::op::v1::Reshape>(
mvn,
ov::op::v0::Constant::create(element::i64, Shape{org_exp_shape.size()}, org_shape),
true);

final_output = reshape_2;

} else {
mvn = std::make_shared<ov::op::v6::MVN>(exp_input, axes_1_node, true, eps_value, mode);
final_output = mvn;
}

if (pattern_to_output.count(mean2) && pattern_to_output.count(sub2)) {
nodes_to_copy_info.push_back(pattern_to_output.at(mean2).get_node_shared_ptr());
Expand All @@ -192,8 +224,8 @@ ov::pass::MVNFusionWithoutConstants::MVNFusionWithoutConstants() {
}

mvn->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(nodes_to_copy_info, mvn);
ov::replace_node(m.get_match_root(), mvn);
ov::copy_runtime_info(nodes_to_copy_info, final_output);
ov::replace_node(m.get_match_root(), final_output);
return true;
};

Expand Down

0 comments on commit c39fe9f

Please sign in to comment.