Source code for qiskit.aqua.components.multiclass_extensions.multiclass_extension

# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2018, 2020.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

""" Base class for multiclass extension """

from typing import Optional, List, Callable
from abc import ABC, abstractmethod
from .estimator import Estimator


[docs]class MulticlassExtension(ABC): """ Base class for multiclass extension. This method should initialize the module and use an exception if a component of the module is not available. """ @abstractmethod def __init__(self) -> None: super().__init__() self.estimator_cls = None self.params = []
[docs] def set_estimator(self, estimator_cls: Callable[[List], Estimator], params: Optional[List] = None) -> None: """ Called internally to set :class:`Estimator` and parameters Args: estimator_cls: An :class:`Estimator` class params: Parameters for the estimator """ self.estimator_cls = estimator_cls self.params = params if params is not None else []
[docs] @abstractmethod def train(self, x, y): """ Training multiple estimators each for distinguishing a pair of classes. Args: x (numpy.ndarray): input points y (numpy.ndarray): input labels """ raise NotImplementedError()
[docs] @abstractmethod def test(self, x, y): """ Testing multiple estimators each for distinguishing a pair of classes. Args: x (numpy.ndarray): input points y (numpy.ndarray): input labels """ raise NotImplementedError()
[docs] @abstractmethod def predict(self, x): """ Applying multiple estimators for prediction. Args: x (numpy.ndarray): input points """ raise NotImplementedError()