aesara.tensor.batched_tensordot

aesara.tensor.batched_tensordot#

aesara.tensor.batched_tensordot(x, y, axes=2)[source]#

Compute a batched tensordot product.

A hybrid of batched_dot and tensordot, this function computes the tensordot product between the two tensors, by iterating over the first dimension to perform a sequence of tensordots.

Parameters:
  • x (TensorVariable) – A tensor with sizes e.g.: for 3D (dim1, dim3, dim2)

  • y (TensorVariable) – A tensor with sizes e.g.: for 3D (dim1, dim2, dim4)

  • axes (int or array-like of length 2) –

    If an integer, the number of axes to sum over. If an array, it must have two array elements containing the axes to sum over in each tensor.

    If an integer i, it is converted to an array containing the last i dimensions of the first tensor and the first i dimensions of the second tensor (excluding the first (batch) dimension):

    axes = [list(range(a.ndim - i, b.ndim)), list(range(1,i+1))]

    If an array, its two elements must contain compatible axes of the two tensors. For example, [[1, 2], [2, 4]] means sum over the 2nd and 3rd axes of a and the 3rd and 5th axes of b. (Remember axes are zero-indexed!) The 2nd axis of a and the 3rd axis of b must have the same shape; the same is true for the 3rd axis of a and the 5th axis of b.

  • tensordot (Like) –

  • and (this function uses a series of dimshuffles) –

  • vector (reshapes to reduce the tensor dot product to a matrix or) –

  • Finally (dot product.) –

  • result. (it calls batched_dot to compute the) –