Commit 8dcae605 by Steven Bird

Merge branch 'model' of https://github.com/Copper-Head/nltk into model

parents 617c2c4a 96156590
...@@ -36,7 +36,6 @@ class NgramModel(ModelI): ...@@ -36,7 +36,6 @@ class NgramModel(ModelI):
A processing interface for assigning a probability to the next word. A processing interface for assigning a probability to the next word.
""" """
# add cutoff
def __init__(self, n, train, pad_left=True, pad_right=False, def __init__(self, n, train, pad_left=True, pad_right=False,
estimator=None, *estimator_args, **estimator_kwargs): estimator=None, *estimator_args, **estimator_kwargs):
""" """
...@@ -90,42 +89,39 @@ class NgramModel(ModelI): ...@@ -90,42 +89,39 @@ class NgramModel(ModelI):
# make sure n is greater than zero, otherwise print it # make sure n is greater than zero, otherwise print it
assert (n > 0), n assert (n > 0), n
self._unigram_model = (n == 1)
# For explicitness save the check whether this is a unigram model
self.is_unigram_model = (n == 1)
# save the ngram order number
self._n = n self._n = n
# save left and right padding
self._lpad = ('',) * (n - 1) if pad_left else ()
self._rpad = ('',) * (n - 1) if pad_right else ()
if estimator is None: if estimator is None:
estimator = _estimator estimator = _estimator
cfd = ConditionalFreqDist() cfd = ConditionalFreqDist()
# set read-only ngrams set (see property declaration below to reconfigure)
self._ngrams = set() self._ngrams = set()
# If given a list of strings instead of a list of lists, create enclosing list # If given a list of strings instead of a list of lists, create enclosing list
if (train is not None) and isinstance(train[0], compat.string_types): if (train is not None) and isinstance(train[0], compat.string_types):
train = [train] train = [train]
# we need to keep track of the number of word types we encounter
words = set()
for sent in train: for sent in train:
for ngram in ngrams(sent, n, pad_left, pad_right, pad_symbol=''): raw_ngrams = ngrams(sent, n, pad_left, pad_right, pad_symbol='')
for ngram in raw_ngrams:
self._ngrams.add(ngram) self._ngrams.add(ngram)
context = tuple(ngram[:-1]) context = tuple(ngram[:-1])
token = ngram[-1] token = ngram[-1]
cfd[context][token] += 1 cfd[(context, token)] += 1
words.add(token)
# unless number of bins is explicitly passed, we should use the number
# of word types encountered during training as the bins value
if 'bins' not in estimator_kwargs:
estimator_kwargs['bins'] = len(words)
missed_words = (1 - int(pad_left) - int(pad_right)) * (n - 1)
estimator_kwargs['override_N'] = cfd.N() + missed_words
self._model = ConditionalProbDist(cfd, estimator, *estimator_args, **estimator_kwargs) self._probdist = estimator(cfd, *estimator_args, **estimator_kwargs)
# recursively construct the lower-order models # recursively construct the lower-order models
if not self._unigram_model: if not self.is_unigram_model:
self._backoff = NgramModel(n-1, train, self._backoff = NgramModel(n-1, train,
pad_left, pad_right, pad_left, pad_right,
estimator, estimator,
...@@ -135,31 +131,38 @@ class NgramModel(ModelI): ...@@ -135,31 +131,38 @@ class NgramModel(ModelI):
self._backoff_alphas = dict() self._backoff_alphas = dict()
# For each condition (or context) # For each condition (or context)
for ctxt in cfd.conditions(): for ctxt in cfd.conditions():
prdist = self._model[ctxt] # prob dist for this context
backoff_ctxt = ctxt[1:] backoff_ctxt = ctxt[1:]
backoff_total_pr = 0.0 backoff_total_pr = 0.0
total_observed_pr = 0.0 total_observed_pr = 0.0
for word in cfd[ctxt]:
# this is the subset of words that we OBSERVED # this is the subset of words that we OBSERVED following
# following this context # this context.
total_observed_pr += prdist.prob(word) # i.e. Count(word | context) > 0
# we normalize it by the total (n-1)-gram probability of for word in self._words_following(ctxt, cfd):
# words that were observed in this n-gram context total_observed_pr += self.prob(word, ctxt)
# we also need the total (n-1)-gram probability of
# words observed in this n-gram context
backoff_total_pr += self._backoff.prob(word, backoff_ctxt) backoff_total_pr += self._backoff.prob(word, backoff_ctxt)
assert (0 < total_observed_pr <= 1), total_observed_pr assert (0 <= total_observed_pr <= 1), total_observed_pr
# beta is the remaining probability weight after we factor out # beta is the remaining probability weight after we factor out
# the probability of observed words # the probability of observed words.
# As a sanity check, both total_observed_pr and backoff_total_pr
# must be GE 0, since probabilities are never negative
beta = 1.0 - total_observed_pr beta = 1.0 - total_observed_pr
# backoff total has to be less than one, otherwise we get # backoff total has to be less than one, otherwise we get
# ZeroDivision error when we try subtracting it from 1 below # an error when we try subtracting it from 1 in the denominator
assert (0 < backoff_total_pr < 1), backoff_total_pr assert (0 <= backoff_total_pr < 1), backoff_total_pr
alpha_ctxt = beta / (1.0 - backoff_total_pr) alpha_ctxt = beta / (1.0 - backoff_total_pr)
self._backoff_alphas[ctxt] = alpha_ctxt self._backoff_alphas[ctxt] = alpha_ctxt
def _words_following(self, context, cond_freq_dist):
for ctxt, word in cond_freq_dist.iterkeys():
if ctxt == context:
yield word
def prob(self, word, context): def prob(self, word, context):
""" """
Evaluate the probability of this word in this context using Katz Backoff. Evaluate the probability of this word in this context using Katz Backoff.
...@@ -170,15 +173,17 @@ class NgramModel(ModelI): ...@@ -170,15 +173,17 @@ class NgramModel(ModelI):
:type context: list(str) :type context: list(str)
""" """
context = tuple(context) context = tuple(context)
if (context + (word,) in self._ngrams) or (self._unigram_model): if (context + (word,) in self._ngrams) or (self.is_unigram_model):
return self[context].prob(word) return self._probdist.prob((context, word))
else: else:
return self._alpha(context) * self._backoff.prob(word, context[1:]) return self._alpha(context) * self._backoff.prob(word, context[1:])
# Updated _alpha function, discarded the _beta function
def _alpha(self, context): def _alpha(self, context):
"""Get the backoff alpha value for the given context """Get the backoff alpha value for the given context
""" """
error_message = "Alphas and backoff are not defined for unigram models"
assert not self.is_unigram_model, error_message
if context in self._backoff_alphas: if context in self._backoff_alphas:
return self._backoff_alphas[context] return self._backoff_alphas[context]
else: else:
...@@ -193,9 +198,20 @@ class NgramModel(ModelI): ...@@ -193,9 +198,20 @@ class NgramModel(ModelI):
:param context: the context the word is in :param context: the context the word is in
:type context: list(str) :type context: list(str)
""" """
return -log(self.prob(word, context), 2) return -log(self.prob(word, context), 2)
@property
def ngrams(self):
return self._ngrams
@property
def backoff(self):
return self._backoff
@property
def probdist(self):
return self._probdist
def choose_random_word(self, context): def choose_random_word(self, context):
''' '''
Randomly select a word that is likely to appear in this context. Randomly select a word that is likely to appear in this context.
...@@ -224,8 +240,7 @@ class NgramModel(ModelI): ...@@ -224,8 +240,7 @@ class NgramModel(ModelI):
return text return text
def _generate_one(self, context): def _generate_one(self, context):
context = (self._lpad + tuple(context))[-self._n+1:] context = (self._lpad + tuple(context))[- self._n + 1:]
# print "Context (%d): <%s>" % (self._n, ','.join(context))
if context in self: if context in self:
return self[context].generate() return self[context].generate()
elif self._n > 1: elif self._n > 1:
...@@ -245,11 +260,11 @@ class NgramModel(ModelI): ...@@ -245,11 +260,11 @@ class NgramModel(ModelI):
e = 0.0 e = 0.0
text = list(self._lpad) + text + list(self._rpad) text = list(self._lpad) + text + list(self._rpad)
for i in range(self._n-1, len(text)): for i in range(self._n - 1, len(text)):
context = tuple(text[i-self._n+1:i]) context = tuple(text[i - self._n + 1:i])
token = text[i] token = text[i]
e += self.logprob(token, context) e += self.logprob(token, context)
return e / float(len(text) - (self._n-1)) return e / float(len(text) - (self._n - 1))
def perplexity(self, text): def perplexity(self, text):
""" """
...@@ -263,10 +278,10 @@ class NgramModel(ModelI): ...@@ -263,10 +278,10 @@ class NgramModel(ModelI):
return pow(2.0, self.entropy(text)) return pow(2.0, self.entropy(text))
def __contains__(self, item): def __contains__(self, item):
return tuple(item) in self._model return tuple(item) in self._probdist.freqdist
def __getitem__(self, item): def __getitem__(self, item):
return self._model[tuple(item)] return self._probdist[tuple(item)]
def __repr__(self): def __repr__(self):
return '<NgramModel with %d %d-grams>' % (len(self._ngrams), self._n) return '<NgramModel with %d %d-grams>' % (len(self._ngrams), self._n)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment