Skip to content

Plotting API Reference

nelpy.plotting

The nelpy.plotting sub-package provides a variety of plotting functions and tools for visualizing data in nelpy, including raster plots, tuning curves, color utilities, colormaps, and more. It includes convenience wrappers for matplotlib and plotting functions that work directly with nelpy objects.

Main Features
  • Plotting functions for nelpy objects (e.g., rasterplot, epochplot, imagesc)
  • Color palettes and colormaps for scientific visualization
  • Utilities for figure management and aesthetics
  • Context and style management for publication-quality figures

Examples:

>>> import nelpy.plotting as npl
>>> npl.rasterplot(...)
>>> npl.plot_tuning_curves1D(...)

FigureManager

Bases: object

Figure context manager for creating, displaying, and saving figures.

See http://stackoverflow.com/questions/12594148/skipping-execution-of-with-block but I was unable to get a solution so far...

See http://stackoverflow.com/questions/11195140/break-or-exit-out-of-with-statement for additional inspiration for making nested context managers...

Parameters:

Name Type Description Default
filename str

Filename without an extension. If an extension is present, AND if formats is empty, then the filename extension will be used.

None
save bool

If True, figure will be saved to disk.

False
show bool

If True, figure will be shown.

False
nrows int

Number of subplot rows.

1
ncols int

Number of subplot columns.

1
figsize tuple

Figure size in inches (width, height).

(8, 3)
tight_layout bool

If True, use tight layout.

False
formats list

List of formats to export. Defaults to ['pdf', 'png']

None
dpi float

Resolution of the figure in dots per inch (DPI).

None
verbose bool

If True, print additional output to screen.

True
overwrite bool

If True, file will be overwritten.

False
**kwargs dict

Additional keyword arguments passed to plt.figure().

{}

Examples:

>>> with FigureManager(filename="myfig", save=True, show=False) as (fig, ax):
...     ax.plot([1, 2, 3], [4, 5, 6])
Source code in nelpy/plotting/utils.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
class FigureManager(object):
    """
    Figure context manager for creating, displaying, and saving figures.

    See http://stackoverflow.com/questions/12594148/skipping-execution-of-with-block
    but I was unable to get a solution so far...

    See http://stackoverflow.com/questions/11195140/break-or-exit-out-of-with-statement
    for additional inspiration for making nested context managers...

    Parameters
    ----------
    filename : str, optional
        Filename without an extension. If an extension is present,
        AND if formats is empty, then the filename extension will be used.
    save : bool, optional
        If True, figure will be saved to disk.
    show : bool, optional
        If True, figure will be shown.
    nrows : int, optional
        Number of subplot rows.
    ncols : int, optional
        Number of subplot columns.
    figsize : tuple, optional
        Figure size in inches (width, height).
    tight_layout : bool, optional
        If True, use tight layout.
    formats : list, optional
        List of formats to export. Defaults to ['pdf', 'png']
    dpi : float, optional
        Resolution of the figure in dots per inch (DPI).
    verbose : bool, optional
        If True, print additional output to screen.
    overwrite : bool, optional
        If True, file will be overwritten.
    **kwargs : dict
        Additional keyword arguments passed to plt.figure().

    Examples
    --------
    >>> with FigureManager(filename="myfig", save=True, show=False) as (fig, ax):
    ...     ax.plot([1, 2, 3], [4, 5, 6])
    """

    class Break(Exception):
        """Exception to break out of the context manager block."""

        pass

    def __init__(
        self,
        *,
        filename=None,
        save=False,
        show=False,
        nrows=1,
        ncols=1,
        figsize=(8, 3),
        tight_layout=False,
        formats=None,
        dpi=None,
        verbose=True,
        overwrite=False,
        **kwargs,
    ):
        self.nrows = nrows
        self.ncols = ncols
        self.figsize = figsize
        self.tight_layout = tight_layout
        self.dpi = dpi
        self.kwargs = kwargs

        self.filename = filename
        self.show = show
        self.save = save
        self.formats = formats
        self.dpi = dpi
        self.verbose = verbose
        self.overwrite = overwrite

        if self.show or self.save:
            self.skip = False
        else:
            self.skip = True

    def __enter__(self):
        """
        Enter the context manager, creating the figure and axes.

        Returns
        -------
        fig : matplotlib.figure.Figure
            The created figure.
        ax : matplotlib.axes.Axes or numpy.ndarray
            The created axes (single or array, depending on nrows/ncols).
        """
        if not self.skip:
            self.fig = plt.figure(figsize=self.figsize, dpi=self.dpi, **self.kwargs)
            self.fig.npl_gs = gridspec.GridSpec(nrows=self.nrows, ncols=self.ncols)

            self.ax = np.array([self.fig.add_subplot(ss) for ss in self.fig.npl_gs])
            # self.fig, self.ax = plt.subplots(nrows=self.nrows,
            #                                  ncols=self.ncols,
            #                                  figsize=self.figsize,
            #                                  tight_layout=self.tight_layout,
            #                                  dpi=self.dpi,
            #                                  **self.kwargs)
            if len(self.ax) == 1:
                self.ax = self.ax[0]

            if self.tight_layout:
                self.fig.npl_gs.tight_layout(self.fig)

            # gs1.tight_layout(fig, rect=[0, 0.03, 1, 0.95])
            if self.fig != plt.gcf():
                self.clear()
                raise RuntimeError("Figure does not match active mpl figure")
            return self.fig, self.ax
        return -1, -1

    def __exit__(self, exc_type, exc_value, traceback):
        """
        Exit the context manager, saving and/or showing the figure if requested.

        Parameters
        ----------
        exc_type : type
            Exception type, if any.
        exc_value : Exception
            Exception value, if any.
        traceback : traceback
            Traceback object, if any.
        """
        if self.skip:
            return True
        if not exc_type:
            if self.save:
                assert self.filename is not None, "filename has to be specified!"
                savefig(
                    name=self.filename,
                    fig=self.fig,
                    formats=self.formats,
                    dpi=self.dpi,
                    verbose=self.verbose,
                    overwrite=self.overwrite,
                )

            if self.show:
                plt.show(self.fig)
            self.clear()
        else:
            self.clear()
            return False

    def clear(self):
        """
        Close the figure and clean up references.
        """
        plt.close(self.fig)
        del self.ax
        del self.fig

Break

Bases: Exception

Exception to break out of the context manager block.

Source code in nelpy/plotting/utils.py
152
153
154
155
class Break(Exception):
    """Exception to break out of the context manager block."""

    pass

clear()

Close the figure and clean up references.

Source code in nelpy/plotting/utils.py
262
263
264
265
266
267
268
def clear(self):
    """
    Close the figure and clean up references.
    """
    plt.close(self.fig)
    del self.ax
    del self.fig

add_scalebar(ax, *, matchx=False, matchy=False, sizex=None, sizey=None, labelx=None, labely=None, hidex=True, hidey=True, ec='k', **kwargs)

Add scalebars to axes, matching the size to the ticks of the plot and optionally hiding the x and y axes.

Parameters:

Name Type Description Default
ax Axes

The axis to attach scalebars to.

required
matchx bool

If True, set size of x scalebar to spacing between ticks. Default is False.

False
matchy bool

If True, set size of y scalebar to spacing between ticks. Default is False.

False
sizex float

Size of x scalebar. Used if matchx is False.

None
sizey float

Size of y scalebar. Used if matchy is False.

None
labelx str

Label for x scalebar.

None
labely str

Label for y scalebar.

None
hidex bool

If True, hide x-axis of parent. Default is True.

True
hidey bool

If True, hide y-axis of parent. Default is True.

True
ec color

Edge color of the scalebar. Default is 'k'.

'k'
**kwargs dict

Additional arguments passed to AnchoredScaleBar.

{}

Returns:

Name Type Description
ax Axes

The axis containing the scalebar object.

Examples:

>>> import matplotlib.pyplot as plt
>>> fig, ax = plt.subplots()
>>> add_scalebar(ax, sizex=1, labelx="1 s")
Source code in nelpy/plotting/scalebar.py
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
def add_scalebar(
    ax,
    *,
    matchx=False,
    matchy=False,
    sizex=None,
    sizey=None,
    labelx=None,
    labely=None,
    hidex=True,
    hidey=True,
    ec="k",
    **kwargs,
):
    """
    Add scalebars to axes, matching the size to the ticks of the plot and optionally hiding the x and y axes.

    Parameters
    ----------
    ax : matplotlib.axes.Axes
        The axis to attach scalebars to.
    matchx : bool, optional
        If True, set size of x scalebar to spacing between ticks. Default is False.
    matchy : bool, optional
        If True, set size of y scalebar to spacing between ticks. Default is False.
    sizex : float, optional
        Size of x scalebar. Used if matchx is False.
    sizey : float, optional
        Size of y scalebar. Used if matchy is False.
    labelx : str, optional
        Label for x scalebar.
    labely : str, optional
        Label for y scalebar.
    hidex : bool, optional
        If True, hide x-axis of parent. Default is True.
    hidey : bool, optional
        If True, hide y-axis of parent. Default is True.
    ec : color, optional
        Edge color of the scalebar. Default is 'k'.
    **kwargs : dict
        Additional arguments passed to AnchoredScaleBar.

    Returns
    -------
    ax : matplotlib.axes.Axes
        The axis containing the scalebar object.

    Examples
    --------
    >>> import matplotlib.pyplot as plt
    >>> fig, ax = plt.subplots()
    >>> add_scalebar(ax, sizex=1, labelx="1 s")
    """

    # determine which type op scalebar to plot:
    # [(horizontal, vertical, both), (matchx, matchy), (labelx, labely)]
    #
    # matchx AND sizex ==> error
    # matchy AND sizey ==> error
    #
    # matchx == True ==> determine sizex
    # matchy == True ==> determine sizey
    #
    # if sizex ==> horizontal
    # if sizey ==> vertical
    # if sizex and sizey ==> both
    #
    # at this point we fully know which type the scalebar is
    #
    # labelx is None ==> determine from size
    # labely is None ==> determine from size
    #
    # NOTE: to force label empty, use labelx = ' '
    #

    # TODO: add logic for inverted axes:
    # yinverted = ax.yaxis_inverted()
    # xinverted = ax.xaxis_inverted()

    def f(axis):
        tick_locations = axis.get_majorticklocs()
        return len(tick_locations) > 1 and (tick_locations[1] - tick_locations[0])

    if matchx and sizex:
        raise ValueError("matchx and sizex cannot both be specified")
    if matchy and sizey:
        raise ValueError("matchy and sizey cannot both be specified")

    if matchx:
        sizex = f(ax.xaxis)
    if matchy:
        sizey = f(ax.yaxis)

    if not sizex and not sizey:
        raise ValueError("sizex and sizey cannot both be zero")

    kwargs["sizex"] = sizex
    kwargs["sizey"] = sizey

    if sizex:
        sbtype = "horizontal"
        if labelx is None:
            labelx = str(sizex)
    if sizey:
        sbtype = "vertical"
        if labely is None:
            labely = str(sizey)
    if sizex and sizey:
        sbtype = "both"

    kwargs["labelx"] = labelx
    kwargs["labely"] = labely
    kwargs["ec"] = ec

    if sbtype == "both":
        # draw horizontal component:
        kwargs["labely"] = " "  # necessary to correct center alignment
        kwargs["ec"] = None  # necessary to correct possible artifact
        sbx = AnchoredScaleBar(ax.transData, xfirst=True, **kwargs)

        # draw vertical component:
        kwargs["ec"] = ec
        kwargs["labelx"] = " "
        kwargs["labely"] = labely
        sby = AnchoredScaleBar(ax.transData, xfirst=False, **kwargs)
        ax.add_artist(sbx)
        ax.add_artist(sby)
    else:
        sb = AnchoredScaleBar(ax.transData, **kwargs)
        ax.add_artist(sb)

    if hidex:
        ax.xaxis.set_visible(False)
    if hidey:
        ax.yaxis.set_visible(False)

    return ax

add_simple_scalebar(text, ax=None, xy=None, length=None, orientation='v', rotation_text=None, xytext=None, **kwargs)

Add a simple horizontal or vertical scalebar with a label to an axis.

Parameters:

Name Type Description Default
text str

The label for the scalebar.

required
ax Axes

Axis to add the scalebar to. If None, uses current axis.

None
xy tuple of float

Starting (x, y) position for the scalebar.

None
length float

Length of the scalebar. Default is 10.

None
orientation (v, h, vert, horz)

Orientation of the scalebar. 'v' or 'vert' for vertical, 'h' or 'horz' for horizontal. Default is 'v'.

'v'
rotation_text int or str

Rotation of the label text. Default is 0.

None
xytext tuple of float

Position for the label text. If None, automatically determined.

None
**kwargs dict

Additional keyword arguments passed to matplotlib's annotate.

{}

Examples:

