AI-generated Movie Reviews
We will create a language model that will generate its own movie reviews.
This blog post is basically a continuation of my previous post titled Classifying movie reviews using Sentiment Analysis and ULMFit, and you should definitely read that if you want to better understand the methodology behind the process used in this task.
The dataset we’ll be using is the IMDb Large Movie Review Dataset, which contains 25,000 highly polarized movie reviews for training, and 25,000 for testing.
!pip install -Uqq fastbook
import fastbook
fastbook.setup_book()
from fastbook import *
from fastai.text.all import *
path = untar_data(URLs.IMDB)
Path.BASE_PATH = path
path.ls()
We’ll grab the text files using get_text_files, which gets all the text files in a pth. We can optionally pass folders to restrict the search to a particular list of subfolders.
files = get_text_files(path, folders=['train', 'test', 'unsup'])
txt = files[0].open().read()
txt
get_imdb = partial(get_text_files, folders=['train', 'test', 'unsup'])
dls_lm = DataBlock(
blocks=TextBlock.from_folder(path, is_lm=True),
get_items=get_imdb, splitter=RandomSplitter(0.1)
).dataloaders(path, path=path, bs=128, seq_len=72)
dls_lm.show_batch(max_n=2)
Now that our data is ready, we can fine-tune the pretrained language model.
Fine-tuning the Language Model
To convert the integer word indices into activations that we can use for our neural network, we will use embeddings. We’ll feed those embeddings into a recurrent neural network (RNN), using an architecture called AWD-LSTM.
The embeddings in the pretrained model are merged with random embeddings added for words that weren’t in the pretraining vocabulary. This is handled automatically inside language_model_learner.
learn = language_model_learner(
dls_lm, AWD_LSTM, drop_mult=0.3,
metrics=[accuracy, Perplexity()]
).to_fp16()
learn.fit_one_cycle(3, 2e-2)
learn.unfreeze()
learn.fit_one_cycle(10, 2e-3)
TEXT = 'This movie is terrible'
N_WORDS = 70
N_SENTENCES = 5
preds = [learn.predict(TEXT, N_WORDS, temperature=0.75)
for _ in range(N_SENTENCES)]
print('\n\n'.join(preds))