Source code for empyrean.ephemeris.sensitivity

"""Per-row sensitivity tables for state-space and observation-space partials.

Two flat quivr Tables, one row per ``(orbit, epoch)`` (or
``(orbit, observer, epoch)`` for the observation side):

- :class:`StateSensitivities` — STM (and optional STT) at each output
  epoch of a propagation. ``stm`` / ``stt`` ride as row-major flattened
  ``LargeListColumn`` values (length 36 / 216).
- :class:`ObservationSensitivities` — observation Jacobian (and
  optional Hessian) at each ephemeris epoch. ``n_params`` column
  documents the inner shape (6 for state-only DC, 9 with non-grav);
  ``jacobian`` / ``hessian`` are variable-length flattened lists
  (length ``6·n_params`` / ``6·n_params²``).

Filter to one chain with the standard quivr pattern before calling
the per-chain accessors::

    chain = sens.select("orbit_id", oid)  # one orbit
    obs_chain = obs.select("orbit_id", oid).select("obs_code", "F51")

Helper accessors on the filtered table (``stms_array``,
``jacobians_array``, ``index_at``, ``propagate_covariance``,
``kappa``) reshape the flat lists back to numpy matrices on demand.
"""

from datetime import datetime
from typing import Literal

import numpy as np
import quivr as qv

# ``pyarrow.compute`` generates its comparison wrappers (``less_equal`` …)
# dynamically at import time, so they are invisible to the static type
# checker. ``call_function`` is the public, statically-typed entry point
# these wrappers delegate to; import it from its defining module so the
# checker can resolve it.
from pyarrow._compute import call_function

from empyrean.coordinates.epoch import Epochs

EpochLike = float | str | datetime | Epochs
"""Anything :func:`StateSensitivities.index_at` accepts — a scalar MJD
TDB, a length-1 :class:`Epochs`, an ISO-8601 string, or a ``datetime``."""


# ── State-space sensitivity ───────────────────────────────────────────


