parametrization.compute_param_scales

parametrization.compute_param_scales(metadata, scene_scale=1.0)

Compute per-parameter normalization scales.

Pure Python — reads metadata[k].bounds and returns a dict of JAX arrays (same shape as each parameter value). Call once before :func:jax.jit; the returned scales are static constants.

Scale assignment:

  • [lo, hi] fully-bounded → sigmoid maps to (0, 1), logit maps back to (-∞, +∞) and is already O(1–4) → scale = 1.0
  • [lo, ∞) lower-bounded or unbounded → values are O(scene_scale) after the softplus⁻¹/identity → scale = scene_scale

Parameters

Name Type Description Default
metadata dict[str, Parameter] as returned by :func:extract_parameters. required
scene_scale Characteristic length of the scene (default 1.0). Set to the typical coordinate magnitude, e.g. 2.0 for a scene where objects span ±2 units. 1.0

Returns

Name Type Description
dict[str, Array] mapping each parameter name to its scale array.