Loops#

The module aesara.scan provides the basic functionality needed to do loops in Aesara.

This module provides the Scan Op.

Scanning is a general form of recurrence, which can be used for looping. The idea is that you scan a function along some input sequence, producing an output at each time-step that can be seen (but not modified) by the function at the next time-step. (Technically, the function can see the previous K time-steps of your outputs and L time steps (from past and future) of your inputs.

So for example, sum() could be computed by scanning the z+x_i function over a list, given an initial state of z=0.

Special cases:

  • A reduce operation can be performed by using only the last output of a scan.

  • A map operation can be performed by applying a function that ignores previous steps of the outputs.

Often a for-loop or while-loop can be expressed as a scan() operation, and scan is the closest that aesara comes to looping. The advantages of using scan over for loops in python (among others) are:

  • it allows the number of iterations to be part of the symbolic graph

  • it allows computing gradients through the for loop

  • there exist a bunch of optimizations that help re-write your loop such that less memory is used and that it runs faster

The Scan Op should typically be used by calling any of the following functions: scan(), map(), reduce(), foldl(), foldr().

aesara.scan#

aesara.scan(fn, sequences=None, outputs_info=None, non_sequences=None, n_steps=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None, profile=False, allow_gc=None, strict=False, return_list=False)[source]

This function constructs and applies a Scan Op to the provided arguments.

Parameters:
  • fn

    fn is a function that describes the operations involved in one step of scan. fn should construct variables describing the output of one iteration step. It should expect as input Variables representing all the slices of the input sequences and previous values of the outputs, as well as all other arguments given to scan as non_sequences. The order in which scan passes these variables to fn is the following :

    • all time slices of the first sequence

    • all time slices of the second sequence

    • all time slices of the last sequence

    • all past slices of the first output

    • all past slices of the second output

    • all past slices of the last output

    • all other arguments (the list given as non_sequences to

      scan)

    The order of the sequences is the same as the one in the list sequences given to scan. The order of the outputs is the same as the order of outputs_info. For any sequence or output the order of the time slices is the same as the one in which they have been given as taps. For example if one writes the following :

    scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1])
                         , Sequence2
                         , dict(input =  Sequence3, taps = 3) ]
           , outputs_info = [ dict(initial =  Output1, taps = [-3,-5])
                            , dict(initial = Output2, taps = None)
                            , Output3 ]
           , non_sequences = [ Argument1, Argument2])
    

    fn should expect the following arguments in this given order:

    1. sequence1[t-3]

    2. sequence1[t+2]

    3. sequence1[t-1]

    4. sequence2[t]

    5. sequence3[t+3]

    6. output1[t-3]

    7. output1[t-5]

    8. output3[t-1]

    9. argument1

    10. argument2

    The list of non_sequences can also contain shared variables used in the function, though scan is able to figure those out on its own so they can be skipped. For the clarity of the code we recommend though to provide them to scan. To some extend scan can also figure out other non sequences (not shared) even if not passed to scan (but used by fn). A simple example of this would be :

    import aesara.tensor as at
    
    W   = at.matrix()
    W_2 = W**2
    
    def f(x):
        return at.dot(x,W_2)
    

    The function fn is expected to return two things. One is a list of outputs ordered in the same order as outputs_info, with the difference that there should be only one output variable per output initial state (even if no tap value is used). Secondly fn should return an update dictionary (that tells how to update any shared variable after each iteration step). The dictionary can optionally be given as a list of tuples. There is no constraint on the order of these two list, fn can return either (outputs_list, update_dictionary) or (update_dictionary, outputs_list) or just one of the two (in case the other is empty).

    To use scan as a while loop, the user needs to change the function fn such that also a stopping condition is returned. To do so, one needs to wrap the condition in an until class. The condition should be returned as a third element, for example:

    ...
    return [y1_t, y2_t], {x:x+1}, until(x < 50)
    

    Note that a number of steps–considered in here as the maximum number of steps–is still required even though a condition is passed. It is used to allocate memory if needed.

  • sequences

    sequences is the list of Variables or dicts describing the sequences scan has to iterate over. If a sequence is given as wrapped in a dict, then a set of optional information can be provided about the sequence. The dict should have the following keys:

    • input (mandatory) – Variable representing the sequence.

    • taps – Temporal taps of the sequence required by fn. They are provided as a list of integers, where a value k impiles that at iteration step t scan will pass to fn the slice t+k. Default value is [0]

    All Variables in the list sequences are automatically wrapped into a dict where taps is set to [0]

  • outputs_info

    outputs_info is the list of Variables or dicts describing the initial state of the outputs computed recurrently. When the initial states are given as dicts, optional information can be provided about the output corresponding to those initial states. The dict should have the following keys:

    • initial – A Variable that represents the initial state of a given output. In case the output is not computed recursively (e.g. a map-like function) and does not require an initial state, this field can be skipped. Given that only the previous time step of the output is used by fn, the initial state should have the same shape as the output and should not involve a downcast of the data type of the output. If multiple time taps are used, the initial state should have one extra dimension that covers all the possible taps. For example if we use -5, -2 and -1 as past taps, at step 0, fn will require (by an abuse of notation) output[-5], output[-2] and output[-1]. This will be given by the initial state, which in this case should have the shape (5,) + output.shape. If this Variable containing the initial state is called init_y then init_y[0] corresponds to output[-5]. init_y[1] corresponds to output[-4], init_y[2] corresponds to output[-3], init_y[3] corresponds to output[-2], init_y[4] corresponds to output[-1]. While this order might seem strange, it comes natural from splitting an array at a given point. assume that we have a array x, and we choose k to be time step 0. Then our initial state would be x[:k], while the output will be x[k:]. Looking at this split, elements in x[:k] are ordered exactly like those in init_y.

    • taps – Temporal taps of the output that will be passed to fn. They are provided as a list of negative integers, where a value k implies that at iteration step t scan will pass to fn the slice t+k.

    scan will follow this logic if partial information is given:

    • If an output is not wrapped in a dict, scan will wrap it in one assuming that you use only the last step of the output (i.e. it makes your tap value list equal to [-1]).

    • If you wrap an output in a dict and you do not provide any taps but you provide an initial state it will assume that you are using only a tap value of -1.

    • If you wrap an output in a dict but you do not provide any initial state, it assumes that you are not using any form of taps.

    • If you provide a None instead of a Variable or a empty dict scan assumes that you will not use any taps for this output (like for example in case of a map)

    If outputs_info is an empty list or None, scan assumes that no tap is used for any of the outputs. If information is provided just for a subset of the outputs, an exception is raised, because there is no convention on how scan should map the provided information to the outputs of fn.

  • non_sequencesnon_sequences is the list of arguments that are passed to fn at each steps. One can choose to exclude variables used in fn from this list, as long as they are part of the computational graph, although–for clarity–this is not encouraged.

  • n_stepsn_steps is the number of steps to iterate given as an int or a scalar Variable. If any of the input sequences do not have enough elements, scan will raise an error. If the value is 0, the outputs will have 0 rows. If n_steps is not provided, scan will figure out the amount of steps it should run given its input sequences. n_steps < 0 is not supported anymore.

  • truncate_gradienttruncate_gradient is the number of steps to use in truncated back-propagation through time (BPTT). If you compute gradients through a Scan Op, they are computed using BPTT. By providing a different value then -1, you choose to use truncated BPTT instead of classical BPTT, where you go for only truncate_gradient number of steps back in time.

  • go_backwardsgo_backwards is a flag indicating if scan should go backwards through the sequences. If you think of each sequence as indexed by time, making this flag True would mean that scan goes back in time, namely that for any sequence it starts from the end and goes towards 0.

  • name – When profiling scan, it is helpful to provide a name for any instance of scan. For example, the profiler will produce an overall profile of your code as well as profiles for the computation of one step of each instance of Scan. The name of the instance appears in those profiles and can greatly help to disambiguate information.

  • mode – The mode used to compile the inner-graph. If you prefer the computations of one step of scan to be done differently then the entire function, you can use this parameter to describe how the computations in this loop are done (see aesara.function for details about possible values and their meaning).

  • profile – If True or a non-empty string, a profile object will be created and attached to the inner graph of Scan. When profile is True, the profiler results will use the name of the Scan instance, otherwise it will use the passed string. The profiler only collects and prints information when running the inner graph with the CVM Linker.

  • allow_gc

    Set the value of allow_gc for the internal graph of the Scan. If set to None, this will use the value of aesara.config.scan__allow_gc.

    The full Scan behavior related to allocation is determined by this value and the flag aesara.config.allow_gc. If the flag allow_gc is True (default) and this allow_gc is False (default), then we let Scan allocate all intermediate memory on the first iteration, and they are not garbage collected after that first iteration; this is determined by allow_gc. This can speed up allocation of the subsequent iterations. All those temporary allocations are freed at the end of all iterations; this is what the flag aesara.config.allow_gc means.

  • strict – If True, all the shared variables used in fn must be provided as a part of non_sequences or sequences.

  • return_list – If True, will always return a list, even if there is only one output.

