Source code for nlpaug.augmenter.sentence.lambada

"""
    Augmenter that apply operation (sentence level) to textual input based on abstractive summarization.
"""

import os
import json

from nlpaug.augmenter.sentence import SentenceAugmenter
import nlpaug.model.lang_models as nml
from nlpaug.util import Action, Doc

LAMBADA_MODELS = {}

def init_lambada_model(model_dir, threshold, min_length, max_length, batch_size, 
    temperature, top_k, top_p, repetition_penalty, device, force_reload):
    global LAMBADA_MODELS

    model_name = '_'.join([os.path.basename(model_dir), str(device)])
    if model_name in LAMBADA_MODELS and not force_reload:
        LAMBADA_MODELS[model_name].threshold = threshold
        LAMBADA_MODELS[model_name].min_length = min_length
        LAMBADA_MODELS[model_name].max_length = max_length
        LAMBADA_MODELS[model_name].batch_size = batch_size
        LAMBADA_MODELS[model_name].temperature = temperature
        LAMBADA_MODELS[model_name].top_k = top_k
        LAMBADA_MODELS[model_name].top_p = top_p
        LAMBADA_MODELS[model_name].repetition_penalty = repetition_penalty
        return LAMBADA_MODELS[model_name]

    model = nml.Lambada(
        cls_model_dir=os.path.join(model_dir, 'cls'), gen_model_dir=os.path.join(model_dir, 'gen'), 
        threshold=threshold, max_length=max_length, min_length=min_length, temperature=temperature, 
        top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, batch_size=batch_size,
        device=device)

    LAMBADA_MODELS[model_name] = model
    return model


[docs]class LambadaAug(SentenceAugmenter): """ Augmenter that leverage contextual word embeddings to find top n similar word for augmentation. :param str model_dir: Directory of model. It is generated from train_lambada.sh under scritps folders.n :param float threshold: The threshold of classification probabilty for accpeting generated text. Return all result if threshold is None. :param int batch_size: Batch size. :param int min_length: The min length of output text. :param int max_length: The max length of output text. :param float temperature: The value used to module the next token probabilities. :param int top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. :param float top_p: If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. :param float repetition_penalty : The parameter for repetition penalty. 1.0 means no penalty. :param str device: Default value is CPU. If value is CPU, it uses CPU for processing. If value is CUDA, it uses GPU for processing. Possible values include 'cuda' and 'cpu'. :param bool force_reload: Force reload the contextual word embeddings model to memory when initialize the class. Default value is False and suggesting to keep it as False if performance is the consideration. :param str name: Name of this augmenter >>> import nlpaug.augmenter.sentence as nas >>> aug = nas.LambadaAug() """ def __init__(self, model_dir, threshold=None, min_length=100, max_length=300, batch_size=16, temperature=1.0, top_k=50, top_p=0.9, repetition_penalty=1.0, name='Lambada_Aug', device='cpu', force_reload=False, verbose=0): super().__init__( action=Action.INSERT, name=name, tokenizer=None, stopwords=None, device=device, include_detail=False, verbose=verbose) self.model_dir = model_dir with open(os.path.join(model_dir, 'cls', 'label_encoder.json')) as json_file: self.label2id = json.load(json_file) self.model = self.get_model( model_dir=model_dir, threshold=threshold, max_length=max_length, min_length=min_length, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, batch_size=batch_size, device=device, force_reload=force_reload) self.device = self.model.get_device() def insert(self, data, n=10): if not data: return data if isinstance(data, list): all_data = data else: if data.strip() == '': return data all_data = [data] for d in all_data: if d not in self.label2id: raise Exception('Label {} does not exist. Possible labels are {}'.format(d, self.label2id.keys())) return self.model.predict(all_data, n=n) @classmethod def get_model(cls, model_dir, threshold, min_length, max_length, batch_size, temperature, top_k, top_p, repetition_penalty, device='cuda', force_reload=False): return init_lambada_model(model_dir, threshold, min_length, max_length, batch_size, temperature, top_k, top_p, repetition_penalty, device, force_reload)