from __future__ import absolute_import
import os
import re
import shutil
from glob import glob
import ROOT
from . import log; log = log[__name__]
from ...extern.six import string_types
from ...memory.keepalive import keepalive
from ...utils.silence import silence_sout_serr
from ...utils.path import mkdir_p
from ...context import (
do_nothing, working_directory, preserve_current_directory)
from ...io import root_open
from ... import asrootpy
from . import Channel, Measurement, HistoSys, OverallSys
__all__ = [
'make_channel',
'make_measurement',
'make_workspace',
'measurements_from_xml',
'write_measurement',
'patch_xml',
'split_norm_shape',
]
[docs]def make_channel(name, samples, data=None, verbose=False):
"""
Create a Channel from a list of Samples
"""
if verbose:
llog = log['make_channel']
llog.info("creating channel {0}".format(name))
# avoid segfault if name begins with a digit by using "channel_" prefix
chan = Channel('channel_{0}'.format(name))
chan.SetStatErrorConfig(0.05, "Poisson")
if data is not None:
if verbose:
llog.info("setting data")
chan.SetData(data)
for sample in samples:
if verbose:
llog.info("adding sample {0}".format(sample.GetName()))
chan.AddSample(sample)
return chan
[docs]def make_measurement(name,
channels,
lumi=1.0, lumi_rel_error=0.1,
output_prefix='./histfactory',
POI=None,
const_params=None,
verbose=False):
"""
Create a Measurement from a list of Channels
"""
if verbose:
llog = log['make_measurement']
llog.info("creating measurement {0}".format(name))
if not isinstance(channels, (list, tuple)):
channels = [channels]
# Create the measurement
meas = Measurement('measurement_{0}'.format(name), '')
meas.SetOutputFilePrefix(output_prefix)
if POI is not None:
if isinstance(POI, string_types):
if verbose:
llog.info("setting POI {0}".format(POI))
meas.SetPOI(POI)
else:
if verbose:
llog.info("adding POIs {0}".format(', '.join(POI)))
for p in POI:
meas.AddPOI(p)
if verbose:
llog.info("setting lumi={0:f} +/- {1:f}".format(lumi, lumi_rel_error))
meas.lumi = lumi
meas.lumi_rel_error = lumi_rel_error
for channel in channels:
if verbose:
llog.info("adding channel {0}".format(channel.GetName()))
meas.AddChannel(channel)
if const_params is not None:
if verbose:
llog.info("adding constant parameters {0}".format(
', '.join(const_params)))
for param in const_params:
meas.AddConstantParam(param)
return meas
[docs]def make_workspace(measurement, channel=None, name=None, silence=False):
"""
Create a workspace containing the model for a measurement
If `channel` is None then include all channels in the model
If `silence` is True, then silence HistFactory's output on
stdout and stderr.
"""
context = silence_sout_serr if silence else do_nothing
with context():
hist2workspace = ROOT.RooStats.HistFactory.HistoToWorkspaceFactoryFast(
measurement)
if channel is not None:
workspace = hist2workspace.MakeSingleChannelModel(
measurement, channel)
else:
workspace = hist2workspace.MakeCombinedModel(measurement)
workspace = asrootpy(workspace)
keepalive(workspace, measurement)
if name is not None:
workspace.SetName('workspace_{0}'.format(name))
return workspace
[docs]def measurements_from_xml(filename,
collect_histograms=True,
cd_parent=False,
silence=False):
"""
Read in a list of Measurements from XML
"""
if not os.path.isfile(filename):
raise OSError("the file {0} does not exist".format(filename))
silence_context = silence_sout_serr if silence else do_nothing
filename = os.path.abspath(os.path.normpath(filename))
if cd_parent:
xml_directory = os.path.dirname(filename)
parent = os.path.abspath(os.path.join(xml_directory, os.pardir))
cd_context = working_directory
else:
parent = None
cd_context = do_nothing
log.info("parsing XML in {0} ...".format(filename))
with cd_context(parent):
parser = ROOT.RooStats.HistFactory.ConfigParser()
with silence_context():
measurements_vect = parser.GetMeasurementsFromXML(filename)
# prevent measurements_vect from being garbage collected
ROOT.SetOwnership(measurements_vect, False)
measurements = []
for m in measurements_vect:
if collect_histograms:
with silence_context():
m.CollectHistograms()
measurements.append(asrootpy(m))
return measurements
[docs]def write_measurement(measurement,
root_file=None,
xml_path=None,
output_path=None,
output_suffix=None,
write_workspaces=False,
apply_xml_patches=True,
silence=False):
"""
Write a measurement and RooWorkspaces for all contained channels
into a ROOT file and write the XML files into a directory.
Parameters
----------
measurement : HistFactory::Measurement
An asrootpy'd ``HistFactory::Measurement`` object
root_file : ROOT TFile or string, optional (default=None)
A ROOT file or string file name. The measurement and workspaces
will be written to this file. If ``root_file is None`` then a
new file will be created with the same name as the measurement and
with the prefix ``ws_``.
xml_path : string, optional (default=None)
A directory path to write the XML into. If None, a new directory with
the same name as the measurement and with the prefix ``xml_`` will be
created.
output_path : string, optional (default=None)
If ``root_file is None``, create the ROOT file under this path.
If ``xml_path is None``, create the XML directory under this path.
output_suffix : string, optional (default=None)
If ``root_file is None`` then a new file is created with the same name
as the measurement and with the prefix ``ws_``. ``output_suffix`` will
append a suffix to this file name (before the .root extension).
If ``xml_path is None``, then a new directory is created with the
same name as the measurement and with the prefix ``xml_``.
``output_suffix`` will append a suffix to this directory name.
write_workspaces : bool, optional (default=False)
If True then also write a RooWorkspace for each channel and for all
channels combined.
apply_xml_patches : bool, optional (default=True)
Apply fixes on the output of ``Measurement::PrintXML()`` to avoid known
HistFactory bugs. Some of the patches assume that the ROOT file
containing the histograms will exist one directory level up from the
XML and that hist2workspace, or any tool that later reads the XML will
run from that same directory containing the ROOT file.
silence : bool, optional (default=False)
If True then capture and silence all stdout/stderr output from
HistFactory.
"""
context = silence_sout_serr if silence else do_nothing
output_name = measurement.name
if output_suffix is not None:
output_name += '_{0}'.format(output_suffix)
output_name = output_name.replace(' ', '_')
if xml_path is None:
xml_path = 'xml_{0}'.format(output_name)
if output_path is not None:
xml_path = os.path.join(output_path, xml_path)
if not os.path.exists(xml_path):
mkdir_p(xml_path)
if root_file is None:
root_file = 'ws_{0}.root'.format(output_name)
if output_path is not None:
root_file = os.path.join(output_path, root_file)
own_file = False
if isinstance(root_file, string_types):
root_file = root_open(root_file, 'recreate')
own_file = True
with preserve_current_directory():
root_file.cd()
log.info("writing histograms and measurement in {0} ...".format(
root_file.GetName()))
with context():
measurement.writeToFile(root_file)
# get modified measurement
out_m = root_file.Get(measurement.name)
log.info("writing XML in {0} ...".format(xml_path))
with context():
out_m.PrintXML(xml_path)
if write_workspaces:
log.info("writing combined model in {0} ...".format(
root_file.GetName()))
workspace = make_workspace(measurement, silence=silence)
workspace.Write()
for channel in measurement.channels:
log.info("writing model for channel `{0}` in {1} ...".format(
channel.name, root_file.GetName()))
workspace = make_workspace(
measurement, channel=channel, silence=silence)
workspace.Write()
if apply_xml_patches:
# patch the output XML to avoid HistFactory bugs
patch_xml(glob(os.path.join(xml_path, '*.xml')),
root_file=os.path.basename(root_file.GetName()))
if own_file:
root_file.Close()
[docs]def patch_xml(files, root_file=None, float_precision=3):
"""
Apply patches to HistFactory XML output from PrintXML
"""
if float_precision < 0:
raise ValueError("precision must be greater than 0")
def fix_path(match):
path = match.group(1)
if path:
head, tail = os.path.split(path)
new_path = os.path.join(os.path.basename(head), tail)
else:
new_path = ''
return '<Input>{0}</Input>'.format(new_path)
for xmlfilename in files:
xmlfilename = os.path.abspath(os.path.normpath(xmlfilename))
patched_xmlfilename = '{0}.tmp'.format(xmlfilename)
log.info("patching {0} ...".format(xmlfilename))
fin = open(xmlfilename, 'r')
fout = open(patched_xmlfilename, 'w')
for line in fin:
if root_file is not None:
line = re.sub(
'InputFile="[^"]*"',
'InputFile="{0}"'.format(root_file), line)
line = line.replace(
'<StatError Activate="True" InputFile="" '
'HistoName="" HistoPath="" />',
'<StatError Activate="True" />')
line = re.sub(
'<Combination OutputFilePrefix="(\S*)" >',
'<Combination OutputFilePrefix="hist2workspace" >', line)
line = re.sub('\w+=""', '', line)
line = re.sub('\s+/>', ' />', line)
line = re.sub('(\S)\s+</', r'\1</', line)
# HistFactory bug:
line = re.sub('InputFileHigh="\S+"', '', line)
line = re.sub('InputFileLow="\S+"', '', line)
# HistFactory bug:
line = line.replace(
'<ParamSetting Const="True"></ParamSetting>', '')
# chop off floats to desired precision
line = re.sub(
r'"(\d*\.\d{{{0:d},}})"'.format(float_precision + 1),
lambda x: '"{0}"'.format(
str(round(float(x.group(1)), float_precision))),
line)
line = re.sub('"\s\s+(\S)', r'" \1', line)
line = re.sub('<Input>(.*)</Input>', fix_path, line)
fout.write(line)
fin.close()
fout.close()
shutil.move(patched_xmlfilename, xmlfilename)
if not os.path.isfile(os.path.join(
os.path.dirname(xmlfilename),
'HistFactorySchema.dtd')):
rootsys = os.getenv('ROOTSYS', None)
if rootsys is not None:
dtdfile = os.path.join(rootsys, 'etc/HistFactorySchema.dtd')
target = os.path.dirname(xmlfilename)
if os.path.isfile(dtdfile):
log.info("copying {0} to {1} ...".format(dtdfile, target))
shutil.copy(dtdfile, target)
else:
log.warning("{0} does not exist".format(dtdfile))
else:
log.warning(
"$ROOTSYS is not set so cannot find HistFactorySchema.dtd")
[docs]def split_norm_shape(histosys, nominal_hist):
"""
Split a HistoSys into normalization (OverallSys) and shape (HistoSys)
components.
It is recommended to use OverallSys as much as possible, which tries to
enforce continuity up to the second derivative during
interpolation/extrapolation. So, if there is indeed a shape variation, then
factorize it into shape and normalization components.
"""
up = histosys.GetHistoHigh()
dn = histosys.GetHistoLow()
up = up.Clone(name=up.name + '_shape')
dn = dn.Clone(name=dn.name + '_shape')
n_nominal = nominal_hist.integral(overflow=True)
n_up = up.integral(overflow=True)
n_dn = dn.integral(overflow=True)
if n_up != 0:
up.Scale(n_nominal / n_up)
if n_dn != 0:
dn.Scale(n_nominal / n_dn)
shape = HistoSys(histosys.GetName(), low=dn, high=up)
norm = OverallSys(histosys.GetName(),
low=n_dn / n_nominal if n_nominal != 0 else 1.,
high=n_up / n_nominal if n_nominal != 0 else 1.)
return norm, shape