Skip to content

Integrators

Riemax provides a simple implementation of a number of numerical integrators. However, integration with these relies on standard automatic differentiation. Riemax also provides an implementation of odeint which provides custom reverse-mode differentiation in order to compute the adjoint. If you want more options for adjoint-enabled integrators, Diffrax is a great place to start. Hopefully, we can add similar functionality here soon...

riemax.numerical.integrators

riemax.numerical.integrators.ParametersIVP

Bases: typing.NamedTuple

Parameters for the Initial Value Problem.

Parameters:

Name Type Description Default
differential_operator

function to compute dynamics

required
dt

step size for integration

required
n_steps

total number of steps to integrate for

required
Source code in src/riemax/numerical/integrators.py
class ParametersIVP[T](tp.NamedTuple):

    """Parameters for the Initial Value Problem.

    Parameters:
        differential_operator: function to compute dynamics
        dt: step size for integration
        n_steps: total number of steps to integrate for
    """

    differential_operator: tp.Callable[[T], T]

    dt: float = 1e-3
    n_steps: int = 1000

riemax.numerical.integrators.euler_integrator(ivp_params: ParametersIVP[T], initial_state: T) -> tuple[T, T]

Forward-Euler method for integration of the initial value problem.

Parameters:

Name Type Description Default
ivp_params riemax.numerical.integrators.ParametersIVP[T]

parameters for the initial value problem

required
initial_state T

state at t=0

required

Returns:

Name Type Description
final_state T

final state of the initial value problem

full_state T

entire solution of the initial value problem

Source code in src/riemax/numerical/integrators.py
@_adjoint_warning
def euler_integrator[T](ivp_params: ParametersIVP[T], initial_state: T) -> tuple[T, T]:
    """Forward-Euler method for integration of the initial value problem.

    Parameters:
        ivp_params: parameters for the initial value problem
        initial_state: state at t=0

    Returns:
        final_state: final state of the initial value problem
        full_state: entire solution of the initial value problem
    """

    def _single_step(state: T, _: None) -> tuple[T, T]:
        update = ivp_params.differential_operator(state)
        next_state = jtu.tree_map(lambda x, dxdt: x + dxdt * ivp_params.dt, state, update)

        return next_state, state

    final_state, preceding_states = jax.lax.scan(_single_step, initial_state, None, ivp_params.n_steps)
    full_state = _merge_states(preceding=preceding_states, final=final_state)

    return final_state, full_state

riemax.numerical.integrators.implicit_euler_integrator(ivp_params: ParametersIVP[T], initial_state: T) -> tuple[T, T]

implicit-Euler method for integration of the initial value problem.

Parameters:

Name Type Description Default
ivp_params riemax.numerical.integrators.ParametersIVP[T]

parameters for the initial value problem

required
initial_state T

state at t=0 to integrate from

required

Returns:

Name Type Description
final_state T

final state of the initial value problem

full_state T

entire solution of the initial value problem

Source code in src/riemax/numerical/integrators.py
@_adjoint_warning
def implicit_euler_integrator[T](ivp_params: ParametersIVP[T], initial_state: T) -> tuple[T, T]:
    """implicit-Euler method for integration of the initial value problem.

    Parameters:
        ivp_params: parameters for the initial value problem
        initial_state: state at t=0 to integrate from

    Returns:
        final_state: final state of the initial value problem
        full_state: entire solution of the initial value problem
    """

    # TODO >> @danielkelshaw
    #         Should we consider passing nr_params as an argument?
    nr_params = NewtonRaphsonParams(max_steps=1000, target_residual=1e-9)

    def _residual(curr_state: T, state: T) -> T:
        update = ivp_params.differential_operator(state)
        return jtu.tree_map(lambda s, cs, u: s - cs - u * ivp_params.dt, state, curr_state, update)

    def _single_step(state: T, _: None) -> tuple[T, T]:
        # initial guess for the newton-raphson is the forward-Euler method
        update = ivp_params.differential_operator(state)
        nr_initial_state = jtu.tree_map(lambda x, dxdt: x + dxdt * ivp_params.dt, state, update)

        # create partial residual for current time-step
        p_residual = jtu.Partial(_residual, state)

        # compute optimised state
        next_state, _ = newton_raphson(p_residual, nr_initial_state, nr_params)  # type: ignore

        return next_state, state

    final_state, preceding_states = jax.lax.scan(_single_step, initial_state, None, ivp_params.n_steps)
    full_state = _merge_states(preceding=preceding_states, final=final_state)

    return final_state, full_state