>>> import matplotlib.pyplot as plt
>>> fig, ax = plt.subplots()
>>> add_simple_scalebar("10 s", ax=ax, xy=(0, 0), length=10, orientation="h")
Source code in nelpy/plotting/scalebar.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def add_simple_scalebar(
    text,
    ax=None,
    xy=None,
    length=None,
    orientation="v",
    rotation_text=None,
    xytext=None,
    **kwargs,
):
    """
    Add a simple horizontal or vertical scalebar with a label to an axis.

    Parameters
    ----------
    text : str
        The label for the scalebar.
    ax : matplotlib.axes.Axes, optional
        Axis to add the scalebar to. If None, uses current axis.
    xy : tuple of float
        Starting (x, y) position for the scalebar.
    length : float, optional
        Length of the scalebar. Default is 10.
    orientation : {'v', 'h', 'vert', 'horz'}, optional
        Orientation of the scalebar. 'v' or 'vert' for vertical, 'h' or 'horz' for horizontal. Default is 'v'.
    rotation_text : int or str, optional
        Rotation of the label text. Default is 0.
    xytext : tuple of float, optional
        Position for the label text. If None, automatically determined.
    **kwargs : dict
        Additional keyword arguments passed to matplotlib's annotate.

    Examples
    --------
    >>> import matplotlib.pyplot as plt
    >>> fig, ax = plt.subplots()
    >>> add_simple_scalebar("10 s", ax=ax, xy=(0, 0), length=10, orientation="h")
    """
    if rotation_text is None:
        rotation_text = 0
    if rotation_text == "vert" or rotation_text == "v":
        rotation_text = 90
    if rotation_text == "horz" or rotation_text == "h":
        rotation_text = 0
    if orientation is None:
        orientation = 0
    if orientation == "vert" or orientation == "v":
        orientation = 90
    if orientation == "horz" or orientation == "h":
        orientation = 0

    if length is None:
        length = 10

    if ax is None:
        ax = plt.gca()

    #     if va is None:
    #         if rotation_text == 90:
    #             va = 'bottom'
    #         else:
    #             va = 'baseline'

    if orientation == 0:
        ax.hlines(xy[1], xy[0], xy[0] + length, lw=2, zorder=1000)
    else:
        ax.vlines(xy[0], xy[1], xy[1] + length, lw=2, zorder=1000)
        xytext = (xy[0] + 3, xy[1] + length / 2)
        ax.annotate(
            text, xy=xytext, rotation=rotation_text, va="center", zorder=1000, **kwargs
        )

axes_style(style=None, rc=None)

Return a parameter dict for the aesthetic style of the plots.

This affects things like the color of the axes, whether a grid is enabled by default, and other aesthetic elements.

This function returns an object that can be used in a with statement to temporarily change the style parameters.

Parameters:

Name Type Description Default
style dict, None, or one of {darkgrid, whitegrid, dark, white, ticks}

A dictionary of parameters or the name of a preconfigured set.

None
rc dict

Parameter mappings to override the values in the preset seaborn style dictionaries. This only updates parameters that are considered part of the style definition.

None

Returns:

Name Type Description
style_object _AxesStyle

An object that can be used as a context manager to temporarily set style.

Examples:

>>> st = axes_style("whitegrid")
>>> set_style("ticks", {"xtick.major.size": 8, "ytick.major.size": 8})
>>> import matplotlib.pyplot as plt
>>> with axes_style("white"):
...     f, ax = plt.subplots()
...     ax.plot([0, 1], [0, 1])
See Also

set_style : set the matplotlib parameters for a seaborn theme plotting_context : return a parameter dict to scale plot elements color_palette : define the color palette for a plot

Source code in nelpy/plotting/rcmod.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def axes_style(style=None, rc=None):
    """
    Return a parameter dict for the aesthetic style of the plots.

    This affects things like the color of the axes, whether a grid is
    enabled by default, and other aesthetic elements.

    This function returns an object that can be used in a ``with`` statement
    to temporarily change the style parameters.

    Parameters
    ----------
    style : dict, None, or one of {darkgrid, whitegrid, dark, white, ticks}
        A dictionary of parameters or the name of a preconfigured set.
    rc : dict, optional
        Parameter mappings to override the values in the preset seaborn
        style dictionaries. This only updates parameters that are
        considered part of the style definition.

    Returns
    -------
    style_object : _AxesStyle
        An object that can be used as a context manager to temporarily set style.

    Examples
    --------
    >>> st = axes_style("whitegrid")
    >>> set_style("ticks", {"xtick.major.size": 8, "ytick.major.size": 8})
    >>> import matplotlib.pyplot as plt
    >>> with axes_style("white"):
    ...     f, ax = plt.subplots()
    ...     ax.plot([0, 1], [0, 1])

    See Also
    --------
    set_style : set the matplotlib parameters for a seaborn theme
    plotting_context : return a parameter dict to scale plot elements
    color_palette : define the color palette for a plot
    """
    if style is None:
        style_dict = {k: mpl.rcParams[k] for k in _style_keys}

    elif isinstance(style, dict):
        style_dict = style

    else:
        styles = ["white", "dark", "whitegrid", "darkgrid", "ticks"]
        if style not in styles:
            raise ValueError("style must be one of %s" % ", ".join(styles))

        # Define colors here
        dark_gray = ".15"
        light_gray = ".8"

        # Common parameters
        style_dict = {
            "figure.facecolor": "white",
            "text.color": dark_gray,
            "axes.labelcolor": dark_gray,
            "legend.frameon": False,
            "legend.numpoints": 1,
            "legend.scatterpoints": 1,
            "xtick.direction": "out",
            "ytick.direction": "out",
            "xtick.color": dark_gray,
            "ytick.color": dark_gray,
            "axes.axisbelow": True,
            "lines.linewidth": 1.75,
            "image.cmap": "Greys",
            "font.family": ["sans-serif"],
            "font.sans-serif": [
                "DejaVu Sans",
                "Arial",
                "Liberation Sans",
                "Bitstream Vera Sans",
                "sans-serif",
            ],
            "grid.linestyle": "-",
            "lines.solid_capstyle": "round",
        }

        # Set grid on or off
        if "grid" in style:
            style_dict.update(
                {
                    "axes.grid": True,
                }
            )
        else:
            style_dict.update(
                {
                    "axes.grid": False,
                }
            )

        # Set the color of the background, spines, and grids
        if style.startswith("dark"):
            style_dict.update(
                {
                    "axes.facecolor": "#EAEAF2",
                    "axes.edgecolor": "white",
                    "axes.linewidth": 0,
                    "grid.color": "white",
                }
            )

        elif style == "whitegrid":
            style_dict.update(
                {
                    "axes.facecolor": "white",
                    "axes.edgecolor": light_gray,
                    "axes.linewidth": 1,
                    "grid.color": light_gray,
                }
            )

        elif style in ["white", "ticks"]:
            style_dict.update(
                {
                    "axes.facecolor": "white",
                    "axes.edgecolor": dark_gray,
                    "axes.linewidth": 1.25,
                    "grid.color": light_gray,
                }
            )

        # Show or hide the axes ticks
        if style == "ticks":
            style_dict.update(
                {
                    "xtick.major.size": 6,
                    "ytick.major.size": 6,
                    "xtick.minor.size": 3,
                    "ytick.minor.size": 3,
                }
            )
        else:
            style_dict.update(
                {
                    "xtick.major.size": 0,
                    "ytick.major.size": 0,
                    "xtick.minor.size": 0,
                    "ytick.minor.size": 0,
                }
            )

    # Override these settings with the provided rc dictionary
    if rc is not None:
        rc = {k: v for k, v in rc.items() if k in _style_keys}
        style_dict.update(rc)

    # Wrap in an _AxesStyle object so this can be used in a with statement
    style_object = _AxesStyle(style_dict)

    return style_object

colorline(x, y, cmap=None, cm_range=(0, 0.7), **kwargs)

Plot a trajectory of (x, y) points with a colormap along the path.

Parameters:

Name Type Description Default
x array - like

X coordinates of the trajectory.

required
y array - like

Y coordinates of the trajectory.

required
cmap Colormap

Colormap to use for coloring the line. Defaults to plt.cm.Blues_r.

None
cm_range tuple of float

Range of the colormap to use (min, max). Defaults to (0, 0.7).

(0, 0.7)
**kwargs dict

Additional keyword arguments passed to the plot (e.g., ax, lw).

{}

Returns:

Name Type Description
lc LineCollection

The colored line collection added to the axis.

Examples:

>>> colorline(x, y, cmap=plt.cm.viridis)
Source code in nelpy/plotting/core.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def colorline(x, y, cmap=None, cm_range=(0, 0.7), **kwargs):
    """
    Plot a trajectory of (x, y) points with a colormap along the path.

    Parameters
    ----------
    x : array-like
        X coordinates of the trajectory.
    y : array-like
        Y coordinates of the trajectory.
    cmap : matplotlib.colors.Colormap, optional
        Colormap to use for coloring the line. Defaults to plt.cm.Blues_r.
    cm_range : tuple of float, optional
        Range of the colormap to use (min, max). Defaults to (0, 0.7).
    **kwargs : dict
        Additional keyword arguments passed to the plot (e.g., ax, lw).

    Returns
    -------
    lc : matplotlib.collections.LineCollection
        The colored line collection added to the axis.

    Examples
    --------
    >>> colorline(x, y, cmap=plt.cm.viridis)
    """

    # plt.plot(x, y, '-k', zorder=1)
    # plt.scatter(x, y, s=40, c=plt.cm.RdBu(np.linspace(0,1,40)), zorder=2, edgecolor='k')

    assert len(cm_range) == 2, "cm_range must have (min, max)"
    assert len(x) == len(y), "x and y must have the same number of elements!"

    ax = kwargs.get("ax", plt.gca())
    lw = kwargs.get("lw", 2)
    if cmap is None:
        cmap = plt.cm.Blues_r

    t = np.linspace(cm_range[0], cm_range[1], len(x))

    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)

    lc = LineCollection(segments, cmap=cmap, norm=plt.Normalize(0, 1), zorder=50)
    lc.set_array(t)
    lc.set_linewidth(lw)

    ax.add_collection(lc)

    return lc

decode_and_plot_events1D(*, bst, tc, raster=True, st=None, st_order='track', evt_subset=None, **kwargs)

Decode and plot 1D events with optional raster plot overlay.

Parameters:

Name Type Description Default
bst BinnedSpikeTrainArray

The binned spike train array to decode.

required
tc TuningCurve1D

The tuning curve used for decoding.

required
raster bool

Whether to include a raster plot (default is True).

True
st SpikeTrainArray

The spike train array for raster plotting.

None
st_order str or array - like

Order of units for raster plot. Options: 'track', 'first', 'random', or array of unit ids.

'track'
evt_subset list

List of integer indices for event selection. If not sorted, will be sorted.

None
**kwargs

Additional keyword arguments for plotting.

{}

Returns:

Name Type Description
fig Figure

The figure containing the plot.

Examples:

>>> fig = decode_and_plot_events1D(bst=bst, tc=tc, st=st)
>>> plt.show()
Source code in nelpy/plotting/decoding.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def decode_and_plot_events1D(
    *, bst, tc, raster=True, st=None, st_order="track", evt_subset=None, **kwargs
):
    """
    Decode and plot 1D events with optional raster plot overlay.

    Parameters
    ----------
    bst : BinnedSpikeTrainArray
        The binned spike train array to decode.
    tc : TuningCurve1D
        The tuning curve used for decoding.
    raster : bool, optional
        Whether to include a raster plot (default is True).
    st : SpikeTrainArray, optional
        The spike train array for raster plotting.
    st_order : str or array-like, optional
        Order of units for raster plot. Options: 'track', 'first', 'random', or array of unit ids.
    evt_subset : list, optional
        List of integer indices for event selection. If not sorted, will be sorted.
    **kwargs
        Additional keyword arguments for plotting.

    Returns
    -------
    fig : matplotlib.figure.Figure
        The figure containing the plot.

    Examples
    --------
    >>> fig = decode_and_plot_events1D(bst=bst, tc=tc, st=st)
    >>> plt.show()
    """

    # TODO: add **kwargs
    #   fig size, cmap, raster lw, raster color, other axes props, ...

    from mpl_toolkits.axes_grid1 import make_axes_locatable

    unit_ids = set(bst.unit_ids)
    unit_ids = unit_ids.intersection(st.unit_ids)
    unit_ids = unit_ids.intersection(tc.unit_ids)

    bst = bst._unit_subset(unit_ids)
    st = st._unit_subset(unit_ids)
    tc = tc._unit_subset(unit_ids)

    if evt_subset is None:
        evt_subset = np.arange(bst.n_epochs)
    evt_subset = list(evt_subset)
    if not is_sorted(evt_subset):
        evt_subset.sort()
    bst = bst[evt_subset]

    # now that the bst has potentially been restricted by evt_subset, we trim down the spike train as well:
    st = st[bst.support]
    st = collapse_time(st)

    if st_order == "track":
        new_order = tc.get_peak_firing_order_ids()
    elif st_order == "first":
        new_order = st.get_spike_firing_order()
    elif st_order == "random":
        new_order = np.random.permutation(st.unit_ids)
    else:
        new_order = st_order
    st.reorder_units_by_ids(new_order, inplace=True)

    # now decode events in bst:
    posterior, bdries, mode_pth, mean_pth = decoding.decode1D(
        bst=bst, ratemap=tc, xmax=tc.bins[-1]
    )

    fig, ax = plt.subplots(figsize=(bst.n_bins / 5, 4))

    pixel_width = 0.5

    imagesc(
        x=np.arange(bst.n_bins),
        y=np.arange(311),
        data=posterior,
        cmap=plt.cm.Spectral_r,
        ax=ax,
    )
    plotutils.yticks_interval(310)
    plotutils.no_yticks(ax)

    ax.vlines(
        np.arange(bst.lengths.sum()) - pixel_width,
        *ax.get_ylim(),
        lw=1,
        linestyle=":",
        color="0.8",
    )
    ax.vlines(np.cumsum(bst.lengths) - pixel_width, *ax.get_ylim(), lw=1)

    ax.set_xlim(-pixel_width, bst.lengths.sum() - pixel_width)

    event_centers = np.insert(np.cumsum(bst.lengths), 0, 0)
    event_centers = event_centers[:-1] + bst.lengths / 2 - 0.5

    #     ax.set_xticks([0, bst.n_bins-1])
    #     ax.set_xticklabels([1, bst.n_bins])

    ax.set_xticks(event_centers)
    ax.set_xticklabels(evt_subset)
    #     ax.xaxis.tick_top()
    #     ax.xaxis.set_label_position('top')

    plotutils.no_xticks(ax)

    divider = make_axes_locatable(ax)
    axRaster = divider.append_axes("top", size=1.5, pad=0)

    rasterplot(st, vertstack=True, ax=axRaster, lh=1.25, lw=2.5, color="0.1")

    axRaster.set_xlim(st.support.time.squeeze())
    bin_edges = np.linspace(
        st.support.time[0, 0], st.support.time[0, 1], bst.n_bins + 1
    )
    axRaster.vlines(bin_edges, *axRaster.get_ylim(), lw=1, linestyle=":", color="0.8")
    axRaster.vlines(
        bin_edges[np.cumsum(bst.lengths)], *axRaster.get_ylim(), lw=1, color="0.2"
    )
    plotutils.no_xticks(axRaster)
    plotutils.no_xticklabels(axRaster)
    plotutils.no_yticklabels(axRaster)
    plotutils.no_yticks(axRaster)
    ax.set_ylabel("position")
    axRaster.set_ylabel("units")
    ax.set_xlabel("time bins")
    plotutils.clear_left_right(axRaster)
    plotutils.clear_top_bottom(axRaster)

    plotutils.align_ylabels(0, ax, axRaster)
    return fig

