🌑

Stephen's Blog

Spelling Correction with The Pretrained BERT Model

 

Stephen Cheng

Intro

BERT (Bidirectional Encoder Representations from Transformers) is published by researchers at Google AI Language. It has caused a stir in the Machine Learning community by presenting state-of-the-art results in a wide variety of NLP tasks, including Question Answering, Natural Language Inference, and others. BERT’s key technical innovation is applying the bidirectional training of Transformer, a popular attention model, to language modelling. This is in contrast to previous efforts which looked at a text sequence either from left to right or combined left-to-right and right-to-left training. The results show that a language model which is bidirectionally trained can have a deeper sense of language context and flow than single-direction language models.

BERT makes use of Transformer, an attention mechanism that learns contextual relations between words (or sub-words) in a text. Transformer includes two separate mechanisms — an encoder that reads the text input and a decoder that produces a prediction for the task. As opposed to directional models, which read the text input sequentially (left-to-right or right-to-left), the Transformer encoder reads the entire sequence of words at once. Therefore it is considered bidirectional, though it would be more accurate to say that it’s non-directional. This characteristic allows the model to learn the context of a word based on all of its surroundings (left and right of the word).

Demo

  • Import necessary libraris
1
2
3
4
5
6
7
8
9
10
import re
import nltk
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from pytesseract import image_to_string
from enchant.checker import SpellChecker
from difflib import SequenceMatcher
import torch
from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM
  • Process images by using OCR
1
2
3
4
5
6
7
8
imagename = '1.png'
pil_img = Image.open(imagename)
text = image_to_string(pil_img)
text_original = str(text)

print(text)
plt.figure(figsize = (12,4))
plt.imshow(np.asarray(pil_img))

Output:

1
2
3
national economy gained momentum in recent weeks as con@gmer spending
Strengthened, manufacturing activity cont@™ed to rise, and producers
scheduled more investment in plant and equipment.
  • Process text and mask incorrect words
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# text cleanup
rep = {'\n': ' ',
'\\': ' ',
'\"': '"',
'-': ' ',
'"': ' " ',
',':' , ',
'.':' . ',
'!':' ! ',
'?':' ? ',
"n't": " not",
"'ll": " will",
'*':' * ',
'(': ' ( ',
')': ' ) ',
"s'": "s '"}

rep = dict((re.escape(k), v) for k, v in rep.items())
pattern = re.compile("|".join(rep.keys()))
text = pattern.sub(lambda m: rep[re.escape(m.group(0))], text)
def get_personslist(text):
personslist = []
for sent in nltk.sent_tokenize(text):
for chunk in nltk.ne_chunk(nltk.pos_tag(nltk.word_tokenize(sent))):
if isinstance(chunk, nltk.tree.Tree) and chunk.label() == 'PERSON':
personslist.insert(0, (chunk.leaves()[0][0]))
return list(set(personslist))
personslist = get_personslist(text)
ignorewords = personslist + ["!", ",", ".", "\"", "?", '(', ')', '*', "''"]

# use SpellChecker to find incorrect words
d = SpellChecker("en_US")
words = text.split()
incorrectwords = [w for w in words if not d.check(w) and w not in ignorewords]

# use SpellChecker to get suggested replacements
suggestedwords = [d.suggest(w) for w in incorrectwords]

# replace incorrect words with [MASK]
for w in incorrectwords:
text = text.replace(w, '[MASK]')

print(text)

Output:

1
national economy gained momentum in recent weeks as [MASK] spending Strengthened ,  manufacturing activity [MASK] to rise ,  and producers  scheduled more investment in plant and equipment .

Use the pretrained BERT model to predict words

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Tokenize text
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenized_text = tokenizer.tokenize(text)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
MASKIDS = [i for i, e in enumerate(tokenized_text) if e == '[MASK]']

# Create the segments tensors
segs = [i for i, e in enumerate(tokenized_text) if e == "."]
segments_ids = []
prev = -1
for k, s in enumerate(segs):
segments_ids = segments_ids + [k] * (s-prev)
prev = s
segments_ids = segments_ids + [len(segs)] * (len(tokenized_text) - len(segments_ids))
segments_tensors = torch.tensor([segments_ids])

# prepare Torch inputs
tokens_tensors = torch.tensor([indexed_tokens])

# Load pre-trained model
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

# Predict all tokens
with torch.no_grad():
predictions = model(tokens_tensors, segments_tensors)
  • Match with proposals from SpellChecker
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def predict_word(text_original, predictions, MASKIDS):
pred_words = []
for i in range(len(MASKIDS)):
preds = torch.topk(predictions[0, MASKIDS[i]], k=50)
indices = preds.indices.tolist()
pred_list = tokenizer.convert_ids_to_tokens(indices)
sugg_list = suggestedwords[i]
sim_max = 0
predicted_token = ''
for word1 in pred_list:
for word2 in sugg_list:
s = SequenceMatcher(None, word1, word2).ratio()
if s is not None and s > sim_max:
sim_max = s
predicted_token = word1
text_original = text_original.replace('[MASK]', predicted_token, 1)
return text_original
1
2
text_refined = predict_word(text, predictions, MASKIDS)
print(text_refined)

Output:

1
national economy gained momentum in recent weeks as consumer spending Strengthened ,  manufacturing activity continued to rise ,  and producers  scheduled more investment in plant and equipment .

, — Jul 18, 2020

Search

    Made with ❤️ and ☀️ on Earth.