Source code for rootpy.plotting.graph

from __future__ import absolute_import

import math
import numbers
from operator import add, sub

import ROOT

from .. import log; log = log[__name__]
from .. import QROOT
from ..extern.six.moves import range
from ..base import NamelessConstructorObject
from ..decorators import snake_case_methods
from .base import Plottable

__all__ = [
    'Graph',
    'Graph1D',
    'Graph2D',
]


class _GraphBase(object):

    class GraphPoint(object):
        """
	Class similar to BinProxy for histograms, useful for
	getting single point information
	"""
        class Measurement(object):
            """
	    Generalized measusement class, each graph point
	    has one for each axis
	    """
            def __init__(self, graph, axis, idx):
                self.isdefault = not hasattr(graph, axis)
                self.axis_ = axis
                self.index_ = idx
                self.graph_ = graph

            @property
            def value(self):
                return 0. if self.isdefault else getattr(self.graph_, self.axis_)(self.index_)

            @value.setter
            def value(self, value):
                axes = ['x', 'y']
                if hasattr(self.graph_, 'z'):
                    axes.append('z')
                vals = []
                for axis in axes:
                    if axis == self.axis_:
                        vals.append(value)
                    else:
                        vals.append(
                            getattr(
                                self.graph_,
                                axis)(self.index_)
                            )
                self.graph_.SetPoint(self.index_, *vals)

            @property
            def error(self):
                return 0. if self.isdefault else getattr(
                    self.graph_,
                    '{0}err'.format(self.axis_)
                    )(self.index_)

            @property
            def error_hi(self):
                return 0. if self.isdefault else getattr(
                    self.graph_,
                    '{0}errh'.format(self.axis_)
                    )(self.index_)

            @error_hi.setter
            def error_hi(self, val):
                if self.isdefault: return
                getattr(
                    self.graph_,
                    'SetPointE{0}high'.format(self.axis_.upper())
                    )(self.index_, val)

            @property
            def error_low(self):
                return 0. if self.isdefault else getattr(
                    self.graph_,
                    '{0}errl'.format(self.axis_)
                    )(self.index_)

            @error_low.setter
            def error_low(self, val):
                if self.isdefault: return
                getattr(
                    self.graph_,
                    'voidSetPointE{0}low'.format(self.axis_.upper())
                    )(self.index_, val)


            @property
            def error_avg(self):
                return 0. if self.isdefault else getattr(
                    self.graph_,
                    '{0}erravg'.format(self.axis_)
                    )(self.index_)

            @property
            def error_max(self):
                return 0. if self.isdefault else getattr(
                    self.graph_,
                    '{0}errmax'.format(self.axis_)
                    )(self.index_)

        def __init__(self, graph, idx):
            self.graph_ = graph
            self.idx_ = idx

        @property
        def x(self):
            """returns the x coordinate
            """
            return _GraphBase.GraphPoint.Measurement(self.graph_, 'x', self.idx_)

        @property
        def y(self):
            """returns the y coordinate
            """
            return _GraphBase.GraphPoint.Measurement(self.graph_, 'y', self.idx_)

        @property
        def z(self):
            """returns the z coordinate
            """
            return _GraphBase.GraphPoint.Measurement(self.graph_, 'z', self.idx_)

    @classmethod
    def from_file(cls, filename, sep=' ', name=None, title=None):
        with open(filename, 'r') as gfile:
            lines = gfile.readlines()
        numpoints = len(lines)
        graph = cls(numpoints, name=name, title=title)
        for idx, line in enumerate(lines):
            point = list(map(float, line.rstrip().split(sep)))
            if len(point) != cls.DIM + 1:
                raise ValueError(
                    "line {0:d} does not contain "
                    "{1:d} values: {2}".format(
                        idx + 1, cls.DIM + 1, line))
            graph.SetPoint(idx, *point)
        graph.Set(numpoints)
        return graph

    def __len__(self):
        return self.GetN()

    def __iter__(self):
        for index in range(len(self)):
            yield self[index]

    @property
    def num_points(self):
        return self.GetN()

    @num_points.setter
    def num_points(self, n):
        if n < 0:
            raise ValueError("number of points in a graph must "
                             "be non-negative")
        # ROOT, why not SetN with GetN?
        self.Set(n)

    def x(self, index=None):
        if index is None:
            return (self.GetX()[i] for i in range(self.GetN()))
        index = index % len(self)
        return self.GetX()[index]

    def xerr(self, index=None):
        if index is None:
            return ((self.GetEXlow()[i], self.GetEXhigh()[i])
                    for i in range(self.GetN()))
        index = index % len(self)
        return (self.GetErrorXlow(index), self.GetErrorXhigh(index))

    def xerrh(self, index=None):
        if index is None:
            return (self.GetEXhigh()[i] for i in range(self.GetN()))
        index = index % len(self)
        return self.GetErrorXhigh(index)

    def xerrl(self, index=None):
        if index is None:
            return (self.GetEXlow()[i] for i in range(self.GetN()))
        index = index % len(self)
        return self.GetErrorXlow(index)

    def xerravg(self, index=None):
        if index is None:
            return (self.xerravg(i) for i in range(self.GetN()))
        index = index % len(self)
        return math.sqrt(self.GetErrorXhigh(index) ** 2 +
                         self.GetErrorXlow(index) ** 2)

    def xerrmax(self, index=None):
        if index is None:
            return (self.xerravg(i) for i in range(self.GetN()))
        index = index % len(self)
        return max(self.GetErrorXhigh(index),
                   self.GetErrorXlow(index))

    def y(self, index=None):
        if index is None:
            return (self.GetY()[i] for i in range(self.GetN()))
        index = index % len(self)
        return self.GetY()[index]

    def yerr(self, index=None):
        if index is None:
            return (self.yerr(i) for i in range(self.GetN()))
        index = index % len(self)
        return (self.GetErrorYlow(index), self.GetErrorYhigh(index))

    def yerrh(self, index=None):
        if index is None:
            return (self.GetEYhigh()[i] for i in range(self.GetN()))
        index = index % len(self)
        return self.GetEYhigh()[index]

    def yerrl(self, index=None):
        if index is None:
            return (self.GetEYlow()[i] for i in range(self.GetN()))
        index = index % len(self)
        return self.GetEYlow()[index]

    def yerravg(self, index=None):
        if index is None:
            return (self.yerravg()[i] for i in range(self.GetN()))
        index = index % len(self)
        return math.sqrt(self.GetEYhigh()[index] ** 2 +
                         self.GetEYlow()[index] ** 2)

    def yerravg(self, index=None):
        if index is None:
            return (self.yerravg()[i] for i in range(self.GetN()))
        index = index % len(self)
        return max(self.GetEYhigh()[index],
                   self.GetEYlow()[index])

    def __getitem__(self, idx):
        return _GraphBase.GraphPoint(self, idx)

    def __setitem__(self, index, point):
        if not 0 <= index <= self.GetN():
            raise IndexError("graph point index out of range")
        self.SetPoint(index, *point)


