Skip connection

Easily composable skip connection layer. Skip connections are great because they help the gradient to flow along the network and are used in a lot of modern architectures.

Introducing skip connections in a Keras model implies moving away from the Sequential model, but we can build a custom SkipConnection layer to be able to integrate it with the easy-to-use Sequential model.


source

SkipConnection

 SkipConnection (main_path, skip_path=None, how='add', **kwargs)

Skip connection layer to easily introduce this architecture without moving away from the Sequential model.

Type Default Details
main_path Layer (or set of layers) to apply to the input through the main path.
skip_path NoneType None Layer (or set of layers) to apply to the input through the main path.
how str add How to combine the two paths. Can be either "add" or "concat".
kwargs
model = tf.keras.Sequential([
    layers.Dense(30, input_shape=(50,)),
    SkipConnection(main_path=tf.keras.Sequential([layers.Dense(15), layers.Dense(30)]))
])
assert model.output_shape[-1] == 30
model.summary()
Model: "sequential_21"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_29 (Dense)            (None, 30)                1530      
                                                                 
 skip_connection_9 (SkipConn  (None, 30)               945       
 ection)                                                         
                                                                 
=================================================================
Total params: 2,475
Trainable params: 2,475
Non-trainable params: 0
_________________________________________________________________
sample_input = tf.random.normal(shape=(32,50))
sample_output = model.predict(sample_input, verbose=0)
assert sample_output.shape == (32,30)
model = tf.keras.Sequential([
    layers.Dense(30, input_shape=(50,)),
    SkipConnection(main_path=tf.keras.Sequential([layers.Dense(15), layers.Dense(30)]), how="concat")
])
assert model.output_shape[-1] == 60
model.summary()
Model: "sequential_23"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_32 (Dense)            (None, 30)                1530      
                                                                 
 skip_connection_10 (SkipCon  (None, 60)               945       
 nection)                                                        
                                                                 
=================================================================
Total params: 2,475
Trainable params: 2,475
Non-trainable params: 0
_________________________________________________________________
sample_input = tf.random.normal(shape=(32,50))
sample_output = model.predict(sample_input, verbose=0)
assert sample_output.shape == (32,60)