constraints.solve.make_manifold_projection

constraints.solve.make_manifold_projection(metadata, *, steps=1)

Return an optax transform that projects params onto the constraint manifold.

Chain after an optimizer to enforce constraints after each gradient step:

optimizer = optax.chain(optax.adam(0.05), make_manifold_projection(metadata))
state = optimizer.init(free_params)
updates, state = optimizer.update(grads, state, free_params)
free_params = optax.apply_updates(free_params, updates)

Parameters

Name Type Description Default
metadata dict[str, Parameter] Name-keyed Parameter objects (carries constraint info). required
steps int Number of Newton corrections to apply (default 1). 1

Returns

Name Type Description
optax.GradientTransformationExtraArgs An optax.GradientTransformationExtraArgs that projects updates onto the
optax.GradientTransformationExtraArgs constraint manifold. Requires params to be passed to update.