PyTorchDiscriminator

class PyTorchDiscriminator(n_features=1, n_out=1)[source]

Discriminator based on PyTorch

Parameters
  • n_features (int) – Dimension of input data vector.

  • n_out (int) – Dimension of the discriminator’s output vector.

Raises

NameError – Pytorch not installed

Attributes

PyTorchDiscriminator.discriminator_net

Get discriminator

Methods

PyTorchDiscriminator.get_label(x[, detach])

Get data sample labels, i.e. true or fake.

PyTorchDiscriminator.gradient_penalty(x[, …])

Compute gradient penalty for discriminator optimization

PyTorchDiscriminator.load_model(load_dir)

Load discriminator model

PyTorchDiscriminator.loss(x, y[, weights])

Loss function

PyTorchDiscriminator.save_model(snapshot_dir)

Save discriminator model

PyTorchDiscriminator.set_seed(seed)

Set seed.

PyTorchDiscriminator.train(data, weights[, …])

Perform one training step w.r.t.