epochplot(epochs, data=None, *, ax=None, height=None, fc='0.5', ec='0.5', alpha=0.5, hatch='', label=None, hc=None, **kwargs)

Plot an EpochArray as horizontal bars (intervals) on a timeline.

Parameters:

Name Type Description Default
epochs EpochArray

The epochs to plot.

required
data array - like

Data to plot on y axis; must be of size (epochs.n_epochs,).

None
ax Axes

Axis to plot on. If None, uses current axis.

None
height float

Height of the bars. If None, uses default.

None
fc color

Face color of the bars. Default is '0.5'.

'0.5'
ec color

Edge color of the bars. Default is '0.5'.

'0.5'
alpha float

Transparency of the bars. Default is 0.5.

0.5
hatch str

Hatching pattern for the bars.

''
label str

Label for the bars.

None
hc color

Highlight color for the bars.

None
**kwargs dict

Additional keyword arguments passed to matplotlib's barh.

{}

Returns:

Name Type Description
ax Axes

The axis with the epoch plot.

Examples:

>>> from nelpy import EpochArray
>>> epochs = EpochArray([[0, 1], [2, 3], [5, 6]])
>>> epochplot(epochs)
Source code in nelpy/plotting/core.py
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
def epochplot(
    epochs,
    data=None,
    *,
    ax=None,
    height=None,
    fc="0.5",
    ec="0.5",
    alpha=0.5,
    hatch="",
    label=None,
    hc=None,
    **kwargs,
):
    """
    Plot an EpochArray as horizontal bars (intervals) on a timeline.

    Parameters
    ----------
    epochs : nelpy.EpochArray
        The epochs to plot.
    data : array-like, optional
        Data to plot on y axis; must be of size (epochs.n_epochs,).
    ax : matplotlib.axes.Axes, optional
        Axis to plot on. If None, uses current axis.
    height : float, optional
        Height of the bars. If None, uses default.
    fc : color, optional
        Face color of the bars. Default is '0.5'.
    ec : color, optional
        Edge color of the bars. Default is '0.5'.
    alpha : float, optional
        Transparency of the bars. Default is 0.5.
    hatch : str, optional
        Hatching pattern for the bars.
    label : str, optional
        Label for the bars.
    hc : color, optional
        Highlight color for the bars.
    **kwargs : dict
        Additional keyword arguments passed to matplotlib's barh.

    Returns
    -------
    ax : matplotlib.axes.Axes
        The axis with the epoch plot.

    Examples
    --------
    >>> from nelpy import EpochArray
    >>> epochs = EpochArray([[0, 1], [2, 3], [5, 6]])
    >>> epochplot(epochs)
    """
    if ax is None:
        ax = plt.gca()

    # do fixed-value-on-epoch plot if data is not None
    if data is not None:
        if epochs.n_intervals != len(data):
            raise ValueError("epocharray and data must have the same length")

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            for epoch, val in zip(epochs, data):
                ax.plot([epoch.start, epoch.stop], [val, val], "-o", **kwargs)
        return ax

    ymin, ymax = ax.get_ylim()
    xmin, xmax = ax.get_xlim()
    if height is None:
        height = ymax - ymin

    if hc is not None:
        try:
            hc_before = mpl.rcParams["hatch.color"]
            mpl.rcParams["hatch.color"] = hc
        except KeyError:
            warnings.warn("Hatch color not supported for matplotlib <2.0")

    for ii, (start, stop) in enumerate(zip(epochs.starts, epochs.stops)):
        ax.axvspan(
            start,
            stop,
            hatch=hatch,
            facecolor=fc,
            edgecolor=ec,
            alpha=alpha,
            label=label if ii == 0 else "_nolegend_",
            **kwargs,
        )

    if epochs.start < xmin:
        xmin = epochs.start
    if epochs.stop > xmax:
        xmax = epochs.stop
    ax.set_xlim([xmin, xmax])

    if hc is not None:
        try:
            mpl.rcParams["hatch.color"] = hc_before
        except UnboundLocalError:
            pass

    return ax

imagesc(x=None, y=None, data=None, *, ax=None, large=False, **kwargs)

Plot a 2D matrix or image similar to Matlab's imagesc.

Parameters:

Name Type Description Default
x array - like

X values (columns).

None
y array - like

Y values (rows).

None
data ndarray of shape (Nrows, Ncols)

Matrix to visualize.

None
ax Axes

Plot in given axis; if None creates a new figure.

None
large bool

If True, optimize for large matrices. Default is False.

False
**kwargs dict

Additional keyword arguments passed to imshow.

{}

Returns:

Name Type Description
im AxesImage

The image object.

Examples:

Plot a simple matrix using imagesc:

>>> x = np.linspace(-100, -10, 10)
>>> y = np.array([-8, -3.0])
>>> data = np.random.randn(y.size, x.size)
>>> imagesc(x, y, data)
or
>>> imagesc(data)

Adding a colorbar:

>>> ax, img = imagesc(data)
>>> from mpl_toolkits.axes_grid1 import make_axes_locatable
>>> divider = make_axes_locatable(ax)
>>> cax = divider.append_axes("right", size="3.5%", pad=0.1)
>>> cb = plt.colorbar(img, cax=cax)
>>> npl.utils.no_yticks(cax)
Source code in nelpy/plotting/core.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
def imagesc(x=None, y=None, data=None, *, ax=None, large=False, **kwargs):
    """
    Plot a 2D matrix or image similar to Matlab's imagesc.

    Parameters
    ----------
    x : array-like, optional
        X values (columns).
    y : array-like, optional
        Y values (rows).
    data : ndarray of shape (Nrows, Ncols)
        Matrix to visualize.
    ax : matplotlib.axes.Axes, optional
        Plot in given axis; if None creates a new figure.
    large : bool, optional
        If True, optimize for large matrices. Default is False.
    **kwargs : dict
        Additional keyword arguments passed to imshow.

    Returns
    -------
    im : matplotlib.image.AxesImage
        The image object.

    Examples
    --------
    Plot a simple matrix using imagesc:

    >>> x = np.linspace(-100, -10, 10)
    >>> y = np.array([-8, -3.0])
    >>> data = np.random.randn(y.size, x.size)
    >>> imagesc(x, y, data)
    or
    >>> imagesc(data)

    Adding a colorbar:

    >>> ax, img = imagesc(data)
    >>> from mpl_toolkits.axes_grid1 import make_axes_locatable
    >>> divider = make_axes_locatable(ax)
    >>> cax = divider.append_axes("right", size="3.5%", pad=0.1)
    >>> cb = plt.colorbar(img, cax=cax)
    >>> npl.utils.no_yticks(cax)
    """

    def extents(f):
        if len(f) > 1:
            delta = f[1] - f[0]
        else:
            delta = 1
        return [f[0] - delta / 2, f[-1] + delta / 2]

    if ax is None:
        ax = plt.gca()
    if data is None:
        if x is None:  # no args
            raise ValueError(
                "Unknown input. Usage imagesc(x, y, data) or imagesc(data)."
            )
        elif y is None:  # only one arg, so assume it to be data
            data = x
            x = np.arange(data.shape[1])
            y = np.arange(data.shape[0])
        else:  # x and y, but no data
            raise ValueError(
                "Unknown input. Usage imagesc(x, y, data) or imagesc(data)."
            )

    if data.ndim != 2:
        raise TypeError("data must be 2 dimensional")

    if not large:
        # Matplotlib imshow
        image = ax.imshow(
            data,
            aspect="auto",
            interpolation="none",
            extent=extents(x) + extents(y),
            origin="lower",
            **kwargs,
        )
    else:
        # ModestImage imshow for large images, but 'extent' is still not working well
        image = utils.imshow(
            axes=ax,
            X=data,
            aspect="auto",
            interpolation="none",
            extent=extents(x) + extents(y),
            origin="lower",
            **kwargs,
        )

    return ax, image

matshow(data, *, ax=None, **kwargs)

Display a matrix in a new figure window using matplotlib's matshow.

Parameters:

Name Type Description Default
data array - like or BinnedSpikeTrainArray

The matrix or nelpy object to display.

required
ax Axes

Axis to plot on. If None, uses current axis.

None
**kwargs dict

Additional keyword arguments passed to matplotlib's matshow.

{}

Returns:

Name Type Description
ax Axes

The axis with the plotted matrix.

Examples:

>>> mat = np.random.rand(5, 5)
>>> matshow(mat)
Source code in nelpy/plotting/core.py
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
def matshow(data, *, ax=None, **kwargs):
    """
    Display a matrix in a new figure window using matplotlib's matshow.

    Parameters
    ----------
    data : array-like or nelpy.BinnedSpikeTrainArray
        The matrix or nelpy object to display.
    ax : matplotlib.axes.Axes, optional
        Axis to plot on. If None, uses current axis.
    **kwargs : dict
        Additional keyword arguments passed to matplotlib's matshow.

    Returns
    -------
    ax : matplotlib.axes.Axes
        The axis with the plotted matrix.

    Examples
    --------
    >>> mat = np.random.rand(5, 5)
    >>> matshow(mat)
    """

    # Sort out default values for the parameters
    if ax is None:
        ax = plt.gca()

    # Handle different types of input data
    if isinstance(data, core.BinnedSpikeTrainArray):
        # TODO: split by epoch, and plot matshows in same row, but with
        # a small gap to indicate discontinuities. How about slicing
        # then? Or slicing within an epoch?
        ax.matshow(data.data, **kwargs)
        ax.set_xlabel("time")
        ax.set_ylabel("unit")
        warnings.warn("Automatic x-axis formatting not yet implemented")
    else:
        raise NotImplementedError(
            "matshow({}) not yet supported".format(str(type(data)))
        )

    return ax

overviewstrip(epochs, *, ax=None, lw=5, solid_capstyle='butt', label=None, **kwargs)

Plot an epoch array as a strip (like a scrollbar) to show gaps in e.g. matshow plots.

Parameters:

Name Type Description Default
epochs EpochArray

The epochs to plot as a strip.

required
ax Axes

Axis to plot on. If None, uses current axis.

None
lw float

Line width for the strip. Default is 5.

5
solid_capstyle str

Cap style for the strip. Default is 'butt'.

'butt'
label str

Label for the strip.

None
**kwargs dict

Additional keyword arguments passed to matplotlib's plot.

{}

Returns:

Name Type Description
ax Axes

The axis with the overview strip.

Examples:

