Loops#
The module aesara.scan
provides the basic functionality needed to do loops
in Aesara.
This module provides the Scan Op.
Scanning is a general form of recurrence, which can be used for looping. The idea is that you scan a function along some input sequence, producing an output at each time-step that can be seen (but not modified) by the function at the next time-step. (Technically, the function can see the previous K time-steps of your outputs and L time steps (from past and future) of your inputs.
So for example, sum()
could be computed by scanning the z+x_i
function over a list, given an initial state of z=0
.
Special cases:
A reduce operation can be performed by using only the last output of a
scan
.A map operation can be performed by applying a function that ignores previous steps of the outputs.
Often a for-loop or while-loop can be expressed as a scan()
operation,
and scan
is the closest that aesara comes to looping. The advantages
of using scan
over for
loops in python (among others) are:
it allows the number of iterations to be part of the symbolic graph
it allows computing gradients through the for loop
there exist a bunch of optimizations that help re-write your loop such that less memory is used and that it runs faster
The Scan Op should typically be used by calling any of the following
functions: scan()
, map()
, reduce()
, foldl()
,
foldr()
.
aesara.scan
#
- aesara.scan(fn, sequences=None, outputs_info=None, non_sequences=None, n_steps=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None, profile=False, allow_gc=None, strict=False, return_list=False)[source]
This function constructs and applies a
Scan
Op
to the provided arguments.- Parameters:
fn –
fn
is a function that describes the operations involved in one step ofscan
.fn
should construct variables describing the output of one iteration step. It should expect as inputVariable
s representing all the slices of the input sequences and previous values of the outputs, as well as all other arguments given to scan asnon_sequences
. The order in which scan passes these variables tofn
is the following :all time slices of the first sequence
all time slices of the second sequence
…
all time slices of the last sequence
all past slices of the first output
all past slices of the second output
…
all past slices of the last output
- all other arguments (the list given as
non_sequences
to scan
)
- all other arguments (the list given as
The order of the sequences is the same as the one in the list
sequences
given toscan
. The order of the outputs is the same as the order ofoutputs_info
. For any sequence or output the order of the time slices is the same as the one in which they have been given as taps. For example if one writes the following :scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1]) , Sequence2 , dict(input = Sequence3, taps = 3) ] , outputs_info = [ dict(initial = Output1, taps = [-3,-5]) , dict(initial = Output2, taps = None) , Output3 ] , non_sequences = [ Argument1, Argument2])
fn
should expect the following arguments in this given order:sequence1[t-3]
sequence1[t+2]
sequence1[t-1]
sequence2[t]
sequence3[t+3]
output1[t-3]
output1[t-5]
output3[t-1]
argument1
argument2
The list of
non_sequences
can also contain shared variables used in the function, thoughscan
is able to figure those out on its own so they can be skipped. For the clarity of the code we recommend though to provide them toscan
. To some extendscan
can also figure out othernon sequences
(not shared) even if not passed toscan
(but used byfn
). A simple example of this would be :import aesara.tensor as at W = at.matrix() W_2 = W**2 def f(x): return at.dot(x,W_2)
The function
fn
is expected to return two things. One is a list of outputs ordered in the same order asoutputs_info
, with the difference that there should be only one output variable per output initial state (even if no tap value is used). Secondlyfn
should return an update dictionary (that tells how to update any shared variable after each iteration step). The dictionary can optionally be given as a list of tuples. There is no constraint on the order of these two list,fn
can return either(outputs_list, update_dictionary)
or(update_dictionary, outputs_list)
or just one of the two (in case the other is empty).To use
scan
as awhile
loop, the user needs to change the functionfn
such that also a stopping condition is returned. To do so, one needs to wrap the condition in anuntil
class. The condition should be returned as a third element, for example:... return [y1_t, y2_t], {x:x+1}, until(x < 50)
Note that a number of steps–considered in here as the maximum number of steps–is still required even though a condition is passed. It is used to allocate memory if needed.
sequences –
sequences
is the list ofVariable
s ordict
s describing the sequencesscan
has to iterate over. If a sequence is given as wrapped in adict
, then a set of optional information can be provided about the sequence. Thedict
should have the following keys:input
(mandatory) –Variable
representing the sequence.taps
– Temporal taps of the sequence required byfn
. They are provided as a list of integers, where a valuek
impiles that at iteration stept
scan will pass tofn
the slicet+k
. Default value is[0]
All
Variable
s in the listsequences
are automatically wrapped into adict
wheretaps
is set to[0]
outputs_info –
outputs_info
is the list ofVariable
s ordict
s describing the initial state of the outputs computed recurrently. When the initial states are given asdict
s, optional information can be provided about the output corresponding to those initial states. Thedict
should have the following keys:initial
– AVariable
that represents the initial state of a given output. In case the output is not computed recursively (e.g. amap
-like function) and does not require an initial state, this field can be skipped. Given that only the previous time step of the output is used byfn
, the initial state should have the same shape as the output and should not involve a downcast of the data type of the output. If multiple time taps are used, the initial state should have one extra dimension that covers all the possible taps. For example if we use-5
,-2
and-1
as past taps, at step0
,fn
will require (by an abuse of notation)output[-5]
,output[-2]
andoutput[-1]
. This will be given by the initial state, which in this case should have the shape(5,) + output.shape
. If thisVariable
containing the initial state is calledinit_y
theninit_y[0]
corresponds tooutput[-5]
.init_y[1]
corresponds tooutput[-4]
,init_y[2]
corresponds tooutput[-3]
,init_y[3]
corresponds tooutput[-2]
,init_y[4]
corresponds tooutput[-1]
. While this order might seem strange, it comes natural from splitting an array at a given point. assume that we have a arrayx
, and we choosek
to be time step0
. Then our initial state would bex[:k]
, while the output will bex[k:]
. Looking at this split, elements inx[:k]
are ordered exactly like those ininit_y
.taps
– Temporal taps of the output that will be passed tofn
. They are provided as a list of negative integers, where a valuek
implies that at iteration stept
scan will pass tofn
the slicet+k
.
scan
will follow this logic if partial information is given:If an output is not wrapped in a
dict
,scan
will wrap it in one assuming that you use only the last step of the output (i.e. it makes your tap value list equal to[-1]
).If you wrap an output in a
dict
and you do not provide any taps but you provide an initial state it will assume that you are using only a tap value of-1
.If you wrap an output in a
dict
but you do not provide any initial state, it assumes that you are not using any form of taps.If you provide a
None
instead of aVariable
or a emptydict
scan
assumes that you will not use any taps for this output (like for example in case of amap
)
If
outputs_info
is an emptylist
orNone
,scan
assumes that no tap is used for any of the outputs. If information is provided just for a subset of the outputs, an exception is raised, because there is no convention on how scan should map the provided information to the outputs offn
.non_sequences –
non_sequences
is the list of arguments that are passed tofn
at each steps. One can choose to exclude variables used infn
from this list, as long as they are part of the computational graph, although–for clarity–this is not encouraged.n_steps –
n_steps
is the number of steps to iterate given as anint
or a scalarVariable
. If any of the input sequences do not have enough elements,scan
will raise an error. If the value is0
, the outputs will have0
rows. Ifn_steps
is not provided,scan
will figure out the amount of steps it should run given its input sequences.n_steps < 0
is not supported anymore.truncate_gradient –
truncate_gradient
is the number of steps to use in truncated back-propagation through time (BPTT). If you compute gradients through aScan
Op
, they are computed using BPTT. By providing a different value then-1
, you choose to use truncated BPTT instead of classical BPTT, where you go for onlytruncate_gradient
number of steps back in time.go_backwards –
go_backwards
is a flag indicating ifscan
should go backwards through the sequences. If you think of each sequence as indexed by time, making this flagTrue
would mean thatscan
goes back in time, namely that for any sequence it starts from the end and goes towards0
.name – When profiling
scan
, it is helpful to provide a name for any instance ofscan
. For example, the profiler will produce an overall profile of your code as well as profiles for the computation of one step of each instance ofScan
. Thename
of the instance appears in those profiles and can greatly help to disambiguate information.mode – The mode used to compile the inner-graph. If you prefer the computations of one step of
scan
to be done differently then the entire function, you can use this parameter to describe how the computations in this loop are done (seeaesara.function
for details about possible values and their meaning).profile – If
True
or a non-empty string, a profile object will be created and attached to the inner graph ofScan
. Whenprofile
isTrue
, the profiler results will use the name of theScan
instance, otherwise it will use the passed string. The profiler only collects and prints information when running the inner graph with theCVM
Linker
.allow_gc –
Set the value of
allow_gc
for the internal graph of theScan
. If set toNone
, this will use the value ofaesara.config.scan__allow_gc
.The full
Scan
behavior related to allocation is determined by this value and the flagaesara.config.allow_gc
. If the flagallow_gc
isTrue
(default) and thisallow_gc
isFalse
(default), then we letScan
allocate all intermediate memory on the first iteration, and they are not garbage collected after that first iteration; this is determined byallow_gc
. This can speed up allocation of the subsequent iterations. All those temporary allocations are freed at the end of all iterations; this is what the flagaesara.config.allow_gc
means.strict – If
True
, all the shared variables used infn
must be provided as a part ofnon_sequences
orsequences
.return_list – If
True
, will always return alist
, even if there is only one output.
- Returns:
tuple
of the form(outputs, updates)
.outputs
is either aVariable
or alist
ofVariable
s representing the outputs in the same order as inoutputs_info
.updates
is a subclass ofdict
specifying the update rules for all shared variables used inScan
. Thisdict
should be passed toaesara.function
when you compile your function.- Return type:
tuple
Other ways to create loops#
aesara.scan()
comes with bells and whistles that are not always all necessary, which is why Aesara provides several other functions to create a Scan
operator:
- aesara.map(fn, sequences, non_sequences=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None)[source]#
Construct a
Scan
Op
that functions likemap
.- Parameters:
fn – The function that
map
applies at each iteration step (seescan
for more info).sequences – List of sequences over which
map
iterates (seescan
for more info).non_sequences – List of arguments passed to
fn
.map
will not iterate over these arguments (seescan
for more info).truncate_gradient – See
scan
.go_backwards (bool) – Decides the direction of iteration. True means that sequences are parsed from the end towards the beginning, while False is the other way around.
mode – See
scan
.name – See
scan
.
- aesara.reduce(fn, sequences, outputs_info, non_sequences=None, go_backwards=False, mode=None, name=None)[source]#
Construct a
Scan
Op
that functions likereduce
.- Parameters:
fn – The function that
reduce
applies at each iteration step (seescan
for more info).sequences – List of sequences over which
reduce
iterates (seescan
for more info).outputs_info – List of dictionaries describing the outputs of reduce (see
scan
for more info).non_sequences –
- List of arguments passed to
fn
.reduce
will not iterate over these arguments (see
scan
for more info).
- List of arguments passed to
go_backwards (bool) – Decides the direction of iteration. True means that sequences are parsed from the end towards the beginning, while False is the other way around.
mode – See
scan
.name – See
scan
.
- aesara.foldl(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None)[source]#
Construct a
Scan
Op
that functions like Haskell’sfoldl
.- Parameters:
fn – The function that
foldl
applies at each iteration step (seescan
for more info).sequences – List of sequences over which
foldl
iterates (seescan
for more info).outputs_info – List of dictionaries describing the outputs of reduce (see
scan
for more info).non_sequences – List of arguments passed to
fn
.foldl
will not iterate over these arguments (seescan
for more info).mode – See
scan
.name – See
scan
.
- aesara.foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None)[source]#
Construct a
Scan
Op
that functions like Haskell’sfoldr
.- Parameters:
fn – The function that
foldr
applies at each iteration step (seescan
for more info).sequences – List of sequences over which
foldr
iterates (seescan
for more info).outputs_info – List of dictionaries describing the outputs of reduce (see
scan
for more info).non_sequences – List of arguments passed to
fn
.foldr
will not iterate over these arguments (seescan
for more info).mode – See
scan
.name – See
scan
.