Parameter Tensor Field¶
tensorfield
¶
Composable and differentiable parameter tensor fields.
This module provides ABCs and implementations for the creation of differentiable parameter fields used in Eikonax. Recall that for the Eikonax solver, and particularly parametric derivatives, we require an input tensor field \(\mathbf{M}: \mathbb{R}^M \times \mathbb{N}_0 \to \mathbb{S}_+^{d\times d}\). This means that the tensor field is a mapping \(\mathbf{M}(\mathbf{m},s)\) that assigns, given a global parameter vector \(\mathbf{m}\), an s.p.d tensor to every simplex \(s\) in the mesh. To allow for sufficient flexibility in the choice of tensor field, we implement it as a composition of two main components.
AbstractVectorToSimplicesMapprovides the interface for a mapping from the global parameter vector \(\mathbf{m}\) to the local parameter values \(\mathbf{m}_s\) required to assemble the tensor \(\mathbf{M}_s\) for simplex \(s\).AbstractSimplexTensorprovides the interface for the assembly of the local tensor \(\mathbf{M}_s\), given the local contributions \(\mathbf{m}_s\) and a simplex s.
Concrete implementations of both components are used to initialize the
TensorField object, which vectorizes and differentiates them
using JAX, to provide the mapping \(\mathbf{M}(\mathbf{m})\) and its Jacobian tensor
\(\frac{d \mathbf{M}}{d \mathbf{m}}\).
Classes:
| Name | Description |
|---|---|
AbstractVectorToSimplicesMap |
ABC interface contract for vector-to-simplices maps |
LinearScalarMap |
Simple one-to-one map from global to simplex parameters |
AbstractSimplexTensor |
ABC interface contract for assembly of the tensor field |
LinearScalarSimplexTensor |
SimplexTensor implementation relying on one parameter per simplex |
InvLinearScalarSimplexTensor |
SimplexTensor implementation relying on one parameter per simplex |
TensorField |
Tensor field component |
eikonax.tensorfield.AbstractVectorToSimplicesMap
¶
Bases: eqx.Module
ABC interface contract for vector-to-simplices maps.
Every component derived from this class needs to implement the map and derivative methods.
The map method is responsible for returning the relevant parameters for a given simplex from
the global parameter vector. The derivative method computes the derivative of the mapping with
respect to the global parameters.
Methods:
| Name | Description |
|---|---|
map |
Interface for vector-to-simplex mapping |
derivative |
Interface for vector-to-simplex mapping derivative |
map
abstractmethod
¶
map(
simplex_ind: jtInt[jax.Array, ""], parameters: jtReal[jax.Array, num_parameters]
) -> jtReal[jax.Array, num_parameters_local]
Interface for vector-to-simplex mapping.
For the given simplex_ind, return those parameters from the global parameter vector that
are relevant for the simplex. This method needs to be broadcastable over simplex_ind by
JAX (with vmap).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
simplex_ind
|
int
|
Index of the simplex under consideration. |
required |
parameters
|
jax.Array
|
Global parameter vector. |
required |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
ABC error indicating that the method needs to be implemented in subclasses. |
Returns:
| Type | Description |
|---|---|
jtReal[jax.Array, num_parameters_local]
|
jax.Array: Relevant parameters for the simplex. |
derivative
abstractmethod
¶
derivative(
simplex_ind: jtInt[jax.Array, ""], parameters: jtReal[jax.Array, num_parameters]
) -> tuple[
jtReal[jax.Array, "num_parameters_local num_parameters_mapped"],
jtInt[jax.Array, num_parameters_mapped],
]
Interface for vector-to-simplex mapping derivative.
The derivative of the simplex mapping for a given simplex_ind is a num_local_parameters
\(\times\) num_parameters Jacobian matrix, where the only non-zero entries are in the
columns corresponding to the values in the global parameter that are used for the current
simplex. The corresponding entries are typically identical to one. This method returns
the Jacobian matrix in a sparse representation: one array with matrix values and one
array with the corresponding column indices.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
simplex_ind
|
int
|
Index of the simplex under consideration. |
required |
parameters
|
jax.Array
|
Global parameter vector. |
required |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
ABC error indicating that the method needs to be implemented in subclasses. |
Returns:
| Type | Description |
|---|---|
tuple[jtReal[jax.Array, 'num_parameters_local num_parameters_mapped'], jtInt[jax.Array, num_parameters_mapped]]
|
tuple[jax.Array, jax.Array]: Jacobian matrix values and corresponding global parameter indices. First array has shape (num_parameters_local, num_parameters_mapped), second array has shape (num_parameters_mapped,). |
eikonax.tensorfield.LinearScalarMap
¶
Bases: AbstractVectorToSimplicesMap
Simple one-to-one map from global to simplex parameters.
Every simplex takes exactly one parameter \(m_s\), which is sorted in the global parameter in the same order as the simplices, meaning that \(m_s = \mathbf{m}[s]\).
map
¶
map(
simplex_ind: jtInt[jax.Array, ""], parameters: jtReal[jax.Array, num_parameters]
) -> jtReal[jax.Array, num_parameters_local]
Return relevant parameters for a given simplex.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
simplex_ind
|
int
|
Index of the simplex under consideration. |
required |
parameters
|
jax.Array
|
Global parameter vector. |
required |
Returns:
| Type | Description |
|---|---|
jtReal[jax.Array, num_parameters_local]
|
jax.Array: Relevant parameter (only one) with shape (1,). |
derivative
¶
derivative(
simplex_ind: jtInt[jax.Array, ""], _parameters: jtReal[jax.Array, num_parameters]
) -> tuple[
jtReal[jax.Array, "num_parameters_local num_parameters_mapped"],
jtInt[jax.Array, num_parameters_mapped],
]
Return sparse representation of Jacobi matrix, in this case two arrays of size one.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
simplex_ind
|
int
|
Index of the simplex under consideration. |
required |
_parameters
|
jax.Array
|
Global parameter vector (not used). |
required |
Returns:
| Type | Description |
|---|---|
tuple[jtReal[jax.Array, 'num_parameters_local num_parameters_mapped'], jtInt[jax.Array, num_parameters_mapped]]
|
tuple[jax.Array, jax.Array]: Jacobian matrix values and corresponding global parameter indices. Both arrays have shape (1, 1) and (1,) respectively. |
eikonax.tensorfield.AbstractSimplexTensor
¶
Bases: eqx.Module
ABC interface contract for assembly of the tensor field.
SimplexTensor components assemble the tensor field for a given simplex and a set of parameters
for that simplex. The relevant parameters are provided by the VectorToSimplicesMap component
from the global parameter vector.
Note
This class provides the metric tensor as used in the inner product for the update stencil of the eikonal equation. This is the inverse of the conductivity tensor, which is the actual tensor field in the eikonal equation.
Methods:
| Name | Description |
|---|---|
assemble |
Assemble the tensor field for a given simplex and parameters |
derivative |
Parametric derivative of the |
assemble
abstractmethod
¶
assemble(
simplex_ind: jtInt[jax.Array, ""], parameters: jtFloat[jax.Array, num_parameters_local]
) -> jtFloat[jax.Array, "dim dim"]
Assemble the tensor field for given simplex and parameters.
Given a parameter array of size \(m_s\), the method returns a tensor of size \(d\times d\).
The method needs to be broadcastable over simplex_ind by JAX (with vmap).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
simplex_ind
|
int
|
Index of the simplex under consideration. |
required |
parameters
|
jax.Array
|
Parameters for the simplex. |
required |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
ABC error indicating that the method needs to be implemented in subclasses. |
Returns:
| Type | Description |
|---|---|
jtFloat[jax.Array, 'dim dim']
|
jax.Array: Tensor field for the simplex under consideration with shape (dim, dim). |
derivative
abstractmethod
¶
derivative(
simplex_ind: jtInt[jax.Array, ""], parameters: jtFloat[jax.Array, num_parameters_local]
) -> jtFloat[jax.Array, "dim dim num_local_parameters"]
Parametric derivative of the assemble method.
Given a parameter array of size \(m_s\), the method returns a Jacobian tensor of size
\(d\times d\times m_s\). The method needs to be broadcastable over simplex_ind by JAX
(with vmap).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
simplex_ind
|
int
|
Index of the simplex under consideration. |
required |
parameters
|
jax.Array
|
Parameters for the simplex. |
required |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
ABC error indicating that the method needs to be implemented in subclasses. |
Returns:
| Type | Description |
|---|---|
jtFloat[jax.Array, 'dim dim num_local_parameters']
|
jax.Array: Jacobian tensor for the simplex under consideration with shape (dim, dim, num_local_parameters). |
eikonax.tensorfield.LinearScalarSimplexTensor
¶
Bases: AbstractSimplexTensor
SimplexTensor implementation relying on one parameter per simplex.
Given a scalar parameter \(m_s\), the tensor field is assembled as \(m_s \cdot \mathbf{I}\), where \(\mathbf{I}\) is the identity matrix.
Methods:
| Name | Description |
|---|---|
assemble |
Assemble the tensor field for a parameter vector |
derivative |
Parametric derivative of the |
assemble
¶
assemble(
_simplex_ind: jtInt[jax.Array, ""], parameters: jtFloat[jax.Array, num_parameters_local]
) -> jtFloat[jax.Array, "dim dim"]
Assemble tensor for given simplex.
the parameters argument is a scalar here, and _simplex_ind is not used.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
_simplex_ind
|
int
|
Index of simplex under consideration (not used). |
required |
parameters
|
jax.Array
|
Parameter (scalar) for tensor assembly with shape (1,). |
required |
Returns:
| Type | Description |
|---|---|
jtFloat[jax.Array, 'dim dim']
|
jax.Array: Tensor for the simplex with shape (dim, dim). |
derivative
¶
derivative(
_simplex_ind: jtInt[jax.Array, ""], _parameters: jtFloat[jax.Array, num_parameters_local]
) -> jtFloat[jax.Array, "dim dim num_local_parameters"]
Parametric derivative of the assemble method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
_simplex_ind
|
int
|
Index of simplex under consideration (not used). |
required |
_parameters
|
jax.Array
|
Parameter (scalar) for tensor assembly (not used). |
required |
Returns:
| Type | Description |
|---|---|
jtFloat[jax.Array, 'dim dim num_local_parameters']
|
jax.Array: Jacobian tensor for the simplex under consideration with shape (dim, dim, 1). |
eikonax.tensorfield.InvLinearScalarSimplexTensor
¶
Bases: AbstractSimplexTensor
SimplexTensor implementation relying on one parameter per simplex.
Given a scalar parameter \(m_s\), the tensor field is assembled as \(\frac{1}{m_s} \cdot \mathbf{I}\), where \(\mathbf{I}\) is the identity matrix.
Methods:
| Name | Description |
|---|---|
assemble |
Assemble the tensor field for a parameter vector |
derivative |
Parametric derivative of the |
assemble
¶
assemble(
_simplex_ind: jtInt[jax.Array, ""], parameters: jtFloat[jax.Array, num_parameters_local]
) -> jtFloat[jax.Array, "dim dim"]
Assemble tensor for given simplex.
The parameters argument is a scalar here, and _simplex_ind is not used.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
_simplex_ind
|
int
|
Index of simplex under consideration (not used). |
required |
parameters
|
jax.Array
|
Parameter (scalar) for tensor assembly with shape (1,). |
required |
Returns:
| Type | Description |
|---|---|
jtFloat[jax.Array, 'dim dim']
|
jax.Array: Tensor for the simplex with shape (dim, dim). |
derivative
¶
derivative(
_simplex_ind: jtInt[jax.Array, ""], parameters: jtFloat[jax.Array, num_local_parameters]
) -> jtFloat[jax.Array, "dim dim num_local_parameters"]
Parametric derivative of the assemble method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
_simplex_ind
|
int
|
Index of simplex under consideration (not used). |
required |
parameters
|
jax.Array
|
Parameter (scalar) for tensor assembly with shape (1,). |
required |
Returns:
| Type | Description |
|---|---|
jtFloat[jax.Array, 'dim dim num_local_parameters']
|
jax.Array: Jacobian tensor for the simplex under consideration with shape (dim, dim, 1). |
eikonax.tensorfield.TensorField
¶
Bases: eqx.Module
Tensor field component.
Tensor fields combine the functionality of vector-to-simplices maps and simplex tensors according to the composition over inheritance principle. They constitute the full mapping \(\mathbf{M}(\mathbf{m})\) from the global parameter vector to the tensor field over all mesh faces (simplices). In addition, they provide the parametric derivative \(\frac{d\mathbf{M}}{\mathbf{m}}\) of that mapping. Tensor fields are completely independent from the Eikonax solver and derivator, but the output of these two components can be used to compute the partial derivative \(\mathbf{G}_m = \frac{du}{d\mathbf{M}}\frac{d\mathbf{M}}{\mathbf{m}}\).
Methods:
| Name | Description |
|---|---|
assemble_field |
Assemble the tensor field for the given parameter vector |
assemble_jacobian |
Assemble the parametric derivative \(\frac{d\mathbf{M}}{\mathbf{m}}\) of the tensor field |
__init__
¶
__init__(
num_simplices: int,
vector_to_simplices_map: AbstractVectorToSimplicesMap,
simplex_tensor: AbstractSimplexTensor,
) -> None
Constructor.
Takes information about the mesh simplices, a vector-to-simplices map, and a simplex tensor map.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_simplices
|
int
|
Number of simplices in the mesh. |
required |
vector_to_simplices_map
|
AbstractVectorToSimplicesMap
|
Mapping from global to simplex parameters. |
required |
simplex_tensor
|
AbstractSimplexTensor
|
Tensor field assembly for a given simplex. |
required |
assemble_field
¶
assemble_field(
parameter_vector: jtFloat[jax.Array | npt.NDArray, num_parameters],
) -> jtFloat[jax.Array, "num_simplices dim dim"]
Assemble global tensor field from global parameter vector.
This method simply chains calls to the vector-to-simplices map and the simplex tensor objects, vectorized over all simplices.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parameter_vector
|
jax.Array | npt.NDArray
|
Global parameter vector. |
required |
Returns:
| Type | Description |
|---|---|
jtFloat[jax.Array, 'num_simplices dim dim']
|
jax.Array: Global tensor field with shape (num_simplices, dim, dim). |
assemble_jacobian
¶
assemble_jacobian(
parameter_vector: jtFloat[jax.Array | npt.NDArray, num_parameters],
) -> linalg.TensorfieldSparseTensor
Assemble the Jacobian \(\frac{d\mathbf{M}}{d\mathbf{m}}\).
The assembly of the Jacobian matrix works via a local chaining of the derivative calls
to the simplex tensor and the vector-to-simplices map. The result is packaged into the
project's TensorfieldSparseTensor data
structure, which pairs derivative values with parameter indices and can be contracted
with other derivative tensors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parameter_vector
|
jax.Array | npt.NDArray
|
Global parameter vector. |
required |
Returns:
| Type | Description |
|---|---|
linalg.TensorfieldSparseTensor
|
eikonax.linalg.TensorfieldSparseTensor: Sparse tensor holding the Jacobian \(\frac{d\mathbf{M}}{d\mathbf{m}}\) with derivative values of shape (num_simplices, dim, dim, num_parameters_mapped) and associated parameter indices. |
_assemble_jacobian_global
¶
_assemble_jacobian_global(
parameter_vector: jtFloat[jax.Array, num_parameters],
) -> tuple[
jtFloat[jax.Array, "num_simplices dim dim num_parameters_mapped"],
jtInt[jax.Array, "num_simplices num_parameters_mapped"],
]
Intermediate call for vectorization with vmap in jit context.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parameter_vector
|
jax.Array
|
Global parameter vector. |
required |
Returns:
| Type | Description |
|---|---|
tuple[jtFloat[jax.Array, 'num_simplices dim dim num_parameters_mapped'], jtInt[jax.Array, 'num_simplices num_parameters_mapped']]
|
tuple[jax.Array, jax.Array]: Derivative values with shape (num_simplices, dim, dim, num_parameters_mapped) and global parameter indices with shape (num_simplices, num_parameters_mapped). |
_assemble_jacobian_local
¶
_assemble_jacobian_local(
simplex_ind: jtInt[jax.Array, ""], parameter_vector: jtFloat[jax.Array, num_parameters]
) -> tuple[
jtFloat[jax.Array, "dim dim num_parameters_mapped"], jtInt[jax.Array, num_parameters_mapped]
]
Assembly of sparse jacobian representation for single simplex.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
simplex_ind
|
jax.Array
|
Index of the simplex (scalar). |
required |
parameter_vector
|
jax.Array
|
Global parameter vector. |
required |
Returns:
| Type | Description |
|---|---|
tuple[jtFloat[jax.Array, 'dim dim num_parameters_mapped'], jtInt[jax.Array, num_parameters_mapped]]
|
tuple[jax.Array, jax.Array]: Derivative values with shape (dim, dim, num_parameters_mapped) and global parameter indices with shape (num_parameters_mapped,). |