forked from dbkgroup/prop_gen
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_molecule.py
153 lines (141 loc) · 7.43 KB
/
run_molecule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#!/usr/bin/env python3
from mpi4py import MPI
from baselines.common import set_global_seeds
from baselines import logger
from tensorboardX import SummaryWriter
import os
import tensorflow as tf
import gym
from gym_molecule.envs.molecule import GraphEnv
def train(args,seed,writer=None):
from baselines.ppo1 import pposgd_simple_gcn, gcn_policy
import baselines.common.tf_util as U
rank = MPI.COMM_WORLD.Get_rank()
# sess = U.single_threaded_session()
# sess.__enter__()
# gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
# sess = tf.Session(config=tf.ConfigProto(
# gpu_options=gpu_options,allow_soft_placement=True, log_device_placement=True))
# sess.__enter__()
# config = tf.ConfigProto(
# device_count={'GPU': 0}
# )
# sess = tf.Session(config=config)
# sess.__enter__()
config = tf.ConfigProto(allow_soft_placement=True)
# config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction=0.04
sess = tf.Session(config=config)
sess.__enter__()
if rank == 0:
logger.configure()
else:
logger.configure(format_strs=[])
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
set_global_seeds(workerseed)
if args.model_path != '':
args.model_path = os.path.abspath(args.model_path)
if args.env=='molecule':
env = gym.make('molecule-v0')
env.init(data_type=args.dataset,logp_ratio=args.logp_ratio,qed_ratio=args.qed_ratio,sa_ratio=args.sa_ratio,reward_step_total=args.reward_step_total,is_normalize=args.normalize_adj,reward_type=args.reward_type,reward_target=args.reward_target,has_feature=bool(args.has_feature),is_conditional=bool(args.is_conditional),conditional=args.conditional,max_action=args.max_action,min_action=args.min_action, model_path=args.model_path, model2_path=args.model2_path) # remember call this after gym.make!!
elif args.env=='graph':
env = GraphEnv()
env.init(reward_step_total=args.reward_step_total,is_normalize=args.normalize_adj,dataset=args.dataset) # remember call this after gym.make!!
print(env.observation_space)
def policy_fn(name, ob_space, ac_space):
return gcn_policy.GCNPolicy(name=name, ob_space=ob_space, ac_space=ac_space, atom_type_num=env.atom_type_num,args=args)
env.seed(workerseed)
pposgd_simple_gcn.learn(args,env, policy_fn,
max_timesteps=args.num_steps,
timesteps_per_actorbatch=256,
clip_param=0.2, entcoeff=0.01,
optim_epochs=8, optim_stepsize=args.lr, optim_batchsize=32,
gamma=1, lam=0.95,
schedule='linear', writer=writer
)
env.close()
def arg_parser():
"""
Create an empty argparse.ArgumentParser.
"""
import argparse
return argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
def molecule_arg_parser():
parser = arg_parser()
parser.add_argument('--env', type=str, help='environment name: molecule; graph',
default='molecule')
parser.add_argument('--seed', help='RNG seed', type=int, default=666)
parser.add_argument('--num_steps', type=int, default=int(5e7))
parser.add_argument('--name', type=str, default='test_conditional')
parser.add_argument('--name_load', type=str, default='0new_concatno_mean_layer3_expert1500')
# parser.add_argument('--name_load', type=str, default='test')
parser.add_argument('--dataset', type=str, default='zinc',help='caveman; grid; ba; zinc; gdb; dopamine')
parser.add_argument('--dataset_load', type=str, default='zinc')
parser.add_argument('--reward_type', type=str, default='logppen',help='logppen;logp_target;qed;qedsa;qed_target;mw_target;gan;pki')
parser.add_argument('--reward_target', type=float, default=0.5,help='target reward value')
parser.add_argument('--logp_ratio', type=float, default=1)
parser.add_argument('--qed_ratio', type=float, default=1)
parser.add_argument('--sa_ratio', type=float, default=1)
parser.add_argument('--gan_step_ratio', type=float, default=1)
parser.add_argument('--gan_final_ratio', type=float, default=1)
parser.add_argument('--reward_step_total', type=float, default=0.5)
parser.add_argument('--lr', type=float, default=1e-3)
# parser.add_argument('--has_rl', type=int, default=1)
# parser.add_argument('--has_expert', type=int, default=1)
parser.add_argument('--has_d_step', type=int, default=1)
parser.add_argument('--has_d_final', type=int, default=1)
parser.add_argument('--has_ppo', type=int, default=1)
parser.add_argument('--rl_start', type=int, default=250)
parser.add_argument('--rl_end', type=int, default=int(1e6))
parser.add_argument('--expert_start', type=int, default=0)
parser.add_argument('--expert_end', type=int, default=int(1e6))
parser.add_argument('--save_every', type=int, default=200)
parser.add_argument('--load', type=int, default=0)
parser.add_argument('--load_step', type=int, default=250)
# parser.add_argument('--load_step', type=int, default=0)
parser.add_argument('--curriculum', type=int, default=0)
parser.add_argument('--curriculum_num', type=int, default=6)
parser.add_argument('--curriculum_step', type=int, default=200)
parser.add_argument('--supervise_time', type=int, default=4)
parser.add_argument('--normalize_adj', type=int, default=0)
parser.add_argument('--layer_num_g', type=int, default=3)
parser.add_argument('--layer_num_d', type=int, default=3)
parser.add_argument('--graph_emb', type=int, default=0)
parser.add_argument('--stop_shift', type=int, default=-3)
parser.add_argument('--has_residual', type=int, default=0)
parser.add_argument('--has_concat', type=int, default=0)
parser.add_argument('--has_feature', type=int, default=0)
parser.add_argument('--emb_size', type=int, default=128) # default 64
parser.add_argument('--gcn_aggregate', type=str, default='mean')# sum, mean, concat
parser.add_argument('--gan_type', type=str, default='normal')# normal, recommend, wgan
parser.add_argument('--gate_sum_d', type=int, default=0)
parser.add_argument('--mask_null', type=int, default=0)
parser.add_argument('--is_conditional', type=int, default=0) # default 0
parser.add_argument('--conditional', type=str, default='low') # default 0
parser.add_argument('--max_action', type=int, default=128) # default 0 #todo:try experimenting with this value of max_action
parser.add_argument('--min_action', type=int, default=20) # default 0
parser.add_argument('--bn', type=int, default=0)
parser.add_argument('--name_full',type=str,default='')
parser.add_argument('--name_full_load',type=str,default='')
parser.add_argument('--model_path', type=str, default='')
parser.add_argument('--model2_path', type=str, default='')
return parser
def main():
args = molecule_arg_parser().parse_args()
print(args)
args.name_full = args.env + '_' + args.dataset + '_' + args.name
args.name_full_load = args.env + '_' + args.dataset_load + '_' + args.name_load + '_' + str(args.load_step)
# check and clean
if not os.path.exists('molecule_gen'):
os.makedirs('molecule_gen')
if not os.path.exists('ckpt'):
os.makedirs('ckpt')
# only keep first worker result in tensorboard
if MPI.COMM_WORLD.Get_rank() == 0:
writer = SummaryWriter(comment='_'+args.dataset+'_'+args.name)
else:
writer = None
train(args,seed=args.seed,writer=writer)
if __name__ == '__main__':
main()