Towards Conflict-free training [ICLR 2025 Spotlight]

Introduction

Optimizing multiple loss terms is a common challenge in various deep learning applications. This issue arises in Physics-Informed Neural Networks (PINNs), Multi-Task Learning (MTL), and Continual Learning (CL), where different loss gradients often conflict with each other. Such conflicts can lead to suboptimal convergence or even complete training failure.

Currently, most mainstream methods attempt to mitigate conflicts by adjusting loss weights, but they lack a unified theoretical framework, making them unstable in practical applications. To address this problem, we propose ConFIG (Conflict-Free Inverse Gradients), a novel optimization strategy that ensures stable and efficient multi-loss optimization.

ConFIG: Conflict-Free Inverse Gradients Method

Let’s consider an optimization procedure with a set of \(m\) individual loss functions, i.e., \(\{\mathcal{L}_1,\mathcal{L}_2,\cdots,\mathcal{L}_m\}\). Let \(\{\mathbf{g}_1,\mathbf{g}_2, \cdots, \mathbf{g}_m\}\) denote the individual gradients corresponding to each of the loss functions. A gradient-descent step with gradient \(\mathbf{g}_c\) will conflict with the decrease of \(\mathcal{L}_i\) if \(\mathbf{g}_i^\top \mathbf{g}_c\) is negative. Thus, to ensure that all losses are decreasing simultaneously along \(\mathbf{g}_c\), all \(m\) components of \([\mathbf{g}_1,\mathbf{g}_2,\cdots, \mathbf{g}_m]^\top\mathbf{g}_c\) should be positive. This condition is fulfilled by setting \(\mathbf{g}_c = [\mathbf{g}_1,\mathbf{g}_2,\cdots, \mathbf{g}_m]^{-\top} \mathbf{w}\), where \(\mathbf{w}=[w_1,w_2,\cdots,w_m]\) is a vector with \(m\) positive components and \(M^{-\top}\) is the pseudoinverse of the transposed matrix \(M^{\top}\)​.

Although a positive \(\mathbf{w}\)​ vector guarantees a conflict-free update direction for all losses, the specific value of \(w_i\)​ further influences the exact direction of \(\mathbf{g}_c\)​. To facilitate determining \(\mathbf{w}\)​, we reformulate \(\mathbf{g}_c\)​ as \(\mathbf{g}_c=k[\mathcal{U}(\mathbf{g}_1),\mathcal{U}(\mathbf{g}_2),\cdots, \mathcal{U}(\mathbf{g}_m)]^{-\top} \mathbf{\hat{w}}\)​, where \(\mathcal{U}(\mathbf{g}_i)=\mathbf{g}_i/(\|\mathbf{g}_i\|+\varepsilon)\)​ is a normalization operator and \(k>0\)​. Now, \(k\)​ controls the length of \(\mathbf{g}_c\)​ and the ratio of \(\mathbf{\hat{w}}\)​’s components corresponds to the ratio of \(\mathbf{g}_c\)​’s projections onto each loss-specific \(\mathbf{g}_i\)​, i.e., \(\|\mathbf{g}_c\|\mathcal{S}_c(\mathbf{g},\mathbf{g}_i)\)​, where \(\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_j)=\mathbf{g}_i^\top\mathbf{g}_j/(\|\mathbf{g}_i\|\|\mathbf{g}_j\|+\varepsilon)\)​ is the operator for cosine similarity:

\[\frac{ |\mathbf{g}_c|\mathcal{S}_c(\mathbf{g}_c,\mathbf{g}_i) }{ |\mathbf{g}_c|\mathcal{S}_c(\mathbf{g}_c,\mathbf{g}_j) } = \frac{ \mathcal{S}_c(\mathbf{g}_c,\mathbf{g}_i) }{ \mathcal{S}_c(\mathbf{g}_c,\mathbf{g}_j) } = \frac{ \mathcal{S}_c(\mathbf{g}_c,k\mathcal{U}(\mathbf{g}_i)) }{ \mathcal{S}_c(\mathbf{g}_c,k\mathcal{U}(\mathbf{g}_j)) } = \frac{ [k\mathcal{U}(\mathbf{g}_i)]^\top \mathbf{g}_c }{ [k\mathcal{U}(\mathbf{g}_j)]^\top \mathbf{g}_c } = \frac{\hat{w}_i }{ \hat{w}_j } \quad \forall i,j \in [1,m].\]

