Commit 852378a6 by Alex Rudnick

Merge pull request #346 from alexrudnick/master

Fix for decision trees with binary=True
parents 53ac93c4 e3ca040f
...@@ -114,7 +114,7 @@ class DecisionTreeClassifier(ClassifierI): ...@@ -114,7 +114,7 @@ class DecisionTreeClassifier(ClassifierI):
if self._default is not None: if self._default is not None:
if len(self._decisions) == 1: if len(self._decisions) == 1:
s += '%sif %s != %r: '% (prefix, self._fname, s += '%sif %s != %r: '% (prefix, self._fname,
self._decisions.keys()[0]) list(self._decisions.keys())[0])
else: else:
s += '%selse: ' % (prefix,) s += '%selse: ' % (prefix,)
if self._default._fname is not None and depth>1: if self._default._fname is not None and depth>1:
...@@ -131,7 +131,7 @@ class DecisionTreeClassifier(ClassifierI): ...@@ -131,7 +131,7 @@ class DecisionTreeClassifier(ClassifierI):
support_cutoff=10, binary=False, feature_values=None, support_cutoff=10, binary=False, feature_values=None,
verbose=False): verbose=False):
""" """
:param binary: If true, then treat all feature/value pairs a :param binary: If true, then treat all feature/value pairs as
individual binary features, rather than using a single n-way individual binary features, rather than using a single n-way
branch for each feature. branch for each feature.
""" """
...@@ -242,8 +242,15 @@ class DecisionTreeClassifier(ClassifierI): ...@@ -242,8 +242,15 @@ class DecisionTreeClassifier(ClassifierI):
else: else:
neg_fdist.inc(label) neg_fdist.inc(label)
decisions = {feature_value: DecisionTreeClassifier(pos_fdist.max())}
default = DecisionTreeClassifier(neg_fdist.max()) decisions = {}
default = label
# But hopefully we have observations!
if pos_fdist.N() > 0:
decisions = {feature_value: DecisionTreeClassifier(pos_fdist.max())}
if neg_fdist.N() > 0:
default = DecisionTreeClassifier(neg_fdist.max())
return DecisionTreeClassifier(label, feature_name, decisions, default) return DecisionTreeClassifier(label, feature_name, decisions, default)
@staticmethod @staticmethod
...@@ -261,7 +268,7 @@ class DecisionTreeClassifier(ClassifierI): ...@@ -261,7 +268,7 @@ class DecisionTreeClassifier(ClassifierI):
best_stump = stump best_stump = stump
if best_stump._decisions: if best_stump._decisions:
descr = '%s=%s' % (best_stump._fname, descr = '%s=%s' % (best_stump._fname,
best_stump._decisions.keys()[0]) list(best_stump._decisions.keys())[0])
else: else:
descr = '(default)' descr = '(default)'
if verbose: if verbose:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment