\(\newcommand{\B}[1]{ {\bf #1} }\) \(\newcommand{\R}[1]{ {\rm #1} }\)
xad_gradient.hpp¶
View page sourceCalculate Gradient Using XAD¶
Syntax¶
# include <cmpad/xad/gradient.hpp>cmpad::xad::gradient < Algo > grad grad
.setup ( option ) g = grad ( x )
Purpose¶
This implements the cpp_gradient interface using XAD.
Algo¶
see Algo for the base class.
ADVector¶
The type Algo ::vector_type is the
ADVector type for this gradient.
vector_type¶
see vector_type for the base class.
scalar_type¶
see scalar_type for the base class.
setup¶
see the gradient setup for the base class.
option¶
This option_t object is used to specify the setup options.
Example¶
The file xam_gradient_xad.cpp contains an example and test using this class.
Source Code¶
# if CMPAD_HAS_XAD
/*
2DO: Under Construction: This gradient does not yet pass its tests
*/
# include <XAD/XAD.hpp>
# include <cmpad/gradient.hpp>
namespace cmpad { namespace xad { // BEGIN cmpad::xad namespace
::xad::adj<double>::tape_type adj_tape;
// cmpad::xad::gradient_retape
template < template<class ADVector> class Algo > class gradient
: public
cmpad::gradient {
private:
// mode
typedef ::xad::adj<double> mode;
//
// ADScalar, ADVector
typedef mode::active_type ADScalar;
typedef cmpad::vector<ADScalar> ADVector;
//
// option_
option_t option_;
//
// algo_
Algo<ADVector> algo_;
//
// ax_, ay_, az_
ADVector ax_;
ADVector ay_;
ADScalar az_;
//
// g_
cmpad::vector<double> g_;
//
public:
//
// scalar_type
typedef double scalar_type;
//
// vector_type
typedef cmpad::vector<double> vector_type;
//
// option
const option_t& option(void) const override
{ return option_; }
//
// setup
void setup(const option_t& option) override
{ //
// option_
option_ = option;
//
// algo_
algo_.setup(option);
//
// n
size_t n = algo_.domain();
//
// m
size_t m = algo_.range();
//
// ax_
ax_.resize(n);
//
// ay_
ay_.resize(m);
//
// g_
g_.resize(n);
}
// domain
size_t domain(void) const override
{ return algo_.domain(); };
//
// operator
const cmpad::vector<double>& operator()(
const cmpad::vector<double>& x
) override
{ assert( x.size() == algo_.domain() );
//
// n, m
int n = int( algo_.domain() );
int m = int( algo_.range() );
//
// ax_
// independent variable values
for(size_t j = 0; j < n; ++j)
ax_[j] = x[j];
//
// adj_tape
adj_tape.clearAll();
for(size_t j = 0; j < n; ++j)
adj_tape.registerInput( ax_[j] );
adj_tape.newRecording();
//
// az_
// dependent variable
ay_ = algo_(ax_);
az_ = ay_[m-1];
//
// adj_tape
adj_tape.registerOutput(az_);
//
// adj_tape, az_, ax_
derivative(az_) = 1.0;
adj_tape.computeAdjoints();
//
// g_
for(size_t j = 0; j < n; ++j)
g_[j] = derivative( ax_[j] );
//
return g_;
}
};
} } // END cmpad::xad namespace
# endif // CMPAD_HAS_XAD