Skip to content

maps

The riemax.manifold.maps module allows the user to define the exponential and log maps on a manifold. These are defined using the geodesic dynamics, defined in riemax.manifold.geodesic.

Types:

type ExponentialMap = tp.Callable[[TangentSpace[jax.Array]], tuple[M[jax.Array], TangentSpace[jax.Array]]]
type LogMap[*Ts] = tp.Callable[[M[jax.Array], M[jax.Array], *Ts], tuple[TangentSpace[jax.Array], bool]]

Exponential Map

Suppose we have a continuous, differentiable manifold, \(M\). Given a point \(p \in M\), and tangent vector \(v \in T_p M\), there exists a unique geodesic \(\gamma_v : [0, 1] \rightarrow M\) satisfying \(\gamma_v(0) = p\), \(\dot{\gamma}_v(0) = v\). The exponential map is defined by \(\exp_p(v) = \gamma_v(1)\), or \(\exp_p : T_p M \rightarrow M\).

Log Map

Given two points \(p, q \in M\), the \(\log\) map provides the tangent-vector which, upon application of the exponential map, transports one point to the other. The log map is the natural inverse of the exponential map, defined as \(\log_p(q) = v\) such that \(\exp_p(\log_p(q)) = q\). This mapping is not unique as there may be many tangent-vectors which connect the two points \(p, q\).

Shooting Solver

Ordinarily, we would consider computing the \(\log\) map a two-point boundary value problem. We can approach this using a shooting method, posing the problem: find \(v \in T_p M\) such that \(\exp_p (v) = q\). We define the residual

\[ r(v) = \exp_p(v) - q, \]

and use a root-finding technique, such as Newton-Raphson, to obtain a solution. While this remains a principled approach, it is somewhat reliant on having a good initial guess for the solution.

riemax.manifold.maps

riemax.manifold.maps.exponential_map_factory(integrator: Integrator[TangentSpace[jax.Array]], dt: float, metric: MetricFn, n_steps: int | None = None) -> ExponentialMap

Produce an exponential map, \(\exp: TM \rightarrow M\).

Example:

# ...

exp_map = exponential_map_factory(riemax.numerical.integrators.odeint, dt=1e-3, metric=fn_metric)

Parameters:

Name Type Description Default
integrator riemax.numerical.integrators.Integrator[riemax.manifold.types.TangentSpace[jax.Array]]

choice of integrator used to propgate dynamics

required
dt float

time-step for the integration

required
metric riemax.manifold.types.MetricFn

function defining the metric tensor on the manifold

required
n_steps int | None

number of steps to integrate for

None

Returns:

Name Type Description
exp_map ExponentialMap

function for computing exponential map

