Skip to content

Commit

Permalink
Merge pull request #564 from isl-org/dev_to_master_0.16
Browse files Browse the repository at this point in the history
Dev to master 0.16
  • Loading branch information
ssheorey authored Oct 14, 2022
2 parents 5148228 + 7c692d2 commit 761df64
Show file tree
Hide file tree
Showing 26 changed files with 461 additions and 196 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/ubuntu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ jobs:
steps:
- name: Checkout source code
uses: actions/checkout@v2
with:
submodules: true
- name: Setup cache
uses: actions/cache@v2
with:
Expand All @@ -35,7 +33,7 @@ jobs:
- name: Set up Python version
uses: actions/setup-python@v2
with:
python-version: 3.6
python-version: "3.10"
# Pre-installed 18.04 packages: https://git.io/JfHmW
- name: Install ccache
run: |
Expand Down
52 changes: 30 additions & 22 deletions ci/run_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
# The following environment variables are required:
# - NPROC
#
TENSORFLOW_VER="2.5.2"
TORCH_GLNX_VER="1.8.2+cpu"
TENSORFLOW_VER="2.8.2"
TORCH_GLNX_VER="1.12.0+cpu"
# OPENVINO_DEV_VER="2021.4.2" # Numpy version conflict with TF 2.8.2
PIP_VER="21.1.1"
WHEEL_VER="0.37.1"
STOOLS_VER="50.3.2"
YAPF_VER="0.30.0"
PYTEST_VER="6.0.1"
PYTEST_VER="7.1.2"
PYTEST_RANDOMLY_VER="3.8.0"

set -euo pipefail
Expand All @@ -16,7 +20,14 @@ echo
export PATH_TO_OPEN3D_ML=$(pwd)
# the build system of the main repo expects a master branch. make sure master exists
git checkout -b master || true
pip install -r requirements.txt
python -m pip install -U pip==$PIP_VER \
wheel=="$WHEEL_VER" \
setuptools=="$STOOLS_VER" \
yapf=="$YAPF_VER" \
pytest=="$PYTEST_VER" \
pytest-randomly=="$PYTEST_RANDOMLY_VER"

python -m pip install -r requirements.txt
echo $PATH_TO_OPEN3D_ML
cd ..
python -m pip install -U Cython
Expand All @@ -26,28 +37,25 @@ echo
git clone --recursive --branch master https://github.com/isl-org/Open3D.git

./Open3D/util/install_deps_ubuntu.sh assume-yes
python -m pip install -U tensorflow-cpu==$TENSORFLOW_VER
python -m pip install -U torch==${TORCH_GLNX_VER} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
python -m pip install -U pytest=="$PYTEST_VER" \
pytest-randomly=="$PYTEST_RANDOMLY_VER"
python -m pip install -U yapf=="$YAPF_VER"
python -m pip install -U openvino-dev==2021.4.2
python -m pip install -U tensorflow-cpu==$TENSORFLOW_VER \
torch==${TORCH_GLNX_VER} --extra-index-url https://download.pytorch.org/whl/cpu/
# openvino-dev=="$OPENVINO_DEV_VER"

echo 3. Configure for bundling the Open3D-ML part
echo
mkdir Open3D/build
pushd Open3D/build
cmake -DBUNDLE_OPEN3D_ML=ON \
-DOPEN3D_ML_ROOT=$PATH_TO_OPEN3D_ML \
-DGLIBCXX_USE_CXX11_ABI=OFF \
-DBUILD_TENSORFLOW_OPS=ON \
-DBUILD_PYTORCH_OPS=ON \
-DBUILD_GUI=OFF \
-DBUILD_RPC_INTERFACE=OFF \
-DBUILD_UNIT_TESTS=OFF \
-DBUILD_BENCHMARKS=OFF \
-DBUILD_EXAMPLES=OFF \
..
-DOPEN3D_ML_ROOT=$PATH_TO_OPEN3D_ML \
-DGLIBCXX_USE_CXX11_ABI=OFF \
-DBUILD_TENSORFLOW_OPS=ON \
-DBUILD_PYTORCH_OPS=ON \
-DBUILD_GUI=OFF \
-DBUILD_RPC_INTERFACE=OFF \
-DBUILD_UNIT_TESTS=OFF \
-DBUILD_BENCHMARKS=OFF \
-DBUILD_EXAMPLES=OFF \
..

echo 4. Build and install wheel
echo
Expand All @@ -60,12 +68,12 @@ popd
mkdir test_workdir
pushd test_workdir
mv $PATH_TO_OPEN3D_ML/tests .
echo Add --rondomly-seed=SEED to the test command to reproduce test order.
echo Add --randomly-seed=SEED to the test command to reproduce test order.
python -m pytest tests

echo ... now do the same but in dev mode by setting OPEN3D_ML_ROOT
export OPEN3D_ML_ROOT=$PATH_TO_OPEN3D_ML
echo Add --rondomly-seed=SEED to the test command to reproduce test order.
echo Add --randomly-seed=SEED to the test command to reproduce test order.
python -m pytest tests
unset OPEN3D_ML_ROOT

