Skip to content

Python API

This section is generated with mkdocstrings (Python handler). It documents the public API of the crazyflow package.

crazyflow

Classes

Control

Bases: str, Enum

Control type of the simulated onboard controller.

Attributes
attitude = 'attitude' class-attribute instance-attribute

Attitude control takes [roll, pitch, yaw, collective thrust].

Note

Recommended frequency is >=100 Hz.

force_torque = 'force_torque' class-attribute instance-attribute

Force and torque control takes [fx, fy, fz, tx, ty, tz].

Note

Recommended frequency is >=500 Hz.

state = 'state' class-attribute instance-attribute

State control takes [x, y, z, vx, vy, vz, ax, ay, az, yaw, roll_rate, pitch_rate, yaw_rate].

Note

Recommended frequency is >=20 Hz.

Warning

Currently, we only use positions, velocities, and yaw. The rest of the state is ignored. This is subject to change in the future.

Physics

Bases: str, Enum

Physics mode for the simulation.

Sim(n_worlds=1, n_drones=1, drone_model='cf2x_L250', physics=Physics.default, control=Control.default, integrator=Integrator.default, freq=500, state_freq=100, attitude_freq=500, force_torque_freq=500, device='cpu', xml_path=None, rng_key=0)

Source code in crazyflow/sim/sim.py
def __init__(
    self,
    n_worlds: int = 1,
    n_drones: int = 1,
    drone_model: str = "cf2x_L250",
    physics: Physics = Physics.default,
    control: Control = Control.default,
    integrator: Integrator = Integrator.default,
    freq: int = 500,
    state_freq: int = 100,
    attitude_freq: int = 500,
    force_torque_freq: int = 500,
    device: str = "cpu",
    xml_path: Path | None = None,
    rng_key: int = 0,
):
    assert Physics(physics) in Physics, f"Physics mode {physics} not implemented"
    assert Control(control) in Control, f"Control mode {control} not implemented"
    if physics != Physics.first_principles and control == Control.force_torque:
        raise ConfigError("Force-torque control requires first principles physics")
    if freq > 10_000 and not jax.config.jax_enable_x64:
        raise ConfigError("High frequency simulations require double precision mode")
    self.physics = physics
    self.control = control
    self.drone_model = drone_model
    self.integrator = integrator
    self.device = jax.devices(device)[0]
    self.n_worlds = n_worlds
    self.n_drones = n_drones
    self.freq = freq
    self.max_visual_geom = 1000

    # Initialize MuJoCo world and data
    self._xml_path = xml_path or Path(__file__).parents[1] / "scene.xml"
    self.drone_path = Path(drone_models.__file__).parent / "data" / f"{drone_model}.xml"
    self.spec = self.build_mjx_spec()
    self.mj_model, self.mj_data, self.mjx_model, self.mjx_data = self.build_mjx_model(self.spec)
    self.viewer: MujocoRenderer | None = None

    self.data = self.init_data(state_freq, attitude_freq, force_torque_freq, rng_key)
    self.default_data: SimData = self.build_default_data()

    # Build the simulation pipeline and overwrite the default _step implementation with it
    self.reset_pipeline: tuple[Callable[[SimData, Array[bool] | None], SimData], ...] = tuple()
    self.step_pipeline: tuple[Callable[[SimData], SimData], ...] = tuple()
    # The ``select_xxx_fn`` methods return functions, not the results of calling those
    # functions. They act as factories that produce building blocks for the construction of our
    # simulation pipeline.
    self.step_pipeline += build_control_fns(self.control, self.physics)
    physics_fn = select_physics_fn(self.physics)
    self.step_pipeline += (select_integrate_fn(self.integrator, physics_fn),)
    self.step_pipeline += (increment_steps,)
    # We never drop below -0.001 (drones can't pass through the floor). We use -0.001 to
    # enable checks for negative z sign
    self.step_pipeline += (clip_floor_pos,)

    self._reset = self.build_reset_fn()
    self._step = self.build_step_fn()
Attributes
controllable property

Boolean array of shape (n_worlds,) that indicates which worlds are controllable.

A world is controllable if the last control step was more than 1/control_freq seconds ago. Desired controls get stashed in the staged control buffers and are applied in step as soon as the controller frequency allows for an update. Successive control updates that happen before the staged buffers are applied overwrite the desired values.

