Parametric Derivatives¶
In the Forward Solver tutorial, we have shown how to solve the Eikonal equation with Eikonax for a given mesh and tensor field. In this tutorial, we build on this procedure to obtain parametric derivatives \(\frac{d\mathbf{u}(\mathbf{M})}{d\mathbf{M}}\) of the solution w.r.t. a given input tensor field. We further assume that a tensor field is defined via some parameter vector \(\mathbf{m}\in\mathbb{R}^M\), s.th. we can define the mapping \(\mathbf{M}: \mathbb{R}^M \to \mathbb{R}^{N_S\times d\times d}\), under the constraint that \(\mathbf{M}\) is pointwise s.p.d.
In the following, we consider the scenario of having some loss functional \(l: \mathbb{R}^{N_V} \to \mathbb{R},\ l = l(\mathbf{u})\), depending on the solution \(\mathbf{u}(\mathbf{M})\) of the eikonal equation for a specific tensor field \(\mathbf{M}(\mathbf{m})\). In various problem settings, such as minimization of the loss, it is essential to be able to obtain the gradient \(\mathbf{g}\) w.r.t. the input parameter vector,
This is the scenario we cover in this tutorial. Eikonax follows a discretize-then-optimize approach to computing the gradient. Moreover, it efficiently computes discrete adjoints by exploiting the causality in the forward solution of the eikonal equation. A detailed description of this procedure is given here.
Test Mesh Setup¶
We start by setting up the same square mesh as for the Forward Solver tutorial.
However, we rely on Eikonax' built-in create_test_mesh function,
instead of using scipy. We also choose a much smaller mesh with \(3\times 3\) vertices, to efficiently
compare derivatives against finite differences later.
from eikonax import corefunctions, preprocessing
vertices, simplices = preprocessing.create_test_mesh((0, 1), (0, 1), 3, 3)
mesh_data = preprocessing.MeshData(vertices, simplices)
Tensor Field Setup¶
In the Forward Solver tutorial, we have constructed a specific tensor field instance as a simple
numpy array. To evaluate derivatives, however, we need to properly define a mapping \(\mathbf{M}(\mathbf{m})\),
and its derivative. Such a mapping is provided by the tensorfield module.
The tensor field module comprises interfaces and basic implementations for two separate components.
The SimplexTensor describes how, for a given simplex index
\(s\) and local parameter vector \(\mathbf{m}_s\), the tensor \(M_s\) for that simplex is constructed.
The VectorToSimplicesMap, in turn, defines the
comtributions to \(\mathbf{m}_s\) from the global parameter vector \(\mathbf{m}\) for a given simplex s.
The actual TensorField object is created from these two components.
Info
The above procedure is quite low-level and requires some effort from the user side. On the other hand, it guarantees flexibility with respect to the employed type of tensor field. Through strict application of the composition-over-inheritance principle, we can mix different global-to-local mappings and tensor assemblies, which are swiftly vectorized and differentiated by JAX.
In our example, we define local tensors with the built-in InvLinearScalarSimplexTensor. This assembles the local tensor from a scalar \(m_s > 0\) simply as \(\mathbf{M}_s = \frac{1}{m_s}\mathbf{I}\). We further employ the LinearScalarMap, which is basically the map \(m_s = \mathbf{m}[s]\). In total, we create our tensor field like this:
from eikonax import tensorfield
tensor_on_simplex = tensorfield.InvLinearScalarSimplexTensor(dimension=vertices.shape[1])
tensor_field_mapping = tensorfield.LinearScalarMap()
tensor_field_object = tensorfield.TensorField(
num_simplices=simplices.shape[0],
vector_to_simplices_map=tensor_field_mapping,
simplex_tensor=tensor_on_simplex,
)
The tensor_field_object is an intelligent mapping for any valid input vector \(\mathbf{m}\). For
demonstration purposes, we simply create a random input vector and build the tensor field with the
assemble_field method,
import numpy as np
rng = np.random.default_rng(seed=0)
parameter_vector = rng.uniform(0.5, 1.5, simplices.shape[0])
tensor_field_instance = tensor_field_object.assemble_field(parameter_vector)
Solver Setup and Run¶
We now have all components to conduct a forward solver run with Eikonax, analogously to the one described in the Forward Solver tutorial.
from eikonax import solver
solver_data = solver.SolverData(
tolerance=1e-8,
max_num_iterations=1000,
loop_type="jitted_while",
max_value=1000,
use_soft_update=False,
softminmax_order=10,
softminmax_cutoff=0.01,
)
initial_sites = preprocessing.InitialSites(inds=(0,), values=(0,))
eikonax_solver = solver.Solver(mesh_data, solver_data, initial_sites)
solution = eikonax_solver.run(tensor_field_instance)
Partial Derivatives¶
Evaluating the gradient \(g(\mathbf{m})\) is a two-step procedure in Eikonax. Firstly, we evaluate, for a
given parameter vector \(\mathbf{m}\) and associated solution \(\mathbf{u}\), the
partial derivatives \(\mathbf{G}_u\) and \(\mathbf{G}_M\) of the global update operator in the
iterative solver. This can be done with the
PartialDerivator object. Its configuration in
PartialDerivatorData is analogous to that of the
forward solver.
Note
First derivatives do not require a differentiable transformation for the update parameter \(\lambda\). Second derivatives do, however.
from eikonax import derivator
derivator_data = derivator.PartialDerivatorData(
use_soft_update=False,
softminmax_order=None,
softminmax_cutoff=None,
)
eikonax_derivator = derivator.PartialDerivator(mesh_data, derivator_data, initial_sites)
We can obtain sparse representations (through local dependencies) of the partial derivatives via
the compute_partial_derivatives
method,
output_partial_solution, output_partial_tensor = eikonax_derivator.compute_partial_derivatives(
solution.values, tensor_field_instance
)
\(\mathbf{G}_u\) and \(\mathbf{G}_M\) are returned using Eikonax's custom sparse data structures from the
linalg module. These structures are designed to efficiently represent the sparse
derivatives while maintaining JAX compatibility.
\(\mathbf{G}_u\) is returned as an EikonaxSparseMatrix with shape
\(N_V \times N_V\). This is a COO-style sparse matrix representation using JAX arrays that stores row
indices, column indices, and values separately.
\(\mathbf{G}_M\) is returned as a DerivatorSparseTensor, which
stores the partial derivatives with respect to the metric tensor field. For each vertex, it contains
derivative contributions from adjacent simplices in a dense local format (shape
\(N_V \times \text{max\_neighbors} \times d \times d\)) along with adjacency data mapping to global
simplex indices.
To compute the total derivative \(\mathbf{G}_m\) according to the chain rule, we require the tensor field
Jacobian \(\frac{d\mathbf{M}}{d\mathbf{m}}\). This is provided by the tensor field component through the
assemble_jacobian method,
TensorfieldSparseTensor, which stores derivative values
of shape \(N_S \times d \times d \times \texttt{num_params_mapped}\) together with parameter indices
indicating which global parameters each simplex depends on.
We can now obtain \(\mathbf{G}_m\) via the tensor contraction function provided by Eikonax,
from eikonax import linalg
output_partial_parameter = linalg.contract_derivative_tensors(
output_partial_tensor, tensor_partial_parameter
)
EikonaxSparseMatrix with shape \(N_V \times M\)
Derivative Solver¶
With the partial derivatives, we can now set up a sparse, triangular equation system for computing discrete adjoints, and subsequently the gradient \(\mathbf{g}\). The rational behind this procedure is explained in more detail here.
The DerivativeSolver requires the partial derivative
\(\mathbf{G}_u\) in SciPy sparse format. We convert the EikonaxSparseMatrix
using the convert_to_scipy_sparse utility function,
output_partial_solution_scipy = linalg.convert_to_scipy_sparse(output_partial_solution)
derivative_solver = derivator.DerivativeSolver(solution.values, output_partial_solution_scipy)
We can now evaluate the discrete adjoint from a given loss gradient \(\frac{dl}{d\mathbf{u}}\),
We then obtain the gradient by simple multiplication of the adjoint with \(\mathbf{G}_m\). Since
\(\mathbf{G}_m\) is an EikonaxSparseMatrix, we convert it to
SciPy format for the matrix-vector multiplication,
output_partial_parameter_scipy = linalg.convert_to_scipy_sparse(output_partial_parameter)
total_grad = output_partial_parameter_scipy.T @ adjoint
Reusability of the gradient copmutation
The setup of the derivative solver and \(\mathbf{G}_m\) basically constitutes a sparse representation of the entire parametric Jacobian \(\mathbf{J}\) at a point \(\mathbf{m}\). Once assembled, arbitrary portions of the Jacobian can be constructed with negligible cost, even for large systems.
Comparison to Finite Differences¶
As a proof-of-concept, we compare the Jacobian \(\mathbf{J}\) at the point \(\mathbf{m}\) produced by Eikonax
to forward finite differences. With Eikonax, we simply evaluate the gradient for all unit vectors
\(\mathbf{e}_i,\ i=1,\ldots,N_V\) as "loss gradient" \(\frac{dl}{d\mathbf{u}}\). Each such computation
yields a row \(\mathbf{J}_i\) of the Jacobian.
To this end, Eikonax provides the utility function compute_eikonax_jacobian,
eikonax_jacobian = derivator.compute_eikonax_jacobian(
derivative_solver, output_partial_parameter_scipy
)
With finite difference, we evaluate a column \(\mathbf{J}_j\) of the jacobian as
with \(j=1,\ldots,N_S\), and \(h\) denotes the step width of the finite difference scheme. Again,
we use an Eikonax utility module, finitediff,
step_widths = np.logspace(-5, -1, 101)
errors = []
for step_width in step_widths:
finite_diff_jacobian = finitediff.compute_fd_jacobian(
eikonax_solver=eikonal_solver,
tensor_field=tensor_field_object,
stencil=finitediff.finite_diff_1_forward,
eval_point=parameter_vector,
step_width=step_width,
)
error = np.linalg.norm(finite_diff_jacobian - eikonax_jacobian)
errors.append(error)
The figure below shows the difference of the Jacobian matrices in Frobenius norm for \(h\in[1\mathrm{e}{-5}, 1\mathrm{e}{-1}]\). As can be expected, the error decreases linearly for decreasing \(h\) down to the square root of the floating point precision (32 bit in this case). Beyond this threshold, the error increases again due to round-off errors.
