Skip to content
Snippets Groups Projects
Commit 9d22fc67 authored by Ivan Kondov's avatar Ivan Kondov
Browse files

init object from existing pyplot object, replaced deprecated as_matrix()

parent 875caf96
No related branches found
No related tags found
No related merge requests found
""" utility classes to create xy plots and bar charts with pandas """
import itertools
from itertools import cycle
import pandas as pd
import matplotlib
matplotlib.use('Qt5Cairo')
......@@ -8,16 +8,23 @@ import matplotlib.pyplot as plt
class PandasPlot:
""" create xy plots and bar charts from a pandas DataFrame """
def __init__(self, df, showx=None, showy=None, labels=None, ptype='xyplot'):
fig = plt.figure()
self.plotobj = fig.add_subplot(111)
self.plotobj.set_xlabel(labels[showx])
self.plotobj.set_ylabel(labels[showy])
color_defs = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
def __init__(self, df, showx=None, showy=None, labels=None,
ptype='xyplot', plotobj=None, colors=None):
if plotobj is None:
fig = plt.figure()
self.plotobj = fig.add_subplot(111)
self.plotobj.set_xlabel(labels[showx])
self.plotobj.set_ylabel(labels[showy])
else:
self.plotobj = plotobj
self.df = df
self.showx = showx
self.showy = showy
self.labels = labels
self.ptype = ptype
self.colors = cycle(self.color_defs) if colors is None else colors
@classmethod
def from_file(cls, filename, **kwargs):
......@@ -27,9 +34,6 @@ class PandasPlot:
def add_datasets(self, select=None, vary=None, marker='o'):
""" add selected datasets from a dataframe to a plot object """
color_defs = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
colors = itertools.cycle(color_defs)
sel = True
labs = []
for key in select:
......@@ -43,16 +47,18 @@ class PandasPlot:
self.barwidth = 0.5/self.nbars
for numb, val in enumerate(vary[key]):
df2 = self.df[sel & (self.df[key] == val)].sort_values(self.showx)
xdata = df2[self.showx].as_matrix()
ydata = df2[self.showy].as_matrix()
xdata = df2[self.showx].values
ydata = df2[self.showy].values
if len(ydata) == 0:
continue
legtx = ', '.join(labs) +', '+self.labels[key]+': '+str(val)
if self.ptype == 'xyplot':
self.plotobj.plot(xdata, ydata, marker=marker,
c=next(colors), linestyle='-',
c=next(self.colors), linestyle='-',
label=legtx)
elif self.ptype == 'barchart':
self.plotobj.bar(xdata+0.5*numb/self.nbars, ydata,
self.barwidth, color=next(colors),
self.barwidth, color=next(self.colors),
label=legtx)
def show(self, legend=False):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment