tensor.rewriting.math – Tensor Rewrites for Math Operations

Rewrites for the Ops in aesara.tensor.math.

class aesara.tensor.rewriting.math.AlgebraicCanonizer(main, inverse_fn, reciprocal_fn, calculate, use_reciprocal=True)[source]

A Rewriter that rewrites algebraic expressions.

The variable is a node_rewriter. It is best used with a WalkingGraphRewriter in in-to-out order.

Usage: AlgebraicCanonizer(main, inverse, reciprocal, calculate)

  • main – A suitable Op class that is commutative, associative and takes one to an arbitrary number of inputs, e.g. add or mul
  • inverse – An Op class such that inverse(main(x, y), y) == x (e.g. sub or true_div).
  • reciprocal – A function such that main(x, reciprocal(y)) == inverse(x, y) (e.g. neg or reciprocal).
  • calculate – Function that takes a list of numpy.ndarray instances for the numerator, another list for the denumerator, and calculates inverse(main(\*num), main(\*denum)). It takes a keyword argument, aslist. If True, the value should be returned as a list of one element, unless the value is such that value = main(). In that case, the return value should be an empty list.


>>> import aesara.tensor as at
>>> from aesara.tensor.rewriting.math import AlgebraicCanonizer
>>> add_canonizer = AlgebraicCanonizer(add, sub, neg, \\
...                                    lambda n, d: sum(n) - sum(d))
>>> mul_canonizer = AlgebraicCanonizer(mul, true_div, inv, \\
...                                    lambda n, d: prod(n) / prod(d))

Examples of rewrites mul_canonizer can perform:

x / x -> 1
(x * y) / x -> y
x / y / x -> 1 / y
x / y / z -> x / (y * z)
x / (y / z) -> (x * z) / y
(a / b) * (b / c) * (c / d) -> a / d
(2.0 * x) / (4.0 * y) -> (0.5 * x) / y
2 * x / 2 -> x
x * y * z -> Elemwise(mul){x,y,z} #only one pass over the memory.
!-> Elemwise(mul){x,Elemwise(mul){y,z}}

This extract two lists, num and denum, such that the input is: self.inverse(self.main(\*num), self.main(\*denum)). It returns the two lists in a (num, denum) pair.

For example, for main, inverse and reciprocal = \*, / and inv(),

input -> returned value (num, denum)
x*y -> ([x, y], [])
inv(x) -> ([], [x])
inv(x) * inv(y) -> ([], [x, y])
x*y/z -> ([x, y], [z])
log(x) / y * (z + x) / y -> ([log(x), z + x], [y, y])
(((a / b) * c) / d) -> ([a, c], [b, d])
a / (b / c) -> ([a, c], [b])
log(x) -> ([log(x)], [])
x**y -> ([x**y], [])
x * y * z -> ([x, y, z], [])
merge_num_denum(num, denum)[source]

Utility function which takes two lists, num and denum, and returns something which is equivalent to inverse(main(*num), main(*denum)), but depends on the length of num and the length of denum (in order to minimize the number of operations).

Let n = len(num) and d = len(denum):

n=0, d=0: neutral element (given by self.calculate([], []))
(for example, this would be 0 if main is addition
and 1 if main is multiplication)
n=1, d=0: num[0]
n=0, d=1: reciprocal(denum[0])
n=1, d=1: inverse(num[0], denum[0])
n=0, d>1: reciprocal(main(*denum))
n>1, d=0: main(*num)
n=1, d>1: inverse(num[0], main(*denum))
n>1, d=1: inverse(main(*num), denum[0])
n>1, d>1: inverse(main(*num), main(*denum))

Given the values of n and d to which they are associated, all of the above are equivalent to: inverse(main(*num), main(*denum))

simplify(num, denum, out_type)[source]

Shorthand for:

self.simplify_constants(*self.simplify_factors(num, denum))
simplify_constants(orig_num, orig_denum, out_type=None)[source]

Find all constants and put them together into a single constant.

Finds all constants in orig_num and orig_denum (using get_constant) and puts them together into a single constant. The constant is inserted as the first element of the numerator. If the constant is the neutral element, it is removed from the numerator.


Let main be multiplication:

[2, 3, x], [] -> [6, x], []
[x, y, 2], [4, z] -> [0.5, x, y], [z]
[x, 2, y], [z, 2] -> [x, y], [z]
simplify_factors(num, denum)[source]

For any Variable r which is both in num and denum, removes it from both lists. Modifies the lists inplace. Returns the modified lists. For example:

[x], [x] -> [], []
[x, y], [x] -> [y], []
[a, b], [c, d] -> [a, b], [c, d]

Return the list of Op classes to which this rewrite applies.

Returns None when the rewrite applies to all nodes.

transform(fgraph, node)[source]

Rewrite the sub-graph given by node.

Subclasses should implement this function so that it returns one of the following:

  • False to indicate that this rewrite cannot be applied to node
  • A list of Variables to use in place of the node’s current outputs
  • A dict mapping old Variables to Variables, or the key

"remove" mapping to a list of Variables to be removed.

  • fgraph – A FunctionGraph containing node.
  • node – An Apply node to be rewritten.
aesara.tensor.rewriting.math.attempt_distribution(factor, num, denum, out_type)[source]

Try to insert each num and each denum in the factor?

Returns:If there are changes, new_num and new_denum contain all the numerators and denominators that could not be distributed in the factor
Return type:changes?, new_factor, new_num, new_denum
aesara.tensor.rewriting.math.check_for_x_over_absX(numerators, denominators)[source]

Convert x/abs(x) into sign(x).


Compute the Variable that is the output of a multiplication tree.

This is the inverse of the operation performed by parse_mul_tree, i.e. compute_mul(parse_mul_tree(tree)) == tree.

Parameters:tree – A multiplication tree (as output by parse_mul_tree).
Returns:A Variable that computes the multiplication represented by the tree.
Return type:object
Returns:A numeric constant if v is a Constant or, well, a numeric constant. If v is a plain Variable, returns None.
Return type:object
aesara.tensor.rewriting.math.is_1pexp(t, only_process_constants=True)[source]
Returns:If ‘t’ is of the form (1+exp(x)), return (False, x). Else return None.
Return type:object

Match a variable with either of the exp(x) or -exp(x) patterns.

Parameters:var – The Variable to analyze.
Returns:A pair (b, x) with b a boolean set to True if var is of the form -exp(x) and False if var is of the form exp(x). If var cannot be cast into either form, then return None.
Return type:tuple
aesara.tensor.rewriting.math.is_inverse_pair(node_op, prev_op, inv_pair)[source]

Given two consecutive operations, check if they are the provided pair of inverse functions.


Match a variable with x * y * z * ....

Parameters:var – The Variable to analyze.
Returns:A list [x, y, z, …] if var is of the form x * y * z * ..., or None if var cannot be cast into this form.
Return type:object

Match a variable with the -x pattern.

Parameters:var – The Variable to analyze.
Returns:x if var is of the form -x, or None otherwise.
Return type:object
aesara.tensor.rewriting.math.local_add_mul_fusion(fgraph, node)[source]

Fuse consecutive add or mul in one such node with more inputs.

It is better to fuse add/mul that way then in a Composite node as this make the inner graph of the Composite smaller. This allow to put more computation in a Composite before hitting the max recursion limit when pickling Composite.


Parse a tree of multiplications starting at the given root.

Parameters:root – The variable at the root of the tree.
Returns:A tree where each non-leaf node corresponds to a multiplication in the computation of root, represented by the list of its inputs. Each input is a pair [n, x] with n a boolean value indicating whether sub-tree x should be negated.
Return type:object


x * y               -> [False, [[False, x], [False, y]]]
-(x * y)            -> [True, [[False, x], [False, y]]]
-x * y              -> [False, [[True, x], [False, y]]]
-x                  -> [True, x]
(x * y) * -z        -> [False, [[False, [[False, x], [False, y]]],
                                [True, z]]]
aesara.tensor.rewriting.math.perform_sigm_times_exp(tree, exp_x=None, exp_minus_x=None, sigm_x=None, sigm_minus_x=None, parent=None, child_idx=None, full_tree=None)[source]

Core processing of the local_sigm_times_exp rewrite.

This recursive function operates on a multiplication tree as output by parse_mul_tree. It walks through the tree and modifies it in-place by replacing matching pairs (exp, sigmoid) with the desired version.

  • tree – The sub-tree to operate on.
  • exp_x – List of arguments x so that exp(x) exists somewhere in the whole multiplication tree. Each argument is a pair (x, leaf) with x the argument of the exponential, and leaf the corresponding leaf in the multiplication tree (of the form [n, exp(x)] – see parse_mul_tree). If None, this argument is initialized to an empty list.
  • exp_minus_x – Similar to exp_x, but for exp(-x).
  • sigm_x – Similar to exp_x, but for sigmoid(x).
  • sigm_minus_x – Similar to exp_x, but for sigmoid(-x).
  • parent – Parent of tree (None if tree is the global root).
  • child_idx – Index of tree in its parent’s inputs (None if tree is the global root).
  • full_tree – The global multiplication tree (should not be set except by recursive calls to this function). Used for debugging only.

True if a modification was performed somewhere in the whole multiplication tree, or False otherwise.

Return type:


aesara.tensor.rewriting.math.replace_leaf(arg, leaves, new_leaves, op, neg)[source]

Attempt to replace a leaf of a multiplication tree.

We search for a leaf in leaves whose argument is arg, and if we find one, we remove it from leaves and add to new_leaves a leaf with argument arg and variable op(arg).

  • arg – The argument of the leaf we are looking for.
  • leaves – List of leaves to look into. Each leaf should be a pair (x, l) with x the argument of the Op found in the leaf, and l the actual leaf as found in a multiplication tree output by parse_mul_tree (i.e. a pair [boolean, variable]).
  • new_leaves – If a replacement occurred, then the leaf is removed from leaves and added to the list new_leaves (after being modified by op).
  • op – A function that, when applied to arg, returns the Variable we want to replace the original leaf variable with.
  • neg (bool) – If True, then the boolean value associated to the leaf should be swapped. If False, then this value should remain unchanged.

True if a replacement occurred, or False otherwise.

Return type:


aesara.tensor.rewriting.math.scalarconsts_rest(inputs, elemwise=True, only_process_constants=False)[source]

Partition a list of variables into two kinds: scalar constants, and the rest.


Simplify a multiplication tree.

Parameters:tree – A multiplication tree (as output by parse_mul_tree).
Returns:A multiplication tree computing the same output as tree but without useless multiplications by 1 nor -1 (identified by leaves of the form [False, None] or [True, None] respectively). Useless multiplications (with less than two inputs) are also removed from the tree.
Return type:object