Source code in src/riemax/manifold/maps.py
def exponential_map_factory(
    integrator: Integrator[TangentSpace[jax.Array]], dt: float, metric: MetricFn, n_steps: int | None = None
) -> ExponentialMap:
    r"""Produce an exponential map, $\exp: TM \rightarrow M$.

    !!! note "Example:"

        ```python
        # ...

        exp_map = exponential_map_factory(riemax.numerical.integrators.odeint, dt=1e-3, metric=fn_metric)
        ```

    Parameters:
        integrator: choice of integrator used to propgate dynamics
        dt: time-step for the integration
        metric: function defining the metric tensor on the manifold
        n_steps: number of steps to integrate for

    Returns:
        exp_map: function for computing exponential map
    """

    if not n_steps:
        n_steps = int(1.0 // dt)

    dynamics = jtu.Partial(geodesic_dynamics, metric=metric)
    ivp_params = ParametersIVP(differential_operator=dynamics, dt=dt, n_steps=n_steps)

    @_integrator_to_exp
    def exp_map(state: TangentSpace[jax.Array]) -> tuple[TangentSpace[jax.Array], TangentSpace[jax.Array]]:
        return integrator(ivp_params, state)

    return exp_map

riemax.manifold.maps.symplectic_exponential_map_factory(integrator: LagrangianSymplecticIntegrator, dt: float, omega: float, metric: MetricFn, n_steps: int | None = None) -> ExponentialMap

Produce an exponential map, \(\exp: TM \rightarrow M\), using symplectic dynamics.

Parameters:

Name Type Description Default
integrator riemax.manifold.symplectic.LagrangianSymplecticIntegrator

choice of Lagrangian symplectic integrator used to propgate dynamics

required
dt float

time-step for the integration

required
omega float

strength of the constraint between the phase-split copies

required
metric riemax.manifold.types.MetricFn

function defining the metric tensor on the manifold

required
n_steps int | None

number of steps to integrate for

None

Returns:

Name Type Description
exp_map ExponentialMap

function for computing exponential map

Source code in src/riemax/manifold/maps.py
def symplectic_exponential_map_factory(
    integrator: LagrangianSymplecticIntegrator, dt: float, omega: float, metric: MetricFn, n_steps: int | None = None
) -> ExponentialMap:
    r"""Produce an exponential map, $\exp: TM \rightarrow M$, using symplectic dynamics.

    Parameters:
        integrator: choice of Lagrangian symplectic integrator used to propgate dynamics
        dt: time-step for the integration
        omega: strength of the constraint between the phase-split copies
        metric: function defining the metric tensor on the manifold
        n_steps: number of steps to integrate for

    Returns:
        exp_map: function for computing exponential map
    """

    if not n_steps:
        n_steps = int(1.0 // dt)

    symplectic_params = SymplecticParams(metric=metric, dt=dt, omega=omega, n_steps=n_steps)

    @_integrator_to_exp
    def exp_map(state: TangentSpace[jax.Array]) -> tuple[TangentSpace[jax.Array], TangentSpace[jax.Array]]:
        return integrator(symplectic_params, state)

    return exp_map

riemax.manifold.maps.shooting_log_map_factory(exp_map: ExponentialMap, nr_parameters: NewtonRaphsonParams | None = None) -> LogMap

Produce log map, computed using a shooting method.

Efficacy of Shooting Solvers:

Shooting solvers typically require a good initial guess. If the initial guess for the velocity vector is too far from a true solution, this can tend to fail. We also note that, this does not guarantee obtaining the velocity vector of the globally length-minimising geodesic -- only of a valid geodesic which connects the two points.

Parameters:

Name Type Description Default
exp_map ExponentialMap

function used to compute the exponential map

required
nr_parameters riemax.numerical.newton_raphson.NewtonRaphsonParams | None

parameters used in the Newton-Raphson optimisation

None

Returns:

Name Type Description
log_map LogMap

function to compute the log map between \(p, q \in M\)

Source code in src/riemax/manifold/maps.py
def shooting_log_map_factory(exp_map: ExponentialMap, nr_parameters: NewtonRaphsonParams | None = None) -> LogMap:
    r"""Produce log map, computed using a shooting method.

    !!! note "Efficacy of Shooting Solvers:"

        Shooting solvers typically require a good initial guess. If the initial guess for the velocity vector is too far
        from a true solution, this can tend to fail. We also note that, this does not guarantee obtaining the velocity
        vector of the globally length-minimising geodesic -- only of a valid geodesic which connects the two points.

    Parameters:
        exp_map: function used to compute the exponential map
        nr_parameters: parameters used in the Newton-Raphson optimisation

    Returns:
        log_map: function to compute the log map between $p, q \in M$
    """

    def log_map(
        p: TangentSpace[jax.Array] | M[jax.Array],
        q: M[jax.Array],
    ) -> tuple[TangentSpace[jax.Array], bool]:
        """Compute the log map between points p and q.

        Parameters:
            p: origin point on the manifold
            q: destination point on the manifold
            initial_p0_dot: initial guess for the tangent vector

        Returns:
            state which, when the exponential map is taken at p, yields q
        """

        if not isinstance(p, TangentSpace):
            p = TangentSpace(point=p, vector=(q - p))

        def shooting_residual(p_dot: TpM[jax.Array]) -> TpM[jax.Array]:
            initial_state = TangentSpace[jax.Array](point=p.point, vector=p_dot)
            point, _ = exp_map(initial_state)

            return point - q

        # root-finding for shooting residual
        p_dot, newton_convergence_state = newton_raphson(
            shooting_residual, initial_guess=p.vector, nr_parameters=nr_parameters
        )

        initial_condition = TangentSpace[jax.Array](point=p.point, vector=p_dot)

        return initial_condition, newton_convergence_state.converged

    return log_map

riemax.manifold.maps.minimising_log_map_factory(metric: MetricFn, optimiser: optax.GradientTransformation, num_nodes: int = 20, n_collocation: int = 100, iterations: int = 100, tol: float = 0.0001) -> LogMap

Produce a log-map using an energy-minimising approach.

Parameters:

Name Type Description Default
metric riemax.manifold.types.MetricFn

function defining the metric tensor on the manifold

required
optimiser optax.GradientTransformation

optimiser to use to minimise energy of the curve

required
num_nodes int

number of nodes to use to parameterise cubic spline

20
n_collocation int

number of points at which to evaluate energy along the curve

100
iterations int

number of iterations to optimise for

100
tol float

tolerance for measuring convergence

0.0001

Returns:

Name Type Description
log_map LogMap

function to compute the log map between \(p, q \in M\)

Source code in src/riemax/manifold/maps.py
def minimising_log_map_factory(
    metric: MetricFn,
    optimiser: optax.GradientTransformation,
    num_nodes: int = 20,
    n_collocation: int = 100,
    iterations: int = 100,
    tol: float = 1e-4,
) -> LogMap:
    r"""Produce a log-map using an energy-minimising approach.

    Parameters:
        metric: function defining the metric tensor on the manifold
        optimiser: optimiser to use to minimise energy of the curve
        num_nodes: number of nodes to use to parameterise cubic spline
        n_collocation: number of points at which to evaluate energy along the curve
        iterations: number of iterations to optimise for
        tol: tolerance for measuring convergence

    Returns:
        log_map: function to compute the log map between $p, q \in M$
    """

    def log_map(p: TangentSpace[jax.Array] | M[jax.Array], q: M[jax.Array]) -> tuple[TangentSpace[jax.Array], bool]:
        """Compute the log map between points p and q.

        Parameters:
            p: origin point on the manifold
            q: destination point on the manifold

        Returns:
            state which, when the exponential map is taken at p, yields q
        """

        if isinstance(p, TangentSpace):
            p = p.point

        geodesic, converged = minimising_geodesic(
            p=p,
            q=q,
            metric=metric,
            optimiser=optimiser,
            num_nodes=num_nodes,
            n_collocation=n_collocation,
            iterations=iterations,
            tol=tol,
        )

        return TangentSpace(point=geodesic.point[0], vector=geodesic.vector[0]), converged

    return log_map

riemax.manifold.maps.scipy_bvp_log_map_factory(metric: MetricFn, n_collocation: int = 100, explicit_jacobian: bool = False, tol: float = 0.0001) -> LogMap

Produce a log-map using scipy solve_bvp approach.

Parameters:

Name Type Description Default
metric riemax.manifold.types.MetricFn

function defining the metric tensor on the manifold

required
n_collocation int

number of points at which to evaluate energy along the curve

100
explicit_jacobian bool

whether to use the jacobian compute by jax

False
tol float

tolerance for measuring convergence

0.0001

Returns:

Name Type Description
log_map LogMap

function to compute the log map between \(p, q \in M\)

Source code in src/riemax/manifold/maps.py
def scipy_bvp_log_map_factory(
    metric: MetricFn, n_collocation: int = 100, explicit_jacobian: bool = False, tol: float = 1e-4
) -> LogMap:
    r"""Produce a log-map using scipy solve_bvp approach.

    Parameters:
        metric: function defining the metric tensor on the manifold
        n_collocation: number of points at which to evaluate energy along the curve
        explicit_jacobian: whether to use the jacobian compute by jax
        tol: tolerance for measuring convergence

    Returns:
        log_map: function to compute the log map between $p, q \in M$
    """

    def log_map(p: TangentSpace[jax.Array] | M[jax.Array], q: M[jax.Array]) -> tuple[TangentSpace[jax.Array], bool]:
        """Compute the log map between points p and q.

        Parameters:
            p: origin point on the manifold
            q: destination point on the manifold

        Returns:
            state which, when the exponential map is taken at p, yields q
        """

        if isinstance(p, TangentSpace):
            p = p.point

        geodesic, converged = scipy_bvp_geodesic(
            p=p, q=q, metric=metric, n_collocation=n_collocation, explicit_jacobian=explicit_jacobian, tol=tol
        )

        return TangentSpace(point=geodesic.point[0], vector=geodesic.vector[0]), converged

    return log_map