TensorRF Explained: A milestone in Novel View Synthesis

Aspiring Computer Vision engineer, eager to delve into the world of AI/ML, cloud, Computer Vision, TinyML and other programming stuff.
Novel View Synthesis has always been a prominent problem to address in 3D Computer Vision. Directly succeeding techniques like DeepVoxels and Scene Representation Networks, we have Neural Radiance Fields, or simply NeRFs.
Neural Radiance Fields is a way of utilizing a MLP based Deep Learning architecture while considering both the color and lighting of the scene to obtain a photorealistic representation. This technology was introduced to us in 2020, and since then has matured enough to be used in actual industrial pipelines.

This GIF is from the website of Instant-NGP which is an improvement over NeRF. You can see the full resolution GIF here
But this advancement was only possible by surpassing NeRF’s biggest issue : Training time. The original NeRF paper reveals that the implementation required 100-300 thousand iterations to converge on Nvidia V100 GPUs which took about 1-2 days.
Due to the high potential of this methods, many researchers proposed various optimizations resulting in superfast methods like Instant-NGP, DVGO, etc. One of them is TensoRF which we are going to discuss today. For starters, TensoRF reduced the training time to less than 15 minutes! This obviously demands a full fledged analysis which we will entail in this article. But before that, we will need to have a decent understanding of the vanilla NeRF architecture and working to find out why the improvements make sense.
NeRF being a dense topic, require a lot of time and attention to understand for someone new to the field. The best way to do so is to go through the official paper. But I will lay out the crux of the method here for a gentle entry and provide resources for further exploration.
Neural Radiance Fields
The basic outlook is, that NeRF employ massive MLPs (8 layer) in a non-traditional fashion. They use a 5D vector as an input (3D vector for coordinates \(\mathbf{x} = (x, y, z)\) and 2D viewing direction \(\mathbf{d} = (\theta, \phi)\)) and output The emitted color \(\mathbf{c} = (r, g, b)\) and volume density \(\sigma\).
$$F_{\Theta} : (\mathbf{x}, \mathbf{d}) \rightarrow (\mathbf{c}, \sigma)$$
The density \(\sigma\) depends only on the location \(\mathbf{x}\) (geometry is view-independent), while color \(\mathbf{c}\) depends on both \(\mathbf{x}\)and \(\mathbf{d}\) (allowing for view-dependent lighting effects like specular reflections).

Volume Rendering with Ray Marching
To generate an image from a specific viewpoint, NeRF renders the scene by shooting rays through the image pixels into the volume.
Ray definition: A ray is defined as \(\mathbf{r}(t) = \mathbf{o} + t\mathbf{d}\) where \(\mathbf{o}\) is the camera origin and \(\mathbf{d}\) is the direction.
Integral Equation: The expected color \(C(\mathbf{r})\) of camera ray \(\mathbf{r}\) is calculated by integrating along the ray within near \((t_n)\) and far \((t_f)\) bounds:
$$C(\mathbf{r}) = \int_{t_n}^{t_f} T(t) \sigma(\mathbf{r}(t)) \mathbf{c}(\mathbf{r}(t), \mathbf{d}) dt$$
Here, $T(t)$ denotes transmittance—the probability that the ray travels from \(t_n\) to $t$ without hitting any other particle:
$$T(t) = \exp\left(-\int_{t_n}^t \sigma(\mathbf{r}(s)) ds\right)$$
- Numerical Estimation: In practice, this integral is estimated using stratified sampling (breaking the ray into $N$ bins and picking random points) and summing the results:
$$\hat{C}(\mathbf{r}) = \sum_{i=1}^{N} T_i (1 - \exp(-\sigma_i \delta_i)) \mathbf{c}_i$$
Where \(\delta_i\) is the distance between adjacent samples.
These set of equations is crucially important. In fact, the volume rendering process remains almost same in other improvements regardless of other architectural changes
The Massive MLP