Expand Down
13 changes: 12 additions & 1 deletion docs/tensorboard.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ order.

Now you can visualize the data in TensorBoard as before. The web interface
allows showing and hiding points with different classes, changing their colors,
and exploring predictions and intermediate network features. Scalar network
and exploring predictions and intermediate network features. Scalar network
features can be visualized with custom user editable colormaps, and 3D features
can be visualized as RGB colors. Here is a video showing the different ways in
which semantic segmentation summary data can be visualized in TensorBoard.
Expand Down Expand Up @@ -376,3 +376,14 @@ for step in range(len(val_split)): # one pointcloud per step
step,
label_to_names=dset.get_label_to_names())
```

Troubleshooting
---------------

If you cannot interact with the 3D model, or use controls in the WebRTC widget,
make sure that Allow Autoplay is enabled for the Tensorboard web site and reload.

<img src=https://user-images.githubusercontent.com/41028320/180485249-5233b65e-11b1-44ff-bfc4-35f390ef51f2.png
title="Allow Autoplay for correct behavior."
alt="Allow Autoplay for correct behavior."
style="width:80%;display:block;margin:auto"></img>
19 changes: 10 additions & 9 deletions ml3d/configs/pointpillars_waymo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ dataset:
name: Waymo
dataset_path: # path/to/your/dataset
cache_dir: ./logs/cache
steps_per_epoch_train: 5000
steps_per_epoch_train: 4000

model:
name: PointPillars
Expand Down Expand Up @@ -31,7 +31,7 @@ model:
max_voxels: [32000, 32000]

voxel_encoder:
in_channels: 5
in_channels: 4
feat_channels: [64]
voxel_size: *vsize

Expand All @@ -43,7 +43,7 @@ model:
in_channels: 64
out_channels: [64, 128, 256]
layer_nums: [3, 5, 5]
layer_strides: [2, 2, 2]
layer_strides: [1, 2, 2]

neck:
in_channels: [64, 128, 256]
Expand All @@ -62,17 +62,18 @@ model:
[-74.88, -74.88, 0, 74.88, 74.88, 0],
]
sizes: [
[2.08, 4.73, 1.77], # car
[0.84, 1.81, 1.77], # cyclist
[0.84, 0.91, 1.74] # pedestrian
[2.08, 4.73, 1.77], # VEHICLE
[0.84, 1.81, 1.77], # CYCLIST
[0.84, 0.91, 1.74] # PEDESTRIAN
]
dir_offset: 0.7854
rotations: [0, 1.57]
iou_thr: [[0.4, 0.55], [0.3, 0.5], [0.3, 0.5]]

augment:
PointShuffle: True
ObjectRangeFilter: True
ObjectRangeFilter:
point_cloud_range: [-74.88, -74.88, -2, 74.88, 74.88, 4]
ObjectSample:
min_points_dict:
VEHICLE: 5
Expand All @@ -88,7 +89,7 @@ pipeline:
name: ObjectDetection
test_compute_metric: true
batch_size: 6
val_batch_size: 1
val_batch_size: 6
test_batch_size: 1
save_ckpt_freq: 5
max_epoch: 200
Expand All @@ -102,7 +103,7 @@ pipeline:
weight_decay: 0.01

# evaluation properties
overlaps: [0.5, 0.5, 0.7]
overlaps: [0.5, 0.5, 0.5]
difficulties: [0, 1, 2]
summary:
record_for: []
Expand Down
2 changes: 1 addition & 1 deletion ml3d/datasets/augment/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def ObjectSample(self, data, db_boxes_dict, sample_dict):
sampled_points = np.concatenate(
[box.points_inside_box for box in sampled], axis=0)
points = remove_points_in_boxes(points, sampled)
points = np.concatenate([sampled_points, points], axis=0)
points = np.concatenate([sampled_points[:, :4], points], axis=0)

return {
'point': points,
Expand Down
2 changes: 1 addition & 1 deletion ml3d/datasets/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import math
from scipy.spatial import ConvexHull

from ...metrics import iou_bev
from open3d.ml.contrib import iou_bev_cpu as iou_bev


def create_3D_rotations(axis, angle):
Expand Down
59 changes: 30 additions & 29 deletions ml3d/datasets/waymo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(self,
name='Waymo',
cache_dir='./logs/cache',
use_cache=False,
val_split=3,
**kwargs):
"""Initialize the function by passing the dataset and other details.
Expand All @@ -34,7 +33,6 @@ def __init__(self,
name: The name of the dataset (Waymo in this case).
cache_dir: The directory where the cache is stored.
use_cache: Indicates if the dataset should be cached.
val_split: The split value to get a set of images for training, validation, for testing.
Returns:
class: The corresponding class.
Expand All @@ -43,7 +41,6 @@ def __init__(self,
name=name,
cache_dir=cache_dir,
use_cache=use_cache,
val_split=val_split,
**kwargs)

cfg = self.cfg
Expand All @@ -52,22 +49,27 @@ def __init__(self,
self.dataset_path = cfg.dataset_path
self.num_classes = 4
self.label_to_names = self.get_label_to_names()
self.shuffle = kwargs.get('shuffle', False)

self.all_files = sorted(
glob(join(cfg.dataset_path, 'velodyne', '*.bin')))
self.train_files = []
self.val_files = []
self.test_files = []

for f in self.all_files:
idx = Path(f).name.replace('.bin', '')[:3]
idx = int(idx)
if idx < cfg.val_split:
if 'train' in f:
self.train_files.append(f)
else:
elif 'val' in f:
self.val_files.append(f)

self.test_files = glob(
join(cfg.dataset_path, 'testing', 'velodyne', '*.bin'))
elif 'test' in f:
self.test_files.append(f)
else:
log.warning(
f"Skipping {f}, prefix must be one of train, test or val.")
if self.shuffle:
log.info("Shuffling training files...")
self.rng.shuffle(self.train_files)

@staticmethod
def get_label_to_names():
Expand All @@ -90,18 +92,21 @@ def read_lidar(path):
"""Reads lidar data from the path provided.
Returns:
A data object with lidar information.
pc: pointcloud data with shape [N, 6], where
the format is xyzRGB.
"""
assert Path(path).exists()

return np.fromfile(path, dtype=np.float32).reshape(-1, 6)

@staticmethod
def read_label(path, calib):
"""Reads labels of bound boxes.
"""Reads labels of bounding boxes.
Args:
path: The path to the label file.
calib: Calibration as returned by read_calib().
Returns:
The data objects with bound boxes information.
The data objects with bounding boxes information.
"""
if not Path(path).exists():
return None
Expand Down Expand Up @@ -131,24 +136,22 @@ def read_calib(path):
Returns:
The camera and the camera image used in calibration.
"""
assert Path(path).exists()

with open(path, 'r') as f:
lines = f.readlines()
obj = lines[0].strip().split(' ')[1:]
P0 = np.array(obj, dtype=np.float32)
unused_P0 = np.array(obj, dtype=np.float32)

obj = lines[1].strip().split(' ')[1:]
P1 = np.array(obj, dtype=np.float32)
unused_P1 = np.array(obj, dtype=np.float32)

obj = lines[2].strip().split(' ')[1:]
P2 = np.array(obj, dtype=np.float32)

obj = lines[3].strip().split(' ')[1:]
P3 = np.array(obj, dtype=np.float32)
unused_P3 = np.array(obj, dtype=np.float32)

obj = lines[4].strip().split(' ')[1:]
P4 = np.array(obj, dtype=np.float32)
unused_P4 = np.array(obj, dtype=np.float32)

obj = lines[5].strip().split(' ')[1:]
R0 = np.array(obj, dtype=np.float32).reshape(3, 3)
Expand All @@ -162,7 +165,7 @@ def read_calib(path):
Tr_velo_to_cam = Waymo._extend_matrix(Tr_velo_to_cam)

world_cam = np.transpose(rect_4x4 @ Tr_velo_to_cam)
cam_img = np.transpose(P2)
cam_img = np.transpose(np.vstack((P2.reshape(3, 4), [0, 0, 0, 1])))

return {'world_cam': world_cam, 'cam_img': cam_img}

Expand Down Expand Up @@ -209,7 +212,7 @@ def get_split_list(self, split):
else:
raise ValueError("Invalid split {}".format(split))

def is_tested():
def is_tested(attr):
"""Checks if a datum in the dataset has been tested.
Args:
Expand All @@ -219,16 +222,16 @@ def is_tested():
If the datum attribute is tested, then return the path where the
attribute is stored; else, returns false.
"""
pass
raise NotImplementedError()

def save_test_result():
def save_test_result(results, attr):
"""Saves the output of a model.
Args:
results: The output of a model for the datum associated with the attribute passed.
attr: The attributes that correspond to the outputs passed in results.
"""
pass
raise NotImplementedError()


class WaymoSplit():
Expand Down Expand Up @@ -273,11 +276,9 @@ def get_attr(self, idx):


class Object3d(BEVBox3D):
"""The class stores details that are object-specific, such as bounding box
coordinates, occlusion and so on.
"""

def __init__(self, center, size, label, calib):
# ground truth files doesn't have confidence value.
confidence = float(label[15]) if label.__len__() == 16 else -1.0

world_cam = calib['world_cam']
Expand Down
2 changes: 1 addition & 1 deletion ml3d/tf/pipelines/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def run_train(self):

self.save_logs(writer, epoch)

if epoch % cfg.save_ckpt_freq == 0:
if epoch % cfg.save_ckpt_freq == 0 or epoch == cfg.max_epoch:
self.save_ckpt(epoch)

def get_3d_summary(self,
Expand Down
Loading

0 comments on commit 761df64

Please sign in to comment.