constraints.solve.make_manifold_projection
constraints.solve.make_manifold_projection(metadata, *, steps=1)Return an optax transform that projects params onto the constraint manifold.
Chain after an optimizer to enforce constraints after each gradient step:
optimizer = optax.chain(optax.adam(0.05), make_manifold_projection(metadata))
state = optimizer.init(free_params)
updates, state = optimizer.update(grads, state, free_params)
free_params = optax.apply_updates(free_params, updates)
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| metadata | dict[str, Parameter] | Name-keyed Parameter objects (carries constraint info). | required |
| steps | int | Number of Newton corrections to apply (default 1). | 1 |
Returns
| Name | Type | Description |
|---|---|---|
| optax.GradientTransformationExtraArgs | An optax.GradientTransformationExtraArgs that projects updates onto the | |
| optax.GradientTransformationExtraArgs | constraint manifold. Requires params to be passed to update. |