The network starts by taking only the 3D position \(\mathbf{x} = (x, y, z)\) as input. It passes this position through 8 fully-connected layers (each with 256 neurons and ReLU activation).
Skip Connection: To help the network remember the original input coordinate (which can get lost in deep networks), the input \(\mathbf{x}\) is concatenated back into the network at Layer 5.
After these 8 layers, the network effectively "knows" if there is an object at that location or just empty air. It outputs the Volume Density \(\sigma\).
Only after the density is decided does the network look at the Viewing Direction \((\mathbf{d})\). The 256-dimensional feature vector from the end of the 8th layer is concatenated with the viewing direction \(\mathbf{d}\).
This combined data goes into final hidden layer (Layer 9, 128 channels). It outputs the RGB Color \(\mathbf{c}\).
Why split it this way?
This architecture enforces a physical rule: an object's shape doesn't change when you move your head, but its color might (due to reflections). Layers 1-8 handle the pure geometry (Density), or in layman terms, it handles the question - "Is there a wall here?" While Layer 9 handles Radiance or "How shiny does this wall look from this angle?"
The size of the MLP itself is a major cause for the high training time, which is one of the challenges beautifully tackled by later renditions.
Other than these two, positional encoding, hierarchical sampling etc. were used too to boos the performance, but we will not discuss them here for various reasons said before.
The Root Causes (and Fixes)
Okay I did say that the MLP was responsible, but this isn’t that straightforward. Infact, we already have similar sized MLPs which are way more efficient (as in StyleGAN, and PointNet). The slowness comes from the frequency of access and other various reasons.
1. The "Query" Problem (Dense Compute for Empty Space)
To render one 800x800 image, NeRF casts roughly ~640,000 rays. Along each ray, it samples ~192 points.
Total Queries: \(640,000 \times 192 \approx 122 \text{ million queries}\).
For every single one of these 122 million points, the massive 8-layer MLP must be executed.
However, almost 90% of a scene is empty air. NeRF spends massive compute power running a deep neural network just to output
density = 0. It has no easy way to "skip" empty space without querying the network first.
- Global Entanglement (Global Weights vs. Local Detail)
In a standard MLP, every weight is "global." Changing one weight in Layer 1 potentially alters the geometry everywhere in the scene.
To learn a tiny detail (like a scratch on a wall), the optimizer has to carefully adjust global weights without breaking the rest of the scene. This creates a complex optimization landscape that requires thousands of iterations to converge.
How TensoRF Fixes This
We will discuss the processes mentioned below in detail later. For now, we will only focus on what they managed to achieve
1. Shift from "Computation" to "Lookup" \(\mathbf{O}(\text{N}) \to \mathbf{O}(1)\) :
It utilizes VM decomposition paired with trilinear interpolation which is essentially the improved storing and fetching technique we were talking about. The MLP in this version is only 2-3 layer and is only used at the very end to decode that fetched feature into a color. This means, that the heavy lifting is moved from processing to memory access.
2. Update Locality
This is the biggest factor for convergence speed. Because TensoRF represents the scene as explicit tensor factors (vectors and matrices), the gradients are local unlike the previous case.
When the loss function says "this pixel is wrong," the optimizer only updates the specific values in the tensor vectors/matrices that correspond to that specific spatial location. It doesn't have to re-balance the weights of a global neural network.
3. Coarse-to-Fine Grid Upsampling
You are building the image out of tiles (voxels). Coarse-to-Fine Upsampling is like starting with giant tiles (4x4 grid) to get the general colors right. Once those are set, you split them into smaller tiles (128x128), and then even smaller ones (300x300).
The grid capacity physically grows to accommodate specific features after the general features are established. This is achieving exactly what Positional Encoding was doing in Vanilla NeRF, but better, because :
4. Feature-Grid vs. Coordinate-Based
NeRF suffers from "Spectral Bias" (it learns low frequencies easily but struggles with high frequencies/sharp edges). TensoRF bypasses this because it stores features directly in the grid (tensors).
Sharp edges are naturally preserved by the grid resolution rather than hoping an MLP can approximate a step function.
VM Decomposition
Before we proceed, it is important to discuss this one separately. A neural radiance field stores Density \(\sigma(x, y, z)\) and View-dependent appearance \(\mathsf{c} (x, y,z, d)\). These can be viewed as high-dimensional tensors:
Density → a 3D tensor: \(\mathcal{T}_\sigma \in R^{X×Y×Z} \)
Color features → a 4D tensor: \(\mathcal{T}_{c} \in \mathbb{R}^{X \times Y \times Z \times D}\)
The Idea is to break them down them using some sort of tensor decomposition into a combination of low rank tensors. One way of that is CP decomposition which is also used in some TensoRF implementations (including the official paper) since it requires less compute and is a more relaxed compression. But we will discuss only the novel VM decomposition which goes like this :
$$\mathbf{T} \approx \sum_{r=1}^{R} \left( \mathbf{M}^{r}{xy} \otimes \mathbf{v}^{r}{z} \;+\; \mathbf{M}^{r}{yz} \otimes \mathbf{v}^{r}{x} \;+\; \mathbf{M}^{r}{zx} \otimes \mathbf{v}^{r}{y} \right)$$
To break this down, we need to understand what the core idea is : The variation along x axis interacts with y and z axis (and same for the others). So what the approximation is doing is for a single term (say for x axis) :
Storing what happens along x.
Storing what happens along yz plane
The outer product just expands the plane towards the other axis’ direction. That’s why using trilinear interpolation while decoding and generating this tensor from the stored vectors and matrices is actually intuitive.

