Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt DeePMD in Ascend platform (NPU) #2371

Open
wants to merge 22 commits into
base: ascend
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1a42f36
add convert_org_to_ascend in convert
invalid-email-address Aug 18, 2022
3947b65
add convert_org_to_ascend
invalid-email-address Aug 19, 2022
344f02e
change the variable is_transfer to is_ascend_transfer
invalid-email-address Aug 22, 2022
3f1d134
1. add a new dp argument transfer-to-ascend; 2. add dp test for ascen…
invalid-email-address Aug 27, 2022
93016e3
1. fix network bug; 2. fix prod_env_mat op register bug
invalid-email-address Aug 29, 2022
fcc909c
fix prod_env_mat_multi_device.cc miswriting
invalid-email-address Aug 30, 2022
05216ef
fix a Ascend incremental code bug
invalid-email-address Aug 30, 2022
3b9a794
fix mixed_prec bug
invalid-email-address Aug 30, 2022
ba22b1f
fix some details according to the Modification comments
invalid-email-address Aug 31, 2022
2fbc8fc
fix transfer spell
invalid-email-address Aug 31, 2022
4771c75
Second modified version according to the comments, we refactored the …
invalid-email-address Sep 6, 2022
98a68c9
Third modified version according to the comments
invalid-email-address Sep 14, 2022
9124fac
sync fork modifications
invalid-email-address Sep 15, 2022
a1f4a10
fix a bug in network.py
invalid-email-address Sep 16, 2022
d2ec06d
1. add a test unit for transfer-to-ascend interface; 2. fix the bugs …
invalid-email-address Sep 23, 2022
aeb1804
set DP_INTERFACE_PREC=ascend_mix for the test unit
invalid-email-address Sep 23, 2022
3eb6609
modify deeppot-2.pbtxt which has BatchMatMulV2
invalid-email-address Sep 23, 2022
563f312
set ascend_mix
invalid-email-address Sep 23, 2022
d0261d5
fix a cmakelist bug, find python first then find TF
invalid-email-address Dec 14, 2022
e86654a
1. add more explanation for ascend graph when natoms is a const op; 3…
invalid-email-address Mar 7, 2023
5d3c526
Add a new custom option to adapte the newest Ascend CANN software
invalid-email-address Mar 24, 2023
107e6f6
fix a NPU option bug
invalid-email-address Jul 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepmd/entrypoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..infer.model_devi import make_model_devi
from .convert import convert
from .neighbor_stat import neighbor_stat
from .transfer_to_ascend import transfer_to_ascend

__all__ = [
"config",
Expand All @@ -25,4 +26,5 @@
"make_model_devi",
"convert",
"neighbor_stat",
"transfer_to_ascend",
]
4 changes: 2 additions & 2 deletions deepmd/entrypoints/convert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from deepmd.utils.convert import convert_012_to_21, convert_10_to_21, convert_20_to_21, convert_13_to_21, convert_12_to_21
from deepmd.utils.convert import convert_012_to_21, convert_10_to_21, convert_20_to_21, convert_13_to_21, convert_12_to_21

def convert(
*,
Expand All @@ -19,4 +19,4 @@ def convert(
elif FROM == '2.0':
convert_20_to_21(input_model, output_model)
else:
raise RuntimeError('unsupported model version ' + FROM)
raise RuntimeError('unsupported model version ' + FROM)
53 changes: 53 additions & 0 deletions deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
make_model_devi,
convert,
neighbor_stat,
transfer_to_ascend
)
from deepmd.loggers import set_log_handles

Expand Down Expand Up @@ -139,6 +140,14 @@ def main_parser() -> argparse.ArgumentParser:
type=str,
help="the model after passing parameters",
)
parser_transfer.add_argument(
"-a",
"--ascend-graph",
default="",
type=str,
help="the model with constant natoms input, which is onle used for Ascend platform",
)


# * config parser ******************************************************************
parser_train = subparsers.add_parser(
Expand Down Expand Up @@ -513,6 +522,48 @@ def main_parser() -> argparse.ArgumentParser:
choices=['s1', 's2'],
help="steps to train model of NVNMD: s1 (train CNN), s2 (train QNN)"
)

