qiskit.aqua.components.neural_networks.PyTorchDiscriminator¶
-
class
PyTorchDiscriminator
(n_features=1, n_out=1)[código fonte]¶ Discriminator based on PyTorch
- Parâmetros
n_features (
int
) – Dimension of input data vector.n_out (
int
) – Dimension of the discriminator’s output vector.
- Levanta
MissingOptionalLibraryError – Pytorch not installed
-
__init__
(n_features=1, n_out=1)[código fonte]¶ - Parâmetros
n_features (
int
) – Dimension of input data vector.n_out (
int
) – Dimension of the discriminator’s output vector.
- Levanta
MissingOptionalLibraryError – Pytorch not installed
Methods
__init__
([n_features, n_out])- type n_features
int
get_label
(x[, detach])Get data sample labels, i.e. true or fake.
gradient_penalty
(x[, lambda_, k, c])Compute gradient penalty for discriminator optimization
load_model
(load_dir)Load discriminator model
loss
(x, y[, weights])Loss function
save_model
(snapshot_dir)Save discriminator model
set_seed
(seed)Set seed.
train
(data, weights[, penalty, …])Perform one training step w.r.t to the discriminator’s parameters
Attributes
Get discriminator
-
property
discriminator_net
¶ Get discriminator
- Retorna
discriminator object
- Tipo de retorno
object
-
get_label
(x, detach=False)[código fonte]¶ Get data sample labels, i.e. true or fake.
- Parâmetros
x (Union(numpy.ndarray, torch.Tensor)) – Discriminator input, i.e. data sample.
detach (bool) – if None detach from torch tensor variable (optional)
- Retorna
Discriminator output, i.e. data label
- Tipo de retorno
torch.Tensor
-
gradient_penalty
(x, lambda_=5.0, k=0.01, c=1.0)[código fonte]¶ Compute gradient penalty for discriminator optimization
- Parâmetros
x (numpy.ndarray) – Generated data sample.
lambda (float) – Gradient penalty coefficient 1.
k (float) – Gradient penalty coefficient 2.
c (float) – Gradient penalty coefficient 3.
- Retorna
Gradient penalty.
- Tipo de retorno
torch.Tensor
-
load_model
(load_dir)[código fonte]¶ Load discriminator model
- Parâmetros
load_dir (
str
) – file with stored pytorch discriminator model to be loaded
-
loss
(x, y, weights=None)[código fonte]¶ Loss function
- Parâmetros
x (torch.Tensor) – Discriminator output.
y (torch.Tensor) – Label of the data point
weights (torch.Tensor) – Data weights.
- Retorna
Loss w.r.t to the generated data points.
- Tipo de retorno
torch.Tensor
-
save_model
(snapshot_dir)[código fonte]¶ Save discriminator model
- Parâmetros
snapshot_dir (
str
) – directory path for saving the model
-
set_seed
(seed)[código fonte]¶ Set seed.
- Parâmetros
seed (
int
) – seed
-
train
(data, weights, penalty=False, quantum_instance=None, shots=None)[código fonte]¶ Perform one training step w.r.t to the discriminator’s parameters
- Parâmetros
data (
Iterable
) – Data batch.weights (
Iterable
) – Data sample weights.penalty (
bool
) – Indicate whether or not penalty function is applied to the loss function. Ignored if no penalty function defined.quantum_instance (QuantumInstance) – used to run Quantum network. Ignored for a classical network.
shots (
Optional
[int
]) – Number of shots for hardware or qasm execution. Ignored for classical network
- Retorna
- with discriminator loss and updated parameters.data, weights, penalty=True,
quantum_instance=None, shots=None) -> Dict[str, Any]:
- Tipo de retorno
dict