>>> from nelpy import EpochArray
>>> epochs = EpochArray([[0, 1], [2, 3], [5, 6]])
>>> overviewstrip(epochs)
Source code in nelpy/plotting/core.py
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
def overviewstrip(
    epochs, *, ax=None, lw=5, solid_capstyle="butt", label=None, **kwargs
):
    """
    Plot an epoch array as a strip (like a scrollbar) to show gaps in e.g. matshow plots.

    Parameters
    ----------
    epochs : nelpy.EpochArray
        The epochs to plot as a strip.
    ax : matplotlib.axes.Axes, optional
        Axis to plot on. If None, uses current axis.
    lw : float, optional
        Line width for the strip. Default is 5.
    solid_capstyle : str, optional
        Cap style for the strip. Default is 'butt'.
    label : str, optional
        Label for the strip.
    **kwargs : dict
        Additional keyword arguments passed to matplotlib's plot.

    Returns
    -------
    ax : matplotlib.axes.Axes
        The axis with the overview strip.

    Examples
    --------
    >>> from nelpy import EpochArray
    >>> epochs = EpochArray([[0, 1], [2, 3], [5, 6]])
    >>> overviewstrip(epochs)
    """
    from mpl_toolkits.axes_grid1 import make_axes_locatable

    if ax is None:
        ax = plt.gca()

    divider = make_axes_locatable(ax)
    ax_ = divider.append_axes("top", size=0.2, pad=0.05)

    for epoch in epochs:
        ax_.plot(
            [epoch.start, epoch.stop],
            [1, 1],
            lw=lw,
            solid_capstyle=solid_capstyle,
            **kwargs,
        )

    if label is not None:
        ax_.set_yticks([1])
        ax_.set_yticklabels([label])
    else:
        ax_.set_yticks([])

    utils.no_yticks(ax_)
    utils.clear_left(ax_)
    utils.clear_right(ax_)
    utils.clear_top_bottom(ax_)

    ax_.set_xlim(ax.get_xlim())

palplot(pal, size=1)

Plot the values in a color palette as a horizontal array.

Parameters:

Name Type Description Default
pal sequence of matplotlib colors

Colors, i.e. as returned by nelpy.color_palette().

required
size float

Scaling factor for size of plot. Default is 1.

1

Examples:

>>> from nelpy.plotting.miscplot import palplot
>>> pal = ["#FF0000", "#00FF00", "#0000FF"]
>>> palplot(pal)
Source code in nelpy/plotting/miscplot.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def palplot(pal, size=1):
    """
    Plot the values in a color palette as a horizontal array.

    Parameters
    ----------
    pal : sequence of matplotlib colors
        Colors, i.e. as returned by nelpy.color_palette().
    size : float, optional
        Scaling factor for size of plot. Default is 1.

    Examples
    --------
    >>> from nelpy.plotting.miscplot import palplot
    >>> pal = ["#FF0000", "#00FF00", "#0000FF"]
    >>> palplot(pal)
    """
    n = len(pal)
    f, ax = plt.subplots(1, 1, figsize=(n * size, size))
    ax.imshow(
        np.arange(n).reshape(1, n),
        cmap=mpl.colors.ListedColormap(list(pal)),
        interpolation="nearest",
        aspect="auto",
    )
    ax.set_xticks(np.arange(n) - 0.5)
    ax.set_yticks([-0.5, 0.5])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

plot(obj, *args, **kwargs)

Plot a nelpy object or array-like data using matplotlib.

Parameters:

Name Type Description Default
obj nelpy object or array-like

The object or data to plot. Can be a nelpy RegularlySampledAnalogSignalArray or array-like.

required
*args tuple

Additional positional arguments passed to matplotlib's plot.

()
**kwargs dict

Additional keyword arguments passed to matplotlib's plot. Special keys: ax : matplotlib.axes.Axes, optional Axis to plot on. If None, uses current axis. autoscale : bool, optional Whether to autoscale the axis. Default is True. xlabel : str, optional X-axis label. ylabel : str, optional Y-axis label.

{}

Returns:

Name Type Description
ax Axes

The axis with the plotted data.

Examples:

>>> from nelpy.core import RegularlySampledAnalogSignalArray
>>> obj = RegularlySampledAnalogSignalArray(...)  # your data here
>>> plot(obj)
>>> plot([1, 2, 3, 4])
Source code in nelpy/plotting/core.py
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
def plot(obj, *args, **kwargs):
    """
    Plot a nelpy object or array-like data using matplotlib.

    Parameters
    ----------
    obj : nelpy object or array-like
        The object or data to plot. Can be a nelpy RegularlySampledAnalogSignalArray or array-like.
    *args : tuple
        Additional positional arguments passed to matplotlib's plot.
    **kwargs : dict
        Additional keyword arguments passed to matplotlib's plot. Special keys:
            ax : matplotlib.axes.Axes, optional
                Axis to plot on. If None, uses current axis.
            autoscale : bool, optional
                Whether to autoscale the axis. Default is True.
            xlabel : str, optional
                X-axis label.
            ylabel : str, optional
                Y-axis label.

    Returns
    -------
    ax : matplotlib.axes.Axes
        The axis with the plotted data.

    Examples
    --------
    >>> from nelpy.core import RegularlySampledAnalogSignalArray
    >>> obj = RegularlySampledAnalogSignalArray(...)  # your data here
    >>> plot(obj)
    >>> plot([1, 2, 3, 4])
    """

    ax = kwargs.pop("ax", None)
    autoscale = kwargs.pop("autoscale", True)
    xlabel = kwargs.pop("xlabel", None)
    ylabel = kwargs.pop("ylabel", None)

    if ax is None:
        ax = plt.gca()

    if isinstance(obj, core.RegularlySampledAnalogSignalArray):
        if obj.n_signals == 1:
            label = kwargs.pop("label", None)
            for ii, (abscissa_vals, data) in enumerate(
                zip(
                    obj._intervaltime.plot_generator(),
                    obj._intervaldata.plot_generator(),
                )
            ):
                ax.plot(
                    abscissa_vals,
                    data.T,
                    label=label if ii == 0 else "_nolegend_",
                    *args,
                    **kwargs,
                )
        elif obj.n_signals > 1:
            # TODO: intercept when any color is requested. This could happen
            # multiple ways, such as plt.plot(x, '-r') or plt.plot(x, c='0.7')
            # or plt.plot(x, color='red'), and maybe some others? Probably have
            # dig into the matplotlib code to see how they parse this and do
            # conflict resolution... Update: they use the last specified color.
            # but I still need to know how to detect a color that was passed in
            # the *args part, e.g., '-r'

            color = kwargs.pop("color", None)
            carg = kwargs.pop("c", None)

            if color is not None and carg is not None:
                # TODO: fix this so that a warning is issued, not raised
                raise ValueError("saw kwargs ['c', 'color']")
                # raise UserWarning("saw kwargs ['c', 'color'] which are all aliases for 'color'.  Kept value from 'color'")
            if carg:
                color = carg

            if not color:
                colors = []
                for ii in range(obj.n_signals):
                    (line,) = ax.plot(0, 0.5)
                    colors.append(line.get_color())
                    line.remove()

                for ee, (abscissa_vals, data) in enumerate(
                    zip(
                        obj._intervaltime.plot_generator(),
                        obj._intervaldata.plot_generator(),
                    )
                ):
                    if ee > 0:
                        kwargs["label"] = "_nolegend_"
                    for ii, snippet in enumerate(data):
                        ax.plot(
                            abscissa_vals, snippet, *args, color=colors[ii], **kwargs
                        )
            else:
                kwargs["color"] = color
                for ee, (abscissa_vals, data) in enumerate(
                    zip(
                        obj._intervaltime.plot_generator(),
                        obj._intervaldata.plot_generator(),
                    )
                ):
                    if ee > 0:
                        kwargs["label"] = "_nolegend_"
                    for ii, snippet in enumerate(data):
                        ax.plot(abscissa_vals, snippet, *args, **kwargs)

        if xlabel is None:
            xlabel = obj._abscissa.label
        if ylabel is None:
            ylabel = obj._ordinate.label
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
    else:  # if we didn't handle it yet, just pass it through to matplotlib...
        ax.plot(obj, *args, **kwargs)

    if autoscale:
        xmin = np.inf
        xmax = -np.inf
        ymin = np.inf
        ymax = -np.inf
        for child in ax.get_children():
            try:
                cxmin, cymin = np.min(child.get_xydata(), axis=0)
                cxmax, cymax = np.max(child.get_xydata(), axis=0)
                if cxmin < xmin:
                    xmin = cxmin
                if cymin < ymin:
                    ymin = cymin
                if cxmax > xmax:
                    xmax = cxmax
                if cymax > ymax:
                    ymax = cymax
            except Exception:
                pass
        ax.set_xlim(xmin, xmax)

plot2d(npl_obj, data=None, *, ax=None, mew=None, color=None, mec=None, markerfacecolor=None, **kwargs)

Plot 2D data for nelpy objects or array-like input.

Parameters:

Name Type Description Default
npl_obj nelpy object or array-like

The object or data to plot in 2D.

required
data array - like

Data to plot. If None, uses npl_obj's data.

None
ax Axes

Axis to plot on. If None, uses current axis.

None
mew float

Marker edge width.

None
color matplotlib color

Trace color.

None
mec matplotlib color

Marker edge color.

None
markerfacecolor matplotlib color

Marker face color.

None
**kwargs dict

Additional keyword arguments passed to matplotlib's plot.

{}

Returns:

Name Type Description
ax Axes

The axis with the plotted data.

Examples:

>>> plot2d([[0, 1], [1, 2], [2, 3]])
Source code in nelpy/plotting/core.py
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
def plot2d(
    npl_obj,
    data=None,
    *,
    ax=None,
    mew=None,
    color=None,
    mec=None,
    markerfacecolor=None,
    **kwargs,
):
    """
    Plot 2D data for nelpy objects or array-like input.

    Parameters
    ----------
    npl_obj : nelpy object or array-like
        The object or data to plot in 2D.
    data : array-like, optional
        Data to plot. If None, uses npl_obj's data.
    ax : matplotlib.axes.Axes, optional
        Axis to plot on. If None, uses current axis.
    mew : float, optional
        Marker edge width.
    color : matplotlib color, optional
        Trace color.
    mec : matplotlib color, optional
        Marker edge color.
    markerfacecolor : matplotlib color, optional
        Marker face color.
    **kwargs : dict
        Additional keyword arguments passed to matplotlib's plot.

    Returns
    -------
    ax : matplotlib.axes.Axes
        The axis with the plotted data.

    Examples
    --------
    >>> plot2d([[0, 1], [1, 2], [2, 3]])
    """

    if ax is None:
        ax = plt.gca()
    if mec is None:
        mec = color
    if markerfacecolor is None:
        markerfacecolor = "w"

    if isinstance(npl_obj, np.ndarray):
        ax.plot(npl_obj, mec=mec, markerfacecolor=markerfacecolor, **kwargs)

    # TODO: better solution for this? we could just iterate over the epochs and
    # plot them but that might take up too much time since a copy is being made
    # each iteration?
    if isinstance(npl_obj, core.RegularlySampledAnalogSignalArray):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            for segment in npl_obj:
                if color is not None:
                    ax.plot(
                        segment[:, 0]._data_colsig,
                        segment[:, 1]._data_colsig,
                        color=color,
                        mec=mec,
                        markerfacecolor="w",
                        **kwargs,
                    )
                else:
                    ax.plot(
                        segment[:, 0]._data_colsig,
                        segment[:, 1]._data_colsig,
                        # color=color,
                        mec=mec,
                        markerfacecolor="w",
                        **kwargs,
                    )

    if isinstance(npl_obj, core.PositionArray):
        xlim, ylim = npl_obj.xlim, npl_obj.ylim
        if xlim is not None:
            ax.set_xlim(xlim)
        if ylim is not None:
            ax.set_ylim(ylim)
    return ax

plot_cum_error_dist(*, cumhist=None, bincenters=None, bst=None, extern=None, decodefunc=None, k=None, transfunc=None, n_extern=None, n_bins=None, extmin=None, extmax=None, sigma=None, lw=None, ax=None, inset=True, inset_ax=None, color=None, **kwargs)

Plot (and optionally compute) the cumulative distribution of decoding errors.

Evaluated using a cross-validation procedure. See Fig 3.(b) of "Analysis of Hippocampal Memory Replay Using Neural Population Decoding", Fabian Kloosterman, 2012.

Parameters:

Name Type Description Default
cumhist array - like

Precomputed cumulative histogram of errors. If None, will be computed.

None
bincenters array - like

Bin centers for the cumulative histogram. If None, will be computed.

None
bst BinnedSpikeTrainArray

Required if cumhist and bincenters are not provided. Used for error computation.

None
extern array - like

External variable (e.g., position) for decoding. Required if cumhist and bincenters are not provided.

None
decodefunc callable

Decoding function to use. Defaults to decoding.decode1D.

None
k int

Number of cross-validation folds. Default is 5.

None
transfunc callable

Optional transformation function for the external variable.

None
n_extern int

Number of external variable samples. Default is 100.

None
n_bins int

Number of bins for the error histogram. Default is 200.

None
extmin float

Minimum value of the external variable. Default is 0.

None
extmax float

Maximum value of the external variable. Default is 100.

None
sigma float

Smoothing parameter. Default is 3.

None
lw float

Line width for the plot. Default is 1.5.

None
ax Axes

Axis to plot on. If None, uses current axis.

None
inset bool

Whether to include an inset plot. Default is True.

True
inset_ax Axes

Axis for the inset plot. If None, one will be created.

None
color color

Line color. If None, uses next color in cycle.

None
**kwargs

Additional keyword arguments for plotting.

{}

Returns:

Name Type Description
ax Axes

The axis with the cumulative error plot.

inset_ax (Axes, optional)

