\(\newcommand{\B}[1]{ {\bf #1} }\) \(\newcommand{\R}[1]{ {\rm #1} }\)
jax_gradient.py¶
View page sourceCalculate Gradient Using JAX¶
Syntax¶
cmpad.jax.gradient ( algo )setup ( option )Purpose¶
This uses JAX to implement a py_fun_obj that computes the gradient of the last component of values computed by algo .
algo¶
This is a py_fun_obj where the
input and output vectors have type jax.numpy.array with float elements .
The last range space component, computed by algo ,
defines the scalar function that the gradient is for.
grad¶
This is a py_fun_obj where the input and output vectors
have elements of type float .
x¶
This is a numpy vector of float with length option [ 'n_arg' ] .
It is the argument value at which to compute the gradient.
g¶
This is a numpy vector of float with length option [ 'n_arg' ] .
It is the value of the gradient ad x .
Example¶
The file xam_grad_jax.py contains an example and test using this class.
Source Code¶
#
# imports
import jax
jax.config.update('jax_enable_x64', True)
#
# gradient
class gradient :
#
def __init__(self, algo) :
self.algo = algo
self.option = None
#
def option(self) :
return self.optiion
#
def domain(self) :
return self.option['n_arg']
#
def range(self) :
return self.option['n_arg']
#
def func(self, x) :
v = self.algo(x)
return v[-1]
#
def setup(self, option) :
assert type(option) == dict
assert 'n_arg' in option
#
# self.option
self.option = option
#
# self.algo
self.algo.setup(option)
#
# self.n_arg
self.n_arg = self.algo.domain()
assert self.n_arg == option['n_arg']
#
# self.grad
self.grad = jax.grad(self.func)
#
#
# call
def __call__(self, x) :
x = jax.numpy.array(x, dtype=float)
z = self.grad(x)
return z