Home

Awesome

Backpropagation through Combinatorial Algorithms: Identity with Projection Works

By Subham S. Sahoo*, Anselm Paulus*, Marin Vlastelica, Vít Musil, Volodymyr Kuleshov, Georg Martius

Embedding discrete solvers as differentiable layers equips deep learning models with combinatorial expressivity and discrete reasoning. However, these solvers have zero or undefined derivatives, necessitating a meaningful gradient replacement for effective learning. Prior methods smooth or relax solvers, or interpolate loss landscapes, often requiring extra solver calls, hyperparameters, or sacrificing performance. We propose a principled approach leveraging the geometry of the discrete solution space, treating the solver as a negative identity on the backward pass, with theoretical justification. Experiments show our hyperparameter-free method competes with complex alternatives across tasks like discrete sampling, graph matching, and image retrieval. We also introduce a generic regularization that prevents cost collapse and enhances robustness, replacing previous margin-based approaches.

Link to the paper on arXiv.

Requirements

IdentitySubsetLayer class implements the backpropation through the sampling process of $k-hot$ vectors using our proposed identity method.

class IdentitySubsetLayer(tf.keras.layers.Layer):
       
    @tf.custom_gradient
    def identity_layer(self, logits_with_noise, hard=False):
        threshold = tf.expand_dims(tf.nn.top_k(logits_with_noise, self.k, sorted=True)[0][:,-1], -1)
        y_map = tf.cast(tf.greater_equal(logits_with_noise, threshold), tf.float32)

        def custom_grad(dy):
            return dy, hard

        return y_map, custom_grad

    def call(self, logits, hard=False):
        logits = logits + sample_gumbel_k(tf.shape(logits))
        
        # Project onto the hyperplane
        logits = logits - tf.reduce_mean(logits, 1)[:, None]
        
        # Project onto the sphere
        logits = logits / tf.norm(logits, axis=1)[:, None]
        
        return self.identity_layer(logits, hard)

<sup>This notebook was adapted from here.</sup>

Citation

@inproceedings{
sahoo2023backpropagation,
title={Backpropagation through Combinatorial Algorithms: Identity with Projection Works},
author={Subham Sekhar Sahoo and Anselm Paulus and Marin Vlastelica and V{\'\i}t Musil and Volodymyr Kuleshov and Georg Martius},
booktitle={The Eleventh International Conference on Learning Representations },
year={2023},
url={https://openreview.net/forum?id=JZMR727O29}
}