Skip to content

utils.identification

This module contains functions to identify so_rpy models from data.

sys_id_rotation(data, data_validation=None, verbose=0, plot=False)

Identify the rotational part of the so_rpy model from data.

Parameters:

Name Type Description Default
data dict[str, Array]

Training data containing time, and the SVF values of rpy [rad], cmd_rpy [rad].

required
data_validation dict[str, Array] | None

Optional validation data containing the same fields as data.

None
verbose int

Verbosity level for the optimizer from 0 to 2.

0
plot bool

Whether to plot the results.

False

Returns: Identified model parameters.

Source code in drone_models/utils/identification.py
def sys_id_rotation(
    data: dict[str, Array],
    data_validation: dict[str, Array] | None = None,
    verbose: int = 0,
    plot: bool = False,
) -> dict[str, Array]:
    """Identify the rotational part of the so_rpy model from data.

    Args:
        data: Training data containing time, and the SVF values of rpy [rad], cmd_rpy [rad].
        data_validation: Optional validation data containing the same fields as data.
        verbose: Verbosity level for the optimizer from 0 to 2.
        plot: Whether to plot the results.

    Returns: Identified model parameters.
    """
    # theta includes the values for roll/pitch (same value) and yaw
    theta0 = np.array([-10.0, -10.0, -1.0, -1.0, 10.0, 10.0])  # ry, ry_rates, cmd_ry
    method = "trf"
    xtol, ftol, gtol = 1e-10, 1e-10, 1e-10
    t = jnp.array(data["time"])
    rpy = jnp.array(data["SVF_rpy"])
    cmd_rpy = jnp.array(data["SVF_cmd_rpy"])
    if data_validation is not None:
        t_valid = jnp.array(data_validation["time"])
        rpy_valid = jnp.array(data_validation["SVF_rpy"])
        cmd_rpy_valid = jnp.array(data_validation["SVF_cmd_rpy"])

    # Identification
    residual_fun_rot, residual_fun_rot_jac = _build_residuals_fun_rotation()
    res = least_squares(
        residual_fun_rot,
        x0=theta0,
        jac=residual_fun_rot_jac,
        args=(cmd_rpy, t, rpy),
        method=method,
        xtol=xtol,
        ftol=ftol,
        gtol=gtol,
        verbose=verbose,
    )
    theta = res.x

    rpy_coef = np.array([theta[0], theta[0], theta[1]])
    rpy_rates_coef = np.array([theta[2], theta[2], theta[3]])
    cmd_rpy_coef = np.array([theta[4], theta[4], theta[5]])
    params = {"rpy_coef": rpy_coef, "rpy_rates_coef": rpy_rates_coef, "cmd_rpy_coef": cmd_rpy_coef}

    rpy_pred = _simulate_system_rotation(cmd_rpy, t, theta)
    if data_validation is not None:
        rpy_pred_valid = _simulate_system_rotation(cmd_rpy_valid, t_valid, theta)

    # Report
    txt = "\n=== Stats roll & pitch ==="
    txt += f"\nEstimated:  {rpy_coef=}, {rpy_rates_coef=}, {cmd_rpy_coef=}"
    txt += f"\nTraining success={res.success}, results:"
    txt += f"\nRMSE={_rmse(rpy, rpy_pred):.6f}"
    txt += f"\nR^2={_r2(rpy, rpy_pred):.4f}"

    if data_validation is not None:
        txt += "\nValidation results:"
        txt += f"\nRMSE roll={_rmse(rpy_valid[..., 0], rpy_pred_valid[..., 0]):.6f}"
        txt += f"\nRMSE pitch={_rmse(rpy_valid[..., 1], rpy_pred_valid[..., 1]):.6f}"
        txt += f"\nR^2 roll={_r2(rpy_valid[..., 0], rpy_pred_valid[..., 0]):.4f}"
        txt += f"\nR^2 pitch={_r2(rpy_valid[..., 1], rpy_pred_valid[..., 1]):.4f}"
    logger.info(txt)

    # Plotting
    if plot:
        fig, axs = plt.subplots(3, 2, figsize=(20, 12))
        plt.suptitle("RPY dynamics fit")

        axs[0, 0].plot(t, rpy[..., 0], label="Measured roll")
        axs[0, 0].plot(t, rpy_pred[..., 0], "--", label="Predicted roll")
        axs[0, 0].set_ylabel("Roll [rad]")

        axs[0, 1].plot(t_valid, rpy_valid[..., 0], label="Measured roll (valid)")
        axs[0, 1].plot(t_valid, rpy_pred_valid[..., 0], "--", label="Predicted roll (valid)")
        axs[0, 1].set_ylabel("Roll [rad]")

        axs[1, 0].plot(t, rpy[..., 1], label="Measured pitch")
        axs[1, 0].plot(t, rpy_pred[..., 1], "--", label="Predicted pitch")
        axs[1, 0].set_ylabel("Pitch [rad]")

        axs[1, 1].plot(t_valid, rpy_valid[..., 1], label="Measured pitch (valid)")
        axs[1, 1].plot(t_valid, rpy_pred_valid[..., 1], "--", label="Predicted pitch (valid)")
        axs[1, 1].set_ylabel("Pitch [rad]")

        axs[2, 0].plot(t, rpy[..., 2], label="Measured yaw")
        axs[2, 0].plot(t, rpy_pred[..., 2], "--", label="Predicted yaw")
        axs[2, 0].set_xlabel("Time [s]")
        axs[2, 0].set_ylabel("Yaw [rad]")

        axs[2, 1].plot(t_valid, rpy_valid[..., 2], label="Measured yaw (valid)")
        axs[2, 1].plot(t_valid, rpy_pred_valid[..., 2], "--", label="Predicted yaw (valid)")
        axs[2, 1].set_xlabel("Time [s]")
        axs[2, 1].set_ylabel("Yaw [rad]")

        for ax in axs.flat:
            ax.grid(True)
            ax.legend()

        plt.tight_layout()
        plt.show()

    return params

