Commit fb6cd01d by Steven Bird

Merge pull request #208 from heatherleaf/trees-are-not-lists

Trees are not lists
parents 27de61ac 1c841d33
...@@ -683,7 +683,7 @@ operations: ...@@ -683,7 +683,7 @@ operations:
>>> ptree.pop(-100) >>> ptree.pop(-100)
Traceback (most recent call last): Traceback (most recent call last):
. . . . . .
IndexError: index out of range IndexError: list index out of range
**remove()** **remove()**
...@@ -697,13 +697,13 @@ operations: ...@@ -697,13 +697,13 @@ operations:
>>> ptree[0,0].remove(make_ptree('(Q p)')) >>> ptree[0,0].remove(make_ptree('(Q p)'))
Traceback (most recent call last): Traceback (most recent call last):
. . . . . .
ValueError: list.index(x): x not in list ValueError
>>> ptree.remove('h'); pcheck(ptree) >>> ptree.remove('h'); pcheck(ptree)
ok! (A (B (C (D )) g)) ok! (A (B (C (D )) g))
>>> ptree.remove('h'); >>> ptree.remove('h');
Traceback (most recent call last): Traceback (most recent call last):
. . . . . .
ValueError: list.index(x): x not in list ValueError
>>> # remove() removes the first subtree that is equal (==) to the >>> # remove() removes the first subtree that is equal (==) to the
>>> # given tree, which may not be the identical tree we give it: >>> # given tree, which may not be the identical tree we give it:
>>> ptree = make_ptree('(A (X x) (Y y) (X x))') >>> ptree = make_ptree('(A (X x) (Y y) (X x))')
...@@ -998,7 +998,7 @@ multiple parents.) ...@@ -998,7 +998,7 @@ multiple parents.)
>>> mptree.pop(-100) >>> mptree.pop(-100)
Traceback (most recent call last): Traceback (most recent call last):
. . . . . .
IndexError: index out of range IndexError: list index out of range
**remove()** **remove()**
...@@ -1012,13 +1012,13 @@ multiple parents.) ...@@ -1012,13 +1012,13 @@ multiple parents.)
>>> mptree[0,0].remove(make_mptree('(Q p)')) >>> mptree[0,0].remove(make_mptree('(Q p)'))
Traceback (most recent call last): Traceback (most recent call last):
. . . . . .
ValueError: list.index(x): x not in list ValueError
>>> mptree.remove('h'); mpcheck(mptree) >>> mptree.remove('h'); mpcheck(mptree)
ok! (A (B (C (D )) g)) ok! (A (B (C (D )) g))
>>> mptree.remove('h'); >>> mptree.remove('h');
Traceback (most recent call last): Traceback (most recent call last):
. . . . . .
ValueError: list.index(x): x not in list ValueError
>>> # remove() removes the first subtree that is equal (==) to the >>> # remove() removes the first subtree that is equal (==) to the
>>> # given tree, which may not be the identical tree we give it: >>> # given tree, which may not be the identical tree we give it:
>>> mptree = make_mptree('(A (X x) (Y y) (X x))') >>> mptree = make_mptree('(A (X x) (Y y) (X x))')
......
...@@ -18,6 +18,7 @@ syntax trees and morphological trees. ...@@ -18,6 +18,7 @@ syntax trees and morphological trees.
import re import re
import string import string
from collections import MutableSequence
from nltk.grammar import Production, Nonterminal from nltk.grammar import Production, Nonterminal
from nltk.probability import ProbabilisticMixIn from nltk.probability import ProbabilisticMixIn
...@@ -27,7 +28,7 @@ from nltk.util import slice_bounds ...@@ -27,7 +28,7 @@ from nltk.util import slice_bounds
## Trees ## Trees
###################################################################### ######################################################################
class Tree(list): class Tree(MutableSequence):
""" """
A Tree represents a hierarchical grouping of leaves and subtrees. A Tree represents a hierarchical grouping of leaves and subtrees.
For example, each constituent in a syntax tree is represented by a single Tree. For example, each constituent in a syntax tree is represented by a single Tree.
...@@ -97,13 +98,13 @@ class Tree(list): ...@@ -97,13 +98,13 @@ class Tree(list):
raise TypeError("%s: Expected a node value and child list " raise TypeError("%s: Expected a node value and child list "
"or a single string" % type(self).__name__) "or a single string" % type(self).__name__)
tree = type(self).parse(node_or_str) tree = type(self).parse(node_or_str)
list.__init__(self, tree) self.children = tree.children
self.node = tree.node self.node = tree.node
elif isinstance(children, basestring): elif isinstance(children, basestring):
raise TypeError("%s() argument 2 should be a list, not a " raise TypeError("%s() argument 2 should be a list, not a "
"string" % type(self).__name__) "string" % type(self).__name__)
else: else:
list.__init__(self, children) self.children = list(children)
self.node = node_or_str self.node = node_or_str
#//////////////////////////////////////////////////////////// #////////////////////////////////////////////////////////////
...@@ -112,42 +113,29 @@ class Tree(list): ...@@ -112,42 +113,29 @@ class Tree(list):
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, Tree): return False if not isinstance(other, Tree): return False
return self.node == other.node and list.__eq__(self, other) return self.node == other.node and self.children == other.children
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)
def __lt__(self, other): def __lt__(self, other):
if not isinstance(other, Tree): return False if not isinstance(other, Tree): return False
return self.node < other.node or list.__lt__(self, other) return self.node < other.node or self.children < other.children
def __le__(self, other): def __le__(self, other):
if not isinstance(other, Tree): return False if not isinstance(other, Tree): return False
return self.node <= other.node or list.__le__(self, other) return self.node <= other.node or self.children <= other.children
def __gt__(self, other): def __gt__(self, other):
if not isinstance(other, Tree): return True if not isinstance(other, Tree): return True
return self.node > other.node or list.__gt__(self, other) return self.node > other.node or self.children > other.children
def __ge__(self, other): def __ge__(self, other):
if not isinstance(other, Tree): return False if not isinstance(other, Tree): return False
return self.node >= other.node or list.__ge__(self, other) return self.node >= other.node or self.children >= other.children
#//////////////////////////////////////////////////////////// #////////////////////////////////////////////////////////////
# Disabled list operations # Required MutableSequence methods
#////////////////////////////////////////////////////////////
def __mul__(self, v):
raise TypeError('Tree does not support multiplication')
def __rmul__(self, v):
raise TypeError('Tree does not support multiplication')
def __add__(self, v):
raise TypeError('Tree does not support addition')
def __radd__(self, v):
raise TypeError('Tree does not support addition')
#////////////////////////////////////////////////////////////
# Indexing (with support for tree positions)
#//////////////////////////////////////////////////////////// #////////////////////////////////////////////////////////////
def __getitem__(self, index): def __getitem__(self, index):
if isinstance(index, (int, slice)): if isinstance(index, (int, slice)):
return list.__getitem__(self, index) return self.children[index]
elif isinstance(index, (list, tuple)): elif isinstance(index, (list, tuple)):
if len(index) == 0: if len(index) == 0:
return self return self
...@@ -161,7 +149,7 @@ class Tree(list): ...@@ -161,7 +149,7 @@ class Tree(list):
def __setitem__(self, index, value): def __setitem__(self, index, value):
if isinstance(index, (int, slice)): if isinstance(index, (int, slice)):
return list.__setitem__(self, index, value) self.children[index] = value
elif isinstance(index, (list, tuple)): elif isinstance(index, (list, tuple)):
if len(index) == 0: if len(index) == 0:
raise IndexError('The tree position () may not be ' raise IndexError('The tree position () may not be '
...@@ -176,7 +164,7 @@ class Tree(list): ...@@ -176,7 +164,7 @@ class Tree(list):
def __delitem__(self, index): def __delitem__(self, index):
if isinstance(index, (int, slice)): if isinstance(index, (int, slice)):
return list.__delitem__(self, index) del self.children[index]
elif isinstance(index, (list, tuple)): elif isinstance(index, (list, tuple)):
if len(index) == 0: if len(index) == 0:
raise IndexError('The tree position () may not be deleted.') raise IndexError('The tree position () may not be deleted.')
...@@ -188,6 +176,18 @@ class Tree(list): ...@@ -188,6 +176,18 @@ class Tree(list):
raise TypeError("%s indices must be integers, not %s" % raise TypeError("%s indices must be integers, not %s" %
(type(self).__name__, type(index).__name__)) (type(self).__name__, type(index).__name__))
def __contains__(self, child):
return child in self.children
def __len__(self):
return len(self.children)
def __iter__(self):
return iter(self.children)
def insert(self, index, child):
self.children.insert(index, child)
#//////////////////////////////////////////////////////////// #////////////////////////////////////////////////////////////
# Basic tree operations # Basic tree operations
#//////////////////////////////////////////////////////////// #////////////////////////////////////////////////////////////
...@@ -482,7 +482,7 @@ class Tree(list): ...@@ -482,7 +482,7 @@ class Tree(list):
return tree return tree
def copy(self, deep=False): def copy(self, deep=False):
if not deep: return type(self)(self.node, self) if not deep: return type(self)(self.node, self.children[:])
else: return type(self).convert(self) else: return type(self).convert(self)
def _frozen_class(self): return ImmutableTree def _frozen_class(self): return ImmutableTree
...@@ -738,7 +738,7 @@ class ImmutableTree(Tree): ...@@ -738,7 +738,7 @@ class ImmutableTree(Tree):
# Precompute our hash value. This ensures that we're really # Precompute our hash value. This ensures that we're really
# immutable. It also means we only have to calculate it once. # immutable. It also means we only have to calculate it once.
try: try:
self._hash = hash( (self.node, tuple(self)) ) self._hash = hash( (self.node, self.children) )
except (TypeError, ValueError): except (TypeError, ValueError):
raise ValueError("%s: node value and children " raise ValueError("%s: node value and children "
"must be immutable" % type(self).__name__) "must be immutable" % type(self).__name__)
...@@ -772,7 +772,7 @@ class ImmutableTree(Tree): ...@@ -772,7 +772,7 @@ class ImmutableTree(Tree):
@property @property
def node(self): def node(self):
"""Get the node value""" """Get the node value."""
return self._node return self._node
@node.setter @node.setter
...@@ -785,6 +785,20 @@ class ImmutableTree(Tree): ...@@ -785,6 +785,20 @@ class ImmutableTree(Tree):
raise ValueError('%s may not be modified' % type(self).__name__) raise ValueError('%s may not be modified' % type(self).__name__)
self._node = value self._node = value
@property
def children(self):
"""Get the list of children."""
return self._children
@children.setter
def children(self, children):
"""
Set the children. This will only succeed the first time the
children are set, which should occur in ImmutableTree.__init__().
"""
if hasattr(self, 'children'):
raise ValueError('%s may not be modified' % type(self).__name__)
self._children = tuple(children)
###################################################################### ######################################################################
## Parented trees ## Parented trees
...@@ -963,17 +977,6 @@ class AbstractParentedTree(Tree): ...@@ -963,17 +977,6 @@ class AbstractParentedTree(Tree):
raise TypeError("%s indices must be integers, not %s" % raise TypeError("%s indices must be integers, not %s" %
(type(self).__name__, type(index).__name__)) (type(self).__name__, type(index).__name__))
def append(self, child):
if isinstance(child, Tree):
self._setparent(child, len(self))
super(AbstractParentedTree, self).append(child)
def extend(self, children):
for child in children:
if isinstance(child, Tree):
self._setparent(child, len(self))
super(AbstractParentedTree, self).append(child)
def insert(self, index, child): def insert(self, index, child):
# Handle negative indexes. Note that if index < -len(self), # Handle negative indexes. Note that if index < -len(self),
# we do *not* raise an IndexError, unlike __getitem__. This # we do *not* raise an IndexError, unlike __getitem__. This
...@@ -985,34 +988,6 @@ class AbstractParentedTree(Tree): ...@@ -985,34 +988,6 @@ class AbstractParentedTree(Tree):
self._setparent(child, index) self._setparent(child, index)
super(AbstractParentedTree, self).insert(index, child) super(AbstractParentedTree, self).insert(index, child)
def pop(self, index=-1):
if index < 0: index += len(self)
if index < 0: raise IndexError('index out of range')
if isinstance(self[index], Tree):
self._delparent(self[index], index)
return super(AbstractParentedTree, self).pop(index)
# n.b.: like `list`, this is done by equality, not identity!
# To remove a specific child, use del ptree[i].
def remove(self, child):
index = self.index(child)
if isinstance(self[index], Tree):
self._delparent(self[index], index)
super(AbstractParentedTree, self).remove(child)
# We need to implement __getslice__ and friends, even though
# they're deprecated, because otherwise list.__getslice__ will get
# called (since we're subclassing from list). Just delegate to
# __getitem__ etc., but use max(0, start) and max(0, stop) because
# because negative indices are already handled *before*
# __getslice__ is called; and we don't want to double-count them.
if hasattr(list, '__getslice__'):
def __getslice__(self, start, stop):
return self.__getitem__(slice(max(0, start), max(0, stop)))
def __delslice__(self, start, stop):
return self.__delitem__(slice(max(0, start), max(0, stop)))
def __setslice__(self, start, stop, value):
return self.__setitem__(slice(max(0, start), max(0, stop)), value)
class ParentedTree(AbstractParentedTree): class ParentedTree(AbstractParentedTree):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment