Source code for nlpaug.augmenter.word.back_translation

    Augmenter that apply operation (word level) to textual input based on back translation.

import string
import os

from nlpaug.augmenter.word import WordAugmenter
import nlpaug.model.lang_models as nml


def init_back_translation_model(from_model_name, to_model_name, device, force_reload=False,
                                batch_size=32, max_length=None):

    model_name = '_'.join([from_model_name, to_model_name, str(device)])
    if model_name in BACK_TRANSLATION_MODELS and not force_reload:
        BACK_TRANSLATION_MODELS[model_name].batch_size = batch_size
        BACK_TRANSLATION_MODELS[model_name].max_length = max_length

        return BACK_TRANSLATION_MODELS[model_name]

    model = nml.MtTransformers(src_model_name=from_model_name, tgt_model_name=to_model_name, 
        device=device, batch_size=batch_size, max_length=max_length)

    BACK_TRANSLATION_MODELS[model_name] = model
    return model

[docs]class BackTranslationAug(WordAugmenter): # """ Augmenter that leverage two translation models for augmentation. For example, the source is English. This augmenter translate source to German and translating it back to English. For detail, you may visit :param str from_model_name: Any model from As long as from_model_name is pair with to_model_name. For example, from_model_name is English to Japanese, then to_model_name should be Japanese to English. :param str to_model_name: Any model from :param str device: Default value is CPU. If value is CPU, it uses CPU for processing. If value is CUDA, it uses GPU for processing. Possible values include 'cuda' and 'cpu'. (May able to use other options) :param bool force_reload: Force reload the contextual word embeddings model to memory when initialize the class. Default value is False and suggesting to keep it as False if performance is the consideration. :param int batch_size: Batch size. :param int max_length: The max length of output text. :param str name: Name of this augmenter >>> import nlpaug.augmenter.word as naw >>> aug = naw.BackTranslationAug() """ def __init__(self, from_model_name='facebook/wmt19-en-de', to_model_name='facebook/wmt19-de-en', name='BackTranslationAug', device='cpu', batch_size=32, max_length=300, force_reload=False, verbose=0): super().__init__( action='substitute', name=name, aug_p=None, aug_min=None, aug_max=None, tokenizer=None, device=device, verbose=verbose, include_detail=False) self.model = self.get_model(from_model_name=from_model_name, to_model_name=to_model_name, device=device, batch_size=batch_size, max_length=max_length ) self.device = self.model.device def substitute(self, data, n=1): if not data: return data augmented_text = self.model.predict(data) return augmented_text @classmethod def get_model(cls, from_model_name, to_model_name, device='cuda', force_reload=False, batch_size=32, max_length=None): return init_back_translation_model(from_model_name, to_model_name, device, force_reload, batch_size, max_length) @classmethod def clear_cache(cls): global BACK_TRANSLATION_MODELS BACK_TRANSLATION_MODELS = {}