Source code for tofu._plot

""" Module providing a basic routine for plotting a shot overview """
# Built-in
import warnings

# Common
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

# tofu
try:
    from tofu.version import __version__
    import tofu.utils as utils
except Exception:
    from tofu.version import __version__
    from .. import utils as utils


__all__ = ['plot_shotoverview']

_fs = (12,6)
__github = 'https://github.com/ToFuProject/tofu/issues'
_wintit = 'tofu-%s        report issues / requests at %s'%(__version__, __github)
_dmargin = dict(left=0.04, right=0.99,
                bottom=0.07, top=0.93,
                wspace=0.25, hspace=0.12)
_fontsize = 8
_labelpad = 0
_dcol = {'Ip':'k', 'B':'b', 'Bt':'b',
         'PLH1':(1.,0.,0.),'PLH2':(1.,0.5,0.),
         'PIC1':'',
         'Prad':(1.,0.,1.),
         'q1rhot':(0.8,0.8,0.8),
         'Ax':(0.,1.,0.)}
_lct = [plt.cm.tab20.colors[ii] for ii in [0,2,4,1,3,5]]
_ntMax = 3


[docs]def plot_shotoverview(db, ntMax=_ntMax, indt=0, config=None, inct=[1,5], dcol=None, lct=_lct, fmt_t='06.3f', fs=None, dmargin=None, tit=None, wintit=None, fontsize=_fontsize, labelpad=_labelpad, sharet=True, sharey=True, shareRZ=True, connect=True, draw=True): kh = _plot_shotoverview(db, ntMax=ntMax, indt=0, config=config, inct=inct, dcol=dcol, lct=lct, fmt_t=fmt_t, fs=fs, dmargin=dmargin, tit=tit, wintit=wintit, fontsize=fontsize, labelpad=labelpad, sharet=sharet, sharey=sharey, shareRZ=shareRZ, connect=connect, draw=draw) return kh
###################################################### # plot new ###################################################### def _plot_shotoverview_init(ns=1, sharet=True, sharey=True, shareRZ=True, fontsize=_fontsize, fs=None, wintit=None, dmargin=None): # Fromat inputs if fs is None: fs = _fs elif type(fs) is str and fs.lower()=='a4': fs = (11.7,8.3) if wintit is None: wintit = _wintit if dmargin is None: dmargin = _dmargin # Make figure and axes fig = plt.figure(figsize=fs) if wintit is not None: fig.canvas.set_window_title(wintit) axarr = GridSpec(ns, 3, **dmargin) laxt = [None for ii in range(0,ns)] laxc = [None for ii in range(0,ns)] for ii in range(0,ns): if ii == 0: laxt[ii] = fig.add_subplot(axarr[ii,:2]) laxc[ii] = fig.add_subplot(axarr[ii,2]) sht = laxt[0] if sharet else None shy = laxt[0] if sharey else None shRZ = laxc[0] if shareRZ else None else: laxt[ii] = fig.add_subplot(axarr[ii,:2], sharex=sht, sharey=shy) laxc[ii] = fig.add_subplot(axarr[ii,2], sharex=shRZ, sharey=shRZ) if not shareRZ: ax2.set_aspect('equal', adjustable='datalim') laxc[-1].set_xlabel(r'$R$ ($m$)') laxt[-1].set_xlabel(r'$t$ ($s$)', fontsize=fontsize) # datalim or box must be chosen for shared axis depending on matplotlib # version => let matplotlib decide until support for matplotlib 2.X.X is # stopped laxc[0].set_aspect('equal')#, adjustable='box') xtxt = laxc[0].get_position().bounds[0] dx = laxc[0].get_position().bounds[2] Ytxt, DY = np.sum(laxc[0].get_position().bounds[1::2]), 0.1 axtxtt = fig.add_axes([xtxt, Ytxt, dx, DY], fc='None') # xtxt, Ytxt, dx, DY = 0.01, 0.98, 0.15, 0.02 # axtxtg = fig.add_axes([xtxt, Ytxt, dx, DY], fc='None') # Dict dax = {'t':laxt, 'cross':laxc, 'txtt':[axtxtt]} #'txtg':[axtxtg] # not useful, one group only # Formatting for kk in dax.keys(): for ii in range(0,len(dax[kk])): dax[kk][ii].tick_params(labelsize=fontsize) if 'txt' in kk: dax[kk][ii].patch.set_alpha(0.) for ss in ['left','right','bottom','top']: dax[kk][ii].spines[ss].set_visible(False) dax[kk][ii].set_xticks([]), dax[kk][ii].set_yticks([]) dax[kk][ii].set_xlim(0,1), dax[kk][ii].set_ylim(0,1) return dax def _plot_shotoverview(db, ntMax=_ntMax, indt=0, config=None, inct=[1,5], dcol=None, lct=_lct, fmt_t='06.3f', fs=None, dmargin=None, tit=None, wintit=None, fontsize=_fontsize, labelpad=_labelpad, sharet=True, sharey=True, shareRZ=True, connect=True, draw=True): ######### # Prepare ######### fldict = dict(fontsize=fontsize, labelpad=labelpad) # Preformat if dcol is None: dcol = _dcol ls = sorted(list(db.keys())) ns = len(ls) lcol = ['k','b','r','m'] # Find common time limits tlim = np.vstack([np.vstack([(np.nanmin(vv['t']), np.nanmax(vv['t'])) if 't' in vv.keys() else (-np.inf,np.inf) for vv in db[ss].values()]) for ss in ls]) tlim = (np.min(tlim),np.max(tlim)) # Find common (R,Z) lims if config=None lEq = ['Ax','X','Sep','q1'] if config is None: Anycross = False Rmin, Rmax = np.full((ns,),np.inf), np.full((ns,),-np.inf) Zmin, Zmax = np.full((ns,),np.inf), np.full((ns,),-np.inf) for ii in range(0,ns): for kk in set(db[ls[ii]].keys()).intersection(lEq): if db[ls[ii]][kk]['data2D'].ndim == 2: Rmin[ii] = min(Rmin[ii],np.nanmin(db[ls[ii]][kk]['data2D'][:,0])) Rmax[ii] = max(Rmax[ii],np.nanmax(db[ls[ii]][kk]['data2D'][:,0])) Zmin[ii] = min(Zmin[ii],np.nanmin(db[ls[ii]][kk]['data2D'][:,1])) Zmax[ii] = max(Zmax[ii],np.nanmax(db[ls[ii]][kk]['data2D'][:,1])) else: Rmin[ii] = min(Rmin[ii],np.nanmin(db[ls[ii]][kk]['data2D'][:,0,:])) Rmax[ii] = max(Rmax[ii],np.nanmax(db[ls[ii]][kk]['data2D'][:,0,:])) Zmin[ii] = min(Zmin[ii],np.nanmin(db[ls[ii]][kk]['data2D'][:,1,:])) Zmax[ii] = max(Zmax[ii],np.nanmax(db[ls[ii]][kk]['data2D'][:,1,:])) Anycross = True Rlim = (np.nanmin(Rmin),np.nanmax(Rmax)) Zlim = (np.nanmin(Zmin),np.nanmax(Zmax)) if Anycross is False: Rlim = (1,3) Zlim = (-1,-1) # time vectors and refs lt = [None for ss in ls] lidt = [0 for ss in ls] for ii in range(0,ns): for kk in set(db[ls[ii]].keys()).intersection(lEq): lt[ii] = db[ls[ii]][kk]['t'] lidt[ii] = id(db[ls[ii]][kk]['t']) break else: for kk in set(db[ls[ii]].keys()).difference(lEq): lt[ii] = db[ls[ii]][kk]['t'] lidt[ii] = id(db[ls[ii]][kk]['t']) break else: msg = "No reference time vector found for shot %s"%str(ls[ii]) warnings.warn(msg) # dlextra id for ii in range(0,ns): for kk in set(db[ls[ii]].keys()).intersection(lEq): db[ls[ii]][kk]['id'] = id(db[ls[ii]][kk]['data2D']) ############## # Plot static ############## dax = _plot_shotoverview_init(ns=ns, sharet=sharet, sharey=sharey, shareRZ=shareRZ, fontsize=fontsize, fs=fs, wintit=wintit, dmargin=dmargin) fig = dax['t'][0].figure if tit is None: tit = r"overview of shots " + ', '.join(map('{0:05.0f}'.format,ls)) fig.suptitle(tit) # Plot config and time traces for ii in range(0,ns): dd = db[ls[ii]] # config if config is not None: dax['cross'][ii] = config.plot(proj='cross', lax=dax['cross'][ii], element='P', dLeg=None, draw=False) # time traces for kk in set(dd.keys()).difference(lEq): if 'c' in dd[kk].keys(): c = dd[kk]['c'] else: c = dcol[kk] lab = dd[kk]['label'] + ' (%s)'%dd[kk]['units'] dax['t'][ii].plot(dd[kk]['t'], dd[kk]['data'], ls='-', lw=1., c=c, label=lab) kk = 'Ax' if kk in dd.keys(): if 'c' in dd[kk].keys(): c = dd[kk]['c'] else: c = dcol[kk] x = db[ls[ii]][kk]['data2D'][:,0] y = db[ls[ii]][kk]['data2D'][:,1] dax['t'][ii].plot(lt[ii], x, lw=1., ls='-', label=r'$R_{Ax}$ (m)') dax['t'][ii].plot(lt[ii], y, lw=1., ls='-', label=r'$Z_{Ax}$ (m)') dax['t'][0].axhline(0., ls='--', lw=1., c='k') dax['t'][0].legend(bbox_to_anchor=(0.,1.01,1.,0.1), loc=3, ncol=5, mode='expand', borderaxespad=0., prop={'size':fontsize}) dax['t'][0].set_xlim(tlim) if config is None: try: # DB dax['cross'][0].set_xlim(Rlim) dax['cross'][0].set_ylim(Zlim) except Exception as err: # DB print(Rlim, Zlim) print(Rmin, Rmax) print(Zmin, Zmax) raise err for ii in range(0,ns): dax['t'][ii].set_ylabel('{0:05.0f} data'.format(ls[ii]), fontsize=fontsize) dax['cross'][-1].set_ylabel(r'$Z$ ($m$)', fontsize=fontsize) ################## # Interactivity dict ################## dgroup = {'time': {'nMax':ntMax, 'key':'f1', 'defid':lidt[0], 'defax':dax['t'][0]}} # Group info (make dynamic in later versions ?) # msg = ' '.join(['%s: %s'%(v['key'],k) for k, v in dgroup.items()]) # l0 = dax['txtg'][0].text(0., 0., msg, # color='k', fontweight='bold', # fontsize=6., ha='left', va='center') # dref dref = dict([(lidt[ii], {'group':'time', 'val':lt[ii], 'inc':inct}) for ii in range(0,ns)]) # ddata ddat = {} for ii in range(0,ns): for kk in set(db[ls[ii]].keys()).intersection(lEq): ddat[db[ls[ii]][kk]['id']] = {'val':db[ls[ii]][kk]['data2D'], 'refids':[lidt[ii]]} # dax lax_fix = dax['cross'] + dax['txtt'] # + dax['txtg'] dax2 = dict([(dax['t'][ii], {'ref':{lidt[ii]:'x'}}) for ii in range(0,ns)]) dobj = {} ################## # Populating dobj # One-axes time txt for jj in range(0,ntMax): l0 = dax['txtt'][0].text((0.5+jj)/ntMax, 0., r'', color='k', fontweight='bold', fontsize=fontsize, ha='left', va='bottom') dobj[l0] = {'dupdate':{'txt':{'id':lidt[0], 'lrid':[lidt[0]], 'bstr':'{0:%s}'%fmt_t}}, 'drefid':{lidt[0]:jj}} # Time-dependent nan2 = np.array([np.nan]) for ii in range(0,ns): # time vlines for jj in range(0,ntMax): l0 = dax['t'][ii].axvline(np.nan, c=lct[jj], ls='-', lw=1.) dobj[l0] = {'dupdate':{'xdata':{'id':lidt[ii], 'lrid':[lidt[ii]]}}, 'drefid':{lidt[ii]:jj}} # Eq for kk in set(db[ls[ii]].keys()).intersection(lEq): id_ = db[ls[ii]][kk]['id'] for jj in range(0,ntMax): l0, = dax['cross'][ii].plot(nan2, nan2, ls='-', c=lct[jj], lw=1.) dobj[l0] = {'dupdate':{'data':{'id':id_, 'lrid':[lidt[ii]]}}, 'drefid':{lidt[ii]:jj}} ################## # Instanciate KeyHandler can = fig.canvas can.draw() kh = utils.KeyHandler_mpl(can=can, dgroup=dgroup, dref=dref, ddata=ddat, dobj=dobj, dax=dax2, lax_fix=lax_fix, groupinit='time', follow=True) if connect: kh.disconnect_old() kh.connect() if draw: fig.canvas.draw() return kh