The matrix multiplication is performed with
tf.matmul in Tensorflow or
K.dot in Keras :
from keras import backend as K a = K.ones((3,4)) b = K.ones((4,5)) c = K.dot(a, b) print(c.shape)
import tensorflow as tf a = tf.ones((3,4)) b = tf.ones((4,5)) c = tf.matmul(a, b) print(c.shape)
returns a tensor of shape (3,5) in both cases. It is always simple when tensor dimension is no greater than 2, even 3. However, a compatible way is what we persue. When dimension is higher(introducing higher dimension data and batch data), batch matrix multiplication is what we need.
Simple Batch Matrix Multiplication : tf.matmul or K.batch_dot
There is another operator,
K.batch_dot that works the same as
from keras import backend as K a = K.ones((9, 8, 7, 4, 2)) b = K.ones((9, 8, 7, 2, 5)) c = K.batch_dot(a, b) print(c.shape)
import tensorflow as tf a = tf.ones((9, 8, 7, 4, 2)) b = tf.ones((9, 8, 7, 2, 5)) c = tf.matmul(a, b) print(c.shape)
returns a tensor of shape (9, 8, 7, 4, 5) in both cases.
So, here the multiplication has been performed considering (9,8,7) as the batch size or non-spatial dimension. Data is considered as (B1,…,Bn,C,H,W) format. Spatial dimension of tensor is in the last two indices. Here , spatial dimension of tensor a is (4,2) and b is (2,5).
However, if channel as the last dimension(data default format) will cause spatial multiplication( height and width matrix multiplication) error in
K.batch_dot. Thus, we need
K.permute_dimensions as a preprocessing step for Batch Matrix Multiplication.
Batch Matrix Multiplication：K.permute_dimensions and K.batch_dot
Take a multi-channel data as example.
from keras import backend as K """ batch_a [10,512,256,3] , 10 as batch number, 512x256 as height x width, 3-channel batch_b [10,256,512,3] , 10 as batch number, 256x512 as height x width, 3-channel """ a_t = K.permute_dimensions(a, (0,3,1,2)) # K.int_shape(a_t)=(10,3,512,256) b_t = K.permute_dimensions(b, (0,3,1,2)) # K.int_shape(b_t)=(10,3,256,512) c_t = K.batch_dot(a_t, b_t, axes=(3, 2)) K.int_shape(c) # (10, 3, 512, 512) c = K.permute_dimensions(c_t, (0,2,3,1)) # K.int_shape(c)=(10,512,512,3)
Case suspends for better answer if anyone please to enlight more.