Source code for qiskit.aqua.components.neural_networks.discriminative_network

# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2019, 2020.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

""" Discriminative Quantum or Classical Neural Networks."""

from abc import ABC, abstractmethod


[docs]class DiscriminativeNetwork(ABC): """ Base class for discriminative Quantum or Classical Neural Networks. This method should initialize the module but raise an exception if a required component of the module is not available. """ @abstractmethod def __init__(self) -> None: super().__init__() self._num_parameters = 0 self._num_qubits = 0 self._bounds = list()
[docs] @abstractmethod def set_seed(self, seed): """ Set seed. Args: seed (int): seed Raises: NotImplementedError: not implemented """ raise NotImplementedError()
[docs] @abstractmethod def get_label(self, x): """ Apply quantum/classical neural network to the given input sample and compute the respective data label Args: x (Discriminator): input, i.e. data sample. Raises: NotImplementedError: not implemented """ raise NotImplementedError()
[docs] @abstractmethod def loss(self, x, y, weights=None): """ Loss function used for optimization Args: x (Discriminator): output. y (Label): the data point weights (numpy.ndarray): Data weights. Returns: Loss w.r.t to the generated data points. Raises: NotImplementedError: not implemented """ raise NotImplementedError()
[docs] @abstractmethod def train(self, data, weights, penalty=False, quantum_instance=None, shots=None): """ Perform one training step w.r.t to the discriminator's parameters Args: data (numpy.ndarray): Data batch. weights (numpy.ndarray): Data sample weights. penalty (bool): Indicate whether or not penalty function is applied to the loss function. Ignored if no penalty function defined. quantum_instance (QuantumInstance): used to run Quantum network. Ignored for a classical network. shots (int): Number of shots for hardware or qasm execution. Ignored for classical network Returns: dict: with Discriminator loss and updated parameters. Raises: NotImplementedError: not implemented """ raise NotImplementedError()