This section of the documentation references all the rewrites that can be applied during the compilation of an Aesara graph

Tensor rewrites#

These rewrites are implemented in the module tensor.rewriting.basic.

Tensor optimizations addressing the ops in

aesara.tensor.rewriting.basic.broadcast_like(value, template, fgraph, dtype=None)[source]#

Return a Variable with the same shape and dtype as the template, filled by broadcasting value through it. value will be cast as necessary.

aesara.tensor.rewriting.basic.encompasses_broadcastable(b1, b2)[source]#
  • b1 – The broadcastable attribute of a tensor type.

  • b2 – The broadcastable attribute of a tensor type.


True if the broadcastable patterns b1 and b2 are such that b2 is broadcasted to b1’s shape and not the opposite.

Return type:


aesara.tensor.rewriting.basic.is_an_upcast(type1, type2)[source]#

Given two data types (as strings), check if converting to type2 from type1 constitutes an upcast. Differs from aesara.scalar.upcast



Determine the axis at which an array index is applied.

This only works for take-like indices: e.g. x[:, :, idx, ...]. For the above example, get_advsubtensor_axis would return 2. If it encounters anything other than a set of indices containing full slices and an array/tensor index, it will return None.


Determine if x is a slice(None) or a symbolic equivalent.

aesara.tensor.rewriting.subtensor.merge_two_slices(fgraph, slice1, len1, slice2, len2)[source]#

This function merges two slices into a single slice. The code works on the assumption that:

  1. slice1 is actually a slice and not an index, while slice2 can be just an index.

  2. the two slices have been applied consecutively on the same tensor

The output slice is not in canonical form, but actually just a slice that can be applied to a tensor to produce the same output as applying the two consecutive slices. len1 is the length of the tensor before applying the first slice, while len2 is the length after applying the first slice.

aesara.tensor.rewriting.subtensor.transform_take(a, indices, axis)[source]#

Transform arr[:,:,:,indices,...]-like operations into single-dimensional, vector index operations.

This effectively converts certain AdvancedSubtensor Ops into a combination of AdvancedSubtensor1, Dimshuffle, and Reshape Ops, which can be more efficient.

  • a (TensorVariable) – The source array.

  • indices (TensorVariable, ndarray, list, tuple) – The indices of the values to extract.

  • axis (int) – The axis over which to select values. By default, the flattened input array is used.


class aesara.tensor.rewriting.shape.ShapeFeature[source]#

A Feature that tracks shape information in a graph.

This Feature aids in the replacement of all Shapes and Subtensors of Shapes with Shape_i and MakeVector Ops.

This Feature and its associated rewrites have several goals:

  1. to “lift” Shapes to as close to the inputs as possible,

  2. to infer the shape of every node in the graph in terms of the input shapes, and

  3. remove fill Ops (e.g. Second) from the graph.

Lifting shapes as close to the inputs as possible is important for canonicalization because it is very bad form to have to compute something just to know how big it will be. Firstly, it is a waste of time to compute such outputs. But it is important to get rid of these outputs as early as possible in the compilation process because the extra computations make it appear as if many internal graph nodes have multiple clients. Many rewrites refuse to work on nodes with multiple clients.

Lifting is done by using an Op.infer_shape() method if one is present, or else using a conservative default..

Inferring the shape of internal nodes in the graph is important for doing size-driven rewrites. If we know how big various intermediate results will be, we can estimate the cost of many Ops accurately, and generate code that is specific (e.g. unrolled) to particular sizes.

In cases where ShapeFeature cannot figure out the shape, it raises a ShapeError.


We can’t automatically infer the shape of shared variables as they can change of shape during the execution by default.

To use the shape information gathered by a FunctionGraph-attached ShapeFeature in rewrites, use the ShapeFeature.get_shape() method.


Create a clone that can be attached to a new FunctionGraph.

This default implementation returns self, which carries the assumption that the Feature is essentially stateless. If a subclass has state of its own that is in any way relative to a given FunctionGraph, this method should be overridden with an implementation that actually creates a fresh copy.

default_infer_shape(fgraph: FunctionGraph, node: Apply, i_shapes: InputShapesType) OutputShapesType[source]#

Return a list of shape tuple or None for the outputs of node.

This function is used for Ops that don’t implement infer_shape. Ops that do implement infer_shape should use the i_shapes parameter, but this default implementation ignores it.