This is a badly converted GIF of the visual animation from an official video by one of the authors (referenced at the end).
This drastically reduces memory and speeds up training, while still allowing the representation to be expressive enough for view-dependent effects (specularities, glossiness). The reason why this was chosen over CP decomposition and the role of mode rank in this is elaborated separately in the Appendix.
TensoRF (Core implementation methods)
We will discuss the major parts of the implementation with code one by one. As usual, I will only discuss the important parts of the code and a link to the complete notebook will be provided in the end.
Data Loading
It is a very standard process. The Nerf-Synthetic dataset we are using provides transforms_train.json file for each class (we are only using the mic class from the dataset). This file provides all the necessary metadata and a easier way to load all the images. We will just use the json library to retrieve all the necessary data. This mainly includes images, poses (rotation matrices) and horizontal field of view.
One technique we are utilizing is Precropping or center cropping for first N iterations during sampling. Since the mic class has an object in center with white background around. Precropping allows us to focus on the rays from the center for the initial iterations. Later we will expand to rays from rest of the image.
(All objects in the NeRF Synthetic Dataset go like this)

(Infact this is an established technique which allows us to achieve 32 PSNR on mic class, improving on the original paper.)
# pre-cropping during random sampling
img_ids = np.random.randint(0, self.N, size=total_batch_size)
if iter_step < precrop_iters:
js = np.random.randint(self.h_start, self.h_start + self.center_crop_size, size=total_batch_size)
is_ = np.random.randint(self.w_start, self.w_start + self.center_crop_size, size=total_batch_size)
else:
js = np.random.randint(0, self.H, size=total_batch_size)
is_ = np.random.randint(0, self.W, size=total_batch_size)
Rest of the procedure is as follows :
- Calculate the focal length using image width and horizontal Field of View.
$$f = \frac{W/2}{\tan \left( \frac{\theta}{2} \right)}$$
self.focal = 0.5 * W / np.tan(0.5 * self.camera_angle_x)
- Then create a camera-space ray directions :
$$\begin{bmatrix} (x−W/2)/f \\ −(y−H/2)/f \\ −1 \end{bmatrix}$$
base_dirs = np.stack([
(is_ - self.W * 0.5) / self.focal,
-(js - self.H * 0.5) / self.focal,
-np.ones(total_batch_size, dtype=np.float32)], axis=-1)
- Next step is to extract the origins and rotations matrices of rays from the poses we extracted earlier through the image ids. The rotation matrices (or transform matrices) are required because the base directions are in camera space and we have to get them in the world space.
# using einsum for optimized batch operations
rays_o = self.poses[img_ids, :3, 3]
R = self.poses[img_ids, :3, :3]
rays_d = np.einsum('...i,...ji->...j', base_dirs, R)
- Calculate Final color using the alpha blending equation \(C = \alpha \cdot C_f + ( 1- \alpha) \cdot C_b\) :
selected_px = self.imgs[img_ids, js, is_, :]
alpha = selected_px[..., 3:4]
gt_rgb = selected_px[..., :3] * alpha + np.array([1.0, 1.0, 1.0], dtype=np.float32)[None, :] * (1.0 - alpha)
- Remember to also create a simple full image version of loader for pure rendering purposes.
def get_full_image_rays(self, idx):
pose = self.poses[idx]
i, j = np.meshgrid(np.arange(self.W, dtype=np.float32),
np.arange(self.H, dtype=np.float32), indexing='xy')
dirs = np.stack([(i - self.W * 0.5) / self.focal,
-(j - self.H * 0.5) / self.focal,
-np.ones_like(i)], -1)
rays_d = np.sum(dirs[..., np.newaxis, :] * pose[:3, :3], -1)
rays_o = np.broadcast_to(pose[:3, 3], rays_d.shape)
return rays_o, rays_d
Model Class
Initialization
This is where we initialize the tensor components, basis matrix and MLP. There are a few things we have to take into consideration while doing this.
Grid Dimension and Bounding Box Limits: These need to be marked as static (unlearnable parameters) because the grid dimension will change due to coarse-to-fine upsampling, and we must ensure this does not interfere with training.
Density vs. Color Representation: The scene is divided into density and color aspects, using 8 and 24 components respectively as standard values. Density is relatively simple to represent (smoothly changing shapes), while color is complex, containing textures and lighting variations.
Basis Matrix
The appearance grid produces 24 × 3 = 72 raw features, which are redundant and over-parameterized, making them unfit for view-dependent rendering.
Function: Multiplication with a basis matrix compresses and mixes these raw features into a denser, high-capacity latent space. This facilitates better training, especially for Tiny MLPs.
Dimensions: The basis matrix is sized
(app_dim, 27). The specific choice of 27 (and 54 for the MLP input size) will be explained later.
class TensoRF(eqx.Module):
grid_dim: int = eqx.field(static=True)
bbox_min: jax.Array
bbox_max: jax.Array
den_planes: List[jax.Array]
den_lines: List[jax.Array]
app_planes: List[jax.Array]
app_lines: List[jax.Array]
basis_mat: jax.Array
mlp_render: eqx.nn.MLP
def __init__(self, key, grid_dim=128, n_comp_den=[8,8,8], n_comp_app=[24, 24, 24], bbox_min=-1.5, bbox_max=1.5):
keys = jax.random.split(key, 10)
self.bbox_min = jnp.array([bbox_min]*3)
self.bbox_max = jnp.array([bbox_max]*3)
self.grid_dim = grid_dim
self.den_planes, self.den_lines = self.init_tensor_components(keys, n_comp_den, 0)
self.app_planes, self.app_lines = self.init_tensor_components(keys, n_comp_app, 6)
self.basis_mat, self.mlp_render = self.init_decoders(keys, sum(n_comp_app))
def init_tensor_components(self, keys, n_components, key_offset):
planes, lines = [], []
plane_dims = [[0,1], [0,2], [1,2]]
for i, (c, p_dim) in enumerate(zip(n_components, plane_dims)):
p_shape = (c, self.grid_dim, self.grid_dim)
l_shape = (c, self.grid_dim, 1)
p = jax.random.normal(keys[key_offset + i], p_shape) * 0.1
l = jax.random.normal(keys[key_offset + i + 3], l_shape) * 0.1
planes.append(p)
lines.append(l)
return planes, lines
def init_decoders(self, keys, app_dim):
basis_matrix = jax.random.normal(keys[9], (app_dim, 27)) * 0.1
mlp_render = eqx.nn.MLP(in_size=54, out_size=3, width_size=128, depth=2, key=keys[0])
return basis_matrix, mlp_render
Interpolation
First we have to normalize and scale the coordinates. Also, we have to flip the coordinates (x, y) to (y, x) since y represents height and x represents width (Basically cartesian coordinates to matrix indexing).
def normalize_coordinates(self, xyz):
min_b = jax.lax.stop_gradient(self.bbox_min)
max_b = jax.lax.stop_gradient(self.bbox_max)
# We use lax to stop gradients for bounding box limits as we decided
return (xyz - min_b) / (max_b - min_b)
def interpolate_tensor_components(self, xyz_normed, planes, lines):
coordinate_plane = [xyz_normed[..., [0,1]], xyz_normed[..., [0,2]], xyz_normed[..., [1,2]]]
coordinate_line = [xyz_normed[..., [2]], xyz_normed[..., [1]], xyz_normed[..., [0]]]
grid_dim = self.grid_dim
scaled_plane = [c * (grid_dim - 1) for c in coordinate_plane]
scaled_line = [c * (grid_dim - 1) for c in coordinate_line]
results = []
for i in range(3):
coords_p = scaled_plane[i].reshape(-1, 2).T
coords_p = coords_p[::-1, :] # flip to (y,x)
plane_val = manual_bilinear(planes[i], coords_p)
coords_l = scaled_line[i].reshape(-1, 1).T
coords_l_2d = jnp.stack([coords_l[0], jnp.zeros_like(coords_l[0])], axis=0)
line_val = manual_bilinear(lines[i], coords_l_2d)
results.append(plane_val * line_val)
return results
I am using a manual bilinear interpolation method. This is because I was getting some trouble with scipy’s mapping_coordinates method. But this can be optimized for sure.
In the first block, the input scaled_plane[i] has shape (Batch, 2) containing $(u,v)$ coordinates.
.reshape(-1, 2)ensures it is a list of pairs..Ttransposes it to shape(2, Batch).
In the second block of the loop, we are utilizing a clever trick. Since the manual_bilinear function is set for 2D matrices. We will reshape lines to imitate a very thin 2D matrix. The actaual values of the line are stored in coords_l[0].
coords_l[0]: Extracts the coordinate values (the heights).jnp.zeros_like(...): Creates a matching array of zeros.jnp.stack: Combines them into pairs of $(z,0)$.
As for the main function responsible for interpolation :
- Finding the Neighbors
Given a continuous point $(x,y)$, we find the four integer corners surrounding it.
$$\begin{gather} x\_0 = \lfloor x \rfloor, \quad x\_1 = x\_0 + 1 \\ y\_0 = \lfloor y \rfloor, \quad y\_1 = y\_0 + 1 \end{gather}$$
Top-Left -\((x_0, y_0)\), Top-Right - \((x_1, y_0)\), Bottom-Left: \((x_0, y_1)\), Bottom-Right: \((x_1, y_1)\)
- Calculating Weights (Areas)
The influence (weight) of a corner is determined by the area of the rectangle opposite to it.
- If the point is very close to the Top-Left, the area of the Bottom-Right rectangle formed by the split is largest.
The code calculates these 4 weights:
wa = (x1 - x) * (y1 - y) # Weight for Top-Left
wb = (x1 - x) * (y - y0) # Weight for Bottom-Left
wc = (x - x0) * (y1 - y) # Weight for Top-Right
wd = (x - x0) * (y - y0) # Weight for Bottom-Right
- Fetching Corner Values
The code converts the continuous indices to integers and clips them (to ensure we don't crash by reading outside the image).
x0_idx = jnp.clip(x0, 0, W-1).astype(jnp.int32)
x1_idx = jnp.clip(x1, 0, W-1).astype(jnp.int32)
y0_idx = jnp.clip(y0, 0, H-1).astype(jnp.int32)
y1_idx = jnp.clip(y1, 0, H-1).astype(jnp.int32)
Ia = grid[:, y0_idx, x0_idx]
Ib = grid[:, y1_idx, x0_idx]
Ic = grid[:, y0_idx, x1_idx]
Id = grid[:, y1_idx, x1_idx]
- Weighted Sum
Finally, we sum the four corner values multiplied by their respective weights.
$$\text{Value} = w_a I_a + w_b I_b + w_c I_c + w_d I_d$$
return wa[None, :] * Ia + wb[None, :] * Ib + wc[None, :] * Ic + wd[None, :] * Id
This results in a smooth transition of values as the sampling point moves across the grid cells. To extract the fruits of our labor we have this method :
def get_sigma_feat(self, xyz_normed):
den_components = self.interpolate_tensor_components(xyz_normed, self.den_planes, self.den_lines)
sigma = sum(jnp.sum(comp, axis=0) for comp in den_components)
sigma = jax.nn.softplus(sigma) * 5.0
app_components = self.interpolate_tensor_components(xyz_normed, self.app_planes, self.app_lines)
app_feats = jnp.concatenate(app_components, axis=0).T
return sigma, app_feats
This gives us the density and appearance features we were trying to obtain.
softplus\(\left (\log(1+e^x) \right)\)is a smooth approximation of ReLU that ensures positivity while avoiding the "dying ReLU" problem where gradients vanish for negative inputs.
Volumetric Rendering
In volume rendering, we don't just look at a surface; we look through the volume. To render a single pixel, we cast a ray from the camera and sample multiple points along its path.
If we always sampled at fixed depths (e.g., 1.0m, 1.1m, 1.2m), the model would overfit to those specific distances, creating visual artifacts known as "ringing." To prevent this, we add random noise (t_rand) to perturb the sample positions within their intervals. This converts structured artifacts into less noticeable noise, which is easier for the network to handle. This is called Jittering.
def sample_along_rays(self, rays_o, rays_d, n_samples, key=None):
near, far = 0.2, 6.0
z_vals = jnp.linspace(near, far, n_samples)
z_vals = jnp.broadcast_to(z_vals, (rays_o.shape[0], n_samples))
if key is not None:
mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
upper = jnp.concatenate([mids, z_vals[..., -1:]], -1)
lower = jnp.concatenate([z_vals[..., :1], mids], -1)
t_rand = jax.random.uniform(key, z_vals.shape)
z_vals = lower + (upper - lower) * t_rand
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
return pts, z_vals
Ray Equation: We convert depths $t$ (or z_vals) into 3D coordinates using the vector equation \(\mathbf{r}(t) = \mathbf{o} + t\mathbf{d}\)
Now comes the view direction encoding, which we are doing using regular sin-cos positional encoding which was also done in vanilla NeRF. But this also explains why Basis Martix was chosen to be of size (N, 27).
As you can see we are returning the encoded directions with raw directions concatenated. The raw directions are of size (N,3) and since the number of frequencies we have chosen is 4, that will apply to each raw direction vector, which means that the encoded directions are of size (3 × 4) x 2 (each of sine and cos). This give the final size of the encoded directions vector to be (N, 24), making the size of the returning vector to be 24 + 3 = 27 → (N, 27) .
Basically, you have to match the size of encoded view directions after concatenation and basis matrix since they have to be multiplied later to ensure proper mixing and compression. This is also why the input size of the MLP is 54 (27 + 27) as we want it to treat both the appearance features and encoded view directions equally.
def encode_view_directions(self, directions, num_freqs=4):
freqs = 2.0 ** jnp.linspace(0, num_freqs-1, num_freqs)
dirs_enc = jnp.concatenate([jnp.sin(directions[..., None] * freqs),
jnp.cos(directions[..., None] * freqs)
], -1)
dirs_enc = dirs_enc.reshape(directions.shape[0], -1)
return jnp.concatenate([dirs_enc, directions], axis=-1)
Now is the time for final volumetric rendering. You will see the equations we discussed before making a return here. First is alpha compositing :
$$\alpha_i = 1 - \exp(-\sigma_i \delta_i)$$
This \(\sigma\) is the one we are getting from get_sigma_feat and the \(\delta\) represents physical distance between adjacent samples \(\delta_i = z_{i+1} - z_i\):
dists = z_vals[..., 1:] - z_vals[..., :-1]
# The last distance is set to infinity (1e10) to handle the background
dists = jnp.concatenate([dists, jnp.broadcast_to(1e10, dists[..., :1].shape)], -1)
# We multiply by norm(rays_d) to correct for non-unit direction vectors
dists = dists * jnp.linalg.norm(rays_d[..., None, :], axis=-1)
# alpha compositing main step
alpha = 1.0 - jnp.exp(-sigma * dists)
This \(\alpha\) is what tells us how opaque an object is. It relates volumetric density to the probability of light being occluded over distance \(\delta_i\). If \(\sigma \to 0\), \(\alpha \to 0\) (transparent) and if \(\sigma \to \infty\), \(\alpha \to 1\) (opaque).
Moving to the next equation :
$$T_i = \prod_{j=1}^{i-1} (1 - \alpha_j)$$
Transmittance \(T_i\) is the probability that a ray travels from the camera to sample $i$ without hitting anything. The weight \(w_i\) is the contribution of sample $i$ to the final pixel.
transmittance = jnp.cumprod(1.0 - alpha + 1e-10, axis=-1)
weights = alpha * jnp.concatenate([ jnp.ones((alpha.shape[0], 1)),
transmittance[..., :-1] ], -1)
jnp.cumprodcalculates the running product. We shift it right by one (using concatenate with ones at the start) because \(T_1\) (transmittance to the first point) is always 1.0.
Finally, we sum the weighted colors and depths.
$$\begin{gather} \hat{C}(\mathbf{r}) = \sum_{i=1}^{N} w_i \mathbf{c}i \\ \\ \hat{D}(\mathbf{r}) = \sum{i=1}^{N} w_i z_i \end{gather}$$
rgb_map = jnp.sum(weights[..., None] * rgb, axis=-2)
depth_map = jnp.sum(weights * z_vals, axis=-1)
acc_map = jnp.sum(weights, -1)
rgb_out = rgb_map + (1. - acc_map[..., None]) * bg_color
acc_map (accumulated opacity) is \(\sum w_i\). If the ray hits nothing \((\sum w_i < 1)\), the remaining weight is assigned to the background color (bg_color). This is a part of background compositing.
Culmination in call method
As we know, in equinox models, __call__ method acts as the "forward pass" of the model. It takes ray origins and directions and outputs the final pixel color. This also shows the entire pipeline till now.
- Ray Sampling : the model generates sample points along each ray using the method we wrote earlier.
Input Shapes:
rays_o: \((N_{rays}, 3) \) ,rays_d: \((N_{rays}, 3)\).Output Shapes:
pts: \((N_{rays}, N_{samples}, 3)\) — The 3D coordinates of every sample point.z_vals: \((N_{rays}, N_{samples})\) — The depth of each sample along the ray.
Coordinate Normalization and Feature Extraction. Points outside the voxel grid are masked to set density to 0. Then
get_sigma_featare used for retrieving \((\sigma)\) and appearance features from the VM decomposition.Appearance Projection using basis matrix and input to MLP.
$$\mathbf{c} = \text{MLP}(\mathbf{a} \cdot \mathbf{B}, \gamma(\mathbf{d}))$$
Where \(\mathbf{B}\) is basis_mat and \(\gamma\) is positional encoding.
Loss Function
The original TensoRF paper is optimized on per pixel using a L2 rendering loss, but they found that this factorization-based model can overfit or get stuck in local minima, especially in regions with few observed views, leading to outliers and noise.
Hence, they use a combination of L1 norm loss on the factors to encourage sparsity and Total Variation (TV) loss on the factors to encourage smoothness. If you haven’t heard of the second one like me then it goes like this :
For a 2D image $I$ of dimensions \(H \times W\), the Total Variation is the sum of the absolute differences for neighboring pixel values in the horizontal and vertical directions.
$$L_{TV}(I) = \sum_{i,j} \left| I_{i+1, j} - I_{i, j} \right| + \left| I_{i, j+1} - I_{i, j} \right|$$
TV Loss penalizes the differences between neighbors. If pixel $A$ is very bright and its immediate neighbor pixel $B$ is dark, TV Loss generates a high penalty. This forces the model to smooth out these rapid changes, effectively "ironing out" the noise.
The paper says that L1 sparsity loss alone is sufficient for most datasets, but more complex real-world datasets with very few images or imperfect captures (like LLFF or Tanks and Temples) may require additional regularization (the text cuts off, implying the use of both L1 and TV).
In our implementation, the final loss equation is :
$$\mathcal{L}{\text{total}} = \mathcal{L}{\text{MSE}} + \lambda_{\text{TV}} + \left( \mathcal{L}^{\text{den}}{\text{TV}} + 0.1 \times \mathcal{L}^{\text{app}}{\text{TV}} \right) + \lambda_{\text{L1}}\mathcal{L}_{\text{L1}} $$
def l1_on_factors(planes, lines):
l1 = 0.0
for p in planes + lines:
l1 += jnp.mean(jnp.abs(p))
return l1
def compute_tv_loss(planes, lines):
tv = 0.0
for p in planes:
tv += jnp.mean(jnp.abs(p[:, 1:, :] - p[:, :-1, :])) + \
jnp.mean(jnp.abs(p[:, :, 1:] - p[:, :, :-1]))
for l in lines:
tv += jnp.mean(jnp.abs(l[:, 1:, :] - l[:, :-1, :]))
return tv
def loss_fn(model, rays_o, rays_d, target_rgb, key, tv_weight, l1_weight, bg_color):
pred_rgb, _, weights = model(rays_o, rays_d, key, bg_color)
mse_loss = jnp.mean((pred_rgb - target_rgb) ** 2)
tv_den = compute_tv_loss(model.den_planes, model.den_lines)
tv_app = compute_tv_loss(model.app_planes, model.app_lines)
all_planes = model.den_planes + model.app_planes
all_lines = model.den_lines + model.app_lines
l1_factors = l1_on_factors(all_planes, all_lines)
total_loss = mse_loss + tv_weight * (tv_den + 0.1 * tv_app) + l1_weight * l1_factors
return total_loss, mse_loss
Also, further down the code we use dynamic loss weighting. This means Total Variation (TV) loss weight isn't constant; it decays exponentially. This because high TV regularization is needed early to prevent noise when the grid is coarse, but it must be reduced as the grid upsamples to allow for fine details (high-frequency features).
current_tv_weight = np.exp((1 - alpha) * np.log(TV_START_WEIGHT) + alpha * np.log(TV_END_WEIGHT))
This completes most of the important discussion. I will also discuss some minor implementation details and techniques I found important while trying to implement this in equinox.
Implementation Techniques
In case you had not picked it up by now, I took some help with Gemini 3 to write this code. However, while analyzing and reiterating on it later, I stumped across a few things which seemed confusing or might be to some since they are not seen often. I will try my best to explain them here. I will avoid the code as most of these things are specific to Equinox and optimizations on multi-core TPUs.
Stopping Gradients
Grid Dimensions (
static=True): We mark these as static so they are known at compile time and treated as Python values, not tensors. This is mandatory—avoiding it would cause an error.Bounding Box Limits (
jax.lax.stop_gradients): We want these as arrays to keep coordinate math consistent (all operations on tensors) and allow them to be loaded from checkpoint files in the future. We stop their gradients explicitly to prevent the model from changing them to fit the loss function, as static fields are typically not saved in model weights, whereas arrays are.
Step Count Re-injection : When the model shape changes due to upsampling, the Adam optimizer must be re-initialized because parameter shapes have changed.
A normal re-initialization resets the optimizer's internal
count(step number), which would restart the learning rate schedule.The
restore_step_countfunction injects the old step count into the new optimizer state, allowing the learning rate decay to continue smoothly despite the mid-training architecture change.
Single Program Multiple Data (SPMD) : Used to utilize Kaggle’s TPU v5 (8 cores). Equinox facilitates this with
eqx.partition()andeqx.combine().Partitioning:
params, static = eqx.partition(model, eqx.is_array)splits the model into trainable arrays (weights) and static configuration (e.g., grid dimensions, activation functions).Replication:
params_rep = jax.device_put_replicated(params, devices)explicitly copies the model weights to every TPU core before the training loop.Recombination: Inside the
parallel_step_fn,model_local = eqx.combine(params, static)stitches the model back together so it can be called like a standard Python object.
Manual Gradient Scaling
Instead of relying solely on the optimizer, the
scale_mlp_gradsfunction manually intervenes in gradients before the update.It multiplies the gradients of the MLP layers by
0.1(or0.05in other cells).This technique effectively gives the MLP a learning rate 10x smaller than the grids' without requiring a complex multi-optimizer setup.
Hyperparameter Details and Results
1. Model Architecture & Grid Settings
Grid Resolution:
Initial Resolution: \(128^3\)
Final Resolution: \(512^3\)
Upsampling Schedule: The grid resolution is progressively increased at specific iteration steps:
Iter 2000 \(\rightarrow\) 150
Iter 3000 \(\rightarrow\) 200
Iter 4000 \(\rightarrow\)300
Iter 5500 \(\rightarrow\)400
Iter 7000 \(\rightarrow\)512
Tensor Components:
Density Components (
n_comp_den):[8, 8, 8](8 components per plane/line combo)Appearance Components (
n_comp_app):[24, 24, 24]
Rendering MLP: A small MLP decodes features into RGB.
Depth: 2 layers
Hidden Width: 128
Input Size: 54 (Features) \(\rightarrow\)Output: 3 (RGB)
Ray Sampling:
Samples per Ray: 192 (Uniform sampling with stratified jitter)
Near/Far Bounds: 0.2 to 6.0
Bounding Box: \([-1.5, 1.5]^3\)
View Encoding: Frequency encoding with
num_freqs=4.
2. Training Optimization
Iterations: Total of 30,000 steps.
Batch Size: 1,024 rays per device. With 8 TPU devices used, the Total Batch Size is 8,192 rays.
Optimizer: Adam (
b1=0.9,b2=0.99) with Global Norm Clipping at 1.0.Learning Rate Schedule: Warmup Cosine Decay.
Start:
1e-4Peak:
2e-2(reached at 2,000 warmup steps)End:
1e-3(decaying until 30,000 steps)MLP Specific Scaling: Gradients for the MLP render head are scaled down by a factor of 0.1 relative to the grid parameters.
3. Loss Functions & Regularization
Reconstruction Loss: Mean Squared Error (MSE).
L1 Regularization: Weight
4e-5applied to tensor factors.Total Variation (TV) Loss:
Weight Schedule: Decays exponentially from
0.5to0.01over the course of training.Scaling Factor: The scheduled weight is multiplied by a constant
1e-4.Ratio: Appearance TV loss is scaled by 0.1 relative to Density TV loss.
4. Data Loading Strategies
Pre-cropping: For the first 1,000 iterations, rays are sampled only from the center 50% of the image (
precrop_frac=0.5) to encourage learning the central object first.Resolution: Full resolution \((800 \times 800)\) images were used (Dataset:
nerf_synthetic/mic).
Results and Training Logs Analysis
We get a highly effective convergence trajectory, peaking at a PSNR of 32.10 dB.
The model starts with a PSNR of ~12.16 and rapidly improves to ~26.77 within the first 1,600 iterations, likely aided by the center-cropping strategy.
A distinctive pattern in the logs is the "Upsampling" events (e.g., at iter 2,000, 3,000, etc.), where the grid resolution increases. These events often result in a momentary stagnation or slight regression in loss as the optimization adapts to the new higher-resolution grid, followed by a surge in PSNR (e.g., jumping from 26.54 to 28.96 shortly after the upsample to resolution 300).
The validation images generated every 5,000 steps confirm visual improvement, and the best model checkpoint was saved near the very end of training (iteration 29,800), indicating that the cosine decay schedule effectively fine-tuned the detailed texture of the "mic" object on the \(512^3\) grid.
Resources
For Vanilla NeRF:
For TensoRF:
Video Explanation of the method by Zexiang Xu (who is one of the authors of this paper)
Code : This is the link to my kaggle notebook.



