46#ifndef MUELU_REBALANCETRANSFERFACTORY_DEF_HPP
47#define MUELU_REBALANCETRANSFERFACTORY_DEF_HPP
50#include <Teuchos_Tuple.hpp>
52#include "Xpetra_MultiVector.hpp"
53#include "Xpetra_MultiVectorFactory.hpp"
54#include "Xpetra_Vector.hpp"
55#include "Xpetra_VectorFactory.hpp"
56#include <Xpetra_Matrix.hpp>
57#include <Xpetra_MapFactory.hpp>
58#include <Xpetra_MatrixFactory.hpp>
59#include <Xpetra_Import.hpp>
60#include <Xpetra_ImportFactory.hpp>
61#include <Xpetra_IO.hpp>
68#include "MueLu_PerfUtils.hpp"
72 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
74 RCP<ParameterList> validParamList = rcp(
new ParameterList());
76#define SET_VALID_ENTRY(name) validParamList->setEntry(name, MasterList::getEntry(name))
84 typedef Teuchos::StringToIntegralParameterEntryValidator<int> validatorType;
85 RCP<validatorType> typeValidator = rcp (
new validatorType(Teuchos::tuple<std::string>(
"Interpolation",
"Restriction"),
"type"));
86 validParamList->set(
"type",
"Interpolation",
"Type of the transfer operator that need to be rebalanced (Interpolation or Restriction)", typeValidator);
89 validParamList->set< RCP<const FactoryBase> >(
"P", null,
"Factory of the prolongation operator that need to be rebalanced (only used if type=Interpolation)");
90 validParamList->set< RCP<const FactoryBase> >(
"R", null,
"Factory of the restriction operator that need to be rebalanced (only used if type=Restriction)");
91 validParamList->set< RCP<const FactoryBase> >(
"Nullspace", null,
"Factory of the nullspace that need to be rebalanced (only used if type=Interpolation)");
92 validParamList->set< RCP<const FactoryBase> >(
"Coordinates", null,
"Factory of the coordinates that need to be rebalanced (only used if type=Interpolation)");
93 validParamList->set< RCP<const FactoryBase> >(
"BlockNumber", null,
"Factory of the block ids that need to be rebalanced (only used if type=Interpolation)");
94 validParamList->set< RCP<const FactoryBase> >(
"Importer", null,
"Factory of the importer object used for the rebalancing");
95 validParamList->set<
int > (
"write start", -1,
"First level at which coordinates should be written to file");
96 validParamList->set<
int > (
"write end", -1,
"Last level at which coordinates should be written to file");
102 return validParamList;
105 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
107 const ParameterList& pL = GetParameterList();
109 if (pL.get<std::string>(
"type") ==
"Interpolation") {
110 Input(coarseLevel,
"P");
111 if (pL.get<
bool>(
"repartition: rebalance Nullspace"))
112 Input(coarseLevel,
"Nullspace");
113 if (pL.get< RCP<const FactoryBase> >(
"Coordinates") != Teuchos::null)
114 Input(coarseLevel,
"Coordinates");
115 if (pL.get< RCP<const FactoryBase> >(
"BlockNumber") != Teuchos::null)
116 Input(coarseLevel,
"BlockNumber");
119 if (pL.get<
bool>(
"transpose: use implicit") ==
false)
120 Input(coarseLevel,
"R");
123 Input(coarseLevel,
"Importer");
126 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
129 typedef Xpetra::MultiVector<typename Teuchos::ScalarTraits<Scalar>::magnitudeType, LO, GO, NO> xdMV;
131 const ParameterList& pL = GetParameterList();
133 RCP<Matrix> originalP = Get< RCP<Matrix> >(coarseLevel,
"P");
136 if (originalP == Teuchos::null) {
137 Set(coarseLevel,
"P", originalP);
140 int implicit = !pL.get<
bool>(
"repartition: rebalance P and R");
141 int writeStart = pL.get<
int> (
"write start");
142 int writeEnd = pL.get<
int> (
"write end");
144 if (writeStart == 0 && fineLevel.
GetLevelID() == 0 && writeStart <= writeEnd && IsAvailable(fineLevel,
"Coordinates")) {
145 std::string fileName =
"coordinates_level_0.m";
146 RCP<xdMV> fineCoords = fineLevel.
Get< RCP<xdMV> >(
"Coordinates");
147 if (fineCoords != Teuchos::null)
148 Xpetra::IO<typename Teuchos::ScalarTraits<Scalar>::magnitudeType,LO,GO,NO>::Write(fileName, *fineCoords);
151 if (writeStart == 0 && fineLevel.
GetLevelID() == 0 && writeStart <= writeEnd && IsAvailable(fineLevel,
"BlockNumber")) {
152 std::string fileName =
"BlockNumber_level_0.m";
153 RCP<LocalOrdinalVector> fineBlockNumber = fineLevel.
Get< RCP<LocalOrdinalVector> >(
"BlockNumber");
154 if (fineBlockNumber != Teuchos::null)
155 Xpetra::IO<LO,LO,GO,NO>::Write(fileName, *fineBlockNumber);
158 RCP<const Import> importer = Get<RCP<const Import> >(coarseLevel,
"Importer");
164 RCP<ParameterList> params = rcp(
new ParameterList());
166 params->set(
"printLoadBalancingInfo",
true);
167 params->set(
"printCommInfo",
true);
170 std::string transferType = pL.get<std::string>(
"type");
171 if (transferType ==
"Interpolation") {
172 originalP = Get< RCP<Matrix> >(coarseLevel,
"P");
178 if (implicit || importer.is_null()) {
179 GetOStream(
Runtime0) <<
"Using original prolongator" << std::endl;
180 Set(coarseLevel,
"P", originalP);
198 RCP<Matrix> rebalancedP = originalP;
199 RCP<const CrsMatrixWrap> crsOp = rcp_dynamic_cast<const CrsMatrixWrap>(originalP);
200 TEUCHOS_TEST_FOR_EXCEPTION(crsOp == Teuchos::null,
Exceptions::BadCast,
"Cast from Xpetra::Matrix to Xpetra::CrsMatrixWrap failed");
202 RCP<CrsMatrix> rebalancedP2 = crsOp->getCrsMatrix();
203 TEUCHOS_TEST_FOR_EXCEPTION(rebalancedP2 == Teuchos::null, std::runtime_error,
"Xpetra::CrsMatrixWrap doesn't have a CrsMatrix");
206 SubFactoryMonitor subM(*
this,
"Rebalancing prolongator -- fast map replacement", coarseLevel);
208 RCP<const Import> newImporter;
211 newImporter = ImportFactory::Build(importer->getTargetMap(), rebalancedP->getColMap());
213 rebalancedP2->replaceDomainMapAndImporter(importer->getTargetMap(), newImporter);
222 if(!rebalancedP.is_null()) {std::ostringstream oss; oss <<
"P_" << coarseLevel.GetLevelID(); rebalancedP->setObjectLabel(oss.str());}
223 Set(coarseLevel,
"P", rebalancedP);
230 if (importer.is_null()) {
231 if (IsAvailable(coarseLevel,
"Nullspace"))
232 Set(coarseLevel,
"Nullspace", Get<RCP<MultiVector> >(coarseLevel,
"Nullspace"));
234 if (pL.isParameter(
"Coordinates") && pL.get< RCP<const FactoryBase> >(
"Coordinates") != Teuchos::null)
235 if (IsAvailable(coarseLevel,
"Coordinates"))
236 Set(coarseLevel,
"Coordinates", Get< RCP<xdMV> >(coarseLevel,
"Coordinates"));
238 if (pL.isParameter(
"BlockNumber") && pL.get< RCP<const FactoryBase> >(
"BlockNumber") != Teuchos::null)
239 if (IsAvailable(coarseLevel,
"BlockNumber"))
240 Set(coarseLevel,
"BlockNumber", Get< RCP<LocalOrdinalVector> >(coarseLevel,
"BlockNumber"));
245 if (pL.isParameter(
"Coordinates") &&
246 pL.get< RCP<const FactoryBase> >(
"Coordinates") != Teuchos::null &&
247 IsAvailable(coarseLevel,
"Coordinates")) {
248 RCP<xdMV> coords = Get<RCP<xdMV> >(coarseLevel,
"Coordinates");
253 LO nodeNumElts = coords->getMap()->getLocalNumElements();
256 LO myBlkSize = 0, blkSize = 0;
258 myBlkSize = importer->getSourceMap()->getLocalNumElements() / nodeNumElts;
259 MueLu_maxAll(coords->getMap()->getComm(), myBlkSize, blkSize);
261 RCP<const Import> coordImporter;
263 coordImporter = importer;
269 RCP<const Map> origMap = coords->getMap();
270 GO indexBase = origMap->getIndexBase();
272 ArrayView<const GO> OEntries = importer->getTargetMap()->getLocalElementList();
273 LO numEntries = OEntries.size()/blkSize;
274 ArrayRCP<GO> Entries(numEntries);
275 for (LO i = 0; i < numEntries; i++)
276 Entries[i] = (OEntries[i*blkSize]-indexBase)/blkSize + indexBase;
278 RCP<const Map> targetMap = MapFactory::Build(origMap->lib(), origMap->getGlobalNumElements(), Entries(), indexBase, origMap->getComm());
279 coordImporter = ImportFactory::Build(origMap, targetMap);
282 RCP<xdMV> permutedCoords = Xpetra::MultiVectorFactory<typename Teuchos::ScalarTraits<Scalar>::magnitudeType,LO,GO,NO>::Build(coordImporter->getTargetMap(), coords->getNumVectors());
283 permutedCoords->doImport(*coords, *coordImporter, Xpetra::INSERT);
285 if (pL.isParameter(
"repartition: use subcommunicators") ==
true && pL.get<
bool>(
"repartition: use subcommunicators") ==
true)
286 permutedCoords->replaceMap(permutedCoords->getMap()->removeEmptyProcesses());
288 if (permutedCoords->getMap() == Teuchos::null)
289 permutedCoords = Teuchos::null;
291 Set(coarseLevel,
"Coordinates", permutedCoords);
293 std::string fileName =
"rebalanced_coordinates_level_" +
toString(coarseLevel.GetLevelID()) +
".m";
294 if (writeStart <= coarseLevel.GetLevelID() && coarseLevel.GetLevelID() <= writeEnd && permutedCoords->getMap() != Teuchos::null)
295 Xpetra::IO<typename Teuchos::ScalarTraits<Scalar>::magnitudeType,LO,GO,NO>::Write(fileName, *permutedCoords);
298 if (pL.isParameter(
"BlockNumber") &&
299 pL.get< RCP<const FactoryBase> >(
"BlockNumber") != Teuchos::null &&
300 IsAvailable(coarseLevel,
"BlockNumber")) {
301 RCP<LocalOrdinalVector> BlockNumber = Get<RCP<LocalOrdinalVector> >(coarseLevel,
"BlockNumber");
306 RCP<LocalOrdinalVector> permutedBlockNumber = LocalOrdinalVectorFactory::Build(importer->getTargetMap(),
false);
307 permutedBlockNumber->doImport(*BlockNumber, *importer, Xpetra::INSERT);
309 if (pL.isParameter(
"repartition: use subcommunicators") ==
true && pL.get<
bool>(
"repartition: use subcommunicators") ==
true)
310 permutedBlockNumber->replaceMap(permutedBlockNumber->getMap()->removeEmptyProcesses());
312 if (permutedBlockNumber->getMap() == Teuchos::null)
313 permutedBlockNumber = Teuchos::null;
315 Set(coarseLevel,
"BlockNumber", permutedBlockNumber);
317 std::string fileName =
"rebalanced_BlockNumber_level_" +
toString(coarseLevel.GetLevelID()) +
".m";
318 if (writeStart <= coarseLevel.GetLevelID() && coarseLevel.GetLevelID() <= writeEnd && permutedBlockNumber->getMap() != Teuchos::null)
319 Xpetra::IO<LO,LO,GO,NO>::Write(fileName, *permutedBlockNumber);
322 if (IsAvailable(coarseLevel,
"Nullspace")) {
323 RCP<MultiVector> nullspace = Get< RCP<MultiVector> >(coarseLevel,
"Nullspace");
328 RCP<MultiVector> permutedNullspace = MultiVectorFactory::Build(importer->getTargetMap(), nullspace->getNumVectors());
329 permutedNullspace->doImport(*nullspace, *importer, Xpetra::INSERT);
331 if (pL.get<
bool>(
"repartition: use subcommunicators") ==
true)
332 permutedNullspace->replaceMap(permutedNullspace->getMap()->removeEmptyProcesses());
334 if (permutedNullspace->getMap() == Teuchos::null)
335 permutedNullspace = Teuchos::null;
337 Set(coarseLevel,
"Nullspace", permutedNullspace);
341 if (pL.get<
bool>(
"transpose: use implicit") ==
false) {
342 RCP<Matrix> originalR = Get< RCP<Matrix> >(coarseLevel,
"R");
346 if (implicit || importer.is_null()) {
347 GetOStream(
Runtime0) <<
"Using original restrictor" << std::endl;
348 Set(coarseLevel,
"R", originalR);
351 RCP<Matrix> rebalancedR;
353 SubFactoryMonitor subM(*
this,
"Rebalancing restriction -- fusedImport", coarseLevel);
356 Teuchos::ParameterList listLabel;
357 listLabel.set(
"Timer Label",
"MueLu::RebalanceR-" + Teuchos::toString(coarseLevel.GetLevelID()));
358 rebalancedR = MatrixFactory::Build(originalR, *importer, dummy, importer->getTargetMap(),Teuchos::rcp(&listLabel,
false));
360 if(!rebalancedR.is_null()) {std::ostringstream oss; oss <<
"R_" << coarseLevel.GetLevelID(); rebalancedR->setObjectLabel(oss.str());}
361 Set(coarseLevel,
"R", rebalancedR);
#define SET_VALID_ENTRY(name)
#define MueLu_maxAll(rcpComm, in, out)
Exception indicating invalid cast attempted.
Timer to be used in factories. Similar to Monitor but with additional timers.
Class that holds all level-specific information.
int GetLevelID() const
Return level number.
T & Get(const std::string &ename, const FactoryBase *factory=NoFactory::get())
Get data without decrementing associated storage counter (i.e., read-only access)....
void Set(const std::string &ename, const T &entry, const FactoryBase *factory=NoFactory::get())
static const NoFactory * get()
static std::string PrintMatrixInfo(const Matrix &A, const std::string &msgTag, RCP< const Teuchos::ParameterList > params=Teuchos::null)
RCP< const ParameterList > GetValidParameterList() const
Return a const parameter list of valid parameters that setParameterList() will accept.
void DeclareInput(Level &fineLevel, Level &coarseLevel) const
Specifies the data that this class needs, and the factories that generate that data.
void Build(Level &fineLevel, Level &coarseLevel) const
Build an object with this factory.
Timer to be used in factories. Similar to SubMonitor but adds a timer level by level.
Namespace for MueLu classes and methods.
@ Statistics2
Print even more statistics.
@ Runtime0
One-liner description of what is happening.
std::string toString(const T &what)
Little helper function to convert non-string types to strings.