We call \(\mathbf{\hat{w}}\) the direction weight. The projection length of \(\mathbf{g}_c\) on each loss-specific gradient serves as an effective ‘‘learning rate’’ for each loss. Here, we choose \(\hat{w}_i=\hat{w}_j \ \forall i,j \in [1,m]\) to ensure a uniform decrease rate of all losses, as it was shown to yield a weak form of Pareto optimality for multi-task learning.

Meanwhile, we introduce an adaptive strategy for the length of \(\mathbf{g}_c\) rather than directly setting a fixed value of \(k\). We notice that the length of \(\mathbf{g}_c\) should increase when all loss-specific gradients point nearly in the same direction since it indicates a favorable direction for optimization. Conversely, when loss-specific gradients are close to opposing each other, the magnitude of \(\mathbf{g}_c\) should decrease. We realize this by rescaling the length of \(\mathbf{g}_c\) to the sum of the projection lengths of each loss-specific gradient on it, i.e., \(\|\mathbf{g}_c\|=\sum_{i=1}^m\|\mathbf{g}_i\|\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c)\).

The procedures above are summarized in the Conflict-Free Inverse Gradients (ConFIG) operator \(G\) and we correspondingly denote the final update gradient \(\mathbf{g}_c\) with \(\mathbf{g}_{\text{ConFIG}}\):

\[\mathbf{g}_u = \mathcal{U}\left[ [\mathcal{U}(\mathbf{g}_1),\mathcal{U}(\mathbf{g}_2),\cdots, \mathcal{U}(\mathbf{g}_m)]^{-\top} \mathbf{1}_m\right].\] \[\mathbf{g}_{\text{ConFIG}} =\mathcal{G}(\mathbf{g}_1,\mathbf{g}_1,\cdots,\mathbf{g}_m) = \left(\sum_{i=1}^m \mathbf{g}_i^\top \mathbf{g}_u \right)\mathbf{g}_u,\]

Here, \(\mathbf{1}_m\) is a unit vector with \(m\)​ components.

M-ConFIG: Momentum-Based Acceleration

In ConFIG and other gradient-based methods, computing separate gradients for each loss requires multiple backpropagation steps, increasing computational cost. To address this, we introduce M-ConFIG, which accelerates optimization by:

  • Replacing gradients with exponential moving average (EMA) gradients (momentums).
  • Updating only a subset of loss gradients per iteration while retaining momentum from previous updates.

Experiments show that M-ConFIG significantly reduces computational overhead while maintaining—if not improving—training performance.

Results: Faster Convergence, Better Predictions

We tested ConFIG on multiple PDE tasks in Physics-Informed Neural Networks (PINNs), including: Burgers’ Equation (1D fluid dynamics); Schrödinger Equation (Quantum mechanics); Navier-Stokes Equations (2D and 3D fluid flow simulations)

Experiments show that ConFIG significantly reduces test errors, converges faster than existing methods, and shows improved performance in all tests. M-ConFIG further reduces computation time by around 50%, greatly improving the feasibility of practical applications.

Relative improvements of PINNs trained with different methods under same wall time.
Test MSE of Adam baseline, ConFIG, and M-ConFIG as functions of wall time.

We also evaluated ConFIG on Multi-Task Learning (MTL) using the CelebA dataset (40 tasks) and compared it against state-of-the-art MTL methods. We found that ConFIG achieves the best or second-best performance across metrics, and M-ConFIG improves training efficiency, reducing computational cost while preserving accuracy. Additionally, for larger task sets, M-ConFIG requires tuning update frequency. Increasing momentum update steps allows M-ConFIG to match or even surpass standard ConFIG while maintaining efficiency.

Test evaluation for the CelebA experiments.

Conclusion

ConFIG offers a mathematically rigorous and computationally efficient solution to gradient conflicts in multi-loss optimization. It is applicable to PINNs, MTL, and other deep learning tasks, enhancing training stability and optimization efficiency. Our code is open-source at GitHub, featuring runnable notebook examples and practical suggestions.