get_shape(fgraph: FunctionGraph, var: Variable, idx: int) Variable[source]#

Get the shape of var at index idx.

It is better to call this than use ShapeFeature.shape_of[var][idx], since this method will update ShapeFeature.shape_of when needed.

TODO: Up to now, we don’t update it in all cases. Update in all cases.

init_r(r: Variable) None[source]#

Register r’s shape in the shape_of dictionary.


Called by FunctionGraph.attach_feature, the method that attaches the feature to the FunctionGraph. Since this is called after the FunctionGraph is initially populated, this is where you should run checks on the initial contents of the FunctionGraph.

The on_attach method may raise the AlreadyThere exception to cancel the attach operation if it detects that another Feature instance implementing the same functionality is already attached to the FunctionGraph.

The feature has great freedom in what it can do with the fgraph: it may, for example, add methods to it dynamically.

on_change_input(fgraph, node, i, r, new_r, reason)[source]#

Called whenever node.inputs[i] is changed from var to new_var. At the moment the callback is done, the change has already taken place.

If you raise an exception in this function, the state of the graph might be broken for all intents and purposes.


Called by FunctionGraph.remove_feature. Should remove any dynamically-added functionality that it installed into the fgraph.

on_import(fgraph, node, reason)[source]#

Called whenever a node is imported into fgraph, which is just before the node is actually connected to the graph.

Note: this is not called when the graph is created. If you want to detect the first nodes to be implemented to the graph, you should do this by implementing on_attach.

same_shape(x: Variable, y: Variable, dim_x: int | None = None, dim_y: int | None = None) bool[source]#

Return True if x and y have the same shape.

  • x – The Variable for which its shape is to be compared with y’s shape.

  • y – The Variable for which its shape is to be compared with x’s shape.

  • dim_x – If non None, compare only the dimension of x equal to dim_x.

  • dim_y – If non None, compare only the dimension of y equal to dim_y.

set_shape(r: Variable, s: Sequence[Variable] | None, override: bool = False) None[source]#

Assign the shape s to previously un-shaped variable r.

  • r

  • s

  • override – If False, it means r is a new, unseen term. If True, it means r is assumed to have already been seen and we want to override its shape.

set_shape_i(r: Variable, i: int, s_i: Variable) None[source]#

Replace element i of shape_of[r] by s_i

shape_ir(i: int, r: Variable) Variable[source]#

Return symbolic r.shape[i].

shape_tuple(r: Variable) Tuple[Variable, ...] | None[source]#

Return a tuple of symbolic shape vars for tensor variable r.

to_symbolic_int(s_i: int | float | integer | ArrayLike | Variable) Variable[source]#

Return a symbolic integer scalar for the shape element s_i.

TODO: Re-evaluate the need for this, since it’s effectively eager canonicalization.


s_i – The s_i argument is assumed to be produced by an Op.infer_shape().

update_shape(r: Variable, other_r: Variable) None[source]#

Replace shape of r by shape of other_r.

If, on some dimensions, the shape of other_r is not informative, keep the shape of r on those dimensions.

class aesara.tensor.rewriting.shape.ShapeOptimizer[source]#

Rewriter that adds ShapeFeature as a feature.


Add Features and other requirements to a FunctionGraph.


Apply the rewriter to a FunctionGraph.

It may use all the methods defined by the FunctionGraph. If the GraphRewriter needs to use a certain tool, such as an InstanceFinder, it can do so in its add_requirements method.

class aesara.tensor.rewriting.shape.UnShapeOptimizer[source]#

Rewriter that removes ShapeFeature as a feature.


Apply the rewriter to a FunctionGraph.

It may use all the methods defined by the FunctionGraph. If the GraphRewriter needs to use a certain tool, such as an InstanceFinder, it can do so in its add_requirements method.

Mathematical operations#

Rewrites for the Ops in aesara.tensor.math.

class aesara.tensor.rewriting.math.AlgebraicCanonizer(main, inverse_fn, reciprocal_fn, calculate, use_reciprocal=True)[source]#

A Rewriter that rewrites algebraic expressions.

The variable is a node_rewriter. It is best used with a WalkingGraphRewriter in in-to-out order.

