Using GPT-2 to generate Pokémon anime episodes
A journey from data acquisition to hosting a web-service for something truly goofy.
Ludicolo was a salsa master, he would teach Ash how to move like a god. He would make fun of Ash for being unable to move so quickly, and would even attack him for being weak.
What follows is the story of someone using a ridiculously complex piece of technology just to make something goofy. The GPT-2 model, presented by OpenAI, was a game changer in AI generated text. So much so that the team behind the development of the model delayed it’s public release so people could prepare for a world were things like fake news could be generated effortlessly and without much human intervention. And YEAH today I’m going to show how one could use this dangerous artifact to create Pokémon episodes.
You can check the end result here: http://pokegen.thiagolira.com.br/ and all the code can be found on my Github repository. If you just want to see some examples skip to the end of the post.
I can’t promise the website will answer all requests at all times because the server can’t deal with more than 3 requests at the same time (I will explain why shortly) and I’m already out of the free-tier to host it on AWS. But if the site yields an error, just wait some seconds and try again :)
In this article, I’ll try my best to explain the main challenges of an end-to-end machine learning project, every step of the way.
The Data
Machine Learning methods are used to extract information and infer patterns from data. While classical statistical methods would have many parameters and assumptions selected by the statistician on the modeling phase, Machine Learning ones let the data speak for itself. This is a well known trade off between explainable (classical) models and accurate (machine learning) ones. The predictive power of machine learning basically comes from having a lot of data and a complex enough model to capture highly subtle patterns from it. There is a “smoothness” assumption that the model is being trained on a big enough sample of reality to infer (generalize) on what it doesn’t see directly, assuming it to be close to some example it has indeed been trained with (hence the smoothness).
Recent NLP (Natural Language Processing) models are no different and they need huge amounts of text and computing power to be trained. These new models start with zero knowledge of language and in the end, they become really good at gauging contextual information from sequences of words. To the point of understanding that the very same words have different meaning on different places on a sentence, something that classical NLP models are just not very good at achieving.
The GPT-2 model comes pre-trained on the entirety of Wikipedia, Reddit, and many other places. What I’ve done is to fine-tune the model on a more specific set of texts from the internet. This being the subset of the internet which consists of summaries of Pokémon anime episodes (and is on Bulbapedia). Shout out to Bulbapedia for being an excellent community-driven Pokémon website!
The crawler I wrote downloads around 400 episode summaries written by the community. Here is a sample from one episode:
Ash declares to himself and the Pokémon of the world that he will become a Pokémon Master. His speech, however, is interrupted by his mom who tells him to get to bed as he has a big day tomorrow. Ash protests that he’s too excited to sleep, so his mom tells him that if he won’t sleep then to at least get ready for the next day as she switches on a program hosted by the town’s Pokémon expert, Professor Oak. Ash watches as Oak explains that new Trainers get to pick one of three Pokémon to start their journey; the Grass-type Bulbasaur, the Fire-type Charmander or the Water-type Squirtle.
My crawler is on the file crawler_bulbapedia.py, and when run, will create a folder called data/pokeCorpusBulba where it will store every episode inside a separate text file.
The data is not yet ready to be given to the model. Another script by the name prepare_corpus.py will clean the text and join them all in a single file called train.txt, ready to be used with GPT-2.
The Model
GPT-2 is a Transformer based model, that uses a technique called self-attention to learn in an astoundingly natural way how words would complete or continue sentences. I don’t think I can do a better job at explaining the mathematics and inner-workings of this model than these multiple excellent sources. But I can offer some insights, on a purely programming perspective, on how to use this model pre-trained as if it were a text generating API. For that I’ve found an excellent resource, the gpt-2-simple python library, which makes all Tensorflow complexity basically invisible and offers some very simple functions to download, fine-tune and sample from the GPT-2 model.
Basically a language model tries to predict the next word from a sentence, and we can keep getting predictions from the model to generate new text, feeding the last predictions as new inputs to get more and more words. So as an example we can give our model the prefix input of “Ash and Pikachu were”:
What GPT-2 does, using an attention mechanism, is to dynamically assess the importance of the last words on predicting the next word. There is something called a “transformer cell” inside the model that calculates an attention value for each word on the input sequence in relation to every other word. This is all passed along to generate an output i.e. to predict the next word on the sentence.
As a somewhat simplified example, we can see by the strength of the attention values (the purple-er the more attention) that clearly “Ash” and “Pikachu” are relevant to determine what comes after “were”. This is something pretty nice about this model that classical “counting words” methods such as Naive-Bayes couldn’t do.
The training occurs by blanking words on sentences from the corpus and fine-tuning the model to predict them correctly. In the end, we have a checkpoint folder, which is the only thing we need to generate text from this model. This folder created by tensorflow contains the whole state of the model after fine-tuning with my Pokémon corpus, and the gpt-2-simple library will look for it when generating new text.
The server
This was BY FAR the most challenging part. Serving this model for inference on the internet is not a trivial task since the text generation is so memory-heavy.
Basically the server structure answers a GET request pointed at port 5000. It has a single function to answer this request that gets the parameters (the user input), initializes the model, generates some fixed amount of text, and returns everything inside a JSON. The difficult part is that the model occupies a whopping 1GB of memory to make an inference. So before anything I had to have a server with a decent amount of RAM (Goodbye AWS free tier! Twas’ nice meeting you). So I ended up choosing a EC2 t2-medium instance on AWS, and set it up with the good help of my friend João.
The following web-server structure to go inside the EC2 instance has been thoroughly copied from my other friend Gabriela, from her popular medium post.
The web-server I chose to run on this EC2 instance is nginx, which listen for requests and then forwards them to a uWSGI web-server that communicates with the Flask app via the WSGI protocol. Basically we have this structure:
The WSGI protocol has the purpose of creating a common interface for web applications written in Python. So I, for example, could change the application framework (Flask to Django) or the application server (uWSGI to Unicorn) with this being essentially invisible to the other part.
Now, why I don’t just serve the uWSGI server to the web? Why use another layer, i.e. nginx? Well, the simple answer is that nginx abstracts away some issues that could come with server load, which uWSGI alone is unsuited to deal with.
I, of course, had to pack all this software on a Docker container, because that’s just how all the cool kids are doing nowadays. All the code is available on my Github repository, but to read a really in-depth and well-made explanation of this configuration I suggest paying a visit to Gabriela’s post, since my setup is basically the same, with some small tweaks because my application is a bit different.
The Flask App
The Flask App — the place where the model is run on the server — has a single entry-point for requests, the generate function:
Lessons Learned
- Data is very important if you want to make something new with Machine Learning. And cleaning it decently is just as important.
- The GPT-2 model is just impractical to use as an on-demand text generating tool, it needs too much memory and CPU power to run. It is very expensive to have a service that needs 1GB of RAM to serve each request.
- docker system prune is your friend.
- The Python ecosystem of web-servers is not that hard to use, and there are plenty of examples out there.
Some output examples
The input feed to the model is in bold. Of course, there was some cherrypicking involved but this is just how generative models are.
Ash and Misty were dating when they fell in love. As they both recall their respective first experiences, ash’s first brush with pokémon is all he ever remembers, as he was just a child. later, after ash had made his first poké ball, he skipped lunch and pursued a friend and switched trainers. This ended in them both falling in love, leaving dawn and brock in tears. When they were out searching for ash’s bulbasaur, a wild gyarados swatted them off. (…)
Pikachu was tired of all this sh*t. He runs in fear of the grass, and runs in fear of the trainers, too. Jessie and james run outside and run outside. (…)
Ash wanted to be the very best and trained all of his pokémon to get there. He told his trainer the whole story and promised to be a great trainer. He told his parents and his friends that he would train them as best he could. They were surprised and were ready to give up on him, when his parents started crying. His mother told him to come back home and find his friends. They had no choice but to go with him.
Pikachu was being arrested for tax evasion. Once the trio get out of the olivine city pokémon center, they are immediately attacked by a former police officer, a police detective, a nurse joy, and a nurse joy’s glameow.