Augmenter that apply operation (sentence level) to textual input based on contextual word embeddings.
ContextualWordEmbsForSentenceAug(model_path='distilgpt2', temperature=1.0, top_k=100, top_p=None, name='ContextualWordEmbsForSentence_Aug', device='cpu', force_reload=False, optimize=None, verbose=0, silence=True)¶
Augmenter that leverage contextual word embeddings to find top n similar word for augmentation.
- model_path (str) – Model name or model path. It used transformers to load the model. Tested ‘xlnet-base-cased’, ‘gpt2’, ‘distilgpt2’. If you want to reduce inference time, you may select distilgpt2.
- temperature (float) – Controlling randomness. Default value is 1 and lower temperature results in less random behavior
- top_k (int) – Controlling lucky draw pool. Top k score token will be used for augmentation. Larger k, more token can be used. Default value is 100. If value is None which means using all possible tokens.
- top_p (float) – Controlling lucky draw pool. Top p of cumulative probability will be removed. Larger p, more token can be used. Default value is None which means using all possible tokens.
- device (str) – Default value is CPU. If value is CPU, it uses CPU for processing. If value is CUDA, it uses GPU for processing. Possible values include ‘cuda’ and ‘cpu’. (May able to use other options)
- force_reload (bool) – Force reload the contextual word embeddings model to memory when initialize the class. Default value is False and suggesting to keep it as False if performance is the consideration.
- optimize (obj) –
Configuration for optimized process. external_memory: Persisting previous computed result for next prediction. Extra memory will be used in orderto have shorter inference time. gpt2 and `distilgpt2`are supported.
- silence (bool) – Default is True. transformers library will print out warning message when leveraing pre-trained model. Set True to disable the expected warning message.
- name (str) – Name of this augmenter
>>> import nlpaug.augmenter.sentence as nas >>> aug = nas.ContextualWordEmbsForSentenceAug()
augment(data, n=1, num_thread=1)¶
- data (object/list) – Data for augmentation. It can be list of data (e.g. list of string or numpy) or single element (e.g. string or numpy)
- n (int) – Default is 1. Number of unique augmented output. Will be force to 1 if input is list of data
- num_thread (int) – Number of thread for data augmentation. Use this option when you are using CPU and n is larger than 1
>>> augmented_data = aug.augment(data)