Functions
attitude_control(controls)

Set the desired attitude for all drones in all worlds.

We need to stage the attitude controls because the sys_id physics mode operates directly on the attitude controls. If we were to directly update the controls, this would effectively bypass the control frequency and run the attitude controller at the physics update rate. By staging the controls, we ensure that the physics module sees the old controls until the controller updates at its correct frequency.

Source code in crazyflow/sim/sim.py
def attitude_control(self, controls: Array):
    """Set the desired attitude for all drones in all worlds.

    We need to stage the attitude controls because the sys_id physics mode operates directly on
    the attitude controls. If we were to directly update the controls, this would effectively
    bypass the control frequency and run the attitude controller at the physics update rate. By
    staging the controls, we ensure that the physics module sees the old controls until the
    controller updates at its correct frequency.
    """
    assert controls.shape == (self.n_worlds, self.n_drones, 4), "controls shape mismatch"
    assert self.control == Control.attitude, "Attitude control is not enabled by the sim config"
    controls = to_device(controls, self.device)
    self.data = self.data.replace(
        controls=self.data.controls.replace(
            attitude=self.data.controls.attitude.replace(staged_cmd=controls)
        )
    )
build_data()

Build the simulation data for the current configuration.

Note

This function re-initializes the simulation data according to the current configuration. It also returns the constructed data for use with pure functions.

Returns:

Type Description
SimData

The simulation data as a single PyTree that can be passed to the pure simulation

SimData

functions for stepping and resetting.

Source code in crazyflow/sim/sim.py
def build_data(self) -> SimData:
    """Build the simulation data for the current configuration.

    Note:
        This function re-initializes the simulation data according to the current configuration.
        It also returns the constructed data for use with pure functions.

    Returns:
        The simulation data as a single PyTree that can be passed to the pure simulation
        functions for stepping and resetting.
    """
    state_freq = self.data.controls.state.freq if self.data.controls.state is not None else 0
    attitude_freq = (
        self.data.controls.attitude.freq if self.data.controls.attitude is not None else 0
    )
    force_torque_freq = self.data.controls.force_torque.freq
    self.data = self.init_data(
        state_freq, attitude_freq, force_torque_freq, self.data.core.rng_key
    )
    return self.data
build_default_data()

Initialize the default data for the simulation.

Note

This function initializes the default data used as a reference in the reset function to reset the simulation to. It also returns the constructed data for use with pure functions.

Returns:

Type Description
SimData

The default simulation data used as a reference in the reset function to reset the

SimData

simulation to.

Source code in crazyflow/sim/sim.py
def build_default_data(self) -> SimData:
    """Initialize the default data for the simulation.

    Note:
        This function initializes the default data used as a reference in the reset function to
        reset the simulation to. It also returns the constructed data for use with pure
        functions.

    Returns:
        The default simulation data used as a reference in the reset function to reset the
        simulation to.
    """
    self.default_data = self.data.replace()
    return self.default_data
build_mjx_model(spec)

Build the MuJoCo model and data structures for the simulation.

Source code in crazyflow/sim/sim.py
def build_mjx_model(self, spec: mujoco.MjSpec) -> tuple[Any, Any, Model, Data]:
    """Build the MuJoCo model and data structures for the simulation."""
    mj_model = spec.compile()
    mj_data = mujoco.MjData(mj_model)
    mjx_model = mjx.put_model(mj_model, device=self.device)
    mjx_data = mjx.put_data(mj_model, mj_data, device=self.device)
    mjx_data = jax.vmap(lambda _: mjx_data)(jnp.arange(self.n_worlds))
    return mj_model, mj_data, mjx_model, mjx_data
build_mjx_spec()

Build the MuJoCo model specification for the simulation.

Source code in crazyflow/sim/sim.py
def build_mjx_spec(self) -> mujoco.MjSpec:
    """Build the MuJoCo model specification for the simulation."""
    assert self._xml_path.exists(), f"Model file {self._xml_path} does not exist"
    spec = mujoco.MjSpec.from_file(str(self._xml_path))
    spec.option.timestep = 1 / self.freq
    spec.copy_during_attach = True
    drone_spec = mujoco.MjSpec.from_file(str(self.drone_path))
    frame = spec.worldbody.add_frame(name="world")
    # Add drones and their actuators
    for i in range(self.n_drones):
        drone_body = drone_spec.body("drone")
        if drone_body is None:
            raise ValueError("Drone body not found in drone spec")
        drone = frame.attach_body(drone_body, "", f":{i}")
        drone.add_freejoint()
    return spec
