Skip to content

symplectic

The geodesic equation is usually solved in the classic Lagrangian form, however it also admits a Hamiltonian. This hamiltonian is non-separable, meaning that standard methods for integration are not feasible. Recently, an approach for integration of non-separable Hamiltonians has been developed, doubling the phase-space and evolving states in parallel.123 Riemax provides an implementation of this approach for defining dynamics of geodesics upto arbitrary orders of integration.

Additional documentation required.

This documentation requires further explanation of the method for integration of non-separable Hamiltonians. This page will be updated in future iterations of the documentation to make the process as clear as possible.

riemax.manifold.symplectic

riemax.manifold.symplectic.SymplecticGeodesicState

Bases: typing.NamedTuple

PyTree for Symplectic Geodesic State.

Parameters:

Name Type Description Default
q

position on the geodesic

required
p

conjugate momenta on the co-tangent space

required
Source code in src/riemax/manifold/symplectic.py
class SymplecticGeodesicState(tp.NamedTuple):

    """PyTree for Symplectic Geodesic State.

    Parameters:
        q: position on the geodesic
        p: conjugate momenta on the co-tangent space
    """

    q: jax.Array
    p: jax.Array

    @classmethod
    def from_lagrangian(cls, state: TangentSpace[jax.Array], metric: MetricFn) -> SymplecticGeodesicState:
        """Build Hamiltonian symplectic state from given Lagrangian state.

        Parameters:
            state: Geodesic state in Lagrangian coordinates.
            metric: Function used to evaluate the metric.

        Returns:
            Symplectic state in Hamiltonian coordinates.
        """

        conjugate_momenta = jnp.einsum('...ij, ...j -> ...i', metric(state.point), state.vector)
        return cls(q=state.point, p=conjugate_momenta)

    def to_lagrangian(self, metric: MetricFn) -> TangentSpace[jax.Array]:
        """Convert intrinstic Hamiltonian coordinates to Lagrangian coordinates.

        Parameters:
            metric: Function used to evaluate the metric.

        Returns:
            Geodesic state in Lagrangian coordinates.
        """

        velocity = jnp.einsum('...ij, ...j -> ...i', contravariant_metric_tensor(self.q, metric), self.p)
        return TangentSpace(point=self.q, vector=velocity)
from_lagrangian(state: TangentSpace[jax.Array], metric: MetricFn) -> SymplecticGeodesicState classmethod

Build Hamiltonian symplectic state from given Lagrangian state.

Parameters:

Name Type Description Default
state riemax.manifold.types.TangentSpace[jax.Array]

Geodesic state in Lagrangian coordinates.

required
metric riemax.manifold.types.MetricFn

Function used to evaluate the metric.

required

Returns:

Type Description
riemax.manifold.symplectic.SymplecticGeodesicState

Symplectic state in Hamiltonian coordinates.

Source code in src/riemax/manifold/symplectic.py
@classmethod
def from_lagrangian(cls, state: TangentSpace[jax.Array], metric: MetricFn) -> SymplecticGeodesicState:
    """Build Hamiltonian symplectic state from given Lagrangian state.

    Parameters:
        state: Geodesic state in Lagrangian coordinates.
        metric: Function used to evaluate the metric.

    Returns:
        Symplectic state in Hamiltonian coordinates.
    """

    conjugate_momenta = jnp.einsum('...ij, ...j -> ...i', metric(state.point), state.vector)
    return cls(q=state.point, p=conjugate_momenta)
to_lagrangian(metric: MetricFn) -> TangentSpace[jax.Array]

Convert intrinstic Hamiltonian coordinates to Lagrangian coordinates.

Parameters:

Name Type Description Default
metric riemax.manifold.types.MetricFn

Function used to evaluate the metric.

required

Returns:

Type Description
riemax.manifold.types.TangentSpace[jax.Array]

Geodesic state in Lagrangian coordinates.