Usage: AlgebraicCanonizer(main, inverse, reciprocal, calculate)

  • main – A suitable Op class that is commutative, associative and takes one to an arbitrary number of inputs, e.g. add or mul

  • inverse – An Op class such that inverse(main(x, y), y) == x (e.g. sub or true_divide).

  • reciprocal – A function such that main(x, reciprocal(y)) == inverse(x, y) (e.g. neg or reciprocal).

  • calculate – Function that takes a list of numpy.ndarray instances for the numerator, another list for the denumerator, and calculates inverse(main(\*num), main(\*denum)). It takes a keyword argument, aslist. If True, the value should be returned as a list of one element, unless the value is such that value = main(). In that case, the return value should be an empty list.


>>> import aesara.tensor as at
>>> from aesara.tensor.rewriting.math import AlgebraicCanonizer
>>> add_canonizer = AlgebraicCanonizer(add, sub, neg, \\
...                                    lambda n, d: sum(n) - sum(d))
>>> mul_canonizer = AlgebraicCanonizer(mul, true_divide, inv, \\
...                                    lambda n, d: prod(n) / prod(d))

Examples of rewrites mul_canonizer can perform:

x / x -> 1
(x * y) / x -> y
x / y / x -> 1 / y
x / y / z -> x / (y * z)
x / (y / z) -> (x * z) / y
(a / b) * (b / c) * (c / d) -> a / d
(2.0 * x) / (4.0 * y) -> (0.5 * x) / y
2 * x / 2 -> x
x * y * z -> Elemwise(mul){x,y,z} #only one pass over the memory.
!-> Elemwise(mul){x,Elemwise(mul){y,z}}

This extract two lists, num and denum, such that the input is: self.inverse(self.main(\*num), self.main(\*denum)). It returns the two lists in a (num, denum) pair.

For example, for main, inverse and reciprocal = \*, / and inv(),

input -> returned value (num, denum)
x*y -> ([x, y], [])
inv(x) -> ([], [x])
inv(x) * inv(y) -> ([], [x, y])
x*y/z -> ([x, y], [z])
log(x) / y * (z + x) / y -> ([log(x), z + x], [y, y])
(((a / b) * c) / d) -> ([a, c], [b, d])
a / (b / c) -> ([a, c], [b])
log(x) -> ([log(x)], [])
x**y -> ([x**y], [])
x * y * z -> ([x, y, z], [])
merge_num_denum(num, denum)[source]#

Utility function which takes two lists, num and denum, and returns something which is equivalent to inverse(main(*num), main(*denum)), but depends on the length of num and the length of denum (in order to minimize the number of operations).

Let n = len(num) and d = len(denum):

n=0, d=0: neutral element (given by self.calculate([], []))
(for example, this would be 0 if main is addition
and 1 if main is multiplication)
n=1, d=0: num[0]
n=0, d=1: reciprocal(denum[0])
n=1, d=1: inverse(num[0], denum[0])
n=0, d>1: reciprocal(main(*denum))
n>1, d=0: main(*num)
n=1, d>1: inverse(num[0], main(*denum))
n>1, d=1: inverse(main(*num), denum[0])
n>1, d>1: inverse(main(*num), main(*denum))

Given the values of n and d to which they are associated, all of the above are equivalent to: inverse(main(*num), main(*denum))

simplify(num, denum, out_type)[source]#

Shorthand for:

self.simplify_constants(*self.simplify_factors(num, denum))
simplify_constants(orig_num, orig_denum, out_type=None)[source]#

Find all constants and put them together into a single constant.

Finds all constants in orig_num and orig_denum (using get_constant) and puts them together into a single constant. The constant is inserted as the first element of the numerator. If the constant is the neutral element, it is removed from the numerator.


Let main be multiplication:

[2, 3, x], [] -> [6, x], []
[x, y, 2], [4, z] -> [0.5, x, y], [z]
[x, 2, y], [z, 2] -> [x, y], [z]
simplify_factors(num, denum)[source]#

For any Variable r which is both in num and denum, removes it from both lists. Modifies the lists inplace. Returns the modified lists. For example:

[x], [x] -> [], []
[x, y], [x] -> [y], []
[a, b], [c, d] -> [a, b], [c, d]

Return the list of Op classes to which this rewrite applies.

Returns None when the rewrite applies to all nodes.

transform(fgraph, node)[source]#

Rewrite the sub-graph given by node.

Subclasses should implement this function so that it returns one of the following:

  • False to indicate that this rewrite cannot be applied to node

  • A list of Variables to use in place of the node’s current outputs

  • A dict mapping old Variables to Variables, or the key

