Misplaced Pages

Wasserstein GAN

Article snapshot taken from Wikipedia with creative commons attribution-sharealike license. Give it a read and then ask your questions in the chat. We can research this topic together.
Generative adversarial network variant
Part of a series on
Machine learning
and data mining
Paradigms
Problems
Supervised learning
(classification • regression)
Clustering
Dimensionality reduction
Structured prediction
Anomaly detection
Artificial neural network
Reinforcement learning
Learning with humans
Model diagnostics
Mathematical foundations
Journals and conferences
Related articles

The Wasserstein Generative Adversarial Network (WGAN) is a variant of generative adversarial network (GAN) proposed in 2017 that aims to "improve the stability of learning, get rid of problems like mode collapse, and provide meaningful learning curves useful for debugging and hyperparameter searches".

Compared with the original GAN discriminator, the Wasserstein GAN discriminator provides a better learning signal to the generator. This allows the training to be more stable when generator is learning distributions in very high dimensional spaces.

Motivation

The GAN game

The original GAN method is based on the GAN game, a zero-sum game with 2 players: generator and discriminator. The game is defined over a probability space ( Ω , B , μ r e f ) {\displaystyle (\Omega ,{\mathcal {B}},\mu _{ref})} , The generator's strategy set is the set of all probability measures μ G {\displaystyle \mu _{G}} on ( Ω , B ) {\displaystyle (\Omega ,{\mathcal {B}})} , and the discriminator's strategy set is the set of measurable functions D : Ω [ 0 , 1 ] {\displaystyle D:\Omega \to } .

The objective of the game is L ( μ G , D ) := E x μ r e f [ ln D ( x ) ] + E x μ G [ ln ( 1 D ( x ) ) ] . {\displaystyle L(\mu _{G},D):=\mathbb {E} _{x\sim \mu _{ref}}+\mathbb {E} _{x\sim \mu _{G}}.} The generator aims to minimize it, and the discriminator aims to maximize it.

A basic theorem of the GAN game states that

Theorem (the optimal discriminator computes the Jensen–Shannon divergence) — For any fixed generator strategy μ G {\displaystyle \mu _{G}} , let the optimal reply be D = arg max D L ( μ G , D ) {\displaystyle D^{*}=\arg \max _{D}L(\mu _{G},D)} , then

D ( x ) = d μ r e f d ( μ r e f + μ G ) L ( μ G , D ) = 2 D J S ( μ r e f ; μ G ) 2 ln 2 , {\displaystyle {\begin{aligned}D^{*}(x)&={\frac {d\mu _{ref}}{d(\mu _{ref}+\mu _{G})}}\\L(\mu _{G},D^{*})&=2D_{JS}(\mu _{ref};\mu _{G})-2\ln 2,\end{aligned}}}

where the derivative is the Radon–Nikodym derivative, and D J S {\displaystyle D_{JS}} is the Jensen–Shannon divergence.

Repeat the GAN game many times, each time with the generator moving first, and the discriminator moving second. Each time the generator μ G {\displaystyle \mu _{G}} changes, the discriminator must adapt by approaching the ideal D ( x ) = d μ r e f d ( μ r e f + μ G ) . {\displaystyle D^{*}(x)={\frac {d\mu _{ref}}{d(\mu _{ref}+\mu _{G})}}.} Since we are really interested in μ r e f {\displaystyle \mu _{ref}} , the discriminator function D {\displaystyle D} is by itself rather uninteresting. It merely keeps track of the likelihood ratio between the generator distribution and the reference distribution. At equilibrium, the discriminator is just outputting 1 2 {\displaystyle {\frac {1}{2}}} constantly, having given up trying to perceive any difference.

Concretely, in the GAN game, let us fix a generator μ G {\displaystyle \mu _{G}} , and improve the discriminator step-by-step, with μ D , t {\displaystyle \mu _{D,t}} being the discriminator at step t {\displaystyle t} . Then we (ideally) have L ( μ G , μ D , 1 ) L ( μ G , μ D , 2 ) max μ D L ( μ G , μ D ) = 2 D J S ( μ r e f μ G ) 2 ln 2 , {\displaystyle L(\mu _{G},\mu _{D,1})\leq L(\mu _{G},\mu _{D,2})\leq \cdots \leq \max _{\mu _{D}}L(\mu _{G},\mu _{D})=2D_{JS}(\mu _{ref}\|\mu _{G})-2\ln 2,} so we see that the discriminator is actually lower-bounding D J S ( μ r e f μ G ) {\displaystyle D_{JS}(\mu _{ref}\|\mu _{G})} .

Wasserstein distance

Thus, we see that the point of the discriminator is mainly as a critic to provide feedback for the generator, about "how far it is from perfection", where "far" is defined as Jensen–Shannon divergence.

Naturally, this brings the possibility of using a different criteria of farness. There are many possible divergences to choose from, such as the f-divergence family, which would give the f-GAN.

The Wasserstein GAN is obtained by using the Wasserstein metric, which satisfies a "dual representation theorem" that renders it highly efficient to compute:

Theorem (Kantorovich-Rubenstein duality) — When the probability space Ω {\displaystyle \Omega } is a metric space, then for any fixed K > 0 {\displaystyle K>0} , W 1 ( μ , ν ) = 1 K sup f L K E x μ [ f ( x ) ] E y ν [ f ( y ) ] {\displaystyle W_{1}(\mu ,\nu )={\frac {1}{K}}\sup _{\|f\|_{L}\leq K}\mathbb {E} _{x\sim \mu }-\mathbb {E} _{y\sim \nu }} where L {\displaystyle \|\cdot \|_{L}} is the Lipschitz norm.

A proof can be found in the main page on Wasserstein metric.

Definition

By the Kantorovich-Rubenstein duality, the definition of Wasserstein GAN is clear:

A Wasserstein GAN game is defined by a probability space ( Ω , B , μ r e f ) {\displaystyle (\Omega ,{\mathcal {B}},\mu _{ref})} , where Ω {\displaystyle \Omega } is a metric space, and a constant K > 0 {\displaystyle K>0} .

There are 2 players: generator and discriminator (also called "critic").

The generator's strategy set is the set of all probability measures μ G {\displaystyle \mu _{G}} on ( Ω , B ) {\displaystyle (\Omega ,{\mathcal {B}})} .

The discriminator's strategy set is the set of measurable functions of type D : Ω R {\displaystyle D:\Omega \to \mathbb {R} } with bounded Lipschitz-norm: D L K {\displaystyle \|D\|_{L}\leq K} .

The Wasserstein GAN game is a zero-sum game, with objective function L W G A N ( μ G , D ) := E x μ G [ D ( x ) ] E x μ r e f [ D ( x ) ] . {\displaystyle L_{WGAN}(\mu _{G},D):=\mathbb {E} _{x\sim \mu _{G}}-\mathbb {E} _{x\sim \mu _{ref}}.}

The generator goes first, and the discriminator goes second. The generator aims to minimize the objective, and the discriminator aims to maximize the objective: min μ G max D L W G A N ( μ G , D ) . {\displaystyle \min _{\mu _{G}}\max _{D}L_{WGAN}(\mu _{G},D).}

By the Kantorovich-Rubenstein duality, for any generator strategy μ G {\displaystyle \mu _{G}} , the optimal reply by the discriminator is D {\displaystyle D^{*}} , such that L W G A N ( μ G , D ) = K W 1 ( μ G , μ r e f ) . {\displaystyle L_{WGAN}(\mu _{G},D^{*})=K\cdot W_{1}(\mu _{G},\mu _{ref}).} Consequently, if the discriminator is good, the generator would be constantly pushed to minimize W 1 ( μ G , μ r e f ) {\displaystyle W_{1}(\mu _{G},\mu _{ref})} , and the optimal strategy for the generator is just μ G = μ r e f {\displaystyle \mu _{G}=\mu _{ref}} , as it should.

Comparison with GAN

In the Wasserstein GAN game, the discriminator provides a better gradient than in the GAN game.

Consider for example a game on the real line where both μ G {\displaystyle \mu _{G}} and μ r e f {\displaystyle \mu _{ref}} are Gaussian. Then the optimal Wasserstein critic D W G A N {\displaystyle D_{WGAN}} and the optimal GAN discriminator D {\displaystyle D} are plotted as below:

The optimal Wasserstein critic D W G A N {\displaystyle D_{WGAN}} and the optimal GAN discriminator D {\displaystyle D} for a fixed reference distribution μ r e f {\displaystyle \mu _{ref}} and generator distribution μ G {\displaystyle \mu _{G}} . Both the Wasserstein critic D W G A N {\displaystyle D_{WGAN}} and the GAN discriminator D {\displaystyle D} are scaled down to fit the plot.

For fixed discriminator, the generator needs to minimize the following objectives:

  • For GAN, E x μ G [ ln ( 1 D ( x ) ) ] {\displaystyle \mathbb {E} _{x\sim \mu _{G}}} .
  • For Wasserstein GAN, E x μ G [ D W G A N ( x ) ] {\displaystyle \mathbb {E} _{x\sim \mu _{G}}} .

Let μ G {\displaystyle \mu _{G}} be parametrized by θ {\displaystyle \theta } , then we can perform stochastic gradient descent by using two unbiased estimators of the gradient: θ E x μ G [ ln ( 1 D ( x ) ) ] = E x μ G [ ln ( 1 D ( x ) ) θ ln ρ μ G ( x ) ] {\displaystyle \nabla _{\theta }\mathbb {E} _{x\sim \mu _{G}}=\mathbb {E} _{x\sim \mu _{G}}} θ E x μ G [ D W G A N ( x ) ] = E x μ G [ D W G A N ( x ) θ ln ρ μ G ( x ) ] {\displaystyle \nabla _{\theta }\mathbb {E} _{x\sim \mu _{G}}=\mathbb {E} _{x\sim \mu _{G}}} where we used the reparameterization trick.

The same plot, but with the GAN discriminator D {\displaystyle D} replaced by ln ( 1 D ) {\displaystyle \ln(1-D)} (and scaled down to fit the plot)

As shown, the generator in GAN is motivated to let its μ G {\displaystyle \mu _{G}} "slide down the peak" of ln ( 1 D ( x ) ) {\displaystyle \ln(1-D(x))} . Similarly for the generator in Wasserstein GAN.

For Wasserstein GAN, D W G A N {\displaystyle D_{WGAN}} has gradient 1 almost everywhere, while for GAN, ln ( 1 D ) {\displaystyle \ln(1-D)} has flat gradient in the middle, and steep gradient elsewhere. As a result, the variance for the estimator in GAN is usually much larger than that in Wasserstein GAN. See also Figure 3 of.

The problem with D J S {\displaystyle D_{JS}} is much more severe in actual machine learning situations. Consider training a GAN to generate ImageNet, a collection of photos of size 256-by-256. The space of all such photos is R 256 2 {\displaystyle \mathbb {R} ^{256^{2}}} , and the distribution of ImageNet pictures, μ r e f {\displaystyle \mu _{ref}} , concentrates on a manifold of much lower dimension in it. Consequently, any generator strategy μ G {\displaystyle \mu _{G}} would almost surely be entirely disjoint from μ r e f {\displaystyle \mu _{ref}} , making D J S ( μ G μ r e f ) = + {\displaystyle D_{JS}(\mu _{G}\|\mu _{ref})=+\infty } . Thus, a good discriminator can almost perfectly distinguish μ r e f {\displaystyle \mu _{ref}} from μ G {\displaystyle \mu _{G}} , as well as any μ G {\displaystyle \mu _{G}'} close to μ G {\displaystyle \mu _{G}} . Thus, the gradient μ G L ( μ G , D ) 0 {\displaystyle \nabla _{\mu _{G}}L(\mu _{G},D)\approx 0} , creating no learning signal for the generator.

Detailed theorems can be found in.

Training Wasserstein GANs

Training the generator in Wasserstein GAN is just gradient descent, the same as in GAN (or most deep learning methods), but training the discriminator is different, as the discriminator is now restricted to have bounded Lipschitz norm. There are several methods for this.

Upper-bounding the Lipschitz norm

Let the discriminator function D {\displaystyle D} to be implemented by a multilayer perceptron: D = D n D n 1 D 1 {\displaystyle D=D_{n}\circ D_{n-1}\circ \cdots \circ D_{1}} where D i ( x ) = h ( W i x ) {\displaystyle D_{i}(x)=h(W_{i}x)} , and h : R R {\displaystyle h:\mathbb {R} \to \mathbb {R} } is a fixed activation function with sup x | h ( x ) | 1 {\displaystyle \sup _{x}|h'(x)|\leq 1} . For example, the hyperbolic tangent function h = tanh {\displaystyle h=\tanh } satisfies the requirement.

Then, for any x {\displaystyle x} , let x i = ( D i D i 1 D 1 ) ( x ) {\displaystyle x_{i}=(D_{i}\circ D_{i-1}\circ \cdots \circ D_{1})(x)} , we have by the chain rule: d D ( x ) = d i a g ( h ( W n x n 1 ) ) W n d i a g ( h ( W n 1 x n 2 ) ) W n 1 d i a g ( h ( W 1 x ) ) W 1 d x {\displaystyle dD(x)=diag(h'(W_{n}x_{n-1}))\cdot W_{n}\cdot diag(h'(W_{n-1}x_{n-2}))\cdot W_{n-1}\cdots diag(h'(W_{1}x))\cdot W_{1}\cdot dx} Thus, the Lipschitz norm of D {\displaystyle D} is upper-bounded by D L sup x d i a g ( h ( W n x n 1 ) ) W n d i a g ( h ( W n 1 x n 2 ) ) W n 1 d i a g ( h ( W 1 x ) ) W 1 F {\displaystyle \|D\|_{L}\leq \sup _{x}\|diag(h'(W_{n}x_{n-1}))\cdot W_{n}\cdot diag(h'(W_{n-1}x_{n-2}))\cdot W_{n-1}\cdots diag(h'(W_{1}x))\cdot W_{1}\|_{F}} where s {\displaystyle \|\cdot \|_{s}} is the operator norm of the matrix, that is, the largest singular value of the matrix, that is, the spectral radius of the matrix (these concepts are the same for matrices, but different for general linear operators).

Since sup x | h ( x ) | 1 {\displaystyle \sup _{x}|h'(x)|\leq 1} , we have d i a g ( h ( W i x i 1 ) ) s = max j | h ( W i x i 1 , j ) | 1 {\displaystyle \|diag(h'(W_{i}x_{i-1}))\|_{s}=\max _{j}|h'(W_{i}x_{i-1,j})|\leq 1} , and consequently the upper bound: D L i = 1 n W i s {\displaystyle \|D\|_{L}\leq \prod _{i=1}^{n}\|W_{i}\|_{s}} Thus, if we can upper-bound operator norms W i s {\displaystyle \|W_{i}\|_{s}} of each matrix, we can upper-bound the Lipschitz norm of D {\displaystyle D} .

Weight clipping

Since for any m × l {\displaystyle m\times l} matrix W {\displaystyle W} , let c = max i , j | W i , j | {\displaystyle c=\max _{i,j}|W_{i,j}|} , we have W s 2 = sup x 2 = 1 W x 2 2 = sup x 2 = 1 i ( j W i , j x j ) 2 = sup x 2 = 1 i , j , k W i j W i k x j x k c 2 m l 2 {\displaystyle \|W\|_{s}^{2}=\sup _{\|x\|_{2}=1}\|Wx\|_{2}^{2}=\sup _{\|x\|_{2}=1}\sum _{i}\left(\sum _{j}W_{i,j}x_{j}\right)^{2}=\sup _{\|x\|_{2}=1}\sum _{i,j,k}W_{ij}W_{ik}x_{j}x_{k}\leq c^{2}ml^{2}} by clipping all entries of W {\displaystyle W} to within some interval [ c , c ] {\displaystyle } , we have can bound W s {\displaystyle \|W\|_{s}} .

This is the weight clipping method, proposed by the original paper.

Spectral normalization

The spectral radius can be efficiently computed by the following algorithm:

INPUT matrix W {\displaystyle W} and initial guess x {\displaystyle x}

Iterate x 1 W x 2 W x {\displaystyle x\mapsto {\frac {1}{\|Wx\|_{2}}}Wx} to convergence x {\displaystyle x^{*}} . This is the eigenvector of W {\displaystyle W} with eigenvalue W s {\displaystyle \|W\|_{s}} .

RETURN x , W x 2 {\displaystyle x^{*},\|Wx^{*}\|_{2}}

By reassigning W i W i W i s {\displaystyle W_{i}\leftarrow {\frac {W_{i}}{\|W_{i}\|_{s}}}} after each update of the discriminator, we can upper bound W i s 1 {\displaystyle \|W_{i}\|_{s}\leq 1} , and thus upper bound D L {\displaystyle \|D\|_{L}} .

The algorithm can be further accelerated by memoization: At step t {\displaystyle t} , store x i ( t ) {\displaystyle x_{i}^{*}(t)} . Then at step t + 1 {\displaystyle t+1} , use x i ( t ) {\displaystyle x_{i}^{*}(t)} as the initial guess for the algorithm. Since W i ( t + 1 ) {\displaystyle W_{i}(t+1)} is very close to W i ( t ) {\displaystyle W_{i}(t)} , so is x i ( t ) {\displaystyle x_{i}^{*}(t)} close to x i ( t + 1 ) {\displaystyle x_{i}^{*}(t+1)} , so this allows rapid convergence.

This is the spectral normalization method.

Gradient penalty

Instead of strictly bounding D L {\displaystyle \|D\|_{L}} , we can simply add a "gradient penalty" term for the discriminator, of form E x μ ^ [ ( D ( x ) 2 a ) 2 ] {\displaystyle \mathbb {E} _{x\sim {\hat {\mu }}}} where μ ^ {\displaystyle {\hat {\mu }}} is a fixed distribution used to estimate how much the discriminator has violated the Lipschitz norm requirement. The discriminator, in attempting to minimize the new loss function, would naturally bring D ( x ) {\displaystyle \nabla D(x)} close to a {\displaystyle a} everywhere, thus making D L a {\displaystyle \|D\|_{L}\approx a} .

This is the gradient penalty method.

Further reading

See also

References

  1. ^ Arjovsky, Martin; Chintala, Soumith; Bottou, Léon (2017-07-17). "Wasserstein Generative Adversarial Networks". International Conference on Machine Learning. PMLR: 214–223.
  2. Weng, Lilian (2019-04-18). "From GAN to WGAN". arXiv:1904.08994 .
  3. Nowozin, Sebastian; Cseke, Botond; Tomioka, Ryota (2016). "f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization". Advances in Neural Information Processing Systems. 29. Curran Associates, Inc. arXiv:1606.00709.
  4. Arjovsky, Martin; Bottou, Léon (2017-01-01). "Towards Principled Methods for Training Generative Adversarial Networks". arXiv:1701.04862. {{cite journal}}: Cite journal requires |journal= (help)
  5. Miyato, Takeru; Kataoka, Toshiki; Koyama, Masanori; Yoshida, Yuichi (2018-02-16). "Spectral Normalization for Generative Adversarial Networks". arXiv:1802.05957 .
  6. Gulrajani, Ishaan; Ahmed, Faruk; Arjovsky, Martin; Dumoulin, Vincent; Courville, Aaron C (2017). "Improved Training of Wasserstein GANs". Advances in Neural Information Processing Systems. 30. Curran Associates, Inc.

Notes

  1. In practice, the generator would never be able to reach perfect imitation, and so the discriminator would have motivation for perceiving the difference, which allows it to be used for other tasks, such as performing ImageNet classification without supervision.
  2. This is not how it is really done in practice, since θ ln ρ μ G ( x ) {\displaystyle \nabla _{\theta }\ln \rho _{\mu _{G}}(x)} is in general intractable, but it is theoretically illuminating.
Differentiable computing
General
Hardware
Software libraries
Categories: