Rewrites#
This section of the documentation references all the rewrites that can be applied during the compilation of an Aesara graph
Tensor rewrites#
These rewrites are implemented in the module tensor.rewriting.basic
.
Tensor optimizations addressing the ops in basic.py.
- aesara.tensor.rewriting.basic.broadcast_like(value, template, fgraph, dtype=None)[source]#
Return a Variable with the same shape and dtype as the template, filled by broadcasting value through it.
value
will be cast as necessary.
- aesara.tensor.rewriting.basic.encompasses_broadcastable(b1, b2)[source]#
- Parameters:
b1 – The broadcastable attribute of a tensor type.
b2 – The broadcastable attribute of a tensor type.
- Returns:
True if the broadcastable patterns b1 and b2 are such that b2 is broadcasted to b1’s shape and not the opposite.
- Return type:
bool
Indexing#
- aesara.tensor.rewriting.subtensor.get_advsubtensor_axis(indices)[source]#
Determine the axis at which an array index is applied.
This only works for
take
-like indices: e.g.x[:, :, idx, ...]
. For the above example,get_advsubtensor_axis
would return2
. If it encounters anything other than a set ofindices
containing full slices and an array/tensor index, it will returnNone
.
- aesara.tensor.rewriting.subtensor.is_full_slice(x)[source]#
Determine if
x
is aslice(None)
or a symbolic equivalent.
- aesara.tensor.rewriting.subtensor.merge_two_slices(fgraph, slice1, len1, slice2, len2)[source]#
This function merges two slices into a single slice. The code works on the assumption that:
slice1 is actually a slice and not an index, while slice2 can be just an index.
the two slices have been applied consecutively on the same tensor
The output slice is not in canonical form, but actually just a slice that can be applied to a tensor to produce the same output as applying the two consecutive slices.
len1
is the length of the tensor before applying the first slice, whilelen2
is the length after applying the first slice.
- aesara.tensor.rewriting.subtensor.transform_take(a, indices, axis)[source]#
Transform
arr[:,:,:,indices,...]
-like operations into single-dimensional, vector index operations.This effectively converts certain
AdvancedSubtensor
Op
s into a combination ofAdvancedSubtensor1
,Dimshuffle
, andReshape
Op
s, which can be more efficient.- Parameters:
a (TensorVariable) – The source array.
indices (TensorVariable, ndarray, list, tuple) – The indices of the values to extract.
axis (int) – The axis over which to select values. By default, the flattened input array is used.
Shape#
- class aesara.tensor.rewriting.shape.ShapeFeature[source]#
A
Feature
that tracks shape information in a graph.This
Feature
aids in the replacement of allShape
s andSubtensor
s ofShape
s withShape_i
andMakeVector
Op
s.This
Feature
and its associated rewrites have several goals:to “lift”
Shape
s to as close to the inputs as possible,to infer the shape of every node in the graph in terms of the input shapes, and
remove fill
Op
s (e.g.Second
) from the graph.
Lifting shapes as close to the inputs as possible is important for canonicalization because it is very bad form to have to compute something just to know how big it will be. Firstly, it is a waste of time to compute such outputs. But it is important to get rid of these outputs as early as possible in the compilation process because the extra computations make it appear as if many internal graph nodes have multiple clients. Many rewrites refuse to work on nodes with multiple clients.
Lifting is done by using an
Op.infer_shape()
method if one is present, or else using a conservative default..Inferring the shape of internal nodes in the graph is important for doing size-driven rewrites. If we know how big various intermediate results will be, we can estimate the cost of many
Op
s accurately, and generate code that is specific (e.g. unrolled) to particular sizes.In cases where
ShapeFeature
cannot figure out the shape, it raises aShapeError
.Note
We can’t automatically infer the shape of shared variables as they can change of shape during the execution by default.
To use the shape information gathered by a
FunctionGraph
-attachedShapeFeature
in rewrites, use theShapeFeature.get_shape()
method.- clone()[source]#
Create a clone that can be attached to a new
FunctionGraph
.This default implementation returns
self
, which carries the assumption that theFeature
is essentially stateless. If a subclass has state of its own that is in any way relative to a givenFunctionGraph
, this method should be overridden with an implementation that actually creates a fresh copy.
- default_infer_shape(fgraph: FunctionGraph, node: Apply, i_shapes: InputShapesType) OutputShapesType [source]#
Return a list of shape tuple or None for the outputs of node.
This function is used for Ops that don’t implement infer_shape. Ops that do implement infer_shape should use the i_shapes parameter, but this default implementation ignores it.
- get_shape(fgraph: FunctionGraph, var: Variable, idx: int) Variable [source]#
Get the shape of
var
at indexidx
.It is better to call this than use
ShapeFeature.shape_of[var][idx]
, since this method will updateShapeFeature.shape_of
when needed.TODO: Up to now, we don’t update it in all cases. Update in all cases.
- on_attach(fgraph)[source]#
Called by
FunctionGraph.attach_feature
, the method that attaches the feature to theFunctionGraph
. Since this is called after theFunctionGraph
is initially populated, this is where you should run checks on the initial contents of theFunctionGraph
.The on_attach method may raise the
AlreadyThere
exception to cancel the attach operation if it detects that another Feature instance implementing the same functionality is already attached to theFunctionGraph
.The feature has great freedom in what it can do with the
fgraph
: it may, for example, add methods to it dynamically.
- on_change_input(fgraph, node, i, r, new_r, reason)[source]#
Called whenever
node.inputs[i]
is changed fromvar
tonew_var
. At the moment the callback is done, the change has already taken place.If you raise an exception in this function, the state of the graph might be broken for all intents and purposes.
- on_detach(fgraph)[source]#
Called by
FunctionGraph.remove_feature
. Should remove any dynamically-added functionality that it installed into the fgraph.
- on_import(fgraph, node, reason)[source]#
Called whenever a node is imported into
fgraph
, which is just before the node is actually connected to the graph.Note: this is not called when the graph is created. If you want to detect the first nodes to be implemented to the graph, you should do this by implementing
on_attach
.
- same_shape(x: Variable, y: Variable, dim_x: int | None = None, dim_y: int | None = None) bool [source]#
Return
True
ifx
andy
have the same shape.- Parameters:
x – The
Variable
for which its shape is to be compared withy
’s shape.y – The
Variable
for which its shape is to be compared withx
’s shape.dim_x – If non
None
, compare only the dimension ofx
equal todim_x
.dim_y – If non
None
, compare only the dimension ofy
equal todim_y
.
- set_shape(r: Variable, s: Sequence[Variable] | None, override: bool = False) None [source]#
Assign the shape
s
to previously un-shaped variabler
.- Parameters:
r –
s –
override – If
False
, it meansr
is a new, unseen term. IfTrue
, it meansr
is assumed to have already been seen and we want to override its shape.
- set_shape_i(r: Variable, i: int, s_i: Variable) None [source]#
Replace element i of shape_of[r] by s_i
- shape_tuple(r: Variable) Tuple[Variable, ...] | None [source]#
Return a tuple of symbolic shape vars for tensor variable r.
- to_symbolic_int(s_i: int | float | integer | ArrayLike | Variable) Variable [source]#
Return a symbolic integer scalar for the shape element
s_i
.TODO: Re-evaluate the need for this, since it’s effectively eager canonicalization.
- Parameters:
s_i – The
s_i
argument is assumed to be produced by anOp.infer_shape()
.
- class aesara.tensor.rewriting.shape.ShapeOptimizer[source]#
Rewriter that adds
ShapeFeature
as a feature.- apply(fgraph)[source]#
Apply the rewriter to a
FunctionGraph
.It may use all the methods defined by the
FunctionGraph
. If theGraphRewriter
needs to use a certain tool, such as anInstanceFinder
, it can do so in itsadd_requirements
method.
- class aesara.tensor.rewriting.shape.UnShapeOptimizer[source]#
Rewriter that removes
ShapeFeature
as a feature.- apply(fgraph)[source]#
Apply the rewriter to a
FunctionGraph
.It may use all the methods defined by the
FunctionGraph
. If theGraphRewriter
needs to use a certain tool, such as anInstanceFinder
, it can do so in itsadd_requirements
method.
Mathematical operations#
Rewrites for the Op
s in aesara.tensor.math
.
- class aesara.tensor.rewriting.math.AlgebraicCanonizer(main, inverse_fn, reciprocal_fn, calculate, use_reciprocal=True)[source]#
A
Rewriter
that rewrites algebraic expressions.The variable is a
node_rewriter
. It is best used with aWalkingGraphRewriter
in in-to-out order.Usage:
AlgebraicCanonizer(main, inverse, reciprocal, calculate)
- Parameters:
main – A suitable
Op
class that is commutative, associative and takes one to an arbitrary number of inputs, e.g. add or mulinverse – An
Op
class such thatinverse(main(x, y), y) == x
(e.g.sub
ortrue_divide
).reciprocal – A function such that
main(x, reciprocal(y)) == inverse(x, y)
(e.g.neg
orreciprocal
).calculate – Function that takes a list of
numpy.ndarray
instances for the numerator, another list for the denumerator, and calculatesinverse(main(\*num), main(\*denum))
. It takes a keyword argument,aslist
. IfTrue
, the value should be returned as a list of one element, unless the value is such thatvalue = main()
. In that case, the return value should be an empty list.
Examples
>>> import aesara.tensor as at >>> from aesara.tensor.rewriting.math import AlgebraicCanonizer >>> add_canonizer = AlgebraicCanonizer(add, sub, neg, \\ ... lambda n, d: sum(n) - sum(d)) >>> mul_canonizer = AlgebraicCanonizer(mul, true_divide, inv, \\ ... lambda n, d: prod(n) / prod(d))
Examples of rewrites
mul_canonizer
can perform:x / x -> 1(x * y) / x -> yx / y / x -> 1 / yx / y / z -> x / (y * z)x / (y / z) -> (x * z) / y(a / b) * (b / c) * (c / d) -> a / d(2.0 * x) / (4.0 * y) -> (0.5 * x) / y2 * x / 2 -> xx * y * z -> Elemwise(mul){x,y,z} #only one pass over the memory.!-> Elemwise(mul){x,Elemwise(mul){y,z}}- get_num_denum(inp)[source]#
This extract two lists,
num
anddenum
, such that the input is:self.inverse(self.main(\*num), self.main(\*denum))
. It returns the two lists in a(num, denum)
pair.For example, for main, inverse and
reciprocal = \*, / and inv()
,input -> returned value (num, denum)x*y -> ([x, y], [])inv(x) -> ([], [x])inv(x) * inv(y) -> ([], [x, y])x*y/z -> ([x, y], [z])log(x) / y * (z + x) / y -> ([log(x), z + x], [y, y])(((a / b) * c) / d) -> ([a, c], [b, d])a / (b / c) -> ([a, c], [b])log(x) -> ([log(x)], [])x**y -> ([x**y], [])x * y * z -> ([x, y, z], [])
- merge_num_denum(num, denum)[source]#
Utility function which takes two lists, num and denum, and returns something which is equivalent to inverse(main(*num), main(*denum)), but depends on the length of num and the length of denum (in order to minimize the number of operations).
Let n = len(num) and d = len(denum):
n=0, d=0: neutral element (given by self.calculate([], []))(for example, this would be 0 if main is additionand 1 if main is multiplication)n=1, d=0: num[0]n=0, d=1: reciprocal(denum[0])n=1, d=1: inverse(num[0], denum[0])n=0, d>1: reciprocal(main(*denum))n>1, d=0: main(*num)n=1, d>1: inverse(num[0], main(*denum))n>1, d=1: inverse(main(*num), denum[0])n>1, d>1: inverse(main(*num), main(*denum))Given the values of n and d to which they are associated, all of the above are equivalent to: inverse(main(*num), main(*denum))
- simplify(num, denum, out_type)[source]#
Shorthand for:
self.simplify_constants(*self.simplify_factors(num, denum))
- simplify_constants(orig_num, orig_denum, out_type=None)[source]#
Find all constants and put them together into a single constant.
Finds all constants in orig_num and orig_denum (using get_constant) and puts them together into a single constant. The constant is inserted as the first element of the numerator. If the constant is the neutral element, it is removed from the numerator.
Examples
Let main be multiplication:
[2, 3, x], [] -> [6, x], [][x, y, 2], [4, z] -> [0.5, x, y], [z][x, 2, y], [z, 2] -> [x, y], [z]
- simplify_factors(num, denum)[source]#
For any Variable r which is both in num and denum, removes it from both lists. Modifies the lists inplace. Returns the modified lists. For example:
[x], [x] -> [], [][x, y], [x] -> [y], [][a, b], [c, d] -> [a, b], [c, d]
- tracks()[source]#
Return the list of
Op
classes to which this rewrite applies.Returns
None
when the rewrite applies to all nodes.
- transform(fgraph, node)[source]#
Rewrite the sub-graph given by
node
.Subclasses should implement this function so that it returns one of the following:
False
to indicate that this rewrite cannot be applied tonode
A list of
Variable
s to use in place of thenode
’s current outputsA
dict
mapping oldVariable
s toVariable
s, or the key
"remove"
mapping to a list ofVariable
s to be removed.- Parameters:
fgraph – A
FunctionGraph
containingnode
.node – An
Apply
node to be rewritten.
- aesara.tensor.rewriting.math.attempt_distribution(factor, num, denum, out_type)[source]#
Try to insert each
num
and eachdenum
in the factor?- Returns:
If there are changes,
new_num
andnew_denum
contain all the numerators and denominators that could not be distributed in the factor- Return type:
changes?, new_factor, new_num, new_denum
- aesara.tensor.rewriting.math.check_for_x_over_absX(numerators, denominators)[source]#
Convert x/abs(x) into sign(x).
- aesara.tensor.rewriting.math.compute_mul(tree)[source]#
Compute the Variable that is the output of a multiplication tree.
This is the inverse of the operation performed by
parse_mul_tree
, i.e. compute_mul(parse_mul_tree(tree)) == tree.- Parameters:
tree – A multiplication tree (as output by
parse_mul_tree
).- Returns:
A Variable that computes the multiplication represented by the tree.
- Return type:
object
- aesara.tensor.rewriting.math.get_constant(v)[source]#
- Returns:
A numeric constant if v is a Constant or, well, a numeric constant. If v is a plain Variable, returns None.
- Return type:
object
- aesara.tensor.rewriting.math.is_1pexp(t, only_process_constants=True)[source]#
- Returns:
If ‘t’ is of the form (1+exp(x)), return (False, x). Else return None.
- Return type:
object
- aesara.tensor.rewriting.math.is_exp(var)[source]#
Match a variable with either of the
exp(x)
or-exp(x)
patterns.- Parameters:
var – The Variable to analyze.
- Returns:
A pair (b, x) with
b
a boolean set to True ifvar
is of the form-exp(x)
and False ifvar
is of the formexp(x)
. Ifvar
cannot be cast into either form, then returnNone
.- Return type:
tuple
- aesara.tensor.rewriting.math.is_inverse_pair(node_op, prev_op, inv_pair)[source]#
Given two consecutive operations, check if they are the provided pair of inverse functions.
- aesara.tensor.rewriting.math.is_mul(var)[source]#
Match a variable with
x * y * z * ...
.- Parameters:
var – The Variable to analyze.
- Returns:
A list [x, y, z, …] if
var
is of the formx * y * z * ...
, or None ifvar
cannot be cast into this form.- Return type:
object
- aesara.tensor.rewriting.math.is_neg(var)[source]#
Match a variable with the
-x
pattern.- Parameters:
var – The Variable to analyze.
- Returns:
x
ifvar
is of the form-x
, or None otherwise.- Return type:
object
- aesara.tensor.rewriting.math.local_add_mul_fusion(fgraph, node)[source]#
Fuse consecutive add or mul in one such node with more inputs.
It is better to fuse add/mul that way then in a Composite node as this make the inner graph of the Composite smaller. This allow to put more computation in a Composite before hitting the max recursion limit when pickling Composite.
- aesara.tensor.rewriting.math.parse_mul_tree(root)[source]#
Parse a tree of multiplications starting at the given root.
- Parameters:
root – The variable at the root of the tree.
- Returns:
A tree where each non-leaf node corresponds to a multiplication in the computation of
root
, represented by the list of its inputs. Each input is a pair [n, x] withn
a boolean value indicating whether sub-treex
should be negated.- Return type:
object
Examples
x * y -> [False, [[False, x], [False, y]]] -(x * y) -> [True, [[False, x], [False, y]]] -x * y -> [False, [[True, x], [False, y]]] -x -> [True, x] (x * y) * -z -> [False, [[False, [[False, x], [False, y]]], [True, z]]]
- aesara.tensor.rewriting.math.perform_sigm_times_exp(tree, exp_x=None, exp_minus_x=None, sigm_x=None, sigm_minus_x=None, parent=None, child_idx=None, full_tree=None)[source]#
Core processing of the
local_sigm_times_exp
rewrite.This recursive function operates on a multiplication tree as output by
parse_mul_tree
. It walks through the tree and modifies it in-place by replacing matching pairs (exp, sigmoid) with the desired version.- Parameters:
tree – The sub-tree to operate on.
exp_x – List of arguments
x
so thatexp(x)
exists somewhere in the whole multiplication tree. Each argument is a pair(x, leaf)
withx
the argument of the exponential, andleaf
the corresponding leaf in the multiplication tree (of the form[n, exp(x)]
– seeparse_mul_tree
). IfNone
, this argument is initialized to an empty list.exp_minus_x – Similar to
exp_x
, but forexp(-x)
.sigm_x – Similar to
exp_x
, but forsigmoid(x)
.sigm_minus_x – Similar to
exp_x
, but forsigmoid(-x)
.parent – Parent of
tree
(None
iftree
is the global root).child_idx – Index of
tree
in its parent’s inputs (None
iftree
is the global root).full_tree – The global multiplication tree (should not be set except by recursive calls to this function). Used for debugging only.
- Returns:
True
if a modification was performed somewhere in the whole multiplication tree, orFalse
otherwise.- Return type:
bool
- aesara.tensor.rewriting.math.replace_leaf(arg, leaves, new_leaves, op, neg)[source]#
Attempt to replace a leaf of a multiplication tree.
We search for a leaf in
leaves
whose argument isarg
, and if we find one, we remove it fromleaves
and add tonew_leaves
a leaf with argumentarg
and variableop(arg)
.- Parameters:
arg – The argument of the leaf we are looking for.
leaves – List of leaves to look into. Each leaf should be a pair (x, l) with
x
the argument of the Op found in the leaf, andl
the actual leaf as found in a multiplication tree output byparse_mul_tree
(i.e. a pair [boolean, variable]).new_leaves – If a replacement occurred, then the leaf is removed from
leaves
and added to the listnew_leaves
(after being modified byop
).op – A function that, when applied to
arg
, returns the Variable we want to replace the original leaf variable with.neg (bool) – If True, then the boolean value associated to the leaf should be swapped. If False, then this value should remain unchanged.
- Returns:
True if a replacement occurred, or False otherwise.
- Return type:
bool
- aesara.tensor.rewriting.math.scalarconsts_rest(inputs, elemwise=True, only_process_constants=False)[source]#
Partition a list of variables into two kinds: scalar constants, and the rest.
- aesara.tensor.rewriting.math.simplify_mul(tree)[source]#
Simplify a multiplication tree.
- Parameters:
tree – A multiplication tree (as output by
parse_mul_tree
).- Returns:
A multiplication tree computing the same output as
tree
but without useless multiplications by 1 nor -1 (identified by leaves of the form [False, None] or [True, None] respectively). Useless multiplications (with less than two inputs) are also removed from the tree.- Return type:
object
- class aesara.tensor.rewriting.elemwise.FusionOptimizer(node_rewriter)[source]#
Graph rewriter that simply runs node fusion operations.
TODO: This is basically an
EquilibriumGraphRewriter
; we should just use that.- apply(fgraph)[source]#
Apply the rewriter to a
FunctionGraph
.It may use all the methods defined by the
FunctionGraph
. If theGraphRewriter
needs to use a certain tool, such as anInstanceFinder
, it can do so in itsadd_requirements
method.
- class aesara.tensor.rewriting.elemwise.InplaceElemwiseOptimizer(OP)[source]#
This is parameterized so that it works for
Elemwise
Op
s.- apply(fgraph)[source]#
Attempts to replace all
Elemwise
s by versions of them that operate inplace. It operates greedily: for eachElemwise
that is encountered, for each output, it tries each input to see if it can operate inplace on that input. If so, it makes the change and goes to the next output orElemwise
.Examples
x + y + z -> x += y += z (x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
- aesara.tensor.rewriting.elemwise.is_dimshuffle_useless(new_order, input)[source]#
- Checks for two types of useless dimshuffles:
1 - dimshuffle all dimensions in order. 2 - dimshuffle a broadcastable dimension.
- aesara.tensor.rewriting.elemwise.local_elemwise_fusion(fgraph, node)[source]#
Fuse
Elemwise
Op
s in a node.As part of specialization, we fuse two consecutive
Elemwise
Op
s of the same shape.For mixed dtype, we let the
Composite
Op
do the cast. It lets the C compiler do the cast.The number of dimensions is validated at call time by Aesara itself.
- aesara.tensor.rewriting.elemwise.local_elemwise_fusion_op(op_class, max_input_fct=<function <lambda>>, maker=None)[source]#
Create a recursive function that fuses
Elemwise
Op
s.The basic idea is that we loop through an
Elemwise
node’s inputs, find otherElemwise
nodes, determine the scalars input types for all of theElemwise
Op
s, construct a new scalarOp
using the scalar input types and eachElemwise
’s scalarOp
, and use the composite scalarOp
in a new “fused”Elemwise
.It’s parameterized in order to work for
Elemwise
Op
s.- Parameters:
op_class (type) –
Elemwise
class (the one that we want to fuse)max_input_fct (callable) – A function that returns the maximum number of inputs that this
Elemwise
can take. On the CPU we limit to 32 input variables since that is the maximum NumPy support.maker (callable) – A function with the signature
(node, *args)
that constructs anop_class
instance (e.g.op_class(*args)
).
Random variables#
- aesara.tensor.random.rewriting.basic.is_rv_used_in_graph(base_rv, node, fgraph)[source]#
Determine whether or not
base_rv
is used by a node other thannode
infgraph
.If a node uses
Shape
orShape_i
on thebase_rv
, we ignore it, because thoseOp`s don't rely on the actual sample values of `base_rv
.TODO: We should apply all the shape rewrites before these rewrites, since that would properly remove the unnecessary dependencies on
base_rv
(when possible).