jaxCAD
Differentiable CAD built on JAX and Signed Distance Functions (SDFs).
jaxCAD combines parametric geometry, geometric constraints, and automatic differentiation to enable gradient-based shape optimization.
Core concepts
| Layer | What it does |
|---|---|
| Geometry | Parametric Vector / Scalar carrying free / fixed flags |
| SDF | Primitive shapes, boolean ops, and transforms — all differentiable |
| Constraints | Geometric relationships (DistanceConstraint, etc.) that reduce DOF |
| Extraction | Walk an SDF tree to collect free and fixed parameters |
| Functionalize | Convert a parametric scene to a pure JAX function for jit / grad / vmap |
Quick start
import jax.numpy as jnp
from jaxcad.geometry.parameters import Vector
from jaxcad.sdf.primitives import Sphere
from jaxcad.sdf.transforms import Translate
from jaxcad.constraints import DistanceConstraint, solve_constraints
from jaxcad import extract_parameters, functionalize
# Unknown point — wrong initial guess
p = Vector(jnp.array([0.5, 0.5, 0.0]), free=True, name='p')
scene = Translate(Sphere(radius=0.5), offset=p)
# Fixed anchors
anchor_a = Vector(jnp.array([0.0, 0.0, 0.0]), free=False, name='a')
anchor_b = Vector(jnp.array([4.0, 0.0, 0.0]), free=False, name='b')
anchor_c = Vector(jnp.array([2.0, 3.0, 0.0]), free=False, name='c')
true_p = jnp.array([2.0, 1.0, 0.0])
# Three distance constraints fully determine p
DistanceConstraint(p, anchor_a, float(jnp.linalg.norm(true_p - anchor_a.value)))
DistanceConstraint(p, anchor_b, float(jnp.linalg.norm(true_p - anchor_b.value)))
DistanceConstraint(p, anchor_c, float(jnp.linalg.norm(true_p - anchor_c.value)))
# Solve
solved_params = solve_constraints(scene)
_, fixed_params = extract_parameters(scene)
sdf_fn = functionalize(scene)(solved_params, fixed_params)
print(sdf_fn(true_p)) # ≈ -0.5 (inside sphere)