"remove" mapping to a list of Variables to be removed.

  • fgraph – A FunctionGraph containing node.

  • node – An Apply node to be rewritten.

aesara.tensor.rewriting.math.attempt_distribution(factor, num, denum, out_type)[source]#

Try to insert each num and each denum in the factor?


If there are changes, new_num and new_denum contain all the numerators and denominators that could not be distributed in the factor

Return type:

changes?, new_factor, new_num, new_denum

aesara.tensor.rewriting.math.check_for_x_over_absX(numerators, denominators)[source]#

Convert x/abs(x) into sign(x).


Compute the Variable that is the output of a multiplication tree.

This is the inverse of the operation performed by parse_mul_tree, i.e. compute_mul(parse_mul_tree(tree)) == tree.


tree – A multiplication tree (as output by parse_mul_tree).


A Variable that computes the multiplication represented by the tree.

Return type:



A numeric constant if v is a Constant or, well, a numeric constant. If v is a plain Variable, returns None.

Return type:


aesara.tensor.rewriting.math.is_1pexp(t, only_process_constants=True)[source]#

If ‘t’ is of the form (1+exp(x)), return (False, x). Else return None.

Return type:



Match a variable with either of the exp(x) or -exp(x) patterns.


var – The Variable to analyze.


A pair (b, x) with b a boolean set to True if var is of the form -exp(x) and False if var is of the form exp(x). If var cannot be cast into either form, then return None.

Return type:


aesara.tensor.rewriting.math.is_inverse_pair(node_op, prev_op, inv_pair)[source]#

Given two consecutive operations, check if they are the provided pair of inverse functions.


Match a variable with x * y * z * ....


var – The Variable to analyze.


A list [x, y, z, …] if var is of the form x * y * z * ..., or None if var cannot be cast into this form.

Return type:



Match a variable with the -x pattern.


var – The Variable to analyze.


x if var is of the form -x, or None otherwise.

Return type:


aesara.tensor.rewriting.math.local_add_mul_fusion(fgraph, node)[source]#

Fuse consecutive add or mul in one such node with more inputs.

It is better to fuse add/mul that way then in a Composite node as this make the inner graph of the Composite smaller. This allow to put more computation in a Composite before hitting the max recursion limit when pickling Composite.


Parse a tree of multiplications starting at the given root.


root – The variable at the root of the tree.


A tree where each non-leaf node corresponds to a multiplication in the computation of root, represented by the list of its inputs. Each input is a pair [n, x] with n a boolean value indicating whether sub-tree x should be negated.

Return type:



x * y               -> [False, [[False, x], [False, y]]]
-(x * y)            -> [True, [[False, x], [False, y]]]
-x * y              -> [False, [[True, x], [False, y]]]
-x                  -> [True, x]
(x * y) * -z        -> [False, [[False, [[False, x], [False, y]]],
                                [True, z]]]
aesara.tensor.rewriting.math.perform_sigm_times_exp(tree, exp_x=None, exp_minus_x=None, sigm_x=None, sigm_minus_x=None, parent=None, child_idx=None, full_tree=None)[source]#

Core processing of the local_sigm_times_exp rewrite.

This recursive function operates on a multiplication tree as output by parse_mul_tree. It walks through the tree and modifies it in-place by replacing matching pairs (exp, sigmoid) with the desired version.

  • tree – The sub-tree to operate on.

  • exp_x – List of arguments x so that exp(x) exists somewhere in the whole multiplication tree. Each argument is a pair (x, leaf) with x the argument of the exponential, and leaf the corresponding leaf in the multiplication tree (of the form [n, exp(x)] – see parse_mul_tree). If None, this argument is initialized to an empty list.

  • exp_minus_x – Similar to exp_x, but for exp(-x).

  • sigm_x – Similar to exp_x, but for sigmoid(x).

  • sigm_minus_x – Similar to exp_x, but for sigmoid(-x).

  • parent – Parent of tree (None if tree is the global root).

  • child_idx – Index of tree in its parent’s inputs (None if tree is the global root).

  • full_tree – The global multiplication tree (should not be set except by recursive calls to this function). Used for debugging only.


True if a modification was performed somewhere in the whole multiplication tree, or False otherwise.

Return type:


aesara.tensor.rewriting.math.replace_leaf(arg, leaves, new_leaves, op, neg)[source]#

Attempt to replace a leaf of a multiplication tree.

