forked from pierotofy/OpenSplat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
opensplat.cpp
144 lines (122 loc) · 7.58 KB
/
opensplat.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#include <filesystem>
#include "vendor/json/json.hpp"
#include "opensplat.hpp"
#include "point_io.hpp"
#include "utils.hpp"
#include "cv_utils.hpp"
#include "vendor/cxxopts.hpp"
namespace fs = std::filesystem;
using namespace torch::indexing;
int main(int argc, char *argv[]){
cxxopts::Options options("opensplat", "Open Source 3D Gaussian Splats generator");
options.add_options()
("i,input", "Path to nerfstudio project", cxxopts::value<std::string>())
("o,output", "Path where to save output scene", cxxopts::value<std::string>()->default_value("splat.ply"))
("s,save-every", "Save output scene every these many steps (set to -1 to disable)", cxxopts::value<int>()->default_value("-1"))
("val", "Withhold a camera shot for validating the scene loss")
("val-image", "Filename of the image to withhold for validating scene loss", cxxopts::value<std::string>()->default_value("random"))
("n,num-iters", "Number of iterations to run", cxxopts::value<int>()->default_value("30000"))
("d,downscale-factor", "Scale input images by this factor.", cxxopts::value<float>()->default_value("1"))
("num-downscales", "Number of images downscales to use. After being scaled by [downscale-factor], images are initially scaled by a further (2^[num-downscales]) and the scale is increased every [resolution-schedule]", cxxopts::value<int>()->default_value("3"))
("resolution-schedule", "Double the image resolution every these many steps", cxxopts::value<int>()->default_value("250"))
("sh-degree", "Maximum spherical harmonics degree (must be > 0)", cxxopts::value<int>()->default_value("3"))
("sh-degree-interval", "Increase the number of spherical harmonics degree after these many steps (will not exceed [sh-degree])", cxxopts::value<int>()->default_value("1000"))
("ssim-weight", "Weight to apply to the structural similarity loss. Set to zero to use least absolute deviation (L1) loss only", cxxopts::value<float>()->default_value("0.2"))
("refine-every", "Split/duplicate/prune gaussians every these many steps", cxxopts::value<int>()->default_value("100"))
("warmup-length", "Split/duplicate/prune gaussians only after these many steps", cxxopts::value<int>()->default_value("500"))
("reset-alpha-every", "Reset the opacity values of gaussians after these many refinements (not steps)", cxxopts::value<int>()->default_value("30"))
("stop-split-at", "Stop splitting/duplicating gaussians after these many steps", cxxopts::value<int>()->default_value("15000"))
("densify-grad-thresh", "Threshold of the positional gradient norm (magnitude of the loss function) which when exceeded leads to a gaussian split/duplication", cxxopts::value<float>()->default_value("0.0002"))
("densify-size-thresh", "Gaussians' scales below this threshold are duplicated, otherwise split", cxxopts::value<float>()->default_value("0.01"))
("stop-screen-size-at", "Stop splitting gaussians that are larger than [split-screen-size] after these many steps", cxxopts::value<int>()->default_value("4000"))
("split-screen-size", "Split gaussians that are larger than this percentage of screen space", cxxopts::value<float>()->default_value("0.05"))
("h,help", "Print usage")
;
options.parse_positional({ "input" });
options.positional_help("[nerfstudio project path]");
cxxopts::ParseResult result;
try {
result = options.parse(argc, argv);
}
catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
std::cerr << options.help() << std::endl;
return EXIT_FAILURE;
}
if (result.count("help") || !result.count("input")) {
std::cout << options.help() << std::endl;
return EXIT_SUCCESS;
}
const std::string projectRoot = result["input"].as<std::string>();
const std::string outputScene = result["output"].as<std::string>();
const int saveEvery = result["save-every"].as<int>();
const bool validate = result.count("val") > 0;
const std::string valImage = result["val-image"].as<std::string>();
const float downScaleFactor = (std::max)(result["downscale-factor"].as<float>(), 1.0f);
const int numIters = result["num-iters"].as<int>();
const int numDownscales = result["num-downscales"].as<int>();
const int resolutionSchedule = result["resolution-schedule"].as<int>();
const int shDegree = result["sh-degree"].as<int>();
const int shDegreeInterval = result["sh-degree-interval"].as<int>();
const float ssimWeight = result["ssim-weight"].as<float>();
const int refineEvery = result["refine-every"].as<int>();
const int warmupLength = result["warmup-length"].as<int>();
const int resetAlphaEvery = result["reset-alpha-every"].as<int>();
const int stopSplitAt = result["stop-split-at"].as<int>();
const float densifyGradThresh = result["densify-grad-thresh"].as<float>();
const float densifySizeThresh = result["densify-size-thresh"].as<float>();
const int stopScreenSizeAt = result["stop-screen-size-at"].as<int>();
const float splitScreenSize = result["split-screen-size"].as<float>();
torch::Device device = torch::kCPU;
if (torch::cuda::is_available()) {
std::cout << "Using CUDA" << std::endl;
device = torch::kCUDA;
}else{
std::cout << "Using CPU" << std::endl;
}
try{
ns::InputData inputData = ns::inputDataFromNerfStudio(projectRoot);
for (ns::Camera &cam : inputData.cameras){
cam.loadImage(downScaleFactor);
}
// Withhold a validation camera if necessary
auto t = inputData.getCameras(validate, valImage);
std::vector<ns::Camera> cams = std::get<0>(t);
ns::Camera *valCam = std::get<1>(t);
ns::Model model(inputData.points,
cams.size(),
numDownscales, resolutionSchedule, shDegree, shDegreeInterval,
refineEvery, warmupLength, resetAlphaEvery, stopSplitAt, densifyGradThresh, densifySizeThresh, stopScreenSizeAt, splitScreenSize,
numIters,
device);
InfiniteRandomIterator<ns::Camera> camsIter(cams);
int imageSize = -1;
for (size_t step = 1; step <= numIters; step++){
ns::Camera cam = camsIter.next();
model.optimizersZeroGrad();
torch::Tensor rgb = model.forward(cam, step);
torch::Tensor gt = cam.getImage(model.getDownscaleFactor(step));
gt = gt.to(device);
torch::Tensor mainLoss = model.mainLoss(rgb, gt, ssimWeight);
mainLoss.backward();
if (step % 10 == 0) std::cout << "Step " << step << ": " << mainLoss.item<float>() << std::endl;
model.optimizersStep();
model.schedulersStep(step);
model.afterTrain(step);
if (saveEvery > 0 && step % saveEvery == 0){
fs::path p(outputScene);
model.savePlySplat((p.replace_filename(fs::path(p.stem().string() + "_" + std::to_string(step) + p.extension().string())).string()));
}
}
model.savePlySplat(outputScene);
// Validate
if (valCam != nullptr){
torch::Tensor rgb = model.forward(*valCam, numIters);
torch::Tensor gt = valCam->getImage(model.getDownscaleFactor(numIters)).to(device);
std::cout << valCam->filePath << " validation loss: " << model.mainLoss(rgb, gt, ssimWeight).item<float>() << std::endl;
}
}catch(const std::exception &e){
std::cerr << e.what() << std::endl;
exit(1);
}
}