Source code in src/riemax/manifold/symplectic.py
def to_lagrangian(self, metric: MetricFn) -> TangentSpace[jax.Array]:
    """Convert intrinstic Hamiltonian coordinates to Lagrangian coordinates.

    Parameters:
        metric: Function used to evaluate the metric.

    Returns:
        Geodesic state in Lagrangian coordinates.
    """

    velocity = jnp.einsum('...ij, ...j -> ...i', contravariant_metric_tensor(self.q, metric), self.p)
    return TangentSpace(point=self.q, vector=velocity)

riemax.manifold.symplectic.PhaseDoubledSymplecticGeodesicState

Bases: typing.NamedTuple

PyTree for the phase-doubled Symplectic Geodesic State

Parameters:

Name Type Description Default
q

position on the geodesic

required
p

conjugate momenta on the co-tangent space

required
x

phase-doubled position on the geodesic

required
y

phase-doubled conjugate momenta on the co-tangent space

required
Source code in src/riemax/manifold/symplectic.py
class PhaseDoubledSymplecticGeodesicState(tp.NamedTuple):

    """PyTree for the phase-doubled Symplectic Geodesic State

    Parameters:
        q: position on the geodesic
        p: conjugate momenta on the co-tangent space
        x: phase-doubled position on the geodesic
        y: phase-doubled conjugate momenta on the co-tangent space
    """

    q: jax.Array
    p: jax.Array
    x: jax.Array
    y: jax.Array

    @classmethod
    def from_symplectic(cls, state: SymplecticGeodesicState) -> PhaseDoubledSymplecticGeodesicState:
        """Build phase-doubled symplectic state from given symplectic state.

        Parameters:
            state: Symplectic state in a single phase-space.

        Returns:
            Phase-doubled symplectic state, replicating state in new phase-space.
        """

        return cls(q=state.q, p=state.p, x=state.q, y=state.p)

    def to_symplectic(self) -> SymplecticGeodesicState:
        """Transform phase-doubled symplectic state to a single-phase symplectic state.

        Returns:
            Single-phase symplectic state -- removing phase-doubling.
        """

        return SymplecticGeodesicState(q=self.q, p=self.p)
from_symplectic(state: SymplecticGeodesicState) -> PhaseDoubledSymplecticGeodesicState classmethod

Build phase-doubled symplectic state from given symplectic state.

Parameters:

Name Type Description Default
state riemax.manifold.symplectic.SymplecticGeodesicState

Symplectic state in a single phase-space.

required

Returns:

Type Description
riemax.manifold.symplectic.PhaseDoubledSymplecticGeodesicState

Phase-doubled symplectic state, replicating state in new phase-space.

Source code in src/riemax/manifold/symplectic.py
@classmethod
def from_symplectic(cls, state: SymplecticGeodesicState) -> PhaseDoubledSymplecticGeodesicState:
    """Build phase-doubled symplectic state from given symplectic state.

    Parameters:
        state: Symplectic state in a single phase-space.

    Returns:
        Phase-doubled symplectic state, replicating state in new phase-space.
    """

    return cls(q=state.q, p=state.p, x=state.q, y=state.p)
to_symplectic() -> SymplecticGeodesicState

Transform phase-doubled symplectic state to a single-phase symplectic state.

Returns:

Type Description
riemax.manifold.symplectic.SymplecticGeodesicState

Single-phase symplectic state -- removing phase-doubling.

Source code in src/riemax/manifold/symplectic.py
def to_symplectic(self) -> SymplecticGeodesicState:
    """Transform phase-doubled symplectic state to a single-phase symplectic state.

    Returns:
        Single-phase symplectic state -- removing phase-doubling.
    """

    return SymplecticGeodesicState(q=self.q, p=self.p)

riemax.manifold.symplectic.SymplecticParams

Bases: typing.NamedTuple

Contained for parameters of symplectic integration

Parameters:

Name Type Description Default
metric

