render.functionalize.functionalize_render
render.functionalize.functionalize_render(
scene,
max_steps=32,
max_dist=15.0,
shadow_steps=12,
shadow_hardness=6.0,
gamma=2.2,
fd_normals=False,
normal_eps=0.0001,
reflect_steps=0,
)Compile a Scene to a differentiable render function.
Fixed geometry params are extracted once at call time and baked in, so the returned function only needs free_params — the dict that changes each optimisation step::
render_fn = functionalize_render(scene)
image = render_fn(free_params, resolution=(64, 96))
image is a JAX (H, W, 3) float32 array fully differentiable w.r.t. free_params via jax.grad.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| scene | Scene | Scene with geometry, camera, and lighting. |
required |
| max_steps | int | Sphere-tracing iterations per primary ray. | 32 |
| max_dist | float | Miss threshold distance. | 15.0 |
| shadow_steps | int | Soft shadow ray iterations. | 12 |
| shadow_hardness | float | Shadow edge sharpness. | 6.0 |
| gamma | float | Gamma correction exponent applied to the final image. | 2.2 |
| fd_normals | bool | Use central finite differences for surface normals instead of jax.grad. Set to True when calling inside jax.grad(loss_fn) to avoid 2nd-order AD overhead. |
False |
| normal_eps | float | Step size for finite-difference normal estimation. | 0.0001 |
Returns
| Name | Type | Description |
|---|---|---|
| Callable | (free_params, resolution=(H, W)) -> JAX image (H, W, 3) |