# A brief talk through Matrix Multiplication in Keras with Tensorflow as Backend

### Matrix Multiplication

The matrix multiplication is performed with `tf.matmul`

in Tensorflow or `K.dot`

in Keras :

```
from keras import backend as K
a = K.ones((3,4))
b = K.ones((4,5))
c = K.dot(a, b)
print(c.shape)
```

or

```
import tensorflow as tf
a = tf.ones((3,4))
b = tf.ones((4,5))
c = tf.matmul(a, b)
print(c.shape)
```

returns a tensor of shape (3,5) in both cases. It is always simple when tensor dimension is no greater than 2, even 3. However, a compatible way is what we persue. When dimension is higher(introducing higher dimension data and batch data), batch matrix multiplication is what we need.

### Simple Batch Matrix Multiplication : tf.matmul or K.batch_dot

There is another operator, `K.batch_dot`

that works the same as `tf.matmul`

```
from keras import backend as K
a = K.ones((9, 8, 7, 4, 2))
b = K.ones((9, 8, 7, 2, 5))
c = K.batch_dot(a, b)
print(c.shape)
```

or

```
import tensorflow as tf
a = tf.ones((9, 8, 7, 4, 2))
b = tf.ones((9, 8, 7, 2, 5))
c = tf.matmul(a, b)
print(c.shape)
```

returns a tensor of shape (9, 8, 7, 4, 5) in both cases.

So, here the multiplication has been performed considering (9,8,7) as the batch size or non-spatial dimension. Data is considered as (B1,…,Bn,C,H,W) format. Spatial dimension of tensor is in the last two indices. Here , spatial dimension of tensor a is (4,2) and b is (2,5).

However, if channel as the last dimension(data default format) will cause spatial multiplication( height and width matrix multiplication) error in `K.batch_dot`

. Thus, we need `K.permute_dimensions`

as a preprocessing step for Batch Matrix Multiplication.

### Batch Matrix Multiplication：K.permute_dimensions and K.batch_dot

Take a multi-channel data as example.

```
from keras import backend as K
"""
batch_a [10,512,256,3] , 10 as batch number, 512x256 as height x width, 3-channel
batch_b [10,256,512,3] , 10 as batch number, 256x512 as height x width, 3-channel
"""
a_t = K.permute_dimensions(a, (0,3,1,2)) # K.int_shape(a_t)=(10,3,512,256)
b_t = K.permute_dimensions(b, (0,3,1,2)) # K.int_shape(b_t)=(10,3,256,512)
c_t = K.batch_dot(a_t, b_t, axes=(3, 2))
K.int_shape(c) # (10, 3, 512, 512)
c = K.permute_dimensions(c_t, (0,2,3,1)) # K.int_shape(c)=(10,512,512,3)
```

Case suspends for better answer if anyone please to enlight more.