Source code for nlpaug.base_augmenter

import math
import random
import numpy as np
import pandas as pd
from multiprocessing.dummy import Pool as ThreadPool

from nlpaug.util import Action, Method, WarningException, WarningName, WarningCode, WarningMessage


class Augmenter:
    def __init__(self, name, method, action, aug_min, aug_max, aug_p=0.1, device='cpu', 
        include_detail=False, verbose=0):

        self.name = name
        self.action = action
        self.method = method
        self.aug_min = aug_min
        self.aug_max = aug_max
        self.aug_p = aug_p
        self.device = device
        self.verbose = verbose
        self.include_detail = include_detail

        self.parent_change_seq = 0

        self._validate_augmenter(method, action)

    @classmethod
    def _validate_augmenter(cls, method, action):
        if method not in Method.getall():
            raise ValueError(
                'Method must be one of {} while {} is passed'.format(Method.getall(), method))

        if action not in Action.getall():
            raise ValueError(
                'Action must be one of {} while {} is passed'.format(Action.getall(), action))

    def augment(self, data, n=1, num_thread=1):
        """
        :param object/list data: Data for augmentation. It can be list of data (e.g. list 
            of string or numpy) or single element (e.g. string or numpy). Numpy format only
            supports audio or spectrogram data. For text data, only support string or
            list of string.
        :param int n: Default is 1. Number of unique augmented output. Will be force to 1 
            if input is list of data
        :param int num_thread: Number of thread for data augmentation. Use this option 
            when you are using CPU and n is larger than 1
        :return: Augmented data

        >>> augmented_data = aug.augment(data)

        """
        max_retry_times = 3  # max loop times of n to generate expected number of outputs
        aug_num = 1 if isinstance(data, list) else n
        expected_output_num = len(data) if isinstance(data, list) else aug_num

        exceptions = self._validate_augment(data)
        # TODO: Handle multiple exceptions
        for exception in exceptions:
            if isinstance(exception, WarningException):
                if self.verbose > 0:
                    exception.output()

                # Return empty value per data type
                if isinstance(data, str):
                    return []
                elif isinstance(data, list):
                    return []
                elif isinstance(data, np.ndarray):
                    return np.array([])

                return []

        action_fx = None
        clean_data = self.clean(data)
        if self.action == Action.INSERT:
            action_fx = self.insert
        elif self.action == Action.SUBSTITUTE:
            action_fx = self.substitute
        elif self.action == Action.SWAP:
            action_fx = self.swap
        elif self.action == Action.DELETE:
            action_fx = self.delete
        elif self.action == Action.CROP:
            action_fx = self.crop
        elif self.action == Action.SPLIT:
            action_fx = self.split

        for _ in range(max_retry_times+1):
            augmented_results = []

            # By design, it is one-to-many
            if self.__class__.__name__ in ['LambadaAug']:
                augmented_results = action_fx(clean_data, n=n)
            # PyTorch's augmenter
            elif self.__class__.__name__ in ['AbstSummAug', 'BackTranslationAug', 'ContextualWordEmbsAug', 'ContextualWordEmbsForSentenceAug']:
                for _ in range(aug_num):
                    result = action_fx(clean_data)
                    if isinstance(result, list):
                        augmented_results.extend(result)
                    else:
                        augmented_results.append(result)
            # Multi inputs
            elif isinstance(data, list):
                # Single Thread
                if num_thread == 1:
                    augmented_results = [action_fx(d) for d in clean_data]

                # Multi Thread
                else:
                    batch_data = [data[i:i+num_thread] for i in range(0, len(data), num_thread)]
                    for mini_batch_data in batch_data:
                        augmented_results.extend(self._parallel_augments(self.augment, mini_batch_data))

            # Single input with/without multiple input
            else:
                # Single Thread
                if num_thread == 1:
                    augmented_results = [action_fx(clean_data) for _ in range(n)]

                # Multi Thread
                else:
                    augmented_results = self._parallel_augment(action_fx, clean_data, n=n, num_thread=num_thread)

            if len(augmented_results) >= expected_output_num:
                break

        if len(augmented_results) == 0:
            # if not result, return itself
            if isinstance(data, list):
                return data
            # Single input with/without multiple input
            else:
                return [data]

        if isinstance(augmented_results, pd.DataFrame):
            return augmented_results
        else:
            if isinstance(data, list):
                return augmented_results
            else:
                return augmented_results[:n]

        # return augmented_results

    # def augments(self, data, num_thread=1):
    #     """
    #     :param list data: List of data
    #     :param int num_thread: Number of thread for data augmentation. Use this option when you are using CPU and
    #         n is larger than 1. Do NOT support GPU process.
    #     :return: Augmented data (Does not follow original order)

    #     >>> augmented_data = aug.augment(data)

    #     """
    #     n = 1
    #     augmented_results = []
    #     if num_thread == 1 or self.device == 'cuda':
    #         for d in data:
    #             augmented_result = self.augment(data=d, n=n, num_thread=1)  # TOOD: cuda does not support mulithread
    #             if n == 1:
    #                 augmented_results.append(augmented_result)
    #             else:
    #                 augmented_results.extend(augmented_result)
    #     else:
    #         batch_data = [data[i:i+num_thread] for i in range(0, len(data), num_thread)]
    #         for i in range(n):
    #             for mini_batch_data in batch_data:
    #                 augmented_results.extend(self._parallel_augments(self.augment, mini_batch_data))

    #     return augmented_results

    @classmethod
    def _validate_augment(cls, data):
        if data is None or len(data) == 0:
            return [WarningException(name=WarningName.INPUT_VALIDATION_WARNING,
                                     code=WarningCode.WARNING_CODE_001, msg=WarningMessage.LENGTH_IS_ZERO)]

        return []

    @classmethod
    def _parallel_augment(cls, action_fx, data, n, num_thread=2):
        pool = ThreadPool(num_thread)
        results = pool.map(action_fx, [data] * n)
        pool.close()
        pool.join()
        return results

    @classmethod
    def _parallel_augments(cls, action_fx, data):
        pool = ThreadPool(len(data))
        results = pool.map(action_fx, data)
        pool.close()
        pool.join()
        return results

    def insert(self, data):
        raise NotImplementedError

    def substitute(self, data):
        raise NotImplementedError

    def swap(self, data):
        raise NotImplementedError

    def delete(self, data):
        raise NotImplementedError

    def crop(self, data):
        raise NotImplementedError        

    def split(self, data):
        raise NotImplementedError

    def tokenizer(self, tokens):
        raise NotImplementedError

    def evaluate(self):
        raise NotImplementedError

    @classmethod
    def is_duplicate(cls, dataset, data):
        raise NotImplementedError

    @classmethod
    def prob(cls):
        return np.random.random()

    @classmethod
    def sample(cls, x, num=None):
        if isinstance(x, list):
            return random.sample(x, num)
        elif isinstance(x, int):
            return np.random.randint(1, x-1)

    @classmethod
    def clean(cls, data):
        raise NotImplementedError

    def _generate_aug_cnt(self, size, aug_min, aug_max, aug_p=None):
        if aug_p is not None:
            percent = aug_p
        elif self.aug_p:
            percent = self.aug_p
        else:
            percent = 0.3
        cnt = int(math.ceil(percent * size))

        if aug_min and cnt < aug_min:
            return aug_min
        if aug_max and cnt > aug_max:
            return aug_max
        return cnt

    def generate_aug_cnt(self, size, aug_p=None):
        if size == 0:
            return 0
        return self._generate_aug_cnt(size, self.aug_min, self.aug_max, aug_p)

    def generate_aug_idxes(self, inputs):
        aug_cnt = self.generate_aug_cnt(len(inputs))
        token_idxes = [i for i, _ in enumerate(inputs)]
        aug_idxes = self.sample(token_idxes, aug_cnt)
        return aug_idxes

    def _get_random_aug_idxes(self, data):
        aug_cnt = self.generate_aug_cnt(len(data))
        idxes = self.pre_skip_aug(data)
        if len(idxes) < aug_cnt:
            aug_cnt = len(idxes)

        aug_idxes = self.sample(idxes, aug_cnt)

        return aug_idxes

    def __str__(self):
        return 'Name:{}, Action:{}, Method:{}'.format(self.name, self.action, self.method)