# * transfer to ascend models ***********************************************************
parser_trans_to_ascend = subparsers.add_parser(
'transfer-to-ascend',
parents=[parser_log, parser_mpi_log],
help='transfer original model to ascend NPU supported mix precision version',
)
parser_trans_to_ascend.add_argument(
'TO',
type = str,
default = 'mix_precision',
choices = ['mix_precision'],
help="The transfer type of transfer-to-ascend module",
)
parser_trans_to_ascend.add_argument(
'-i',
"--input-model",
default = "model.pb",
type=str,
help = "the input model",
)
parser_trans_to_ascend.add_argument(
"-o",
"--output-model",
default = "Ascend_transfer.pb",
type=str,
help='the output model',
)
parser_trans_to_ascend.add_argument(
"-c",
"--checkpoint-folder",
default = "model-transfer",
type=str,
help='path to checkpoint folder',
)
parser_trans_to_ascend.add_argument(
"-t",
"--training-script",
type=str,
default=None,
help="The training script of the input frozen model",
)
return parser


Expand Down Expand Up @@ -580,6 +631,8 @@ def main():
neighbor_stat(**dict_args)
elif args.command == "train-nvnmd": # nvnmd
train_nvnmd(**dict_args)
elif args.command == "transfer-to-ascend":
transfer_to_ascend(**dict_args)
elif args.command is None:
pass
else:
Expand Down
10 changes: 8 additions & 2 deletions deepmd/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,14 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions, is_compress: bool = Fal

# setup data modifier
modifier = get_modifier(jdata["model"].get("modifier", None))

# get transfer info
is_ascend_transfer = jdata["model"].get("transfered_from_model", None)

# decouple the training data from the model compress process
train_data = None
valid_data = None
if not is_compress:
if not is_compress and not is_ascend_transfer:
# init data
train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, modifier)
train_data.print_summary("training")
Expand All @@ -162,7 +165,10 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions, is_compress: bool = Fal
stop_batch = j_must_have(jdata["training"], "numb_steps")
model.build(train_data, stop_batch)

if not is_compress:
if is_ascend_transfer:
model.save_transfered()
log.info("finished transfering")
elif not is_compress:
# train the model with the provided systems in a cyclic way
start_time = time.time()
model.train(train_data, valid_data)
Expand Down
53 changes: 45 additions & 8 deletions deepmd/entrypoints/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import numpy as np
import logging
import os

__all__ = ["transfer"]

Expand Down Expand Up @@ -67,6 +68,12 @@ def transfer(*, old_model: str, raw_model: str, output: str, **kwargs):
with tf.gfile.GFile(output, mode="wb") as f:
f.write(new_graph_def.SerializeToString())
log.info("the output model is saved in " + output)
dp_float_prec = os.environ.get("DP_INTERFACE_PREC", "high").lower()
if kwargs['ascend_graph']:
const_graph_def = modify_const_op(new_graph_def)
with tf.gfile.GFile(kwargs['ascend_graph'], mode="wb") as f:
f.write(const_graph_def.SerializeToString())
log.info("the dp test model is saved in " + kwargs['ascend_graph'])


def load_graph(graph_name: str) -> tf.Graph:
Expand All @@ -89,6 +96,35 @@ def load_graph(graph_name: str) -> tf.Graph:
tf.import_graph_def(graph_def, name="")
return graph

def modify_const_op(new_graph_def: tf.Graph) -> tf.Graph:
"""modify natoms to constant.

Parameters
----------
new_graph : tf.Graph
orginal new graph
Returns
-------
tf.Graph
natoms transfer to a const op for Ascend platform
"""
for node in new_graph_def.node:
if "t_natoms" in node.name:
node.op = "Const"
natoms_shape = node.attr["shape"]
shape_val = [dim.size for dim in natoms_shape.shape.dim]
if os.path.exists("natoms_val.txt"):
natoms_list = np.loadtxt("natoms_val.txt")
assert shape_val[0] == len(natoms_list)
del node.attr["shape"]
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto([int(i) for i in natoms_list],
tf.int32, [shape_val[0]])))
log.info(f"{node.name} is passed from a placeholder to a const")
else:
explanation = "natoms_val.txt file is not exist, one shold put padding natoms in it. Values are separated by SPACE. For example, in WATER exple one can excute 'echo 210 4000 70 140 > natoms_val.txt && dp transfer-to-ascend mix_precision -i model.pb'."
log.warning(explanation)