build_reset_fn()

Build the reset function for the current simulation configuration.

Note

This function both changes the underlying implementation of Sim.reset() in-place to the current pipeline and returns the function for pure functional style programming.

Returns:

Type Description
Callable[[SimData, SimData, Array | None], SimData]

The pure JAX function that resets simulation data. It takes the current SimData, default

Callable[[SimData, SimData, Array | None], SimData]

SimData, and an optional mask for worlds to reset, returning the updated SimData.

Source code in crazyflow/sim/sim.py
def build_reset_fn(self) -> Callable[[SimData, SimData, Array | None], SimData]:
    """Build the reset function for the current simulation configuration.

    Note:
        This function both changes the underlying implementation of Sim.reset() in-place to the
        current pipeline and returns the function for pure functional style programming.

    Returns:
        The pure JAX function that resets simulation data. It takes the current SimData, default
        SimData, and an optional mask for worlds to reset, returning the updated SimData.
    """
    pipeline = self.reset_pipeline

    @jax.jit
    def reset(data: SimData, default_data: SimData, mask: Array | None = None) -> SimData:
        data = pytree_replace(data, default_data, mask)  # Does not overwrite rng_key
        for fn in pipeline:
            data = fn(data, mask)
        data = data.replace(core=data.core.replace(mjx_synced=False))  # Flag mjx data as stale
        return data

    self._reset = reset
    return reset
build_step_fn()

Setup the chain of functions that are called in Sim.step().

We know all the functions that are called in succession since the simulation is configured at initialization time. Instead of branching through options at runtime, we construct a step function at initialization that selects the correct functions based on the settings.

Note

This function both changes the underlying implementation of Sim.step() in-place to the current pipeline and returns the function for pure functional style programming.

Warning

If any settings change, the pipeline of functions needs to be reconstructed.

Returns:

Type Description
Callable[[SimData, int], SimData]

The pure JAX function that steps through the simulation. It takes the current SimData

Callable[[SimData, int], SimData]

and the number of steps to simulate, and returns the updated SimData.

Source code in crazyflow/sim/sim.py
def build_step_fn(self) -> Callable[[SimData, int], SimData]:
    """Setup the chain of functions that are called in Sim.step().

    We know all the functions that are called in succession since the simulation is configured
    at initialization time. Instead of branching through options at runtime, we construct a step
    function at initialization that selects the correct functions based on the settings.

    Note:
        This function both changes the underlying implementation of Sim.step() in-place to the
        current pipeline and returns the function for pure functional style programming.

    Warning:
        If any settings change, the pipeline of functions needs to be reconstructed.

    Returns:
        The pure JAX function that steps through the simulation. It takes the current SimData
        and the number of steps to simulate, and returns the updated SimData.
    """
    pipeline = self.step_pipeline

    # None is required by jax.lax.scan to unpack the tuple returned by single_step.
    def single_step(data: SimData, _: None) -> tuple[SimData, None]:
        for fn in pipeline:
            data = fn(data)
        return data, None

    # ``scan`` allows us control over loop unrolling for single steps from a single WhileOp to
    # complete unrolling, reducing either compilation times or fusing the loops to give XLA
    # maximum freedom to reorder operations and jointly optimize the pipeline. This is
    # especially relevant for the common use case of running multiple sim steps in an outer
    # loop, e.g. in gym environments.
    # Having n_steps as a static argument is fine, since patterns with n_steps > 1 will almost
    # always use the same n_steps value for successive calls.
    @partial(jax.jit, static_argnames="n_steps")
    def step(data: SimData, n_steps: int = 1) -> SimData:
        data, _ = jax.lax.scan(single_step, data, length=n_steps, unroll=1)
        data = data.replace(core=data.core.replace(mjx_synced=False))  # Flag mjx data as stale
        return data

    self._step = step
    return step
contacts(body=None)

Get contact information from the simulation.

Parameters:

Name Type Description Default
body str | None

Optional body name to filter contacts for. If None, returns flags for all bodies.

None

Returns:

Type Description
Array

An boolean array of shape (n_worlds,) that is True if any contact is present.

Source code in crazyflow/sim/sim.py
@requires_mujoco_sync
def contacts(self, body: str | None = None) -> Array:
    """Get contact information from the simulation.

    Args:
        body: Optional body name to filter contacts for. If None, returns flags for all bodies.

    Returns:
        An boolean array of shape (n_worlds,) that is True if any contact is present.
    """
    if body is None:
        return self.mjx_data._impl.contact.dist < 0
    body_id = self.mj_model.body(body).id
    geom_start = self.mj_model.body_geomadr[body_id]
    geom_count = self.mj_model.body_geomnum[body_id]
    return contacts(geom_start, geom_count, self.mjx_data)
force_torque_control(cmd)

Set the desired force and torque for all drones in all worlds.

Source code in crazyflow/sim/sim.py
def force_torque_control(self, cmd: Array):
    """Set the desired force and torque for all drones in all worlds."""
    assert cmd.shape == (self.n_worlds, self.n_drones, 4), "Command shape mismatch"
    assert self.control == Control.force_torque, (
        "Force-torque control is not enabled by the sim config"
    )
    controls = to_device(cmd, self.device)
    self.data = self.data.replace(
        controls=self.data.controls.replace(
            force_torque=self.data.controls.force_torque.replace(staged_cmd=controls)
        )
    )
init_data(state_freq, attitude_freq, force_torque_freq, rng_key)

Initialize the simulation data.

Source code in crazyflow/sim/sim.py
def init_data(
    self, state_freq: int, attitude_freq: int, force_torque_freq: int, rng_key: Array
) -> SimData:
    """Initialize the simulation data."""
    drone_ids = [self.mj_model.body(f"drone:{i}").id for i in range(self.n_drones)]
    N, D = self.n_worlds, self.n_drones
    data = SimData(
        states=SimState.create(N, D, self.device),
        states_deriv=SimStateDeriv.create(N, D, self.device),
        controls=SimControls.create(
            N,
            D,
            self.control,
            self.drone_model,
            state_freq,
            attitude_freq,
            force_torque_freq,
            self.device,
        ),
        params=SimParams.create(N, D, self.physics, self.drone_model, self.device),
        core=SimCore.create(self.freq, N, D, drone_ids, rng_key, self.device),
    )
    if D > 1:  # If multiple drones, arrange them in a grid
        grid = grid_2d(D)
        states = data.states.replace(pos=data.states.pos.at[..., :2].set(grid))
        data = data.replace(states=states)
    return data
reset(mask=None)

Reset the simulation to the initial state.

Parameters:

Name Type Description Default
mask Array | None

Boolean array of shape (n_worlds, ) that indicates which worlds to reset. If None, all worlds are reset.

None
Source code in crazyflow/sim/sim.py
def reset(self, mask: Array | None = None):
    """Reset the simulation to the initial state.

    Args:
        mask: Boolean array of shape (n_worlds, ) that indicates which worlds to reset. If None,
            all worlds are reset.
    """
    assert mask is None or mask.shape == (self.n_worlds,), f"Mask shape mismatch {mask.shape}"
    self.data = self._reset(self.data, self.default_data, mask)
seed(seed)

Set the JAX rng key for the simulation.

Parameters:

Name Type Description Default
seed int

The seed for the JAX rng.

required
Source code in crazyflow/sim/sim.py
def seed(self, seed: int):
    """Set the JAX rng key for the simulation.

    Args:
        seed: The seed for the JAX rng.
    """
    self.data: SimData = seed_sim(self.data, seed, self.device)
state_control(controls)

Set the desired state for all drones in all worlds.

Source code in crazyflow/sim/sim.py
def state_control(self, controls: Array):
    """Set the desired state for all drones in all worlds."""
    assert controls.shape == (self.n_worlds, self.n_drones, 13), "controls shape mismatch"
    assert self.control == Control.state, "State control is not enabled by the sim config"
    controls = to_device(controls, self.device)
    self.data = self.data.replace(
        controls=self.data.controls.replace(
            state=self.data.controls.state.replace(staged_cmd=controls)
        )
    )
step(n_steps=1)

Simulate all drones in all worlds for n time steps.

Source code in crazyflow/sim/sim.py
def step(self, n_steps: int = 1):
    """Simulate all drones in all worlds for n time steps."""
    assert n_steps > 0, "Number of steps must be positive"
    self.data = self._step(self.data, n_steps=n_steps)