Gradients#
The module aesara.grad
allows to compute the gradient of an Aesara graph.
- aesara.gradient.grad(cost: Variable | None, wrt: Variable | Sequence[Variable], consider_constant: Sequence[Variable] | None = None, disconnected_inputs: Literal['ignore', 'warn', 'raise'] = 'raise', add_names: bool = True, known_grads: Mapping[Variable, Variable] | None = None, return_disconnected: Literal['none', 'zero', 'disconnected'] = 'zero', null_gradients: Literal['raise', 'return'] = 'raise') Variable | None | Sequence[Variable | None] [source]#
Return symbolic gradients of one cost with respect to one or more variables.
For more information about how automatic differentiation works in Aesara, see
gradient
. For information on how to implement the gradient of a certain Op, seegrad()
.- Parameters:
cost – Value that we are differentiating (i.e. for which we want the gradient). May be
None
ifknown_grads
is provided.wrt – The term(s) with respect to which we want gradients.
consider_constant – Expressions not to backpropagate through.
disconnected_inputs ({'ignore', 'warn', 'raise'}) –
Defines the behaviour if some of the variables in
wrt
are not part of the computational graph computingcost
(or if all links are non-differentiable). The possible values are:'ignore'
: considers that the gradient on these parameters is zero'warn'
: consider the gradient zero, and print a warning'raise'
: raiseDisconnectedInputError
add_names – If
True
, variables generated bygrad
will be named(d<cost.name>/d<wrt.name>)
provided that bothcost
andwrt
have names.known_grads – An ordered dictionary mapping variables to their gradients. This is useful in the case where you know the gradients of some variables but do not know the original cost.
return_disconnected –
'zero'
: Ifwrt[i]
is disconnected, return valuei
will bewrt[i].zeros_like()
'none'
: Ifwrt[i]
is disconnected, return valuei
will beNone
'disconnected'
: returns variables of typeDisconnectedType
null_gradients –
Defines the behaviour when some of the variables in
wrt
have a null gradient. The possibles values are:'raise'
: raise aNullTypeGradError
exception'return'
: return the null gradients
- Returns:
A symbolic expression for the gradient of
cost
with respect to eachof the
wrt
terms. If an element ofwrt
is not differentiable withrespect to the output, then a zero variable is returned.
This section of the documentation is organized as follows:
Derivatives in Aesara gives a hands-on introduction to how to build gradient graphs in Aesara.
Gradients API is an API reference for the
aesara.gradient
module.