aesara.tensor.batched_dot

Contents

aesara.tensor.batched_dot#

aesara.tensor.batched_dot(a, b)[source]#

Compute the batched dot product of two variables.

I.e.:

batched_dot(a, b)[i] = dot(a[i], b[i])

Note that this batched_dot function does one of three things, in the following sequence:

  1. If either a or b is a vector, it returns the batched elementwise product without calling the Aesara BatchedDot op.

  2. If both a and b have either 2 or 3 dimensions, it calls Aesara’s BatchedDot op on a and b.

  3. If either a or b has more than 3 dimensions, it calls Aesara’s batched_tensordot function with appropriate axes. The batched_tensordot function expresses high-dimensional batched dot products in terms of batched matrix-matrix dot products, so it may be possible to further optimize for performance.