parametrization.compute_param_scales
parametrization.compute_param_scales(metadata, scene_scale=1.0)Compute per-parameter normalization scales.
Pure Python — reads metadata[k].bounds and returns a dict of JAX arrays (same shape as each parameter value). Call once before :func:jax.jit; the returned scales are static constants.
Scale assignment:
[lo, hi]fully-bounded → sigmoid maps to(0, 1), logit maps back to(-∞, +∞)and is already O(1–4) → scale =1.0[lo, ∞)lower-bounded or unbounded → values are O(scene_scale) after the softplus⁻¹/identity → scale =scene_scale
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| metadata | dict[str, Parameter] as returned by :func:extract_parameters. |
required | |
| scene_scale | Characteristic length of the scene (default 1.0). Set to the typical coordinate magnitude, e.g. 2.0 for a scene where objects span ±2 units. |
1.0 |
Returns
| Name | Type | Description |
|---|---|---|
dict[str, Array] mapping each parameter name to its scale array. |