Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add latest Gaussian Splatting techiques to 2DGS #151

Open
wants to merge 84 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
6c298b0
add supportment for non-centralized pinhole cameras data
yulunwu0108 Jun 12, 2024
15b3aa8
feat: Add normal loader
hugoycj Jul 23, 2024
ee41a99
Merge branch 'main' into leonwu0108/main
hugoycj Jul 23, 2024
9c49968
feat: Add ranking normal loss from Wonder3D
hugoycj Aug 30, 2024
79a475b
feat: Add normal gradient loss
hugoycj Aug 30, 2024
da4d038
feat: Add low contribution gaussian pruning
hugoycj Sep 1, 2024
340792a
fix: Update transmittance recording for prune
hugoycj Sep 1, 2024
da727ef
fix: Update normal_prior, pruning config
hugoycj Sep 1, 2024
33095da
feat: Add appearance network model and update optimization params
hugoycj Sep 4, 2024
c0d1c27
feat: Add edge-aware normal gradient loss and appearance model
hugoycj Sep 4, 2024
0df56b5
fix: Update submodule
hugoycj Sep 4, 2024
25ab52a
Merge branch 'main' of https://github.com/hugoycj/2d-gaussian-splatti…
hugoycj Sep 4, 2024
69fd664
fix: Update default resolution
hugoycj Sep 4, 2024
99dba7b
fix: Update diff-surfel-rasterization and appearance model
hugoycj Sep 8, 2024
b48e332
fix: Tune hyperparameters
hugoycj Sep 11, 2024
d799e51
update
Sep 12, 2024
a56dbba
feat: Enhance propagation and normal rendering in patchmatch
hugoycj Sep 16, 2024
4fd89f0
feat: Add depth rendering and camera intrinsic calculation
hugoycj Sep 16, 2024
3273705
feat: Add normal to rotation conversion and improve depth propagation
hugoycj Sep 16, 2024
6d9c326
feat: Enhance PatchMatch with normal initialization and improved plan…
hugoycj Sep 16, 2024
e1926b0
fix: Clean egg
hugoycj Sep 16, 2024
9ed89df
feat: Add pixelgs densification
hugoycj Sep 22, 2024
51661a9
Merge branch 'feat/propagation' into main
hugoycj Sep 22, 2024
a883585
Merge branch 'feat/pixelgs' into main
hugoycj Sep 22, 2024
e9f0577
feat: Add big points splitting
hugoycj Sep 22, 2024
dd521f1
feat: Improve Gaussian model pruning and splitting logic
hugoycj Sep 23, 2024
464bdce
fix: Replace distortion loss with modified normal consistency loss
hugoycj Sep 23, 2024
963dca3
Merge pull request #2 from hugoycj/feat/split_big_points
hugoycj Sep 23, 2024
7ca5c22
fix: Tune normal prior and trimgs hyperparameters
hugoycj Sep 23, 2024
484303a
feat: Add progressive training
hugoycj Sep 25, 2024
106ec91
fix: update config
hugoycj Sep 25, 2024
455204d
feat: Remove opencv dependency
hugoycj Sep 27, 2024
e65c3f2
fix: Add depth, depth_mask for depth loss
hugoycj Sep 27, 2024
33a137c
fix: Update diff-surfel-rasterization
hugoycj Sep 27, 2024
2419d68
fix: Update normal initialization
hugoycj Sep 27, 2024
949e8c7
fix: Improve initialization speed by fusion operation
hugoycj Sep 27, 2024
5243b9f
feat: Enhance depth handling and bug fixes
hugoycj Sep 27, 2024
4fc87f4
Merge pull request #3 from hugoycj/feat/torch_propagation
hugoycj Sep 27, 2024
0818608
fix: Update diff-surfel-rasterization
hugoycj Sep 29, 2024
a4e9312
feat: Add normal guided 2dgs init
hugoycj Sep 29, 2024
90e3687
fix: Update normal-guided init
hugoycj Sep 30, 2024
d02d891
Merge branch 'hbb1:main' into main
hugoycj Oct 10, 2024
7fcf600
fix: Update train_progressive
hugoycj Oct 13, 2024
8b3f0b7
fix: Update fetchPly to load point cloud without normals
hugoycj Oct 13, 2024
465ba80
feat: Add background gaussian
hugoycj Oct 13, 2024
0b28dce
feat: Update BgGaussianModel training setup
hugoycj Oct 14, 2024
66feecf
feat: Add bg_gaussians export and import
hugoycj Oct 14, 2024
f6a348d
feat: Add bg_gaussians pretraining and finetune
hugoycj Oct 14, 2024
b2b4447
feat: Add ms_l1_loss
hugoycj Oct 14, 2024
1c0fd9d
feat: Add skip_geometric flag to improve efficiency
hugoycj Oct 14, 2024
3192088
clean: Remove bg gaussian code from train.py
hugoycj Oct 16, 2024
2b8e156
fix: Update skip_geometric flag
hugoycj Oct 16, 2024
7f37cbb
feat: Optimize Gaussian model training
hugoycj Oct 16, 2024
0faddba
Merge pull request #4 from hugoycj/feat/bg_gaussians
hugoycj Oct 16, 2024
3c0e0fc
Merge branch 'hbb1:main' into main
hugoycj Oct 16, 2024
dbcd181
feat: Add w_mask flag to load extra mask
hugoycj Oct 16, 2024
6b63529
feat: Add alpha loss
hugoycj Oct 16, 2024
ab2cadc
fix: Update nerf evaluation scripts
hugoycj Oct 16, 2024
04323e9
Merge pull request #5 from hugoycj/feat/alpha_loss
hugoycj Oct 16, 2024
ff9149e
fix: Update lambda_mask
hugoycj Oct 16, 2024
7bf6498
Merge branch 'feat/alpha_loss' into main
hugoycj Oct 16, 2024
fb8cd9c
fix: Add default principal_point_ndc
hugoycj Oct 16, 2024
608afe0
fix: Fix masked normal loss
hugoycj Oct 16, 2024
a474861
feat: Enable normals and depth normals processing
hugoycj Oct 16, 2024
e5641a3
fix: Update train_progressive
hugoycj Oct 16, 2024
4e414e8
feat: Add training scripts
hugoycj Oct 16, 2024
965fc37
fix: Add background model init
hugoycj Oct 16, 2024
4615fa5
fix: Add init_normal scripts
hugoycj Oct 16, 2024
e7b4bbb
fix: Update normal_prior_loss
hugoycj Oct 16, 2024
de9d8fe
fix: Update normal loading
hugoycj Oct 16, 2024
b194160
docs: Add wnormal training scripts
hugoycj Oct 16, 2024
163372d
fix: Update wmask scripts
hugoycj Oct 16, 2024
a339fb3
fix: Update diff-surfel-rasterization
hugoycj Oct 17, 2024
83ded47
fix: Improve mask handling and prevent NaN in normals
hugoycj Oct 17, 2024
eb55716
fix: Add edge_aware_curvature_loss from atomg
hugoycj Oct 17, 2024
1604bf3
fix: Update curvature loss
hugoycj Oct 17, 2024
2301269
fix: Update training scripts
hugoycj Oct 17, 2024
c403e24
fix: Update normal supervision
hugoycj Oct 23, 2024
51bf7db
Update train_wmask_wnormal.sh
hugoycj Oct 25, 2024
9eea484
feat: Add 2dgs fast scripts
hugoycj Oct 27, 2024
de77878
Merge remote-tracking branch 'git/main' into main
hugoycj Oct 27, 2024
515588a
fix: Update scripts
hugoycj Oct 29, 2024
d4b3dcb
feat: Add fused-ssim for acceleration
hugoycj Nov 5, 2024
323ad80
fix: Update normal scripts to latest version and fix some bugs
hugoycj Nov 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ data
*.out
eval
*.npz
**/tmp
**/tmp
eval_dtu
9 changes: 6 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
[submodule "submodules/diff-surfel-rasterization"]
path = submodules/diff-surfel-rasterization
url = https://github.com/hbb1/diff-surfel-rasterization.git
[submodule "submodules/simple-knn"]
path = submodules/simple-knn
url = https://gitlab.inria.fr/bkerbl/simple-knn.git
[submodule "submodules/diff-surfel-rasterization"]
path = submodules/diff-surfel-rasterization
url = https://github.com/hugoycj/diff-surfel-rasterization-MCMC
[submodule "submodules/fused-ssim"]
path = submodules/fused-ssim
url = https://github.com/rahul-goel/fused-ssim.git
29 changes: 27 additions & 2 deletions arguments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ def __init__(self, parser, sentinel=False):
self._source_path = ""
self._model_path = ""
self._images = "images"
self._resolution = -1
self._resolution = 1
self._white_background = False
self.data_device = "cuda"
self.eval = False
self.render_items = ['RGB', 'Alpha', 'Normal', 'Depth', 'Edge', 'Curvature']
self.w_normal_prior = ""
self.w_mask = ""
self.use_decoupled_appearance = False
super().__init__(parser, "Loading Parameters", sentinel)

