46#ifndef MUELU_LOWPRECISIONFACTORY_DEF_HPP
47#define MUELU_LOWPRECISIONFACTORY_DEF_HPP
49#include <Xpetra_Matrix.hpp>
50#include <Xpetra_Operator.hpp>
51#include <Xpetra_TpetraOperator.hpp>
52#include <Tpetra_CrsMatrixMultiplyOp.hpp>
56#include "MueLu_FactoryManager.hpp"
64 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
66 RCP<ParameterList> validParamList = rcp(
new ParameterList());
68 validParamList->set<std::string>(
"matrix key",
"A",
"");
69 validParamList->set< RCP<const FactoryBase> >(
"R", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
70 validParamList->set< RCP<const FactoryBase> >(
"A", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
71 validParamList->set< RCP<const FactoryBase> >(
"P", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
73 return validParamList;
76 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
79 const ParameterList& pL = GetParameterList();
80 std::string matrixKey = pL.get<std::string>(
"matrix key");
81 Input(currentLevel, matrixKey);
84 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
86 using Teuchos::ParameterList;
88 const ParameterList& pL = GetParameterList();
89 std::string matrixKey = pL.get<std::string>(
"matrix key");
91 FactoryMonitor m(*
this,
"Converting " + matrixKey +
" to half precision", currentLevel);
93 RCP<Matrix> A = Get< RCP<Matrix> >(currentLevel, matrixKey);
95 GetOStream(
Warnings) <<
"Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
96 Set(currentLevel, matrixKey, A);
100#if defined(HAVE_TPETRA_INST_DOUBLE) && defined(HAVE_TPETRA_INST_FLOAT)
101 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
103 RCP<ParameterList> validParamList = rcp(
new ParameterList());
105 validParamList->set<std::string>(
"matrix key",
"A",
"");
106 validParamList->set< RCP<const FactoryBase> >(
"R", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
107 validParamList->set< RCP<const FactoryBase> >(
"A", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
108 validParamList->set< RCP<const FactoryBase> >(
"P", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
110 return validParamList;
113 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
116 const ParameterList& pL = GetParameterList();
117 std::string matrixKey = pL.get<std::string>(
"matrix key");
118 Input(currentLevel, matrixKey);
121 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
123 using Teuchos::ParameterList;
124 using HalfScalar =
typename Teuchos::ScalarTraits<Scalar>::halfPrecision;
126 const ParameterList& pL = GetParameterList();
127 std::string matrixKey = pL.get<std::string>(
"matrix key");
129 FactoryMonitor m(*
this,
"Converting " + matrixKey +
" to half precision", currentLevel);
131 RCP<Matrix> A = Get< RCP<Matrix> >(currentLevel, matrixKey);
133 if ((A->getRowMap()->lib() == Xpetra::UseTpetra) && std::is_same<Scalar, double>::value) {
134 auto tpA = rcp_dynamic_cast<TpetraCrsMatrix>(rcp_dynamic_cast<CrsMatrixWrap>(A)->getCrsMatrix(),
true)->getTpetra_CrsMatrix();
135 auto tpLowA = tpA->template convert<HalfScalar>();
136 auto tpLowOpA = rcp(
new Tpetra::CrsMatrixMultiplyOp<Scalar,HalfScalar,LocalOrdinal,GlobalOrdinal,Node>(tpLowA));
137 auto xpTpLowOpA = rcp(
new TpetraOperator(tpLowOpA));
138 auto xpLowOpA = rcp_dynamic_cast<Operator>(xpTpLowOpA);
139 Set(currentLevel, matrixKey, xpLowOpA);
143 GetOStream(
Warnings) <<
"Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
144 Set(currentLevel, matrixKey, A);
149#if defined(HAVE_TPETRA_INST_COMPLEX_DOUBLE) && defined(HAVE_TPETRA_INST_COMPLEX_FLOAT)
150 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
152 RCP<ParameterList> validParamList = rcp(
new ParameterList());
154 validParamList->set<std::string>(
"matrix key",
"A",
"");
155 validParamList->set< RCP<const FactoryBase> >(
"R", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
156 validParamList->set< RCP<const FactoryBase> >(
"A", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
157 validParamList->set< RCP<const FactoryBase> >(
"P", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
159 return validParamList;
162 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
165 const ParameterList& pL = GetParameterList();
166 std::string matrixKey = pL.get<std::string>(
"matrix key");
167 Input(currentLevel, matrixKey);
170 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
172 using Teuchos::ParameterList;
173 using HalfScalar =
typename Teuchos::ScalarTraits<Scalar>::halfPrecision;
175 const ParameterList& pL = GetParameterList();
176 std::string matrixKey = pL.get<std::string>(
"matrix key");
178 FactoryMonitor m(*
this,
"Converting " + matrixKey +
" to half precision", currentLevel);
180 RCP<Matrix> A = Get< RCP<Matrix> >(currentLevel, matrixKey);
182 if ((A->getRowMap()->lib() == Xpetra::UseTpetra) && std::is_same<Scalar, std::complex<double> >::value) {
183 auto tpA = rcp_dynamic_cast<TpetraCrsMatrix>(rcp_dynamic_cast<CrsMatrixWrap>(A)->getCrsMatrix(),
true)->getTpetra_CrsMatrix();
184 auto tpLowA = tpA->template convert<HalfScalar>();
185 auto tpLowOpA = rcp(
new Tpetra::CrsMatrixMultiplyOp<Scalar,HalfScalar,LocalOrdinal,GlobalOrdinal,Node>(tpLowA));
186 auto xpTpLowOpA = rcp(
new TpetraOperator(tpLowOpA));
187 auto xpLowOpA = rcp_dynamic_cast<Operator>(xpTpLowOpA);
188 Set(currentLevel, matrixKey, xpLowOpA);
192 GetOStream(
Warnings) <<
"Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
193 Set(currentLevel, matrixKey, A);
MueLu::DefaultLocalOrdinal LocalOrdinal
MueLu::DefaultGlobalOrdinal GlobalOrdinal
Timer to be used in factories. Similar to Monitor but with additional timers.
Class that holds all level-specific information.
RCP< const ParameterList > GetValidParameterList() const
Return a const parameter list of valid parameters that setParameterList() will accept.
void DeclareInput(Level ¤tLevel) const
Input.
void Build(Level ¤tLevel) const
Build method.
Namespace for MueLu classes and methods.
@ Warnings
Print all warning messages.