The axis with the inset plot (if inset=True).

Examples:

>>> ax, inset_ax = plot_cum_error_dist(bst=bst, extern=pos)
>>> plt.show()
Source code in nelpy/plotting/decoding.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
def plot_cum_error_dist(
    *,
    cumhist=None,
    bincenters=None,
    bst=None,
    extern=None,
    decodefunc=None,
    k=None,
    transfunc=None,
    n_extern=None,
    n_bins=None,
    extmin=None,
    extmax=None,
    sigma=None,
    lw=None,
    ax=None,
    inset=True,
    inset_ax=None,
    color=None,
    **kwargs,
):
    """
    Plot (and optionally compute) the cumulative distribution of decoding errors.

    Evaluated using a cross-validation procedure. See Fig 3.(b) of "Analysis of Hippocampal Memory Replay Using Neural Population Decoding", Fabian Kloosterman, 2012.

    Parameters
    ----------
    cumhist : array-like, optional
        Precomputed cumulative histogram of errors. If None, will be computed.
    bincenters : array-like, optional
        Bin centers for the cumulative histogram. If None, will be computed.
    bst : BinnedSpikeTrainArray, optional
        Required if cumhist and bincenters are not provided. Used for error computation.
    extern : array-like, optional
        External variable (e.g., position) for decoding. Required if cumhist and bincenters are not provided.
    decodefunc : callable, optional
        Decoding function to use. Defaults to decoding.decode1D.
    k : int, optional
        Number of cross-validation folds. Default is 5.
    transfunc : callable, optional
        Optional transformation function for the external variable.
    n_extern : int, optional
        Number of external variable samples. Default is 100.
    n_bins : int, optional
        Number of bins for the error histogram. Default is 200.
    extmin : float, optional
        Minimum value of the external variable. Default is 0.
    extmax : float, optional
        Maximum value of the external variable. Default is 100.
    sigma : float, optional
        Smoothing parameter. Default is 3.
    lw : float, optional
        Line width for the plot. Default is 1.5.
    ax : matplotlib.axes.Axes, optional
        Axis to plot on. If None, uses current axis.
    inset : bool, optional
        Whether to include an inset plot. Default is True.
    inset_ax : matplotlib.axes.Axes, optional
        Axis for the inset plot. If None, one will be created.
    color : color, optional
        Line color. If None, uses next color in cycle.
    **kwargs
        Additional keyword arguments for plotting.

    Returns
    -------
    ax : matplotlib.axes.Axes
        The axis with the cumulative error plot.
    inset_ax : matplotlib.axes.Axes, optional
        The axis with the inset plot (if inset=True).

    Examples
    --------
    >>> ax, inset_ax = plot_cum_error_dist(bst=bst, extern=pos)
    >>> plt.show()
    """

    if ax is None:
        ax = plt.gca()
    if lw is None:
        lw = 1.5
    if decodefunc is None:
        decodefunc = decoding.decode1D
    if k is None:
        k = 5
    if n_extern is None:
        n_extern = 100
    if n_bins is None:
        n_bins = 200
    if extmin is None:
        extmin = 0
    if extmax is None:
        extmax = 100
    if sigma is None:
        sigma = 3

    # Get the color from the current color cycle
    if color is None:
        (line,) = ax.plot(0, 0.5)
        color = line.get_color()
        line.remove()

    # if cumhist or bincenters are NOT provided, then compute them
    if cumhist is None or bincenters is None:
        assert bst is not None, (
            "if cumhist and bincenters are not given, then bst must be provided to recompute them!"
        )
        assert extern is not None, (
            "if cumhist and bincenters are not given, then extern must be provided to recompute them!"
        )
        cumhist, bincenters = decoding.cumulative_dist_decoding_error_using_xval(
            bst=bst,
            extern=extern,
            decodefunc=decoding.decode1D,
            k=k,
            transfunc=transfunc,
            n_extern=n_extern,
            extmin=extmin,
            extmax=extmax,
            sigma=sigma,
            n_bins=n_bins,
        )
    # now plot results
    ax.plot(bincenters, cumhist, lw=lw, color=color, **kwargs)
    ax.set_xlim(bincenters[0], bincenters[-1])
    ax.set_xlabel("error [cm]")
    ax.set_ylabel("cumulative probability")

    ax.set_ylim(0)

    if inset:
        if inset_ax is None:
            inset_ax = inset_axes(
                parent_axes=ax, width="60%", height="50%", loc=4, borderpad=2
            )

        inset_ax.plot(bincenters, cumhist, lw=lw, color=color, **kwargs)

        # annotate inset
        thresh1 = 0.7
        inset_ax.hlines(
            thresh1, 0, cumhist(thresh1), color=color, alpha=0.9, lw=lw, linestyle="--"
        )
        inset_ax.vlines(
            cumhist(thresh1), 0, thresh1, color=color, alpha=0.9, lw=lw, linestyle="--"
        )
        inset_ax.set_xlim(0, 12 * np.ceil(cumhist(thresh1) / 10))

        thresh2 = 0.5
        inset_ax.hlines(
            thresh2, 0, cumhist(thresh2), color=color, alpha=0.6, lw=lw, linestyle="--"
        )
        inset_ax.vlines(
            cumhist(thresh2), 0, thresh2, color=color, alpha=0.6, lw=lw, linestyle="--"
        )

        inset_ax.set_yticks((0, thresh1, thresh2, 1))
        inset_ax.set_ylim(0)

        return ax, inset_ax

    return ax

plot_posteriors(bst, tuningcurve, idx=None, w=1, bin_px_size=0.08)

Plot posterior probabilities for decoded neural activity.

Parameters:

Name Type Description Default
bst BinnedSpikeTrainArray

The binned spike train array to decode.

required
tuningcurve TuningCurve1D

The tuning curve used for decoding.

required
idx array - like

Indices of events to plot. If None, all events are plotted.

None
w int

Window size for decoding (default is 1).

1
bin_px_size float

Size of each bin in pixels for the plot (default is 0.08).

0.08

Returns:

Name Type Description
ax Axes

The axis with the posterior plot.

Examples:

>>> ax = plot_posteriors(bst, tc)
>>> plt.show()
Source code in nelpy/plotting/decoding.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def plot_posteriors(bst, tuningcurve, idx=None, w=1, bin_px_size=0.08):
    """
    Plot posterior probabilities for decoded neural activity.

    Parameters
    ----------
    bst : BinnedSpikeTrainArray
        The binned spike train array to decode.
    tuningcurve : TuningCurve1D
        The tuning curve used for decoding.
    idx : array-like, optional
        Indices of events to plot. If None, all events are plotted.
    w : int, optional
        Window size for decoding (default is 1).
    bin_px_size : float, optional
        Size of each bin in pixels for the plot (default is 0.08).

    Returns
    -------
    ax : matplotlib.axes.Axes
        The axis with the posterior plot.

    Examples
    --------
    >>> ax = plot_posteriors(bst, tc)
    >>> plt.show()
    """
    if idx is not None:
        bst = bst[idx]
    tc = tuningcurve

    # decode neural activity
    posterior, bdries, mode_pth, mean_pth = decoding.decode1D(
        bst=bst, ratemap=tc, xmin=tc.bins[0], xmax=tc.bins[-1], w=w
    )

    pixel_width = 0.5

    n_ext, n_bins = posterior.shape
    lengths = np.diff(bdries)

    plt.figure(figsize=(bin_px_size * n_bins, 2))
    ax = plt.gca()

    imagesc(
        x=np.arange(n_bins),
        y=np.arange(int(tc.bins[-1] + 1)),
        data=posterior,
        cmap=plt.cm.Spectral_r,
        ax=ax,
    )
    plotutils.yticks_interval(tc.bins[-1])
    plotutils.no_yticks(ax)
    # plt.imshow(posterior, cmap=plt.cm.Spectral_r, interpolation='none', aspect='auto')
    ax.vlines(
        np.arange(lengths.sum()) - pixel_width,
        *ax.get_ylim(),
        lw=1,
        linestyle=":",
        color="0.8",
    )
    ax.vlines(np.cumsum(lengths) - pixel_width, *ax.get_ylim(), lw=1)

    ax.set_xlim(-pixel_width, lengths.sum() - pixel_width)

    event_centers = np.insert(np.cumsum(lengths), 0, 0)
    event_centers = event_centers[:-1] + lengths / 2 - 0.5

    ax.set_xticks(event_centers)
    if idx is not None:
        ax.set_xticklabels(idx)
    else:
        ax.set_xticklabels(np.arange(bst.n_intervals))

    plotutils.no_xticks(ax)

    return ax

plot_tuning_curves1D(ratemap, ax=None, normalize=False, pad=None, unit_labels=None, fill=True, color=None, alpha=0.3)

Plot 1D tuning curves for multiple units.

Parameters:

Name Type Description Default
ratemap TuningCurve1D or similar

Object with .ratemap (2D array: n_units x n_ext), .bins, .bin_centers, and .unit_labels.

required
ax Axes

Axis to plot on. If None, uses current axis.

None
normalize bool

If True, normalize each curve to its peak value.

False
pad float

Vertical offset between curves. If None, uses mean of ratemap / 2.

None
unit_labels list

Labels for each unit. If None, uses ratemap.unit_labels.

None
fill bool

Whether to fill under each curve. Default is True.

True
color color or None

Color for all curves. If None, uses default color cycle.

None
alpha float

Transparency for the fill. Default is 0.3.

0.3

Returns:

Name Type Description
ax Axes

The axis with the plotted tuning curves.

Examples:

>>> plot_tuning_curves1D(tc)
Source code in nelpy/plotting/core.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def plot_tuning_curves1D(
    ratemap,
    ax=None,
    normalize=False,
    pad=None,
    unit_labels=None,
    fill=True,
    color=None,
    alpha=0.3,
):
    """
    Plot 1D tuning curves for multiple units.

    Parameters
    ----------
    ratemap : auxiliary.TuningCurve1D or similar
        Object with .ratemap (2D array: n_units x n_ext), .bins, .bin_centers, and .unit_labels.
    ax : matplotlib.axes.Axes, optional
        Axis to plot on. If None, uses current axis.
    normalize : bool, optional
        If True, normalize each curve to its peak value.
    pad : float, optional
        Vertical offset between curves. If None, uses mean of ratemap / 2.
    unit_labels : list, optional
        Labels for each unit. If None, uses ratemap.unit_labels.
    fill : bool, optional
        Whether to fill under each curve. Default is True.
    color : color or None, optional
        Color for all curves. If None, uses default color cycle.
    alpha : float, optional
        Transparency for the fill. Default is 0.3.

    Returns
    -------
    ax : matplotlib.axes.Axes
        The axis with the plotted tuning curves.

    Examples
    --------
    >>> plot_tuning_curves1D(tc)
    """
    if ax is None:
        ax = plt.gca()

    if isinstance(ratemap, auxiliary.TuningCurve1D) | isinstance(
        ratemap, auxiliary._tuningcurve.TuningCurve1D
    ):
        xmin = ratemap.bins[0]
        xmax = ratemap.bins[-1]
        xvals = ratemap.bin_centers
        if unit_labels is None:
            unit_labels = ratemap.unit_labels
        ratemap = ratemap.ratemap
    else:
        raise NotImplementedError

    if pad is None:
        pad = ratemap.mean() / 2

    n_units, n_ext = ratemap.shape

    if normalize:
        peak_firing_rates = ratemap.max(axis=1)
        ratemap = (ratemap.T / peak_firing_rates).T

    # determine max firing rate
    # max_firing_rate = ratemap.max()

    if xvals is None:
        xvals = np.arange(n_ext)
    if xmin is None:
        xmin = xvals[0]
    if xmax is None:
        xmax = xvals[-1]

    for unit, curve in enumerate(ratemap):
        if color is None:
            line = ax.plot(
                xvals, unit * pad + curve, zorder=int(10 + 2 * n_units - 2 * unit)
            )
        else:
            line = ax.plot(
                xvals,
                unit * pad + curve,
                zorder=int(10 + 2 * n_units - 2 * unit),
                color=color,
            )
        if fill:
            # Get the color from the current curve
            fillcolor = line[0].get_color()
            ax.fill_between(
                xvals,
                unit * pad,
                unit * pad + curve,
                alpha=alpha,
                color=fillcolor,
                zorder=int(10 + 2 * n_units - 2 * unit - 1),
            )

    ax.set_xlim(xmin, xmax)
    if pad != 0:
        yticks = np.arange(n_units) * pad + 0.5 * pad
        ax.set_yticks(yticks)
        ax.set_yticklabels(unit_labels)
        ax.set_xlabel("external variable")
        ax.set_ylabel("unit")
        utils.no_yticks(ax)
        utils.clear_left(ax)
    else:
        if normalize:
            ax.set_ylabel("normalized firing rate")
        else:
            ax.set_ylabel("firing rate [Hz]")
        ax.set_ylim(0)

    utils.clear_top(ax)
    utils.clear_right(ax)

    return ax

plotting_context(context=None, font_scale=1, rc=None)

Return a parameter dict to scale elements of the figure.

This affects things like the size of the labels, lines, and other elements of the plot, but not the overall style. The base context is "notebook", and the other contexts are "paper", "talk", and "poster", which are versions of the notebook parameters scaled by .8, 1.3, and 1.6, respectively.

