Skip to content

Commit

Permalink
Move some LinearSolver::Tags
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsvu committed Jan 6, 2025
1 parent 742c8c0 commit 5cb83f7
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 28 deletions.
21 changes: 11 additions & 10 deletions src/ParallelAlgorithms/LinearSolver/Multigrid/ElementActions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ struct InitializeElement : tt::ConformsTo<amr::protocols::Projector> {
Tags::ParentMesh<Dim>,
observers::Tags::ObservationKey<Tags::MultigridLevel>,
observers::Tags::ObservationKey<Tags::IsFinestGrid>,
Tags::ObservationId<OptionsGroup>,
LinearSolver::Tags::ObservationId<OptionsGroup>,
Tags::VolumeDataForOutput<OptionsGroup, FieldsTag>>;
using compute_tags = tmpl::list<>;
using const_global_cache_tags =
tmpl::list<Tags::MaxLevels<OptionsGroup>,
Tags::OutputVolumeData<OptionsGroup>>;
LinearSolver::Tags::OutputVolumeData<OptionsGroup>>;

template <typename DbTagsList, typename... InboxTags, typename Metavariables,
typename ActionList, typename ParallelComponent>
Expand All @@ -91,7 +91,7 @@ struct InitializeElement : tt::ConformsTo<amr::protocols::Projector> {
using argument_tags =
tmpl::list<domain::Tags::Mesh<Dim>, domain::Tags::Element<Dim>,
domain::Tags::InitialRefinementLevels<Dim>,
Tags::OutputVolumeData<OptionsGroup>>;
LinearSolver::Tags::OutputVolumeData<OptionsGroup>>;
using return_tags = tmpl::append<simple_tags, simple_tags_from_options>;

template <typename... AmrData>
Expand Down Expand Up @@ -147,11 +147,12 @@ struct InitializeElement : tt::ConformsTo<amr::protocols::Projector> {
// Preserve state of observation ID
if constexpr (tt::is_a_v<tuples::TaggedTuple, AmrData...>) {
// h-refinement: copy from the parent
*observation_id = get<Tags::ObservationId<OptionsGroup>>(amr_data...);
*observation_id =
get<LinearSolver::Tags::ObservationId<OptionsGroup>>(amr_data...);
} else if constexpr (tt::is_a_v<std::unordered_map, AmrData...>) {
// h-coarsening: copy from one of the children (doesn't matter which)
*observation_id =
get<Tags::ObservationId<OptionsGroup>>(amr_data.begin()->second...);
*observation_id = get<LinearSolver::Tags::ObservationId<OptionsGroup>>(
amr_data.begin()->second...);
} else {
(void)observation_id;
}
Expand Down Expand Up @@ -235,7 +236,7 @@ struct PreparePreSmoothing {
}

// Record pre-smoothing initial fields and source
if (db::get<Tags::OutputVolumeData<OptionsGroup>>(box)) {
if (db::get<LinearSolver::Tags::OutputVolumeData<OptionsGroup>>(box)) {
db::mutate<Tags::VolumeDataForOutput<OptionsGroup, FieldsTag>>(
[](const auto volume_data, const auto& initial_fields,
const auto& source) {
Expand Down Expand Up @@ -293,7 +294,7 @@ struct SkipPostSmoothingAtBottom {
not db::get<Tags::ParentId<Dim>>(box).has_value();

// Record pre-smoothing result fields and residual
if (db::get<Tags::OutputVolumeData<OptionsGroup>>(box)) {
if (db::get<LinearSolver::Tags::OutputVolumeData<OptionsGroup>>(box)) {
db::mutate<Tags::VolumeDataForOutput<OptionsGroup, FieldsTag>>(
[](const auto volume_data, const auto& result_fields,
const auto& residuals) {
Expand Down Expand Up @@ -362,7 +363,7 @@ struct SendCorrectionToFinerGrid {
const auto& child_ids = db::get<Tags::ChildIds<Dim>>(box);

// Record post-smoothing result fields and residual
if (db::get<Tags::OutputVolumeData<OptionsGroup>>(box)) {
if (db::get<LinearSolver::Tags::OutputVolumeData<OptionsGroup>>(box)) {
db::mutate<Tags::VolumeDataForOutput<OptionsGroup, FieldsTag>>(
[](const auto volume_data, const auto& result_fields,
const auto& residuals) {
Expand Down Expand Up @@ -475,7 +476,7 @@ struct ReceiveCorrectionFromCoarserGrid {
make_not_null(&box));

// Record post-smoothing initial fields and source
if (db::get<Tags::OutputVolumeData<OptionsGroup>>(box)) {
if (db::get<LinearSolver::Tags::OutputVolumeData<OptionsGroup>>(box)) {
db::mutate<Tags::VolumeDataForOutput<OptionsGroup, FieldsTag>>(
[](const auto volume_data, const auto& initial_fields,
const auto& source) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ struct ObserveVolumeData {
Parallel::GlobalCache<Metavariables>& cache,
const ElementId<Dim>& element_id, const ActionList /*meta*/,
const ParallelComponent* const /*meta*/) {
if (not db::get<Tags::OutputVolumeData<OptionsGroup>>(box)) {
if (not db::get<LinearSolver::Tags::OutputVolumeData<OptionsGroup>>(box)) {
return {Parallel::AlgorithmExecution::Continue, std::nullopt};
}
const auto& volume_data = db::get<volume_data_tag>(box);
const auto& observation_id =
db::get<Tags::ObservationId<OptionsGroup>>(box);
db::get<LinearSolver::Tags::ObservationId<OptionsGroup>>(box);
const auto& mesh = db::get<domain::Tags::Mesh<Dim>>(box);
const auto& inertial_coords =
db::get<domain::Tags::Coordinates<Dim, Frame::Inertial>>(box);
Expand Down Expand Up @@ -124,7 +124,7 @@ struct ObserveVolumeData {
ElementVolumeData{element_id, std::move(components), mesh});

// Increment observation ID
db::mutate<Tags::ObservationId<OptionsGroup>>(
db::mutate<LinearSolver::Tags::ObservationId<OptionsGroup>>(
[](const auto local_observation_id) { ++(*local_observation_id); },
make_not_null(&box));
return {Parallel::AlgorithmExecution::Continue, std::nullopt};
Expand Down
13 changes: 2 additions & 11 deletions src/ParallelAlgorithms/LinearSolver/Multigrid/Tags.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ template <typename OptionsGroup>
struct OutputVolumeData : db::SimpleTag {
using type = bool;
static constexpr bool pass_metavariables = false;
using option_tags = tmpl::list<OptionTags::OutputVolumeData<OptionsGroup>>;
using option_tags =
tmpl::list<LinearSolver::OptionTags::OutputVolumeData<OptionsGroup>>;
static type create_from_options(const type value) { return value; };
static std::string name() {
return "OutputVolumeData(" + pretty_type::name<OptionsGroup>() + ")";
Expand Down Expand Up @@ -189,16 +190,6 @@ struct ParentMesh : db::SimpleTag {
using type = std::optional<Mesh<Dim>>;
};

// The following tags are related to volume data output

/// Continuously incrementing ID for volume observations
template <typename OptionsGroup>
struct ObservationId : db::SimpleTag {
using type = size_t;
static std::string name() {
return "ObservationId(" + pretty_type::name<OptionsGroup>() + ")";
}
};
/// @{
/// Prefix tag for recording volume data in
/// `LinearSolver::multigrid::Tags::VolumeDataForOutput`
Expand Down
36 changes: 36 additions & 0 deletions src/ParallelAlgorithms/LinearSolver/Tags.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
#include "DataStructures/DataBox/Tag.hpp"
#include "DataStructures/DataBox/TagName.hpp"
#include "DataStructures/DynamicMatrix.hpp"
#include "Options/String.hpp"
#include "Utilities/Gsl.hpp"
#include "Utilities/TMPL.hpp"
#include "Utilities/TypeTraits/GetFundamentalType.hpp"

/*!
Expand All @@ -25,6 +27,19 @@
*/
namespace LinearSolver {

namespace OptionTags {

template <typename OptionsGroup>
struct OutputVolumeData {
using type = bool;
static constexpr Options::String help =
"Record volume data for debugging purposes.";
using group = OptionsGroup;
static bool suggested_value() { return false; }
};

} // namespace OptionTags

/*!
* \ingroup LinearSolverGroup
* \brief The \ref DataBoxGroup tags associated with the linear solver
Expand Down Expand Up @@ -174,5 +189,26 @@ struct Preconditioned : db::PrefixTag, db::SimpleTag {
using tag = Tag;
};

/// Whether or not volume data should be recorded for debugging purposes
template <typename OptionsGroup>
struct OutputVolumeData : db::SimpleTag {
using type = bool;
static constexpr bool pass_metavariables = false;
using option_tags = tmpl::list<OptionTags::OutputVolumeData<OptionsGroup>>;
static type create_from_options(const type value) { return value; };
static std::string name() {
return "OutputVolumeData(" + pretty_type::name<OptionsGroup>() + ")";
}
};

/// Continuously incrementing ID for volume observations
template <typename OptionsGroup>
struct ObservationId : db::SimpleTag {
using type = size_t;
static std::string name() {
return "ObservationId(" + pretty_type::name<OptionsGroup>() + ")";
}
};

} // namespace Tags
} // namespace LinearSolver
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,11 @@ SPECTRE_TEST_CASE("Unit.ParallelAlgorithms.LinearSolver.Multigrid.Tags",
"ParentRefinementLevels");
TestHelpers::db::test_simple_tag<Tags::MaxLevels<TestSolver>>(
"MaxLevels(TestSolver)");
TestHelpers::db::test_simple_tag<Tags::OutputVolumeData<TestSolver>>(
"OutputVolumeData(TestSolver)");
TestHelpers::db::test_simple_tag<Tags::MultigridLevel>("MultigridLevel");
TestHelpers::db::test_simple_tag<Tags::IsFinestGrid>("IsFinestGrid");
TestHelpers::db::test_simple_tag<Tags::ParentId<1>>("ParentId");
TestHelpers::db::test_simple_tag<Tags::ChildIds<1>>("ChildIds");
TestHelpers::db::test_simple_tag<Tags::ParentMesh<1>>("ParentMesh");
TestHelpers::db::test_simple_tag<Tags::ObservationId<TestSolver>>(
"ObservationId(TestSolver)");
TestHelpers::db::test_prefix_tag<Tags::PreSmoothingInitial<Tag>>(
"PreSmoothingInitial(Tag)");
TestHelpers::db::test_prefix_tag<Tags::PreSmoothingSource<Tag>>(
Expand Down
7 changes: 7 additions & 0 deletions tests/Unit/ParallelAlgorithms/LinearSolver/Test_Tags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct Tag : db::SimpleTag {
struct TestOptionsGroup {
static std::string name() { return "TestLinearSolver"; }
};
struct TestSolver {};
} // namespace

SPECTRE_TEST_CASE("Unit.ParallelAlgorithms.LinearSolver.Tags",
Expand All @@ -42,6 +43,12 @@ SPECTRE_TEST_CASE("Unit.ParallelAlgorithms.LinearSolver.Tags",
LinearSolver::Tags::KrylovSubspaceBasis<Tag>>("KrylovSubspaceBasis(Tag)");
TestHelpers::db::test_prefix_tag<LinearSolver::Tags::Preconditioned<Tag>>(
"Preconditioned(Tag)");
TestHelpers::db::test_simple_tag<
LinearSolver::Tags::OutputVolumeData<TestSolver>>(
"OutputVolumeData(TestSolver)");
TestHelpers::db::test_simple_tag<
LinearSolver::Tags::ObservationId<TestSolver>>(
"ObservationId(TestSolver)");

{
INFO("ResidualCompute");
Expand Down

0 comments on commit 5cb83f7

Please sign in to comment.