taichi.ad
¶
Module Contents¶
Functions¶
|
A decorator for python function to customize gradient with Taichi's autodiff |
|
Generates a decorator to decorate primal's customized gradient function. |
- taichi.ad.grad_replaced(func)¶
A decorator for python function to customize gradient with Taichi’s autodiff system, e.g. ti.Tape() and kernel.grad(). This decorator forces Taichi’s autodiff system to use a user-defined gradient function for the decorated function. Its customized gradient must be decorated by
grad_for()
.- Parameters
fn (Callable) – The python function to be decorated.
- Returns
The decorated function.
- Return type
Callable
Example:
>>> @ti.kernel >>> def multiply(a: ti.float32): >>> for I in ti.grouped(x): >>> y[I] = x[I] * a >>> >>> @ti.kernel >>> def multiply_grad(a: ti.float32): >>> for I in ti.grouped(x): >>> x.grad[I] = y.grad[I] / a >>> >>> @ti.grad_replaced >>> def foo(a): >>> multiply(a) >>> >>> @ti.grad_for(foo) >>> def foo_grad(a): >>> multiply_grad(a)
- taichi.ad.grad_for(primal)¶
Generates a decorator to decorate primal’s customized gradient function. See
grad_replaced()
for examples.- Parameters
primal (Callable) – The primal function, must be decorated by
grad_replaced()
.- Returns
The decorator used to decorate customized gradient function.
- Return type
Callable