xam_grad_jax.py

View page source

Example and Test of jax Gradient

check_grad_det

see check_grad_det.py .

check_grad_ode

see check_grad_ode.py .

Source Code

import cmpad
import jax
from check_grad_det  import check_grad_det
from check_grad_ode  import check_grad_ode
from check_grad_llsq import check_grad_llsq

def xam_grad_jax() :
   # ok
   ok = True
   #
   # grad_det, ok
   algo     = cmpad.det_by_minor()
   grad_det = cmpad.jax.gradient( algo )
   ok      &= check_grad_det( grad_det )
   #
   # grad_ode, ok
   algo     = cmpad.an_ode(cmpad.jax.like_numpy)
   grad_ode = cmpad.jax.gradient(algo)
   ok      &= check_grad_ode( grad_ode )
   #
   # grad_llsq, ok
   algo      = cmpad.llsq_obj(cmpad.jax.like_numpy)
   grad_llsq = cmpad.jax.gradient( algo )
   ok       &= check_grad_llsq( grad_llsq )
   #
   return ok
#
def test_xam_grad_jax() :
   assert xam_grad_jax() == True