-
Notifications
You must be signed in to change notification settings - Fork 0
/
auto_train.py
131 lines (113 loc) · 4.35 KB
/
auto_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import subprocess
import redis
from contexts.terrain_context import TerrainContextSpace
import json
import wandb
import time
import os
import torch
from cfg.jackal_config import JackalCfg
def retrieve_context(r):
terrain_context_data = r.get('terrain_context')
if terrain_context_data is None:
raise ValueError("No terrain context found in Redis!")
terrain_context_dict = json.loads(terrain_context_data)
terrain_context = TerrainContextSpace.from_dict(terrain_context_dict)
return terrain_context
def load_from_file(file_path):
if not os.path.exists(file_path):
raise FileNotFoundError(f"File '{file_path}' not found.")
with open(file_path, 'r') as f:
data = json.load(f)
return TerrainContextSpace.from_dict(data)
def choose(choices):
while True:
print("Please choose a project:")
for i, name in enumerate(choices):
print(f"{i + 1}. {name}")
try:
choice = int(input("Enter the number corresponding to the project: ")) - 1
# Check if the choice is valid
if 0 <= choice < len(choices):
choice = choices[choice]
print(f"Running with project name: {choice}")
return choice
else:
print("Invalid choice. Please try again.")
except ValueError:
print("Invalid input. Please enter a number corresponding to the project.")
# Define the available project names
trainers = [
'teacher',
'student'
]
environment_samplers = [
'nature',
'procedural',
'ddpm',
]
environment_curricula = [
'uniform',
'adaptive'
]
trainer = choose(trainers)
context_space = JackalCfg()
context_space.context.environment.sampler = choose(environment_samplers)
context_space.context.environment.curriculum = choose(environment_curricula)
# Prompt the user for a file input
file_path = input("Please enter the file path (press Enter to skip): ")
if not file_path:
file_path = None
project = 'Offroad Navigation ' + trainer
method = context_space.context.environment.sampler + ' ' + context_space.context.environment.curriculum
if trainer == 'teacher':
if file_path is not None:
terrain_context = load_from_file(file_path)
else:
terrain_context = TerrainContextSpace(context_space)
terrain_context.start_time = time.time()
terrain_context.flag_evaluaton = False
policy = 'rl_agents.train_teacher'
elif trainer == 'student':
if file_path is None:
file_path = f"wandb/Log/{context_space.context.environment.sampler}/{context_space.context.environment.curriculum}/latest.json"
terrain_context = load_from_file(file_path)
terrain_context.num_rows = context_space.context.environment.num_rows_distill
terrain_context.num_cols = context_space.context.environment.num_cols_distill
terrain_context.num_terrains = terrain_context.num_rows * terrain_context.num_cols
num_distill_steps = len(terrain_context.data) // terrain_context.num_terrains
policy = 'bc_agents.train_student'
run_id = wandb.util.generate_id()
# Connect to the Redis server
r = redis.Redis(host='localhost', port=6379, db=0)
# Create the list of programs to run
for itr in range(context_space.rl.num_learning_steps):
slice_start = None if trainer == 'teacher' else itr
tif_file = terrain_context.sample(slice_start)
backbone = terrain_context.get_backbone()
r.set('terrain_context', json.dumps(terrain_context.to_dict()))
torch.cuda.empty_cache()
program = [
'python3', '-m', policy,
'--tif_name', tif_file,
'--backbone', backbone,
'--project', project,
'--method', method,
'--wandb', 'False',
'--wandb_entity', 'Nullptr',
'--wandb_id', run_id
]
print(f"Executing: {program}")
subprocess.call(program)
terrain_context_data = r.get('terrain_context')
if terrain_context_data is None:
raise ValueError("No terrain context found in Redis!")
terrain_context_dict = json.loads(terrain_context_data)
terrain_context = TerrainContextSpace.from_dict(terrain_context_dict)
if trainer == 'teacher':
terrain_context.evaluate(r)
terrain_context.save_to_file()
if trainer == 'teacher' and terrain_context.epoch >= context_space.rl.num_learning_steps:
break
elif trainer == 'student' and itr >= num_distill_steps:
break