function defining the metric tensor on the manifold

required
dt

time-step used for the integration

required
omega

strength of the constraint between the phase-split copies

required
n_steps

number of steps to integrate for

required
Source code in src/riemax/manifold/symplectic.py
class SymplecticParams(tp.NamedTuple):

    """Contained for parameters of symplectic integration

    Parameters:
        metric: function defining the metric tensor on the manifold
        dt: time-step used for the integration
        omega: strength of the constraint between the phase-split copies
        n_steps: number of steps to integrate for
    """

    metric: MetricFn
    dt: float = 1e-3
    omega: float = 1e-2
    n_steps: int = int(1e3)

riemax.manifold.symplectic.hamiltonian(q: jax.Array, conjugate_momenta: jax.Array, metric: MetricFn) -> jax.Array

Computes the Hamiltonian of the state.

Parameters:

Name Type Description Default
q jax.Array

position on the geodesic

required
conjugate_momenta jax.Array

conjugate momenta of the geodesic path

required
metric riemax.manifold.types.MetricFn

function defining the metric tensor on the manifold

required

Returns:

Type Description
jax.Array

hamiltonian of the geodesic

Source code in src/riemax/manifold/symplectic.py
def hamiltonian(q: jax.Array, conjugate_momenta: jax.Array, metric: MetricFn) -> jax.Array:
    """Computes the Hamiltonian of the state.

    Parameters:
        q: position on the geodesic
        conjugate_momenta: conjugate momenta of the geodesic path
        metric: function defining the metric tensor on the manifold

    Returns:
        hamiltonian of the geodesic
    """

    fn_contra_gx = jtu.Partial(contravariant_metric_tensor, metric=metric)
    return 0.5 * jnp.einsum('ij, i, j -> ', fn_contra_gx(q), conjugate_momenta, conjugate_momenta)

riemax.manifold.symplectic.second_order_dynamics(state: PhaseDoubledSymplecticGeodesicState, dt: float, omega: float, metric: MetricFn) -> PhaseDoubledSymplecticGeodesicState

Conduct time-step using second-order dynamics.

Parameters:

Name Type Description Default
state riemax.manifold.symplectic.PhaseDoubledSymplecticGeodesicState

current state of the symplectic integrator

required
dt float

time-step used for the integration

required
omega float

strength of the constraint between the phase-split copies

required
metric riemax.manifold.types.MetricFn

function defining the metric tensor on the manifold

required

Returns:

Name Type Description
state riemax.manifold.symplectic.PhaseDoubledSymplecticGeodesicState

time-stepped, phase-doubled symplectic geodesic state

Source code in src/riemax/manifold/symplectic.py
def second_order_dynamics(
    state: PhaseDoubledSymplecticGeodesicState, dt: float, omega: float, metric: MetricFn
) -> PhaseDoubledSymplecticGeodesicState:
    """Conduct time-step using second-order dynamics.

    Parameters:
        state: current state of the symplectic integrator
        dt: time-step used for the integration
        omega: strength of the constraint between the phase-split copies
        metric: function defining the metric tensor on the manifold

    Returns:
        state: time-stepped, phase-doubled symplectic geodesic state
    """

    fn_phi_ha = jtu.Partial(_phi_ha, dt=dt, metric=metric)
    fn_phi_hb = jtu.Partial(_phi_hb, dt=dt, metric=metric)
    fn_phi_hc = jtu.Partial(_phi_hc, dt=dt, omega=omega)

    state = fn_phi_ha(state=state)
    state = fn_phi_hb(state=state)
    state = fn_phi_hc(state=state)
    state = fn_phi_hb(state=state)
    state = fn_phi_ha(state=state)

    return state

riemax.manifold.symplectic.construct_nth_order_dynamics(n: int)

Construct nth order symplectic dynamics.

Recursive definition:

Function works recursively, producing additional phase-maps as required.

Parameters:

Name Type Description Default
n int

order of integration to produce

required

Returns:

Type Description

function to compute nth order dynamics

Source code in src/riemax/manifold/symplectic.py
def construct_nth_order_dynamics(n: int):
    """Construct nth order symplectic dynamics.

    !!! note "Recursive definition:"
        Function works recursively, producing additional phase-maps as required.

    Parameters:
        n: order of integration to produce

    Returns:
        function to compute nth order dynamics
    """

    if not n % 2 == 0:
        raise ValueError('Only works for even n.')

    @ft.lru_cache()
    def _construct(n: int):
        if n == 2:
            return second_order_dynamics

        def nth_order_dynamics(state, dt, omega, metric):
            _n = (n - 2) // 2
            z0, z1 = _yoshida_triple_jump_constants(n=_n)

            nmt_dynamics = _construct(n - 2)
            fn_phi_a = jtu.Partial(nmt_dynamics, dt=(z1 * dt), omega=omega, metric=metric)
            fn_phi_b = jtu.Partial(nmt_dynamics, dt=(z0 * dt), omega=omega, metric=metric)

            state = fn_phi_a(state=state)
            state = fn_phi_b(state=state)
            state = fn_phi_a(state=state)

            return state

        return nth_order_dynamics

    return _construct(n)

riemax.manifold.symplectic.construct_nth_order_symplectic_integrator(n: int) -> PhaseDoubledSymplecticIntegrator

Construct symplectic integrator of the nth order.

Parameters:

Name Type Description Default
n int

order of integration required

required

Returns:

Type Description
riemax.manifold.symplectic.PhaseDoubledSymplecticIntegrator

integrator for phase-doubled symplectic state

Source code in src/riemax/manifold/symplectic.py
def construct_nth_order_symplectic_integrator(n: int) -> PhaseDoubledSymplecticIntegrator:
    """Construct symplectic integrator of the nth order.

    Parameters:
        n: order of integration required

    Returns:
        integrator for phase-doubled symplectic state
    """

    def _nth_order_symplectic_integrator(
        symplectic_params: SymplecticParams, initial_state: PhaseDoubledSymplecticGeodesicState
    ) -> tuple[PhaseDoubledSymplecticGeodesicState, PhaseDoubledSymplecticGeodesicState]:
        """Integrator using nth order symplectic dynamics.

        Parameters:
            symplectic_params: parameters for the symplectic integration
            initial_state: state at t=0
        """

        nth_order_dynamics = construct_nth_order_dynamics(n)
        fn_updator = jtu.Partial(
            nth_order_dynamics, dt=symplectic_params.dt, omega=symplectic_params.omega, metric=symplectic_params.metric
        )

        def _single_step(
            state: PhaseDoubledSymplecticGeodesicState, _: None
        ) -> tuple[PhaseDoubledSymplecticGeodesicState, PhaseDoubledSymplecticGeodesicState]:
            return fn_updator(state=state), state

        final_state, preceding_states = jax.lax.scan(_single_step, initial_state, None, symplectic_params.n_steps)
        full_state = _merge_states(preceding=preceding_states, final=final_state)

        return final_state, full_state

    return _nth_order_symplectic_integrator

  1. Christian, Pierre, and Chi-kwan Chan. ‘FANTASY: User-Friendly Symplectic Geodesic Integrator for Arbitrary Metrics with Automatic Differentiation’. The Astrophysical Journal 909, 2021. https://doi.org/10.3847/1538-4357/abdc28 

  2. Tao, Molei. ‘Explicit Symplectic Approximation of Nonseparable Hamiltonians: Algorithm and Long Time Performance’. Physical Review E 94, 2016. https://doi.org/10.1103/PhysRevE.94.043303 

  3. Yoshida, Haruo. ‘Construction of Higher Order Symplectic Integrators’. Physics Letters A, 1990. https://doi.org/10.1016/0375-9601(90)90092-3