def extract(self, args):
Expand All @@ -81,17 +84,39 @@ def __init__(self, parser):
self.opacity_lr = 0.05
self.scaling_lr = 0.005
self.rotation_lr = 0.001
self.appearance_embeddings_lr = 0.001
self.appearance_network_lr = 0.001
self.percent_dense = 0.01
self.lambda_dssim = 0.2
self.lambda_dist = 0.0
self.lambda_dist = 0.
self.lambda_depth = 0.1
self.lambda_normal = 0.05
self.lambda_mask = 0.
self.lambda_normal_prior = 0.05
self.lambda_normal_gradient = 0.01
self.opacity_cull = 0.05

self.split_interval = 500
self.max_screen_size = 20

self.densification_interval = 100
self.opacity_reset_interval = 3000
self.densify_from_iter = 500
self.densify_until_iter = 15_000
self.densify_grad_threshold = 0.0002

self.propagation_interval = 20
self.depth_error_min_threshold = 0.8
self.depth_error_max_threshold = 1.0
self.propagation_begin = 9000
self.propagation_after = 15000
self.patch_size = 11

self.pixel_dense_from_iter = 30000

self.contribution_prune_from_iter = 500
self.contribution_prune_interval = 300
self.contribution_prune_ratio = 0.1
super().__init__(parser, "Optimization Parameters")

