From 19b0f2f1522195df33b2ac8aec77618550a7c3ec Mon Sep 17 00:00:00 2001 From: Donovan Hutchence Date: Fri, 6 Dec 2024 10:50:43 +0000 Subject: [PATCH] Add spherical harmonics to compressed.ply export (#313) --- src/asset-loader.ts | 4 +- src/camera.ts | 3 - src/sh-utils.ts | 41 +++--- src/shaders/splat-shader.ts | 9 +- src/splat-serialize.ts | 256 ++++++++++++++++++++++++++---------- 5 files changed, 214 insertions(+), 99 deletions(-) diff --git a/src/asset-loader.ts b/src/asset-loader.ts index dcaa3cef..4fb3dc8c 100644 --- a/src/asset-loader.ts +++ b/src/asset-loader.ts @@ -150,7 +150,9 @@ class AssetLoader { } }); - asset.on('error', (err: string) => reject(err)); + asset.on('error', (err: string) => { + reject(err); + }); this.registry.add(asset); this.registry.load(asset); diff --git a/src/camera.ts b/src/camera.ts index 4c7e0027..25749a9e 100644 --- a/src/camera.ts +++ b/src/camera.ts @@ -49,9 +49,6 @@ const ray = new Ray(); const vec = new Vec3(); const vecb = new Vec3(); const va = new Vec3(); -const vb = new Vec3(); -const vc = new Vec3(); -const v4 = new Vec4(); // modulo dealing with negative numbers const mod = (n: number, m: number) => ((n % m) + m) % m; diff --git a/src/sh-utils.ts b/src/sh-utils.ts index 4555645c..65eaf176 100644 --- a/src/sh-utils.ts +++ b/src/sh-utils.ts @@ -37,7 +37,7 @@ const dp = (n: number, start: number, a: number[] | Float32Array, b: number[] | return sum; }; -const coeffsIn = new Float32Array(16); +const coeffsIn = new Float32Array(15); // Rotate spherical harmonics up to band 3 based on https://github.com/andrewwillmott/sh-lib // @@ -155,38 +155,35 @@ class SHRotation { src = coeffsIn; } - // band 0 - result[0] = src[0]; - // band 1 - if (result.length < 4) { + if (result.length < 3) { return; } - result[1] = dp(3, 1, src, sh1[0]); - result[2] = dp(3, 1, src, sh1[1]); - result[3] = dp(3, 1, src, sh1[2]); + result[0] = dp(3, 0, src, sh1[0]); + result[1] = dp(3, 0, src, sh1[1]); + result[2] = dp(3, 0, src, sh1[2]); // band 2 - if (result.length < 9) { + if (result.length < 8) { return; } - result[4] = dp(5, 4, src, sh2[0]); - result[5] = dp(5, 4, src, sh2[1]); - result[6] = dp(5, 4, src, sh2[2]); - result[7] = dp(5, 4, src, sh2[3]); - result[8] = dp(5, 4, src, sh2[4]); + result[3] = dp(5, 3, src, sh2[0]); + result[4] = dp(5, 3, src, sh2[1]); + result[5] = dp(5, 3, src, sh2[2]); + result[6] = dp(5, 3, src, sh2[3]); + result[7] = dp(5, 3, src, sh2[4]); // band 3 - if (result.length < 16) { + if (result.length < 15) { return; } - result[9] = dp(7, 9, src, sh3[0]); - result[10] = dp(7, 9, src, sh3[1]); - result[11] = dp(7, 9, src, sh3[2]); - result[12] = dp(7, 9, src, sh3[3]); - result[13] = dp(7, 9, src, sh3[4]); - result[14] = dp(7, 9, src, sh3[5]); - result[15] = dp(7, 9, src, sh3[6]); + result[8] = dp(7, 8, src, sh3[0]); + result[9] = dp(7, 8, src, sh3[1]); + result[10] = dp(7, 8, src, sh3[2]); + result[11] = dp(7, 8, src, sh3[3]); + result[12] = dp(7, 8, src, sh3[4]); + result[13] = dp(7, 8, src, sh3[5]); + result[14] = dp(7, 8, src, sh3[6]); }; } } diff --git a/src/shaders/splat-shader.ts b/src/shaders/splat-shader.ts index fb6a6097..40025034 100644 --- a/src/shaders/splat-shader.ts +++ b/src/shaders/splat-shader.ts @@ -171,6 +171,12 @@ void main(void) { color.xyz = max(vec3(0.0), color.xyz + evalSH(state, projState)); #endif + // apply tint/brightness + color = color * clrScale + vec4(clrOffset, 0.0); + + // don't allow out-of-range alpha + color.a = clamp(color.a, 0.0, 1.0); + // apply locked/selected colors if ((vertexState & 2u) != 0u) { // locked @@ -178,9 +184,6 @@ void main(void) { } else if ((vertexState & 1u) != 0u) { // selected color.xyz = mix(color.xyz, selectedClr.xyz * 0.8, selectedClr.a); - } else { - // apply tint/brightness - color = color * clrScale + vec4(clrOffset, 0.0); } #endif diff --git a/src/splat-serialize.ts b/src/splat-serialize.ts index fef6846a..c3c6399a 100644 --- a/src/splat-serialize.ts +++ b/src/splat-serialize.ts @@ -210,7 +210,7 @@ const serializePly = async (splats: Splat[], write: WriteFunc) => { shData.push(splatData.getProp(`f_rest_${i}`)); } - shCoeffs = [0]; + shCoeffs = []; } for (let i = 0; i < splatData.numSplats; ++i) { @@ -246,13 +246,13 @@ const serializePly = async (splats: Splat[], write: WriteFunc) => { if (hasSH) { for (let c = 0; c < 3; ++c) { for (let d = 0; d < 15; ++d) { - shCoeffs[d + 1] = shData[c * 15 + d][i]; + shCoeffs[d] = shData[c * 15 + d][i]; } transformCache.getSHRot(i).apply(shCoeffs, shCoeffs); for (let d = 0; d < 15; ++d) { - splat[`f_rest_${c * 15 + d}`] = shCoeffs[d + 1]; + splat[`f_rest_${c * 15 + d}`] = shCoeffs[d]; } } } @@ -315,7 +315,8 @@ class SingleSplat { read(splats: Splat[], index: CompressedIndex) { const splat = splats[index.splatIndex]; const { splatData } = splat; - const val = (prop: string) => splatData.getProp(prop)[index.i]; + const { i } = index; + const val = (prop: string) => splatData.getProp(prop)[i]; [this.x, this.y, this.z] = [val('x'), val('y'), val('z')]; [this.scale_0, this.scale_1, this.scale_2] = [val('scale_0'), val('scale_1'), val('scale_2')]; [this.f_dc_0, this.f_dc_1, this.f_dc_2, this.opacity] = [val('f_dc_0'), val('f_dc_1'), val('f_dc_2'), val('opacity')]; @@ -416,6 +417,18 @@ class Chunk { const sy = calcMinMax(scale_1); const sz = calcMinMax(scale_2); + // convert f_dc_ to colors before calculating min/max and packaging + const SH_C0 = 0.28209479177387814; + for (let i = 0; i < f_dc_0.length; ++i) { + f_dc_0[i] = f_dc_0[i] * SH_C0 + 0.5; + f_dc_1[i] = f_dc_1[i] * SH_C0 + 0.5; + f_dc_2[i] = f_dc_2[i] * SH_C0 + 0.5; + } + + const cr = calcMinMax(f_dc_0); + const cg = calcMinMax(f_dc_1); + const cb = calcMinMax(f_dc_2); + const packUnorm = (value: number, bits: number) => { const t = (1 << bits) - 1; return Math.max(0, Math.min(t, Math.floor(value * t + 0.5))); @@ -458,16 +471,6 @@ class Chunk { return result; }; - const packColor = (r: number, g: number, b: number, a: number) => { - const SH_C0 = 0.28209479177387814; - return pack8888( - r * SH_C0 + 0.5, - g * SH_C0 + 0.5, - b * SH_C0 + 0.5, - 1 / (1 + Math.exp(-a)) - ); - }; - // pack for (let i = 0; i < this.size; ++i) { this.position[i] = pack111011( @@ -484,10 +487,15 @@ class Chunk { normalize(scale_2[i], sz.min, sz.max) ); - this.color[i] = packColor(f_dc_0[i], f_dc_1[i], f_dc_2[i], opacity[i]); + this.color[i] = pack8888( + normalize(f_dc_0[i], cr.min, cr.max), + normalize(f_dc_1[i], cg.min, cg.max), + normalize(f_dc_2[i], cb.min, cb.max), + 1 / (1 + Math.exp(-opacity[i])) + ); } - return { px, py, pz, sx, sy, sz }; + return { px, py, pz, sx, sy, sz, cr, cg, cb }; } } @@ -571,76 +579,166 @@ const sortSplats = (splats: Splat[], indices: CompressedIndex[]) => { indices.sort((a, b) => morton[a.globalIndex] - morton[b.globalIndex]); }; -const serializePlyCompressed = async (splats: Splat[], write: WriteFunc) => { - const chunkProps = [ - 'min_x', 'min_y', 'min_z', - 'max_x', 'max_y', 'max_z', - 'min_scale_x', 'min_scale_y', - 'min_scale_z', 'max_scale_x', - 'max_scale_y', 'max_scale_z' - ]; +// returns the number of spherical harmonic bands present on a splat scene +const getSHBands = (splat: Splat) => { + let coeffs = 0; + for (; coeffs < 45; ++coeffs) { + if (!splat.splatData.getProp(`f_rest_${coeffs}`)) break; + } - const vertexProps = [ - 'packed_position', - 'packed_rotation', - 'packed_scale', - 'packed_color' - ]; + if (coeffs === 9) { + return 1; + } else if (coeffs === 24) { + return 2; + } else if (coeffs === 45) { + return 3; + } + + return 0; +}; + +const quantizeSH = (splats: Splat[], indices: CompressedIndex[], transformCaches: SplatTransformCache[]) => { + // get the maximum number of bands in the scene + const numBands = Math.max(...splats.map(s => getSHBands(s))); + if (numBands === 0) { + return null; + } + const numCoeffs = [3, 8, 15][numBands - 1]; + const propNames = new Array(numCoeffs * 3).fill('').map((_, i) => `f_rest_${i}`); + const splatSHData = splats.map((splat) => { + return propNames.map(name => splat.splatData.getProp(name)); + }); + const splatTints = splats.map((splat) => { + const { blackPoint, whitePoint, tintClr } = splat; + const scale = 1 / (whitePoint - blackPoint); + return [tintClr.r * scale, tintClr.g * scale, tintClr.b * scale]; + }); + const coeffs = new Float32Array(numCoeffs); + const data = new Uint8Array(indices.length * numCoeffs * 3); + + for (let i = 0; i < indices.length; ++i) { + const index = indices[i]; + const splatIndex = index.splatIndex; + const shIndex = index.i; + const shData = splatSHData[splatIndex]; + const tint = splatTints[splatIndex]; + const shRot = transformCaches[splatIndex].getSHRot(shIndex); + + for (let j = 0; j < 3; ++j) { + // extract sh coefficients + for (let k = 0; k < numCoeffs; ++k) { + const src = shData[j * numCoeffs + k]; + coeffs[k] = src ? src[shIndex] : 0; + } + + // apply tint + for (let k = 0; k < numCoeffs; ++k) { + coeffs[k] *= tint[j]; + } + + // rotate + shRot.apply(coeffs, coeffs); + + // quantize + for (let k = 0; k < numCoeffs; ++k) { + const nvalue = coeffs[k] / 8 + 0.5; + data[(i * 3 + j) * numCoeffs + k] = Math.max(0, Math.min(255, Math.trunc(nvalue * 256))); + } + } + } + return [{ + name: 'sh', + length: indices.length, + properties: new Array(numCoeffs * 3).fill('').map((_, i) => { + return { + name: `f_rest_${i}`, + type: 'uchar' + }; + }), + data + }]; +}; + +const serializePlyCompressed = async (splats: Splat[], write: WriteFunc) => { // create a list of indices spanning all splats - const indices: CompressedIndex[] = splats.reduce((indices, splat, splatIndex) => { - const splatData = splat.splatData; + const indices: CompressedIndex[] = []; + for (let splatIndex = 0; splatIndex < splats.length; ++splatIndex) { + const splatData = splats[splatIndex].splatData; const state = splatData.getProp('state') as Uint8Array; for (let i = 0; i < splatData.numSplats; ++i) { if ((state[i] & State.deleted) === 0) { - indices.push({ - splatIndex, - i, - globalIndex: indices.length - }); + indices.push({ splatIndex, i, globalIndex: indices.length }); } } - return indices; - }, []); + } if (indices.length === 0) { console.error('nothing to export'); return; } + // sort splats into some kind of order (morton order rn) + sortSplats(splats, indices); + + // create a transform cache per splat + const transformCaches = splats.map(splat => new SplatTransformCache(splat)); + const numSplats = indices.length; const numChunks = Math.ceil(numSplats / 256); + const quantizedSHData = quantizeSH(splats, indices, transformCaches); + + const chunkProps = [ + 'min_x', 'min_y', 'min_z', + 'max_x', 'max_y', 'max_z', + 'min_scale_x', 'min_scale_y', 'min_scale_z', + 'max_scale_x', 'max_scale_y', 'max_scale_z', + 'min_r', 'min_g', 'min_b', + 'max_r', 'max_g', 'max_b' + ]; + + const vertexProps = [ + 'packed_position', + 'packed_rotation', + 'packed_scale', + 'packed_color' + ]; + + const shHeader = quantizedSHData?.map((element) => { + return [ + `element ${element.name} ${element.length}`, + element.properties.map(prop => `property ${prop.type} ${prop.name}`) + ]; + }).flat(2); + const headerText = [ - [ - 'ply', - 'format binary_little_endian 1.0', - `comment ${generatedByString}`, - `element chunk ${numChunks}` - ], + 'ply', + 'format binary_little_endian 1.0', + `comment ${generatedByString}`, + `element chunk ${numChunks}`, chunkProps.map(p => `property float ${p}`), - [ - `element vertex ${numSplats}` - ], + `element vertex ${numSplats}`, vertexProps.map(p => `property uint ${p}`), - [ - 'end_header\n' - ] + shHeader ?? [], + 'end_header\n' ].flat().join('\n'); const header = (new TextEncoder()).encode(headerText); - const result = new Uint8Array(header.byteLength + numChunks * chunkProps.length * 4 + numSplats * vertexProps.length * 4); + + const result = new Uint8Array( + header.byteLength + + numChunks * chunkProps.length * 4 + + numSplats * vertexProps.length * 4 + + (quantizedSHData ? quantizedSHData.reduce((acc, x) => acc + x.data.byteLength, 0) : 0) + ); const dataView = new DataView(result.buffer); result.set(header); const chunkOffset = header.byteLength; - const vertexOffset = chunkOffset + numChunks * 12 * 4; + const vertexOffset = chunkOffset + numChunks * chunkProps.length * 4; - // sort splats into some kind of order - sortSplats(splats, indices); - - const transformCaches = splats.map(splat => new SplatTransformCache(splat)); const chunk = new Chunk(); const singleSplat = new SingleSplat(); @@ -665,20 +763,29 @@ const serializePlyCompressed = async (splats: Splat[], write: WriteFunc) => { const result = chunk.pack(); + const off = chunkOffset + i * 18 * 4; + // write chunk data - dataView.setFloat32(chunkOffset + i * 12 * 4 + 0, result.px.min, true); - dataView.setFloat32(chunkOffset + i * 12 * 4 + 4, result.py.min, true); - dataView.setFloat32(chunkOffset + i * 12 * 4 + 8, result.pz.min, true); - dataView.setFloat32(chunkOffset + i * 12 * 4 + 12, result.px.max, true); - dataView.setFloat32(chunkOffset + i * 12 * 4 + 16, result.py.max, true); - dataView.setFloat32(chunkOffset + i * 12 * 4 + 20, result.pz.max, true); - - dataView.setFloat32(chunkOffset + i * 12 * 4 + 24, result.sx.min, true); - dataView.setFloat32(chunkOffset + i * 12 * 4 + 28, result.sy.min, true); - dataView.setFloat32(chunkOffset + i * 12 * 4 + 32, result.sz.min, true); - dataView.setFloat32(chunkOffset + i * 12 * 4 + 36, result.sx.max, true); - dataView.setFloat32(chunkOffset + i * 12 * 4 + 40, result.sy.max, true); - dataView.setFloat32(chunkOffset + i * 12 * 4 + 44, result.sz.max, true); + dataView.setFloat32(off + 0, result.px.min, true); + dataView.setFloat32(off + 4, result.py.min, true); + dataView.setFloat32(off + 8, result.pz.min, true); + dataView.setFloat32(off + 12, result.px.max, true); + dataView.setFloat32(off + 16, result.py.max, true); + dataView.setFloat32(off + 20, result.pz.max, true); + + dataView.setFloat32(off + 24, result.sx.min, true); + dataView.setFloat32(off + 28, result.sy.min, true); + dataView.setFloat32(off + 32, result.sz.min, true); + dataView.setFloat32(off + 36, result.sx.max, true); + dataView.setFloat32(off + 40, result.sy.max, true); + dataView.setFloat32(off + 44, result.sz.max, true); + + dataView.setFloat32(off + 48, result.cr.min, true); + dataView.setFloat32(off + 52, result.cg.min, true); + dataView.setFloat32(off + 56, result.cb.min, true); + dataView.setFloat32(off + 60, result.cr.max, true); + dataView.setFloat32(off + 64, result.cg.max, true); + dataView.setFloat32(off + 68, result.cb.max, true); // write splat data const offset = vertexOffset + i * 256 * 4 * 4; @@ -691,6 +798,15 @@ const serializePlyCompressed = async (splats: Splat[], write: WriteFunc) => { } } + // write sh data + if (quantizedSHData) { + let offset = vertexOffset + numSplats * 4 * 4; + quantizedSHData.forEach((element) => { + result.set(new Uint8Array(element.data.buffer), offset); + offset += element.data.byteLength; + }); + } + await write(result, true); };