class _Graph1DBase(_GraphBase):

    @classmethod
    def divide(cls, top, bottom, option='cp'):
        from .hist import Hist
        if isinstance(top, _Graph1DBase):
            top = Hist(top)
        if isinstance(bottom, _Graph1DBase):
            bottom = Hist(bottom)
        ratio = Graph(type='asymm')
        ratio.Divide(top, bottom, option)
        return ratio

    def __add__(self, other):
        copy = self.Clone()
        copy += other
        return copy

    def __radd__(self, other):
        return self + other

    def __sub__(self, other):
        copy = self.Clone()
        copy -= other
        return copy

    def __rsub__(self, other):
        return -1 * (self - other)

    def __div__(self, other):
        copy = self.Clone()
        copy /= other
        return copy

    __truediv__ = __div__

    def __mul__(self, other):
        copy = self.Clone()
        copy *= other
        return copy

    def __rmul__(self, other):
        return self * other

    def __iadd__(self, other):
        if isinstance(other, numbers.Real):
            for index in range(len(self)):
                point = self[index]
                self.SetPoint(index, point.x.value, point.y.value + other)
            return self
        for index in range(len(self)):
            mypoint = self[index]
            otherpoint = other[index]
            xlow = self.GetEXlow()[index]
            xhigh = self.GetEXhigh()[index]
            ylow = math.sqrt((self.GetEYlow()[index]) ** 2 +
                            (other.GetEYlow()[index]) ** 2)
            yhigh = math.sqrt((self.GetEYhigh()[index]) ** 2 +
                                (other.GetEYhigh()[index]) ** 2)
            self.SetPoint(index, mypoint.x.value, mypoint.y.value + otherpoint.y.value)
            self.SetPointError(index, xlow, xhigh, ylow, yhigh)
        return self

    def __isub__(self, other):
        if isinstance(other, numbers.Real):
            for index in range(len(self)):
                point = self[index]
                self.SetPoint(index, point.x.value, point.y.value - other)
            return self
        for index in range(len(self)):
            mypoint = self[index]
            otherpoint = other[index]
            xlow = self.GetEXlow()[index]
            xhigh = self.GetEXhigh()[index]
            ylow = math.sqrt((self.GetEYlow()[index]) ** 2 +
                            (other.GetEYlow()[index]) ** 2)
            yhigh = math.sqrt((self.GetEYhigh()[index]) ** 2 +
                                (other.GetEYhigh()[index]) ** 2)
            self.SetPoint(index, mypoint.x.value, mypoint.y.value - otherpoint.y.value)
            self.SetPointError(index, xlow, xhigh, ylow, yhigh)
        return self

    def __idiv__(self, other):
        if isinstance(other, numbers.Real):
            for index in range(len(self)):
                point = self[index]
                ylow, yhigh = self.GetEYlow()[index], self.GetEYhigh()[index]
                xlow, xhigh = self.GetEXlow()[index], self.GetEXhigh()[index]
                self.SetPoint(index, point.x.value, point.y.value / other)
                self.SetPointError(index, xlow, xhigh,
                                   ylow / other, yhigh / other)
            return self
        for index in range(len(self)):
            mypoint = self[index]
            otherpoint = other[index]
            xlow = self.GetEXlow()[index]
            xhigh = self.GetEXhigh()[index]
            ylow = (
                (mypoint.y.value / otherpoint.y.value) *
                math.sqrt((self.GetEYlow()[index] / mypoint.y.value) ** 2 +
                            (other.GetEYlow()[index] /
                                otherpoint.y.value) ** 2))
            yhigh = (
                (mypoint.y.value / otherpoint.y.value) *
                math.sqrt((self.GetEYhigh()[index] / mypoint.y.value) ** 2 +
                            (other.GetEYhigh()[index] /
                                otherpoint.y.value) ** 2))
            self.SetPoint(index, mypoint.x.value, mypoint.y.value / otherpoint.y.value)
            self.SetPointError(index, xlow, xhigh, ylow, yhigh)
        return self

    __itruediv__ = __idiv__

    def __imul__(self, other):
        if isinstance(other, numbers.Real):
            for index in range(len(self)):
                point = self[index]
                ylow, yhigh = self.GetEYlow()[index], self.GetEYhigh()[index]
                xlow, xhigh = self.GetEXlow()[index], self.GetEXhigh()[index]
                self.SetPoint(index, point.x.value, point.y.value * other)
                self.SetPointError(index, xlow, xhigh,
                                   ylow * other, yhigh * other)
            return self
        for index in range(len(self)):
            mypoint = self[index]
            otherpoint = other[index]
            xlow = self.GetEXlow()[index]
            xhigh = self.GetEXhigh()[index]
            ylow = (
                (mypoint.y.value * otherpoint.y.value) *
                math.sqrt((self.GetEYlow()[index] / mypoint.y.value) ** 2 +
                            (other.GetEYlow()[index] / otherpoint.y.value) ** 2))
            yhigh = (
                (mypoint.y.value * otherpoint.y.value) *
                math.sqrt((self.GetEYhigh()[index] / mypoint.y.value) ** 2 +
                            (other.GetEYhigh()[index] / otherpoint.y.value) ** 2))
            self.SetPoint(index, mypoint.x.value, mypoint.y.value * otherpoint.y.value)
            self.SetPointError(index, xlow, xhigh, ylow, yhigh)
        return self

    def GetMaximum(self, include_error=False):
        if not include_error:
            return self.GetYmax()
        summed = map(add, self.y(), self.yerrh())
        return max(summed)

    def GetMinimum(self, include_error=False):
        if not include_error:
            return self.GetYmin()
        summed = map(sub, self.y(), self.yerrl())
        return min(summed)

    def GetXmin(self):
        if len(self) == 0:
            raise ValueError("Attemping to get xmin of empty graph")
        return ROOT.TMath.MinElement(self.GetN(), self.GetX())

    def GetXmax(self):
        if len(self) == 0:
            raise ValueError("Attempting to get xmax of empty graph")
        return ROOT.TMath.MaxElement(self.GetN(), self.GetX())

    def GetYmin(self):
        if len(self) == 0:
            raise ValueError("Attempting to get ymin of empty graph")
        return ROOT.TMath.MinElement(self.GetN(), self.GetY())

    def GetYmax(self):
        if len(self) == 0:
            raise ValueError("Attempting to get ymax of empty graph!")
        return ROOT.TMath.MaxElement(self.GetN(), self.GetY())

    def GetEXhigh(self):
        if isinstance(self, ROOT.TGraphErrors):
            return self.GetEX()
        return super(_Graph1DBase, self).GetEXhigh()

    def GetEXlow(self):
        if isinstance(self, ROOT.TGraphErrors):
            return self.GetEX()
        return super(_Graph1DBase, self).GetEXlow()

    def GetEYhigh(self):
        if isinstance(self, ROOT.TGraphErrors):
            return self.GetEY()
        return super(_Graph1DBase, self).GetEYhigh()

    def GetEYlow(self):
        if isinstance(self, ROOT.TGraphErrors):
            return self.GetEY()
        return super(_Graph1DBase, self).GetEYlow()

    def Crop(self, x1, x2, copy=False):
        """
        Remove points which lie outside of [x1, x2].
        If x1 and/or x2 is below/above the current lowest/highest
        x-coordinates, additional points are added to the graph using a
        linear interpolation
        """
        numPoints = self.GetN()
        if copy:
            cropGraph = self.Clone()
            copyGraph = self
        else:
            cropGraph = self
            copyGraph = self.Clone()
        X = copyGraph.GetX()
        EXlow = copyGraph.GetEXlow()
        EXhigh = copyGraph.GetEXhigh()
        Y = copyGraph.GetY()
        EYlow = copyGraph.GetEYlow()
        EYhigh = copyGraph.GetEYhigh()
        xmin = copyGraph.GetXmin()
        if x1 < xmin:
            cropGraph.Set(numPoints + 1)
            numPoints += 1
        xmax = copyGraph.GetXmax()
        if x2 > xmax:
            cropGraph.Set(numPoints + 1)
            numPoints += 1
        index = 0
        for i in range(numPoints):
            if i == 0 and x1 < xmin:
                cropGraph.SetPoint(0, x1, copyGraph.Eval(x1))
            elif i == numPoints - 1 and x2 > xmax:
                cropGraph.SetPoint(i, x2, copyGraph.Eval(x2))
            else:
                cropGraph.SetPoint(i, X[index], Y[index])
                cropGraph.SetPointError(
                    i,
                    EXlow[index], EXhigh[index],
                    EYlow[index], EYhigh[index])
                index += 1
        return cropGraph

    def Reverse(self, copy=False):
        """
        Reverse the order of the points
        """
        numPoints = self.GetN()
        if copy:
            revGraph = self.Clone()
        else:
            revGraph = self
        X = self.GetX()
        EXlow = self.GetEXlow()
        EXhigh = self.GetEXhigh()
        Y = self.GetY()
        EYlow = self.GetEYlow()
        EYhigh = self.GetEYhigh()
        for i in range(numPoints):
            index = numPoints - 1 - i
            revGraph.SetPoint(i, X[index], Y[index])
            revGraph.SetPointError(
                i,
                EXlow[index], EXhigh[index],
                EYlow[index], EYhigh[index])
        return revGraph

    def Invert(self, copy=False):
        """
        Interchange the x and y coordinates of all points
        """
        numPoints = self.GetN()
        if copy:
            invGraph = self.Clone()
        else:
            invGraph = self
        X = self.GetX()
        EXlow = self.GetEXlow()
        EXhigh = self.GetEXhigh()
        Y = self.GetY()
        EYlow = self.GetEYlow()
        EYhigh = self.GetEYhigh()
        for i in range(numPoints):
            invGraph.SetPoint(i, Y[i], X[i])
            invGraph.SetPointError(
                i,
                EYlow[i], EYhigh[i],
                EXlow[i], EXhigh[i])
        return invGraph

    def Scale(self, value, copy=False):
        """
        Scale the graph vertically by value
        """
        numPoints = self.GetN()
        if copy:
            scaleGraph = self.Clone()
        else:
            scaleGraph = self
        X = self.GetX()
        EXlow = self.GetEXlow()
        EXhigh = self.GetEXhigh()
        Y = self.GetY()
        EYlow = self.GetEYlow()
        EYhigh = self.GetEYhigh()
        for i in range(numPoints):
            scaleGraph.SetPoint(i, X[i], Y[i] * value)
            scaleGraph.SetPointError(
                i,
                EXlow[i], EXhigh[i],
                EYlow[i] * value, EYhigh[i] * value)
        return scaleGraph

    def Stretch(self, value, copy=False):
        """
        Stretch the graph horizontally by a factor of value
        """
        numPoints = self.GetN()
        if copy:
            stretchGraph = self.Clone()
        else:
            stretchGraph = self
        X = self.GetX()
        EXlow = self.GetEXlow()
        EXhigh = self.GetEXhigh()
        Y = self.GetY()
        EYlow = self.GetEYlow()
        EYhigh = self.GetEYhigh()
        for i in range(numPoints):
            stretchGraph.SetPoint(i, X[i] * value, Y[i])
            stretchGraph.SetPointError(
                i,
                EXlow[i] * value, EXhigh[i] * value,
                EYlow[i], EYhigh[i])
        return stretchGraph

    def Shift(self, value, copy=False):
        """
        Shift the graph left or right by value
        """
        numPoints = self.GetN()
        if copy:
            shiftGraph = self.Clone()
        else:
            shiftGraph = self
        X = self.GetX()
        EXlow = self.GetEXlow()
        EXhigh = self.GetEXhigh()
        Y = self.GetY()
        EYlow = self.GetEYlow()
        EYhigh = self.GetEYhigh()
        for i in range(numPoints):
            shiftGraph.SetPoint(i, X[i] + value, Y[i])
            shiftGraph.SetPointError(
                i,
                EXlow[i], EXhigh[i],
                EYlow[i], EYhigh[i])
        return shiftGraph

    def Integrate(self):
        """
        Integrate using the trapazoidal method
        """
        area = 0.
        X = self.GetX()
        Y = self.GetY()
        for i in range(self.GetN() - 1):
            area += (X[i + 1] - X[i]) * (Y[i] + Y[i + 1]) / 2.
        return area

    def Append(self, other):
        """
        Append points from another graph
        """
        orig_len = len(self)
        self.Set(orig_len + len(other))
        ipoint = orig_len
        if hasattr(self, 'SetPointError'):
            for point in other:
                self.SetPoint(ipoint, point.x.value, point.y.value)
                self.SetPointError(
                    ipoint,
                    point.x.error_low, point.x.error_hi,
                    point.y.error_low, point.y.error_hi)
                ipoint += 1
        else:
            for point in other:
                self.SetPoint(ipoint, point.x.value, point.y.value)
                ipoint += 1


