diff --git a/.gitignore b/.gitignore index d9929e19..7fecbd5d 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ data *.out eval *.npz -**/tmp \ No newline at end of file +**/tmp +eval_dtu \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 4d2d1242..058b5395 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,9 @@ -[submodule "submodules/diff-surfel-rasterization"] - path = submodules/diff-surfel-rasterization - url = https://github.com/hbb1/diff-surfel-rasterization.git [submodule "submodules/simple-knn"] path = submodules/simple-knn url = https://gitlab.inria.fr/bkerbl/simple-knn.git +[submodule "submodules/diff-surfel-rasterization"] + path = submodules/diff-surfel-rasterization + url = https://github.com/hugoycj/diff-surfel-rasterization-MCMC +[submodule "submodules/fused-ssim"] + path = submodules/fused-ssim + url = https://github.com/rahul-goel/fused-ssim.git diff --git a/arguments/__init__.py b/arguments/__init__.py index 2977a888..ed77c3ba 100644 --- a/arguments/__init__.py +++ b/arguments/__init__.py @@ -50,11 +50,14 @@ def __init__(self, parser, sentinel=False): self._source_path = "" self._model_path = "" self._images = "images" - self._resolution = -1 + self._resolution = 1 self._white_background = False self.data_device = "cuda" self.eval = False self.render_items = ['RGB', 'Alpha', 'Normal', 'Depth', 'Edge', 'Curvature'] + self.w_normal_prior = "" + self.w_mask = "" + self.use_decoupled_appearance = False super().__init__(parser, "Loading Parameters", sentinel) def extract(self, args): @@ -81,17 +84,39 @@ def __init__(self, parser): self.opacity_lr = 0.05 self.scaling_lr = 0.005 self.rotation_lr = 0.001 + self.appearance_embeddings_lr = 0.001 + self.appearance_network_lr = 0.001 self.percent_dense = 0.01 self.lambda_dssim = 0.2 - self.lambda_dist = 0.0 + self.lambda_dist = 0. + self.lambda_depth = 0.1 self.lambda_normal = 0.05 + self.lambda_mask = 0. + self.lambda_normal_prior = 0.05 + self.lambda_normal_gradient = 0.01 self.opacity_cull = 0.05 + self.split_interval = 500 + self.max_screen_size = 20 + self.densification_interval = 100 self.opacity_reset_interval = 3000 self.densify_from_iter = 500 self.densify_until_iter = 15_000 self.densify_grad_threshold = 0.0002 + + self.propagation_interval = 20 + self.depth_error_min_threshold = 0.8 + self.depth_error_max_threshold = 1.0 + self.propagation_begin = 9000 + self.propagation_after = 15000 + self.patch_size = 11 + + self.pixel_dense_from_iter = 30000 + + self.contribution_prune_from_iter = 500 + self.contribution_prune_interval = 300 + self.contribution_prune_ratio = 0.1 super().__init__(parser, "Optimization Parameters") def get_combined_args(parser : ArgumentParser): diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py index 3b217802..7d15df21 100644 --- a/gaussian_renderer/__init__.py +++ b/gaussian_renderer/__init__.py @@ -16,15 +16,29 @@ from utils.sh_utils import eval_sh from utils.point_utils import depth_to_normal -def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): +def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, + override_color = None, record_transmittance=False, bg_gaussians=None, skip_geometric=False): """ Render the scene. Background tensor (bg_color) must be on GPU! """ - + if bg_gaussians is None: + means3D = pc.get_xyz + opacity = pc.get_opacity + scales = pc.get_scaling + rotations = pc.get_rotation + shs = pc.get_features + else: + means3D = torch.cat([pc.get_xyz, bg_gaussians.get_xyz]) + opacity = torch.cat([pc.get_opacity, bg_gaussians.get_opacity]) + scales = torch.cat([pc.get_scaling, bg_gaussians.get_scaling]) + rotations = torch.cat([pc.get_rotation, bg_gaussians.get_rotation]) + shs = torch.cat([pc.get_features, bg_gaussians.get_features]) + num_fg_points = pc.get_xyz.shape[0] + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means - screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 + screenspace_points = torch.zeros((means3D.shape[0], 4), dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 try: screenspace_points.retain_grad() except: @@ -46,74 +60,43 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, sh_degree=pc.active_sh_degree, campos=viewpoint_camera.camera_center, prefiltered=False, + record_transmittance=record_transmittance, debug=False, # pipe.debug ) rasterizer = GaussianRasterizer(raster_settings=raster_settings) - means3D = pc.get_xyz means2D = screenspace_points - opacity = pc.get_opacity - # If precomputed 3d covariance is provided, use it. If not, then it will be computed from - # scaling / rotation by the rasterizer. - scales = None - rotations = None - cov3D_precomp = None - if pipe.compute_cov3D_python: - # currently don't support normal consistency loss if use precomputed covariance - splat2world = pc.get_covariance(scaling_modifier) - W, H = viewpoint_camera.image_width, viewpoint_camera.image_height - near, far = viewpoint_camera.znear, viewpoint_camera.zfar - ndc2pix = torch.tensor([ - [W / 2, 0, 0, (W-1) / 2], - [0, H / 2, 0, (H-1) / 2], - [0, 0, far-near, near], - [0, 0, 0, 1]]).float().cuda().T - world2pix = viewpoint_camera.full_proj_transform @ ndc2pix - cov3D_precomp = (splat2world[:, [0,1,3]] @ world2pix[:,[0,1,3]]).permute(0,2,1).reshape(-1, 9) # column major - else: - scales = pc.get_scaling - rotations = pc.get_rotation - - # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors - # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. - pipe.convert_SHs_python = False - shs = None - colors_precomp = None - if override_color is None: - if pipe.convert_SHs_python: - shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) - dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) - dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) - sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) - colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) - else: - shs = pc.get_features - else: - colors_precomp = override_color - - rendered_image, radii, allmap = rasterizer( + output = rasterizer( means3D = means3D, means2D = means2D, shs = shs, - colors_precomp = colors_precomp, + colors_precomp = None, opacities = opacity, scales = scales, rotations = rotations, - cov3D_precomp = cov3D_precomp - ) - + cov3D_precomp = None) + + if record_transmittance: + rendered_image, radii, allmap, transmittance_avg, num_covered_pixels = output + transmittance_avg = transmittance_avg[:num_fg_points] + num_covered_pixels = num_covered_pixels[:num_fg_points] + else: + rendered_image, radii, allmap = output + transmittance_avg = num_covered_pixels = None + radii = radii[:num_fg_points] # Those Gaussians that were frustum culled or had a radius of 0 were not visible. # They will be excluded from value updates used in the splitting criteria. rets = {"render": rendered_image, "viewspace_points": means2D, "visibility_filter" : radii > 0, "radii": radii, + "pixels_num":num_covered_pixels, + "transmittance_avg": transmittance_avg } - # additional regularizations render_alpha = allmap[1:2] @@ -140,18 +123,24 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, # for unbounded scene, use expected depth, i.e., depth_ration = 0, to reduce disk anliasing. surf_depth = render_depth_expected * (1-pipe.depth_ratio) + (pipe.depth_ratio) * render_depth_median - # assume the depth points form the 'surface' and generate psudo surface normal for regularizations. - surf_normal = depth_to_normal(viewpoint_camera, surf_depth) - surf_normal = surf_normal.permute(2,0,1) - # remember to multiply with accum_alpha since render_normal is unnormalized. - surf_normal = surf_normal * (render_alpha).detach() - + if skip_geometric: + # assume the depth points form the 'surface' and generate psudo surface normal for regularizations. + surf_normal_expected = depth_to_normal(viewpoint_camera, render_depth_expected).permute(2,0,1) + surf_normal = depth_to_normal(viewpoint_camera, render_depth_median).permute(2,0,1) + # remember to multiply with accum_alpha since render_normal is unnormalized. + surf_normal_expected = surf_normal_expected * (render_alpha).detach() + surf_normal = surf_normal * (render_alpha).detach() + else: + surf_normal_expected = render_normal + surf_normal = render_normal rets.update({ 'rend_alpha': render_alpha, 'rend_normal': render_normal, + 'rend_depth': render_depth_expected, 'rend_dist': render_dist, 'surf_depth': surf_depth, + 'surf_normal_expected': surf_normal_expected, 'surf_normal': surf_normal, }) diff --git a/scene/__init__.py b/scene/__init__.py index cbd196d0..3aff1c55 100644 --- a/scene/__init__.py +++ b/scene/__init__.py @@ -14,7 +14,8 @@ import json from utils.system_utils import searchForMaxIteration from scene.dataset_readers import sceneLoadTypeCallbacks -from scene.gaussian_model import GaussianModel +from scene.gaussian_model import GaussianModel, BgGaussianModel +from scene.appearance_model import AppearanceModel from arguments import ModelParams from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON @@ -22,13 +23,14 @@ class Scene: gaussians : GaussianModel - def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): + def __init__(self, args : ModelParams, gaussians : GaussianModel, bg_gaussians: BgGaussianModel = None, load_iteration=None, shuffle=True, resolution_scales=[1.0]): """b :param path: Path to colmap scene main folder. """ self.model_path = args.model_path self.loaded_iter = None self.gaussians = gaussians + self.bg_gaussians = bg_gaussians if load_iteration: if load_iteration == -1: @@ -79,12 +81,21 @@ def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration "point_cloud", "iteration_" + str(self.loaded_iter), "point_cloud.ply")) + if self.bg_gaussians is not None: + self.bg_gaussians.load_ply(os.path.join(self.model_path, + "point_cloud", + "iteration_" + str(self.loaded_iter), + "bg_point_cloud.ply")) else: self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) + if self.bg_gaussians is not None: + self.bg_gaussians.load_ply('assets/background_gs.ply') def save(self, iteration): point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) + if self.bg_gaussians is not None: + self.bg_gaussians.save_ply(os.path.join(point_cloud_path, "bg_point_cloud.ply")) def getTrainCameras(self, scale=1.0): return self.train_cameras[scale] diff --git a/scene/appearance_model.py b/scene/appearance_model.py new file mode 100644 index 00000000..f1bf8dcd --- /dev/null +++ b/scene/appearance_model.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class UpsampleBlock(nn.Module): + def __init__(self, num_input_channels, num_output_channels): + super(UpsampleBlock, self).__init__() + self.pixel_shuffle = nn.PixelShuffle(2) + self.conv = nn.Conv2d(num_input_channels // (2 * 2), num_output_channels, 3, stride=1, padding=1) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.pixel_shuffle(x) + x = self.conv(x) + x = self.relu(x) + return x + +class AppearanceNetwork(nn.Module): + def __init__(self, num_input_channels, num_output_channels): + super(AppearanceNetwork, self).__init__() + + self.conv1 = nn.Conv2d(num_input_channels, 256, 3, stride=1, padding=1) + self.up1 = UpsampleBlock(256, 128) + self.up2 = UpsampleBlock(128, 64) + self.up3 = UpsampleBlock(64, 32) + self.up4 = UpsampleBlock(32, 16) + + self.conv2 = nn.Conv2d(16, 16, 3, stride=1, padding=1) + self.conv3 = nn.Conv2d(16, num_output_channels, 3, stride=1, padding=1) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.up1(x) + x = self.up2(x) + x = self.up3(x) + x = self.up4(x) + # bilinear interpolation + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + x = self.conv2(x) + x = self.relu(x) + x = self.conv3(x) + x = self.sigmoid(x) + return x + +class AppearanceModel: + def __init__(self, num_embeddings, num_input_channels=67, num_output_channels=3): + self.appearance_network = AppearanceNetwork(num_input_channels, num_output_channels).cuda() + + std = 1e-4 + self._appearance_embeddings = nn.Parameter(torch.empty(num_embeddings, 64).cuda()) + self._appearance_embeddings.data.normal_(0, std) + + def get_embedding(self, idx): + return self._appearance_embeddings[idx] + + def training_setup(self, training_args): + params = [ + {'params': [self._appearance_embeddings], 'lr': training_args.appearance_embeddings_lr, "name": "appearance_embeddings"}, + {'params': self.appearance_network.parameters(), 'lr': training_args.appearance_network_lr, "name": "appearance_network"} + ] + self.optimizer = torch.optim.Adam(params, lr=0.0, eps=1e-15) + + def load_state_dict(self, state_dict): + self._appearance_embeddings = state_dict["_appearance_embeddings"] + self.appearance_network.load_state_dict(state_dict["appearance_network"]) + + def state_dict(self): + return { + "_appearance_embeddings": self._appearance_embeddings, + "appearance_network": self.appearance_network.state_dict() + } + + + +if __name__ == "__main__": + H, W = 1200//32, 1600//32 + input_channels = 3 + 64 + output_channels = 3 + input = torch.randn(1, input_channels, H, W).cuda() + model = AppearanceNetwork(input_channels, output_channels).cuda() + + output = model(input) + print(output.shape) \ No newline at end of file diff --git a/scene/cameras.py b/scene/cameras.py index 4a5d84a8..48ef2d81 100644 --- a/scene/cameras.py +++ b/scene/cameras.py @@ -3,7 +3,7 @@ # GRAPHDECO research group, https://team.inria.fr/graphdeco # All rights reserved. # -# This software is free for non-commercial, research and evaluation use +# This software is free for non-commercial, research and evaluation use # under the terms of the LICENSE.md file. # # For inquiries contact george.drettakis@inria.fr @@ -12,11 +12,11 @@ import torch from torch import nn import numpy as np -from utils.graphics_utils import getWorld2View2, getProjectionMatrix +from utils.graphics_utils import getWorld2View2, getProjectionMatrixShift, generate_K class Camera(nn.Module): def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, - image_name, uid, + image_name, uid, principal_point_ndc, normal=None, trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" ): super(Camera, self).__init__() @@ -28,6 +28,8 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, self.FoVx = FoVx self.FoVy = FoVy self.image_name = image_name + self.depth_prior = None + self.depth_mask = None try: self.data_device = torch.device(data_device) @@ -37,16 +39,26 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, self.data_device = torch.device("cuda") self.original_image = image.clamp(0.0, 1.0).to(self.data_device) + if normal is not None: + self.normal_prior = normal.to(self.data_device) + normal_norm = torch.norm(self.normal_prior, dim=0, keepdim=True) + self.normal_mask = ~((normal_norm > 1.1) | (normal_norm < 0.9)) + self.normal_prior = self.normal_prior / normal_norm + else: + self.normal_prior = None + self.normal_mask = None self.image_width = self.original_image.shape[2] self.image_height = self.original_image.shape[1] if gt_alpha_mask is not None: - # self.original_image *= gt_alpha_mask.to(self.data_device) + self.original_image *= gt_alpha_mask.to(self.data_device) self.gt_alpha_mask = gt_alpha_mask.to(self.data_device) else: self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) - self.gt_alpha_mask = None - + self.gt_alpha_mask = torch.ones((1, self.image_height, self.image_width), device=self.data_device) + + self.K = generate_K(fovX=self.FoVx, fovY=self.FoVy, width=self.image_width, height=self.image_height, principal_point_ndc=principal_point_ndc).to(self.data_device).to(torch.float32) + self.zfar = 100.0 self.znear = 0.01 @@ -54,14 +66,14 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, self.scale = scale self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() - self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() + self.projection_matrix = getProjectionMatrixShift(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy, width=self.image_width, height=self.image_height, principal_point_ndc=principal_point_ndc).transpose(0,1).cuda() self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) self.camera_center = self.world_view_transform.inverse()[3, :3] class MiniCam: def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): self.image_width = width - self.image_height = height + self.image_height = height self.FoVy = fovy self.FoVx = fovx self.znear = znear diff --git a/scene/dataset_readers.py b/scene/dataset_readers.py index 2a6f904a..ffeed45f 100644 --- a/scene/dataset_readers.py +++ b/scene/dataset_readers.py @@ -3,7 +3,7 @@ # GRAPHDECO research group, https://team.inria.fr/graphdeco # All rights reserved. # -# This software is free for non-commercial, research and evaluation use +# This software is free for non-commercial, research and evaluation use # under the terms of the LICENSE.md file. # # For inquiries contact george.drettakis@inria.fr @@ -18,6 +18,7 @@ from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal import numpy as np import json +import open3d as o3d from pathlib import Path from plyfile import PlyData, PlyElement from utils.sh_utils import SH2RGB @@ -34,6 +35,7 @@ class CameraInfo(NamedTuple): image_name: str width: int height: int + principal_point_ndc: np.array = np.array([1/2, 1/2]) class SceneInfo(NamedTuple): point_cloud: BasicPointCloud @@ -84,22 +86,29 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): if intr.model=="SIMPLE_PINHOLE": focal_length_x = intr.params[0] + cx = intr.params[1] + cy = intr.params[2] FovY = focal2fov(focal_length_x, height) FovX = focal2fov(focal_length_x, width) - elif intr.model=="PINHOLE": + elif intr.model=="PINHOLE" or intr.model=='OPENCV': focal_length_x = intr.params[0] focal_length_y = intr.params[1] + cx = intr.params[2] + cy = intr.params[3] FovY = focal2fov(focal_length_y, height) FovX = focal2fov(focal_length_x, width) else: assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" + principal_point_ndc = np.array([cx / width, cy / height]) + image_path = os.path.join(images_folder, os.path.basename(extr.name)) image_name = os.path.basename(image_path).split(".")[0] image = Image.open(image_path) cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, - image_path=image_path, image_name=image_name, width=width, height=height) + image_path=image_path, image_name=image_name, width=width, height=height, + principal_point_ndc=principal_point_ndc) cam_infos.append(cam_info) sys.stdout.write('\n') return cam_infos @@ -109,7 +118,12 @@ def fetchPly(path): vertices = plydata['vertex'] positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 - normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T + try: + normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T + except: + normals = np.zeros_like(positions) + if np.all(normals == 0): + normals = None return BasicPointCloud(points=positions, colors=colors, normals=normals) def storePly(path, xyz, rgb): @@ -117,7 +131,7 @@ def storePly(path, xyz, rgb): dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] - + normals = np.zeros_like(xyz) elements = np.empty(xyz.shape[0], dtype=dtype) @@ -164,10 +178,8 @@ def readColmapSceneInfo(path, images, eval, llffhold=8): except: xyz, rgb, _ = read_points3D_text(txt_path) storePly(ply_path, xyz, rgb) - try: - pcd = fetchPly(ply_path) - except: - pcd = None + + pcd = fetchPly(ply_path) scene_info = SceneInfo(point_cloud=pcd, train_cameras=train_cam_infos, @@ -210,12 +222,12 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension= image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) - FovY = fovy + FovY = fovy FovX = fovx cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) - + return cam_infos def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): @@ -223,7 +235,7 @@ def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension) print("Reading Test Transforms") test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension) - + if not eval: train_cam_infos.extend(test_cam_infos) test_cam_infos = [] @@ -235,7 +247,7 @@ def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): # Since this data set has no colmap data, we start with random points num_pts = 100_000 print(f"Generating random point cloud ({num_pts})...") - + # We create random points inside the bounds of the synthetic Blender scenes xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 shs = np.random.random((num_pts, 3)) / 255.0 diff --git a/scene/gaussian_model.py b/scene/gaussian_model.py index 9d05b2cf..051b2805 100644 --- a/scene/gaussian_model.py +++ b/scene/gaussian_model.py @@ -21,6 +21,37 @@ from utils.graphics_utils import BasicPointCloud from utils.general_utils import strip_symmetric, build_scaling_rotation +# adopted from https://github.com/turandai/gaussian_surfels/blob/main/utils/general_utils.py +def normal2rotation(n): + # construct a random rotation matrix from normal + # it would better be positive definite and orthogonal + n = torch.nn.functional.normalize(n) + w0 = torch.tensor([[1, 0, 0]]).expand(n.shape).to(n.device) + R0 = w0 - torch.sum(w0 * n, -1, True) * n + R0 *= torch.sign(R0[:, :1]) + R0 = torch.nn.functional.normalize(R0) + R1 = torch.cross(n, R0) + + R1 *= torch.sign(R1[:, 1:2]) * torch.sign(n[:, 2:]) + R = torch.stack([R0, R1, n], -1) + q = rotmat2quaternion(R) + + return q + +def rotmat2quaternion(R, normalize=False): + tr = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2] + 1e-6 + r = torch.sqrt(1 + tr) / 2 + # print(torch.sum(torch.isnan(r))) + q = torch.stack([ + r, + (R[:, 2, 1] - R[:, 1, 2]) / (4 * r), + (R[:, 0, 2] - R[:, 2, 0]) / (4 * r), + (R[:, 1, 0] - R[:, 0, 1]) / (4 * r) + ], -1) + if normalize: + q = torch.nn.functional.normalize(q, dim=-1) + return q + class GaussianModel: def setup_functions(self): @@ -133,7 +164,12 @@ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 2) - rots = torch.rand((fused_point_cloud.shape[0], 4), device="cuda") + + # calculate normal + if pcd.normals is None: + rots = torch.rand((fused_point_cloud.shape[0], 4), device="cuda") + else: + rots = normal2rotation(torch.from_numpy(np.asarray(pcd.normals)).float().cuda()) opacities = self.inverse_opacity_activation(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) @@ -238,6 +274,7 @@ def load_ply(self, path): scales = np.zeros((xyz.shape[0], len(scale_names))) for idx, attr_name in enumerate(scale_names): scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + scales = scales[:, :2] rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) @@ -333,6 +370,7 @@ def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new "scaling" : new_scaling, "rotation" : new_rotation} + self.max_radii2D = torch.cat([self.max_radii2D, torch.zeros((new_xyz.shape[0]), device="cuda")], dim=0) optimizable_tensors = self.cat_tensors_to_optimizer(d) self._xyz = optimizable_tensors["xyz"] self._features_dc = optimizable_tensors["f_dc"] @@ -343,7 +381,6 @@ def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") - self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): n_init_points = self.get_xyz.shape[0] @@ -386,7 +423,7 @@ def densify_and_clone(self, grads, grad_threshold, scene_extent): self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation) - def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): + def densify_and_prune(self, max_grad, min_opacity, extent, prune_big_points=False): grads = self.xyz_gradient_accum / self.denom grads[grads.isnan()] = 0.0 @@ -394,14 +431,160 @@ def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): self.densify_and_split(grads, max_grad, extent) prune_mask = (self.get_opacity < min_opacity).squeeze() - if max_screen_size: - big_points_vs = self.max_radii2D > max_screen_size + if prune_big_points: big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent - prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) + prune_mask = torch.logical_or(prune_mask, big_points_ws) + self.prune_points(prune_mask) - torch.cuda.empty_cache() + + def split_big_points(self, max_screen_size): + big_points_mask = self.max_radii2D > max_screen_size + big_point_indices = torch.where(big_points_mask)[0] + + if big_point_indices.numel() == 0: + return # No points to split + + # Calculate split numbers based on max_radii2D + split_numbers = torch.ceil(self.max_radii2D[big_points_mask] / max_screen_size).long() + total_new_points = split_numbers.sum().item() + print(f"Generting {total_new_points} new points") + + # Create prune filter + prune_filter = torch.zeros(self.get_xyz.shape[0] + total_new_points, device="cuda", dtype=bool) + prune_filter[:self.get_xyz.shape[0]] = big_points_mask + + index_list = torch.arange(split_numbers.size(0), device="cuda").repeat_interleave(split_numbers) + + stds = self.get_scaling[big_point_indices[index_list]] + stds = torch.cat([stds, torch.zeros_like(stds[:, :1])], dim=-1) + means = torch.zeros_like(stds) + samples = torch.normal(mean=means, std=stds) + + rots = build_rotation(self._rotation[big_point_indices[index_list]]) + + new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[big_point_indices[index_list]] + new_scaling = self.scaling_inverse_activation(self.get_scaling[big_point_indices[index_list]] / (0.8 * split_numbers[index_list].unsqueeze(1))) + new_rotation = self._rotation[big_point_indices[index_list]] + new_features_dc = self._features_dc[big_point_indices[index_list]] + new_features_rest = self._features_rest[big_point_indices[index_list]] + new_opacity = self._opacity[big_point_indices[index_list]] + + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation) + + self.prune_points(prune_filter) + + # Reset max_radii2D for the new points + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + + def add_densification_stats(self, viewspace_point_tensor, update_filter, pixels): + if pixels is not None: + self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[:len(update_filter)][update_filter], dim=-1, keepdim=True) * pixels[update_filter].unsqueeze(-1) + self.denom[update_filter] += pixels[update_filter].unsqueeze(-1) + else: + self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[:len(update_filter)][update_filter], dim=-1, keepdim=True) + self.denom[update_filter] += 1 + + + def densify_from_depth_propagation(self, viewpoint_cam, propagated_depth, propagated_normal, filter_mask): + # inverse project pixels into 3D scenes + K = viewpoint_cam.K + cam2world = viewpoint_cam.world_view_transform.transpose(0, 1).inverse() + + # Get the shape of the depth image + height, width = propagated_depth.shape + # Create a grid of 2D pixel coordinates + y, x = torch.meshgrid(torch.arange(0, height), torch.arange(0, width)) + # Stack the 2D and depth coordinates to create 3D homogeneous coordinates + coordinates = torch.stack([x.to(propagated_depth.device), y.to(propagated_depth.device), torch.ones_like(propagated_depth)], dim=-1) + # Reshape the coordinates to (height * width, 3) + coordinates = coordinates.view(-1, 3).to(K.device).to(torch.float32) + # Reproject the 2D coordinates to 3D coordinates + coordinates_3D = (K.inverse() @ coordinates.T).T + + # Multiply by depth + coordinates_3D *= propagated_depth.view(-1, 1) + + # convert to the world coordinate + world_coordinates_3D = (cam2world[:3, :3] @ coordinates_3D.T).T + cam2world[:3, 3] + + #mask the points below the confidence threshold + #downsample the pixels; 1/4 + world_coordinates_3D = world_coordinates_3D.view(height, width, 3) + world_coordinates_3D_downsampled = world_coordinates_3D[::4, ::4] + filter_mask_downsampled = filter_mask[::4, ::4] + gt_image = viewpoint_cam.original_image.cuda() + gt_image_downsampled = gt_image.permute(1, 2, 0)[::4, ::4] + + world_coordinates_3D_downsampled = world_coordinates_3D_downsampled[filter_mask_downsampled] + color_downsampled = gt_image_downsampled[filter_mask_downsampled] + + # Compute world scale + fx, fy = K[0, 0] / 4, K[1, 1] / 4 # Assuming K is the camera intrinsic matrix + world_scale = propagated_depth[::4, ::4][filter_mask_downsampled] / ((fx + fy) / 2) + world_normal = propagated_normal[::4, ::4][filter_mask_downsampled] + + # initialize gaussians + fused_point_cloud = world_coordinates_3D_downsampled + fused_color = RGB2SH(color_downsampled) + features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).to(fused_color.device) + features[:, :3, 0 ] = fused_color + features[:, 3:, 1:] = 0.0 + + scales = torch.log(world_scale)[..., None].repeat(1, 2) + rots = normal2rotation(world_normal).to(scales.device) + opacities = inverse_sigmoid(0.5 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) + + new_xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) + new_features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) + new_features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True)) + new_scaling = nn.Parameter(scales.requires_grad_(True)) + new_rotation = nn.Parameter(rots.requires_grad_(True)) + new_opacity = nn.Parameter(opacities.requires_grad_(True)) + + #update gaussians + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation) + +class BgGaussianModel(GaussianModel): + def __init__(self, sh_degree: int): + self.active_sh_degree = 3 + self.max_sh_degree = sh_degree + self._features_dc = torch.empty(0) + self._features_rest = torch.empty(0) + self._scaling = torch.empty(0) + self._rotation = torch.empty(0) + self._opacity = torch.empty(0) + self.optimizer = None + self.setup_functions() + + def capture(self): + return ( + self.active_sh_degree, + self._features_dc, + self._features_rest, + self._scaling, + self._rotation, + self._opacity, + self.optimizer.state_dict(), + ) + + def restore(self, model_args): + ( + self.active_sh_degree, + self._features_dc, + self._features_rest, + self._scaling, + self._rotation, + self._opacity, + opt_dict, + ) = model_args + self.optimizer.load_state_dict(opt_dict) + + def training_setup(self, training_args): + l = [ + {'params': [self._xyz], 'lr': training_args.position_lr_init, "name": "xyz"}, + {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, + {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, + ] - def add_densification_stats(self, viewspace_point_tensor, update_filter): - self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter], dim=-1, keepdim=True) - self.denom[update_filter] += 1 \ No newline at end of file + self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) \ No newline at end of file diff --git a/scripts/init_bg_gs.py b/scripts/init_bg_gs.py new file mode 100644 index 00000000..7911f789 --- /dev/null +++ b/scripts/init_bg_gs.py @@ -0,0 +1,7 @@ +from gaustudio.pipelines import initializers, optimizers +from gaustudio import models + +bg_gaussians_coarse = models.make("vanilla_pcd") +bg_initializer = initializers.make({"name": "gaussiansky", "radius": 100, "resolution": 500}) +bg_initializer(bg_gaussians_coarse) +bg_gaussians_coarse.export("assets/background_gs.ply") \ No newline at end of file diff --git a/scripts/init_normal.py b/scripts/init_normal.py new file mode 100644 index 00000000..46630143 --- /dev/null +++ b/scripts/init_normal.py @@ -0,0 +1,36 @@ +import click +import torch +import glob +from PIL import Image +import os +from tqdm import tqdm +import numpy as np +from transformers import AutoModelForImageSegmentation +import torchvision.transforms as transforms + +@click.command() +@click.option('--source_path', '-s', required=True, help='Path to the dataset') +def main(source_path: str) -> None: + # Load StableNormal model + normal_predictor = torch.hub.load("hugoycj/StableNormal", "StableNormal", + trust_repo=True, yoso_version='yoso-normal-v1-5', + local_cache_dir='/workspace/code/InverseRendering/StableNormal_git/weights') + + output_normal_dir = os.path.join(source_path, "normals") + os.makedirs(output_normal_dir, exist_ok=True) + + for image_path in tqdm(glob.glob(f"{source_path}/images/*.jpg", recursive=True)[::5] + \ + glob.glob(f"{source_path}/images/*.jpeg", recursive=True)[::5] + \ + glob.glob(f"{source_path}/images/*.png", recursive=True)[::5]): + image_name = os.path.basename(image_path.split("/")[-1]).split(".")[0] + + input_image = Image.open(image_path) + output_normal_path = os.path.join(output_normal_dir, image_name+'.png') + + # Generate normal map if it doesn't exist + if not os.path.exists(output_normal_path): + normal_image = normal_predictor(input_image, data_type='outdoor') + normal_image.save(output_normal_path) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/nerf_eval.py b/scripts/nerf_eval.py index a95c9920..9f0a546d 100644 --- a/scripts/nerf_eval.py +++ b/scripts/nerf_eval.py @@ -23,7 +23,7 @@ jobs = list(itertools.product(scenes, factors)) def train_scene(gpu, scene, factor): - cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python train.py -s {dataset_dir}/{scene} -m {output_dir}/{scene} --eval --white_background --lambda_normal 0.0 --port {6209+int(gpu)}" + cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python train.py -s {dataset_dir}/{scene} -m {output_dir}/{scene} --eval --white_background --lambda_mask 0.1 --lambda_normal 0.0 --port {6209+int(gpu)}" print(cmd) if not dry_run: os.system(cmd) diff --git a/scripts/train_fast.sh b/scripts/train_fast.sh new file mode 100644 index 00000000..dd463507 --- /dev/null +++ b/scripts/train_fast.sh @@ -0,0 +1,5 @@ +data=$1 +postix=gaustudio_fast +gs-init -s $data -o ${data}_${postix} +rm -r ${data}_${postix}/result_2 +python train_fast.py -s ${data}_${postix} -r 1 -m ${data}_${postix}/result_2 --iteration 7000 \ No newline at end of file diff --git a/scripts/train_progressive.sh b/scripts/train_progressive.sh new file mode 100644 index 00000000..690fd029 --- /dev/null +++ b/scripts/train_progressive.sh @@ -0,0 +1,5 @@ +data=$1 +postix=gaustudio_progressive +gs-init -s $data -o ${data}_${postix} +rm -r ${data}_${postix}/result_2 +python train_progressive -s ${data}_${postix} -r 1 -m ${data}_${postix}/result_2 \ No newline at end of file diff --git a/scripts/train_with_background.sh b/scripts/train_with_background.sh new file mode 100644 index 00000000..e73c85e1 --- /dev/null +++ b/scripts/train_with_background.sh @@ -0,0 +1,6 @@ +data=$1 + +python scripts/init_bg_gs.py +gs-init -s $data -o ${data}_gaustudio --overwrite --pcd combined +python train_with_bg.py -s ${data}_gaustudio -r 2 --lambda_dist 100 -m ${data}_gaustudio/result +gs-extract-pcd -m ${data}_gaustudio/result -o ${data}_gaustudio/result/fusion --meshing nksr \ No newline at end of file diff --git a/scripts/train_with_mask.sh b/scripts/train_with_mask.sh new file mode 100644 index 00000000..1d1bee05 --- /dev/null +++ b/scripts/train_with_mask.sh @@ -0,0 +1,9 @@ +data=$1 +postix=gaustudio_wmask +# gs-init -s $data -o ${data}_${postix} --w_mask mask --overwrite +rm -r ${data}_${postix}/result_2 +python train_progressive.py -s ${data}_${postix} -r 2 --lambda_dist 1000 \ + --w_mask masks --lambda_mask 0.1 --max_screen_size 5 \ + -m ${data}_${postix}/result_2 --iteration 20000 +gs-extract-pcd -m ${data}_${postix}/result_2 -o ${data}_${postix}/result_2/fusion_2 \ + --meshing sap --config 2dgs \ No newline at end of file diff --git a/scripts/train_wmask_wnormal.sh b/scripts/train_wmask_wnormal.sh new file mode 100644 index 00000000..48054cd1 --- /dev/null +++ b/scripts/train_wmask_wnormal.sh @@ -0,0 +1,10 @@ +data=$1 +postix=gaustudio_wmask_wnormal +gs-init -s $data -o ${data}_${postix} --w_mask mask +python scripts/init_normal.py -s ${data}_${postix} +rm -r ${data}_${postix}/result_2 +python train_progressive.py -s ${data}_${postix} -r 1 --lambda_dist 1000 \ + --w_mask masks --lambda_mask 0.1 --w_normal_prior normals \ + --max_screen_size 20 -m ${data}_${postix}/result_2 --iteration 7000 +gs-extract-pcd -m ${data}_${postix}/result_2 -o ${data}_${postix}/result_2/fusion_2 \ + --meshing sap --config 2dgs diff --git a/scripts/train_wmask_wnormal_fast.sh b/scripts/train_wmask_wnormal_fast.sh new file mode 100644 index 00000000..b28ce8dc --- /dev/null +++ b/scripts/train_wmask_wnormal_fast.sh @@ -0,0 +1,11 @@ +data=$1 +postix=gaustudio_wmask_wnormal_fast +# gs-init -s $data -o ${data}_${postix} --w_mask mask +python scripts/init_normal.py -s ${data}_${postix} +rm -r ${data}_${postix}/result_2 +python train_fast.py -s ${data}_${postix} -r 1 --lambda_dist 1000 --lambda_dssim 1 \ + --w_mask masks --lambda_mask 0.1 --w_normal_prior normals \ + --max_screen_size 20 -m ${data}_${postix}/result_2 --iteration 7000 + +gs-extract-pcd -m ${data}_${postix}/result_2 -o ${data}_${postix}/result_2/fusion_2 \ + --meshing sap --config 2dgs \ No newline at end of file diff --git a/scripts/train_wnormal.sh b/scripts/train_wnormal.sh new file mode 100644 index 00000000..84f492bf --- /dev/null +++ b/scripts/train_wnormal.sh @@ -0,0 +1,9 @@ +data=$1 +postix=gaustudio_wnormal +# gs-init -s $data -o ${data}_${postix} --pcd combined +# python scripts/init_normal.py -s ${data}_${postix} +rm -r ${data}_${postix}/result_2 +python train.py -s ${data}_${postix} -r 1 --contribution_prune_ratio 0.5 \ + --lambda_normal_prior 1 --lambda_dist 10 \ + --densify_until_iter 3000 --iteration 7000 \ + -m ${data}_${postix}/result_2 --w_normal_prior normals diff --git a/submodules/Propagation/PatchMatch.cpp b/submodules/Propagation/PatchMatch.cpp new file mode 100644 index 00000000..e9e2f582 --- /dev/null +++ b/submodules/Propagation/PatchMatch.cpp @@ -0,0 +1,463 @@ +#include "PatchMatch.h" +#include +#include + +#include + +void StringAppendV(std::string* dst, const char* format, va_list ap) { + // First try with a small fixed size buffer. + static const int kFixedBufferSize = 1024; + char fixed_buffer[kFixedBufferSize]; + + // It is possible for methods that use a va_list to invalidate + // the data in it upon use. The fix is to make a copy + // of the structure before using it and use that copy instead. + va_list backup_ap; + va_copy(backup_ap, ap); + int result = vsnprintf(fixed_buffer, kFixedBufferSize, format, backup_ap); + va_end(backup_ap); + + if (result < kFixedBufferSize) { + if (result >= 0) { + // Normal case - everything fits. + dst->append(fixed_buffer, result); + return; + } + +#ifdef _MSC_VER + // Error or MSVC running out of space. MSVC 8.0 and higher + // can be asked about space needed with the special idiom below: + va_copy(backup_ap, ap); + result = vsnprintf(nullptr, 0, format, backup_ap); + va_end(backup_ap); +#endif + + if (result < 0) { + // Just an error. + return; + } + } + + // Increase the buffer size to the size requested by vsnprintf, + // plus one for the closing \0. + const int variable_buffer_size = result + 1; + std::unique_ptr variable_buffer(new char[variable_buffer_size]); + + // Restore the va_list before we use it again. + va_copy(backup_ap, ap); + result = + vsnprintf(variable_buffer.get(), variable_buffer_size, format, backup_ap); + va_end(backup_ap); + + if (result >= 0 && result < variable_buffer_size) { + dst->append(variable_buffer.get(), result); + } +} + +std::string StringPrintf(const char* format, ...) { + va_list ap; + va_start(ap, format); + std::string result; + StringAppendV(&result, format, ap); + va_end(ap); + return result; +} + +void CudaSafeCall(const cudaError_t error, const std::string& file, + const int line) { + if (error != cudaSuccess) { + std::cerr << StringPrintf("%s in %s at line %i", cudaGetErrorString(error), + file.c_str(), line) + << std::endl; + exit(EXIT_FAILURE); + } +} + +void CudaCheckError(const char* file, const int line) { + cudaError error = cudaGetLastError(); + if (error != cudaSuccess) { + std::cerr << StringPrintf("cudaCheckError() failed at %s:%i : %s", file, + line, cudaGetErrorString(error)) + << std::endl; + exit(EXIT_FAILURE); + } + + // More careful checking. However, this will affect performance. + // Comment away if needed. + error = cudaDeviceSynchronize(); + if (cudaSuccess != error) { + std::cerr << StringPrintf("cudaCheckError() with sync failed at %s:%i : %s", + file, line, cudaGetErrorString(error)) + << std::endl; + std::cerr + << "This error is likely caused by the graphics card timeout " + "detection mechanism of your operating system. Please refer to " + "the FAQ in the documentation on how to solve this problem." + << std::endl; + exit(EXIT_FAILURE); + } +} + +PatchMatch::PatchMatch() {} + +PatchMatch::~PatchMatch() +{ + delete[] plane_hypotheses_host; + delete[] costs_host; + + for (int i = 0; i < num_images; ++i) { + cudaDestroyTextureObject(texture_objects_host.images[i]); + cudaFreeArray(cuArray[i]); + } + cudaFree(texture_objects_cuda); + cudaFree(cameras_cuda); + cudaFree(plane_hypotheses_cuda); + cudaFree(costs_cuda); + cudaFree(rand_states_cuda); + cudaFree(selected_views_cuda); + cudaFree(depths_cuda); + + if (params.geom_consistency) { + for (int i = 0; i < num_images; ++i) { + cudaDestroyTextureObject(texture_depths_host.images[i]); + cudaFreeArray(cuDepthArray[i]); + } + cudaFree(texture_depths_cuda); + } +} + +Camera ReadCamera(torch::Tensor intrinsic, torch::Tensor pose, torch::Tensor depth_interval) +{ + Camera camera; + + for (int i = 0; i < 3; ++i) { + camera.R[3 * i + 0] = pose[i][0].item(); + camera.R[3 * i + 1] = pose[i][1].item(); + camera.R[3 * i + 2] = pose[i][2].item(); + camera.t[i] = pose[i][3].item(); + } + + for (int i = 0; i < 3; ++i) { + camera.K[3 * i + 0] = intrinsic[i][0].item(); + camera.K[3 * i + 1] = intrinsic[i][1].item(); + camera.K[3 * i + 2] = intrinsic[i][2].item(); + } + + camera.depth_min = depth_interval[0].item(); + camera.depth_max = depth_interval[3].item(); + + return camera; +} + +void RescaleImageAndCamera(torch::Tensor &src, torch::Tensor &dst, torch::Tensor &depth, Camera &camera) +{ + const int cols = depth.size(1); + const int rows = depth.size(0); + + if (cols == src.size(1) && rows == src.size(0)) { + dst = src.clone(); + return; + } + + const float scale_x = cols / static_cast(src.size(1)); + const float scale_y = rows / static_cast(src.size(0)); + dst = torch::nn::functional::interpolate(src.unsqueeze(0), torch::nn::functional::InterpolateFuncOptions().size(std::vector({rows, cols})).mode(torch::kBilinear)).squeeze(0); + + camera.K[0] *= scale_x; + camera.K[2] *= scale_x; + camera.K[4] *= scale_y; + camera.K[5] *= scale_y; + camera.width = cols; + camera.height = rows; +} + +float3 Get3DPointonWorld(const int x, const int y, const float depth, const Camera camera) +{ + float3 pointX; + float3 tmpX; + // Reprojection + pointX.x = depth * (x - camera.K[2]) / camera.K[0]; + pointX.y = depth * (y - camera.K[5]) / camera.K[4]; + pointX.z = depth; + + // Rotation + tmpX.x = camera.R[0] * pointX.x + camera.R[3] * pointX.y + camera.R[6] * pointX.z; + tmpX.y = camera.R[1] * pointX.x + camera.R[4] * pointX.y + camera.R[7] * pointX.z; + tmpX.z = camera.R[2] * pointX.x + camera.R[5] * pointX.y + camera.R[8] * pointX.z; + + // Transformation + float3 C; + C.x = -(camera.R[0] * camera.t[0] + camera.R[3] * camera.t[1] + camera.R[6] * camera.t[2]); + C.y = -(camera.R[1] * camera.t[0] + camera.R[4] * camera.t[1] + camera.R[7] * camera.t[2]); + C.z = -(camera.R[2] * camera.t[0] + camera.R[5] * camera.t[1] + camera.R[8] * camera.t[2]); + pointX.x = tmpX.x + C.x; + pointX.y = tmpX.y + C.y; + pointX.z = tmpX.z + C.z; + + return pointX; +} + +void ProjectonCamera(const float3 PointX, const Camera camera, float2 &point, float &depth) +{ + float3 tmp; + tmp.x = camera.R[0] * PointX.x + camera.R[1] * PointX.y + camera.R[2] * PointX.z + camera.t[0]; + tmp.y = camera.R[3] * PointX.x + camera.R[4] * PointX.y + camera.R[5] * PointX.z + camera.t[1]; + tmp.z = camera.R[6] * PointX.x + camera.R[7] * PointX.y + camera.R[8] * PointX.z + camera.t[2]; + + depth = camera.K[6] * tmp.x + camera.K[7] * tmp.y + camera.K[8] * tmp.z; + point.x = (camera.K[0] * tmp.x + camera.K[1] * tmp.y + camera.K[2] * tmp.z) / depth; + point.y = (camera.K[3] * tmp.x + camera.K[4] * tmp.y + camera.K[5] * tmp.z) / depth; +} + +float GetAngle(const torch::Tensor &v1, const torch::Tensor &v2) +{ + float dot_product = v1[0].item() * v2[0].item() + v1[1].item() * v2[1].item() + v1[2].item() * v2[2].item(); + float angle = acosf(dot_product); + //if angle is not a number the dot product was 1 and thus the two vectors should be identical --> return 0 + if ( angle != angle ) + return 0.0f; + + return angle; +} + +void StoreColorPlyFileBinaryPointCloud (const std::string &plyFilePath, const std::vector &pc) +{ + std::cout << "store 3D points to ply file" << std::endl; + + FILE *outputPly; + outputPly=fopen(plyFilePath.c_str(), "wb"); + + /*write header*/ + fprintf(outputPly, "ply\n"); + fprintf(outputPly, "format binary_little_endian 1.0\n"); + fprintf(outputPly, "element vertex %d\n",pc.size()); + fprintf(outputPly, "property float x\n"); + fprintf(outputPly, "property float y\n"); + fprintf(outputPly, "property float z\n"); + fprintf(outputPly, "property float nx\n"); + fprintf(outputPly, "property float ny\n"); + fprintf(outputPly, "property float nz\n"); + fprintf(outputPly, "property uchar red\n"); + fprintf(outputPly, "property uchar green\n"); + fprintf(outputPly, "property uchar blue\n"); + fprintf(outputPly, "end_header\n"); + + //write data +#pragma omp parallel for + for(size_t i = 0; i < pc.size(); i++) { + const PointList &p = pc[i]; + float3 X = p.coord; + const float3 normal = p.normal; + const float3 color = p.color; + const char b_color = (int)color.x; + const char g_color = (int)color.y; + const char r_color = (int)color.z; + + if(!(X.x < FLT_MAX && X.x > -FLT_MAX) || !(X.y < FLT_MAX && X.y > -FLT_MAX) || !(X.z < FLT_MAX && X.z >= -FLT_MAX)){ + X.x = 0.0f; + X.y = 0.0f; + X.z = 0.0f; + } +#pragma omp critical + { + fwrite(&X.x, sizeof(X.x), 1, outputPly); + fwrite(&X.y, sizeof(X.y), 1, outputPly); + fwrite(&X.z, sizeof(X.z), 1, outputPly); + fwrite(&normal.x, sizeof(normal.x), 1, outputPly); + fwrite(&normal.y, sizeof(normal.y), 1, outputPly); + fwrite(&normal.z, sizeof(normal.z), 1, outputPly); + fwrite(&r_color, sizeof(char), 1, outputPly); + fwrite(&g_color, sizeof(char), 1, outputPly); + fwrite(&b_color, sizeof(char), 1, outputPly); + } + + } + fclose(outputPly); +} + +static float GetDisparity(const Camera &camera, const int2 &p, const float &depth) +{ + float point3D[3]; + point3D[0] = depth * (p.x - camera.K[2]) / camera.K[0]; + point3D[1] = depth * (p.y - camera.K[5]) / camera.K[4]; + point3D[2] = depth; + + return std::sqrt(point3D[0] * point3D[0] + point3D[1] * point3D[1] + point3D[2] * point3D[2]); +} + +void PatchMatch::SetGeomConsistencyParams() +{ + params.geom_consistency = true; + params.max_iterations = 2; +} + +void PatchMatch::InuputInitialization(torch::Tensor images_cuda, torch::Tensor intrinsics_cuda, torch::Tensor poses_cuda, + torch::Tensor depth_cuda, torch::Tensor normal_cuda, torch::Tensor depth_intervals) +{ + images.clear(); + cameras.clear(); + + torch::Tensor image_color = images_cuda[0]; + torch::Tensor image_float = torch::mean(image_color, /*dim=*/2, /*keepdim=*/true).squeeze(); + image_float = image_float.to(torch::kFloat32); + images.push_back(image_float); + + Camera camera = ReadCamera(intrinsics_cuda[0], poses_cuda[0], depth_intervals[0]); + camera.height = image_float.size(0); + camera.width = image_float.size(1); + cameras.push_back(camera); + + torch::Tensor ref_depth = depth_cuda; + depths.push_back(ref_depth); + + torch::Tensor ref_normal = normal_cuda; + normals.push_back(ref_normal); + + int num_src_images = images_cuda.size(0); + for (int i = 1; i < num_src_images; ++i) { + torch::Tensor src_image_color = images_cuda[i]; + torch::Tensor src_image_float = torch::mean(src_image_color, /*dim=*/2, /*keepdim=*/true).squeeze(); + src_image_float = src_image_float.to(torch::kFloat32); + images.push_back(src_image_float); + + Camera camera = ReadCamera(intrinsics_cuda[i], poses_cuda[i], depth_intervals[i]); + camera.height = src_image_float.size(0); + camera.width = src_image_float.size(1); + cameras.push_back(camera); + } + + // Scale cameras and images + for (size_t i = 0; i < images.size(); ++i) { + if (images[i].size(1) <= params.max_image_size && images[i].size(0) <= params.max_image_size) { + continue; + } + + const float factor_x = static_cast(params.max_image_size) / images[i].size(1); + const float factor_y = static_cast(params.max_image_size) / images[i].size(0); + const float factor = std::min(factor_x, factor_y); + + const int new_cols = std::round(images[i].size(1) * factor); + const int new_rows = std::round(images[i].size(0) * factor); + + const float scale_x = new_cols / static_cast(images[i].size(1)); + const float scale_y = new_rows / static_cast(images[i].size(0)); + + torch::Tensor scaled_image_float = torch::nn::functional::interpolate(images[i].unsqueeze(0), torch::nn::functional::InterpolateFuncOptions().size(std::vector({new_rows, new_cols})).mode(torch::kBilinear)).squeeze(0); + images[i] = scaled_image_float.clone(); + + cameras[i].K[0] *= scale_x; + cameras[i].K[2] *= scale_x; + cameras[i].K[4] *= scale_y; + cameras[i].K[5] *= scale_y; + cameras[i].height = scaled_image_float.size(0); + cameras[i].width = scaled_image_float.size(1); + } + + params.depth_min = cameras[0].depth_min * 0.6f; + params.depth_max = cameras[0].depth_max * 1.2f; + params.num_images = (int)images.size(); + params.disparity_min = cameras[0].K[0] * params.baseline / params.depth_max; + params.disparity_max = cameras[0].K[0] * params.baseline / params.depth_min; + +} + +void PatchMatch::CudaSpaceInitialization() +{ + num_images = (int)images.size(); + + for (int i = 0; i < num_images; ++i) { + int rows = images[i].size(0); + int cols = images[i].size(1); + + cudaChannelFormatDesc channelDesc = cudaCreateChannelDesc(32, 0, 0, 0, cudaChannelFormatKindFloat); + cudaMallocArray(&cuArray[i], &channelDesc, cols, rows); + + cudaMemcpy2DToArray(cuArray[i], 0, 0, images[i].data_ptr(), images[i].stride(0) * sizeof(float), cols * sizeof(float), rows, cudaMemcpyHostToDevice); + + struct cudaResourceDesc resDesc; + memset(&resDesc, 0, sizeof(cudaResourceDesc)); + resDesc.resType = cudaResourceTypeArray; + resDesc.res.array.array = cuArray[i]; + + struct cudaTextureDesc texDesc; + memset(&texDesc, 0, sizeof(cudaTextureDesc)); + texDesc.addressMode[0] = cudaAddressModeWrap; + texDesc.addressMode[1] = cudaAddressModeWrap; + texDesc.filterMode = cudaFilterModeLinear; + texDesc.readMode = cudaReadModeElementType; + texDesc.normalizedCoords = 0; + + cudaCreateTextureObject(&(texture_objects_host.images[i]), &resDesc, &texDesc, NULL); + } + + cudaMalloc((void**)&texture_objects_cuda, sizeof(cudaTextureObjects)); + cudaMemcpy(texture_objects_cuda, &texture_objects_host, sizeof(cudaTextureObjects), cudaMemcpyHostToDevice); + + cudaMalloc((void**)&cameras_cuda, sizeof(Camera) * (num_images)); + cudaMemcpy(cameras_cuda, &cameras[0], sizeof(Camera) * (num_images), cudaMemcpyHostToDevice); + + int total_pixels = cameras[0].height * cameras[0].width; + // Concatenate normals and depths into a single tensor + torch::Tensor plane_hypotheses_tensor = torch::cat({ + normals[0].view({total_pixels, 3}), + depths[0].view({total_pixels, 1}) + }, 1); + + // TODO: Fix initialization bug + // auto plane_hypotheses_float4 = plane_hypotheses_tensor.to(torch::kFloat32).view({-1, 4}); + // plane_hypotheses_host = reinterpret_cast(plane_hypotheses_float4.data_ptr()); + plane_hypotheses_host = new float4[total_pixels]; + cudaMalloc((void**)&plane_hypotheses_cuda, sizeof(float4) * total_pixels); + cudaMemcpy(plane_hypotheses_cuda, plane_hypotheses_host, sizeof(float4) * total_pixels, cudaMemcpyHostToDevice); + + costs_host = new float[cameras[0].height * cameras[0].width]; + cudaMalloc((void**)&costs_cuda, sizeof(float) * (cameras[0].height * cameras[0].width)); + + cudaMalloc((void**)&rand_states_cuda, sizeof(curandState) * (cameras[0].height * cameras[0].width)); + cudaMalloc((void**)&selected_views_cuda, sizeof(unsigned int) * (cameras[0].height * cameras[0].width)); + + cudaMalloc((void**)&depths_cuda, sizeof(float) * (cameras[0].height * cameras[0].width)); + cudaMemcpy(depths_cuda, depths[0].data_ptr(), sizeof(float) * cameras[0].height * cameras[0].width, cudaMemcpyHostToDevice); +} + +int PatchMatch::GetReferenceImageWidth() +{ + return cameras[0].width; +} + +int PatchMatch::GetReferenceImageHeight() +{ + return cameras[0].height; +} + +torch::Tensor PatchMatch::GetReferenceImage() +{ + return images[0]; +} + +float4 PatchMatch::GetPlaneHypothesis(const int index) +{ + return plane_hypotheses_host[index]; +} + +float4* PatchMatch::GetPlaneHypotheses() +{ + return plane_hypotheses_host; +} + +float PatchMatch::GetCost(const int index) +{ + return costs_host[index]; +} + +void PatchMatch::SetPatchSize(int patch_size) +{ + params.patch_size = patch_size; +} + +int PatchMatch::GetPatchSize() +{ + return params.patch_size; +} + + diff --git a/submodules/Propagation/PatchMatch.h b/submodules/Propagation/PatchMatch.h new file mode 100644 index 00000000..f10fb363 --- /dev/null +++ b/submodules/Propagation/PatchMatch.h @@ -0,0 +1,84 @@ +#ifndef _PatchMatch_H_ +#define _PatchMatch_H_ + +#include "main.h" +#include + +Camera ReadCamera(torch::Tensor intrinsic, torch::Tensor pose, torch::Tensor depth_interval); +void RescaleImageAndCamera(torch::Tensor &src, torch::Tensor &dst, torch::Tensor &depth, Camera &camera); +float3 Get3DPointonWorld(const int x, const int y, const float depth, const Camera camera); +void ProjectonCamera(const float3 PointX, const Camera camera, float2 &point, float &depth); +float GetAngle(const torch::Tensor &v1, const torch::Tensor &v2); +void StoreColorPlyFileBinaryPointCloud(const std::string &plyFilePath, const std::vector &pc); + +#define CUDA_SAFE_CALL(error) CudaSafeCall(error, __FILE__, __LINE__) +#define CUDA_CHECK_ERROR() CudaCheckError(__FILE__, __LINE__) + +void CudaSafeCall(const cudaError_t error, const std::string& file, const int line); +void CudaCheckError(const char* file, const int line); + +struct cudaTextureObjects { + cudaTextureObject_t images[MAX_IMAGES]; +}; + +struct PatchMatchParams { + int max_iterations = 6; + int patch_size = 11; + int num_images = 5; + int max_image_size=3200; + int radius_increment = 2; + float sigma_spatial = 5.0f; + float sigma_color = 3.0f; + int top_k = 4; + float baseline = 0.54f; + float depth_min = 0.0f; + float depth_max = 1.0f; + float disparity_min = 0.0f; + float disparity_max = 1.0f; + bool geom_consistency = false; +}; + +class PatchMatch { +public: + PatchMatch(); + ~PatchMatch(); + + void InuputInitialization(torch::Tensor images_cuda, torch::Tensor intrinsics_cuda, torch::Tensor poses_cuda, torch::Tensor depth_cuda, torch::Tensor normal_cuda, torch::Tensor depth_intervals); + void Colmap2MVS(const std::string &dense_folder, std::vector &problems); + void CudaSpaceInitialization(); + void RunPatchMatch(); + void SetGeomConsistencyParams(); + void SetPatchSize(int patch_size); + int GetPatchSize(); + int GetReferenceImageWidth(); + int GetReferenceImageHeight(); + torch::Tensor GetReferenceImage(); + float4 GetPlaneHypothesis(const int index); + float GetCost(const int index); + float4* GetPlaneHypotheses(); + +private: + int num_images; + std::vector images; + std::vector depths; + std::vector normals; + std::vector cameras; + cudaTextureObjects texture_objects_host; + cudaTextureObjects texture_depths_host; + float4 *plane_hypotheses_host; + float *costs_host; + PatchMatchParams params; + + Camera *cameras_cuda; + cudaArray *cuArray[MAX_IMAGES]; + cudaArray *cuDepthArray[MAX_IMAGES]; + cudaTextureObjects *texture_objects_cuda; + cudaTextureObjects *texture_depths_cuda; + float4 *plane_hypotheses_cuda; + float *costs_cuda; + curandState *rand_states_cuda; + unsigned int *selected_views_cuda; + float *depths_cuda; +}; + +#endif // _PatchMatch_H_ diff --git a/submodules/Propagation/Propagation.cu b/submodules/Propagation/Propagation.cu new file mode 100644 index 00000000..f8573ad4 --- /dev/null +++ b/submodules/Propagation/Propagation.cu @@ -0,0 +1,1170 @@ +#include "PatchMatch.h" +#include +#include + +__device__ void sort_small(float *d, const int n) +{ + int j; + for (int i = 1; i < n; i++) { + float tmp = d[i]; + for (j = i; j >= 1 && tmp < d[j-1]; j--) + d[j] = d[j-1]; + d[j] = tmp; + } +} + +__device__ void sort_small_weighted(float *d, float *w, int n) +{ + int j; + for (int i = 1; i < n; i++) { + float tmp = d[i]; + float tmp_w = w[i]; + for (j = i; j >= 1 && tmp < d[j - 1]; j--) { + d[j] = d[j - 1]; + w[j] = w[j - 1]; + } + d[j] = tmp; + w[j] = tmp_w; + } +} + +__device__ int FindMinCostIndex(const float *costs, const int n) +{ + float min_cost = costs[0]; + int min_cost_idx = 0; + for (int idx = 1; idx < n; ++idx) { + if (costs[idx] <= min_cost) { + min_cost = costs[idx]; + min_cost_idx = idx; + } + } + return min_cost_idx; +} + +__device__ void setBit(unsigned int &input, const unsigned int n) +{ + input |= (unsigned int)(1 << n); +} + +__device__ int isSet(unsigned int input, const unsigned int n) +{ + return (input >> n) & 1; +} + +__device__ void Mat33DotVec3(const float mat[9], const float4 vec, float4 *result) +{ + result->x = mat[0] * vec.x + mat[1] * vec.y + mat[2] * vec.z; + result->y = mat[3] * vec.x + mat[4] * vec.y + mat[5] * vec.z; + result->z = mat[6] * vec.x + mat[7] * vec.y + mat[8] * vec.z; +} + +__device__ float Vec3DotVec3(const float4 vec1, const float4 vec2) +{ + return vec1.x * vec2.x + vec1.y * vec2.y + vec1.z * vec2.z; +} + +__device__ void NormalizeVec3 (float4 *vec) +{ + const float normSquared = vec->x * vec->x + vec->y * vec->y + vec->z * vec->z; + const float inverse_sqrt = rsqrtf (normSquared); + vec->x *= inverse_sqrt; + vec->y *= inverse_sqrt; + vec->z *= inverse_sqrt; +} + +__device__ void TransformPDFToCDF(float* probs, const int num_probs) +{ + float prob_sum = 0.0f; + for (int i = 0; i < num_probs; ++i) { + prob_sum += probs[i]; + } + const float inv_prob_sum = 1.0f / prob_sum; + + float cum_prob = 0.0f; + for (int i = 0; i < num_probs; ++i) { + const float prob = probs[i] * inv_prob_sum; + cum_prob += prob; + probs[i] = cum_prob; + } +} + +__device__ void Get3DPoint(const Camera camera, const int2 p, const float depth, float *X) +{ + X[0] = depth * (p.x - camera.K[2]) / camera.K[0]; + X[1] = depth * (p.y - camera.K[5]) / camera.K[4]; + X[2] = depth; +} + +__device__ float4 GetViewDirection(const Camera camera, const int2 p, const float depth) +{ + float X[3]; + Get3DPoint(camera, p, depth, X); + float norm = sqrt(X[0] * X[0] + X[1] * X[1] + X[2] * X[2]); + + float4 view_direction; + view_direction.x = X[0] / norm; + view_direction.y = X[1] / norm; + view_direction.z = X[2] / norm; + view_direction.w = 0; + return view_direction; +} + +__device__ float GetDistance2Origin(const Camera camera, const int2 p, const float depth, const float4 normal) +{ + float X[3]; + Get3DPoint(camera, p, depth, X); + return -(normal.x * X[0] + normal.y * X[1] + normal.z * X[2]); +} + +__device__ float ComputeDepthfromPlaneHypothesis(const Camera camera, const float4 plane_hypothesis, const int2 p) +{ + return -plane_hypothesis.w * camera.K[0] / ((p.x - camera.K[2]) * plane_hypothesis.x + (camera.K[0] / camera.K[4]) * (p.y - camera.K[5]) * plane_hypothesis.y + camera.K[0] * plane_hypothesis.z); +} + +__device__ float4 GenerateRandomNormal(const Camera camera, const int2 p, curandState *rand_state, const float depth) +{ + float4 normal; + float q1 = 1.0f; + float q2 = 1.0f; + float s = 2.0f; + while (s >= 1.0f) { + q1 = 2.0f * curand_uniform(rand_state) -1.0f; + q2 = 2.0f * curand_uniform(rand_state) - 1.0f; + s = q1 * q1 + q2 * q2; + } + const float sq = sqrt(1.0f - s); + normal.x = 2.0f * q1 * sq; + normal.y = 2.0f * q2 * sq; + normal.z = 1.0f - 2.0f * s; + normal.w = 0; + + float4 view_direction = GetViewDirection(camera, p, depth); + float dot_product = normal.x * view_direction.x + normal.y * view_direction.y + normal.z * view_direction.z; + if (dot_product > 0.0f) { + normal.x = -normal.x; + normal.y = -normal.y; + normal.z = - normal.z; + } + NormalizeVec3(&normal); + return normal; +} + +__device__ float4 GeneratePerturbedNormal(const Camera camera, const int2 p, const float4 normal, curandState *rand_state, const float perturbation) +{ + float4 view_direction = GetViewDirection(camera, p, 1.0f); + + const float a1 = (curand_uniform(rand_state) - 0.5f) * perturbation; + const float a2 = (curand_uniform(rand_state) - 0.5f) * perturbation; + const float a3 = (curand_uniform(rand_state) - 0.5f) * perturbation; + + const float sin_a1 = sin(a1); + const float sin_a2 = sin(a2); + const float sin_a3 = sin(a3); + const float cos_a1 = cos(a1); + const float cos_a2 = cos(a2); + const float cos_a3 = cos(a3); + + float R[9]; + R[0] = cos_a2 * cos_a3; + R[1] = cos_a3 * sin_a1 * sin_a2 - cos_a1 * sin_a3; + R[2] = sin_a1 * sin_a3 + cos_a1 * cos_a3 * sin_a2; + R[3] = cos_a2 * sin_a3; + R[4] = cos_a1 * cos_a3 + sin_a1 * sin_a2 * sin_a3; + R[5] = cos_a1 * sin_a2 * sin_a3 - cos_a3 * sin_a1; + R[6] = -sin_a2; + R[7] = cos_a2 * sin_a1; + R[8] = cos_a1 * cos_a2; + + float4 normal_perturbed; + Mat33DotVec3(R, normal, &normal_perturbed); + + if (Vec3DotVec3(normal_perturbed, view_direction) >= 0.0f) { + normal_perturbed = normal; + } + + NormalizeVec3(&normal_perturbed); + return normal_perturbed; +} + +__device__ float4 GenerateRandomPlaneHypothesis(const Camera camera, const int2 p, curandState *rand_state, const float depth_min, const float depth_max, float init_depth) +{ + float depth = init_depth; + if (depth <= 0){ + depth = curand_uniform(rand_state) * (depth_max - depth_min) + depth_min; + } + // printf("initdepth: %f\n", init_depth); + + float4 plane_hypothesis = GenerateRandomNormal(camera, p, rand_state, depth); + plane_hypothesis.w = GetDistance2Origin(camera, p, depth, plane_hypothesis); + return plane_hypothesis; +} + +__device__ float4 GeneratePertubedPlaneHypothesis(const Camera camera, const int2 p, curandState *rand_state, const float perturbation, const float4 plane_hypothesis_now, const float depth_now, const float depth_min, const float depth_max) +{ + float depth_perturbed = depth_now; + + float dist_perturbed = plane_hypothesis_now.w; + const float dist_min_perturbed = (1 - perturbation) * dist_perturbed; + const float dist_max_perturbed = (1 + perturbation) * dist_perturbed; + float4 plane_hypothesis_temp = plane_hypothesis_now; + do { + dist_perturbed = curand_uniform(rand_state) * (dist_max_perturbed - dist_min_perturbed) + dist_min_perturbed; + plane_hypothesis_temp.w = dist_perturbed; + depth_perturbed = ComputeDepthfromPlaneHypothesis(camera, plane_hypothesis_temp, p); + } while (depth_perturbed < depth_min && depth_perturbed > depth_max); + + float4 plane_hypothesis = GeneratePerturbedNormal(camera, p, plane_hypothesis_now, rand_state, perturbation * M_PI); + plane_hypothesis.w = dist_perturbed; + return plane_hypothesis; +} + +__device__ void ComputeHomography(const Camera ref_camera, const Camera src_camera, const float4 plane_hypothesis, float *H) +{ + float ref_C[3]; + float src_C[3]; + ref_C[0] = -(ref_camera.R[0] * ref_camera.t[0] + ref_camera.R[3] * ref_camera.t[1] + ref_camera.R[6] * ref_camera.t[2]); + ref_C[1] = -(ref_camera.R[1] * ref_camera.t[0] + ref_camera.R[4] * ref_camera.t[1] + ref_camera.R[7] * ref_camera.t[2]); + ref_C[2] = -(ref_camera.R[2] * ref_camera.t[0] + ref_camera.R[5] * ref_camera.t[1] + ref_camera.R[8] * ref_camera.t[2]); + src_C[0] = -(src_camera.R[0] * src_camera.t[0] + src_camera.R[3] * src_camera.t[1] + src_camera.R[6] * src_camera.t[2]); + src_C[1] = -(src_camera.R[1] * src_camera.t[0] + src_camera.R[4] * src_camera.t[1] + src_camera.R[7] * src_camera.t[2]); + src_C[2] = -(src_camera.R[2] * src_camera.t[0] + src_camera.R[5] * src_camera.t[1] + src_camera.R[8] * src_camera.t[2]); + + float R_relative[9]; + float C_relative[3]; + float t_relative[3]; + R_relative[0] = src_camera.R[0] * ref_camera.R[0] + src_camera.R[1] * ref_camera.R[1] + src_camera.R[2] *ref_camera.R[2]; + R_relative[1] = src_camera.R[0] * ref_camera.R[3] + src_camera.R[1] * ref_camera.R[4] + src_camera.R[2] *ref_camera.R[5]; + R_relative[2] = src_camera.R[0] * ref_camera.R[6] + src_camera.R[1] * ref_camera.R[7] + src_camera.R[2] *ref_camera.R[8]; + R_relative[3] = src_camera.R[3] * ref_camera.R[0] + src_camera.R[4] * ref_camera.R[1] + src_camera.R[5] *ref_camera.R[2]; + R_relative[4] = src_camera.R[3] * ref_camera.R[3] + src_camera.R[4] * ref_camera.R[4] + src_camera.R[5] *ref_camera.R[5]; + R_relative[5] = src_camera.R[3] * ref_camera.R[6] + src_camera.R[4] * ref_camera.R[7] + src_camera.R[5] *ref_camera.R[8]; + R_relative[6] = src_camera.R[6] * ref_camera.R[0] + src_camera.R[7] * ref_camera.R[1] + src_camera.R[8] *ref_camera.R[2]; + R_relative[7] = src_camera.R[6] * ref_camera.R[3] + src_camera.R[7] * ref_camera.R[4] + src_camera.R[8] *ref_camera.R[5]; + R_relative[8] = src_camera.R[6] * ref_camera.R[6] + src_camera.R[7] * ref_camera.R[7] + src_camera.R[8] *ref_camera.R[8]; + C_relative[0] = (ref_C[0] - src_C[0]); + C_relative[1] = (ref_C[1] - src_C[1]); + C_relative[2] = (ref_C[2] - src_C[2]); + t_relative[0] = src_camera.R[0] * C_relative[0] + src_camera.R[1] * C_relative[1] + src_camera.R[2] * C_relative[2]; + t_relative[1] = src_camera.R[3] * C_relative[0] + src_camera.R[4] * C_relative[1] + src_camera.R[5] * C_relative[2]; + t_relative[2] = src_camera.R[6] * C_relative[0] + src_camera.R[7] * C_relative[1] + src_camera.R[8] * C_relative[2]; + + H[0] = R_relative[0] - t_relative[0] * plane_hypothesis.x / plane_hypothesis.w; + H[1] = R_relative[1] - t_relative[0] * plane_hypothesis.y / plane_hypothesis.w; + H[2] = R_relative[2] - t_relative[0] * plane_hypothesis.z / plane_hypothesis.w; + H[3] = R_relative[3] - t_relative[1] * plane_hypothesis.x / plane_hypothesis.w; + H[4] = R_relative[4] - t_relative[1] * plane_hypothesis.y / plane_hypothesis.w; + H[5] = R_relative[5] - t_relative[1] * plane_hypothesis.z / plane_hypothesis.w; + H[6] = R_relative[6] - t_relative[2] * plane_hypothesis.x / plane_hypothesis.w; + H[7] = R_relative[7] - t_relative[2] * plane_hypothesis.y / plane_hypothesis.w; + H[8] = R_relative[8] - t_relative[2] * plane_hypothesis.z / plane_hypothesis.w; + + float tmp[9]; + tmp[0] = H[0] / ref_camera.K[0]; + tmp[1] = H[1] / ref_camera.K[4]; + tmp[2] = -H[0] * ref_camera.K[2] / ref_camera.K[0] - H[1] * ref_camera.K[5] / ref_camera.K[4] + H[2]; + tmp[3] = H[3] / ref_camera.K[0]; + tmp[4] = H[4] / ref_camera.K[4]; + tmp[5] = -H[3] * ref_camera.K[2] / ref_camera.K[0] - H[4] * ref_camera.K[5] / ref_camera.K[4] + H[5]; + tmp[6] = H[6] / ref_camera.K[0]; + tmp[7] = H[7] / ref_camera.K[4]; + tmp[8] = -H[6] * ref_camera.K[2] / ref_camera.K[0] - H[7] * ref_camera.K[5] / ref_camera.K[4] + H[8]; + + H[0] = src_camera.K[0] * tmp[0] + src_camera.K[2] * tmp[6]; + H[1] = src_camera.K[0] * tmp[1] + src_camera.K[2] * tmp[7]; + H[2] = src_camera.K[0] * tmp[2] + src_camera.K[2] * tmp[8]; + H[3] = src_camera.K[4] * tmp[3] + src_camera.K[5] * tmp[6]; + H[4] = src_camera.K[4] * tmp[4] + src_camera.K[5] * tmp[7]; + H[5] = src_camera.K[4] * tmp[5] + src_camera.K[5] * tmp[8]; + H[6] = src_camera.K[8] * tmp[6]; + H[7] = src_camera.K[8] * tmp[7]; + H[8] = src_camera.K[8] * tmp[8]; +} + +__device__ float2 ComputeCorrespondingPoint(const float *H, const int2 p) +{ + float3 pt; + pt.x = H[0] * p.x + H[1] * p.y + H[2]; + pt.y = H[3] * p.x + H[4] * p.y + H[5]; + pt.z = H[6] * p.x + H[7] * p.y + H[8]; + return make_float2(pt.x / pt.z, pt.y / pt.z); +} + +__device__ float4 TransformNormal(const Camera camera, float4 plane_hypothesis) +{ + float4 transformed_normal; + transformed_normal.x = camera.R[0] * plane_hypothesis.x + camera.R[3] * plane_hypothesis.y + camera.R[6] * plane_hypothesis.z; + transformed_normal.y = camera.R[1] * plane_hypothesis.x + camera.R[4] * plane_hypothesis.y + camera.R[7] * plane_hypothesis.z; + transformed_normal.z = camera.R[2] * plane_hypothesis.x + camera.R[5] * plane_hypothesis.y + camera.R[8] * plane_hypothesis.z; + transformed_normal.w = plane_hypothesis.w; + return transformed_normal; +} + +__device__ float4 TransformNormal2RefCam(const Camera camera, float4 plane_hypothesis) +{ + float4 transformed_normal; + transformed_normal.x = camera.R[0] * plane_hypothesis.x + camera.R[1] * plane_hypothesis.y + camera.R[2] * plane_hypothesis.z; + transformed_normal.y = camera.R[3] * plane_hypothesis.x + camera.R[4] * plane_hypothesis.y + camera.R[5] * plane_hypothesis.z; + transformed_normal.z = camera.R[6] * plane_hypothesis.x + camera.R[7] * plane_hypothesis.y + camera.R[8] * plane_hypothesis.z; + transformed_normal.w = plane_hypothesis.w; + return transformed_normal; +} + +__device__ float ComputeBilateralWeight(const float x_dist, const float y_dist, const float pix, const float center_pix, const float sigma_spatial, const float sigma_color) +{ + const float spatial_dist = sqrt(x_dist * x_dist + y_dist * y_dist); + const float color_dist = fabs(pix - center_pix); + return exp(-spatial_dist / (2.0f * sigma_spatial* sigma_spatial) - color_dist / (2.0f * sigma_color * sigma_color)); +} + +__device__ float ComputeBilateralNCC(const cudaTextureObject_t ref_image, const Camera ref_camera, const cudaTextureObject_t src_image, const Camera src_camera, const int2 p, const float4 plane_hypothesis, const PatchMatchParams params) +{ + const float cost_max = 2.0f; + int radius = params.patch_size / 2; + + float H[9]; + ComputeHomography(ref_camera, src_camera, plane_hypothesis, H); + float2 pt = ComputeCorrespondingPoint(H, p); + if (pt.x >= src_camera.width || pt.x < 0.0f || pt.y >= src_camera.height || pt.y < 0.0f) { + return cost_max; + } + + float cost = 0.0f; + { + float sum_ref = 0.0f; + float sum_ref_ref = 0.0f; + float sum_src = 0.0f; + float sum_src_src = 0.0f; + float sum_ref_src = 0.0f; + float bilateral_weight_sum = 0.0f; + const float ref_center_pix = tex2D(ref_image, p.x + 0.5f, p.y + 0.5f); + + for (int i = -radius; i < radius + 1; i += params.radius_increment) { + float sum_ref_row = 0.0f; + float sum_src_row = 0.0f; + float sum_ref_ref_row = 0.0f; + float sum_src_src_row = 0.0f; + float sum_ref_src_row = 0.0f; + float bilateral_weight_sum_row = 0.0f; + + for (int j = -radius; j < radius + 1; j += params.radius_increment) { + const int2 ref_pt = make_int2(p.x + i, p.y + j); + const float ref_pix = tex2D(ref_image, ref_pt.x + 0.5f, ref_pt.y + 0.5f); + float2 src_pt = ComputeCorrespondingPoint(H, ref_pt); + const float src_pix = tex2D(src_image, src_pt.x + 0.5f, src_pt.y + 0.5f); + + float weight = ComputeBilateralWeight(i, j, ref_pix, ref_center_pix, params.sigma_spatial, params.sigma_color); + + sum_ref_row += weight * ref_pix; + sum_ref_ref_row += weight * ref_pix * ref_pix; + sum_src_row += weight * src_pix; + sum_src_src_row += weight * src_pix * src_pix; + sum_ref_src_row += weight * ref_pix * src_pix; + bilateral_weight_sum_row += weight; + } + + sum_ref += sum_ref_row; + sum_ref_ref += sum_ref_ref_row; + sum_src += sum_src_row; + sum_src_src += sum_src_src_row; + sum_ref_src += sum_ref_src_row; + bilateral_weight_sum += bilateral_weight_sum_row; + } + const float inv_bilateral_weight_sum = 1.0f / bilateral_weight_sum; + sum_ref *= inv_bilateral_weight_sum; + sum_ref_ref *= inv_bilateral_weight_sum; + sum_src *= inv_bilateral_weight_sum; + sum_src_src *= inv_bilateral_weight_sum; + sum_ref_src *= inv_bilateral_weight_sum; + + const float var_ref = sum_ref_ref - sum_ref * sum_ref; + const float var_src = sum_src_src - sum_src * sum_src; + + const float kMinVar = 1e-5f; + if (var_ref < kMinVar || var_src < kMinVar) { + return cost = cost_max; + } else { + const float covar_src_ref = sum_ref_src - sum_ref * sum_src; + const float var_ref_src = sqrt(var_ref * var_src); + return cost = max(0.0f, min(cost_max, 1.0f - covar_src_ref / var_ref_src)); + } + } +} + +__device__ float ComputeMultiViewInitialCostandSelectedViews(const cudaTextureObject_t *images, const Camera *cameras, const int2 p, const float4 plane_hypothesis, unsigned int *selected_views, const PatchMatchParams params) +{ + float cost_max = 2.0f; + float cost_vector[32] = {2.0f}; + float cost_vector_copy[32] = {2.0f}; + int cost_count = 0; + int num_valid_views = 0; + + for (int i = 1; i < params.num_images; ++i) { + float c = ComputeBilateralNCC(images[0], cameras[0], images[i], cameras[i], p, plane_hypothesis, params); + cost_vector[i - 1] = c; + cost_vector_copy[i - 1] = c; + cost_count++; + if (c < cost_max) { + num_valid_views++; + } + } + + sort_small(cost_vector, cost_count); + *selected_views = 0; + + int top_k = min(num_valid_views, params.top_k); + if (top_k > 0) { + float cost = 0.0f; + for (int i = 0; i < top_k; ++i) { + cost += cost_vector[i]; + } + float cost_threshold = cost_vector[top_k - 1]; + for (int i = 0; i < params.num_images - 1; ++i) { + if (cost_vector_copy[i] <= cost_threshold) { + setBit(*selected_views, i); + } + } + return cost / top_k; + } else { + return cost_max; + } +} + +__device__ void ComputeMultiViewCostVector(const cudaTextureObject_t *images, const Camera *cameras, const int2 p, const float4 plane_hypothesis, float *cost_vector, const PatchMatchParams params) +{ + for (int i = 1; i < params.num_images; ++i) { + cost_vector[i - 1] = ComputeBilateralNCC(images[0], cameras[0], images[i], cameras[i], p, plane_hypothesis, params); + } +} + +__device__ float3 Get3DPointonWorld_cu(const float x, const float y, const float depth, const Camera camera) +{ + float3 pointX; + float3 tmpX; + // Reprojection + pointX.x = depth * (x - camera.K[2]) / camera.K[0]; + pointX.y = depth * (y - camera.K[5]) / camera.K[4]; + pointX.z = depth; + + // Rotation + tmpX.x = camera.R[0] * pointX.x + camera.R[3] * pointX.y + camera.R[6] * pointX.z; + tmpX.y = camera.R[1] * pointX.x + camera.R[4] * pointX.y + camera.R[7] * pointX.z; + tmpX.z = camera.R[2] * pointX.x + camera.R[5] * pointX.y + camera.R[8] * pointX.z; + + // Transformation + float3 C; + C.x = -(camera.R[0] * camera.t[0] + camera.R[3] * camera.t[1] + camera.R[6] * camera.t[2]); + C.y = -(camera.R[1] * camera.t[0] + camera.R[4] * camera.t[1] + camera.R[7] * camera.t[2]); + C.z = -(camera.R[2] * camera.t[0] + camera.R[5] * camera.t[1] + camera.R[8] * camera.t[2]); + pointX.x = tmpX.x + C.x; + pointX.y = tmpX.y + C.y; + pointX.z = tmpX.z + C.z; + + return pointX; +} + +__device__ void ProjectonCamera_cu(const float3 PointX, const Camera camera, float2 &point, float &depth) +{ + float3 tmp; + tmp.x = camera.R[0] * PointX.x + camera.R[1] * PointX.y + camera.R[2] * PointX.z + camera.t[0]; + tmp.y = camera.R[3] * PointX.x + camera.R[4] * PointX.y + camera.R[5] * PointX.z + camera.t[1]; + tmp.z = camera.R[6] * PointX.x + camera.R[7] * PointX.y + camera.R[8] * PointX.z + camera.t[2]; + + depth = camera.K[6] * tmp.x + camera.K[7] * tmp.y + camera.K[8] * tmp.z; + point.x = (camera.K[0] * tmp.x + camera.K[1] * tmp.y + camera.K[2] * tmp.z) / depth; + point.y = (camera.K[3] * tmp.x + camera.K[4] * tmp.y + camera.K[5] * tmp.z) / depth; +} + +__device__ float ComputeGeomConsistencyCost(const cudaTextureObject_t depth_image, const Camera ref_camera, const Camera src_camera, const float4 plane_hypothesis, const int2 p) +{ + const float max_cost = 5.0f; + + float depth = ComputeDepthfromPlaneHypothesis(ref_camera, plane_hypothesis, p); + float3 forward_point = Get3DPointonWorld_cu(p.x, p.y, depth, ref_camera); + + float2 src_pt; + float src_d; + ProjectonCamera_cu(forward_point, src_camera, src_pt, src_d); + const float src_depth = tex2D(depth_image, (int)src_pt.x + 0.5f, (int)src_pt.y + 0.5f); + + if (src_depth == 0.0f) { + return max_cost; + } + + float3 src_3D_pt = Get3DPointonWorld_cu(src_pt.x, src_pt.y, src_depth, src_camera); + + float2 backward_point; + float ref_d; + ProjectonCamera_cu(src_3D_pt, ref_camera, backward_point, ref_d); + + const float diff_col = p.x - backward_point.x; + const float diff_row = p.y - backward_point.y; + return min(max_cost, sqrt(diff_col * diff_col + diff_row * diff_row)); +} + +__global__ void RandomInitialization(cudaTextureObjects *texture_objects, Camera *cameras, float4 *plane_hypotheses, + float *costs, curandState *rand_states, unsigned int *selected_views, + const PatchMatchParams params, float *depths) +{ + const int2 p = make_int2(blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y); + int width = cameras[0].width; + int height = cameras[0].height; + + if (p.x >= width || p.y >= height) { + return; + } + + const int center = p.y * width + p.x; + curand_init(clock64(), p.y, p.x, &rand_states[center]); + + float4 plane_hypothesis = plane_hypotheses[center]; + plane_hypothesis = TransformNormal2RefCam(cameras[0], plane_hypothesis); + float depth = plane_hypothesis.w; + plane_hypothesis.w = GetDistance2Origin(cameras[0], p, depth, plane_hypothesis); + plane_hypotheses[center] = plane_hypothesis; + costs[center] = ComputeMultiViewInitialCostandSelectedViews(texture_objects[0].images, cameras, p, plane_hypotheses[center], &selected_views[center], params); +} + +__device__ void PlaneHypothesisRefinement(const cudaTextureObject_t *images, const cudaTextureObject_t *depth_images, const Camera *cameras, float4 *plane_hypothesis, float *depth, float *cost, curandState *rand_state, const float *view_weights, const float weight_norm, const int2 p, const PatchMatchParams params) +{ + float perturbation = 0.02f; + + float depth_rand = curand_uniform(rand_state) * (params.depth_max - params.depth_min) + params.depth_min; + float depth_perturbed = *depth; + const float depth_min_perturbed = (1 - perturbation) * depth_perturbed; + const float depth_max_perturbed = (1 + perturbation) * depth_perturbed; + do { + depth_perturbed = curand_uniform(rand_state) * (depth_max_perturbed - depth_min_perturbed) + depth_min_perturbed; + } while (depth_perturbed < params.depth_min && depth_perturbed > params.depth_max); + + const int num_planes = 2; // Reduced from 5 to 2 as we're only testing current and perturbed depth + float depths[num_planes] = {*depth, depth_perturbed}; + float4 normal = make_float4(plane_hypothesis->x, plane_hypothesis->y, plane_hypothesis->z, 0); // Keep the normal fixed + + for (int i = 0; i < num_planes; ++i) { + float cost_vector[32] = {2.0f}; + float4 temp_plane_hypothesis = normal; + temp_plane_hypothesis.w = GetDistance2Origin(cameras[0], p, depths[i], temp_plane_hypothesis); + ComputeMultiViewCostVector(images, cameras, p, temp_plane_hypothesis, cost_vector, params); + + float temp_cost = 0.0f; + for (int j = 0; j < params.num_images - 1; ++j) { + if (view_weights[j] > 0) { + if (params.geom_consistency) { + temp_cost += view_weights[j] * (cost_vector[j] + 0.1f * ComputeGeomConsistencyCost(depth_images[j+1], cameras[0], cameras[j+1], temp_plane_hypothesis, p)); + } + else { + temp_cost += view_weights[j] * cost_vector[j]; + } + } + } + temp_cost /= weight_norm; + + if (temp_cost < *cost) { + *depth = depths[i]; + plane_hypothesis->w = temp_plane_hypothesis.w; + *cost = temp_cost; + } + } +} + +__device__ void CheckerboardPropagation(const cudaTextureObject_t *images, const cudaTextureObject_t *depths, const Camera *cameras, + float4 *plane_hypotheses, float *costs, curandState *rand_states, unsigned int *selected_views, + const int2 p, const PatchMatchParams params, const int iter) +{ + int width = cameras[0].width; + int height = cameras[0].height; + if (p.x >= width || p.y >= height) { + return; + } + + const int center = p.y * width + p.x; + int left_near = center - 1; + int left_far = center - 3; + int right_near = center + 1; + int right_far = center + 3; + int up_near = center - width; + int up_far = center - 3 * width; + int down_near = center + width; + int down_far = center + 3 * width; + + // Adaptive Checkerboard Sampling + float cost_array[8][32] = {2.0f}; + // 0 -- up_near, 1 -- up_far, 2 -- down_near, 3 -- down_far, 4 -- left_near, 5 -- left_far, 6 -- right_near, 7 -- right_far + bool flag[8] = {false}; + int num_valid_pixels = 0; + + float costMin; + int costMinPoint; + + // up_far + if (p.y > 2) { + flag[1] = true; + num_valid_pixels++; + costMin = costs[up_far]; + costMinPoint = up_far; + for (int i = 1; i < 11; ++i) { + if (p.y > 2 + 2 * i) { + int pointTemp = up_far - 2 * i * width; + if (costs[pointTemp] < costMin) { + costMin = costs[pointTemp]; + costMinPoint = pointTemp; + } + } + } + up_far = costMinPoint; + ComputeMultiViewCostVector(images, cameras, p, plane_hypotheses[up_far], cost_array[1], params); + } + + // dwon_far + if (p.y < height - 3) { + flag[3] = true; + num_valid_pixels++; + costMin = costs[down_far]; + costMinPoint = down_far; + for (int i = 1; i < 11; ++i) { + if (p.y < height - 3 - 2 * i) { + int pointTemp = down_far + 2 * i * width; + if (costs[pointTemp] < costMin) { + costMin = costs[pointTemp]; + costMinPoint = pointTemp; + } + } + } + down_far = costMinPoint; + ComputeMultiViewCostVector(images, cameras, p, plane_hypotheses[down_far], cost_array[3], params); + } + + // left_far + if (p.x > 2) { + flag[5] = true; + num_valid_pixels++; + costMin = costs[left_far]; + costMinPoint = left_far; + for (int i = 1; i < 11; ++i) { + if (p.x > 2 + 2 * i) { + int pointTemp = left_far - 2 * i; + if (costs[pointTemp] < costMin) { + costMin = costs[pointTemp]; + costMinPoint = pointTemp; + } + } + } + left_far = costMinPoint; + ComputeMultiViewCostVector(images, cameras, p, plane_hypotheses[left_far], cost_array[5], params); + } + + // right_far + if (p.x < width - 3) { + flag[7] = true; + num_valid_pixels++; + costMin = costs[right_far]; + costMinPoint = right_far; + for (int i = 1; i < 11; ++i) { + if (p.x < width - 3 - 2 * i) { + int pointTemp = right_far + 2 * i; + if (costMin < costs[pointTemp]) { + costMin = costs[pointTemp]; + costMinPoint = pointTemp; + } + } + } + right_far = costMinPoint; + ComputeMultiViewCostVector(images, cameras, p, plane_hypotheses[right_far], cost_array[7], params); + } + + // up_near + if (p.y > 0) { + flag[0] = true; + num_valid_pixels++; + costMin = costs[up_near]; + costMinPoint = up_near; + for (int i = 0; i < 3; ++i) { + if (p.y > 1 + i && p.x > i) { + int pointTemp = up_near - (1 + i) * width - i; + if (costs[pointTemp] < costMin) { + costMin = costs[pointTemp]; + costMinPoint = pointTemp; + } + } + if (p.y > 1 + i && p.x < width - 1 - i) { + int pointTemp = up_near - (1 + i) * width + i; + if (costs[pointTemp] < costMin) { + costMin = costs[pointTemp]; + costMinPoint = pointTemp; + } + } + } + up_near = costMinPoint; + ComputeMultiViewCostVector(images, cameras, p, plane_hypotheses[up_near], cost_array[0], params); + } + + // down_near + if (p.y < height - 1) { + flag[2] = true; + num_valid_pixels++; + costMin = costs[down_near]; + costMinPoint = down_near; + for (int i = 0; i < 3; ++i) { + if (p.y < height - 2 - i && p.x > i) { + int pointTemp = down_near + (1 + i) * width - i; + if (costs[pointTemp] < costMin) { + costMin = costs[pointTemp]; + costMinPoint = pointTemp; + } + } + if (p.y < height - 2 - i && p.x < width - 1 - i) { + int pointTemp = down_near + (1 + i) * width + i; + if (costs[pointTemp] < costMin) { + costMin = costs[pointTemp]; + costMinPoint = pointTemp; + } + } + } + down_near = costMinPoint; + ComputeMultiViewCostVector(images, cameras, p, plane_hypotheses[down_near], cost_array[2], params); + } + + // left_near + if (p.x > 0) { + flag[4] = true; + num_valid_pixels++; + costMin = costs[left_near]; + costMinPoint = left_near; + for (int i = 0; i < 3; ++i) { + if (p.x > 1 + i && p.y > i) { + int pointTemp = left_near - (1 + i) - i * width; + if (costs[pointTemp] < costMin) { + costMin = costs[pointTemp]; + costMinPoint = pointTemp; + } + } + if (p.x > 1 + i && p.y < height - 1 - i) { + int pointTemp = left_near - (1 + i) + i * width; + if (costs[pointTemp] < costMin) { + costMin = costs[pointTemp]; + costMinPoint = pointTemp; + } + } + } + left_near = costMinPoint; + ComputeMultiViewCostVector(images, cameras, p, plane_hypotheses[left_near], cost_array[4], params); + } + + // right_near + if (p.x < width - 1) { + flag[6] = true; + num_valid_pixels++; + costMin = costs[right_near]; + costMinPoint = right_near; + for (int i = 0; i < 3; ++i) { + if (p.x < width - 2 - i && p.y > i) { + int pointTemp = right_near + (1 + i) - i * width; + if (costs[pointTemp] < costMin) { + costMin = costs[pointTemp]; + costMinPoint = pointTemp; + } + } + if (p.x < width - 2 - i && p.y < height - 1- i) { + int pointTemp = right_near + (1 + i) + i * width; + if (costs[pointTemp] < costMin) { + costMin = costs[pointTemp]; + costMinPoint = pointTemp; + } + } + } + right_near = costMinPoint; + ComputeMultiViewCostVector(images, cameras, p, plane_hypotheses[right_near], cost_array[6], params); + } + const int positions[8] = {up_near, up_far, down_near, down_far, left_near, left_far, right_near, right_far}; + + // Multi-hypothesis Joint View Selection + float view_weights[32] = {0.0f}; + float view_selection_priors[32] = {0.0f}; + int neighbor_positions[4] = {center - width, center + width, center - 1, center + 1}; + for (int i = 0; i < 4; ++i) { + if (flag[2 * i]) { + for (int j = 0; j < params.num_images - 1; ++j) { + if (isSet(selected_views[neighbor_positions[i]], j) == 1) { + view_selection_priors[j] += 0.9f; + } else { + view_selection_priors[j] += 0.1f; + } + } + } + } + + float sampling_probs[32] = {0.0f}; + float cost_threshold = 0.8 * expf((iter) * (iter) / (-90.0f)); + for (int i = 0; i < params.num_images - 1; i++) { + float count = 0; + int count_false = 0; + float tmpw = 0; + for (int j = 0; j < 8; j++) { + if (cost_array[j][i] < cost_threshold) { + tmpw += expf(cost_array[j][i] * cost_array[j][i] / (-0.18f)); + count++; + } + if (cost_array[j][i] > 1.2f) { + count_false++; + } + } + if (count > 2 && count_false < 3) { + sampling_probs[i] = tmpw / count; + } + else if (count_false < 3) { + sampling_probs[i] = expf(cost_threshold * cost_threshold / (-0.32f)); + } + sampling_probs[i] = sampling_probs[i] * view_selection_priors[i]; + } + + TransformPDFToCDF(sampling_probs, params.num_images - 1); + for (int sample = 0; sample < 15; ++sample) { + const float rand_prob = curand_uniform(&rand_states[center]) - FLT_EPSILON; + + for (int image_id = 0; image_id < params.num_images - 1; ++image_id) { + const float prob = sampling_probs[image_id]; + if (prob > rand_prob) { + view_weights[image_id] += 1.0f; + break; + } + } + } + + unsigned int temp_selected_views = 0; + int num_selected_view = 0; + float weight_norm = 0; + for (int i = 0; i < params.num_images - 1; ++i) { + if (view_weights[i] > 0) { + setBit(temp_selected_views, i); + weight_norm += view_weights[i]; + num_selected_view++; + } + } + + float final_costs[8] = {0.0f}; + for (int i = 0; i < 8; ++i) { + if (flag[i]) { + float4 temp_plane_hypothesis = plane_hypotheses[center]; + temp_plane_hypothesis.w = GetDistance2Origin(cameras[0], p, ComputeDepthfromPlaneHypothesis(cameras[0], plane_hypotheses[positions[i]], p), temp_plane_hypothesis); + + for (int j = 0; j < params.num_images - 1; ++j) { + if (view_weights[j] > 0) { + if (params.geom_consistency) { + final_costs[i] += view_weights[j] * (cost_array[i][j] + 0.1f * ComputeGeomConsistencyCost(depths[j+1], cameras[0], cameras[j+1], temp_plane_hypothesis, p)); + } + else { + final_costs[i] += view_weights[j] * cost_array[i][j]; + } + } + } + final_costs[i] /= weight_norm; + } + else { + final_costs[i] = FLT_MAX; + } + } + + const int min_cost_idx = FindMinCostIndex(final_costs, 8); + + float cost_vector_now[32] = {2.0f}; + ComputeMultiViewCostVector(images, cameras, p, plane_hypotheses[center], cost_vector_now, params); + float cost_now = 0.0f; + for (int i = 0; i < params.num_images - 1; ++i) { + if (params.geom_consistency) { + cost_now += view_weights[i] * (cost_vector_now[i] + 0.1f * ComputeGeomConsistencyCost(depths[i+1], cameras[0], cameras[i+1], plane_hypotheses[center], p)); + } + else { + cost_now += view_weights[i] * cost_vector_now[i]; + } + } + cost_now /= weight_norm; + costs[center] = cost_now; + float depth_now = ComputeDepthfromPlaneHypothesis(cameras[0], plane_hypotheses[center], p); + + if (flag[min_cost_idx]) { + float depth_before = ComputeDepthfromPlaneHypothesis(cameras[0], plane_hypotheses[positions[min_cost_idx]], p); + + if (depth_before >= params.depth_min && depth_before <= params.depth_max && final_costs[min_cost_idx] < cost_now) { + depth_now = depth_before; + plane_hypotheses[center] = plane_hypotheses[positions[min_cost_idx]]; + costs[center] = final_costs[min_cost_idx]; + selected_views[center] = temp_selected_views; + } + } + + PlaneHypothesisRefinement(images, depths, cameras, &plane_hypotheses[center], &depth_now, &costs[center], &rand_states[center], view_weights, weight_norm, p, params); +} + +__global__ void BlackPixelUpdate(cudaTextureObjects *texture_objects, cudaTextureObjects *texture_depths, + Camera *cameras, float4 *plane_hypotheses, float *costs, curandState *rand_states, + unsigned int *selected_views, const PatchMatchParams params, const int iter) +{ + int2 p = make_int2(blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y); + if (threadIdx.x % 2 == 0) { + p.y = p.y * 2; + } else { + p.y = p.y * 2 + 1; + } + + CheckerboardPropagation(texture_objects[0].images, texture_depths[0].images, cameras, plane_hypotheses, costs, rand_states, selected_views, p, params, iter); +} + +__global__ void RedPixelUpdate(cudaTextureObjects *texture_objects, cudaTextureObjects *texture_depths, + Camera *cameras, float4 *plane_hypotheses, float *costs, curandState *rand_states, + unsigned int *selected_views, const PatchMatchParams params, const int iter) +{ + int2 p = make_int2(blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y); + if (threadIdx.x % 2 == 0) { + p.y = p.y * 2 + 1; + } else { + p.y = p.y * 2; + } + + CheckerboardPropagation(texture_objects[0].images, texture_depths[0].images, cameras, plane_hypotheses, costs, rand_states, selected_views, p, params, iter); +} + +__global__ void GetDepthandNormal(Camera *cameras, float4 *plane_hypotheses, const PatchMatchParams params) +{ + const int2 p = make_int2(blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y); + const int width = cameras[0].width; + const int height = cameras[0].height; + + if (p.x >= width || p.y >= height) { + return; + } + + const int center = p.y * width + p.x; + plane_hypotheses[center].w = ComputeDepthfromPlaneHypothesis(cameras[0], plane_hypotheses[center], p); + plane_hypotheses[center] = TransformNormal(cameras[0], plane_hypotheses[center]); +} + +__device__ void CheckerboardFilter(const Camera *cameras, float4 *plane_hypotheses, float *costs, const int2 p) +{ + int width = cameras[0].width; + int height = cameras[0].height; + if (p.x >= width || p.y >= height) { + return; + } + + const int center = p.y * width + p.x; + + float filter[21]; + int index = 0; + + filter[index++] = plane_hypotheses[center].w; + + // Left + const int left = center - 1; + const int leftleft = center - 3; + + // Up + const int up = center - width; + const int upup = center - 3 * width; + + // Down + const int down = center + width; + const int downdown = center + 3 * width; + + // Right + const int right = center + 1; + const int rightright = center + 3; + + if (costs[center] < 0.001f) { + return; + } + + if (p.y>0) { + filter[index++] = plane_hypotheses[up].w; + } + if (p.y>2) { + filter[index++] = plane_hypotheses[upup].w; + } + if (p.y>4) { + filter[index++] = plane_hypotheses[upup-width*2].w; + } + if (p.y0) { + filter[index++] = plane_hypotheses[left].w; + } + if (p.x>2) { + filter[index++] = plane_hypotheses[leftleft].w; + } + if (p.x>4) { + filter[index++] = plane_hypotheses[leftleft-2].w; + } + if (p.x0 && + p.x0 && + p.x>1) + { + filter[index++] = plane_hypotheses[up-2].w; + } + if (p.y1) { + filter[index++] = plane_hypotheses[down-2].w; + } + if (p.x>0 && + p.y>2) + { + filter[index++] = plane_hypotheses[left - width*2].w; + } + if (p.x2) + { + filter[index++] = plane_hypotheses[right - width*2].w; + } + if (p.x>0 && + p.y>>(texture_objects_cuda, cameras_cuda, plane_hypotheses_cuda, costs_cuda, rand_states_cuda, selected_views_cuda, params, depths_cuda); + CUDA_SAFE_CALL(cudaDeviceSynchronize()); + + for (int i = 0; i < max_iterations; ++i) { + + BlackPixelUpdate<<>>(texture_objects_cuda, texture_depths_cuda, cameras_cuda, plane_hypotheses_cuda, costs_cuda, rand_states_cuda, selected_views_cuda, params, i); + CUDA_SAFE_CALL(cudaDeviceSynchronize()); + + RedPixelUpdate<<>>(texture_objects_cuda, texture_depths_cuda, cameras_cuda, plane_hypotheses_cuda, costs_cuda, rand_states_cuda, selected_views_cuda, params, i); + CUDA_SAFE_CALL(cudaDeviceSynchronize()); + // printf("iteration: %d\n", i); + } + + GetDepthandNormal<<>>(cameras_cuda, plane_hypotheses_cuda, params); + CUDA_SAFE_CALL(cudaDeviceSynchronize()); + + BlackPixelFilter<<>>(cameras_cuda, plane_hypotheses_cuda, costs_cuda); + CUDA_SAFE_CALL(cudaDeviceSynchronize()); + RedPixelFilter<<>>(cameras_cuda, plane_hypotheses_cuda, costs_cuda); + CUDA_SAFE_CALL(cudaDeviceSynchronize()); + + cudaMemcpy(plane_hypotheses_host, plane_hypotheses_cuda, sizeof(float4) * width * height, cudaMemcpyDeviceToHost); + cudaMemcpy(costs_host, costs_cuda, sizeof(float) * width * height, cudaMemcpyDeviceToHost); + CUDA_SAFE_CALL(cudaDeviceSynchronize()); +} + +torch::Tensor propagate_cuda(torch::Tensor images, torch::Tensor intrinsics, torch::Tensor poses, + torch::Tensor depth, torch::Tensor normal, torch::Tensor depth_intervals, int patch_size) +{ + cudaSetDevice(0); + + PatchMatch pm; + pm.SetPatchSize(patch_size); + + images = images.to(torch::kFloat); + + pm.InuputInitialization(images, intrinsics, poses, depth, normal, depth_intervals); + + pm.CudaSpaceInitialization(); + pm.RunPatchMatch(); + + const int width = pm.GetReferenceImageWidth(); + const int height = pm.GetReferenceImageHeight(); + + torch::Tensor depths = torch::zeros({height, width}, torch::kFloat); + torch::Tensor normals = torch::zeros({height, width, 3}, torch::kFloat); + + int numPixels = width * height; + + float4* plane_hypotheses = pm.GetPlaneHypotheses(); + + torch::Tensor planeHypothesisTensor = torch::from_blob(plane_hypotheses, {numPixels, 4}, torch::kFloat); + + torch::Tensor propagated_depth = planeHypothesisTensor.index({torch::indexing::Slice(), 3}).reshape({height, width}).unsqueeze(0); + + torch::Tensor propagated_normal = planeHypothesisTensor.index({torch::indexing::Slice(), torch::indexing::Slice(0, 3)}).reshape({height, width, 3}).permute({2, 0, 1}); + + torch::Tensor results = torch::cat({propagated_depth, propagated_normal}, 0); + + return results; +} \ No newline at end of file diff --git a/submodules/Propagation/main.h b/submodules/Propagation/main.h new file mode 100644 index 00000000..c1b230bd --- /dev/null +++ b/submodules/Propagation/main.h @@ -0,0 +1,48 @@ +#ifndef _MAIN_H_ +#define _MAIN_H_ + +// Includes CUDA +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include "iomanip" + +#include // mkdir +#include // mkdir + +#define MAX_IMAGES 256 + +struct Camera { + float K[9]; + float R[9]; + float t[3]; + int height; + int width; + float depth_min; + float depth_max; +}; + +struct Problem { + int ref_image_id; + std::vector src_image_ids; +}; + +struct PointList { + float3 coord; + float3 normal; + float3 color; +}; + +#endif // _MAIN_H_ diff --git a/submodules/Propagation/pro.cpp b/submodules/Propagation/pro.cpp new file mode 100644 index 00000000..d52db3c8 --- /dev/null +++ b/submodules/Propagation/pro.cpp @@ -0,0 +1,29 @@ +#include +#include + +torch::Tensor propagate_cuda( + torch::Tensor images, + torch::Tensor intrinsics, + torch::Tensor poses, + torch::Tensor depth, + torch::Tensor normal, + torch::Tensor depth_intervals, + int patch_size); + +torch::Tensor propagate( + torch::Tensor images, + torch::Tensor intrinsics, + torch::Tensor poses, + torch::Tensor depth, + torch::Tensor normal, + torch::Tensor depth_intervals, + int patch_size) { + + return propagate_cuda(images, intrinsics, poses, depth, normal, depth_intervals, patch_size); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // bundle adjustment kernels + m.def("propagate", &propagate, "plane propagation"); +} \ No newline at end of file diff --git a/submodules/Propagation/setup.py b/submodules/Propagation/setup.py new file mode 100644 index 00000000..d89ef64d --- /dev/null +++ b/submodules/Propagation/setup.py @@ -0,0 +1,24 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +import os.path as osp +ROOT = osp.dirname(osp.abspath(__file__)) + +setup( + name='gaussianpro', + ext_modules=[ + CUDAExtension('gaussianpro', + sources=[ + 'PatchMatch.cpp', + 'Propagation.cu', + 'pro.cpp' + ], + extra_compile_args={ + 'cxx': ['-O3'], + 'nvcc': ['-O3', '-gencode=arch=compute_80,code=sm_80', + ] + }, + ), + ], + cmdclass={ 'build_ext' : BuildExtension } +) \ No newline at end of file diff --git a/submodules/diff-surfel-rasterization b/submodules/diff-surfel-rasterization index 7bdbd515..94f6d305 160000 --- a/submodules/diff-surfel-rasterization +++ b/submodules/diff-surfel-rasterization @@ -1 +1 @@ -Subproject commit 7bdbd5157fe5667b5cfbc5a51ab402c957f66e22 +Subproject commit 94f6d305556968a35651ba08cee3216ba4a74959 diff --git a/submodules/fused-ssim b/submodules/fused-ssim new file mode 160000 index 00000000..d99e3d27 --- /dev/null +++ b/submodules/fused-ssim @@ -0,0 +1 @@ +Subproject commit d99e3d27513fa3563d98f74fcd40fd429e9e9b0e diff --git a/train.py b/train.py index 614573eb..1e96369d 100644 --- a/train.py +++ b/train.py @@ -12,11 +12,13 @@ import os import torch from random import randint -from utils.loss_utils import l1_loss, ssim +from utils.loss_utils import l1_loss_appearance, ssim, l1_loss from gaussian_renderer import render, network_gui import sys -from scene import Scene, GaussianModel +import torch.nn.functional as F +from scene import Scene, GaussianModel, AppearanceModel from utils.general_utils import safe_state +from utils.patchmatch import process_propagation import uuid from tqdm import tqdm from utils.image_utils import psnr, render_net_image @@ -28,12 +30,98 @@ except ImportError: TENSORBOARD_FOUND = False +def prune_low_contribution_gaussians(gaussians, cameras, pipe, bg, K=5, prune_ratio=0.1): + top_list = [None, ] * K + for i, cam in enumerate(cameras): + trans = render(cam, gaussians, pipe, bg, record_transmittance=True, skip_geometric=True)["transmittance_avg"] + if top_list[0] is not None: + m = trans > top_list[0] + if m.any(): + for i in range(K - 1): + top_list[K - 1 - i][m] = top_list[K - 2 - i][m] + top_list[0][m] = trans[m] + else: + top_list = [trans.clone() for _ in range(K)] + + contribution = torch.stack(top_list, dim=-1).mean(-1) + tile = torch.quantile(contribution, prune_ratio) + prune_mask = contribution < tile + gaussians.prune_points(prune_mask) + torch.cuda.empty_cache() + +def ranking_loss(error, penalize_ratio=0.7, extra_weights=None , type='mean'): + error, indices = torch.sort(error) + # only sum relatively small errors + s_error = torch.index_select(error, 0, index=indices[:int(penalize_ratio * indices.shape[0])]) + if extra_weights is not None: + weights = torch.index_select(extra_weights, 0, index=indices[:int(penalize_ratio * indices.shape[0])]) + s_error = s_error * weights + + if type == 'mean': + return torch.mean(s_error) + elif type == 'sum': + return torch.sum(s_error) + +def normal_gradient_loss(rend_normal, gt_normal): + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().unsqueeze(0).unsqueeze(0).to(rend_normal.device) / 4 + sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).float().unsqueeze(0).unsqueeze(0).to(rend_normal.device) / 4 + + rend_grad_x = F.conv2d(rend_normal, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) + rend_grad_y = F.conv2d(rend_normal, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) + + gt_grad_x = F.conv2d(gt_normal, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) + gt_grad_y = F.conv2d(gt_normal, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) + + loss_x = F.mse_loss(rend_grad_x, gt_grad_x) + loss_y = F.mse_loss(rend_grad_y, gt_grad_y) + + return loss_x + loss_y + +def edge_aware_normal_gradient_loss(gt_image, rend_normal, gt_normal, prior_normal_mask, edge_threshold=1): + # Define Sobel filters + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().unsqueeze(0).unsqueeze(0).to(rend_normal.device) / 8 + sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).float().unsqueeze(0).unsqueeze(0).to(rend_normal.device) / 8 + + # Compute gradients of rendered and ground truth normals + rend_grad_x = F.conv2d(rend_normal, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) + rend_grad_y = F.conv2d(rend_normal, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) + + gt_grad_x = F.conv2d(gt_normal, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) + gt_grad_y = F.conv2d(gt_normal, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) + + # Compute gradients of gt_image for edge detection + dI_dx = torch.cat([F.conv2d(gt_image[i].unsqueeze(0), sobel_x, padding=1) for i in range(gt_image.shape[0])]) + dI_dx = torch.mean(torch.abs(dI_dx), 1, keepdim=True) + dI_dy = torch.cat([F.conv2d(gt_image[i].unsqueeze(0), sobel_y, padding=1) for i in range(gt_image.shape[0])]) + dI_dy = torch.mean(torch.abs(dI_dy), 1, keepdim=True) + + # Compute edge strength + edge_strength = dI_dx + dI_dy + + # Create non-edge mask + non_edge_mask = (edge_strength < edge_threshold).float() + + # Compute loss for gradients + loss_x = F.mse_loss(rend_grad_x, gt_grad_x) + loss_y = F.mse_loss(rend_grad_y, gt_grad_y) + loss = loss_x + loss_y + + # Apply non-edge mask and prior_normal_mask + masked_loss = loss * non_edge_mask * prior_normal_mask + + # Normalize by the number of non-edge pixels + num_non_edge_pixels = torch.sum(non_edge_mask * prior_normal_mask) + 1e-6 + normalized_loss = torch.sum(masked_loss) / num_non_edge_pixels + + return normalized_loss + def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint): first_iter = 0 tb_writer = prepare_output_and_logger(dataset) gaussians = GaussianModel(dataset.sh_degree) scene = Scene(dataset, gaussians) gaussians.training_setup(opt) + if checkpoint: (model_params, first_iter) = torch.load(checkpoint) gaussians.restore(model_params, opt) @@ -46,11 +134,18 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi viewpoint_stack = None ema_loss_for_log = 0.0 - ema_dist_for_log = 0.0 + ema_depth_for_log = 0.0 ema_normal_for_log = 0.0 progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") first_iter += 1 + all_cameras = scene.getTrainCameras() + if dataset.use_decoupled_appearance: + appearances = AppearanceModel(len(all_cameras)) + appearances.training_setup(opt) + else: + appearances = None + for iteration in range(first_iter, opt.iterations + 1): iter_start.record() @@ -64,28 +159,77 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi # Pick a random Camera if not viewpoint_stack: viewpoint_stack = scene.getTrainCameras().copy() - viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) - - render_pkg = render(viewpoint_cam, gaussians, pipe, background) + viewpoint_idx = randint(0, len(all_cameras)-1) + viewpoint_cam = all_cameras[viewpoint_idx] + # Set intervals for patch match + # intervals = [-2, -1, 1, 2] + # src_idxs = [viewpoint_idx+itv for itv in intervals if ((itv + viewpoint_idx > 0) and (itv + viewpoint_idx < len(viewpoint_stack)))] + # process_propagation(viewpoint_stack, viewpoint_cam, gaussians, pipe, background, iteration, opt, src_idxs) + render_pkg = render(viewpoint_cam, gaussians, pipe, background, record_transmittance=(iteration < opt.densify_until_iter)) image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] gt_image = viewpoint_cam.original_image.cuda() - Ll1 = l1_loss(image, gt_image) + Ll1 = l1_loss_appearance(image, gt_image, appearances, viewpoint_idx) # use L1 loss for the transformed image if using decoupled appearance loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) + # alpha loss + if opt.lambda_mask > 0: + opacity = 1 - render_pkg["rend_alpha"].clamp(1e-6, 1-1e-6) + bg = 1 - viewpoint_cam.gt_alpha_mask + mask_error = (- bg * torch.log(opacity)).mean() + loss += opt.lambda_mask * mask_error + # regularization - lambda_normal = opt.lambda_normal if iteration > 7000 else 0.0 + lambda_normal = opt.lambda_normal if iteration > 15000 else 0.0 + lambda_depth = opt.propagation_begin if iteration > opt.propagation_begin else 0.0 lambda_dist = opt.lambda_dist if iteration > 3000 else 0.0 - + lambda_normal_prior = opt.lambda_normal_prior * (7000 - iteration) / 7000 if iteration < 7000 else opt.lambda_normal_prior + lambda_normal_gradient = opt.lambda_normal_gradient if iteration > 15000 else 0.0 + + depth_loss = torch.tensor(0.).to("cuda") + normal_loss = torch.tensor(0.).to("cuda") + normal_prior_loss = torch.tensor(0.).to("cuda") + rend_dist = render_pkg["rend_dist"] - rend_normal = render_pkg['rend_normal'] - surf_normal = render_pkg['surf_normal'] - normal_error = (1 - (rend_normal * surf_normal).sum(dim=0))[None] - normal_loss = lambda_normal * (normal_error).mean() + rend_depth = render_pkg["rend_depth"] + surf_depth = render_pkg["surf_depth"] dist_loss = lambda_dist * (rend_dist).mean() + if lambda_depth > 0 and viewpoint_cam.depth_prior is not None: + depth_error = 0.6 * (surf_depth - viewpoint_cam.depth_prior).abs() + \ + 0.4 * (rend_depth - viewpoint_cam.depth_prior).abs() + depth_mask = viewpoint_cam.depth_mask.unsqueeze(0) & viewpoint_cam.gt_alpha_mask + valid_depth_sum = depth_mask.sum() + 1e-5 + depth_loss += lambda_depth * (depth_error[depth_mask & ~torch.isnan(depth_error)].sum() / valid_depth_sum) + + rend_normal = render_pkg['rend_normal'] + surf_normal_median = render_pkg['surf_normal'] + surf_normal_expected = render_pkg['surf_normal_expected'] + rend_alpha = render_pkg['rend_alpha'] + + if lambda_normal > 0.0: + normal_error = 0.6 * (1 - F.cosine_similarity(rend_normal, surf_normal_median, dim=0)) + \ + 0.4 * (1 - F.cosine_similarity(rend_normal, surf_normal_expected, dim=0)) + normal_error = normal_error * viewpoint_cam.gt_alpha_mask.mean(dim=0) + normal_error = ranking_loss(normal_error.view(-1), penalize_ratio=0.7, type='mean') + normal_loss += lambda_normal * normal_error + + if lambda_normal_prior > 0 and viewpoint_cam.normal_prior is not None: + prior_normal = viewpoint_cam.normal_prior * (rend_alpha).detach() + prior_normal_mask = viewpoint_cam.normal_mask[0] + + normal_prior_error = (1 - F.cosine_similarity(prior_normal, rend_normal, dim=0)) + \ + (1 - F.cosine_similarity(prior_normal, surf_normal_expected, dim=0)) + normal_prior_error = normal_prior_error * viewpoint_cam.gt_alpha_mask.mean(dim=0) + normal_prior_error = ranking_loss(normal_prior_error[prior_normal_mask], + penalize_ratio=1.0, type='mean') + + normal_prior_loss = lambda_normal_prior * normal_prior_error + if lambda_normal_gradient > 0.0: + normal_prior_loss += lambda_normal_gradient * normal_gradient_loss(surf_normal_median, prior_normal) + # loss - total_loss = loss + dist_loss + normal_loss + total_loss = loss + dist_loss + depth_loss + normal_loss + normal_prior_loss total_loss.backward() @@ -94,14 +238,14 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi with torch.no_grad(): # Progress bar ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log - ema_dist_for_log = 0.4 * dist_loss.item() + 0.6 * ema_dist_for_log - ema_normal_for_log = 0.4 * normal_loss.item() + 0.6 * ema_normal_for_log + ema_depth_for_log = 0.4 * (depth_loss.item() + dist_loss.item()) + 0.6 * ema_depth_for_log + ema_normal_for_log = 0.4 * (normal_loss.item() + normal_prior_loss.item()) + 0.6 * ema_normal_for_log if iteration % 10 == 0: loss_dict = { "Loss": f"{ema_loss_for_log:.{5}f}", - "distort": f"{ema_dist_for_log:.{5}f}", + "depth": f"{ema_depth_for_log:.{5}f}", "normal": f"{ema_normal_for_log:.{5}f}", "Points": f"{len(gaussians.get_xyz)}" } @@ -113,10 +257,10 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi # Log and save if tb_writer is not None: - tb_writer.add_scalar('train_loss_patches/dist_loss', ema_dist_for_log, iteration) + tb_writer.add_scalar('train_loss_patches/dist_loss', ema_depth_for_log, iteration) tb_writer.add_scalar('train_loss_patches/normal_loss', ema_normal_for_log, iteration) - training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) + # training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) if (iteration in saving_iterations): print("\n[ITER {}] Saving Gaussians".format(iteration)) scene.save(iteration) @@ -124,20 +268,37 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi # Densification if iteration < opt.densify_until_iter: - gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) - gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) + gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], + radii[visibility_filter] * (render_pkg["transmittance_avg"][visibility_filter] > 0.01)) + gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter, None) if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: - size_threshold = 20 if iteration > opt.opacity_reset_interval else None - gaussians.densify_and_prune(opt.densify_grad_threshold, opt.opacity_cull, scene.cameras_extent, size_threshold) + prune_big_points = True if iteration > opt.opacity_reset_interval else False + gaussians.densify_and_prune(opt.densify_grad_threshold, opt.opacity_cull, scene.cameras_extent, prune_big_points) + if iteration > opt.densify_from_iter and iteration % opt.split_interval == 0: + gaussians.split_big_points(opt.max_screen_size) + + if iteration > opt.contribution_prune_from_iter and iteration % opt.contribution_prune_interval == 0: + if iteration % opt.opacity_reset_interval == opt.contribution_prune_interval: + print("Skipped Pruning for", iteration) + continue + prune_low_contribution_gaussians(gaussians, all_cameras, pipe, background, + K=1, prune_ratio=opt.contribution_prune_ratio) + print(f'Num gs after contribution prune: {len(gaussians.get_xyz)}') + if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): gaussians.reset_opacity() # Optimizer step if iteration < opt.iterations: + # visible = radii > 0 + # gaussians.optimizer.step(visible, radii.shape[0]) gaussians.optimizer.step() gaussians.optimizer.zero_grad(set_to_none = True) + if appearances is not None: + appearances.optimizer.step() + appearances.optimizer.zero_grad(set_to_none = True) if (iteration in checkpoint_iterations): print("\n[ITER {}] Saving Checkpoint".format(iteration)) @@ -258,8 +419,8 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i parser.add_argument('--ip', type=str, default="127.0.0.1") parser.add_argument('--port', type=int, default=6009) parser.add_argument('--detect_anomaly', action='store_true', default=False) - parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000]) - parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000]) + parser.add_argument("--test_iterations", nargs="+", type=int, default=[1, 7_000, 20_000, 30_000]) + parser.add_argument("--save_iterations", nargs="+", type=int, default=[500, 7_000, 20_000, 30_000]) parser.add_argument("--quiet", action="store_true") parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) parser.add_argument("--start_checkpoint", type=str, default = None) diff --git a/train_fast.py b/train_fast.py new file mode 100644 index 00000000..d9298452 --- /dev/null +++ b/train_fast.py @@ -0,0 +1,427 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import torch +from random import randint +from utils.loss_utils import l1_loss_appearance, ssim, l1_loss +from gaussian_renderer import render, network_gui +import sys +import torch.nn.functional as F +from scene import Scene, GaussianModel, AppearanceModel +from utils.general_utils import safe_state +from utils.patchmatch import process_propagation +import uuid +from tqdm import tqdm +from utils.image_utils import psnr, render_net_image +from argparse import ArgumentParser, Namespace +from arguments import ModelParams, PipelineParams, OptimizationParams +try: + from torch.utils.tensorboard import SummaryWriter + TENSORBOARD_FOUND = True +except ImportError: + TENSORBOARD_FOUND = False + +def prune_low_contribution_gaussians(gaussians, cameras, pipe, bg, K=5, prune_ratio=0.1): + top_list = [None, ] * K + for i, cam in enumerate(cameras): + trans = render(cam, gaussians, pipe, bg, record_transmittance=True, skip_geometric=True)["transmittance_avg"] + if top_list[0] is not None: + m = trans > top_list[0] + if m.any(): + for i in range(K - 1): + top_list[K - 1 - i][m] = top_list[K - 2 - i][m] + top_list[0][m] = trans[m] + else: + top_list = [trans.clone() for _ in range(K)] + + contribution = torch.stack(top_list, dim=-1).mean(-1) + tile = torch.quantile(contribution, prune_ratio) + prune_mask = contribution < tile + gaussians.prune_points(prune_mask) + torch.cuda.empty_cache() + +def ranking_loss(error, penalize_ratio=0.7, extra_weights=None , type='mean'): + error, indices = torch.sort(error) + # only sum relatively small errors + s_error = torch.index_select(error, 0, index=indices[:int(penalize_ratio * indices.shape[0])]) + if extra_weights is not None: + weights = torch.index_select(extra_weights, 0, index=indices[:int(penalize_ratio * indices.shape[0])]) + s_error = s_error * weights + + if type == 'mean': + return torch.mean(s_error) + elif type == 'sum': + return torch.sum(s_error) + +def normal_gradient_loss(rend_normal, gt_normal): + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().unsqueeze(0).unsqueeze(0).to(rend_normal.device) / 4 + sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).float().unsqueeze(0).unsqueeze(0).to(rend_normal.device) / 4 + + rend_grad_x = F.conv2d(rend_normal, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) + rend_grad_y = F.conv2d(rend_normal, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) + + gt_grad_x = F.conv2d(gt_normal, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) + gt_grad_y = F.conv2d(gt_normal, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) + + loss_x = F.mse_loss(rend_grad_x, gt_grad_x) + loss_y = F.mse_loss(rend_grad_y, gt_grad_y) + + return loss_x + loss_y + +def edge_aware_normal_gradient_loss(gt_image, rend_normal, gt_normal, prior_normal_mask, edge_threshold=1): + # Define Sobel filters + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().unsqueeze(0).unsqueeze(0).to(rend_normal.device) / 8 + sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).float().unsqueeze(0).unsqueeze(0).to(rend_normal.device) / 8 + + # Compute gradients of rendered and ground truth normals + rend_grad_x = F.conv2d(rend_normal, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) + rend_grad_y = F.conv2d(rend_normal, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) + + gt_grad_x = F.conv2d(gt_normal, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) + gt_grad_y = F.conv2d(gt_normal, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) + + # Compute gradients of gt_image for edge detection + dI_dx = torch.cat([F.conv2d(gt_image[i].unsqueeze(0), sobel_x, padding=1) for i in range(gt_image.shape[0])]) + dI_dx = torch.mean(torch.abs(dI_dx), 1, keepdim=True) + dI_dy = torch.cat([F.conv2d(gt_image[i].unsqueeze(0), sobel_y, padding=1) for i in range(gt_image.shape[0])]) + dI_dy = torch.mean(torch.abs(dI_dy), 1, keepdim=True) + + # Compute edge strength + edge_strength = dI_dx + dI_dy + + # Create non-edge mask + non_edge_mask = (edge_strength < edge_threshold).float() + + # Compute loss for gradients + loss_x = F.mse_loss(rend_grad_x, gt_grad_x) + loss_y = F.mse_loss(rend_grad_y, gt_grad_y) + loss = loss_x + loss_y + + # Apply non-edge mask and prior_normal_mask + masked_loss = loss * non_edge_mask * prior_normal_mask + + # Normalize by the number of non-edge pixels + num_non_edge_pixels = torch.sum(non_edge_mask * prior_normal_mask) + 1e-6 + normalized_loss = torch.sum(masked_loss) / num_non_edge_pixels + + return normalized_loss + +def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint): + first_iter = 0 + tb_writer = prepare_output_and_logger(dataset) + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians) + gaussians.training_setup(opt) + + if checkpoint: + (model_params, first_iter) = torch.load(checkpoint) + gaussians.restore(model_params, opt) + + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + iter_start = torch.cuda.Event(enable_timing = True) + iter_end = torch.cuda.Event(enable_timing = True) + + viewpoint_stack = None + ema_loss_for_log = 0.0 + ema_depth_for_log = 0.0 + ema_normal_for_log = 0.0 + + progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") + first_iter += 1 + all_cameras = scene.getTrainCameras() + if dataset.use_decoupled_appearance: + appearances = AppearanceModel(len(all_cameras)) + appearances.training_setup(opt) + else: + appearances = None + + for iteration in range(first_iter, opt.iterations + 1): + + iter_start.record() + + gaussians.update_learning_rate(iteration) + + # Every 1000 its we increase the levels of SH up to a maximum degree + if iteration % 1000 == 0: + gaussians.oneupSHdegree() + + # Pick a random Camera + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_idx = randint(0, len(all_cameras)-1) + viewpoint_cam = all_cameras[viewpoint_idx] + render_pkg = render(viewpoint_cam, gaussians, pipe, background, record_transmittance=(iteration < opt.densify_until_iter)) + image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] + + gt_image = viewpoint_cam.original_image.cuda() + Ll1 = l1_loss_appearance(image, gt_image, appearances, viewpoint_idx) # use L1 loss for the transformed image if using decoupled appearance + loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) + + # alpha loss + if opt.lambda_mask > 0: + opacity = 1 - render_pkg["rend_alpha"].clamp(1e-6, 1-1e-6) + bg = 1 - viewpoint_cam.gt_alpha_mask + mask_error = (- bg * torch.log(opacity)).mean() + loss += opt.lambda_mask * mask_error + + # regularization + lambda_normal = opt.lambda_normal + lambda_dist = opt.lambda_dist + lambda_normal_prior = opt.lambda_normal_prior + lambda_normal_gradient = opt.lambda_normal_gradient + + depth_loss = torch.tensor(0.).to("cuda") + normal_loss = torch.tensor(0.).to("cuda") + normal_prior_loss = torch.tensor(0.).to("cuda") + + rend_dist = render_pkg["rend_dist"] + dist_loss = lambda_dist * (rend_dist).mean() + + rend_normal = render_pkg['rend_normal'] + surf_normal_median = render_pkg['surf_normal'] + surf_normal_expected = render_pkg['surf_normal_expected'] + rend_alpha = render_pkg['rend_alpha'] + + if lambda_normal > 0.0: + normal_error = 0.6 * (1 - F.cosine_similarity(rend_normal, surf_normal_median, dim=0)) + \ + 0.4 * (1 - F.cosine_similarity(rend_normal, surf_normal_expected, dim=0)) + normal_error = normal_error * viewpoint_cam.gt_alpha_mask.mean(dim=0) + normal_error = ranking_loss(normal_error.view(-1), penalize_ratio=1.0, type='mean') + normal_loss += lambda_normal * normal_error + + if lambda_normal_prior > 0 and viewpoint_cam.normal_prior is not None: + prior_normal = viewpoint_cam.normal_prior * (rend_alpha).detach() + prior_normal_mask = viewpoint_cam.normal_mask[0] + + normal_prior_error = (1 - F.cosine_similarity(prior_normal, rend_normal, dim=0)) + \ + (1 - F.cosine_similarity(prior_normal, surf_normal_expected, dim=0)) + normal_prior_error = normal_prior_error * viewpoint_cam.gt_alpha_mask.mean(dim=0) + normal_prior_error = ranking_loss(normal_prior_error[prior_normal_mask], + penalize_ratio=1.0, type='mean') + + normal_prior_loss = lambda_normal_prior * normal_prior_error + if lambda_normal_gradient > 0.0: + normal_prior_loss += lambda_normal_gradient * normal_gradient_loss(surf_normal_median, prior_normal) + + # loss + total_loss = loss + dist_loss + depth_loss + normal_loss + normal_prior_loss + + total_loss.backward() + + iter_end.record() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + ema_depth_for_log = 0.4 * (depth_loss.item() + dist_loss.item()) + 0.6 * ema_depth_for_log + ema_normal_for_log = 0.4 * (normal_loss.item() + normal_prior_loss.item()) + 0.6 * ema_normal_for_log + + + if iteration % 10 == 0: + loss_dict = { + "Loss": f"{ema_loss_for_log:.{5}f}", + "depth": f"{ema_depth_for_log:.{5}f}", + "normal": f"{ema_normal_for_log:.{5}f}", + "Points": f"{len(gaussians.get_xyz)}" + } + progress_bar.set_postfix(loss_dict) + + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + + # Log and save + if tb_writer is not None: + tb_writer.add_scalar('train_loss_patches/dist_loss', ema_depth_for_log, iteration) + tb_writer.add_scalar('train_loss_patches/normal_loss', ema_normal_for_log, iteration) + + # training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) + if (iteration in saving_iterations): + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(iteration) + + + # Densification + if iteration < opt.densify_until_iter: + gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], + radii[visibility_filter] * (render_pkg["transmittance_avg"][visibility_filter] > 0.01)) + gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter, None) + + if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: + prune_big_points = True if iteration > opt.opacity_reset_interval else False + gaussians.densify_and_prune(opt.densify_grad_threshold, opt.opacity_cull, scene.cameras_extent, prune_big_points) + + if iteration > opt.densify_from_iter and iteration % opt.split_interval == 0: + gaussians.split_big_points(opt.max_screen_size) + + if iteration > opt.contribution_prune_from_iter and iteration % opt.contribution_prune_interval == 0: + if iteration % opt.opacity_reset_interval == opt.contribution_prune_interval: + print("Skipped Pruning for", iteration) + continue + prune_low_contribution_gaussians(gaussians, all_cameras, pipe, background, + K=1, prune_ratio=opt.contribution_prune_ratio) + print(f'Num gs after contribution prune: {len(gaussians.get_xyz)}') + + if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): + gaussians.reset_opacity() + + # Optimizer step + if iteration < opt.iterations: + # visible = radii > 0 + # gaussians.optimizer.step(visible, radii.shape[0]) + gaussians.optimizer.step() + gaussians.optimizer.zero_grad(set_to_none = True) + if appearances is not None: + appearances.optimizer.step() + appearances.optimizer.zero_grad(set_to_none = True) + + if (iteration in checkpoint_iterations): + print("\n[ITER {}] Saving Checkpoint".format(iteration)) + torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") + + with torch.no_grad(): + if network_gui.conn == None: + network_gui.try_connect(dataset.render_items) + while network_gui.conn != None: + try: + net_image_bytes = None + custom_cam, do_training, keep_alive, scaling_modifer, render_mode = network_gui.receive() + if custom_cam != None: + render_pkg = render(custom_cam, gaussians, pipe, background, scaling_modifer) + net_image = render_net_image(render_pkg, dataset.render_items, render_mode, custom_cam) + net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) + metrics_dict = { + "#": gaussians.get_opacity.shape[0], + "loss": ema_loss_for_log + # Add more metrics as needed + } + # Send the data + network_gui.send(net_image_bytes, dataset.source_path, metrics_dict) + if do_training and ((iteration < int(opt.iterations)) or not keep_alive): + break + except Exception as e: + # raise e + network_gui.conn = None + +def prepare_output_and_logger(args): + if not args.model_path: + if os.getenv('OAR_JOB_ID'): + unique_str=os.getenv('OAR_JOB_ID') + else: + unique_str = str(uuid.uuid4()) + args.model_path = os.path.join("./output/", unique_str[0:10]) + + # Set up output folder + print("Output folder: {}".format(args.model_path)) + os.makedirs(args.model_path, exist_ok = True) + with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: + cfg_log_f.write(str(Namespace(**vars(args)))) + + # Create Tensorboard writer + tb_writer = None + if TENSORBOARD_FOUND: + tb_writer = SummaryWriter(args.model_path) + else: + print("Tensorboard not available: not logging progress") + return tb_writer + +@torch.no_grad() +def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): + if tb_writer: + tb_writer.add_scalar('train_loss_patches/reg_loss', Ll1.item(), iteration) + tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) + tb_writer.add_scalar('iter_time', elapsed, iteration) + tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) + + # Report test and samples of training set + if iteration in testing_iterations: + torch.cuda.empty_cache() + validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, + {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) + + for config in validation_configs: + if config['cameras'] and len(config['cameras']) > 0: + l1_test = 0.0 + psnr_test = 0.0 + for idx, viewpoint in enumerate(config['cameras']): + render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs) + image = torch.clamp(render_pkg["render"], 0.0, 1.0) + gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) + if tb_writer and (idx < 5): + from utils.general_utils import colormap + depth = render_pkg["surf_depth"] + norm = depth.max() + depth = depth / norm + depth = colormap(depth.cpu().numpy()[0], cmap='turbo') + tb_writer.add_images(config['name'] + "_view_{}/depth".format(viewpoint.image_name), depth[None], global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) + + try: + rend_alpha = render_pkg['rend_alpha'] + rend_normal = render_pkg["rend_normal"] * 0.5 + 0.5 + surf_normal = render_pkg["surf_normal"] * 0.5 + 0.5 + tb_writer.add_images(config['name'] + "_view_{}/rend_normal".format(viewpoint.image_name), rend_normal[None], global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}/surf_normal".format(viewpoint.image_name), surf_normal[None], global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}/rend_alpha".format(viewpoint.image_name), rend_alpha[None], global_step=iteration) + + rend_dist = render_pkg["rend_dist"] + rend_dist = colormap(rend_dist.cpu().numpy()[0]) + tb_writer.add_images(config['name'] + "_view_{}/rend_dist".format(viewpoint.image_name), rend_dist[None], global_step=iteration) + except: + pass + + if iteration == testing_iterations[0]: + tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) + + l1_test += l1_loss(image, gt_image).mean().double() + psnr_test += psnr(image, gt_image).mean().double() + + psnr_test /= len(config['cameras']) + l1_test /= len(config['cameras']) + print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) + if tb_writer: + tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) + tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) + + torch.cuda.empty_cache() + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument('--ip', type=str, default="127.0.0.1") + parser.add_argument('--port', type=int, default=6009) + parser.add_argument('--detect_anomaly', action='store_true', default=False) + parser.add_argument("--test_iterations", nargs="+", type=int, default=[1, 7_000, 20_000, 30_000]) + parser.add_argument("--save_iterations", nargs="+", type=int, default=[500, 7_000, 20_000, 30_000]) + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--start_checkpoint", type=str, default = None) + args = parser.parse_args(sys.argv[1:]) + args.save_iterations.append(args.iterations) + + print("Optimizing " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + # Start GUI server, configure and run training + network_gui.init(args.ip, args.port) + torch.autograd.set_detect_anomaly(args.detect_anomaly) + training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint) + + # All done + print("\nTraining complete.") \ No newline at end of file diff --git a/train_progressive.py b/train_progressive.py new file mode 100644 index 00000000..c8093593 --- /dev/null +++ b/train_progressive.py @@ -0,0 +1,461 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import torch +from random import randint +from fused_ssim import fused_ssim +from utils.loss_utils import l1_loss_appearance, l1_loss, edge_aware_curvature_loss +from gaussian_renderer import render, network_gui +import sys +import torch.nn.functional as F +from scene import Scene, GaussianModel, AppearanceModel +from utils.general_utils import safe_state +from utils.patchmatch import process_propagation +import uuid +from tqdm import tqdm +from utils.image_utils import psnr, render_net_image +from argparse import ArgumentParser, Namespace +from arguments import ModelParams, PipelineParams, OptimizationParams +try: + from torch.utils.tensorboard import SummaryWriter + TENSORBOARD_FOUND = True +except ImportError: + TENSORBOARD_FOUND = False + +def prune_low_contribution_gaussians(gaussians, cameras, pipe, bg, K=5, prune_ratio=0.1): + top_list = [None, ] * K + for i, cam in enumerate(cameras): + trans = render(cam, gaussians, pipe, bg, record_transmittance=True, skip_geometric=True)["transmittance_avg"] + if top_list[0] is not None: + m = trans > top_list[0] + if m.any(): + for i in range(K - 1): + top_list[K - 1 - i][m] = top_list[K - 2 - i][m] + top_list[0][m] = trans[m] + else: + top_list = [trans.clone() for _ in range(K)] + + contribution = torch.stack(top_list, dim=-1).mean(-1) + tile = torch.quantile(contribution, prune_ratio) + prune_mask = contribution < tile + gaussians.prune_points(prune_mask) + torch.cuda.empty_cache() + +def ranking_loss(error, penalize_ratio=0.7, extra_weights=None , type='mean'): + error, indices = torch.sort(error) + # only sum relatively small errors + s_error = torch.index_select(error, 0, index=indices[:int(penalize_ratio * indices.shape[0])]) + if extra_weights is not None: + weights = torch.index_select(extra_weights, 0, index=indices[:int(penalize_ratio * indices.shape[0])]) + s_error = s_error * weights + + if type == 'mean': + return torch.mean(s_error) + elif type == 'sum': + return torch.sum(s_error) + +def normal_gradient_loss(rend_normal, gt_normal): + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().unsqueeze(0).unsqueeze(0).to(rend_normal.device) / 4 + sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).float().unsqueeze(0).unsqueeze(0).to(rend_normal.device) / 4 + + rend_grad_x = F.conv2d(rend_normal, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) + rend_grad_y = F.conv2d(rend_normal, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) + + gt_grad_x = F.conv2d(gt_normal, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) + gt_grad_y = F.conv2d(gt_normal, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) + + loss_x = F.mse_loss(rend_grad_x, gt_grad_x) + loss_y = F.mse_loss(rend_grad_y, gt_grad_y) + + return loss_x + loss_y + +def edge_aware_normal_gradient_loss(gt_image, rend_normal, gt_normal, prior_normal_mask, edge_threshold=1): + # Define Sobel filters + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().unsqueeze(0).unsqueeze(0).to(rend_normal.device) / 8 + sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).float().unsqueeze(0).unsqueeze(0).to(rend_normal.device) / 8 + + # Compute gradients of rendered and ground truth normals + rend_grad_x = F.conv2d(rend_normal, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) + rend_grad_y = F.conv2d(rend_normal, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) + + gt_grad_x = F.conv2d(gt_normal, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) + gt_grad_y = F.conv2d(gt_normal, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) + + # Compute gradients of gt_image for edge detection + dI_dx = torch.cat([F.conv2d(gt_image[i].unsqueeze(0), sobel_x, padding=1) for i in range(gt_image.shape[0])]) + dI_dx = torch.mean(torch.abs(dI_dx), 1, keepdim=True) + dI_dy = torch.cat([F.conv2d(gt_image[i].unsqueeze(0), sobel_y, padding=1) for i in range(gt_image.shape[0])]) + dI_dy = torch.mean(torch.abs(dI_dy), 1, keepdim=True) + + # Compute edge strength + edge_strength = dI_dx + dI_dy + + # Create non-edge mask + non_edge_mask = (edge_strength < edge_threshold).float() + + # Compute loss for gradients + loss_x = F.mse_loss(rend_grad_x, gt_grad_x) + loss_y = F.mse_loss(rend_grad_y, gt_grad_y) + loss = loss_x + loss_y + + # Apply non-edge mask and prior_normal_mask + masked_loss = loss * non_edge_mask * prior_normal_mask + + # Normalize by the number of non-edge pixels + num_non_edge_pixels = torch.sum(non_edge_mask * prior_normal_mask) + 1e-6 + normalized_loss = torch.sum(masked_loss) / num_non_edge_pixels + + return normalized_loss + +def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint): + first_iter = 0 + tb_writer = prepare_output_and_logger(dataset) + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians, resolution_scales=[4, 2, 1]) + gaussians.training_setup(opt) + if checkpoint: + (model_params, first_iter) = torch.load(checkpoint) + gaussians.restore(model_params, opt) + + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + iter_start = torch.cuda.Event(enable_timing = True) + iter_end = torch.cuda.Event(enable_timing = True) + + viewpoint_stack = None + ema_loss_for_log = 0.0 + ema_depth_for_log = 0.0 + ema_normal_for_log = 0.0 + + progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") + first_iter += 1 + all_cameras = scene.getTrainCameras() + if dataset.use_decoupled_appearance: + appearances = AppearanceModel(len(all_cameras)) + appearances.training_setup(opt) + else: + appearances = None + + for iteration in range(first_iter, opt.iterations + 1): + + iter_start.record() + + gaussians.update_learning_rate(iteration) + + # Every 1000 its we increase the levels of SH up to a maximum degree + if iteration % 1000 == 0: + gaussians.oneupSHdegree() + + if iteration - 1 == 0: + scale = 4 + viewpoint_stack = scene.getTrainCameras(scale=scale) + elif iteration - 1 == 5000: + scale = 2 + viewpoint_stack = scene.getTrainCameras(scale=scale) + elif iteration - 1 >= 10000: + scale = 1 + viewpoint_stack = scene.getTrainCameras(scale=scale) + + viewpoint_idx = randint(0, len(viewpoint_stack)-1) + viewpoint_cam = viewpoint_stack[viewpoint_idx] + + # Set intervals for patch match + # intervals = [-2, -1, 1, 2] + # src_idxs = [viewpoint_idx+itv for itv in intervals if ((itv + viewpoint_idx > 0) and (itv + viewpoint_idx < len(viewpoint_stack)))] + # depth_loss = process_propagation(viewpoint_stack, viewpoint_cam, gaussians, pipe, background, iteration, opt, src_idxs) + render_pkg = render(viewpoint_cam, gaussians, pipe, background, record_transmittance=(iteration < opt.densify_until_iter)) + image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] + + gt_image = viewpoint_cam.original_image.cuda() + Ll1 = l1_loss_appearance(image, gt_image, appearances, viewpoint_idx) # use L1 loss for the transformed image if using decoupled appearance + ssim_value = fused_ssim(image.unsqueeze(0), gt_image.unsqueeze(0)) + loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_value) + + # alpha loss + if opt.lambda_mask > 0: + opacity = 1 - render_pkg["rend_alpha"].clamp(1e-6, 1-1e-6) + bg = 1 - viewpoint_cam.gt_alpha_mask + mask_error = (- bg * torch.log(opacity)).mean() + loss += opt.lambda_mask * mask_error + + # regularization + lambda_normal = opt.lambda_normal if iteration > 7000 else 0.0 + lambda_depth = opt.propagation_begin if iteration > opt.propagation_begin else 0.0 + lambda_dist = opt.lambda_dist if iteration > 3000 else 0.0 + lambda_normal_prior = opt.lambda_normal_prior * (7000 - iteration) / 7000 if iteration < 7000 else opt.lambda_normal_prior + lambda_normal_gradient = opt.lambda_normal_gradient if iteration > 7000 else 0.0 + + depth_loss = torch.tensor(0.).to("cuda") + normal_loss = torch.tensor(0.).to("cuda") + normal_prior_loss = torch.tensor(0.).to("cuda") + + rend_dist = render_pkg["rend_dist"] + rend_depth = render_pkg["rend_depth"] + surf_depth = render_pkg["surf_depth"] + rend_alpha = render_pkg['rend_alpha'] + gt_mask = viewpoint_cam.gt_alpha_mask.mean(dim=0) + valid_pixel_count = gt_mask.sum() + + dist_error = rend_dist * gt_mask + dist_loss = lambda_dist * (dist_error.sum() / valid_pixel_count) + + if lambda_depth > 0 and viewpoint_cam.depth_prior is not None: + depth_error = 0.6 * (surf_depth - viewpoint_cam.depth_prior).abs() + \ + 0.4 * (rend_depth - viewpoint_cam.depth_prior).abs() + depth_mask = viewpoint_cam.depth_mask.unsqueeze(0) & viewpoint_cam.gt_alpha_mask + valid_depth_sum = depth_mask.sum() + 1e-5 + depth_loss += lambda_depth * (depth_error[depth_mask & ~torch.isnan(depth_error)].sum() / valid_depth_sum) + + # fix normal + rend_normal = render_pkg['rend_normal'] / rend_alpha.detach() + rend_normal = torch.nan_to_num(rend_normal, 0, 0) + surf_normal_expected = render_pkg['surf_normal_expected'] / rend_alpha.detach() + surf_normal_expected = torch.nan_to_num(surf_normal_expected, 0, 0) + surf_normal_median = render_pkg['surf_normal'] + + if lambda_normal > 0.0: + normal_error = 0.6 * (1 - F.cosine_similarity(rend_normal, surf_normal_median, dim=0)) + \ + 0.4 * (1 - F.cosine_similarity(rend_normal, surf_normal_expected, dim=0)) + normal_error = normal_error * gt_mask + normal_loss = lambda_normal * (normal_error.sum() / valid_pixel_count) + if lambda_normal_gradient > 0.0: + curvature_error = 0.6 *edge_aware_curvature_loss(gt_image, surf_normal_median, gt_mask) + \ + 0.4 * edge_aware_curvature_loss(gt_image, surf_normal_expected, gt_mask) + normal_loss += lambda_normal_gradient * curvature_error + + if lambda_normal_prior > 0 and viewpoint_cam.normal_prior is not None: + prior_normal = viewpoint_cam.normal_prior * (rend_alpha).detach() + prior_normal_mask = viewpoint_cam.normal_mask[0] + + normal_prior_error = (1 - F.cosine_similarity(prior_normal, rend_normal, dim=0)) + \ + (1 - F.cosine_similarity(prior_normal, surf_normal_expected, dim=0)) + normal_prior_error = normal_prior_error * gt_mask + normal_prior_error = ranking_loss(normal_prior_error[prior_normal_mask], + penalize_ratio=1.0, type='mean') + + normal_prior_loss = lambda_normal_prior * normal_prior_error + if lambda_normal_gradient > 0.0: + normal_prior_loss += lambda_normal_gradient * normal_gradient_loss(surf_normal_median, prior_normal) + + # loss + total_loss = loss + dist_loss + depth_loss + normal_loss + normal_prior_loss + + total_loss.backward() + + iter_end.record() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + ema_depth_for_log = 0.4 * (depth_loss.item() + dist_loss.item()) + 0.6 * ema_depth_for_log + ema_normal_for_log = 0.4 * (normal_loss.item() + normal_prior_loss.item()) + 0.6 * ema_normal_for_log + + + if iteration % 10 == 0: + loss_dict = { + "Loss": f"{ema_loss_for_log:.{5}f}", + "depth": f"{ema_depth_for_log:.{5}f}", + "normal": f"{ema_normal_for_log:.{5}f}", + "Points": f"{len(gaussians.get_xyz)}" + } + progress_bar.set_postfix(loss_dict) + + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + + # Log and save + if tb_writer is not None: + tb_writer.add_scalar('train_loss_patches/dist_loss', ema_depth_for_log, iteration) + tb_writer.add_scalar('train_loss_patches/normal_loss', ema_normal_for_log, iteration) + + # training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) + if (iteration in saving_iterations): + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(iteration) + + + # Densification + if iteration < opt.densify_until_iter: + gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], + radii[visibility_filter] * (render_pkg["transmittance_avg"][visibility_filter] > 0.01)) + gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter, None) + + if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: + prune_big_points = True if iteration > opt.opacity_reset_interval else False + gaussians.densify_and_prune(opt.densify_grad_threshold, opt.opacity_cull, scene.cameras_extent, prune_big_points) + print(f'Num gs after opacity prune: {len(gaussians.get_xyz)}') + + if iteration > opt.contribution_prune_from_iter and iteration % opt.contribution_prune_interval == 0: + if iteration % opt.opacity_reset_interval == opt.contribution_prune_interval: + print("Skipped Pruning for", iteration) + continue + prune_low_contribution_gaussians(gaussians, all_cameras, pipe, background, + K=1, prune_ratio=opt.contribution_prune_ratio) + print(f'Num gs after contribution prune: {len(gaussians.get_xyz)}') + + if iteration > opt.densify_from_iter and iteration % opt.split_interval == 0: + gaussians.split_big_points(opt.max_screen_size) + + if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): + gaussians.reset_opacity() + + # Optimizer step + if iteration < opt.iterations: + # visible = radii > 0 + # gaussians.optimizer.step(visible, radii.shape[0]) + gaussians.optimizer.step() + gaussians.optimizer.zero_grad(set_to_none = True) + if appearances is not None: + appearances.optimizer.step() + appearances.optimizer.zero_grad(set_to_none = True) + + if (iteration in checkpoint_iterations): + print("\n[ITER {}] Saving Checkpoint".format(iteration)) + torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") + + with torch.no_grad(): + if network_gui.conn == None: + network_gui.try_connect(dataset.render_items) + while network_gui.conn != None: + try: + net_image_bytes = None + custom_cam, do_training, keep_alive, scaling_modifer, render_mode = network_gui.receive() + if custom_cam != None: + render_pkg = render(custom_cam, gaussians, pipe, background, scaling_modifer) + net_image = render_net_image(render_pkg, dataset.render_items, render_mode, custom_cam) + net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) + metrics_dict = { + "#": gaussians.get_opacity.shape[0], + "loss": ema_loss_for_log + # Add more metrics as needed + } + # Send the data + network_gui.send(net_image_bytes, dataset.source_path, metrics_dict) + if do_training and ((iteration < int(opt.iterations)) or not keep_alive): + break + except Exception as e: + # raise e + network_gui.conn = None + +def prepare_output_and_logger(args): + if not args.model_path: + if os.getenv('OAR_JOB_ID'): + unique_str=os.getenv('OAR_JOB_ID') + else: + unique_str = str(uuid.uuid4()) + args.model_path = os.path.join("./output/", unique_str[0:10]) + + # Set up output folder + print("Output folder: {}".format(args.model_path)) + os.makedirs(args.model_path, exist_ok = True) + with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: + cfg_log_f.write(str(Namespace(**vars(args)))) + + # Create Tensorboard writer + tb_writer = None + if TENSORBOARD_FOUND: + tb_writer = SummaryWriter(args.model_path) + else: + print("Tensorboard not available: not logging progress") + return tb_writer + +@torch.no_grad() +def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): + if tb_writer: + tb_writer.add_scalar('train_loss_patches/reg_loss', Ll1.item(), iteration) + tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) + tb_writer.add_scalar('iter_time', elapsed, iteration) + tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) + + # Report test and samples of training set + if iteration in testing_iterations: + torch.cuda.empty_cache() + validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, + {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) + + for config in validation_configs: + if config['cameras'] and len(config['cameras']) > 0: + l1_test = 0.0 + psnr_test = 0.0 + for idx, viewpoint in enumerate(config['cameras']): + render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs) + image = torch.clamp(render_pkg["render"], 0.0, 1.0) + gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) + if tb_writer and (idx < 5): + from utils.general_utils import colormap + depth = render_pkg["surf_depth"] + norm = depth.max() + depth = depth / norm + depth = colormap(depth.cpu().numpy()[0], cmap='turbo') + tb_writer.add_images(config['name'] + "_view_{}/depth".format(viewpoint.image_name), depth[None], global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) + + try: + rend_alpha = render_pkg['rend_alpha'] + rend_normal = render_pkg["rend_normal"] * 0.5 + 0.5 + surf_normal = render_pkg["surf_normal"] * 0.5 + 0.5 + tb_writer.add_images(config['name'] + "_view_{}/rend_normal".format(viewpoint.image_name), rend_normal[None], global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}/surf_normal".format(viewpoint.image_name), surf_normal[None], global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}/rend_alpha".format(viewpoint.image_name), rend_alpha[None], global_step=iteration) + + rend_dist = render_pkg["rend_dist"] + rend_dist = colormap(rend_dist.cpu().numpy()[0]) + tb_writer.add_images(config['name'] + "_view_{}/rend_dist".format(viewpoint.image_name), rend_dist[None], global_step=iteration) + except: + pass + + if iteration == testing_iterations[0]: + tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) + + l1_test += l1_loss(image, gt_image).mean().double() + psnr_test += psnr(image, gt_image).mean().double() + + psnr_test /= len(config['cameras']) + l1_test /= len(config['cameras']) + print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) + if tb_writer: + tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) + tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) + + torch.cuda.empty_cache() + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument('--ip', type=str, default="127.0.0.1") + parser.add_argument('--port', type=int, default=6009) + parser.add_argument('--detect_anomaly', action='store_true', default=False) + parser.add_argument("--test_iterations", nargs="+", type=int, default=[1,7_000, 20_000, 30_000]) + parser.add_argument("--save_iterations", nargs="+", type=int, default=[1, 7_000, 20_000, 30_000]) + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--start_checkpoint", type=str, default = None) + args = parser.parse_args(sys.argv[1:]) + args.save_iterations.append(args.iterations) + + print("Optimizing " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + # Start GUI server, configure and run training + network_gui.init(args.ip, args.port) + torch.autograd.set_detect_anomaly(args.detect_anomaly) + training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint) + + # All done + print("\nTraining complete.") \ No newline at end of file diff --git a/train_with_bg.py b/train_with_bg.py new file mode 100644 index 00000000..5a20feb4 --- /dev/null +++ b/train_with_bg.py @@ -0,0 +1,454 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import torch +from random import randint +from utils.loss_utils import l1_loss_appearance, ssim, l1_loss, ms_l1_loss +from gaussian_renderer import render, network_gui +import sys +import torch.nn.functional as F +from scene import Scene, GaussianModel, BgGaussianModel, AppearanceModel +from utils.general_utils import safe_state +from utils.patchmatch import process_propagation +import uuid +from tqdm import tqdm +from utils.image_utils import psnr, render_net_image +from argparse import ArgumentParser, Namespace +from arguments import ModelParams, PipelineParams, OptimizationParams +try: + from torch.utils.tensorboard import SummaryWriter + TENSORBOARD_FOUND = True +except ImportError: + TENSORBOARD_FOUND = False + +def prune_low_contribution_gaussians(gaussians, cameras, pipe, bg, K=5, prune_ratio=0.1): + top_list = [None, ] * K + for i, cam in enumerate(cameras): + trans = render(cam, gaussians, pipe, bg, record_transmittance=True, skip_geometric=True)["transmittance_avg"] + if top_list[0] is not None: + m = trans > top_list[0] + if m.any(): + for i in range(K - 1): + top_list[K - 1 - i][m] = top_list[K - 2 - i][m] + top_list[0][m] = trans[m] + else: + top_list = [trans.clone() for _ in range(K)] + + contribution = torch.stack(top_list, dim=-1).mean(-1) + tile = torch.quantile(contribution, prune_ratio) + prune_mask = contribution < tile + gaussians.prune_points(prune_mask) + torch.cuda.empty_cache() + +def ranking_loss(error, penalize_ratio=0.7, extra_weights=None , type='mean'): + error, indices = torch.sort(error) + # only sum relatively small errors + s_error = torch.index_select(error, 0, index=indices[:int(penalize_ratio * indices.shape[0])]) + if extra_weights is not None: + weights = torch.index_select(extra_weights, 0, index=indices[:int(penalize_ratio * indices.shape[0])]) + s_error = s_error * weights + + if type == 'mean': + return torch.mean(s_error) + elif type == 'sum': + return torch.sum(s_error) + +def normal_gradient_loss(rend_normal, gt_normal): + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().unsqueeze(0).unsqueeze(0).to(rend_normal.device) / 4 + sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).float().unsqueeze(0).unsqueeze(0).to(rend_normal.device) / 4 + + rend_grad_x = F.conv2d(rend_normal, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) + rend_grad_y = F.conv2d(rend_normal, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) + + gt_grad_x = F.conv2d(gt_normal, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) + gt_grad_y = F.conv2d(gt_normal, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) + + loss_x = F.mse_loss(rend_grad_x, gt_grad_x) + loss_y = F.mse_loss(rend_grad_y, gt_grad_y) + + return loss_x + loss_y + +def edge_aware_normal_gradient_loss(gt_image, rend_normal, gt_normal, prior_normal_mask, edge_threshold=1): + # Define Sobel filters + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().unsqueeze(0).unsqueeze(0).to(rend_normal.device) / 8 + sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).float().unsqueeze(0).unsqueeze(0).to(rend_normal.device) / 8 + + # Compute gradients of rendered and ground truth normals + rend_grad_x = F.conv2d(rend_normal, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) + rend_grad_y = F.conv2d(rend_normal, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) + + gt_grad_x = F.conv2d(gt_normal, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) + gt_grad_y = F.conv2d(gt_normal, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) + + # Compute gradients of gt_image for edge detection + dI_dx = torch.cat([F.conv2d(gt_image[i].unsqueeze(0), sobel_x, padding=1) for i in range(gt_image.shape[0])]) + dI_dx = torch.mean(torch.abs(dI_dx), 1, keepdim=True) + dI_dy = torch.cat([F.conv2d(gt_image[i].unsqueeze(0), sobel_y, padding=1) for i in range(gt_image.shape[0])]) + dI_dy = torch.mean(torch.abs(dI_dy), 1, keepdim=True) + + # Compute edge strength + edge_strength = dI_dx + dI_dy + + # Create non-edge mask + non_edge_mask = (edge_strength < edge_threshold).float() + + # Compute loss for gradients + loss_x = F.mse_loss(rend_grad_x, gt_grad_x) + loss_y = F.mse_loss(rend_grad_y, gt_grad_y) + loss = loss_x + loss_y + + # Apply non-edge mask and prior_normal_mask + masked_loss = loss * non_edge_mask * prior_normal_mask + + # Normalize by the number of non-edge pixels + num_non_edge_pixels = torch.sum(non_edge_mask * prior_normal_mask) + 1e-6 + normalized_loss = torch.sum(masked_loss) / num_non_edge_pixels + + return normalized_loss + +def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint): + first_iter = 0 + tb_writer = prepare_output_and_logger(dataset) + gaussians = GaussianModel(dataset.sh_degree) + bg_gaussians = BgGaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians, bg_gaussians) + all_cameras = scene.getTrainCameras() + + bg_gaussians.training_setup(opt) + + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + pbar = tqdm(range(5000), desc="Training Background", unit="iteration") + viewpoint_stack = None + for iteration in pbar: + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) + render_pkg = render(viewpoint_cam, gaussians, pipe, background, record_transmittance=False, + bg_gaussians=bg_gaussians, skip_geometric=True) + total_loss = ms_l1_loss(render_pkg["render"][None], viewpoint_cam.original_image.cuda()[None]) + total_loss.backward() + bg_gaussians.optimizer.step() + bg_gaussians.optimizer.zero_grad(set_to_none = True) + pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"}) + bg_gaussians.optimizer = None + # TODO: Trim unused background point + # prune_low_contribution_gaussians(bg_gaussians, all_cameras, pipe, background, + + gaussians.training_setup(opt) + if checkpoint: + (model_params, first_iter) = torch.load(checkpoint) + gaussians.restore(model_params, opt) + + iter_start = torch.cuda.Event(enable_timing = True) + iter_end = torch.cuda.Event(enable_timing = True) + + ema_loss_for_log = 0.0 + ema_depth_for_log = 0.0 + ema_normal_for_log = 0.0 + + progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") + first_iter += 1 + if dataset.use_decoupled_appearance: + appearances = AppearanceModel(len(all_cameras)) + appearances.training_setup(opt) + else: + appearances = None + + for iteration in range(first_iter, opt.iterations + 1): + + iter_start.record() + + gaussians.update_learning_rate(iteration) + + # Every 1000 its we increase the levels of SH up to a maximum degree + if iteration % 1000 == 0: + gaussians.oneupSHdegree() + bg_gaussians.oneupSHdegree() + + viewpoint_idx = randint(0, len(all_cameras)-1) + viewpoint_cam = all_cameras[viewpoint_idx] + render_pkg = render(viewpoint_cam, gaussians, pipe, background, record_transmittance=False, bg_gaussians=bg_gaussians, + skip_geometric=True) + # render_pkg = render(viewpoint_cam, gaussians, pipe, background, record_transmittance=(iteration < opt.densify_until_iter)) + image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] + + gt_image = viewpoint_cam.original_image.cuda() + Ll1 = l1_loss_appearance(image, gt_image, appearances, viewpoint_idx) # use L1 loss for the transformed image if using decoupled appearance + loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) + + rend_dist = render_pkg["rend_dist"] + lambda_dist = opt.lambda_dist if iteration > 3000 else 0.0 + dist_loss = lambda_dist * (rend_dist).mean() + + # regularization + if iteration > 15000: + render_pkg = render(viewpoint_cam, gaussians, pipe, background, record_transmittance=False) + lambda_normal = opt.lambda_normal if iteration > 15000 else 0.0 + lambda_depth = opt.propagation_begin if iteration > opt.propagation_begin else 0.0 + lambda_normal_prior = opt.lambda_normal_prior if iteration > 15000 else 0.0 + lambda_normal_gradient = opt.lambda_normal_gradient if iteration > 15000 else 0.0 + + depth_loss = torch.tensor(0.).to("cuda") + normal_loss = torch.tensor(0.).to("cuda") + normal_prior_loss = torch.tensor(0.).to("cuda") + + rend_depth = render_pkg["rend_depth"] + surf_depth = render_pkg["surf_depth"] + if lambda_depth > 0 and viewpoint_cam.depth_prior is not None: + depth_error = 0.6 * (surf_depth - viewpoint_cam.depth_prior).abs() + \ + 0.4 * (rend_depth - viewpoint_cam.depth_prior).abs() + depth_mask = viewpoint_cam.depth_mask.unsqueeze(0) + valid_depth_sum = depth_mask.sum() + 1e-5 + depth_loss += lambda_depth * (depth_error[depth_mask & ~torch.isnan(depth_error)].sum() / valid_depth_sum) + + rend_normal = render_pkg['rend_normal'] + surf_normal_median = render_pkg['surf_normal'] + surf_normal_expected = render_pkg['surf_normal_expected'] + rend_alpha = render_pkg['rend_alpha'] + + if lambda_normal > 0.0: + normal_error = 0.6 * (1 - F.cosine_similarity(rend_normal, surf_normal_median, dim=0)) + \ + 0.4 * (1 - F.cosine_similarity(rend_normal, surf_normal_expected, dim=0)) + normal_error = ranking_loss(normal_error.view(-1), penalize_ratio=1.0, type='mean') + normal_loss += lambda_normal * normal_error + + if lambda_normal_prior > 0 and dataset.w_normal_prior: + prior_normal = viewpoint_cam.normal_prior * (rend_alpha).detach() + prior_normal_mask = viewpoint_cam.normal_mask[0] + + normal_prior_error = 0.6 * (1 - F.cosine_similarity(prior_normal, rend_normal, dim=0)) + \ + 0.4 * (1 - F.cosine_similarity(prior_normal, surf_normal_expected, dim=0)) + normal_prior_error = ranking_loss(normal_prior_error[prior_normal_mask], + penalize_ratio=1.0, type='mean') + + normal_prior_loss = lambda_normal_prior * normal_prior_error + if lambda_normal_gradient > 0.0: + normal_prior_loss += lambda_normal_gradient * normal_gradient_loss(surf_normal_median, prior_normal) + else: + depth_loss = torch.tensor(0.).to("cuda") + normal_loss = torch.tensor(0.).to("cuda") + normal_prior_loss = torch.tensor(0.).to("cuda") + # loss + total_loss = loss + dist_loss + depth_loss + normal_loss + normal_prior_loss + + total_loss.backward() + + iter_end.record() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + ema_depth_for_log = 0.4 * depth_loss.item() + 0.6 * ema_depth_for_log + ema_normal_for_log = 0.4 * normal_loss.item() + 0.6 * ema_normal_for_log + + + if iteration % 10 == 0: + loss_dict = { + "Loss": f"{ema_loss_for_log:.{5}f}", + "depth": f"{ema_depth_for_log:.{5}f}", + "normal": f"{ema_normal_for_log:.{5}f}", + "Points": f"{len(gaussians.get_xyz)}" + } + progress_bar.set_postfix(loss_dict) + + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + + # Log and save + if tb_writer is not None: + tb_writer.add_scalar('train_loss_patches/dist_loss', ema_depth_for_log, iteration) + tb_writer.add_scalar('train_loss_patches/normal_loss', ema_normal_for_log, iteration) + + # training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) + if (iteration in saving_iterations): + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(iteration) + + + # Densification + if iteration < opt.densify_until_iter: + gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter, None) + + if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: + prune_big_points = True if iteration > opt.opacity_reset_interval else False + gaussians.densify_and_prune(opt.densify_grad_threshold, opt.opacity_cull, scene.cameras_extent, prune_big_points) + + # if render_pkg["transmittance_avg"] is not None: + # gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], + # radii[visibility_filter] * (render_pkg["transmittance_avg"][visibility_filter] > 0.01)) + # if iteration > 7000 and iteration % opt.split_interval == 0: + # gaussians.split_big_points(opt.max_screen_size) + + if iteration > opt.contribution_prune_from_iter and iteration % opt.contribution_prune_interval == 0: + if iteration % opt.opacity_reset_interval == opt.contribution_prune_interval or \ + iteration % opt.opacity_reset_interval == opt.split_interval: + print("Skipped Pruning for", iteration) + continue + prune_low_contribution_gaussians(gaussians, all_cameras, pipe, background, + K=1, prune_ratio=opt.contribution_prune_ratio) + print(f'Num gs after contribution prune: {len(gaussians.get_xyz)}') + + if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): + gaussians.reset_opacity() + + # Optimizer step + if iteration < opt.iterations: + # visible = radii > 0 + # gaussians.optimizer.step(visible, radii.shape[0]) + gaussians.optimizer.step() + gaussians.optimizer.zero_grad(set_to_none = True) + if appearances is not None: + appearances.optimizer.step() + appearances.optimizer.zero_grad(set_to_none = True) + + if (iteration in checkpoint_iterations): + print("\n[ITER {}] Saving Checkpoint".format(iteration)) + torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") + + with torch.no_grad(): + if network_gui.conn == None: + network_gui.try_connect(dataset.render_items) + while network_gui.conn != None: + try: + net_image_bytes = None + custom_cam, do_training, keep_alive, scaling_modifer, render_mode = network_gui.receive() + if custom_cam != None: + render_pkg = render(custom_cam, gaussians, pipe, background, scaling_modifer) + net_image = render_net_image(render_pkg, dataset.render_items, render_mode, custom_cam) + net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) + metrics_dict = { + "#": gaussians.get_opacity.shape[0], + "loss": ema_loss_for_log + # Add more metrics as needed + } + # Send the data + network_gui.send(net_image_bytes, dataset.source_path, metrics_dict) + if do_training and ((iteration < int(opt.iterations)) or not keep_alive): + break + except Exception as e: + # raise e + network_gui.conn = None + +def prepare_output_and_logger(args): + if not args.model_path: + if os.getenv('OAR_JOB_ID'): + unique_str=os.getenv('OAR_JOB_ID') + else: + unique_str = str(uuid.uuid4()) + args.model_path = os.path.join("./output/", unique_str[0:10]) + + # Set up output folder + print("Output folder: {}".format(args.model_path)) + os.makedirs(args.model_path, exist_ok = True) + with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: + cfg_log_f.write(str(Namespace(**vars(args)))) + + # Create Tensorboard writer + tb_writer = None + if TENSORBOARD_FOUND: + tb_writer = SummaryWriter(args.model_path) + else: + print("Tensorboard not available: not logging progress") + return tb_writer + +@torch.no_grad() +def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): + if tb_writer: + tb_writer.add_scalar('train_loss_patches/reg_loss', Ll1.item(), iteration) + tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) + tb_writer.add_scalar('iter_time', elapsed, iteration) + tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) + + # Report test and samples of training set + if iteration in testing_iterations: + torch.cuda.empty_cache() + validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, + {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) + + for config in validation_configs: + if config['cameras'] and len(config['cameras']) > 0: + l1_test = 0.0 + psnr_test = 0.0 + for idx, viewpoint in enumerate(config['cameras']): + render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs) + image = torch.clamp(render_pkg["render"], 0.0, 1.0) + gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) + if tb_writer and (idx < 5): + from utils.general_utils import colormap + depth = render_pkg["surf_depth"] + norm = depth.max() + depth = depth / norm + depth = colormap(depth.cpu().numpy()[0], cmap='turbo') + tb_writer.add_images(config['name'] + "_view_{}/depth".format(viewpoint.image_name), depth[None], global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) + + try: + rend_alpha = render_pkg['rend_alpha'] + rend_normal = render_pkg["rend_normal"] * 0.5 + 0.5 + surf_normal = render_pkg["surf_normal"] * 0.5 + 0.5 + tb_writer.add_images(config['name'] + "_view_{}/rend_normal".format(viewpoint.image_name), rend_normal[None], global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}/surf_normal".format(viewpoint.image_name), surf_normal[None], global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}/rend_alpha".format(viewpoint.image_name), rend_alpha[None], global_step=iteration) + + rend_dist = render_pkg["rend_dist"] + rend_dist = colormap(rend_dist.cpu().numpy()[0]) + tb_writer.add_images(config['name'] + "_view_{}/rend_dist".format(viewpoint.image_name), rend_dist[None], global_step=iteration) + except: + pass + + if iteration == testing_iterations[0]: + tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) + + l1_test += l1_loss(image, gt_image).mean().double() + psnr_test += psnr(image, gt_image).mean().double() + + psnr_test /= len(config['cameras']) + l1_test /= len(config['cameras']) + print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) + if tb_writer: + tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) + tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) + + torch.cuda.empty_cache() + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument('--ip', type=str, default="127.0.0.1") + parser.add_argument('--port', type=int, default=6009) + parser.add_argument('--detect_anomaly', action='store_true', default=False) + parser.add_argument("--test_iterations", nargs="+", type=int, default=[1, 7_000, 20_000, 30_000]) + parser.add_argument("--save_iterations", nargs="+", type=int, default=[1, 7_000, 20_000, 30_000]) + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--start_checkpoint", type=str, default = None) + args = parser.parse_args(sys.argv[1:]) + args.save_iterations.append(args.iterations) + + print("Optimizing " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + # Start GUI server, configure and run training + network_gui.init(args.ip, args.port) + torch.autograd.set_detect_anomaly(args.detect_anomaly) + training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint) + + # All done + print("\nTraining complete.") \ No newline at end of file diff --git a/utils/camera_utils.py b/utils/camera_utils.py index 0af952a5..f462f441 100644 --- a/utils/camera_utils.py +++ b/utils/camera_utils.py @@ -3,7 +3,7 @@ # GRAPHDECO research group, https://team.inria.fr/graphdeco # All rights reserved. # -# This software is free for non-commercial, research and evaluation use +# This software is free for non-commercial, research and evaluation use # under the terms of the LICENSE.md file. # # For inquiries contact george.drettakis@inria.fr @@ -13,7 +13,9 @@ import numpy as np from utils.general_utils import PILtoTorch from utils.graphics_utils import fov2focal - +from PIL import Image +import os +import torch.nn.functional as F WARNED = False def loadCam(args, id, cam_info, resolution_scale): @@ -45,13 +47,49 @@ def loadCam(args, id, cam_info, resolution_scale): gt_image = resized_image_rgb else: resized_image_rgb = PILtoTorch(cam_info.image, resolution) - loaded_mask = None + if args.w_mask: + mask_dir = os.path.join(os.path.dirname(os.path.dirname(cam_info.image_path)), args.w_mask) + if not os.path.isdir(mask_dir): + exit(f"Cannot find mask dir {mask_dir}") + mask_path = os.path.join(mask_dir, os.path.basename(cam_info.image_name) + '.png') + loaded_mask = Image.open(mask_path) + loaded_mask = PILtoTorch(loaded_mask, resolution) + else: + loaded_mask = None gt_image = resized_image_rgb - return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, - FoVx=cam_info.FovX, FoVy=cam_info.FovY, - image=gt_image, gt_alpha_mask=loaded_mask, - image_name=cam_info.image_name, uid=id, data_device=args.data_device) + if args.w_normal_prior: + import torch + # normal_path = cam_info.image_path.replace('images_4', args.w_normal_prior) + normal_path = os.path.join(os.path.dirname(os.path.dirname(cam_info.image_path)), args.w_normal_prior, os.path.basename(cam_info.image_path).split('.')[0]) + if os.path.exists(normal_path+ '.npy'): + _normal = torch.tensor(np.load(normal_path+ '.npy')) + _normal = - (_normal * 2 - 1) + resized_normal = F.interpolate(_normal.unsqueeze(0), size=resolution[::-1], mode='bicubic') + _normal = resized_normal.squeeze(0) + # normalize normal + _normal = _normal.permute(1, 2, 0) @ (torch.tensor(np.linalg.inv(cam_info.R)).float()) + _normal = _normal.permute(2, 0, 1) + elif os.path.exists(normal_path+ '.png'): + _normal = Image.open(normal_path+ '.png') + resized_normal = PILtoTorch(_normal, resolution) + resized_normal = resized_normal[:3] + _normal = - (resized_normal * 2 - 1) + # normalize normal + _normal = _normal.permute(1, 2, 0) @ (torch.tensor(np.linalg.inv(cam_info.R)).float()) + _normal = _normal.permute(2, 0, 1) + else: + print(f"Cannot find normal {normal_path}.png") + _normal = None + else: + _normal = None + + return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, + FoVx=cam_info.FovX, FoVy=cam_info.FovY, + image=gt_image, normal=_normal, gt_alpha_mask=loaded_mask, + image_name=cam_info.image_name, uid=id, + principal_point_ndc=cam_info.principal_point_ndc, + data_device=args.data_device) def cameraList_from_camInfos(cam_infos, resolution_scale, args): camera_list = [] diff --git a/utils/graphics_utils.py b/utils/graphics_utils.py index b4627d83..78a65568 100644 --- a/utils/graphics_utils.py +++ b/utils/graphics_utils.py @@ -3,7 +3,7 @@ # GRAPHDECO research group, https://team.inria.fr/graphdeco # All rights reserved. # -# This software is free for non-commercial, research and evaluation use +# This software is free for non-commercial, research and evaluation use # under the terms of the LICENSE.md file. # # For inquiries contact george.drettakis@inria.fr @@ -13,6 +13,13 @@ import math import numpy as np from typing import NamedTuple +try: + from gaussianpro import propagate + propagation_installed = True +except: + propagation_installed = False + print("gaussianpro not installed") + class BasicPointCloud(NamedTuple): points : np.array @@ -70,8 +77,190 @@ def getProjectionMatrix(znear, zfar, fovX, fovY): P[2, 3] = -(zfar * znear) / (zfar - znear) return P +def generate_K(width, height, fovX, fovY, principal_point_ndc): + # Calculate focal lengths + focal_x = fov2focal(fovX, width) + focal_y = fov2focal(fovY, height) + + # Calculate principal point + cx = width * principal_point_ndc[0] + cy = height * principal_point_ndc[1] + + # Create the K matrix + K = torch.zeros(3, 3) + K[0, 0] = focal_x + K[1, 1] = focal_y + K[0, 2] = cx + K[1, 2] = cy + K[2, 2] = 1.0 + + return K + +def getProjectionMatrixShift(znear, zfar, fovX, fovY, width, height, principal_point_ndc): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + # the origin at center of image plane + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + # shift the frame window due to the non-zero principle point offsets + cx = width * principal_point_ndc[0] + cy = height * principal_point_ndc[1] + focal_x = fov2focal(fovX, width) + focal_y = fov2focal(fovY, height) + offset_x = cx - (width / 2) + offset_x = (offset_x / focal_x) * znear + offset_y = cy - (height / 2) + offset_y = (offset_y / focal_y) * znear + + top = top + offset_y + left = left + offset_x + right = right + offset_x + bottom = bottom + offset_y + + P = torch.zeros(4, 4) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + def fov2focal(fov, pixels): return pixels / (2 * math.tan(fov / 2)) def focal2fov(focal, pixels): - return 2*math.atan(pixels/(2*focal)) \ No newline at end of file + return 2*math.atan(pixels/(2*focal)) + + +def patchmatch_propagation(viewpoint_cam, rendered_depth, rendered_normal, viewpoint_stack, src_idxs, patch_size): + depth_min = 0.1 + depth_max = 80 + + images = list() + intrinsics = list() + poses = list() + depth_intervals = list() + + images.append((viewpoint_cam.original_image * 255).permute((1, 2, 0)).to(torch.uint8)) + intrinsics.append(viewpoint_cam.K) + poses.append(viewpoint_cam.world_view_transform.transpose(0, 1)) + depth_interval = torch.tensor([depth_min, (depth_max-depth_min)/192.0, 192.0, depth_max]) + depth_intervals.append(depth_interval) + + depth = rendered_depth.unsqueeze(-1) + depth = depth.squeeze(0) + normal = rendered_normal.permute([1, 2, 0]) + + + for idx, src_idx in enumerate(src_idxs): + src_viewpoint = viewpoint_stack[src_idx] + images.append((src_viewpoint.original_image * 255).permute((1, 2, 0)).to(torch.uint8)) + intrinsics.append(src_viewpoint.K) + poses.append(src_viewpoint.world_view_transform.transpose(0, 1)) + depth_intervals.append(depth_interval) + + images = torch.stack(images) + intrinsics = torch.stack(intrinsics) + poses = torch.stack(poses) + depth_intervals = torch.stack(depth_intervals) + + results = propagate(images, intrinsics, poses, depth, normal, depth_intervals, patch_size) + propagated_depth = results[0].to(rendered_depth.device) + propagated_normal = results[1:4].to(rendered_depth.device).permute(1, 2, 0) + + return propagated_depth, propagated_normal + + +def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src, thre1=1, thre2=0.01): + batch, height, width = depth_ref.shape + y_ref, x_ref = torch.meshgrid(torch.arange(0, height).to(depth_ref.device), torch.arange(0, width).to(depth_ref.device)) + x_ref = x_ref.unsqueeze(0).repeat(batch, 1, 1) + y_ref = y_ref.unsqueeze(0).repeat(batch, 1, 1) + inputs = [depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src] + outputs = reproject_with_depth(*inputs) + depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = outputs + # check |p_reproj-p_1| < 1 + dist = torch.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2) + + # check |d_reproj-d_1| / d_1 < 0.01 + depth_diff = torch.abs(depth_reprojected - depth_ref) + relative_depth_diff = depth_diff / depth_ref + + mask = torch.logical_and(dist < thre1, relative_depth_diff < thre2) + depth_reprojected[~mask] = 0 + + return mask, depth_reprojected, x2d_src, y2d_src, relative_depth_diff + + +def bilinear_sampler(img, coords, mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = torch.nn.functional.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +# project the reference point cloud into the source view, then project back +def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): + batch, height, width = depth_ref.shape + + ## step1. project reference pixels to the source view + # reference view x, y + y_ref, x_ref = torch.meshgrid(torch.arange(0, height).to(depth_ref.device), torch.arange(0, width).to(depth_ref.device)) + x_ref = x_ref.unsqueeze(0).repeat(batch, 1, 1) + y_ref = y_ref.unsqueeze(0).repeat(batch, 1, 1) + x_ref, y_ref = x_ref.reshape(batch, -1), y_ref.reshape(batch, -1) + # reference 3D space + + A = torch.inverse(intrinsics_ref) + B = torch.stack((x_ref, y_ref, torch.ones_like(x_ref).to(x_ref.device)), dim=1) * depth_ref.reshape(batch, 1, -1) + xyz_ref = torch.matmul(A, B) + + # source 3D space + xyz_src = torch.matmul(torch.matmul(torch.inverse(extrinsics_src), extrinsics_ref), + torch.cat((xyz_ref, torch.ones_like(x_ref).to(x_ref.device).unsqueeze(1)), dim=1))[:, :3] + # source view x, y + K_xyz_src = torch.matmul(intrinsics_src, xyz_src) + xy_src = K_xyz_src[:, :2] / K_xyz_src[:, 2:3] + + ## step2. reproject the source view points with source view depth estimation + # find the depth estimation of the source view + x_src = xy_src[:, 0].reshape([batch, height, width]).float() + y_src = xy_src[:, 1].reshape([batch, height, width]).float() + + # print(x_src, y_src) + sampled_depth_src = bilinear_sampler(depth_src.view(batch, 1, height, width), torch.stack((x_src, y_src), dim=-1).view(batch, height, width, 2)) + + # source 3D space + # NOTE that we should use sampled source-view depth_here to project back + xyz_src = torch.matmul(torch.inverse(intrinsics_src), + torch.cat((xy_src, torch.ones_like(x_ref).to(x_ref.device).unsqueeze(1)), dim=1) * sampled_depth_src.reshape(batch, 1, -1)) + # reference 3D space + xyz_reprojected = torch.matmul(torch.matmul(torch.inverse(extrinsics_ref), extrinsics_src), + torch.cat((xyz_src, torch.ones_like(x_ref).to(x_ref.device).unsqueeze(1)), dim=1))[:, :3] + # source view x, y, depth + depth_reprojected = xyz_reprojected[:, 2].reshape([batch, height, width]).float() + K_xyz_reprojected = torch.matmul(intrinsics_ref, xyz_reprojected) + xy_reprojected = K_xyz_reprojected[:, :2] / K_xyz_reprojected[:, 2:3] + x_reprojected = xy_reprojected[:, 0].reshape([batch, height, width]).float() + y_reprojected = xy_reprojected[:, 1].reshape([batch, height, width]).float() + + return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src \ No newline at end of file diff --git a/utils/loss_utils.py b/utils/loss_utils.py index 7ef1d77d..f98cbd83 100644 --- a/utils/loss_utils.py +++ b/utils/loss_utils.py @@ -14,12 +14,93 @@ from torch.autograd import Variable from math import exp +def edge_aware_curvature_loss(I, D, mask=None): + # Define Sobel kernels + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().unsqueeze(0).unsqueeze(0).to(I.device) / 4 + sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).float().unsqueeze(0).unsqueeze(0).to(I.device) / 4 + + # Compute derivatives of D + dD_dx = torch.cat([F.conv2d(D[i].unsqueeze(0), sobel_x, padding=1) for i in range(D.shape[0])]) + dD_dy = torch.cat([F.conv2d(D[i].unsqueeze(0), sobel_y, padding=1) for i in range(D.shape[0])]) + + # Compute derivatives of I + dI_dx = torch.cat([F.conv2d(I[i].unsqueeze(0), sobel_x, padding=1) for i in range(I.shape[0])]) + dI_dx = torch.mean(torch.abs(dI_dx), 0, keepdim=True) + dI_dy = torch.cat([F.conv2d(I[i].unsqueeze(0), sobel_y, padding=1) for i in range(I.shape[0])]) + dI_dy = torch.mean(torch.abs(dI_dy), 0, keepdim=True) + + # Compute weights + weights_x = (dI_dx - 1) ** 500 + weights_y = (dI_dy - 1) ** 500 + + # Compute losses + loss_x = torch.abs(dD_dx) * weights_x + loss_y = torch.abs(dD_dy) * weights_y + + # Apply mask to losses + if mask is not None: + # Ensure mask is on the correct device and has correct dimensions + mask = mask.to(I.device) + loss_x = loss_x * mask + loss_y = loss_y * mask + + # Count valid pixels + valid_pixel_count = mask.sum() + + # Compute the mean loss only over valid pixels + if valid_pixel_count.item() > 0: + return (loss_x.sum() + loss_y.sum()) / valid_pixel_count + else: + # Handle the case where no valid pixels exist + return torch.tensor(0.0, device=I.device, requires_grad=True) + else: + # If no mask is provided, calculate the mean over all pixels + return (loss_x + loss_y).mean() + def l1_loss(network_output, gt): return torch.abs((network_output - gt)).mean() +def ms_l1_loss(network_output, gt, scales=[1, 2, 4]): + total_loss = 0 + weights = [1.0, 0.5, 0.25] # Weights for different scales, adjust as needed + + for scale, weight in zip(scales, weights): + if scale == 1: + # Original resolution + total_loss += weight * l1_loss(network_output, gt) + else: + # Downsampled resolution + scaled_output = F.interpolate(network_output, scale_factor=1/scale, mode='bilinear', align_corners=False) + scaled_gt = F.interpolate(gt, scale_factor=1/scale, mode='bilinear', align_corners=False) + total_loss += weight * l1_loss(scaled_output, scaled_gt) + + return total_loss + def l2_loss(network_output, gt): return ((network_output - gt) ** 2).mean() +def l1_loss_appearance(image, gt_image, appearances, view_idx): + if appearances is None: + return l1_loss(image, gt_image) + else: + appearance_embedding = appearances.get_embedding(view_idx) + # center crop the image + origH, origW = image.shape[1:] + H = origH // 32 * 32 + W = origW // 32 * 32 + left = origW // 2 - W // 2 + top = origH // 2 - H // 2 + crop_image = image[:, top:top+H, left:left+W] + crop_gt_image = gt_image[:, top:top+H, left:left+W] + + # down sample the image + crop_image_down = torch.nn.functional.interpolate(crop_image[None], size=(H//32, W//32), mode="bilinear", align_corners=True)[0] + + crop_image_down = torch.cat([crop_image_down, appearance_embedding[None].repeat(H//32, W//32, 1).permute(2, 0, 1)], dim=0)[None] + mapping_image = appearances.appearance_network(crop_image_down) + transformed_image = mapping_image * crop_image + return l1_loss(transformed_image, crop_gt_image) + def gaussian(window_size, sigma): gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) return gauss / gauss.sum() @@ -64,6 +145,7 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True): C1 = 0.01 ** 2 C2 = 0.03 ** 2 + # C1 = C2 = 0.01 ** 2 ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) diff --git a/utils/mesh_utils.py b/utils/mesh_utils.py index e9b1524f..69a2690b 100644 --- a/utils/mesh_utils.py +++ b/utils/mesh_utils.py @@ -92,8 +92,8 @@ def clean(self): self.depthmaps = [] # self.alphamaps = [] self.rgbmaps = [] - # self.normals = [] - # self.depth_normals = [] + self.normals = [] + self.depth_normals = [] self.viewpoint_stack = [] @torch.no_grad() @@ -113,8 +113,8 @@ def reconstruction(self, viewpoint_stack): self.rgbmaps.append(rgb.cpu()) self.depthmaps.append(depth.cpu()) # self.alphamaps.append(alpha.cpu()) - # self.normals.append(normal.cpu()) - # self.depth_normals.append(depth_normal.cpu()) + self.normals.append(normal.cpu()) + self.depth_normals.append(depth_normal.cpu()) # self.rgbmaps = torch.stack(self.rgbmaps, dim=0) # self.depthmaps = torch.stack(self.depthmaps, dim=0) @@ -165,11 +165,11 @@ def extract_mesh_bounded(self, voxel_size=0.004, sdf_trunc=0.02, depth_trunc=3, # if we have mask provided, use it if mask_backgrond and (self.viewpoint_stack[i].gt_alpha_mask is not None): - depth[(self.viewpoint_stack[i].gt_alpha_mask < 0.5)] = 0 + depth[(self.viewpoint_stack[i].gt_alpha_mask.mean(dim=0)[None] < 0.5)] = 0 # make open3d rgbd rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( - o3d.geometry.Image(np.asarray(np.clip(rgb.permute(1,2,0).cpu().numpy(), 0.0, 1.0) * 255, order="C", dtype=np.uint8)), + o3d.geometry.Image(np.asarray(rgb.permute(1,2,0).cpu().numpy() * 255, order="C", dtype=np.uint8)), o3d.geometry.Image(np.asarray(depth.permute(1,2,0).cpu().numpy(), order="C")), depth_trunc = depth_trunc, convert_rgb_to_intensity=False, depth_scale = 1.0 @@ -290,6 +290,6 @@ def export_image(self, path): gt = viewpoint_cam.original_image[0:3, :, :] save_img_u8(gt.permute(1,2,0).cpu().numpy(), os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) save_img_u8(self.rgbmaps[idx].permute(1,2,0).cpu().numpy(), os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) - save_img_f32(self.depthmaps[idx][0].cpu().numpy(), os.path.join(vis_path, 'depth_{0:05d}'.format(idx) + ".tiff")) - # save_img_u8(self.normals[idx].permute(1,2,0).cpu().numpy() * 0.5 + 0.5, os.path.join(vis_path, 'normal_{0:05d}'.format(idx) + ".png")) - # save_img_u8(self.depth_normals[idx].permute(1,2,0).cpu().numpy() * 0.5 + 0.5, os.path.join(vis_path, 'depth_normal_{0:05d}'.format(idx) + ".png")) + # save_img_f32(self.depthmaps[idx][0].cpu().numpy(), os.path.join(vis_path, 'depth_{0:05d}'.format(idx) + ".tiff")) + save_img_u8(self.normals[idx].permute(1,2,0).cpu().numpy() * 0.5 + 0.5, os.path.join(vis_path, 'normal_{0:05d}'.format(idx) + ".png")) + save_img_u8(self.depth_normals[idx].permute(1,2,0).cpu().numpy() * 0.5 + 0.5, os.path.join(vis_path, 'depth_normal_{0:05d}'.format(idx) + ".png")) diff --git a/utils/patchmatch.py b/utils/patchmatch.py new file mode 100644 index 00000000..dd1a65a5 --- /dev/null +++ b/utils/patchmatch.py @@ -0,0 +1,67 @@ +import torch +from gaussian_renderer import render +from utils.graphics_utils import patchmatch_propagation, check_geometric_consistency, propagation_installed + + +def process_propagation(viewpoint_stack, viewpoint_cam, gaussians, pipe, background, iteration, opt, src_idxs): + if not propagation_installed: + return + + with torch.no_grad(): + if iteration > opt.propagation_begin and iteration < opt.propagation_after and (iteration % opt.propagation_interval == 0): + render_pkg = render(viewpoint_cam, gaussians, pipe, background) + projected_depth = render_pkg["rend_depth"] / render_pkg['rend_alpha'] + rendered_normal = render_pkg["rend_normal"] / render_pkg['rend_alpha'] if viewpoint_cam.normal_prior is None else viewpoint_cam.normal_prior + R_w2c = torch.tensor(viewpoint_cam.R.T).cuda().to(torch.float32) + rendered_normal_cam = (R_w2c @ rendered_normal.view(3, -1)).view(3, viewpoint_cam.image_height, viewpoint_cam.image_width) + + # get the propagated depth + propagated_depth, propagated_normal = patchmatch_propagation(viewpoint_cam, projected_depth, rendered_normal_cam, viewpoint_stack, src_idxs, opt.patch_size) + propagated_normal = propagated_normal.permute(2, 0, 1) + valid_mask = propagated_depth != 300 + + # calculate the abs rel depth error of the propagated depth and rendered depth + abs_rel_error = torch.abs(propagated_depth - projected_depth) / propagated_depth + abs_rel_error_threshold = opt.depth_error_max_threshold - (opt.depth_error_max_threshold - opt.depth_error_min_threshold) * (iteration - opt.propagation_begin) / (opt.propagation_after - opt.propagation_begin) + + #for waymo, quantile 0.6 + error_mask = (abs_rel_error > abs_rel_error_threshold) + + # calculate the geometric consistency + ref_K = viewpoint_cam.K + ref_pose = viewpoint_cam.world_view_transform.transpose(0, 1).inverse() + geometric_counts = None + for idx, src_idx in enumerate(src_idxs): + src_viewpoint = viewpoint_stack[src_idx] + #c2w + src_pose = src_viewpoint.world_view_transform.transpose(0, 1).inverse() + src_K = src_viewpoint.K + src_render_pkg = render(src_viewpoint, gaussians, pipe, background) + src_projected_depth = src_render_pkg["rend_depth"] / src_render_pkg['rend_alpha'] + src_rendered_normal = src_render_pkg["rend_normal"] / src_render_pkg['rend_alpha'] if src_viewpoint.normal_prior is None else src_viewpoint.normal_prior + R_w2c = torch.tensor(src_viewpoint.R.T).cuda().to(torch.float32) + src_rendered_normal_cam = (R_w2c @ src_rendered_normal.view(3, -1)).view(3, src_viewpoint.image_height, src_viewpoint.image_width) + + src_depth, _ = patchmatch_propagation(src_viewpoint, src_projected_depth, src_rendered_normal_cam, viewpoint_stack, src_idxs, opt.patch_size) + mask, depth_reprojected, x2d_src, y2d_src, relative_depth_diff = check_geometric_consistency(propagated_depth.unsqueeze(0), ref_K.unsqueeze(0), + ref_pose.unsqueeze(0), src_depth.unsqueeze(0), + src_K.unsqueeze(0), src_pose.unsqueeze(0), thre1=5, thre2=0.05) + if geometric_counts is None: + geometric_counts = mask.to(torch.uint8) + else: + geometric_counts += mask.to(torch.uint8) + + cost = geometric_counts.squeeze() + cost_mask = cost >= 1 + + propagated_mask = valid_mask & error_mask & cost_mask + propagated_mask = propagated_mask.squeeze(0) + depth_mask = (valid_mask & cost_mask).squeeze(0) + projected_depth = projected_depth.squeeze(0) + + viewpoint_cam.depth_prior = projected_depth + viewpoint_cam.depth_mask = depth_mask + + if propagated_mask.sum() > 100: + gaussians.densify_from_depth_propagation(viewpoint_cam, propagated_depth, rendered_normal.permute(1, 2, 0), propagated_mask) + # gaussians.densify_from_depth_propagation(viewpoint_cam, propagated_depth, propagated_normal, propagated_mask) \ No newline at end of file