This function returns an object that can be used in a with statement to temporarily change the context parameters.

Parameters:

Name Type Description Default
context dict, None, or one of {paper, notebook, talk, poster}

A dictionary of parameters or the name of a preconfigured set.

None
font_scale float

Separate scaling factor to independently scale the size of the font elements.

1
rc dict

Parameter mappings to override the values in the preset seaborn context dictionaries. This only updates parameters that are considered part of the context definition.

None

Returns:

Name Type Description
context_object _PlottingContext

An object that can be used as a context manager to temporarily set context.

Examples:

>>> c = plotting_context("poster")
>>> c = plotting_context("notebook", font_scale=1.5)
>>> c = plotting_context("talk", rc={"lines.linewidth": 2})
>>> import matplotlib.pyplot as plt
>>> with plotting_context("paper"):
...     f, ax = plt.subplots()
...     ax.plot([0, 1], [0, 1])
See Also

set_context : set the matplotlib parameters to scale plot elements axes_style : return a dict of parameters defining a figure style color_palette : define the color palette for a plot

Source code in nelpy/plotting/rcmod.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
def plotting_context(context=None, font_scale=1, rc=None):
    """
    Return a parameter dict to scale elements of the figure.

    This affects things like the size of the labels, lines, and other
    elements of the plot, but not the overall style. The base context
    is "notebook", and the other contexts are "paper", "talk", and "poster",
    which are versions of the notebook parameters scaled by .8, 1.3, and 1.6,
    respectively.

    This function returns an object that can be used in a ``with`` statement
    to temporarily change the context parameters.

    Parameters
    ----------
    context : dict, None, or one of {paper, notebook, talk, poster}
        A dictionary of parameters or the name of a preconfigured set.
    font_scale : float, optional
        Separate scaling factor to independently scale the size of the
        font elements.
    rc : dict, optional
        Parameter mappings to override the values in the preset seaborn
        context dictionaries. This only updates parameters that are
        considered part of the context definition.

    Returns
    -------
    context_object : _PlottingContext
        An object that can be used as a context manager to temporarily set context.

    Examples
    --------
    >>> c = plotting_context("poster")
    >>> c = plotting_context("notebook", font_scale=1.5)
    >>> c = plotting_context("talk", rc={"lines.linewidth": 2})
    >>> import matplotlib.pyplot as plt
    >>> with plotting_context("paper"):
    ...     f, ax = plt.subplots()
    ...     ax.plot([0, 1], [0, 1])

    See Also
    --------
    set_context : set the matplotlib parameters to scale plot elements
    axes_style : return a dict of parameters defining a figure style
    color_palette : define the color palette for a plot
    """
    if context is None:
        context_dict = {k: mpl.rcParams[k] for k in _context_keys}

    elif isinstance(context, dict):
        context_dict = context

    else:
        contexts = ["paper", "notebook", "talk", "poster"]
        if context not in contexts:
            raise ValueError("context must be in %s" % ", ".join(contexts))

        # Set up dictionary of default parameters
        base_context = {
            "figure.figsize": np.array([8, 5.5]),
            "font.size": 12,
            "axes.labelsize": 11,
            "axes.titlesize": 12,
            "xtick.labelsize": 10,
            "ytick.labelsize": 10,
            "legend.fontsize": 10,
            "grid.linewidth": 1,
            "lines.linewidth": 1.75,
            "patch.linewidth": 0.3,
            "lines.markersize": 7,
            "lines.markeredgewidth": 0,
            "xtick.major.width": 1,
            "ytick.major.width": 1,
            "xtick.minor.width": 0.5,
            "ytick.minor.width": 0.5,
            "xtick.major.pad": 7,
            "ytick.major.pad": 7,
        }

        # Scale all the parameters by the same factor depending on the context
        scaling = dict(paper=0.8, notebook=1, talk=1.3, poster=1.6)[context]
        context_dict = {k: v * scaling for k, v in base_context.items()}

        # Now independently scale the fonts
        font_keys = [
            "axes.labelsize",
            "axes.titlesize",
            "legend.fontsize",
            "xtick.labelsize",
            "ytick.labelsize",
            "font.size",
        ]
        font_dict = {k: context_dict[k] * font_scale for k in font_keys}
        context_dict.update(font_dict)

    # Implement hack workaround for matplotlib bug
    # See https://github.com/mwaskom/seaborn/issues/344
    # There is a bug in matplotlib 1.4.2 that makes points invisible when
    # they don't have an edgewidth. It will supposedly be fixed in 1.4.3.
    if mpl.__version__ == "1.4.2":
        context_dict["lines.markeredgewidth"] = 0.01

    # Override these settings with the provided rc dictionary
    if rc is not None:
        rc = {k: v for k, v in rc.items() if k in _context_keys}
        context_dict.update(rc)

    # Wrap in a _PlottingContext object so this can be used in a with statement
    context_object = _PlottingContext(context_dict)

    return context_object

psdplot(data, *, fs=None, window=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', ax=None)

Plot the power spectrum of a regularly-sampled time-domain signal.

Parameters:

Name Type Description Default
data RegularlySampledAnalogSignalArray

The input signal to analyze. Must be a 1D regularly sampled signal.

required
fs float

Sampling frequency of the time series in Hz. Defaults to data.fs if available.

None
window str or tuple or array_like

Desired window to use. See scipy.signal.get_window for options. If an array, used directly as the window. Defaults to None ('boxcar').

None
nfft int

Length of the FFT used. If None, the length of data will be used.

None
detrend str or function

Specifies how to detrend data prior to computing the spectrum. If a string, passed as the type argument to detrend. If a function, should return a detrended array. Defaults to 'constant'.

'constant'
return_onesided bool

If True, return a one-sided spectrum for real data. If False, return a two-sided spectrum. For complex data, always returns two-sided spectrum.

True
scaling (density, spectrum)

Selects between computing the power spectral density ('density', units V2/Hz) and the power spectrum ('spectrum', units V2). Defaults to 'density'.

'density'
ax Axes

Axis to plot on. If None, creates a new figure and axis.

None

Returns:

Name Type Description
ax Axes

The axis with the plotted power spectrum.

Examples:

>>> from nelpy.plotting.core import psdplot
>>> ax = psdplot(my_signal)
Source code in nelpy/plotting/core.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
def psdplot(
    data,
    *,
    fs=None,
    window=None,
    nfft=None,
    detrend="constant",
    return_onesided=True,
    scaling="density",
    ax=None,
):
    """
    Plot the power spectrum of a regularly-sampled time-domain signal.

    Parameters
    ----------
    data : RegularlySampledAnalogSignalArray
        The input signal to analyze. Must be a 1D regularly sampled signal.
    fs : float, optional
        Sampling frequency of the time series in Hz. Defaults to data.fs if available.
    window : str or tuple or array_like, optional
        Desired window to use. See scipy.signal.get_window for options. If an array, used directly as the window. Defaults to None ('boxcar').
    nfft : int, optional
        Length of the FFT used. If None, the length of data will be used.
    detrend : str or function, optional
        Specifies how to detrend data prior to computing the spectrum. If a string, passed as the type argument to detrend. If a function, should return a detrended array. Defaults to 'constant'.
    return_onesided : bool, optional
        If True, return a one-sided spectrum for real data. If False, return a two-sided spectrum. For complex data, always returns two-sided spectrum.
    scaling : {'density', 'spectrum'}, optional
        Selects between computing the power spectral density ('density', units V**2/Hz) and the power spectrum ('spectrum', units V**2). Defaults to 'density'.
    ax : matplotlib.axes.Axes, optional
        Axis to plot on. If None, creates a new figure and axis.

    Returns
    -------
    ax : matplotlib.axes.Axes
        The axis with the plotted power spectrum.

    Examples
    --------
    >>> from nelpy.plotting.core import psdplot
    >>> ax = psdplot(my_signal)
    """

    if ax is None:
        ax = plt.gca()

    if isinstance(data, core.RegularlySampledAnalogSignalArray):
        if fs is None:
            fs = data.fs
        if fs is None:
            raise ValueError(
                "The sampling rate fs cannot be inferred, and must be specified manually!"
            )
        if data.n_signals > 1:
            raise NotImplementedError(
                "more than one signal is not yet supported for psdplot!"
            )
        else:
            data = data.data.squeeze()
    else:
        raise NotImplementedError(
            "datatype {} not yet supported by psdplot!".format(str(type(data)))
        )

    kwargs = {
        "x": data,
        "fs": fs,
        "window": window,
        "nfft": nfft,
        "detrend": detrend,
        "return_onesided": return_onesided,
        "scaling": scaling,
    }

    f, Pxx_den = signal.periodogram(**kwargs)

    if scaling == "density":
        ax.semilogy(f, np.sqrt(Pxx_den))
        ax.set_ylabel("PSD [V**2/Hz]")
    elif scaling == "spectrum":
        ax.semilogy(f, np.sqrt(Pxx_den))
        ax.set_ylabel("Linear spectrum [V RMS]")
    ax.set_xlabel("frequency [Hz]")

    return ax

rastercountplot(spiketrain, nbins=50, **kwargs)

Plot a raster plot and spike count histogram for a SpikeTrainArray.

Parameters:

Name Type Description Default
spiketrain SpikeTrainArray

The spike train data to plot.

required
nbins int

Number of bins for the histogram. Default is 50.

50
**kwargs dict

Additional keyword arguments passed to rasterplot.

{}

Returns:

Name Type Description
ax1 Axes

The axis with the histogram plot.

ax2 Axes

The axis with the raster plot.

Examples:

>>> from nelpy import SpikeTrainArray
>>> sta = SpikeTrainArray([[1, 2, 3], [2, 4, 6]], fs=10)
>>> rastercountplot(sta, nbins=20)
Source code in nelpy/plotting/core.py
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
def rastercountplot(spiketrain, nbins=50, **kwargs):
    """
    Plot a raster plot and spike count histogram for a SpikeTrainArray.

    Parameters
    ----------
    spiketrain : nelpy.SpikeTrainArray
        The spike train data to plot.
    nbins : int, optional
        Number of bins for the histogram. Default is 50.
    **kwargs : dict
        Additional keyword arguments passed to rasterplot.

    Returns
    -------
    ax1 : matplotlib.axes.Axes
        The axis with the histogram plot.
    ax2 : matplotlib.axes.Axes
        The axis with the raster plot.

    Examples
    --------
    >>> from nelpy import SpikeTrainArray
    >>> sta = SpikeTrainArray([[1, 2, 3], [2, 4, 6]], fs=10)
    >>> rastercountplot(sta, nbins=20)
    """
    plt.figure(figsize=(14, 6))
    gs = gridspec.GridSpec(2, 1, hspace=0.01, height_ratios=[0.2, 0.8])
    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])

    color = kwargs.get("color", None)
    if color is None:
        color = "0.4"

    ds = (spiketrain.support.stop - spiketrain.support.start) / nbins
    flattened = spiketrain.bin(ds=ds).flatten()
    steps = np.squeeze(flattened.data)
    stepsx = np.linspace(
        spiketrain.support.start, spiketrain.support.stop, num=flattened.n_bins
    )

    #     ax1.plot(stepsx, steps, drawstyle='steps-mid', color='none');
    ax1.set_ylim([-0.5, np.max(steps) + 1])
    rasterplot(spiketrain, ax=ax2, **kwargs)

    utils.clear_left_right(ax1)
    utils.clear_top_bottom(ax1)
    utils.clear_top(ax2)

    ax1.fill_between(stepsx, steps, step="mid", color=color)

    utils.sync_xlims(ax1, ax2)

    return ax1, ax2

rasterplot(data, *, cmap=None, color=None, ax=None, lw=None, lh=None, vertstack=None, labels=None, cmap_lo=0.25, cmap_hi=0.75, **kwargs)

Make a raster plot from a SpikeTrainArray or EventArray object.

Parameters:

Name Type Description Default
data SpikeTrainArray or EventArray

The spike/event data to plot.

required
cmap matplotlib colormap

Colormap to use for the raster lines.

None
color matplotlib color

Plot color; default is '0.25'.

None
ax Axes

Plot in given axis. If None, plots on current axes.

None
lw float

Linewidth, default is 1.5.

None
lh float

Line height, default is 0.95.

None
vertstack bool

If True, stack units in vertically adjacent positions. Default is False.

None
labels list

Labels for input data units. If not specified, uses unit_labels from the input.

None
cmap_lo float

Lower bound for colormap normalization. Default is 0.25.

0.25
cmap_hi float

Upper bound for colormap normalization. Default is 0.75.

0.75
**kwargs dict

Other keyword arguments are passed to main vlines() call.

{}

Returns:

Name Type Description
ax Axes

Axis object with plot data.

Examples:

Instantiate a SpikeTrainArray and create a raster plot:

>>> stdata1 = [1, 2, 4, 5, 6, 10, 20]
>>> stdata2 = [3, 4, 4.5, 5, 5.5, 19]
>>> stdata3 = [5, 12, 14, 15, 16, 18, 22, 23, 24]
>>> stdata4 = [5, 12, 14, 15, 16, 18, 23, 25, 32]