class _Graph2DBase(_GraphBase):

    def z(self, index=None):
        if index is None:
            return (self.GetZ()[i] for i in range(self.GetN()))
        index = index % len(self)
        return self.GetZ()[index]

    def zerr(self, index=None):
        if index is None:
            return (self.zerr(i) for i in range(self.GetN()))
        index = index % len(self)
        return self.GetErrorZ(index)


_GRAPH1D_BASES = {
    'default': QROOT.TGraph,
    'asymm': QROOT.TGraphAsymmErrors,
    'errors': QROOT.TGraphErrors,
    'benterrors': QROOT.TGraphBentErrors,
}
_GRAPH1D_CLASSES = {}


def _Graph_class(base):

    class Graph(_Graph1DBase, Plottable, NamelessConstructorObject,
                base):
        _ROOT = base
        DIM = 1

        def __init__(self, npoints_or_hist=None,
                     name=None, title=None, **kwargs):
            if npoints_or_hist is not None:
                super(Graph, self).__init__(npoints_or_hist,
                                            name=name, title=title)
            else:
                super(Graph, self).__init__(name=name, title=title)
            self._post_init(**kwargs)

    return Graph

for name, base in _GRAPH1D_BASES.items():
    _GRAPH1D_CLASSES[name] = snake_case_methods(_Graph_class(base))


