Jigsaw Puzzle

Simple implementation of the Jigsaw Puzzle self-supervised task for TensorFlow. It is intented to be used with every dataset as plug and play feature.

This could be implemented as a function that uses out of scope variables, but as we want to have everything as packed as possible, we’re going to define a class that stores the possible combinations and the methods needed to invert the permutations in order to de-puzzle a puzzled image.


source

JigsawPuzzle

 JigsawPuzzle (n_tiles)

Class that supports all the logic needed to perform a jigsaw puzzle self-supervision task.

Details
n_tiles Number of tiles per puzzle.
red = np.ones(shape=(16,16,3))*np.array([255,0,0])[None,None,:]
blue = np.ones(shape=(16,16,3))*np.array([0,255,0])[None,None,:]
green = np.ones(shape=(16,16,3))*np.array([0,0,255])[None,None,:]
white = np.ones(shape=(16,16,3))*np.array([255,255,255])[None,None,:]
red_blue = np.concatenate([red, blue], axis=1)
green_white = np.concatenate([green, white], axis=1)
img = np.concatenate([red_blue, green_white], axis=0)
red_blue.shape, green_white.shape, img.shape
((16, 32, 3), (16, 32, 3), (32, 32, 3))
plt.imshow(img)
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

puzler = JigsawPuzzle(n_tiles=4)
img_puzzle, label = puzler.make_puzzle(img)
permutation =  puzler.labels2permutations[label.numpy()]
img_puzzle.shape, label, permutation
(TensorShape([4, 16, 16, 3]),
 <tf.Tensor: shape=(), dtype=int32, numpy=5>,
 (0, 3, 2, 1))
fig, axes = plt.subplots(2,2)
for tile, ax in zip(img_puzzle, axes.ravel()):
    ax.imshow(tile)
plt.suptitle(permutation)
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

inverted_img, inverted_permutation = puzler.invert_puzzle(img_puzzle, permutation, return_full_image=True, return_permutation=True)
inverted_img.shape, inverted_permutation
(TensorShape([32, 32, 3]), array([0, 3, 2, 1]))
fig, axes = plt.subplots(2,2)
for tile, ax in zip(tf.gather(img_puzzle, inverted_permutation), axes.ravel()):
    ax.imshow(tile)
plt.suptitle(permutation)
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

plt.imshow(inverted_img)
plt.title(permutation)
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Using it with a tf.data.Dataset

def sample_dataset():
    for i in range(20):
        red = np.ones(shape=(16,16,3))*np.array([255,0,0])[None,None,:]
        blue = np.ones(shape=(16,16,3))*np.array([0,255,0])[None,None,:]
        green = np.ones(shape=(16,16,3))*np.array([0,0,255])[None,None,:]
        white = np.ones(shape=(16,16,3))*np.array([255,255,255])[None,None,:]
        red_blue = np.concatenate([red, blue], axis=1)
        green_white = np.concatenate([green, white], axis=1)
        img = np.concatenate([red_blue, green_white], axis=0)
        yield img
dst = tf.data.Dataset.from_generator(sample_dataset,
                                     output_signature=(
                                        tf.TensorSpec(shape=(32,32,3), dtype=tf.float32)
                                     ))
puzler = JigsawPuzzle(n_tiles=4)
dst_puzzle = dst.map(puzler.make_puzzle)
for img in dst:
    print(img.shape)
    break
plt.imshow(img)
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
(32, 32, 3)

for img, perm in dst_puzzle:
    print(img.shape)
    break
perm = puzler.labels2permutations[perm.numpy()]
img_ = puzler.invert_puzzle(img, perm)
plt.imshow(puzler.assemble_puzzle(img))
plt.title(perm)
plt.show()
plt.imshow(img_)
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
(4, 16, 16, 3)

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).