def get_combined_args(parser : ArgumentParser):
Expand Down
99 changes: 44 additions & 55 deletions gaussian_renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,29 @@
from utils.sh_utils import eval_sh
from utils.point_utils import depth_to_normal

def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0,
override_color = None, record_transmittance=False, bg_gaussians=None, skip_geometric=False):
"""
Render the scene.

Background tensor (bg_color) must be on GPU!
"""

if bg_gaussians is None:
means3D = pc.get_xyz
opacity = pc.get_opacity
scales = pc.get_scaling
rotations = pc.get_rotation
shs = pc.get_features
else:
means3D = torch.cat([pc.get_xyz, bg_gaussians.get_xyz])
opacity = torch.cat([pc.get_opacity, bg_gaussians.get_opacity])
scales = torch.cat([pc.get_scaling, bg_gaussians.get_scaling])
rotations = torch.cat([pc.get_rotation, bg_gaussians.get_rotation])
shs = torch.cat([pc.get_features, bg_gaussians.get_features])
num_fg_points = pc.get_xyz.shape[0]

# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
screenspace_points = torch.zeros((means3D.shape[0], 4), dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
try:
screenspace_points.retain_grad()
except:
Expand All @@ -46,74 +60,43 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
sh_degree=pc.active_sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
record_transmittance=record_transmittance,
debug=False,
# pipe.debug
)

rasterizer = GaussianRasterizer(raster_settings=raster_settings)

means3D = pc.get_xyz
means2D = screenspace_points
opacity = pc.get_opacity

# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
# scaling / rotation by the rasterizer.
scales = None
rotations = None
cov3D_precomp = None
if pipe.compute_cov3D_python:
# currently don't support normal consistency loss if use precomputed covariance
splat2world = pc.get_covariance(scaling_modifier)
W, H = viewpoint_camera.image_width, viewpoint_camera.image_height
near, far = viewpoint_camera.znear, viewpoint_camera.zfar
ndc2pix = torch.tensor([
[W / 2, 0, 0, (W-1) / 2],
[0, H / 2, 0, (H-1) / 2],
[0, 0, far-near, near],
[0, 0, 0, 1]]).float().cuda().T
world2pix = viewpoint_camera.full_proj_transform @ ndc2pix
cov3D_precomp = (splat2world[:, [0,1,3]] @ world2pix[:,[0,1,3]]).permute(0,2,1).reshape(-1, 9) # column major
else:
scales = pc.get_scaling
rotations = pc.get_rotation

# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
pipe.convert_SHs_python = False
shs = None
colors_precomp = None
if override_color is None:
if pipe.convert_SHs_python:
shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
else:
shs = pc.get_features
else:
colors_precomp = override_color

rendered_image, radii, allmap = rasterizer(
output = rasterizer(
means3D = means3D,
means2D = means2D,
shs = shs,
colors_precomp = colors_precomp,
colors_precomp = None,
opacities = opacity,
scales = scales,
rotations = rotations,
cov3D_precomp = cov3D_precomp
)

cov3D_precomp = None)

if record_transmittance:
rendered_image, radii, allmap, transmittance_avg, num_covered_pixels = output
transmittance_avg = transmittance_avg[:num_fg_points]
num_covered_pixels = num_covered_pixels[:num_fg_points]
else:
rendered_image, radii, allmap = output
transmittance_avg = num_covered_pixels = None
radii = radii[:num_fg_points]
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
# They will be excluded from value updates used in the splitting criteria.
rets = {"render": rendered_image,
"viewspace_points": means2D,
"visibility_filter" : radii > 0,
"radii": radii,
"pixels_num":num_covered_pixels,
"transmittance_avg": transmittance_avg
}


# additional regularizations
render_alpha = allmap[1:2]

Expand All @@ -140,18 +123,24 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
# for unbounded scene, use expected depth, i.e., depth_ration = 0, to reduce disk anliasing.
surf_depth = render_depth_expected * (1-pipe.depth_ratio) + (pipe.depth_ratio) * render_depth_median

# assume the depth points form the 'surface' and generate psudo surface normal for regularizations.
surf_normal = depth_to_normal(viewpoint_camera, surf_depth)
surf_normal = surf_normal.permute(2,0,1)
# remember to multiply with accum_alpha since render_normal is unnormalized.
surf_normal = surf_normal * (render_alpha).detach()

if skip_geometric:
# assume the depth points form the 'surface' and generate psudo surface normal for regularizations.
surf_normal_expected = depth_to_normal(viewpoint_camera, render_depth_expected).permute(2,0,1)
surf_normal = depth_to_normal(viewpoint_camera, render_depth_median).permute(2,0,1)
# remember to multiply with accum_alpha since render_normal is unnormalized.
surf_normal_expected = surf_normal_expected * (render_alpha).detach()
surf_normal = surf_normal * (render_alpha).detach()
else:
surf_normal_expected = render_normal
surf_normal = render_normal

rets.update({
'rend_alpha': render_alpha,
'rend_normal': render_normal,
'rend_depth': render_depth_expected,
'rend_dist': render_dist,
'surf_depth': surf_depth,
'surf_normal_expected': surf_normal_expected,
'surf_normal': surf_normal,
})

Expand Down
15 changes: 13 additions & 2 deletions scene/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,23 @@
import json
from utils.system_utils import searchForMaxIteration
from scene.dataset_readers import sceneLoadTypeCallbacks
from scene.gaussian_model import GaussianModel
from scene.gaussian_model import GaussianModel, BgGaussianModel
from scene.appearance_model import AppearanceModel
from arguments import ModelParams
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON

class Scene:

gaussians : GaussianModel

def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
def __init__(self, args : ModelParams, gaussians : GaussianModel, bg_gaussians: BgGaussianModel = None, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
"""b
:param path: Path to colmap scene main folder.
"""
self.model_path = args.model_path
self.loaded_iter = None
self.gaussians = gaussians
self.bg_gaussians = bg_gaussians

if load_iteration:
if load_iteration == -1:
Expand Down Expand Up @@ -79,12 +81,21 @@ def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration
"point_cloud",
"iteration_" + str(self.loaded_iter),
"point_cloud.ply"))
if self.bg_gaussians is not None:
self.bg_gaussians.load_ply(os.path.join(self.model_path,
"point_cloud",
"iteration_" + str(self.loaded_iter),
"bg_point_cloud.ply"))
else:
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
if self.bg_gaussians is not None:
self.bg_gaussians.load_ply('assets/background_gs.ply')

def save(self, iteration):
point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
if self.bg_gaussians is not None:
self.bg_gaussians.save_ply(os.path.join(point_cloud_path, "bg_point_cloud.ply"))

def getTrainCameras(self, scale=1.0):
return self.train_cameras[scale]
Expand Down
86 changes: 86 additions & 0 deletions scene/appearance_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class UpsampleBlock(nn.Module):
def __init__(self, num_input_channels, num_output_channels):
super(UpsampleBlock, self).__init__()
self.pixel_shuffle = nn.PixelShuffle(2)
self.conv = nn.Conv2d(num_input_channels // (2 * 2), num_output_channels, 3, stride=1, padding=1)
self.relu = nn.ReLU()

def forward(self, x):
x = self.pixel_shuffle(x)
x = self.conv(x)
x = self.relu(x)
return x

class AppearanceNetwork(nn.Module):
def __init__(self, num_input_channels, num_output_channels):
super(AppearanceNetwork, self).__init__()

self.conv1 = nn.Conv2d(num_input_channels, 256, 3, stride=1, padding=1)
self.up1 = UpsampleBlock(256, 128)
self.up2 = UpsampleBlock(128, 64)
self.up3 = UpsampleBlock(64, 32)
self.up4 = UpsampleBlock(32, 16)

self.conv2 = nn.Conv2d(16, 16, 3, stride=1, padding=1)
self.conv3 = nn.Conv2d(16, num_output_channels, 3, stride=1, padding=1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()

def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.up1(x)
x = self.up2(x)
x = self.up3(x)
x = self.up4(x)
# bilinear interpolation
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
x = self.conv2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.sigmoid(x)
return x

class AppearanceModel:
def __init__(self, num_embeddings, num_input_channels=67, num_output_channels=3):
self.appearance_network = AppearanceNetwork(num_input_channels, num_output_channels).cuda()

std = 1e-4
self._appearance_embeddings = nn.Parameter(torch.empty(num_embeddings, 64).cuda())
self._appearance_embeddings.data.normal_(0, std)

def get_embedding(self, idx):
return self._appearance_embeddings[idx]

def training_setup(self, training_args):
params = [
{'params': [self._appearance_embeddings], 'lr': training_args.appearance_embeddings_lr, "name": "appearance_embeddings"},
{'params': self.appearance_network.parameters(), 'lr': training_args.appearance_network_lr, "name": "appearance_network"}
]
self.optimizer = torch.optim.Adam(params, lr=0.0, eps=1e-15)

def load_state_dict(self, state_dict):
self._appearance_embeddings = state_dict["_appearance_embeddings"]
self.appearance_network.load_state_dict(state_dict["appearance_network"])

def state_dict(self):
return {
"_appearance_embeddings": self._appearance_embeddings,
"appearance_network": self.appearance_network.state_dict()
}



if __name__ == "__main__":
H, W = 1200//32, 1600//32
input_channels = 3 + 64
output_channels = 3
input = torch.randn(1, input_channels, H, W).cuda()
model = AppearanceNetwork(input_channels, output_channels).cuda()

output = model(input)
print(output.shape)
Loading