[docs]class Graph(_Graph1DBase, QROOT.TGraph): """ Returns a Graph object which inherits from the associated ROOT.TGraph* class (TGraph, TGraphErrors, TGraphAsymmErrors) """ _ROOT = QROOT.TGraph DIM = 1 @classmethod def dynamic_cls(cls, type='asymm'): return _GRAPH1D_CLASSES[type] def __new__(cls, *args, **kwargs): type = kwargs.pop('type', 'asymm').lower() return cls.dynamic_cls(type)( *args, **kwargs)
# alias Graph1D -> Graph Graph1D = Graph _GRAPH2D_BASES = { 'default': QROOT.TGraph2D, 'errors': QROOT.TGraph2DErrors, } _GRAPH2D_CLASSES = {} def _Graph2D_class(base): class Graph2D(_Graph2DBase, Plottable, NamelessConstructorObject, base): _ROOT = base DIM = 2 def __init__(self, npoints_or_hist=None, name=None, title=None, **kwargs): if npoints_or_hist is not None: super(Graph2D, self).__init__(npoints_or_hist, name=name, title=title) else: super(Graph2D, self).__init__(name=name, title=title) if isinstance(npoints_or_hist, int): # ROOT bug in TGraph2D self.Set(npoints_or_hist) self._post_init(**kwargs) return Graph2D for name, base in _GRAPH2D_BASES.items(): _GRAPH2D_CLASSES[name] = snake_case_methods(_Graph2D_class(base))
[docs]class Graph2D(_Graph2DBase, QROOT.TGraph2D): """ Returns a Graph2D object which inherits from the associated ROOT.TGraph2D* class (TGraph2D, TGraph2DErrors) """ _ROOT = QROOT.TGraph2D DIM = 2 @classmethod def dynamic_cls(cls, type='errors'): return _GRAPH2D_CLASSES[type] def __new__(cls, *args, **kwargs): type = kwargs.pop('type', 'errors').lower() return cls.dynamic_cls(type)( *args, **kwargs)