Skip to content

Manifold

src.riemax.Manifold

Convenience class for creating a manifold.

Note

Riemax remains a functional library, and all functionality has been defined in a pure manner. The Manifold class ties together relevant functionlity, automatically creating partial applications of functions so the user need not pass the metric function around all the time.

This class is defined dynamically, using decorators to mark relevant functions throughout the library.

Instantiation

A default __init__ is provided, allowing the user to pass a function which computes the metric tensor on the given manifold. Alternatively, you can instantiate an Manifold using a function transformation:

manifold = riemax.Manifold.from_fn_transformation(...)
Methods

After you have defined your manifold, you have the freedom to compute quantities without passing in the metric tensor each time. Most of the time, you will be working on a single manifold, and creating partial functions can become tiresome. Instead, the Manifold class handles the partial applications, as well as jax.jit, letting you call functions naturally.

For example you can compute geometric quantities:

metric_p = manifold.metric_tensor(p)
riemann_tensor = manifold.sk_riemann_tensor(p)
Source code in src/riemax/manifold/_manifold.py
class Manifold:

    """Convenience class for creating a manifold.

    !!! note

        Riemax remains a functional library, and all functionality has been defined in a pure manner. The `Manifold`
        class ties together relevant functionlity, automatically creating partial applications of functions so the user
        need not pass the metric function around all the time.

        This class is defined dynamically, using decorators to mark relevant functions throughout the library.

    ### Instantiation

    A default `__init__` is provided, allowing the user to pass a function which computes the metric tensor on the given
    manifold. Alternatively, you can instantiate an `Manifold` using a function transformation:

    ```python
    manifold = riemax.Manifold.from_fn_transformation(...)
    ```

    ### Methods

    After you have defined your manifold, you have the freedom to compute quantities without passing in the metric
    tensor each time. Most of the time, you will be working on a single manifold, and creating partial functions can
    become tiresome. Instead, the `Manifold` class handles the partial applications, as well as `jax.jit`, letting you
    call functions naturally.

    For example you can compute geometric quantities:

    ```python
    metric_p = manifold.metric_tensor(p)
    riemann_tensor = manifold.sk_riemann_tensor(p)
    ```
    """

    def __new__(cls, metric: MetricFn, *, jit: bool = True) -> Manifold:
        obj = super().__new__(cls)
        for fn, jittable in manifold_marker:
            p_fn = jtu.Partial(fn, metric=metric)

            if jittable and jit:
                p_fn = jax.jit(p_fn)

            p_fn.__doc__ = fn.__doc__

            setattr(obj, fn.__name__, p_fn)

        return obj

    def __init__(self, metric: MetricFn, *, jit: bool = True) -> None:
        """Instantiate a Manifold

        Parameters:
            metric: function defining the metric tensor on the manifold.
            jit: whether to apply `jax.jit` to the functions.
        """

        self.metric_tensor = metric
        self.jit = jit

    def __repr__(self) -> str:
        return f'Manifold(metric={self.metric_tensor.__name__})'

    @classmethod
    def from_fn_transformation(
        cls, fn_transformation: tp.Callable[[M[jax.Array]], jax.Array], *, jit: bool = True
    ) -> Manifold:
        """Instantiates an induced Manifold from a function transformation.

        Parameters:
            fn_transformation: function used to define the induced metric.
            jit: whether to apply `jax.jit` to the functions.
        """

        p_metric = jtu.Partial(metric_tensor, fn_transformation=fn_transformation)
        p_metric.__doc__ = metric_tensor.__doc__

        return cls(metric=p_metric, jit=jit)