>>> sta1 = nelpy.SpikeTrainArray([stdata1, stdata2, stdata3,
                                  stdata4, stdata1+stdata4],
                                  fs=5, unit_ids=[1,2,3,4,6])
>>> ax = rasterplot(sta1, color="cyan", lw=2, lh=2)

Instantiate another SpikeTrain Array, stack units, and specify labels. Note that the user-specified labels in the call to raster() will be shown instead of the unit_labels associated with the input data:

>>> sta3 = nelpy.SpikeTrainArray([stdata1, stdata4, stdata2+stdata3],
                                 support=ep1, fs=5, unit_ids=[10,5,12],
                                 unit_labels=['some', 'more', 'cells'])
>>> rasterplot(sta3, color=plt.cm.Blues, lw=2, lh=2, vertstack=True,
           labels=['units', 'of', 'interest'])
Source code in nelpy/plotting/core.py
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
def rasterplot(
    data,
    *,
    cmap=None,
    color=None,
    ax=None,
    lw=None,
    lh=None,
    vertstack=None,
    labels=None,
    cmap_lo=0.25,
    cmap_hi=0.75,
    **kwargs,
):
    """
    Make a raster plot from a SpikeTrainArray or EventArray object.

    Parameters
    ----------
    data : nelpy.SpikeTrainArray or nelpy.EventArray
        The spike/event data to plot.
    cmap : matplotlib colormap, optional
        Colormap to use for the raster lines.
    color : matplotlib color, optional
        Plot color; default is '0.25'.
    ax : matplotlib.axes.Axes, optional
        Plot in given axis. If None, plots on current axes.
    lw : float, optional
        Linewidth, default is 1.5.
    lh : float, optional
        Line height, default is 0.95.
    vertstack : bool, optional
        If True, stack units in vertically adjacent positions. Default is False.
    labels : list, optional
        Labels for input data units. If not specified, uses unit_labels from the input.
    cmap_lo : float, optional
        Lower bound for colormap normalization. Default is 0.25.
    cmap_hi : float, optional
        Upper bound for colormap normalization. Default is 0.75.
    **kwargs : dict
        Other keyword arguments are passed to main vlines() call.

    Returns
    -------
    ax : matplotlib.axes.Axes
        Axis object with plot data.

    Examples
    --------
    Instantiate a SpikeTrainArray and create a raster plot:

        >>> stdata1 = [1, 2, 4, 5, 6, 10, 20]
        >>> stdata2 = [3, 4, 4.5, 5, 5.5, 19]
        >>> stdata3 = [5, 12, 14, 15, 16, 18, 22, 23, 24]
        >>> stdata4 = [5, 12, 14, 15, 16, 18, 23, 25, 32]

        >>> sta1 = nelpy.SpikeTrainArray([stdata1, stdata2, stdata3,
                                          stdata4, stdata1+stdata4],
                                          fs=5, unit_ids=[1,2,3,4,6])
        >>> ax = rasterplot(sta1, color="cyan", lw=2, lh=2)

    Instantiate another SpikeTrain Array, stack units, and specify labels.
    Note that the user-specified labels in the call to raster() will be
    shown instead of the unit_labels associated with the input data:

        >>> sta3 = nelpy.SpikeTrainArray([stdata1, stdata4, stdata2+stdata3],
                                         support=ep1, fs=5, unit_ids=[10,5,12],
                                         unit_labels=['some', 'more', 'cells'])
        >>> rasterplot(sta3, color=plt.cm.Blues, lw=2, lh=2, vertstack=True,
                   labels=['units', 'of', 'interest'])
    """

    # Sort out default values for the parameters
    if ax is None:
        ax = plt.gca()
    if cmap is None and color is None:
        color = "0.25"
    if lw is None:
        lw = 1.5
    if lh is None:
        lh = 0.95
    if vertstack is None:
        vertstack = False

    firstplot = False
    if not ax.findobj(match=RasterLabelData):
        firstplot = True
        ax.add_artist(RasterLabelData())

    # override labels
    if labels is not None:
        series_labels = labels
    else:
        series_labels = []

    hh = lh / 2.0  # half the line height

    # Handle different types of input data
    if isinstance(data, core.EventArray):
        label_data = ax.findobj(match=RasterLabelData)[0].label_data
        serieslist = [-np.inf for element in data.series_ids]
        # no override labels so use unit_labels from input
        if not series_labels:
            series_labels = data.series_labels

        if firstplot:
            if vertstack:
                minunit = 1
                maxunit = data.n_series
                serieslist = range(1, data.n_series + 1)
            else:
                minunit = np.array(data.series_ids).min()
                maxunit = np.array(data.series_ids).max()
                serieslist = data.series_ids
        # see if any of the series_ids has already been plotted. If so,
        # then merge
        else:
            for idx, series_id in enumerate(data.series_ids):
                if series_id in label_data.keys():
                    position, _ = label_data[series_id]
                    serieslist[idx] = position
                else:  # unit not yet plotted
                    if vertstack:
                        serieslist[idx] = 1 + max(
                            int(ax.get_yticks()[-1]), max(serieslist)
                        )
                    else:
                        warnings.warn(
                            "Spike trains may be plotted in "
                            "the same vertical position as "
                            "another unit"
                        )
                        serieslist[idx] = data.series_ids[idx]

        if firstplot:
            minunit = int(minunit)
            maxunit = int(maxunit)
        else:
            (prev_ymin, prev_ymax) = ax.findobj(match=RasterLabelData)[0].yrange
            minunit = int(np.min([np.ceil(prev_ymin), np.min(serieslist)]))
            maxunit = int(np.max([np.floor(prev_ymax), np.max(serieslist)]))

        yrange = (minunit - 0.5, maxunit + 0.5)

        if cmap is not None:
            color_range = range(data.n_series)
            # TODO: if we go from 0 then most colormaps are invisible at one end of the spectrum
            colors = cmap(np.linspace(cmap_lo, cmap_hi, data.n_series))
            for series_ii, series, color_idx in zip(serieslist, data.data, color_range):
                ax.vlines(
                    series,
                    series_ii - hh,
                    series_ii + hh,
                    colors=colors[color_idx],
                    lw=lw,
                    **kwargs,
                )
        else:  # use a constant color:
            for series_ii, series in zip(serieslist, data.data):
                ax.vlines(
                    series,
                    series_ii - hh,
                    series_ii + hh,
                    colors=color,
                    lw=lw,
                    **kwargs,
                )

        # get existing label data so we can set some attributes
        rld = ax.findobj(match=RasterLabelData)[0]

        ax.set_ylim(yrange)
        rld.yrange = yrange

        for series_id, loc, label in zip(data.series_ids, serieslist, series_labels):
            rld.label_data[series_id] = (loc, label)
        serieslocs = []
        serieslabels = []
        for loc, label in label_data.values():
            serieslocs.append(loc)
            serieslabels.append(label)
        ax.set_yticks(serieslocs)
        ax.set_yticklabels(serieslabels)

    else:
        raise NotImplementedError(
            "plotting {} not yet supported".format(str(type(data)))
        )
    return ax

reset_defaults()

Restore all matplotlib RC params to default settings.

Source code in nelpy/plotting/rcmod.py
146
147
148
149
150
def reset_defaults():
    """
    Restore all matplotlib RC params to default settings.
    """
    mpl.rcParams.update(mpl.rcParamsDefault)

reset_orig()

Restore all matplotlib RC params to original settings (respects custom rc).

Source code in nelpy/plotting/rcmod.py
153
154
155
156
157
def reset_orig():
    """
    Restore all matplotlib RC params to original settings (respects custom rc).
    """
    mpl.rcParams.update(_orig_rc_params)

savefig(name, fig=None, formats=None, dpi=None, verbose=True, overwrite=False)

Save a figure in one or multiple formats.

Parameters:

Name Type Description Default
name str

Filename without an extension. If an extension is present, AND if formats is empty, then the filename extension will be used.

required
fig Figure

Figure to save, default uses current figure.

None
formats list

List of formats to export. Defaults to ['pdf', 'png']

None
dpi float

Resolution of the figure in dots per inch (DPI).

None
verbose bool

If True, print additional output to screen.

True
overwrite bool

If True, file will be overwritten.

False

Returns:

Type Description
None

Examples:

>>> import matplotlib.pyplot as plt
>>> fig, ax = plt.subplots()
>>> ax.plot([1, 2, 3], [4, 5, 6])
>>> savefig("myplot", fig=fig, formats=["png"], overwrite=True)
Source code in nelpy/plotting/utils.py
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
def savefig(name, fig=None, formats=None, dpi=None, verbose=True, overwrite=False):
    """
    Save a figure in one or multiple formats.

    Parameters
    ----------
    name : str
        Filename without an extension. If an extension is present,
        AND if formats is empty, then the filename extension will be used.
    fig : matplotlib.figure.Figure, optional
        Figure to save, default uses current figure.
    formats : list, optional
        List of formats to export. Defaults to ['pdf', 'png']
    dpi : float, optional
        Resolution of the figure in dots per inch (DPI).
    verbose : bool, optional
        If True, print additional output to screen.
    overwrite : bool, optional
        If True, file will be overwritten.

    Returns
    -------
    None

    Examples
    --------
    >>> import matplotlib.pyplot as plt
    >>> fig, ax = plt.subplots()
    >>> ax.plot([1, 2, 3], [4, 5, 6])
    >>> savefig("myplot", fig=fig, formats=["png"], overwrite=True)
    """
    # Check inputs
    # if not 0 <= prop <= 1:
    #     raise ValueError("prop must be between 0 and 1")

    if dpi is None:
        dpi = 300

    supportedFormats = [
        "eps",
        "jpeg",
        "jpg",
        "pdf",
        "pgf",
        "png",
        "ps",
        "raw",
        "rgba",
        "svg",
        "svgz",
        "tif",
        "tiff",
    ]

    name, ext = get_extension_from_filename(name)

    # if no list of formats is given, use defaults
    if formats is None and ext is None:
        formats = ["pdf", "png"]
    # if the filename has an extension, AND a list of extensions is given, then use only the list
    elif formats is not None and ext is not None:
        if not isinstance(formats, list):
            formats = [formats]
        print("WARNING! Extension in filename ignored in favor of formats list.")
    # if no list of extensions is given, use the extension from the filename
    elif formats is None and ext is not None:
        formats = [ext]
    else:
        pass

    if fig is None:
        fig = plt.gcf()

    for extension in formats:
        if extension not in supportedFormats:
            print("WARNING! Format '{}' not supported. Aborting...".format(extension))
        else:
            my_file = "figures/{}.{}".format(name, extension)

            if os.path.isfile(my_file):
                # file exists
                print("{} already exists!".format(my_file))

                if overwrite:
                    fig.savefig(my_file, dpi=dpi, bbox_inches="tight")

                    if verbose:
                        print(
                            "{} saved successfully... [using overwrite]".format(
                                extension
                            )
                        )
            else:
                fig.savefig(my_file, dpi=dpi, bbox_inches="tight")

                if verbose:
                    print("{} saved successfully...".format(extension))

set_context(context=None, font_scale=1, rc=None)

Set the plotting context parameters.

This affects things like the size of the labels, lines, and other elements of the plot, but not the overall style. The base context is "notebook", and the other contexts are "paper", "talk", and "poster", which are versions of the notebook parameters scaled by .8, 1.3, and 1.6, respectively.

Parameters:

Name Type Description Default
context dict, None, or one of {paper, notebook, talk, poster}

A dictionary of parameters or the name of a preconfigured set.

None
font_scale float

Separate scaling factor to independently scale the size of the font elements.

1
rc dict

Parameter mappings to override the values in the preset seaborn context dictionaries. This only updates parameters that are considered part of the context definition.

None

Examples:

>>> set_context("paper")
>>> set_context("talk", font_scale=1.4)
>>> set_context("talk", rc={"lines.linewidth": 2})
See Also

plotting_context : return a dictionary of rc parameters, or use in a with statement to temporarily set the context. set_style : set the default parameters for figure style set_palette : set the default color palette for figures

Source code in nelpy/plotting/rcmod.py
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
def set_context(context=None, font_scale=1, rc=None):
    """
    Set the plotting context parameters.

    This affects things like the size of the labels, lines, and other
    elements of the plot, but not the overall style. The base context
    is "notebook", and the other contexts are "paper", "talk", and "poster",
    which are versions of the notebook parameters scaled by .8, 1.3, and 1.6,
    respectively.

    Parameters
    ----------
    context : dict, None, or one of {paper, notebook, talk, poster}
        A dictionary of parameters or the name of a preconfigured set.
    font_scale : float, optional
        Separate scaling factor to independently scale the size of the
        font elements.
    rc : dict, optional
        Parameter mappings to override the values in the preset seaborn
        context dictionaries. This only updates parameters that are
        considered part of the context definition.

    Examples
    --------
    >>> set_context("paper")
    >>> set_context("talk", font_scale=1.4)
    >>> set_context("talk", rc={"lines.linewidth": 2})

    See Also
    --------
    plotting_context : return a dictionary of rc parameters, or use in
                       a ``with`` statement to temporarily set the context.
    set_style : set the default parameters for figure style
    set_palette : set the default color palette for figures
    """
    context_object = plotting_context(context, font_scale, rc)
    mpl.rcParams.update(context_object)