Returns:

tuple of the form (outputs, updates). outputs is either a Variable or a list of Variables representing the outputs in the same order as in outputs_info. updates is a subclass of dict specifying the update rules for all shared variables used in Scan. This dict should be passed to aesara.function when you compile your function.

Return type:

tuple

Other ways to create loops#

aesara.scan() comes with bells and whistles that are not always all necessary, which is why Aesara provides several other functions to create a Scan operator:

aesara.map(fn, sequences, non_sequences=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None)[source]#

Construct a Scan Op that functions like map.

Parameters:
  • fn – The function that map applies at each iteration step (see scan for more info).

  • sequences – List of sequences over which map iterates (see scan for more info).

  • non_sequences – List of arguments passed to fn. map will not iterate over these arguments (see scan for more info).

  • truncate_gradient – See scan.

  • go_backwards (bool) – Decides the direction of iteration. True means that sequences are parsed from the end towards the beginning, while False is the other way around.

  • mode – See scan.

  • name – See scan.

aesara.reduce(fn, sequences, outputs_info, non_sequences=None, go_backwards=False, mode=None, name=None)[source]#

Construct a Scan Op that functions like reduce.

Parameters:
  • fn – The function that reduce applies at each iteration step (see scan for more info).

  • sequences – List of sequences over which reduce iterates (see scan for more info).

  • outputs_info – List of dictionaries describing the outputs of reduce (see scan for more info).

  • non_sequences

    List of arguments passed to fn. reduce will

    not iterate over these arguments (see scan for more info).

  • go_backwards (bool) – Decides the direction of iteration. True means that sequences are parsed from the end towards the beginning, while False is the other way around.

  • mode – See scan.

  • name – See scan.

aesara.foldl(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None)[source]#

Construct a Scan Op that functions like Haskell’s foldl.

Parameters:
  • fn – The function that foldl applies at each iteration step (see scan for more info).

  • sequences – List of sequences over which foldl iterates (see scan for more info).

  • outputs_info – List of dictionaries describing the outputs of reduce (see scan for more info).

  • non_sequences – List of arguments passed to fn. foldl will not iterate over these arguments (see scan for more info).

  • mode – See scan.

  • name – See scan.

aesara.foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None)[source]#

Construct a Scan Op that functions like Haskell’s foldr.

Parameters:
  • fn – The function that foldr applies at each iteration step (see scan for more info).

  • sequences – List of sequences over which foldr iterates (see scan for more info).

  • outputs_info – List of dictionaries describing the outputs of reduce (see scan for more info).

  • non_sequences – List of arguments passed to fn. foldr will not iterate over these arguments (see scan for more info).

  • mode – See scan.

  • name – See scan.