scan-context/src2/python/make_sc_example.py

173 lines
4.8 KiB
Python
Raw Normal View History

2021-10-19 00:52:28 +00:00
import os
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from open3d import *
class kitti_vlp_database:
def __init__(self, bin_dir):
self.bin_dir = bin_dir
self.bin_files = os.listdir(bin_dir); self.bin_files.sort()
self.num_bins = len(self.bin_files)
class ScanContext:
# static variables
viz = 0
downcell_size = 0.5
kitti_lidar_height = 2.0;
# sector_res = np.array([45, 90, 180, 360, 720])
# ring_res = np.array([10, 20, 40, 80, 160])
sector_res = np.array([60])
ring_res = np.array([20])
max_length = 80
def __init__(self, bin_dir, bin_file_name):
self.bin_dir = bin_dir
self.bin_file_name = bin_file_name
self.bin_path = bin_dir + bin_file_name
self.scancontexts = self.genSCs()
def load_velo_scan(self):
scan = np.fromfile(self.bin_path, dtype=np.float32)
scan = scan.reshape((-1, 4))
ptcloud_xyz = scan[:, :-1]
return ptcloud_xyz
def xy2theta(self, x, y):
if (x >= 0 and y >= 0):
theta = 180/np.pi * np.arctan(y/x);
if (x < 0 and y >= 0):
theta = 180 - ((180/np.pi) * np.arctan(y/(-x)));
if (x < 0 and y < 0):
theta = 180 + ((180/np.pi) * np.arctan(y/x));
if ( x >= 0 and y < 0):
theta = 360 - ((180/np.pi) * np.arctan((-y)/x));
return theta
def pt2rs(self, point, gap_ring, gap_sector, num_ring, num_sector):
x = point[0]
y = point[1]
z = point[2]
if(x == 0.0):
x = 0.001
if(y == 0.0):
y = 0.001
theta = self.xy2theta(x, y)
faraway = np.sqrt(x*x + y*y)
idx_ring = np.divmod(faraway, gap_ring)[0]
idx_sector = np.divmod(theta, gap_sector)[0]
if(idx_ring >= num_ring):
idx_ring = num_ring-1 # python starts with 0 and ends with N-1
return int(idx_ring), int(idx_sector)
def ptcloud2sc(self, ptcloud, num_sector, num_ring, max_length):
num_points = ptcloud.shape[0]
gap_ring = max_length/num_ring
gap_sector = 360/num_sector
enough_large = 1000
sc_storage = np.zeros([enough_large, num_ring, num_sector])
sc_counter = np.zeros([num_ring, num_sector])
for pt_idx in range(num_points):
point = ptcloud[pt_idx, :]
point_height = point[2] + ScanContext.kitti_lidar_height
idx_ring, idx_sector = self.pt2rs(point, gap_ring, gap_sector, num_ring, num_sector)
if sc_counter[idx_ring, idx_sector] >= enough_large:
continue
sc_storage[int(sc_counter[idx_ring, idx_sector]), idx_ring, idx_sector] = point_height
sc_counter[idx_ring, idx_sector] = sc_counter[idx_ring, idx_sector] + 1
sc = np.amax(sc_storage, axis=0)
return sc
def genSCs(self):
ptcloud_xyz = self.load_velo_scan()
print("The number of original points: " + str(ptcloud_xyz.shape) )
pcd = PointCloud()
pcd.points = Vector3dVector(ptcloud_xyz)
downpcd = voxel_down_sample(pcd, voxel_size = ScanContext.downcell_size)
ptcloud_xyz_downed = np.asarray(downpcd.points)
print("The number of downsampled points: " + str(ptcloud_xyz_downed.shape) )
# draw_geometries([downpcd])
if(ScanContext.viz):
draw_geometries([downpcd])
self.SCs = []
for res in range(len(ScanContext.sector_res)):
num_sector = ScanContext.sector_res[res]
num_ring = ScanContext.ring_res[res]
sc = self.ptcloud2sc(ptcloud_xyz_downed, num_sector, num_ring, ScanContext.max_length)
self.SCs.append(sc)
def plot_multiple_sc(self, fig_idx=1):
num_res = len(ScanContext.sector_res)
fig, axes = plt.subplots(nrows=num_res)
axes[0].set_title('Scan Contexts with multiple resolutions', fontsize=14)
for ax, res in zip(axes, range(num_res)):
ax.imshow(self.SCs[res])
plt.show()
if __name__ == "__main__":
bin_dir = './data/'
bin_db = kitti_vlp_database(bin_dir)
for bin_idx in range(bin_db.num_bins):
bin_file_name = bin_db.bin_files[bin_idx]
bin_path = bin_db.bin_dir + bin_file_name
sc = ScanContext(bin_dir, bin_file_name)
fig_idx = 1
# sc.plot_multiple_sc(fig_idx)
print(len(sc.SCs))
print(sc.SCs[0].shape)