geodesic
riemax.manifold.geodesic
riemax.manifold.geodesic.geodesic_dynamics(state: TangentSpace[jax.Array], metric: MetricFn) -> TangentSpace[jax.Array]
Compute update step for the geodesic dynamics.
The geodesic equation
is a second order ordinary differential equation. We take the conventional approach of splitting this into two first-order ordinary differential equations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
riemax.manifold.types.TangentSpace[jax.Array]
|
current state of the geodesic integration |
required |
metric |
riemax.manifold.types.MetricFn
|
function defining the metric tensor on the manifold |
required |
Returns:
Type | Description |
---|---|
riemax.manifold.types.TangentSpace[jax.Array]
|
derivatives used to compute update for the state |
Source code in src/riemax/manifold/geodesic.py
riemax.manifold.geodesic.alternative_geodesic_dynamics(state: TangentSpace[jax.Array], metric: MetricFn) -> TangentSpace[jax.Array]
Compute geodesic dynamics, as per (Arvanitidis, G., Hansen, LK., Hauberg, S., 2018).1
Latent Space Oddity Approach
The paper 'Latent Space Oddity' provides a different formulation of the geodesic equation. It is not clear why this is useful or necessary, and obscures computation of the Christoffel symbols; nevertheless, an implementation is provided below.
Interestingly, seminal papers: 'A Geometric take on Metric Learning', 'Metrics for Probabilistic Models' use a similar approach but are missing a term. It appears that these are incorrect and should likely be revised to reflect their errors.
While the mathematical specification for the alternative dynamics makes use of vec
, I avoid this to ensure we
only have to compute the Jacobian of the metric tensor a single time.
-
Arvanitidis, Georgios, Lars Kai Hansen, and Søren Hauberg. ‘Latent Space Oddity: On the Curvature of Deep Generative Models’. arXiv, 2021. http://arxiv.org/abs/1710.11379 ↩
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
riemax.manifold.types.TangentSpace[jax.Array]
|
current state of the geodesic integration |
required |
metric |
riemax.manifold.types.MetricFn
|
function defining the metric tensor on the manifold |
required |
Returns:
Type | Description |
---|---|
riemax.manifold.types.TangentSpace[jax.Array]
|
derivatives used to compute update for the state |
Source code in src/riemax/manifold/geodesic.py
riemax.manifold.geodesic.compute_geodesic_length(geodesic: TangentSpace[jax.Array], dt: float, metric: MetricFn, integral_approximator: IntegralApproximationFn = mean_integration) -> jax.Array
Compute length of the geodesic.
The length of a geodesic is defined as
We note that this is not necessarily equivalent to the geodesic distance between two points.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
geodesic |
riemax.manifold.types.TangentSpace[jax.Array]
|
point on the geodesic |
required |
metric |
riemax.manifold.types.MetricFn
|
function defining the metric tensor on the manifold |
required |
Returns:
Type | Description |
---|---|
jax.Array
|
length of the geodesic |
Source code in src/riemax/manifold/geodesic.py
riemax.manifold.geodesic.compute_geodesic_energy(geodesic: TangentSpace[jax.Array], dt: float, metric: MetricFn, integral_approximator: IntegralApproximationFn = mean_integration) -> jax.Array
Compute energy of the geodesic.
The energy of a geodesic is defined as
Parameters:
Name | Type | Description | Default |
---|---|---|---|
geodesic |
riemax.manifold.types.TangentSpace[jax.Array]
|
point on the geodesic |
required |
metric |
riemax.manifold.types.MetricFn
|
function defining the metric tensor on the manifold |
required |
Returns:
Type | Description |
---|---|
jax.Array
|
energy of the geodesic |
Source code in src/riemax/manifold/geodesic.py
riemax.manifold.geodesic.minimising_geodesic(p: M[jax.Array], q: M[jax.Array], metric: MetricFn, optimiser: optax.GradientTransformation, num_nodes: int = 20, n_collocation: int = 100, iterations: int = 100, tol: float = 0.0001) -> tuple[TangentSpace[jax.Array], bool]
Obtain energy-minimising geodesics between two points.
This implementation models the geodesic as a cubic spline, constrained at the two end-points. An optimisation problem is solved, obtaining parameters of the cubic spline which minimise the energy of the resulting geodesic; ideally, obtaining the length-minimising geodesic between the two points.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
p |
riemax.manifold.types.M[jax.Array]
|
first end-point of the geodesic |
required |
q |
riemax.manifold.types.M[jax.Array]
|
second end-point of the geodesic |
required |
metric |
riemax.manifold.types.MetricFn
|
function defining the metric tensor on the manifold |
required |
optimiser |
optax.GradientTransformation
|
optimiser to use for the optimisation procedure |
required |
num_nodes |
int
|
number of nodes to parameterise the cubic spline by |
20
|
n_collocation |
int
|
number of points to evaluate energy at |
100
|
iterations |
int
|
number of iterations to optimise for |
100
|
tol |
float
|
tolerance for gradients of updates |
0.0001
|
Returns:
Type | Description |
---|---|
riemax.manifold.types.TangentSpace[jax.Array]
|
optimised geodesic, connecting the two points |
bool
|
whether the optimisation procedure converged |
Source code in src/riemax/manifold/geodesic.py
riemax.manifold.geodesic.scipy_bvp_geodesic(p: jax.Array, q: jax.Array, metric: MetricFn, n_collocation: int = 100, explicit_jacobian: bool = False, tol: float = 0.0001) -> tuple[TangentSpace[jax.Array], bool]
Obtain geodesic connecting two points using scipy.integrate.solve_bvp
This method mirrors minimising_geodesic
as scipy uses a similar scheme to solve boundary value problems. The scipy
implementation does not consider fixed end-points though, and a separate set of boundary conditions must be
optimised for. While the scipy implementation is more complete in terms of implementation, external calls are slower
and cannot be jitted. The necessity for minimising a boundary condition residual is also a consideration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
p |
jax.Array
|
first end-point of the geodesic |
required |
q |
jax.Array
|
second end-point of the geodesic |
required |
metric |
riemax.manifold.types.MetricFn
|
function defining the metric tensor on the manifold |
required |
n_collocation |
int
|
number of points to evaluate energy at |
100
|
explicit_jacobian |
bool
|
whether to use jacobian computed by jax |
False
|
tol |
float
|
tolerance for gradients of updates |
0.0001
|
Returns:
Type | Description |
---|---|
riemax.manifold.types.TangentSpace[jax.Array]
|
optimised geodesic, connecting the two points |
bool
|
whether the optimisation procedure converged |
Source code in src/riemax/manifold/geodesic.py
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 |
|