Getting Started

Installation

pip install jaxcad

For development:

git clone https://github.com/andrinr/jaxcad
cd jaxcad
uv sync

Core concepts

jaxCAD has two building blocks: parametric geometry (SDF trees with tagged parameters) and geometric constraints (relationships between parameters). Gradient-based optimization ties them together.

Build a constrained scene

Parameters are Vector or Scalar objects tagged as free=True (differentiable) or free=False (fixed). Constraints link parameters together.

import jax.numpy as jnp
from jaxcad.constraints import DistanceConstraint
from jaxcad.geometry import Vector
from jaxcad.sdf import Sphere, Translate

# A sphere whose center is constrained to lie on a sphere of radius 2
anchor = Vector(jnp.array([0.0, 0.0, 0.0]), free=False, name="anchor")
p = Vector(jnp.array([2.0, 0.0, 0.0]), free=True, name="p")
DistanceConstraint(p, anchor, 2.0)

scene = Translate(Sphere(radius=0.3), offset=p)

Optimize on the constraint manifold

Because every SDF is a pure JAX function, jax.grad works directly. After each gradient step, project back onto the manifold to keep the constraint satisfied.

import jax

target = jnp.array([1.0, 1.5, 0.0])

def obj(q):
    return jnp.sum((q - target) ** 2)

lr = 0.1
p_current = jnp.array([2.0, 0.0, 0.0])
for _ in range(50):
    grad = jax.grad(obj)(p_current)
    p_new = p_current - lr * grad
    p_current = 2.0 * p_new / jnp.linalg.norm(p_new)  # project back onto manifold

print(p_current)  # [1.109 1.664 0.] -- closest point on the sphere to target

The constraint stays satisfied throughout: ||p|| = 2.0 at every step.



Rendering

jaxCAD includes a sphere-tracing raymarcher that supports physically-inspired shading, soft shadows, anti-aliasing, transparency, and glass refraction.

Assign materials

Attach a Material to any primitive. The material_at method is queried per hit point during rendering; Union blends materials smoothly at boundaries.

from jaxcad.render import Material
from jaxcad.sdf.primitives import Sphere

sphere = Sphere(
    radius=1.0,
    material=Material(
        color=[0.22, 0.50, 0.95],  # blue
        roughness=0.35,
        metallic=0.0,
        opacity=1.0,
        ior=1.0,                   # index of refraction
    ),
)

Basic render

raymarch() returns a (H, W, 3) float32 NumPy array. render_raymarched() wraps it and displays on a matplotlib axis.

import jax.numpy as jnp
from jaxcad.render import render_raymarched

render_raymarched(
    sphere,
    camera_pos=jnp.array([3.0, 2.0, 3.0]),
    light_dirs=jnp.array([[1.0, 1.5, 1.5], [-1.5, 0.5, -0.5]]),
    light_colors=jnp.array([[1.0, 0.90, 0.70], [0.20, 0.35, 0.60]]),
    resolution=(400, 400),
    aa_samples=2,
)

Background colour

Set background_color to control the colour returned for rays that miss all geometry. It is also used as the target for the smooth edge fade.

render_raymarched(
    sphere,
    background_color=jnp.array([0.78, 0.91, 1.0]),  # sky blue
)

Transparency and glass refraction

Set opacity < 1 on a material for transparency. With refract_steps=0 (default) the surface simply fades to the background colour. With refract_steps > 0 the renderer performs two-bounce Snell’s-law refraction:

  1. The primary ray hits the front face and bends into the material.
  2. An interior march finds the back face (using −sdf as the distance field).
  3. The ray bends back into air and continues through the rest of the scene.

The Fresnel effect (Schlick approximation) adds realistic edge highlights — the rim of a glass sphere brightens at grazing angles.

import jax.numpy as jnp
from jaxcad.render import raymarch, Material
from jaxcad.sdf.boolean import Union
from jaxcad.sdf.primitives import Sphere
from jaxcad.sdf.transforms import Translate

glass = Sphere(
    radius=1.0,
    material=Material(color=[0.92, 0.97, 1.0], roughness=0.05, opacity=0.04, ior=1.5),
)
red_ball = Translate(
    Sphere(radius=0.65, material=Material(color=[0.93, 0.26, 0.22])),
    offset=jnp.array([-1.1, 0.5, -3.0]),
)
scene = Union((glass, red_ball), smoothness=0.0)

image = raymarch(
    scene,
    camera_pos=jnp.array([0.0, 0.5, 5.5]),
    resolution=(400, 400),
    background_color=jnp.array([0.07, 0.09, 0.16]),
    refract_steps=48,   # enable two-bounce refraction
    max_steps=80,
    aa_samples=2,
)

IOR reference values:

material ior
air / disabled 1.0
water 1.33
glass 1.5
diamond 2.42

Next steps