jaxCAD

Differentiable CAD built on JAX and Signed Distance Functions (SDFs).

jaxCAD combines parametric geometry, geometric constraints, and automatic differentiation to enable gradient-based shape optimization.

Core concepts

Layer What it does
Geometry Parametric Vector / Scalar carrying free / fixed flags
SDF Primitive shapes, boolean ops, and transforms — all differentiable
Constraints Geometric relationships (DistanceConstraint, etc.) that reduce DOF
Extraction Walk an SDF tree to collect free and fixed parameters
Functionalize Convert a parametric scene to a pure JAX function for jit / grad / vmap

Quick start

import jax.numpy as jnp
from jaxcad.geometry.parameters import Vector
from jaxcad.sdf.primitives import Sphere
from jaxcad.sdf.transforms import Translate
from jaxcad.constraints import DistanceConstraint, solve_constraints
from jaxcad import extract_parameters, functionalize

# Unknown point — wrong initial guess
p = Vector(jnp.array([0.5, 0.5, 0.0]), free=True, name='p')
scene = Translate(Sphere(radius=0.5), offset=p)

# Fixed anchors
anchor_a = Vector(jnp.array([0.0, 0.0, 0.0]), free=False, name='a')
anchor_b = Vector(jnp.array([4.0, 0.0, 0.0]), free=False, name='b')
anchor_c = Vector(jnp.array([2.0, 3.0, 0.0]), free=False, name='c')
true_p   = jnp.array([2.0, 1.0, 0.0])

# Three distance constraints fully determine p
DistanceConstraint(p, anchor_a, float(jnp.linalg.norm(true_p - anchor_a.value)))
DistanceConstraint(p, anchor_b, float(jnp.linalg.norm(true_p - anchor_b.value)))
DistanceConstraint(p, anchor_c, float(jnp.linalg.norm(true_p - anchor_c.value)))

# Solve
solved_params = solve_constraints(scene)
_, fixed_params = extract_parameters(scene)
sdf_fn = functionalize(scene)(solved_params, fixed_params)

print(sdf_fn(true_p))  # ≈ -0.5 (inside sphere)