Source code for nlpaug.augmenter.word.word_embs

"""
    Augmenter that apply operation to textual input based on word embeddings.
"""

from nlpaug.augmenter.word import WordAugmenter
from nlpaug.util import Action, Doc
import nlpaug.model.word_embs as nmw


WORD_EMBS_MODELS = {}
model_types = ['word2vec', 'glove', 'fasttext']


def init_word_embs_model(model_path, model_type, force_reload=False, top_k=None, skip_check=False):
    global WORD_EMBS_MODELS

    if model_type in WORD_EMBS_MODELS and not force_reload:
        WORD_EMBS_MODELS[model_type].top_k = top_k
        return WORD_EMBS_MODELS[model_type]

    if model_type == 'word2vec':
        model = nmw.Word2vec(top_k=top_k, skip_check=skip_check)
        model.read(model_path)
    elif model_type == 'glove':
        model = nmw.GloVe(top_k=top_k, skip_check=skip_check)
        model.read(model_path)
    elif model_type == 'fasttext':
        model = nmw.Fasttext(top_k=top_k, skip_check=skip_check)
        model.read(model_path)
    else:
        raise ValueError('Model type value is unexpected. Expected values include {}'.format(model_types))

    WORD_EMBS_MODELS[model_type] = model
    return model


[docs]class WordEmbsAug(WordAugmenter): # https://aclweb.org/anthology/D15-1306, https://arxiv.org/pdf/1804.07998.pdf, https://arxiv.org/pdf/1509.01626.pdf # https://arxiv.org/ftp/arxiv/papers/1812/1812.04718.pdf """ Augmenter that leverage word embeddings to find top n similar word for augmentation. :param str model_type: Model type of word embeddings. Expected values include 'word2vec', 'glove' and 'fasttext'. :param str model_path: Downloaded model directory. Either model_path or model is must be provided :param obj model: Pre-loaded model (e.g. model class is nlpaug.model.word_embs.nmw.Word2vec(), nlpaug.model.word_embs.nmw.Glove() or nlpaug.model.word_embs.nmw.Fasttext()) :param str action: Either 'insert or 'substitute'. If value is 'insert', a new word will be injected to random position according to word embeddings calculation. If value is 'substitute', word will be replaced according to word embeddings calculation :param int top_k: Controlling lucky draw pool. Top k score token will be used for augmentation. Larger k, more token can be used. Default value is 100. If value is None which means using all possible tokens. This attribute will be ignored when using "insert" action. :param float aug_p: Percentage of word will be augmented. :param int aug_min: Minimum number of word will be augmented. :param int aug_max: Maximum number of word will be augmented. If None is passed, number of augmentation is calculated via aup_p. If calculated result from aug_p is smaller than aug_max, will use calculated result from aug_p. Otherwise, using aug_max. :param list stopwords: List of words which will be skipped from augment operation. :param str stopwords_regex: Regular expression for matching words which will be skipped from augment operation. :param func tokenizer: Customize tokenization process :param func reverse_tokenizer: Customize reverse of tokenization process :param bool force_reload: If True, model will be loaded every time while it takes longer time for initialization. :param bool skip_check: Default is False. If True, no validation for size of vocabulary embedding. :param str name: Name of this augmenter >>> import nlpaug.augmenter.word as naw >>> aug = naw.WordEmbsAug(model_type='word2vec', model_path='.') """ def __init__(self, model_type, model_path='.', model=None, action=Action.SUBSTITUTE, name='WordEmbs_Aug', aug_min=1, aug_max=10, aug_p=0.3, top_k=100, n_gram_separator='_', stopwords=None, tokenizer=None, reverse_tokenizer=None, force_reload=False, stopwords_regex=None, verbose=0, skip_check=False): super().__init__( action=action, name=name, aug_p=aug_p, aug_min=aug_min, aug_max=aug_max, stopwords=stopwords, tokenizer=tokenizer, reverse_tokenizer=reverse_tokenizer, device='cpu', verbose=verbose, stopwords_regex=stopwords_regex, include_detail=False) self.model_type = model_type self.model_path = model_path self.top_k = top_k self.n_gram_separator = n_gram_separator self.pre_validate() if model is None: self.model = self.get_model(model_path=model_path, model_type=model_type, force_reload=force_reload, top_k=self.top_k, skip_check=skip_check) else: self.model = model def pre_validate(self): if self.model_type not in model_types: raise ValueError('Model type value is unexpected. Expected values include {}'.format(model_types)) @classmethod def get_model(cls, model_path, model_type, force_reload=False, top_k=100, skip_check=False): return init_word_embs_model(model_path, model_type, force_reload, top_k=top_k, skip_check=skip_check) def skip_aug(self, token_idxes, tokens): results = [] for token_idx in token_idxes: # Some words do not come with vector. It will be excluded in lucky draw. word = tokens[token_idx] if word in self.model.get_vocab(): results.append(token_idx) return results def insert(self, data): if not data or not data.strip(): return data change_seq = 0 doc = Doc(data, self.tokenizer(data)) aug_idxes = self._get_random_aug_idxes(doc.get_original_tokens()) if not aug_idxes: if self.include_detail: return data, [] return data aug_idxes.sort(reverse=True) for aug_idx in aug_idxes: new_token = self.sample(self.model.get_vocab(), 1)[0] if self.n_gram_separator in new_token: new_token = new_token.split(self.n_gram_separator)[0] change_seq += 1 doc.add_token(aug_idx, token=new_token, action=Action.INSERT, change_seq=self.parent_change_seq + change_seq) if self.include_detail: return self.reverse_tokenizer(doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens()) def substitute(self, data): if not data or not data.strip(): return data change_seq = 0 doc = Doc(data, self.tokenizer(data)) aug_idxes = self._get_aug_idxes(doc.get_original_tokens()) if aug_idxes is None or len(aug_idxes) == 0: if self.include_detail: return data, [] return data for aug_idx in aug_idxes: original_token = doc.get_token(aug_idx).get_latest_token().token candidate_tokens = self.model.predict(original_token, n=1) substitute_token = self.sample(candidate_tokens, 1)[0] if aug_idx == 0: substitute_token = self.align_capitalization(original_token, substitute_token) change_seq += 1 doc.add_change_log(aug_idx, new_token=substitute_token, action=Action.SUBSTITUTE, change_seq=self.parent_change_seq + change_seq) if self.include_detail: return self.reverse_tokenizer(doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())