[docs] class StateSensitivities(qv.Table): """Per-``(orbit, epoch)`` state-transition matrices and tensors. One row per output epoch; rows are grouped contiguously by ``orbit_id`` (matching propagation's orbit-major output). Filter to one chain with quivr's standard ``select`` before calling the per-chain accessors:: chain = sens.select("orbit_id", "2020 CD3") phi = chain.stms_array()[chain.index_at(60750.0)] Notes ----- Matrices are stored row-major flattened in ``LargeListColumn`` s: - ``stm`` is the 6×6 STM Φ flattened to length 36 (``stm[6·r + c] = Φ[r, c]``). ``None`` per row when STMs were not computed for that row. - ``stt`` is the 6×6×6 STT Ψ flattened to length 216 (``stt[36·k + 6·a + b] = Ψ[k, a, b]``). ``None`` when the propagation method did not carry STTs (anything other than ``UncertaintyMethod.SECOND_ORDER``). The accessors below operate on whatever rows are present — typically a single chain after ``select("orbit_id", oid)``, but work on the full table too when shapes are uniform. """ orbit_id = qv.LargeStringColumn() """Orbit primary key (matches the input ``Orbits.orbit_id``).""" object_id = qv.LargeStringColumn(nullable=True) """Object metadata label, if carried on the input orbit.""" epoch_mjd_tdb = qv.Float64Column() """Output epoch (MJD TDB).""" stm = qv.LargeListColumn(qv.Float64Column(), nullable=True) """Row-major flattened 6×6 STM (length 36 per row), or ``None``.""" stt = qv.LargeListColumn(qv.Float64Column(), nullable=True) """Row-major flattened 6×6×6 STT (length 216 per row), or ``None``.""" resolved_kind = qv.LargeStringColumn(nullable=True) """Resolved covariance kind at this output epoch (:class:`~empyrean.propagation.tagged_covariance.CovarianceKind` value: ``linear`` / ``second_order`` / …), or ``None`` if the propagation did not resolve a covariance for this row.""" # ── Introspection ─────────────────────────────────────────
[docs] def orbit_ids_unique(self) -> list[str]: """Unique ``orbit_id`` values, in first-seen order.""" seen: set[str] = set() out: list[str] = [] for v in self.orbit_id.to_pylist(): if v not in seen: seen.add(v) out.append(v) return out
# ── Matrix reshaping ──────────────────────────────────────
[docs] def stms_array(self) -> np.ndarray | None: """Reshape ``stm`` to ``(n_t, 6, 6)``. Returns ``None`` when every row has a null STM. Null rows within an otherwise-populated chain are filled with NaN. Raises :class:`ValueError` if the table holds more than one unique ``orbit_id`` — filter via ``select`` first. """ _require_single_state_chain(self, "stms_array") col = self.column("stm") if col.null_count == len(col): return None rows = col.to_pylist() n = len(rows) out = np.empty((n, 36), dtype=np.float64) for i, row in enumerate(rows): if row is None: out[i, :] = np.nan else: out[i, :] = row return out.reshape(n, 6, 6)
[docs] def stts_array(self) -> np.ndarray | None: """Reshape ``stt`` to ``(n_t, 6, 6, 6)``. Returns ``None`` when every row has a null STT. Raises :class:`ValueError` if the table holds more than one unique ``orbit_id`` — filter via ``select`` first. """ _require_single_state_chain(self, "stts_array") col = self.column("stt") if col.null_count == len(col): return None rows = col.to_pylist() n = len(rows) out = np.empty((n, 216), dtype=np.float64) for i, row in enumerate(rows): if row is None: out[i, :] = np.nan else: out[i, :] = row return out.reshape(n, 6, 6, 6)
# ── Epoch lookup ──────────────────────────────────────────
[docs] def index_at(self, epoch: EpochLike, *, atol: float = 1e-9) -> int: """Row index at the given epoch. ``epoch`` is converted to MJD TDB and matched within ``atol``. Raises :class:`ValueError` if no row matches, or if the table holds more than one unique ``orbit_id`` — filter via ``select`` first. """ _require_single_state_chain(self, "index_at") target = _to_mjd_tdb(epoch) mjd = self.column("epoch_mjd_tdb").to_numpy(zero_copy_only=False) diffs = np.abs(mjd - target) i = int(np.argmin(diffs)) if diffs[i] > atol: raise ValueError( f"epoch MJD TDB {target} not found " f"(nearest row {i} at MJD TDB {mjd[i]}, Δ={diffs[i]:.3e} > " f"atol={atol:.3e})" ) return i
[docs] def up_to(self, epoch: EpochLike) -> "StateSensitivities": """Subset including rows with ``epoch_mjd_tdb ≤`` the target.""" target = _to_mjd_tdb(epoch) mask = call_function("less_equal", [self.column("epoch_mjd_tdb"), target]) return self.apply_mask(mask)
# ── Covariance propagation ────────────────────────────────
[docs] def propagate_covariance( self, cov_in: np.ndarray, *, i: int | None = None, order: Literal[1, 2, "auto"] = "auto", ) -> tuple[np.ndarray, np.ndarray]: """Forward-propagate a covariance through the chain. Filter to a single chain via ``select("orbit_id", oid)`` first — this method assumes the chain's STMs share a common t₀. Parameters ---------- cov_in : np.ndarray Input 6×6 covariance at the chain's start epoch. i : int, optional Row index to evaluate at. ``None`` (default) returns the covariance at every chain epoch, shape ``(n_t, 6, 6)``. order : {1, 2, "auto"} ``1``: linear (``Σ = Φ Σ_0 Φᵀ``, ``Δμ = 0``). ``2``: Jet2 second-order Gaussian correction (requires STTs). ``"auto"``: order 2 when STTs are present, else order 1. Returns ------- (cov_out, delta_mu) : (np.ndarray, np.ndarray) ``(6, 6)`` / ``(6,)`` for scalar ``i``; ``(n_t, 6, 6)`` / ``(n_t, 6)`` for ``i=None``. """ stms = self.stms_array() if stms is None: raise ValueError( "chain has no STMs — propagation method did not compute them " "(Monte Carlo / SigmaPoint, or FirstOrder without input covariance)" ) cov_in = np.asarray(cov_in, dtype=np.float64) if cov_in.shape != (6, 6): raise ValueError(f"cov_in must be (6, 6), got {cov_in.shape}") stts = self.stts_array() if order == "auto": order = 2 if stts is not None else 1 if order == 2 and stts is None: raise ValueError( "order=2 requires STTs; chain has none " "(run propagate with uncertainty_method=UncertaintyMethod.SECOND_ORDER)" ) if order not in (1, 2): raise ValueError(f"order must be 1, 2, or 'auto'; got {order!r}") sub_stms = stms if i is None else stms[i : i + 1] sub_stts: np.ndarray | None = None if order == 2: # stts is guaranteed non-None here: order==2 with stts is None # raised above, and order=="auto" resolved to 2 only when stts # was non-None. assert stts is not None sub_stts = stts if i is None else stts[i : i + 1] cov_out, delta_mu = _propagate_cov_batch(sub_stms, sub_stts, cov_in, order=order) if i is None: return cov_out, delta_mu return cov_out[0], delta_mu[0]
[docs] def kappa( self, cov_in: np.ndarray, *, i: int | None = None, ) -> float | np.ndarray: """Jet2 nonlinearity diagnostic κ. Approximates the departure of the true distribution from a Gaussian centered at the nominal state. Small κ (≲ 0.1) means first-order covariance is adequate; larger κ warrants Jet2 SOG or Gaussian-mixture splitting. Requires STTs. Filter to one chain first. """ stts = self.stts_array() if stts is None: raise ValueError( "kappa requires STTs; chain has none " "(run propagate with uncertainty_method=UncertaintyMethod.SECOND_ORDER)" ) cov_in = np.asarray(cov_in, dtype=np.float64) sub_stts = stts if i is None else stts[i : i + 1] kap = _kappa_batch(sub_stts, cov_in) return float(kap[0]) if i is not None else kap
# ── Observation-space sensitivity ─────────────────────────────────────
[docs] class ObservationSensitivities(qv.Table): """Per-``(orbit, observer, epoch)`` observation Jacobians and Hessians. Holds ∂h/∂x₀ at every ephemeris epoch for each ``(orbit, observer)`` pair, plus the observation Hessians when the underlying propagation carried STTs. Filter to one chain via two ``select`` calls before using the per-chain accessors:: chain = obs.select("orbit_id", oid).select("obs_code", "F51") H = chain.jacobians_array()[chain.index_at(60750.0)] Notes ----- Matrix payloads are row-major flattened ``LargeListColumn`` values: - ``jacobian`` is ``(6, n_params)`` flattened to length ``6·n_params`` (``jacobian[n_params·r + c] = J[r, c]``). - ``hessian`` is ``(6, n_params, n_params)`` flattened to length ``6·n_params²``. The ``n_params`` column documents which: ``6`` for a state-only DC, ``9`` when non-gravitational parameters (A1, A2, A3) are also free variables. All rows within a single chain share the same ``n_params``. """ orbit_id = qv.LargeStringColumn() """Orbit primary key.""" object_id = qv.LargeStringColumn(nullable=True) """Object metadata label.""" obs_code = qv.LargeStringColumn() """MPC observatory code.""" epoch_mjd_tdb = qv.Float64Column() """Observation epoch (MJD TDB).""" n_params = qv.UInt8Column() """Last-axis dimension of Jacobian / Hessian. 6 for state-only DC, 9 when non-grav A1/A2/A3 are free. Constant within a chain.""" jacobian = qv.LargeListColumn(qv.Float64Column(), nullable=True) """Row-major flattened (6, n_params) Jacobian.""" hessian = qv.LargeListColumn(qv.Float64Column(), nullable=True) """Row-major flattened (6, n_params, n_params) Hessian.""" # ── Introspection ─────────────────────────────────────────
[docs] def chain_keys(self) -> list[tuple[str, str]]: """Unique ``(orbit_id, obs_code)`` pairs, in first-seen order.""" seen: set[tuple[str, str]] = set() out: list[tuple[str, str]] = [] for oid, obs in zip(self.orbit_id.to_pylist(), self.obs_code.to_pylist(), strict=False): key = (oid, obs) if key not in seen: seen.add(key) out.append(key) return out
# ── Matrix reshaping ──────────────────────────────────────
[docs] def jacobians_array(self) -> np.ndarray | None: """Reshape ``jacobian`` to ``(n_t, 6, n_params)``. Returns ``None`` when every row has a null Jacobian. """ _require_single_obs_chain(self, "jacobians_array") col = self.column("jacobian") if col.null_count == len(col): return None n_p = int(self.column("n_params")[0].as_py()) rows = col.to_pylist() n = len(rows) out = np.empty((n, 6 * n_p), dtype=np.float64) for i, row in enumerate(rows): if row is None: out[i, :] = np.nan else: out[i, :] = row return out.reshape(n, 6, n_p)
[docs] def hessians_array(self) -> np.ndarray | None: """Reshape ``hessian`` to ``(n_t, 6, n_params, n_params)``. Returns ``None`` when every row has a null Hessian. """ _require_single_obs_chain(self, "hessians_array") col = self.column("hessian") if col.null_count == len(col): return None n_p = int(self.column("n_params")[0].as_py()) rows = col.to_pylist() n = len(rows) out = np.empty((n, 6 * n_p * n_p), dtype=np.float64) for i, row in enumerate(rows): if row is None: out[i, :] = np.nan else: out[i, :] = row return out.reshape(n, 6, n_p, n_p)
# ── Epoch lookup ──────────────────────────────────────────
[docs] def index_at(self, epoch: EpochLike, *, atol: float = 1e-9) -> int: """Row index at the given epoch. Filter to a single chain via two ``select`` calls first if epochs repeat across chains. """ _require_single_obs_chain(self, "index_at") target = _to_mjd_tdb(epoch) mjd = self.column("epoch_mjd_tdb").to_numpy(zero_copy_only=False) diffs = np.abs(mjd - target) i = int(np.argmin(diffs)) if diffs[i] > atol: raise ValueError( f"epoch MJD TDB {target} not found " f"(nearest row {i} at MJD TDB {mjd[i]}, Δ={diffs[i]:.3e} > " f"atol={atol:.3e})" ) return i
[docs] def up_to(self, epoch: EpochLike) -> "ObservationSensitivities": """Subset including rows with ``epoch_mjd_tdb ≤`` the target.""" target = _to_mjd_tdb(epoch) mask = call_function("less_equal", [self.column("epoch_mjd_tdb"), target]) return self.apply_mask(mask)
# ── Covariance propagation ────────────────────────────────
[docs] def propagate_covariance( self, cov_in: np.ndarray, *, i: int | None = None, order: Literal[1, 2, "auto"] = "auto", ) -> tuple[np.ndarray, np.ndarray]: """Map an initial-state covariance into observation-frame covariance. Filter to a single ``(orbit_id, obs_code)`` chain first. Returns ``(Σ_obs, Δμ_obs)`` with shapes: ``i=None``: ``(n_t, 6, 6)``, ``(n_t, 6)`` ``i=int``: ``(6, 6)``, ``(6,)`` """ jacs = self.jacobians_array() if jacs is None: raise ValueError( "chain has no Jacobians — ephemeris generation did not carry " "observation partials (likely no input covariance)" ) n_p = jacs.shape[-1] cov_in = np.asarray(cov_in, dtype=np.float64) if cov_in.shape != (n_p, n_p): raise ValueError(f"cov_in must be ({n_p}, {n_p}) for this chain, got {cov_in.shape}") hess = self.hessians_array() if order == "auto": order = 2 if hess is not None else 1 if order == 2 and hess is None: raise ValueError( "order=2 requires Hessians; chain has none (propagation wasn't SECOND_ORDER)" ) sub_jacs = jacs if i is None else jacs[i : i + 1] sub_hess: np.ndarray | None = None if order == 2: # hess is guaranteed non-None here: order==2 with hess is None # raised above, and order=="auto" resolved to 2 only when hess # was non-None. assert hess is not None sub_hess = hess if i is None else hess[i : i + 1] cov_out, delta_mu = _propagate_obs_cov_batch(sub_jacs, sub_hess, cov_in, order=order) if i is None: return cov_out, delta_mu return cov_out[0], delta_mu[0]
# ── Internal helpers ────────────────────────────────────────────────── def _to_mjd_tdb(epoch: EpochLike) -> float: """Coerce any accepted epoch input to a scalar MJD TDB.""" if isinstance(epoch, Epochs): if len(epoch) != 1: raise ValueError(f"expected a single-row Epochs, got length {len(epoch)}") return float(epoch.to_tdb().mjd.to_numpy(zero_copy_only=False)[0]) if isinstance(epoch, str): return float(Epochs.from_iso([epoch]).to_tdb().mjd.to_numpy(zero_copy_only=False)[0]) if isinstance(epoch, datetime): return float( Epochs.from_iso([epoch.isoformat()]).to_tdb().mjd.to_numpy(zero_copy_only=False)[0] ) return float(epoch) def _require_single_state_chain(table: StateSensitivities, method: str) -> None: """Guard for per-chain methods on :class:`StateSensitivities`. Raises :class:`ValueError` if the table contains more than one unique ``orbit_id``, with a hint to filter via ``select`` first. """ oids = table.orbit_ids_unique() if len(oids) > 1: preview = ", ".join(repr(o) for o in oids[:3]) more = f" (+{len(oids) - 3} more)" if len(oids) > 3 else "" raise ValueError( f"{method}() requires a single chain but got {len(oids)} unique " f"orbit_ids: {preview}{more}. Filter to one chain first: " f'sens.select("orbit_id", "<orbit_id>").{method}(...)' ) def _require_single_obs_chain(table: ObservationSensitivities, method: str) -> None: """Guard for per-chain methods on :class:`ObservationSensitivities`. Raises :class:`ValueError` if the table contains more than one unique ``(orbit_id, obs_code)`` pair, with a hint to filter via chained ``select`` calls first. """ keys = table.chain_keys() if len(keys) > 1: preview = ", ".join(repr(k) for k in keys[:3]) more = f" (+{len(keys) - 3} more)" if len(keys) > 3 else "" raise ValueError( f"{method}() requires a single chain but got {len(keys)} unique " f"(orbit_id, obs_code) pairs: {preview}{more}. Filter to one " f"chain first: " f'obs.select("orbit_id", "<oid>").select("obs_code", "<code>").{method}(...)' ) def _propagate_cov_batch( stms: np.ndarray, # (k, 6, 6) stts: np.ndarray | None, # (k, 6, 6, 6) or None cov_in: np.ndarray, # (6, 6) *, order: int, ) -> tuple[np.ndarray, np.ndarray]: """Vectorized state-space covariance propagation. Returns ``(cov_out, delta_mu)`` with shapes ``(k, 6, 6)`` and ``(k, 6)``. """ cov_out = np.einsum("tij,jk,tlk->til", stms, cov_in, stms) k = stms.shape[0] delta_mu = np.zeros((k, 6), dtype=np.float64) if order == 2 and stts is not None: term1 = np.einsum("tkab,tlcd,ac,bd->tkl", stts, stts, cov_in, cov_in) term2 = np.einsum("tkab,tlcd,ad,bc->tkl", stts, stts, cov_in, cov_in) cov_out = cov_out + 0.5 * (term1 + term2) delta_mu = 0.5 * np.einsum("tkab,ab->tk", stts, cov_in) return cov_out, delta_mu def _kappa_batch( stts: np.ndarray, # (k, 6, 6, 6) cov_in: np.ndarray, # (6, 6) ) -> np.ndarray: """Per-epoch κ_t = ‖Ψ_t · Σ_0‖_F / 6.""" quad = np.einsum("tkab,ab->tk", stts, cov_in) return np.asarray(np.linalg.norm(quad, axis=1) / 6.0, dtype=np.float64) def _propagate_obs_cov_batch( jacs: np.ndarray, # (k, 6, n) hess: np.ndarray | None, # (k, 6, n, n) or None cov_in: np.ndarray, # (n, n) *, order: int, ) -> tuple[np.ndarray, np.ndarray]: """Vectorized observation-space covariance propagation.""" cov_out = np.einsum("tij,jk,tlk->til", jacs, cov_in, jacs) k = jacs.shape[0] delta_mu = np.zeros((k, 6), dtype=np.float64) if order == 2 and hess is not None: term1 = np.einsum("tkab,tlcd,ac,bd->tkl", hess, hess, cov_in, cov_in) term2 = np.einsum("tkab,tlcd,ad,bc->tkl", hess, hess, cov_in, cov_in) cov_out = cov_out + 0.5 * (term1 + term2) delta_mu = 0.5 * np.einsum("tkab,ab->tk", hess, cov_in) return cov_out, delta_mu