return new_graph_def

def transform_graph(raw_graph: tf.Graph, old_graph: tf.Graph) -> tf.Graph:
"""Trasform old graph into new.
Expand Down Expand Up @@ -131,15 +167,16 @@ def transform_graph(raw_graph: tf.Graph, old_graph: tf.Graph) -> tf.Graph:
if old_graph_dtype == np.float64 or old_graph_dtype == np.float32:
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
tensor = np.frombuffer(old_node.tensor_content, dtype = old_graph_dtype)
tensor = tensor.astype(raw_graph_dtype)
cp_attr.from_str(tensor)
else:
tensor = load_tensor(old_node, old_graph_dtype, raw_graph_dtype)
cp_attr.from_array(tensor, tf.float16, [1])

elif old_graph_dtype[1] == "float16":
tensor = convertMatrix(np.array(old_node.half_val), tensor_shape)
cp_attr.from_array(tensor, raw_graph_dtype)
tensor = tf.make_tensor_proto(tensor, raw_graph_dtype, tensor_shape)
for i in range(len(tensor.half_val)):
raw_node.half_val[i] = tensor.half_val[i]

elif old_graph_dtype == np.float16:
for i in range(len(old_node.half_val)):
raw_node.half_val[i] = old_node.half_val[i]

elif raw_graph_dtype == np.float64 or raw_graph_dtype == np.float32:
if old_graph_dtype == np.float64 or old_graph_dtype == np.float32:
Expand All @@ -153,10 +190,10 @@ def transform_graph(raw_graph: tf.Graph, old_graph: tf.Graph) -> tf.Graph:

elif old_graph_dtype == np.float16:
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
tensor = convertMatrix(np.array(old_node.half_val), tensor_shape).astype(raw_graph_dtype)
tensor = convert_matrix(np.array(old_node.half_val), tensor_shape).astype(raw_graph_dtype)
cp_attr.from_str(tensor)
else:
tensor = convertMatrix(np.array(old_node.half_val), tensor_shape).astype(raw_graph_dtype)
tensor = convert_matrix(np.array(old_node.half_val), tensor_shape).astype(raw_graph_dtype)
cp_attr.from_array(tensor, raw_graph_dtype)

return raw_graph_def
Expand Down
13 changes: 13 additions & 0 deletions deepmd/entrypoints/transfer_to_ascend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from deepmd.utils.transfer_to_ascend import mix_precision

def transfer_to_ascend(
*,
TO: str,
input_model: str,
output_model: str,
**kwargs,
):
if TO == 'mix_precision':
mix_precision(input_model, output_model, **kwargs)
else:
raise RuntimeError('unsupported transfering type' + FROM)
8 changes: 7 additions & 1 deletion deepmd/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"GLOBAL_TF_FLOAT_PRECISION",
"GLOBAL_NP_FLOAT_PRECISION",
"GLOBAL_ENER_FLOAT_PRECISION",
"GLOBAL_ASCEND_OUT_PRECISION",
"global_float_prec",
"global_cvt_2_tf_float",
"global_cvt_2_ener_float",
Expand Down Expand Up @@ -367,10 +368,15 @@ def _get_package_constants(
GLOBAL_NP_FLOAT_PRECISION = np.float32
GLOBAL_ENER_FLOAT_PRECISION = np.float64
global_float_prec = "float"
elif dp_float_prec == "ascend_mix":
GLOBAL_TF_FLOAT_PRECISION = tf.float32
GLOBAL_NP_FLOAT_PRECISION = np.float32
GLOBAL_ENER_FLOAT_PRECISION = np.float64
global_float_prec = "float"
else:
raise RuntimeError(
"Unsupported float precision option: %s. Supported: high,"
"low. Please set precision with environmental variable "
"low and ascend_mix. Please set precision with environmental variable "
"DP_INTERFACE_PREC." % dp_float_prec
)

Expand Down
Loading