render.functionalize.functionalize_render

render.functionalize.functionalize_render(
    scene,
    max_steps=32,
    max_dist=15.0,
    shadow_steps=12,
    shadow_hardness=6.0,
    gamma=2.2,
    fd_normals=False,
    normal_eps=0.0001,
    reflect_steps=0,
)

Compile a Scene to a differentiable render function.

Fixed geometry params are extracted once at call time and baked in, so the returned function only needs free_params — the dict that changes each optimisation step::

render_fn = functionalize_render(scene)
image     = render_fn(free_params, resolution=(64, 96))

image is a JAX (H, W, 3) float32 array fully differentiable w.r.t. free_params via jax.grad.

Parameters

Name Type Description Default
scene Scene Scene with geometry, camera, and lighting. required
max_steps int Sphere-tracing iterations per primary ray. 32
max_dist float Miss threshold distance. 15.0
shadow_steps int Soft shadow ray iterations. 12
shadow_hardness float Shadow edge sharpness. 6.0
gamma float Gamma correction exponent applied to the final image. 2.2
fd_normals bool Use central finite differences for surface normals instead of jax.grad. Set to True when calling inside jax.grad(loss_fn) to avoid 2nd-order AD overhead. False
normal_eps float Step size for finite-difference normal estimation. 0.0001

Returns

Name Type Description
Callable (free_params, resolution=(H, W)) -> JAX image (H, W, 3)