Skip to content

curves

We provide an implementation of cubic-splines, parameterised by their null-space. The cubic spline is constrained at two end-points, and the basis of the null space is used to parameterise the curve itself.

Rational for Parameterisation:

In short, equations for cubic splines form a system of linear homogeneous equations

\[ \mathbf{A} \mathbf{x} = \mathbf{0} \]

We also know the solution set can be described as

\[ \{ \mathbf{p} + \mathbf{v} : \mathbf{v} \text{ is any solution to } \mathbf{A}\mathbf{x} = \mathbf{0} \} \]

If we consider elements of the null-basis \(\mathbf{\varphi} \in N(A)\), we see

\[ \mathbf{A}\left( \mathbf{x} + \mathbf{\varphi} \right) = \mathbf{0}, \]

so parameterising the cubic spline by the basis of the null-space, we ensure that the equations defining the cubic spline are satisfied.

riemax.numerical.curves.CubicSpline

Bases: typing.NamedTuple

Cubic spline parameterised by basis of the null-space.

Parameters:

Name Type Description Default
p

position of the curve start-point

required
q

position of the curve end-point

required
num_nodes

number of nodes used to represent the curve

required
num_edges

number of edges in the representation

required
basis

computed basis of the null-space

required
Source code in src/riemax/numerical/curves.py
class CubicSpline(tp.NamedTuple):

    """Cubic spline parameterised by basis of the null-space.

    Parameters:
        p: position of the curve start-point
        q: position of the curve end-point
        num_nodes: number of nodes used to represent the curve
        num_edges: number of edges in the representation
        basis: computed basis of the null-space
    """

    p: jax.Array
    q: jax.Array

    num_nodes: int
    num_edges: int

    basis: jax.Array

    @classmethod
    def from_nodes(cls, p: jax.Array, q: jax.Array, num_nodes: int) -> CubicSpline:
        basis = _compute_basis(num_nodes - 1)
        return cls(p=p, q=q, basis=basis, num_nodes=num_nodes, num_edges=(num_nodes - 1))

    def init_params(self) -> jax.Array:
        return jnp.zeros((self.num_nodes, self.p.shape[0]))

    def evaluate(self, t: jax.Array, params: jax.Array) -> TangentSpace[jax.Array]:
        y = _curve(t, params, self)
        y_dot = _curve_t(t, params, self)

        return TangentSpace(point=y, vector=y_dot)

    def __call__(self, t: jax.Array, params: jax.Array) -> TangentSpace[jax.Array]:
        return self.evaluate(t, params)