Parametric Derivatives¶
derivator
¶
Components for computing derivatives of the Eikonax solver.
This module contains two main components. Firstly, the
PartialDerivator evaluates the partial derivatives of the
global Eikonax update operator \(\mathbf{G}\) w.r.t. the parameter tensor field \(\mathbf{M}\) and the
corresponding solution vector \(\mathbf{u}\) obtained from a forward solve. The
DerivativeSolver component makes use of the fixed
point/adjoint property of the Eikonax solver to evaluate total parametric derivatives.
Classes:
| Name | Description |
|---|---|
PartialDerivatorData |
Settings for initialization of partial derivator |
PartialDerivator |
Component for computing partial derivatives of the Godunov Update operator |
DerivativeSolver |
Main component for obtaining gradients from partial derivatives |
eikonax.derivator.PartialDerivatorData
dataclass
¶
Settings for initialization of partial derivator.
See the Forward Solver documentation for more detailed explanations.
Attributes:
| Name | Type | Description |
|---|---|---|
softminmax_order |
int
|
Order of the the soft minmax function for differentiable transformation of the update parameters |
softminmax_cutoff |
Real
|
Cut-off in for minmax transformation, beyond which zero sensitivity is assumed. |
eikonax.derivator.PartialDerivator
¶
Bases: eqx.Module
Component for computing partial derivatives of the Godunov Update operator.
Given a tensor field \(M\) and a solution vector \(u\), the partial derivator computes the partial derivatives of the global Eikonax update operator with respect to the solution vector, \(\mathbf{G}_u(\mathbf{u}, \mathbf{m})\), and the tensor field, \(\mathbf{G}_M(\mathbf{u}, \mathbf{M})\). All derivatives are computed on the vertex level, exploiting the locality of interactions in the update operator (only adjacent simplices are considered). Therefore, we can indeed assemble the complete derivative operators as parse data structures, not just Jacobian-vector or vector-Jacobian products, within a single pass over the computational mesh. Atomic functions on the vertex level are differentiated with Jax.
Info
For the computation of the derivatives, so-called 'self-updates' are disabled. These updates occur when a vertex does not receive a lower update value from any direction than the value it currently has. At the correct solution point, this case cannot occur due to the causality of the update stencil.
Methods:
| Name | Description |
|---|---|
compute_partial_derivatives |
Compute the partial derivatives of the Godunov update operator with respect to the solution vector and the parameter tensor field, given a state for both variables |
__init__
¶
__init__(
mesh_data: preprocessing.MeshData,
derivator_data: PartialDerivatorData,
initial_sites: preprocessing.InitialSites,
) -> None
Constructor for the partial derivator object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mesh_data
|
preprocessing.MeshData
|
Mesh data object also utilized for the Eikonax solver, contains adjacency data for every vertex. |
required |
derivator_data
|
PartialDerivatorData
|
Settings for initialization of the derivator. |
required |
initial_sites
|
preprocessing.InitialSites
|
Locations and values at source points |
required |
compute_partial_derivatives
¶
compute_partial_derivatives(
solution_vector: jtFloat[jax.Array | npt.NDArray, num_vertices],
tensor_field: jtFloat[jax.Array | npt.NDArray, "num_simplices dim dim"],
) -> tuple[linalg.EikonaxSparseMatrix, linalg.DerivatorSparseTensor]
Compute the partial derivatives of the Godunov update operator.
This method provides the main interface for computing the partial derivatives of the global Eikonax update operator with respect to the solution vector and the parameter tensor field. The updates are computed locally for each vertex, such that the resulting data structures are sparse. Subsequently, further zero entries are removed to reduce the memory footprint. The derivatives computed in this component can be utilized to compute the total parametric derivative via a fix point equation, given that the provided solution vector is that fix point. The computation of partial derivatives is possible with a single pass over the mesh, since the solution of the Eikonax equation, and therefore causality within the Godunov update scheme, is known.
Note
The derivator expects the metric tensor field as used in the inner product for the
update stencil of the eikonal equation. This is the INVERSE of the conductivity
tensor, which is the actual tensor field in the eikonal equation. The
Tensorfield component provides the inverse tensor
field.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
solution_vector
|
jax.Array
|
Current solution |
required |
tensor_field
|
jax.Array
|
Parameter field |
required |
Returns:
| Type | Description |
|---|---|
tuple[linalg.EikonaxSparseMatrix, linalg.DerivatorSparseTensor]
|
tuple[eikonax.linalg.EikonaxSparseMatrix, eikonax.linalg.DerivatorSparseTensor]:
The partial derivatives G_u and G_M packaged in the project's sparse data
structures. |
_process_partial_derivative_solution
¶
_process_partial_derivative_solution(
partial_derivative_solution: jtFloat[jax.Array, "num_vertices max_num_adjacent_simplices 2"],
) -> linalg.EikonaxSparseMatrix
Convert dense partial derivatives into the project's sparse matrix representation.
This method transforms the dense vertex-wise derivative array into an
EikonaxSparseMatrix by extracting row indices,
column indices, and values. Derivatives at initial/boundary sites are zeroed out since
these vertices have fixed values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
partial_derivative_solution
|
jax.Array
|
Dense array of shape (num_vertices, max_num_adjacent_simplices, 2). The last axis contains the derivative contributions for the two neighboring vertices participating in each local Godunov update stencil. |
required |
Returns:
| Type | Description |
|---|---|
linalg.EikonaxSparseMatrix
|
eikonax.linalg.EikonaxSparseMatrix: Sparse container representing the partial
derivative matrix G_u with shape (num_vertices, num_vertices). Can be converted
to a SciPy sparse array via
|
_process_partial_derivative_parameter
¶
_process_partial_derivative_parameter(
partial_derivative_parameter: jtFloat[
jax.Array, "num_vertices max_num_adjacent_simplices dim dim"
],
) -> linalg.DerivatorSparseTensor
Package the parameter derivative tensor into DerivatorSparseTensor.
This method wraps the dense tensor of parameter derivatives (with respect to the metric tensor field) into the project's sparse tensor data structure, which pairs the derivative values with simplex adjacency information.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
partial_derivative_parameter
|
jax.Array
|
Dense array of shape (num_vertices, max_num_adjacent_simplices, dim, dim). Contains the partial derivatives of the update operator with respect to the metric tensor components of each adjacent simplex. |
required |
Returns:
| Type | Description |
|---|---|
linalg.DerivatorSparseTensor
|
eikonax.linalg.DerivatorSparseTensor: Sparse tensor wrapper holding derivative values
and simplex connectivity metadata from |
_compute_global_partial_derivatives
¶
_compute_global_partial_derivatives(
solution_vector: jtFloat[jax.Array, num_vertices],
tensor_field: jtFloat[jax.Array, "num_simplices dim dim"],
) -> tuple[
jtFloat[jax.Array, "num_vertices max_num_adjacent_simplices 2"],
jtFloat[jax.Array, "num_vertices max_num_adjacent_simplices dim dim"],
]
Compute partial derivatives of the global update operator.
The method is a jitted and vectorized call to the
_compute_vertex_partial_derivative
method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
solution_vector
|
jax.Array
|
Global solution vector |
required |
tensor_field
|
jax.Array
|
Global parameter tensor field |
required |
Returns:
| Type | Description |
|---|---|
tuple[jtFloat[jax.Array, 'num_vertices max_num_adjacent_simplices 2'], jtFloat[jax.Array, 'num_vertices max_num_adjacent_simplices dim dim']]
|
tuple[jax.Array, jax.Array]: Raw data for partial derivatives, with shapes (N, num_adjacent_simplices, 2) and (N, num_adjacent_simplices, dim, dim) |
_compute_vertex_partial_derivatives
¶
_compute_vertex_partial_derivatives(
solution_vector: jtFloat[jax.Array, num_vertices],
tensor_field: jtFloat[jax.Array, "num_simplices dim dim"],
adjacency_data: jtInt[jax.Array, "max_num_adjacent_simplices 4"],
) -> tuple[
jtFloat[jax.Array, "max_num_adjacent_simplices 2"],
jtFloat[jax.Array, "max_num_adjacent_simplices dim dim"],
]
Compute partial derivatives for the update of a single vertex.
The method computes candidates for all respective subterms through calls to further methods. These candidates are filtered for feasibility by means of JAX filters. The softmin function (and its gradient) is applied to the directions of all optimal updates to ensure differentiability, other contributions are discarded. Lastly, the evaluated contributions are combined according to the form of the "total differential" for the partial derivatives.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
solution_vector
|
jax.Array
|
Global solution vector |
required |
tensor_field
|
jax.Array
|
Global parameter tensor field |
required |
adjacency_data
|
jax.Array
|
Adjacency data for the vertex under consideration |
required |
Returns:
| Type | Description |
|---|---|
tuple[jtFloat[jax.Array, 'max_num_adjacent_simplices 2'], jtFloat[jax.Array, 'max_num_adjacent_simplices dim dim']]
|
tuple[jax.Array, jax.Array]: Partial derivatives for the given vertex |
_compute_vertex_partial_derivative_candidates
¶
_compute_vertex_partial_derivative_candidates(
solution_vector: jtFloat[jax.Array, num_vertices],
tensor_field: jtFloat[jax.Array, "num_simplices dim dim"],
adjacency_data: jtInt[jax.Array, "max_num_adjacent_simplices 4"],
) -> tuple[
jtFloat[jax.Array, "max_num_adjacent_simplices 4 2"],
jtFloat[jax.Array, "max_num_adjacent_simplices 4 dim dim"],
]
Compute partial derivatives corresponding to potential update candidates for a vertex.
Update candidates and corresponding derivatives are computed for all adjacent simplices, and for all possible update parameters per simplex.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
solution_vector
|
jax.Array
|
Global solution vector |
required |
tensor_field
|
jax.Array
|
Global parameter field |
required |
adjacency_data
|
jax.Array
|
Adjacency data for the given vertex |
required |
Returns:
| Type | Description |
|---|---|
tuple[jtFloat[jax.Array, 'max_num_adjacent_simplices 4 2'], jtFloat[jax.Array, 'max_num_adjacent_simplices 4 dim dim']]
|
tuple[jax.Array, jax.Array]: Candidates for partial derivatives |
_compute_partial_derivative_candidates_from_adjacent_simplex
¶
_compute_partial_derivative_candidates_from_adjacent_simplex(
solution_vector: jtFloat[jax.Array, num_vertices],
tensor_field: jtFloat[jax.Array, "num_simplices dim dim"],
adjacency_data: jtInt[jax.Array, 4],
) -> tuple[jtFloat[jax.Array, "4 2"], jtFloat[jax.Array, "4 dim dim"]]
Compute partial derivatives for all update candidates within an adjacent simplex.
The update candidates are evaluated according to the different candidates for the optimization parameters \(\lambda\). Contributions are combined to the form of the involved total differentials.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
solution_vector
|
jax.Array
|
Global solution vector |
required |
tensor_field
|
jax.Array
|
Flobal parameter field |
required |
adjacency_data
|
jax.Array
|
Adjacency data for the given vertex and simplex |
required |
Returns:
| Type | Description |
|---|---|
tuple[jtFloat[jax.Array, '4 2'], jtFloat[jax.Array, '4 dim dim']]
|
tuple[jax.Array, jax.Array]: Derivative candidate from the given simplex |
_filter_candidates
staticmethod
¶
_filter_candidates(
vertex_update_candidates: jtFloat[jax.Array, "max_num_adjacent_simplices 4"],
grad_update_solution_candidates: jtFloat[jax.Array, "max_num_adjacent_simplices 4 2"],
grad_update_parameter_candidates: jtFloat[jax.Array, "max_num_adjacent_simplices 4 dim dim"],
) -> tuple[
jtFloat[jax.Array, ""],
jtFloat[jax.Array, "max_num_adjacent_simplices 4 2"],
jtFloat[jax.Array, "max_num_adjacent_simplices 4 dim dim"],
]
Mask irrelevant derivative candidates so that they are discarded later.
Values are masked by setting them to zero or infinity, depending on the routine in which they are utilized later. Partial derivatives are only relevant if the corresponding update corresponds to an optimal path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vertex_update_candidates
|
jax.Array
|
Update candidates for a given vertex |
required |
grad_update_solution_candidates
|
jax.Array
|
Partial derivative candidates w.r.t. the solution vector |
required |
grad_update_parameter_candidates
|
jax.Array
|
Partial derivative candidates w.r.t. the parameter field |
required |
Returns:
| Type | Description |
|---|---|
tuple[jtFloat[jax.Array, ''], jtFloat[jax.Array, 'max_num_adjacent_simplices 4 2'], jtFloat[jax.Array, 'max_num_adjacent_simplices 4 dim dim']]
|
tuple[jax.Array, jax.Array, jax.Array]: Optimal update value, masked partial derivatives |
_compute_lambda_grad
¶
_compute_lambda_grad(
solution_values: jtFloat[jax.Array, 2],
parameter_tensor: jtFloat[jax.Array, "dim dim"],
edges: tuple[jtFloat[jax.Array, dim], jtFloat[jax.Array, dim], jtFloat[jax.Array, dim]],
) -> tuple[jtFloat[jax.Array, "4 2"], jtFloat[jax.Array, "4 dim dim"]]
Compute the partial derivatives of update parameters for a single vertex.
This method evaluates the partial derivatives of the update parameters with respect to the current solution vector and the given parameter field, for a single triangle.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
solution_values
|
jax.Array
|
Current solution values at the opposite vertices of the considered triangle |
required |
parameter_tensor
|
jax.Array
|
Parameter tensor for the given triangle |
required |
edges
|
tuple[jax.Array, jax.Array, jax.Array]
|
Edges of the considered triangle |
required |
Returns:
| Type | Description |
|---|---|
tuple[jtFloat[jax.Array, '4 2'], jtFloat[jax.Array, '4 dim dim']]
|
tuple[jax.Array, jax.Array]: Jacobians of the update parameters w.r.t. the solution vector and the parameter tensor |
eikonax.derivator.DerivativeSolver
¶
Main component for obtaining gradients from partial derivatives.
The Eikonax PartialDerivator computes partial
derivatives of the global update operator with respect
to the solution vector, \(\mathbf{G}_u\), and the parameter tensor field, \(\mathbf{G}_M\).
We exploit the fact that the obtained solution candidate from a forward solve
\(\mathbf{u}\in\mathbb{R}^{N_V}\) is, up to a given accuracy, a fixed point of the
global update operator. We further consider the scenario of \(\mathbf{M}(\mathbf{m})\) being
dependent on some parameter \(\mathbf{m}\in\mathbb{R}^M\). This means we can write \(\mathbf{u}\) as
a function of \(\mathbf{m}\), obeying the relation
To obtain the Jacobian \(\mathbf{J} = \frac{d\mathbf{u}}{d\mathbf{m}}\in\mathbb{R}^{N_V\times M}\), we simply differentiate the fixed point relation,
\(\mathbf{G}_u\) and \(\mathbf{G}_M\) are provided by the
PartialDerivator, whereas
\(\frac{d\mathbf{M}}{d\mathbf{m}}\) is computed as the Jacobian of the
TensorField component.
We are typically not interested in the full Jacobian, but rather in the gradient of some cost functional \(l:\mathbb{R}^{N_V}\to\mathbb{R},\ l=l(\mathbf{u}(\mathbf{m}))\) with respect to \(\mathbf{m}\). The gradient is given as
We can identify \(\mathbf{v}\) as the adjoint variable, which is obtained by solving the linear discrete adjoint equation,
Now comes the catch: Through the strict causality of the Godunov update operator, we can find a unique and consistent ordering of vertex indices, such that the solution at a vertex \(i\) is only informed by the solution at a vertex \(j\), if \(j\) occurs before \(i\) in that ordering. The matrix \(\mathbf{G}_u\) has to form a directed, acyclic graph. This means that there is an orthogonal permutation matrix \(\mathbf{P}\) such that for \(\bar{\mathbf{G}}_u = \mathbf{P}\mathbf{G}_u\mathbf{P}^T\) an entry \((\bar{\mathbf{G}}_u)_{ij}\) is only non-zero if \(i > j\). In total, we can write
where \(\bar{\mathbf{A}}\) is an upper triangular matrix with unit diagonal. Hence, it is invertible through simple back-substitution.
The DerivativeSolver component does exactly this: It sets up the matrices \(\mathbf{P}\) and
\(\bar{\mathbf{A}}\), permutates inputs/outputs, and solves the sparse linear system through
back-substitution.
Speedy gradients
Given a solution vector \(\mathbf{u}\), Eikonax computes derivatives with linear complexity. Even more, for a given evaluation point, we can evaluate an arbitrary number of gradients through simple backsubstitution. All matrices need to be assembled only once.
Change in tooling
In the DerivativeSolver, we leave JAX and fall back to the numpy/scipy stack. While
the sequential solver operation should not be mush slower on the CPU, we have to transfer
the data back from the offloading device. We plan to implement a GPU-compatible solver with
CuPy in a future version, or in JAX as soon as it offers the necessary
linear algebra tools.
Methods:
| Name | Description |
|---|---|
solve |
Solve the linear system for the adjoint variable |
sparse_system_matrix
property
¶
Get system matrix \(\bar{\mathbf{A}}\in\mathbb{R}^{{N_V}\times {N_V}}\).
sparse_permutation_matrix
property
¶
Get permutation matrix \(\mathbf{P}\in\mathbb{R}^{{N_V}\times {N_V}}\).
__init__
¶
__init__(
solution: jtFloat[jax.Array | npt.NDArray, num_vertices],
sparse_partial_update_solution: sp.sparray,
) -> None
Constructor for the derivative solver.
Initializes the causality-inspired permutation matrix \(\mathbf{P}\), and afterwards the permuted system matrix \(\bar{\mathbf{A}}\), which is triangular.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
solution
|
jax.Array | npt.NDArray
|
Obtained solution of the Eikonal equation. |
required |
sparse_partial_update_solution
|
scipy.sparray
|
Sparse matrix
representing \(\mathbf{G}_u\). Any SciPy sparse format is accepted as long
as it supports conversion to CSC via |
required |
solve
¶
solve(
right_hand_side: jtFloat[jax.Array | npt.NDArray, num_vertices],
) -> jtFloat[npt.NDArray, num_parameters]
Solve the linear system for the parametric gradient.
Following the notation from the class docstring,this method solves the linear system for the adjoint variable \(\mathbf{v}\). Given a right-hand-side \(\mathbf{l}_u\), this is a three- step process:
- Permute the right hand side \(\bar{l}_u = \mathbf{P}l_u\)
- Solve the linear system \(\bar{\mathbf{A}}\bar{\mathbf{v}} = \bar{l}_u\)
- Permute solution back to the original ordering \(\mathbf{v} = \mathbf{P}^T\bar{\mathbf{v}}\)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
right_hand_side
|
jax.Array | npt.NDArray
|
RHS for the linear system solve |
required |
Returns:
| Type | Description |
|---|---|
jtFloat[npt.NDArray, num_parameters]
|
np.ndarray: Solution of the linear system solve, corresponding to the adjoint in an optimization context. |
_assemble_permutation_matrix
¶
Construct permutation matrix \(\mathbf{P}\) for index ordering.
_assemble_system_matrix
¶
_assemble_system_matrix(
sparse_partial_update_solution: sp.sparray, num_points: int
) -> sp.csc_matrix
Assemble system matrix \(\bar{\mathbf{A}}\) for gradient solver.
Before invoking this method, the permutation matrix \(\mathbf{P}\) must be initialized.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sparse_partial_update_solution
|
scipy.sparse matrix-like
|
Sparse matrix
representing \(\mathbf{G}_u\). Will be converted to CSC internally using
|
required |
num_points
|
int
|
Number of points (vertices) in the system; used to build the identity matrix and shape the final system matrix. |
required |
eikonax.derivator.compute_eikonax_jacobian
¶
compute_eikonax_jacobian(
derivative_solver: DerivativeSolver, partial_derivative_parameter: sp.sparray
) -> npt.NDArray
Compute Jacobian from concatenation of gradients, computed with unit vector RHS.
Warning
This method should only be used for small problems.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
derivative_solver
|
DerivativeSolver
|
Initialized derivative solver object. |
required |
partial_derivative_parameter
|
scipy.sparse matrix-like or numpy.ndarray
|
Partial derivative of the global update operator with respect to the parameter
vector/parameters. Should support matrix-vector multiplication or transposed
multiplication with the adjoint vector; typical inputs are SciPy sparse
matrices/arrays (e.g. |
required |
Returns:
| Type | Description |
|---|---|
npt.NDArray
|
npt.NDArray: Dense Jacobian matrix with shape (N_V, M). Note: this routine builds the full dense Jacobian and is intended only for small problems. |
eikonax.derivator.compute_eikonax_hessian
¶
Compute Hessian matrix.
Not implemented yet