We search for a leaf in leaves whose argument is arg, and if we find one, we remove it from leaves and add to new_leaves a leaf with argument arg and variable op(arg).

  • arg – The argument of the leaf we are looking for.

  • leaves – List of leaves to look into. Each leaf should be a pair (x, l) with x the argument of the Op found in the leaf, and l the actual leaf as found in a multiplication tree output by parse_mul_tree (i.e. a pair [boolean, variable]).

  • new_leaves – If a replacement occurred, then the leaf is removed from leaves and added to the list new_leaves (after being modified by op).

  • op – A function that, when applied to arg, returns the Variable we want to replace the original leaf variable with.

  • neg (bool) – If True, then the boolean value associated to the leaf should be swapped. If False, then this value should remain unchanged.


True if a replacement occurred, or False otherwise.

Return type:


aesara.tensor.rewriting.math.scalarconsts_rest(inputs, elemwise=True, only_process_constants=False)[source]#

Partition a list of variables into two kinds: scalar constants, and the rest.


Simplify a multiplication tree.


tree – A multiplication tree (as output by parse_mul_tree).


A multiplication tree computing the same output as tree but without useless multiplications by 1 nor -1 (identified by leaves of the form [False, None] or [True, None] respectively). Useless multiplications (with less than two inputs) are also removed from the tree.

Return type:


class aesara.tensor.rewriting.elemwise.FusionOptimizer(node_rewriter)[source]#

Graph rewriter that simply runs node fusion operations.

TODO: This is basically an EquilibriumGraphRewriter; we should just use that.


Add Features and other requirements to a FunctionGraph.


Apply the rewriter to a FunctionGraph.

It may use all the methods defined by the FunctionGraph. If the GraphRewriter needs to use a certain tool, such as an InstanceFinder, it can do so in its add_requirements method.

class aesara.tensor.rewriting.elemwise.InplaceElemwiseOptimizer(OP)[source]#

This is parameterized so that it works for Elemwise Ops.


Add Features and other requirements to a FunctionGraph.


Attempts to replace all Elemwises by versions of them that operate inplace. It operates greedily: for each Elemwise that is encountered, for each output, it tries each input to see if it can operate inplace on that input. If so, it makes the change and goes to the next output or Elemwise.


x + y + z -> x += y += z (x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)

print_summary(stream=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>, level=0, depth=-1)[source]#

Print a single-line, indented representation of the rewriter.

aesara.tensor.rewriting.elemwise.apply_local_dimshuffle_lift(fgraph, var)[source]#

lift recursively

aesara.tensor.rewriting.elemwise.is_dimshuffle_useless(new_order, input)[source]#
Checks for two types of useless dimshuffles:

1 - dimshuffle all dimensions in order. 2 - dimshuffle a broadcastable dimension.

aesara.tensor.rewriting.elemwise.local_elemwise_fusion(fgraph, node)[source]#

Fuse Elemwise Ops in a node.

As part of specialization, we fuse two consecutive Elemwise Ops of the same shape.

For mixed dtype, we let the Composite Op do the cast. It lets the C compiler do the cast.

The number of dimensions is validated at call time by Aesara itself.

aesara.tensor.rewriting.elemwise.local_elemwise_fusion_op(op_class, max_input_fct=<function <lambda>>, maker=None)[source]#

Create a recursive function that fuses Elemwise Ops.

The basic idea is that we loop through an Elemwise node’s inputs, find other Elemwise nodes, determine the scalars input types for all of the Elemwise Ops, construct a new scalar Op using the scalar input types and each Elemwise’s scalar Op, and use the composite scalar Op in a new “fused” Elemwise.

It’s parameterized in order to work for Elemwise Ops.

  • op_class (type) – Elemwise class (the one that we want to fuse)

  • max_input_fct (callable) – A function that returns the maximum number of inputs that this Elemwise can take. On the CPU we limit to 32 input variables since that is the maximum NumPy support.

  • maker (callable) – A function with the signature (node, *args) that constructs an op_class instance (e.g. op_class(*args)).

Random variables#

aesara.tensor.random.rewriting.basic.is_rv_used_in_graph(base_rv, node, fgraph)[source]#

Determine whether or not base_rv is used by a node other than node in fgraph.

If a node uses Shape or Shape_i on the base_rv, we ignore it, because those Op`s don't rely on the actual sample values of `base_rv.

TODO: We should apply all the shape rewrites before these rewrites, since that would properly remove the unnecessary dependencies on base_rv (when possible).