set_palette(palette, n_colors=None, desat=None)

Set the matplotlib color cycle using a seaborn palette.

Parameters:

Name Type Description Default
palette hls | husl | matplotlib colormap | seaborn color palette

Palette definition. Should be something that :func:color_palette can process.

required
n_colors int

Number of colors in the cycle. The default number of colors will depend on the format of palette, see the :func:color_palette documentation for more information.

None
desat float

Proportion to desaturate each color by.

None

Examples:

>>> set_palette("Reds")
>>> set_palette("Set1", 8, 0.75)
See Also

color_palette : build a color palette or set the color cycle temporarily in a with statement. set_context : set parameters to scale plot elements set_style : set the default parameters for figure style

Source code in nelpy/plotting/rcmod.py
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
def set_palette(palette, n_colors=None, desat=None):
    """
    Set the matplotlib color cycle using a seaborn palette.

    Parameters
    ----------
    palette : hls | husl | matplotlib colormap | seaborn color palette
        Palette definition. Should be something that :func:`color_palette` can process.
    n_colors : int, optional
        Number of colors in the cycle. The default number of colors will depend
        on the format of ``palette``, see the :func:`color_palette`
        documentation for more information.
    desat : float, optional
        Proportion to desaturate each color by.

    Examples
    --------
    >>> set_palette("Reds")
    >>> set_palette("Set1", 8, 0.75)

    See Also
    --------
    color_palette : build a color palette or set the color cycle temporarily
                    in a ``with`` statement.
    set_context : set parameters to scale plot elements
    set_style : set the default parameters for figure style
    """
    colors = palettes.color_palette(palette, n_colors, desat)
    if mpl_ge_150:
        from cycler import cycler

        cyl = cycler("color", colors)
        mpl.rcParams["axes.prop_cycle"] = cyl
    else:
        mpl.rcParams["axes.color_cycle"] = list(colors)
    mpl.rcParams["patch.facecolor"] = colors[0]

set_style(style=None, rc=None)

Set the aesthetic style of the plots.

This affects things like the color of the axes, whether a grid is enabled by default, and other aesthetic elements.

Parameters:

Name Type Description Default
style dict, None, or one of {darkgrid, whitegrid, dark, white, ticks}

A dictionary of parameters or the name of a preconfigured set.

None
rc dict

Parameter mappings to override the values in the preset seaborn style dictionaries. This only updates parameters that are considered part of the style definition.

None

Examples:

>>> set_style("whitegrid")
>>> set_style("ticks", {"xtick.major.size": 8, "ytick.major.size": 8})
See Also

axes_style : return a dict of parameters or use in a with statement to temporarily set the style. set_context : set parameters to scale plot elements set_palette : set the default color palette for figures

Source code in nelpy/plotting/rcmod.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
def set_style(style=None, rc=None):
    """
    Set the aesthetic style of the plots.

    This affects things like the color of the axes, whether a grid is
    enabled by default, and other aesthetic elements.

    Parameters
    ----------
    style : dict, None, or one of {darkgrid, whitegrid, dark, white, ticks}
        A dictionary of parameters or the name of a preconfigured set.
    rc : dict, optional
        Parameter mappings to override the values in the preset seaborn
        style dictionaries. This only updates parameters that are
        considered part of the style definition.

    Examples
    --------
    >>> set_style("whitegrid")
    >>> set_style("ticks", {"xtick.major.size": 8, "ytick.major.size": 8})

    See Also
    --------
    axes_style : return a dict of parameters or use in a ``with`` statement
                 to temporarily set the style.
    set_context : set parameters to scale plot elements
    set_palette : set the default color palette for figures
    """
    style_object = axes_style(style, rc)
    mpl.rcParams.update(style_object)

setup(context='notebook', style='ticks', palette='sweet', font='sans-serif', font_scale=1, rc=None)

Set aesthetic figure parameters for matplotlib plots.

Each set of parameters can be set directly or temporarily. See the referenced functions below for more information.

Parameters:

Name Type Description Default
context str or dict

Plotting context parameters, see :func:plotting_context.

'notebook'
style str or dict

Axes style parameters, see :func:axes_style.

'ticks'
palette str or sequence

Color palette, see :func:color_palette.

'sweet'
font str

Font family, see matplotlib font manager.

'sans-serif'
font_scale float

Separate scaling factor to independently scale the size of the font elements.

1
rc dict or None

Dictionary of rc parameter mappings to override the above.

None

Examples:

>>> setup(
...     context="talk",
...     style="whitegrid",
...     palette="muted",
...     font="Arial",
...     font_scale=1.2,
... )
Source code in nelpy/plotting/rcmod.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def setup(
    context="notebook",
    style="ticks",
    palette="sweet",
    font="sans-serif",
    font_scale=1,
    rc=None,
):
    """
    Set aesthetic figure parameters for matplotlib plots.

    Each set of parameters can be set directly or temporarily. See the
    referenced functions below for more information.

    Parameters
    ----------
    context : str or dict, optional
        Plotting context parameters, see :func:`plotting_context`.
    style : str or dict, optional
        Axes style parameters, see :func:`axes_style`.
    palette : str or sequence, optional
        Color palette, see :func:`color_palette`.
    font : str, optional
        Font family, see matplotlib font manager.
    font_scale : float, optional
        Separate scaling factor to independently scale the size of the
        font elements.
    rc : dict or None, optional
        Dictionary of rc parameter mappings to override the above.

    Examples
    --------
    >>> setup(
    ...     context="talk",
    ...     style="whitegrid",
    ...     palette="muted",
    ...     font="Arial",
    ...     font_scale=1.2,
    ... )
    """
    set_context(context, font_scale)
    set_style(style, rc={"font.family": font})
    set_palette(palette=palette)
    if rc is not None:
        mpl.rcParams.update(rc)

stripplot(*eps, voffset=None, lw=None, labels=None)

Plot epochs as segments on a line.

Parameters:

Name Type Description Default
*eps EpochArray

One or more EpochArray objects to plot.

()
voffset float

Vertical offset between lines.

None
lw float

Line width.

None
labels array-like of str

Labels for each EpochArray.

None

Returns:

Name Type Description
ax Axes

The axis with the strip plot.

Examples:

>>> from nelpy import EpochArray
>>> ep1 = EpochArray([[0, 1], [2, 3]])
>>> ep2 = EpochArray([[4, 5], [6, 7]])
>>> stripplot(ep1, ep2, labels=["A", "B"])
Source code in nelpy/plotting/miscplot.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def stripplot(*eps, voffset=None, lw=None, labels=None):
    """
    Plot epochs as segments on a line.

    Parameters
    ----------
    *eps : nelpy.EpochArray
        One or more EpochArray objects to plot.
    voffset : float, optional
        Vertical offset between lines.
    lw : float, optional
        Line width.
    labels : array-like of str, optional
        Labels for each EpochArray.

    Returns
    -------
    ax : matplotlib.axes.Axes
        The axis with the strip plot.

    Examples
    --------
    >>> from nelpy import EpochArray
    >>> ep1 = EpochArray([[0, 1], [2, 3]])
    >>> ep2 = EpochArray([[4, 5], [6, 7]])
    >>> stripplot(ep1, ep2, labels=["A", "B"])
    """

    # TODO: this plot is in alpha mode; i.e., needs lots of work...
    # TODO: list unpacking if eps is a list of EpochArrays...

    fig = plt.figure(figsize=(10, 2))
    ax0 = fig.add_subplot(111)

    prop_cycle = plt.rcParams["axes.prop_cycle"]
    colors = prop_cycle.by_key()["color"]

    epmin = np.inf
    epmax = -np.inf

    for ii, epa in enumerate(eps):
        epmin = np.min((epa.start, epmin))
        epmax = np.max((epa.stop, epmax))

    # WARNING TODO: this does not yet wrap the color cycler, but it's easy to do with mod arith
    y = 0.2
    for ii, epa in enumerate(eps):
        ax0.hlines(y, epmin, epmax, "0.7")
        for ep in epa:
            ax0.plot(
                [ep.start, ep.stop],
                [y, y],
                lw=6,
                color=colors[ii],
                solid_capstyle="round",
            )
        y += 0.2

    utils.clear_top(ax0)
    #     npl.utils.clear_bottom(ax0)

    if labels is None:
        # try to get labels from epoch arrays
        labels = [""]
        labels.extend([epa.label for epa in eps])
    else:
        labels.insert(0, "")

    ax0.set_yticklabels(labels)

    ax0.set_xlim(epmin - 10, epmax + 10)
    ax0.set_ylim(0, 0.2 * (ii + 2))

    utils.no_yticks(ax0)
    utils.clear_left(ax0)
    utils.clear_right(ax0)

    return ax0

suptitle(t, gs=None, rect=(0, 0, 1, 0.95), **kwargs)

Add a suptitle to a figure with an embedded gridspec.

Parameters:

Name Type Description Default
t str

The suptitle text.

required
gs GridSpec

The gridspec to use. If None, uses fig.npl_gs.

None
rect tuple

Rectangle in figure coordinates (x1, y1, x2, y2).

(0, 0, 1, 0.95)
**kwargs dict

Additional keyword arguments passed to fig.suptitle().

{}

Raises:

Type Description
AttributeError

If no gridspec is found in the figure.

See Also

https://matplotlib.org/users/tight_layout_guide.html

Source code in nelpy/plotting/utils.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def suptitle(t, gs=None, rect=(0, 0, 1, 0.95), **kwargs):
    """
    Add a suptitle to a figure with an embedded gridspec.

    Parameters
    ----------
    t : str
        The suptitle text.
    gs : matplotlib.gridspec.GridSpec, optional
        The gridspec to use. If None, uses fig.npl_gs.
    rect : tuple, optional
        Rectangle in figure coordinates (x1, y1, x2, y2).
    **kwargs : dict
        Additional keyword arguments passed to fig.suptitle().

    Raises
    ------
    AttributeError
        If no gridspec is found in the figure.

    See Also
    --------
    https://matplotlib.org/users/tight_layout_guide.html
    """
    fig = plt.gcf()
    if gs is None:
        try:
            gs = fig.npl_gs
        except AttributeError:
            raise AttributeError(
                "nelpy suptitle requires an embedded gridspec! Use the nelpy FigureManager."
            )

    fig.suptitle(t, **kwargs)
    gs.tight_layout(fig, rect=rect)

veva_scatter(data, *, cmap=None, color=None, ax=None, lw=None, lh=None, **kwargs)

Scatter plot for ValueEventArray objects, colored by value.

Parameters:

Name Type Description Default
data ValueEventArray

The value event data to plot.

required
cmap matplotlib colormap

Colormap to use for the event values.

None
color matplotlib color

Color for the events if cmap is not specified.

None
ax Axes

Axis to plot on. If None, uses current axis.

None
lw float

Line width for the event markers.

None
lh float

Line height for the event markers.

None
**kwargs dict

Additional keyword arguments passed to vlines.

{}

Returns:

Name Type Description
ax Axes

The axis with the scatter plot.

Examples:

>>> from nelpy.core import ValueEventArray
>>> vea = ValueEventArray(
...     [[1, 2, 3], [4, 5, 6]], values=[[10, 20, 30], [40, 50, 60]]
... )
>>> veva_scatter(vea)
Source code in nelpy/plotting/miscplot.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def veva_scatter(data, *, cmap=None, color=None, ax=None, lw=None, lh=None, **kwargs):
    """
    Scatter plot for ValueEventArray objects, colored by value.

    Parameters
    ----------
    data : nelpy.ValueEventArray
        The value event data to plot.
    cmap : matplotlib colormap, optional
        Colormap to use for the event values.
    color : matplotlib color, optional
        Color for the events if cmap is not specified.
    ax : matplotlib.axes.Axes, optional
        Axis to plot on. If None, uses current axis.
    lw : float, optional
        Line width for the event markers.
    lh : float, optional
        Line height for the event markers.
    **kwargs : dict
        Additional keyword arguments passed to vlines.

    Returns
    -------
    ax : matplotlib.axes.Axes
        The axis with the scatter plot.

    Examples
    --------
    >>> from nelpy.core import ValueEventArray
    >>> vea = ValueEventArray(
    ...     [[1, 2, 3], [4, 5, 6]], values=[[10, 20, 30], [40, 50, 60]]
    ... )
    >>> veva_scatter(vea)
    """
    # Sort out default values for the parameters
    if ax is None:
        ax = plt.gca()
    if cmap is None and color is None:
        color = "0.25"
    if lw is None:
        lw = 1.5
    if lh is None:
        lh = 0.95

    hh = lh / 2.0  # half the line height

    # Handle different types of input data
    if isinstance(data, core.ValueEventArray):
        vmin = (
            np.min([np.min(x) for x in data.values]) - 1
        )  # TODO: -1 because white is invisible... fix this properly
        vmax = np.max([np.max(x) for x in data.values])
        norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

        for ii, (events, values) in enumerate(zip(data.events, data.values)):
            if cmap is not None:
                colors = cmap(norm(values))
            else:
                colors = color
            ax.vlines(events, ii - hh, ii + hh, colors=colors, lw=lw, **kwargs)

    else:
        raise NotImplementedError(
            "plotting {} not yet supported".format(str(type(data)))
        )
    return ax