diff --git a/src/dolfinx_adjoint/__init__.py b/src/dolfinx_adjoint/__init__.py index 00e96e7..7a0a688 100644 --- a/src/dolfinx_adjoint/__init__.py +++ b/src/dolfinx_adjoint/__init__.py @@ -7,7 +7,7 @@ from .assembly import assemble_scalar, error_norm from .solvers import LinearProblem, NonlinearProblem -from .types import Constant, Function +from .types import Constant, Function, dirichletbc from .types.function import assign meta = metadata("dolfinx_adjoint") @@ -24,6 +24,7 @@ __all__ = [ "Constant", "Function", + "dirichletbc", "LinearProblem", "NonlinearProblem", "assemble_scalar", diff --git a/src/dolfinx_adjoint/blocks/dirichletbc.py b/src/dolfinx_adjoint/blocks/dirichletbc.py new file mode 100644 index 0000000..0fd8b48 --- /dev/null +++ b/src/dolfinx_adjoint/blocks/dirichletbc.py @@ -0,0 +1,24 @@ +import dolfinx +import numpy as np +import numpy.typing as npt +from pyadjoint.block import Block + + +class DirichletBCBlock(Block): + def __init__( + self, + value: dolfinx.fem.Function | dolfinx.fem.Constant, + dofs: npt.NDArray[np.int32], + V: dolfinx.fem.FunctionSpace | None = None, + ad_block_tag=None, + ): + super().__init__(ad_block_tag=ad_block_tag) + self.dofs = dofs + self.V = V + self.add_dependency(value) + + def prepare_recompute_component(self, inputs, relevant_outputs): + return inputs[0] if inputs else None + + def recompute_component(self, inputs, block_variable, idx, prepared): + return block_variable.saved_output diff --git a/src/dolfinx_adjoint/blocks/function_assigner.py b/src/dolfinx_adjoint/blocks/function_assigner.py index 7407de3..fca04b5 100644 --- a/src/dolfinx_adjoint/blocks/function_assigner.py +++ b/src/dolfinx_adjoint/blocks/function_assigner.py @@ -178,15 +178,18 @@ def prepare_recompute_component(self, inputs, relevant_outputs): def recompute_component(self, inputs, block_variable, idx, prepared): if self.expr is None: prepared = inputs[0] - output = dolfinx.fem.Function( - block_variable.output.function_space, name="f{block_variable.output.name}_AssignBlockRecompute" - ) + + # We should return the exact object instance to maintain C++ memory bindings + # (especially for DirichletBCs), updating it in-place. + output = block_variable.saved_output + try: if output.function_space == prepared.function_space: output.x.array[:] = prepared.x.array[:] except AttributeError: # Handling float value output.x.array[:] = prepared + return output def __str__(self): diff --git a/src/dolfinx_adjoint/blocks/solvers.py b/src/dolfinx_adjoint/blocks/solvers.py index b5db2e3..a55de85 100644 --- a/src/dolfinx_adjoint/blocks/solvers.py +++ b/src/dolfinx_adjoint/blocks/solvers.py @@ -64,6 +64,7 @@ def __init__( self.add_dependency(c, no_duplicates=True) for c in self._rhs.coefficients(): # type: ignore self.add_dependency(c, no_duplicates=True) + except AttributeError: raise NotImplementedError("Blocked systems not implemented yet.") self._compiled_lhs = dolfinx.fem.form( @@ -86,6 +87,13 @@ def __init__( self._petsc_options = petsc_options if petsc_options is not None else {} self._petsc_options_prefix = petsc_options_prefix self._bcs = bcs if bcs is not None else [] + + # Add dependencies from the boundary conditions + if self._bcs is not None: + for bc in self._bcs: + if hasattr(bc, "block_variable"): + self.add_dependency(bc, no_duplicates=True) + # Solver for recomputing the linear problem self._forward_solver = dolfinx.fem.petsc.LinearProblem( a=self._lhs, @@ -162,16 +170,6 @@ def prepare_recompute_component(self, inputs, relevant_outputs): else: initial_guess = [dolfinx.fem.Function(u.function_space, name=u.name + "_initial_guess") for u in self._u] - # Replace values in the DirichletBC if it is dependent on a control - # NOTE: Currently assume that BCS are control independent. - bcs = self._bcs - # for block_variable in self.get_dependencies(): - # c = block_variable.output - # c_rep = block_variable.saved_output - - # if isinstance(c, dolfinx.fem.DirichletBC): - # bcs.append(c_rep) - # Replace form coefficients with checkpointed values. # Loop through the dependencies of the lhs and rhs, check if they are in the respective form lhs = self._replace_coefficients_in_form(self._lhs) @@ -206,7 +204,7 @@ def prepare_recompute_component(self, inputs, relevant_outputs): self._forward_solver._a = compiled_lhs self._forward_solver._L = compiled_rhs self._forward_solver._P = compiled_preconditioner - self._forward_solver.bcs = bcs + self._forward_solver.bcs = self._bcs self._forward_solver._u = initial_guess def recompute_component( diff --git a/src/dolfinx_adjoint/types/__init__.py b/src/dolfinx_adjoint/types/__init__.py index c1a57a0..a781b83 100644 --- a/src/dolfinx_adjoint/types/__init__.py +++ b/src/dolfinx_adjoint/types/__init__.py @@ -1,3 +1,4 @@ -__all__ = ["Function", "Constant"] +__all__ = ["Function", "Constant", "dirichletbc"] +from .dirichletbc import dirichletbc from .function import Constant, Function diff --git a/src/dolfinx_adjoint/types/dirichletbc.py b/src/dolfinx_adjoint/types/dirichletbc.py new file mode 100644 index 0000000..547bf15 --- /dev/null +++ b/src/dolfinx_adjoint/types/dirichletbc.py @@ -0,0 +1,63 @@ +import dolfinx +import numpy as np +import numpy.typing as npt +import pyadjoint +from pyadjoint.overloaded_type import FloatingType + +from ..blocks.dirichletbc import DirichletBCBlock +from .function import Function + + +class DirichletBC(dolfinx.fem.DirichletBC, FloatingType): + """A class overloading `dolfinx.fem.DirichletBC` to support it being used as a control variable + in the adjoint framework. + + Args: + g: The value of the Dirichlet BC. + dofs: An array of degree-of-freedom indices in `V` where the BC should be applied. + **kwargs: Additional keyword arguments to pass to the `pyadjoint.overloaded_type.FloatingType` constructor. + + """ + + def __init__(self, g: Function, dofs: npt.NDArray[np.int32], **kwargs): + dtype = g.dtype + if np.issubdtype(dtype, np.float32): + bctype = dolfinx.cpp.fem.DirichletBC_float32 + elif np.issubdtype(dtype, np.float64): + bctype = dolfinx.cpp.fem.DirichletBC_float64 + elif np.issubdtype(dtype, np.complex64): + bctype = dolfinx.cpp.fem.DirichletBC_complex64 + elif np.issubdtype(dtype, np.complex128): + bctype = dolfinx.cpp.fem.DirichletBC_complex128 + else: + raise NotImplementedError(f"Type {dtype} not supported.") + + super().__init__(bctype(g._cpp_object, dofs)) + + annotate = kwargs.pop("annotate", True) + annotate = annotate and pyadjoint.annotate_tape() + + FloatingType.__init__( + self, + g, + dtype=dtype, + block_class=kwargs.pop("block_class", DirichletBCBlock), + _ad_floating_active=False, + _ad_args=kwargs.pop("_ad_args", (g, dofs)), + annotate=annotate, + **kwargs, + ) + + if annotate: + self._ad_annotate_block() + + def _ad_create_checkpoint(self): + return self + + def _ad_restore_at_checkpoint(self, checkpoint): + return self + + +def dirichletbc(value: Function, dofs: npt.NDArray[np.int32], **kwargs) -> DirichletBC: + """Overloaded DirichletBC constructor that creates an adjoint-aware DirichletBC""" + return DirichletBC(value, dofs, **kwargs) diff --git a/tests/test_dirichlet_bc.py b/tests/test_dirichlet_bc.py new file mode 100644 index 0000000..571553f --- /dev/null +++ b/tests/test_dirichlet_bc.py @@ -0,0 +1,137 @@ +from mpi4py import MPI + +import dolfinx +import numpy as np +import pyadjoint +import ufl +from pyadjoint.overloaded_type import Weakref + +from dolfinx_adjoint import Function, LinearProblem, assemble_scalar, assign, dirichletbc +from dolfinx_adjoint.blocks.dirichletbc import DirichletBCBlock + + +def test_dirichletbc_recording(): + """Test that creating an overloaded dirichletbc correctly registers a block and dependency on the tape.""" + pyadjoint.get_working_tape().clear_tape() + mesh = dolfinx.mesh.create_unit_interval(MPI.COMM_WORLD, 10) + V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1)) + + c = Function(V, name="boundary_value") + c.interpolate(lambda x: x[0]) + + dofs = dolfinx.fem.locate_dofs_geometrical(V, lambda x: np.isclose(x[0], 0.0)) + bc = dirichletbc(c, dofs) + + tape = pyadjoint.get_working_tape() + blocks = tape.get_blocks() + + # The tape should have 1 block: DirichletBCBlock + assert len(blocks) == 1 + assert isinstance(blocks[0], DirichletBCBlock) + + # The block should have exactly 1 dependency (the function 'c') + assert len(blocks[0].get_dependencies()) == 1 + assert blocks[0].get_dependencies()[0].output is c + + # The returned BC object should now possess the injected block_variable + assert hasattr(bc, "block_variable") + + +def test_dirichletbc_no_annotate(): + """Test that setting annotate=False bypasses tape recording entirely.""" + + pyadjoint.get_working_tape().clear_tape() + mesh = dolfinx.mesh.create_unit_interval(MPI.COMM_WORLD, 10) + V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1)) + + c = Function(V, name="boundary_value") + c.interpolate(lambda x: x[0]) + + dofs = dolfinx.fem.locate_dofs_geometrical(V, lambda x: np.isclose(x[0], 0.0)) + + # Run with annotation off + bc = dirichletbc(c, dofs, annotate=False) + + tape = pyadjoint.get_working_tape() + + assert len(tape.get_blocks()) == 0 + # FIX: Check the underlying weak reference rather than invoking the property + assert getattr(bc, "_block_variable", Weakref())() is None + + +def test_dirichletbc_recompute(): + """Test the PyAdjoint internal recompute logic specifically for the DirichletBCBlock.""" + pyadjoint.get_working_tape().clear_tape() + mesh = dolfinx.mesh.create_unit_interval(MPI.COMM_WORLD, 10) + V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1)) + + c = Function(V, name="boundary_value") + c.interpolate(lambda x: np.full_like(x[0], 5.0)) + + dofs = dolfinx.fem.locate_dofs_geometrical(V, lambda x: np.isclose(x[0], 0.0)) + bc = dirichletbc(c, dofs) + + tape = pyadjoint.get_working_tape() + block = tape.get_blocks()[0] + + # Simulate an optimizer changing the function value + c.interpolate(lambda x: np.full_like(x[0], 15.0)) + + # Replay the PyAdjoint mechanics manually + prepared = block.prepare_recompute_component([c], None) + new_bc = block.recompute_component([c], bc.block_variable, 0, prepared) + + # Assert that the re-instantiated C++ object captured the updated control value + assert isinstance(new_bc, dolfinx.fem.bcs.DirichletBC) + assert np.isclose(new_bc.g.x.array[0], 15.0) + + +def test_time_dependent_bc_replay(): + pyadjoint.get_working_tape().clear_tape() + + mesh = dolfinx.mesh.create_unit_square(MPI.COMM_WORLD, 8, 8) + V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1)) + + dt = 0.1 + num_steps = 3 + + m = Function(V, name="control") + m.interpolate(lambda x: np.sin(x[0] * np.pi)) + + u = ufl.TrialFunction(V) + v = ufl.TestFunction(V) + + uh = Function(V, name="state") + assign(0.0, uh) + + u_prev = Function(V, name="state_prev") + assign(0.0, u_prev) + + F = (u - u_prev) / dt * v * ufl.dx + ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx - m * v * ufl.dx + a, L = ufl.system(F) + + bc_func = Function(V, name="bc_func") + mesh.topology.create_connectivity(mesh.topology.dim - 1, mesh.topology.dim) + boundary_facets = dolfinx.mesh.exterior_facet_indices(mesh.topology) + boundary_dofs = dolfinx.fem.locate_dofs_topological(V, mesh.topology.dim - 1, boundary_facets) + + # Use native dolfinx here! PyAdjoint traces the bc_func inside it. + bc = dirichletbc(bc_func, boundary_dofs) + + problem = LinearProblem(a, L, bcs=[bc], u=uh) + + J = 0.0 + + for i in range(num_steps): + assign(float(i + 1), bc_func) + problem.solve() + J += assemble_scalar(0.5 * ufl.inner(uh, uh) * ufl.dx) + assign(uh, u_prev) + + J_forward = float(J) + + control = pyadjoint.Control(m) + Jhat = pyadjoint.ReducedFunctional(J, control) + J_replay = Jhat(m) + + assert np.isclose(J_replay, J_forward, atol=1e-10, rtol=1e-10)