riemax.numerical.integrators.rk4_integrator(ivp_params: ParametersIVP[T], initial_state: T) -> tuple[T, T]

Runge-Kutta (4th order) method for integration of the initial value problem.

Parameters:

Name Type Description Default
ivp_params riemax.numerical.integrators.ParametersIVP[T]

parameters for the initial value problem

required
initial_state T

state at t=0 to integrate from

required

Returns:

Name Type Description
final_state T

final state of the initial value problem

full_state T

entire solution of the initial value problem

Source code in src/riemax/numerical/integrators.py
@_adjoint_warning
def rk4_integrator[T](ivp_params: ParametersIVP[T], initial_state: T) -> tuple[T, T]:
    """Runge-Kutta (4th order) method for integration of the initial value problem.

    Parameters:
        ivp_params: parameters for the initial value problem
        initial_state: state at t=0 to integrate from

    Returns:
        final_state: final state of the initial value problem
        full_state: entire solution of the initial value problem
    """

    def _single_step(state: T, _: None) -> tuple[T, T]:
        k1 = ivp_params.differential_operator(state)
        k2 = ivp_params.differential_operator(jtu.tree_map(lambda s, k: s + ivp_params.dt * k / 2.0, state, k1))
        k3 = ivp_params.differential_operator(jtu.tree_map(lambda s, k: s + ivp_params.dt * k / 2.0, state, k2))
        k4 = ivp_params.differential_operator(jtu.tree_map(lambda s, k: s + ivp_params.dt * k, state, k3))

        update = jtu.tree_map(lambda k1, k2, k3, k4: (k1 + 2.0 * k2 + 2.0 * k3 + k4) / 6.0, k1, k2, k3, k4)

        next_state = jtu.tree_map(lambda x, dxdt: x + dxdt * ivp_params.dt, state, update)

        return next_state, state

    final_state, preceding_states = jax.lax.scan(_single_step, initial_state, None, ivp_params.n_steps)
    full_state = _merge_states(preceding=preceding_states, final=final_state)

    return final_state, full_state

riemax.numerical.integrators.odeint(ivp_params: ParametersIVP[T], initial_state: T) -> tuple[T, T]

DOPRI (4,5th order) method for integration of initial value problem -- adjoint compatible.

Parameters:

Name Type Description Default
ivp_params riemax.numerical.integrators.ParametersIVP[T]

parameters for the initial value problem

required
initial_state T

state at t=0 to integrate from

required

Returns:

Name Type Description
final_state T

final state of the initial value problem

full_state T

entire solution of the initial value problem

Source code in src/riemax/numerical/integrators.py
def odeint[T](ivp_params: ParametersIVP[T], initial_state: T) -> tuple[T, T]:
    """DOPRI (4,5th order) method for integration of initial value problem -- adjoint compatible.

    Parameters:
        ivp_params: parameters for the initial value problem
        initial_state: state at t=0 to integrate from

    Returns:
        final_state: final state of the initial value problem
        full_state: entire solution of the initial value problem
    """

    differential_operator = _timewrap(ivp_params.differential_operator)
    t_record = jnp.linspace(0.0, ivp_params.dt * ivp_params.n_steps, ivp_params.n_steps + 1)

    full_state = jode.odeint(differential_operator, initial_state, t_record)
    final_state = jtu.tree_map(lambda x: x[-1], full_state)

    return final_state, full_state