from __future__ import absolute_import
import multiprocessing
import time
from .. import log; log = log[__name__]
from .. import QROOT
from ..io import root_open, DoesNotExist
from ..utils.extras import humanize_bytes
from ..context import preserve_current_directory
from ..plotting.graph import _GraphBase
from ..extern.six import string_types
from .filtering import EventFilterList
__all__ = [
'TreeChain',
'TreeQueue',
]
class BaseTreeChain(object):
def __init__(self, name,
treebuffer=None,
branches=None,
ignore_branches=None,
events=-1,
onfilechange=None,
read_branches_on_demand=False,
cache=False,
# 30 MB cache by default
cache_size=30000000,
learn_entries=10,
always_read=None,
ignore_unsupported=False,
filters=None):
self._name = name
self._buffer = treebuffer
self._branches = branches
self._ignore_branches = ignore_branches
self._tree = None
self._file = None
self._events = events
self._total_events = 0
self._ignore_unsupported = ignore_unsupported
self._initialized = False
if filters is None:
self._filters = EventFilterList([])
else:
self._filters = filters
if onfilechange is None:
onfilechange = []
self._filechange_hooks = onfilechange
self._read_branches_on_demand = read_branches_on_demand
self._use_cache = cache
self._cache_size = cache_size
self._learn_entries = learn_entries
self.weight = 1.
self.userdata = {}
if not self._rollover():
raise RuntimeError("unable to initialize TreeChain")
if always_read is None:
self._always_read = []
elif isinstance(always_read, string_types):
if '*' in always_read:
always_read = self._tree.glob(always_read)
else:
always_read = [always_read]
self.always_read(always_read)
else:
branches = []
for branch in always_read:
if '*' in branch:
branches += self._tree.glob(branch)
else:
branches.append(branch)
self.always_read(branches)
def __nonzero__(self):
return len(self) > 0
__bool__ = __nonzero__
def _next_file(self):
"""
Override in subclasses
"""
return None
def always_read(self, branches):
self._always_read = branches
self._tree.always_read(branches)
def reset(self):
if self._tree is not None:
self._tree = None
if self._file is not None:
self._file.Close()
self._file = None
def Draw(self, *args, **kwargs):
"""
Loop over subfiles, draw each, and sum the output into a single
histogram.
"""
self.reset()
output = None
while self._rollover():
if output is None:
# Make our own copy of the drawn histogram
output = self._tree.Draw(*args, **kwargs)
if output is not None:
output = output.Clone()
# Make it memory resident (histograms)
if hasattr(output, 'SetDirectory'):
output.SetDirectory(0)
else:
newoutput = self._tree.Draw(*args, **kwargs)
if newoutput is not None:
if isinstance(output, _GraphBase):
output.Append(newoutput)
else: # histogram
output += newoutput
return output
draw = Draw
def __getattr__(self, attr):
try:
return getattr(self._tree, attr)
except AttributeError:
raise AttributeError("{0} instance has no attribute '{1}'".format(
self.__class__.__name__, attr))
def __getitem__(self, item):
return self._tree.__getitem__(item)
def __contains__(self, branch):
return self._tree.__contains__(branch)
def __iter__(self):
passed_events = 0
self.reset()
while self._rollover():
entries = 0
total_entries = float(self._tree.GetEntries())
t1 = time.time()
t2 = t1
for entry in self._tree:
entries += 1
self.userdata = {}
if self._filters(entry):
yield entry
passed_events += 1
if self._events == passed_events:
break
if time.time() - t2 > 60:
entry_rate = int(entries / (time.time() - t1))
log.info(
"{0:d} entr{1} per second. "
"{2:.0f}% done current tree.".format(
entry_rate,
'ies' if entry_rate != 1 else 'y',
100 * entries / total_entries))
t2 = time.time()
if self._events == passed_events:
break
log.info("{0:d} entries per second".format(
int(entries / (time.time() - t1))))
log.info("read {0:d} bytes in {1:d} transactions".format(
self._file.GetBytesRead(),
self._file.GetReadCalls()))
self._total_events += entries
self._filters.finalize()
def _rollover(self):
filename = self._next_file()
if filename is None:
return False
log.info("current file: {0}".format(filename))
try:
with preserve_current_directory():
if self._file is not None:
self._file.Close()
self._file = root_open(filename)
except IOError:
self._file = None
log.warning("could not open file {0} (skipping)".format(filename))
return self._rollover()
try:
self._tree = self._file.Get(self._name)
except DoesNotExist:
log.warning(
"tree {0} does not exist in file {1} (skipping)".format(
self._name, filename))
return self._rollover()
if len(self._tree.GetListOfBranches()) == 0:
log.warning("tree with no branches in file {0} (skipping)".format(
filename))
return self._rollover()
if self._branches is not None:
self._tree.activate(self._branches, exclusive=True)
if self._ignore_branches is not None:
self._tree.deactivate(self._ignore_branches, exclusive=False)
if self._buffer is None:
self._tree.create_buffer(self._ignore_unsupported)
self._buffer = self._tree._buffer
else:
self._tree.set_buffer(
self._buffer,
ignore_missing=True,
transfer_objects=True)
self._buffer = self._tree._buffer
if self._use_cache:
# enable TTreeCache for this tree
log.info(
"enabling a {0} TTreeCache for the current tree "
"({1:d} learning entries)".format(
humanize_bytes(self._cache_size), self._learn_entries))
self._tree.SetCacheSize(self._cache_size)
self._tree.SetCacheLearnEntries(self._learn_entries)
self._tree.read_branches_on_demand = self._read_branches_on_demand
self._tree.always_read(self._always_read)
self.weight = self._tree.GetWeight()
for target, args in self._filechange_hooks:
# run any user-defined functions
target(*args, name=self._name, file=self._file, tree=self._tree)
return True
[docs]class TreeChain(BaseTreeChain):
"""
A ROOT.TChain replacement
"""
def __init__(self, name, files, **kwargs):
if isinstance(files, tuple):
files = list(files)
elif not isinstance(files, list):
files = [files]
else:
files = files[:]
if not files:
raise RuntimeError(
"unable to initialize TreeChain: no files")
self._files = files
self.curr_file_idx = 0
super(TreeChain, self).__init__(name, **kwargs)
self._tchain = QROOT.TChain(name)
for filename in self._files:
self._tchain.Add(filename)
def GetEntries(self, *args, **kwargs):
return self._tchain.GetEntries(*args, **kwargs)
def GetEntriesFast(self, *args, **kwargs):
return self._tchain.GetEntriesFast(*args, **kwargs)
[docs] def reset(self):
"""
Reset the chain to the first file
Note: not valid when in queue mode
"""
super(TreeChain, self).reset()
self.curr_file_idx = 0
def __len__(self):
return len(self._files)
def _next_file(self):
if self.curr_file_idx >= len(self._files):
return None
filename = self._files[self.curr_file_idx]
nfiles_remaining = len(self._files) - self.curr_file_idx
log.info("{0:d} file{1} remaining".format(
nfiles_remaining,
's' if nfiles_remaining > 1 else ''))
self.curr_file_idx += 1
return filename
[docs]class TreeQueue(BaseTreeChain):
"""
A chain of files in a multiprocessing Queue.
Note that asking for the number of files in the queue with len(treequeue)
can be unreliable. Also, methods not overridden by TreeQueue will always be
called on the current tree, so GetEntries will give you the number of
entries in the current tree.
"""
SENTINEL = None
def __init__(self, name, files, **kwargs):
# multiprocessing.queues d.n.e. until one has been created
multiprocessing.Queue()
if not isinstance(files, multiprocessing.queues.Queue):
raise TypeError("files must be a multiprocessing.Queue")
self._files = files
super(TreeQueue, self).__init__(name, **kwargs)
def __len__(self):
# not reliable
return self._files.qsize()
def __nonzero__(self):
# not reliable
return not self._files.empty()
__bool__ = __nonzero__
def _next_file(self):
filename = self._files.get()
if filename == self.SENTINEL:
return None
return filename