Source code for rootpy.tree.categories
from __future__ import absolute_import
import re
from .cut import Cut
__all__ = [
'Categories',
]
[docs]class Categories(object):
"""
Implements a mechanism to ease the creation of cuts that describe
non-overlapping categories.
"""
#TODO: use pyparsing
CUT_REGEX = '[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?'
NODE_PATTERN = re.compile(
'^{(?P<variable>[^:|]+)(?::(?P<type>[IFif]))?\|'
'(?P<leftchild>{.+})?(?P<cut>' + CUT_REGEX + ')'
'(?P<rightchild>{.+})?}$')
CATEGORY_PATTERN = re.compile(
'^(?P<left>{.+})(?:x(?P<right>{.+}(?:x{.+})*))$')
CATEGORY_NODE_PATTERN = re.compile(
'^{(?P<variable>[^:|]+)(?::(?P<type>[IFif]))?\|'
'(?P<cuts>[\*]?(?:' + CUT_REGEX + ')(?:,' + CUT_REGEX + ')*[\*]?)}$')
@classmethod
def from_string(cls, string, variables=None):
node = None
if variables is None:
variables = []
nodematch = re.match(Categories.NODE_PATTERN, string)
categorymatch = re.match(Categories.CATEGORY_PATTERN, string)
categorynodematch = re.match(Categories.CATEGORY_NODE_PATTERN, string)
if categorymatch:
node = Categories.from_string(categorymatch.group('left'), variables)
subtree = Categories.from_string(categorymatch.group('right'), variables)
incompletenodes = node.get_incomplete_children()
for child in incompletenodes:
if not child.leftchild and not child.forbidleft:
clone = subtree.clone()
child.set_left(clone)
if not child.rightchild and not child.forbidright:
clone = subtree.clone()
child.set_right(clone)
elif categorynodematch:
var_type = 'F'
if categorynodematch.group('type'):
var_type = categorynodematch.group('type').upper()
variable = (categorynodematch.group('variable'), var_type)
if variable not in variables:
variables.append(variable)
cuts = categorynodematch.group('cuts').split(',')
if len(cuts) != len(set(cuts)):
raise SyntaxError(
"repeated cuts in '{0}'".format(
categorynodematch.group('cuts')))
if sorted(cuts) != cuts:
raise SyntaxError(
"cuts not in ascending order in '{0}'".format(
categorynodematch.group('cuts')))
nodes = []
for cut in cuts:
actual_cut = cut.replace('*', '')
node = Categories(
feature=variables.index(variable),
data=actual_cut,
variables=variables)
if cut.startswith('*'):
node.forbidleft = True
if cut.endswith('*'):
node.forbidright = True
nodes.append(node)
node = Categories.make_balanced_tree(nodes)
elif nodematch:
var_type = 'F'
if nodematch.group('type'):
var_type = nodematch.group('type').upper()
variable = (nodematch.group('variable'), var_type)
if variable not in variables:
variables.append(variable)
node = Categories(
feature=variables.index(variable),
data=nodematch.group('cut'),
variables=variables)
if nodematch.group('leftchild'):
leftchild = Categories.from_string(
nodematch.group('leftchild'), variables)
node.set_left(leftchild)
if nodematch.group('rightchild'):
rightchild = Categories.from_string(
nodematch.group('rightchild'), variables)
node.set_right(rightchild)
else:
raise SyntaxError(
"{0} is not valid category tree syntax".format(string))
return node
@classmethod
def make_balanced_tree(cls, nodes):
if len(nodes) == 0:
return None
if len(nodes) == 1:
return nodes[0]
center = len(nodes) // 2
leftnodes = nodes[:center]
rightnodes = nodes[center + 1:]
node = nodes[center]
leftchild = Categories.make_balanced_tree(leftnodes)
rightchild = Categories.make_balanced_tree(rightnodes)
node.set_left(leftchild)
node.set_right(rightchild)
return node
def __init__(self,
feature,
data,
variables,
leftchild=None,
rightchild=None,
parent=None,
forbidleft=False,
forbidright=False):
self.feature = feature
self.data = data
self.variables = variables
self.leftchild = leftchild
self.rightchild = rightchild
self.parent = parent
self.forbidleft = forbidleft
self.forbidright = forbidright
def clone(self):
leftclone = None
if self.leftchild is not None:
leftclone = self.leftchild.clone()
rightclone = None
if self.rightchild is not None:
rightclone = self.rightchild.clone()
return Categories(
self.feature,
self.data,
self.variables,
leftclone,
rightclone,
self.parent,
self.forbidleft,
self.forbidright)
def __str__(self):
leftstr = ''
rightstr = ''
if self.forbidleft:
leftstr = '*'
elif self.leftchild is not None:
leftstr = str(self.leftchild)
if self.forbidright:
rightstr = '*'
elif self.rightchild is not None:
rightstr = str(self.rightchild)
if self.feature >= 0:
return '{{0}:{1}|{2}{3}{4}}'.format(
self.variables[self.feature],
leftstr, str(self.data), rightstr)
return '{<<leaf>>|{0}}'.format(str(self.data))
def __repr__(self):
return self.__str__()
def set_left(self, child):
if child is self:
raise ValueError("attempted to set self as left child!")
self.leftchild = child
if child is not None:
child.parent = self
def set_right(self, child):
if child is self:
raise ValueError("attempted to set self as right child!")
self.rightchild = child
if child is not None:
child.parent = self
def is_leaf(self):
return self.leftchild is None and self.rightchild is None
def is_complete(self):
return self.leftchild is not None and self.rightchild is not None
def depth(self):
leftdepth = 0
if self.leftchild is not None:
leftdepth = self.leftchild.depth() + 1
rightdepth = 0
if self.rightchild is not None:
rightdepth = self.rightchild.depth() + 1
return max(leftdepth, rightdepth)
def balance(self):
leftdepth = 0
rightdepth = 0
if self.leftchild is not None:
leftdepth = self.leftchild.depth() + 1
if self.rightchild is not None:
rightdepth = self.rightchild.depth() + 1
return rightdepth - leftdepth
def get_leaves(self):
if self.is_leaf():
return [self]
leftleaves = []
if self.leftchild is not None:
leftleaves = self.leftchild.get_leaves()
rightleaves = []
if self.rightchild is not None:
rightleaves = self.rightchild.get_leaves()
return leftleaves + rightleaves
def get_incomplete_children(self):
children = []
if not self.is_complete():
children.append(self)
if self.leftchild is not None:
children += self.leftchild.get_incomplete_children()
if self.rightchild is not None:
children += self.rightchild.get_incomplete_children()
return children
def __len__(self):
"""
Number of categories beneath current node
"""
if self.is_leaf():
total = 0
if not self.forbidleft:
total += 1
if not self.forbidright:
total += 1
return total
total = 0
if not self.forbidleft and self.leftchild is not None:
total += len(self.leftchild)
if not self.forbidright and self.rightchild is not None:
total += len(self.rightchild)
return total
def walk(self, expression=None):
if expression is None:
expression = Cut()
if self.feature < 0:
if expression:
yield expression
if not self.forbidleft:
leftcondition = expression & Cut(
'{0}<={1}'.format(
self.variables[self.feature][0], self.data))
if self.leftchild is not None:
for condition in self.leftchild.walk(leftcondition):
yield condition
else:
yield leftcondition
if not self.forbidright:
rightcondition = expression & Cut(
'{0}>{1}'.format(
self.variables[self.feature][0], self.data))
if self.rightchild is not None:
for condition in self.rightchild.walk(rightcondition):
yield condition
else:
yield rightcondition
def __iter__(self):
"""
Iterator over leaf conditions
"""
for category in self.walk():
yield category