PyTorchDiscriminator¶
- class PyTorchDiscriminator(n_features=1, n_out=1)[source]¶
Discriminator based on PyTorch
- Parameters
n_features (
int
) – Dimension of input data vector.n_out (
int
) – Dimension of the discriminator’s output vector.
- Raises
NameError – Pytorch not installed
Attributes
Get discriminator
Methods
PyTorchDiscriminator.get_label
(x[, detach])Get data sample labels, i.e. true or fake.
Compute gradient penalty for discriminator optimization
PyTorchDiscriminator.load_model
(load_dir)Load discriminator model
PyTorchDiscriminator.loss
(x, y[, weights])Loss function
PyTorchDiscriminator.save_model
(snapshot_dir)Save discriminator model
Set seed.
PyTorchDiscriminator.train
(data, weights[, …])Perform one training step w.r.t.