Skip to content

euclidean

src.riemax.euclidean

src.riemax.euclidean.metric_tensor(x: jax.Array) -> jax.Array

Defines the metric tensor for Euclidean space.

In Euclidean space, the metric tensor is defined as the identity matrix

\[ g_{ij} = \delta_{ij}. \]

Warning

The Euclidean metric defined in this manner is not differentiable. This could cause problems in some places.

Parameters:

Name Type Description Default
x jax.Array

position \(p \in \mathbb{R}\) at which to evaluate the metric tensor

required

Returns:

Type Description
jax.Array

metric tensor in Euclidean space -- the identity matrix

Source code in src/riemax/manifold/euclidean.py
def metric_tensor(x: jax.Array) -> jax.Array:
    r"""Defines the metric tensor for Euclidean space.

    In Euclidean space, the metric tensor is defined as the identity matrix

    $$
    g_{ij} = \delta_{ij}.
    $$

    !!! warning

        The Euclidean metric defined in this manner is not differentiable. This could cause problems in some places.

    Parameters:
        x: position $p \in \mathbb{R}$ at which to evaluate the metric tensor

    Returns:
        metric tensor in Euclidean space -- the identity matrix
    """

    return jnp.eye(N=x.shape[-1])

src.riemax.euclidean.distance(p: jax.Array, q: jax.Array) -> jax.Array

Compute Euclidean distance between points.

The Euclidean distance is simply defined by the L2 norm:

\[ d_E(p, q) = \lVert p - q \rVert_2. \]

Parameters:

Name Type Description Default
p jax.Array

position \(p \in \mathbb{R}\) of the first point

required
q jax.Array

position \(q \in \mathbb{R}\) of the second point

required

Returns:

Type Description
jax.Array

euclidean distance between \(p, q\)

Source code in src/riemax/manifold/euclidean.py
def distance(p: jax.Array, q: jax.Array) -> jax.Array:
    r"""Compute Euclidean distance between points.

    The Euclidean distance is simply defined by the L2 norm:

    $$
    d_E(p, q) = \lVert p - q \rVert_2.
    $$

    Parameters:
        p: position $p \in \mathbb{R}$ of the first point
        q: position $q \in \mathbb{R}$ of the second point

    Returns:
        euclidean distance between $p, q$
    """

    return jnp.sqrt(jnp.einsum('...i -> ...', (p - q) ** 2))