sys_id_translation(model, mass, data, data_validation=None, gravity=np.array([0, 0, -9.81]), verbose=0, plot=False)

Identify the translational part of the so_rpy model from data.

Parameters:

Name Type Description Default
model Literal['so_rpy', 'so_rpy_rotor', 'so_rpy_rotor_drag']

Model type to identify.

required
mass float

Mass of the drone.

required
data dict[str, Array]

Training data containing time, and the SVF values of vel, acc, quat, cmd_f.

required
data_validation dict[str, Array] | None

Optional validation data containing the same fields as data.

None
gravity Array

Gravity vector in world frame, i.e., [0, 0, -9.81].

array([0, 0, -9.81])
verbose int

Verbosity level for the optimizer from 0 to 2.

0
plot bool

Whether to plot the results.

False

Returns: Identified model parameters.

Source code in drone_models/utils/identification.py
def sys_id_translation(
    model: Literal["so_rpy", "so_rpy_rotor", "so_rpy_rotor_drag"],
    mass: float,
    data: dict[str, Array],
    data_validation: dict[str, Array] | None = None,
    gravity: Array = np.array([0, 0, -9.81]),
    verbose: int = 0,
    plot: bool = False,
) -> dict[str, Array]:
    """Identify the translational part of the so_rpy model from data.

    Args:
        model: Model type to identify.
        mass: Mass of the drone.
        data: Training data containing time, and the SVF values of vel, acc, quat, cmd_f.
        data_validation: Optional validation data containing the same fields as data.
        gravity: Gravity vector in world frame, i.e., [0, 0, -9.81].
        verbose: Verbosity level for the optimizer from 0 to 2.
        plot: Whether to plot the results.

    Returns: Identified model parameters.
    """
    theta0 = [1.0, 1.0, 0.0, 0.0]
    method = "trf"
    xtol, ftol, gtol = 1e-10, 1e-10, 1e-10
    constants = {"mass": mass, "gravity_vec": gravity}
    # Convert the data to jnp arrays for use with jax
    t = jnp.array(data["time"])
    vel = jnp.array(data["SVF_vel"])
    acc = jnp.array(data["SVF_acc"])
    quat = jnp.array(data["SVF_quat"])
    cmd_f = jnp.array(data["SVF_cmd_f"])

    # Identification
    residual_fun_trans, residual_fun_trans_jac = _build_residuals_fun_translation(model)
    res = least_squares(
        residual_fun_trans,
        x0=theta0,
        jac=residual_fun_trans_jac,
        args=(quat, vel, cmd_f, t, constants, acc),
        method=method,
        xtol=xtol,
        ftol=ftol,
        gtol=gtol,
        verbose=verbose,
    )

    theta = res.x
    params = {"cmd_f_coef": theta[0]}
    if "rotor" in model:
        params["thrust_time_coef"] = theta[1]
    else:
        theta[1] = 0.0
    if "drag" in model:
        params["drag_xy_coef"] = theta[2]
        params["drag_z_coef"] = theta[3]
    else:
        theta[2] = 0.0
        theta[3] = 0.0

    acc_pred = _simulate_system_translation(quat, vel, cmd_f, t, theta, constants)
    if data_validation is not None:
        t_valid = jnp.array(data_validation["time"])
        vel_valid = jnp.array(data_validation["SVF_vel"])
        acc_valid = jnp.array(data_validation["SVF_acc"])
        quat_valid = jnp.array(data_validation["SVF_quat"])
        cmd_f_valid = jnp.array(data_validation["SVF_cmd_f"])
        acc_pred_valid = _simulate_system_translation(
            quat_valid, vel_valid, cmd_f_valid, t_valid, theta, constants
        )

    # Report
    txt = f"\n=== Stats {model} ==="
    txt += f"\nParameters: {params=}"
    txt += f"\nTraining success={res.success}, results:"
    txt += f"\nRMSE={_rmse(acc, acc_pred):.6f}"
    txt += f"\nR^2={_r2(acc, acc_pred):.4f}"
    if data_validation is not None:
        txt += "\nValidation results:"
        txt += f"\nRMSE={_rmse(acc_valid, acc_pred_valid):.6f}"
        txt += f"\nR^2={_r2(acc_valid, acc_pred_valid):.4f}"
    logger.info(txt)

    # Plotting
    if plot:
        # Plot acceleration
        fig, axs = plt.subplots(2, 1, figsize=(12, 5))

        # Training data subplot
        axs[0].plot(t, acc, label="Measured acc")
        axs[0].plot(t, acc_pred, "--", label="Predicted acc")
        axs[0].set_xlabel("Time [s]")
        axs[0].set_ylabel("Output")

        # Validation data subplot
        if data_validation is not None:
            axs[1].plot(t_valid, acc_valid, label="Measured acc (valid)")
            axs[1].plot(t_valid, acc_pred_valid, "--", label="Predicted acc (valid)")
            axs[1].set_xlabel("Time [s]")
            axs[1].set_ylabel("Output")

        for ax in axs.flat:
            ax.grid(True)
            ax.legend()

        plt.tight_layout()
        plt.show()

        # Plot commanded thrust vs actual thrust
        fig, ax = plt.subplots(1, 1, figsize=(6, 6))

        ax.scatter(
            cmd_f, np.linalg.norm((acc - constants["gravity_vec"]) * constants["mass"], axis=-1)
        )
        cmd_thrust_lin = np.linspace(np.min(cmd_f) * 0.9, np.max(cmd_f) * 1.1, 1000)
        ax.plot(cmd_thrust_lin, theta[0] * cmd_thrust_lin, label="Fit")
        ax.set_xlabel("Commanded Thrust [N]")
        ax.set_ylabel("Actual Thrust [N]")
        ax.set_xlim(0.1, 0.8)
        ax.set_ylim(0.1, 0.8)

        plt.tight_layout()
        plt.show()

    return params