Introducing skip connections in a Keras model implies moving away from the Sequential model, but we can build a custom SkipConnection layer to be able to integrate it with the easy-to-use Sequential model.
source
SkipConnection
SkipConnection (main_path, skip_path=None, how='add', **kwargs)
Skip connection layer to easily introduce this architecture without moving away from the Sequential model.
main_path
Layer (or set of layers) to apply to the input through the main path.
skip_path
NoneType
None
Layer (or set of layers) to apply to the input through the main path.
how
str
add
How to combine the two paths. Can be either "add" or "concat".
kwargs
model = tf.keras.Sequential([
layers.Dense(30 , input_shape= (50 ,)),
SkipConnection(main_path= tf.keras.Sequential([layers.Dense(15 ), layers.Dense(30 )]))
])
assert model.output_shape[- 1 ] == 30
model.summary()
Model: "sequential_21"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_29 (Dense) (None, 30) 1530
skip_connection_9 (SkipConn (None, 30) 945
ection)
=================================================================
Total params: 2,475
Trainable params: 2,475
Non-trainable params: 0
_________________________________________________________________
sample_input = tf.random.normal(shape= (32 ,50 ))
sample_output = model.predict(sample_input, verbose= 0 )
assert sample_output.shape == (32 ,30 )
model = tf.keras.Sequential([
layers.Dense(30 , input_shape= (50 ,)),
SkipConnection(main_path= tf.keras.Sequential([layers.Dense(15 ), layers.Dense(30 )]), how= "concat" )
])
assert model.output_shape[- 1 ] == 60
model.summary()
Model: "sequential_23"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_32 (Dense) (None, 30) 1530
skip_connection_10 (SkipCon (None, 60) 945
nection)
=================================================================
Total params: 2,475
Trainable params: 2,475
Non-trainable params: 0
_________________________________________________________________
sample_input = tf.random.normal(shape= (32 ,50 ))
sample_output = model.predict(sample_input, verbose= 0 )
assert sample_output.shape == (32 ,60 )