Linear track demo¶
Here we load sorted unit data obtained from MatClust to identify putative place cells. The session has about 15 minutes of linear track running, followed by 15 minutes of rest.
Data was recorded by Joshua Chu, with a Spikegadgets wireless headstage, on July 8th, 2017, from the CA1 area of a male Long-Evans rat named install
.
Notebook was created by Etienne Ackermann.
Preliminaries¶
We use the following Python packages; they can be installed either with conda install <pkg>
or with pip install <pkg>
. Most of these packages are very common, and it is more than likely that you have them installed already.
It doesn't hurt to type pip install <pkg>
if the package is already installed, so don't be nervous to try it out!
pip install nelpy
For downloading example data from the web:
1. requests # used to download data from web
2. tqdm # used to show progress of download
the rest of this notebook:
3. numpy # numerical powerhorse for Python
4. matplotlib # used to make plots and figures
5. scipy # signal processing, stats, etc.; used here to smooth signals
6. sklearn # machine learning in Python; used here to do train-test split
# (under the hood) when evaluating Bayesian decoding performance
and of course:
7. nelpy # Ephys object models, and analysis routines
Now that we have all the packages we nee, let's get started! First, we need to get the sample data...
1. Obtain example data¶
We will look for data in the example-data\linear-track\
directory inside your current working directory. If the data doesn't exist, we will download it from https://github.com/nelpy/example-data, and save it to your local machine.
If you already have the data, it won't be downloaded again.
In particular, we will download two files, namely
trajectory.videoPositionTracking
which is a binary file with (x,y) position coordinate pairs and timestamps, andspikes.mat
which is a Matlab file containing information about sorted units (cells) obtained by using MatClust (https://bitbucket.org/mkarlsso/matclust).
import os
import requests
# from tqdm import tqdm_notebook as tqdm
from tqdm import tqdm
datadir = os.path.join(os.getcwd(), "example-data\linear-track")
os.makedirs(datadir, exist_ok=True)
filenames = []
filenames.append(os.path.join(datadir, "trajectory.videoPositionTracking"))
filenames.append(os.path.join(datadir, "spikes.mat"))
urls = []
urls.append(
"https://github.com/nelpy/example-data/raw/master/linear-track/trajectory.videoPositionTracking"
)
urls.append("https://github.com/nelpy/example-data/raw/master/linear-track/spikes.mat")
for filename, url in zip(filenames, urls):
if os.path.exists(filename):
print("you already have the example data, skipping download...")
else:
print("downloading data from {}".format(url))
# Streaming, so we can iterate over the response.
r = requests.get(url, stream=True)
# Total size in bytes.
total_size = int(r.headers.get("content-length", 0))
chunk_size = 1024 # number of bytes to process at a time (NOTE: progress bar unit only accurate if this is 1 kB)
with open(filename, "wb+") as f:
for data in tqdm(
r.iter_content(chunk_size),
total=int(total_size / chunk_size),
unit="kB",
):
f.write(data)
print("data saved to local directory {}".format(filename))
filename_pos = filenames[0]
filename_spikes = filenames[1]
you already have the example data, skipping download... you already have the example data, skipping download...
2. Import packages¶
import struct
import matplotlib.pyplot as plt
import numpy as np
import nelpy as nel
import nelpy.plotting as npl
# assume default aesthetics
npl.setup()
%matplotlib inline
C:\Users\etien\Anaconda3\lib\site-packages\matplotlib\cbook\deprecation.py:106: MatplotlibDeprecationWarning: The mpl_toolkits.axes_grid module was deprecated in version 2.1. Use mpl_toolkits.axes_grid1 and mpl_toolkits.axisartist provies the same functionality instead.
3. Extract position data¶
The position data is stored in a binary file, and we already made a reference to it as filename_pos
, which should point to ./example-data/linear-track/trajectory.videoPositionTracking
, inside your current working directory.
More information about the format of the file can be obtained at https://github.com/nelpy/example-data/tree/master/linear-track .
Here we define a simple function to read the plaintext header from the file, since the header also informs us of the format of the rest of the data.
def print_header(filename, timeout=50):
"""Reads header lines from a SpikeGadgets .rec file, and prints it to screen."""
linecount = 0
with open(filename, "rb") as fileobj:
instr = fileobj.readline()
linecount += 1
while instr != b"<End settings>\n":
print(instr)
instr = fileobj.readline()
if linecount > timeout:
break
print(instr)
print_header(filename_pos)
b'<Start settings>\n' b'threshold: 199\n' b'dark: 0\n' b'clockrate: 30000\n' b'camera resolution: 640x480\n' b'pixel scale: 0 pix/cm\n' b'Fields: <time uint32><xloc uint16><yloc uint16><xloc2 uint16><yloc2 uint16>\n' b'<End settings>\n'
Armed with this new information, we read over the header, and then extract (1) the 32 bit unsigned timestamp
, as well as (2) the 16 bit unsigned position data into lists x1
and y1
(x2
and y2
were not used in this recording, and can be ignored).
n_packets = 500000
timestamps = []
x1 = []
y1 = []
x2 = []
y2 = []
ii = 0
with open(filename_pos, "rb") as fileobj:
instr = fileobj.readline()
n_max_header_lines = 50
hh = 0
# consume all header lines
while instr != b"<End settings>\n":
hh += 1
instr = fileobj.readline()
if hh > n_max_header_lines:
print("End of header not found! Aborting...")
break
for packet in iter(lambda: fileobj.read(12), ""):
if packet:
ts_ = struct.unpack("<L", packet[0:4])[0]
x1_ = struct.unpack("<H", packet[4:6])[0]
y1_ = struct.unpack("<H", packet[6:8])[0]
x2_ = struct.unpack("<H", packet[8:10])[0]
y2_ = struct.unpack("<H", packet[10:12])[0]
timestamps.append(ts_)
x1.append(x1_)
y1.append(y1_)
x2.append(x2_)
y2.append(y2_)
else:
break
if ii >= n_packets:
print("Stopped before reaching end of file")
break
Remark: up to this point, we have not used nelpy
yet. We have our position as two lists x1
and y1
, along with a list of timestamps
. But working with these lists of coordinates and timestamps directly can be tedious and error-prone. Consequently, we will wrap the data into nelpy
objects to make our life easier.
3.a) Get session boundaries¶
First, we will estimate epochs during which the animal was supposed to be running on the track, and those for which the animal was in its sleep box. Note that this task can be a little tricky to estimate directly from the data, since the animal is not always running while on the track, and may in fact just be stationary for significant periods.
We can (and probably should) of course get the experimantal epochs (that of running on the track, and that of resting in the sleep box) from our lab notes, but it's always good to make sure that your data agrees with your notes, and moreover, it demonstrates how we can use nelpy
to estimate these epochs in case our notes were lost, or inaccurate.
# we estimate large periods of inactivity as periods where the animal's estimated position did not move for at least 10 seconds
minLength = 600 # 10 seconds @ 60 fps
bounds, _, _ = nel.utils.get_events_boundaries(
np.gradient(x1),
PrimaryThreshold=0,
SecondaryThreshold=0,
mode="below",
minLength=minLength,
ds=1,
)
# bounds are now in sample numbers, so we convert them to time using timestamps
FS = 30000
bounds_ts = np.zeros(bounds.shape)
for row in range(len(bounds)):
for col in range(2):
bounds_ts[row, col] = timestamps[bounds[row, col]]
Here we've used nelpy
to get event boundaries, in sample numbers, where the animal spent at least 10 seconds with zero change in its x1
coordinate. We could do this, because the tracker does not track the animal when it is in the sleep box (it is outside of the camera's view).
Next, we create our first nelpy
object, namely an EpochArray
containing all the epochs defined by the bounds returned above.
In general it is not necessary to sort epochs like we do below, but we have to do that here because the way we build our session_epochs
is a little convoluted: we build it by first estimating sleep box epochs, and then combine the sleep box epochs with the complement of the sleep box epochs. When we take a union of epochs, they are no longer guaranteed to be in sorted order, and so we sort them for good measure.
sleep_box = nel.EpochArray(
bounds_ts / FS, domain=nel.EpochArray((timestamps[0] / FS, timestamps[-1] / FS))
)
session_epochs = (
sleep_box + ~sleep_box
) # the entire session includes when the animal was in its sleep box,
# and when it was not(~) in its sleep box
session_epochs._sort()
We can always inspect nelpy
objects by printing them to screen like so:
print(session_epochs)
<EpochArray at 0x2a5cc455550: 3 epochs> of length 33:02:423 minutes
Update: we used to have three epochs as a result of the above construction, but somewhere along the line we changed an implementation detail in nelpy so that when contiguous epochs are added together, they are automatically merged. This is sometimes very useful, and sometimes quite frustrating. It's unclear which approach is superior, but the automatically merged case is better for a majority of underlying epoch array processing functions. If this is an issue, you can always subtract epsilon from the right endpoint of your epochs to prevent them from being viewed as contiguous. This is a hacky-but-effective fix, with very low chances of having any unintended consequences.
It used to be that we had three resultant epochs, two in the rest state, and one while running on the track. But since all the epochs are contiguous, we now only have a single epoch. Consequently, to demonstrate what I wanted to demonstrate (before the change in nelpy), I'll artificially reduce the lengths of the epochs such that they are no longer contiguous:
session_epochs = sleep_box.shrink(0.000001, direction="stop") + (~sleep_box).shrink(
0.000001, direction="stop"
) # hacky "fix" described above
print(session_epochs)
<EpochArray at 0x2a5d01934e0: 3 epochs> of length 33:02:423 minutes
which tells us that we have 3 epochs, with a total duration of approximately 30 minutes.
HOWEVER! We expected to only have two epochs! One 15 minute run, followed by a 15 minute sleep box session. So what happened? Well, we can have a closer look at the epochs, by printing out their durations:
for ep in session_epochs:
print(ep.duration)
25.8235 seconds 15:59:115 minutes 16:37:485 minutes
We see that the first epoch is only about 25 seconds long, and this corresponds to when we moved the animal onto the track, so we're happy to ignore that particular epoch, and focus on the other two.
Note also that nelpy
returned a nice human readable duration for each epoch. Under the hood, the durations are stored (calculated, more accurately) in seconds, and we could have computed it ourselves like so:
print(
session_epochs.time
) # .time returns the bounds defining the epochs, in seconds; each row is [start, stop]
[[4397.0317 4422.85523233] [4422.85523333 5381.970299 ] [5381.9703 6379.455599 ]]
print(
session_epochs.time[:, 1] - session_epochs.time[:, 0]
) # durations are stops - starts
[ 25.82353233 959.11506567 997.485299 ]
We see that the durations that we computed above agree with the ones we printed out earlier, but 15:59:115 minutes is arguably easier to parse than 959.11506667 seconds.
In fact, we don't even really need to call print()
on many of the objects, and we can just invoke them to see a representation of the objects. Let's try that on our sleep_box
object:
sleep_box
<EpochArray at 0x2a5d017e710: 2 epochs> of length 17:03:308 minutes
We see that (as we already know) there are two epochs during which we think the animal was in its sleep box, namely the first 25 seconds, and then the last 16 minutes of the recording session.
We can similarly look at the complement of the sleep box epochs, which should then be (naturally) when the animal was on the linear track:
~sleep_box
<EpochArray at 0x2a5d01b3518: 1 epoch> of length 15:59:115 minutes
3.b) Create 2D trajectory object¶
Next we will use our lists x1
and y1
to build a nelpy
AnalogSignalArray
. This is pretty easy: we pass in the position coordinates as a 2 x n_samples array, along with the corresponding timestamps, and we specify that we want this object to be defined on the epoch where the animal was NOT in its sleep box, i.e., when the animal was on the track.
Note also that we defined the trajectory object on a slightly smaller epoch, namely we shrunk the epoch by 20 seconds from both directions. We did this in order to be sure that the first and last parts of the trajectory were not those where we were moving the animal onto and off of the track.
pos = nel.AnalogSignalArray(
np.vstack((x1, y1)),
timestamps=np.array(timestamps) / FS,
support=(~sleep_box).shrink(20),
fs=60,
)
d:\dropbox\code\nelpy\nelpy\core\_analogsignalarray.py:789: UserWarning: ignoring signal outside of support
pos # inspect pos object to see if everything looks good
<AnalogSignalArray at 0x2a5d102b3c8: 2 signals> for a total of 15:19:115 minutes
We can plot the position/trajectory object like so:
npl.plot2d(pos) # plot the entire trajectory
ax = plt.gca()
ax.set_aspect("equal")
Recall that the linear track epoch ranged from approcimately 4420 to 5380 seconds (We can always get these by simply inspecting pos again, like blow):
pos.support.time # print the time boundaries on which pos is defined
array([[4442.85523333, 5361.9703 ]])
and so we can easily plot the trajectory, say, for the first 30 seconds while the animal was on the track. One way to do this, is to define an epoch, and then to restrict the position object to that newly defined epoch. Let's do this now:
ep = nel.EpochArray([4442, 4472])
ep
<EpochArray at 0x2a5d01b3ef0: 1 epoch> of length 30 seconds
npl.plot2d(pos) # plot the entire trajectory
npl.plot2d(pos[ep], color="k") # plot pos restricted to ep
ax = plt.gca()
ax.set_aspect("equal")
d:\dropbox\code\nelpy\nelpy\core\_analogsignalarray.py:789: UserWarning: ignoring signal outside of support
This way of filtering / restricting the data is very powerful, and we will use it later on to get spikes only during times when the animal was running faster than some threshold, and so on.
Before we move onto the next phase (linearization), let's take a quick aside to make the plot above a bit nicer. First, we may want to set the aspect ratio to equal, and we may also want to add titles, labels, etc. We may even want to smooth the trajectory.
sigmas = [0, 0.1, 0.3] # smooth trajectory in time, with sigma = 100 ms (=0.1 seconds)
fig, axes = plt.subplots(ncols=3, figsize=(16, 4))
for sigma, ax in zip(sigmas, axes):
npl.plot2d(pos.smooth(sigma=sigma), lw=0.5, color="0.3", ax=ax)
ax.set_aspect("equal")
ax.set_xlabel("x position")
ax.set_ylabel("y position")
npl.utils.clear_right(ax)
npl.utils.clear_top(ax)
ax.set_title("Smoothed trajectory, $\sigma$={} ms".format(round(sigma * 1000)))
3.c) Linearize position¶
We can do the rest of our analysis in 2D, but sometimes it's nicer to linearize the position data first. Here, we simply use PCA (a dimensionality reduction technique) to find the direction of maximal variance in our position data, and we use this as our new 1D linear track axis.
NOTE: To use PCA from scikit-learn, we need to provide our data as an n_samples x n_features matrix. The internal representation of the nelpy
AnalogSignalArray
can actually be changed, but we can always explicitly ask to get the underlying data back in a particular format. We do this, by asking for _ydata_colsig
, which means that we want each feature (or signal, or component) to be a column of the data matrix. Here you shouldn't be concerned with the favt that we asked for _ydata
: this is simply an internal naming convention of the AnalogSignalArray
, and has nothing to do with our x
and y
coordinates. More precisely, _ydata
refers to the entire data matrix, in contrast to _tdata
which stores the timestamp info. So _ydata
contains both x1
and y1
.
from sklearn.decomposition import PCA
X = pos._ydata_colsig # we access the
pca = PCA(n_components=1)
Xlinear = pca.fit_transform(X)
Xlinear_ = pca.inverse_transform(Xlinear)
ax = npl.plot2d(pos, lw=0.5, color="0.8", label="original trajectory")
plt.plot(Xlinear_[:, 0], Xlinear_[:, 1], color="0.2", label="linearization")
plt.legend()
ax.set_aspect("equal")
ax.set_xlabel("x position")
ax.set_ylabel("y position")
npl.utils.clear_right(ax)
npl.utils.clear_top(ax)
ax.set_title("linearization using PCA")
Text(0.5,1,'linearization using PCA')
We also re-scale our data to range from 0 to 100 (so that we can express movement along the track as a percentage). Typically we would actually scale our data to some physical dimension or unit, such as cm, and not pixels as we have it above, or percentage, as we're transforming it into here.
Xlinear = Xlinear - np.min(Xlinear)
Xlinear = (Xlinear / np.max(Xlinear)) * 100
We now have Xlinear
as a scaled (0 to 100) position variable, as a numpy array. We want to put it back into a nelpy
AnalogSignalArray
container, so that we can slice and interogate it as before:
pos1d = nel.AnalogSignalArray(Xlinear, timestamps=pos.time, support=pos.support, fs=60)
print("pos: ", pos)
print("pos1d:", pos1d)
pos: <AnalogSignalArray at 0x2a5d102b3c8: 2 signals> for a total of 15:19:115 minutes pos1d: <AnalogSignalArray at 0x2a5d20df160: 1 signals> for a total of 15:19:115 minutes
We see that pos1d
now only has 1 signal, as expected, and if we plot it, we can confirm that it ranges between 0 and 100:
fig = plt.figure(figsize=(15, 4))
npl.plot(pos1d)
ax = plt.gca()
ax.set_xlabel("time [s]")
ax.set_ylabel("linearized position (%)")
npl.utils.clear_right(ax)
npl.utils.clear_top(ax)
4. Extract spike times¶
We recorded extracellular activity in region CA1 of the hippocampus from a male Long-Evans rat, and we already performed spike detection and sorting (clusterting) in a separate program named MatClust (https://bitbucket.org/mkarlsso/matclust). MatClust is Matlab-based, and it gives us a .mat file with containing our sorted units and their associated spike times.
Luckily, we can read .mat files directly into Python, but the file format is still a little clunky, and so we need to parse it into a more intuitive format fisrt. For more information on the format of the .mat file, you can refer to https://github.com/nelpy/example-data/tree/master/linear-track.
At any rate, we extract spike times below, so that we end up with a list of lists, where the ith list is a list of the spike times associated with the ith unit. For example,
spikes = [[1, 2, 5, 10], [3, 5], [2, 9, 20]]
contains three units (three inner lists), with the spike times for the 1st unit being [1, 2, 5, 10]
and those for the last unit being [2, 9, 20]
. Spike times are given in seconds, so floating point numbers are used.
# load matlab file contining sorted spikes
mat = nel.io.matlab.load(filename_spikes)
unit_id_to_matclust = dict()
# parse mat file contents into list of lists as described above
spikes = []
ct = 0
num_array = 0
for ii, array in enumerate(mat["spikes"]):
# If empty array, that particular tetrode was not sorted
if array.size > 1:
for jj, subarray in enumerate(array):
if subarray.size != 0:
# Exclude tetrodes with no spikes
if len(subarray["time"].ravel()[0]) != 0:
spikes.append(subarray["time"].ravel()[0])
unit_id_to_matclust[ct + 1] = ii
ct += 1
elif array.size == 1:
if len(array["time"].ravel()[0]) != 0:
spikes.append(array["time"].ravel()[0])
unit_id_to_matclust[ct + 1] = ii
ct += 1
print("Found {} non-empty units total".format(ct))
# unit_id_to_matclust
Found 31 non-empty units total
As before, we now put the list of spike times into a nelpy
container object to make it easier to interact with the data. In particular, we put the spikes into a SpikeTrainArray
. But first, let's also get out the time boundaries for which the spikes were recorded. We can approximate these boundaries by looking for the first and last recorded spikes, but the .mat file actually contains this information explicitly, so we'll get it directly from there:
mattdtype = np.dtype([("time", "O"), ("timerange", "O"), ("meanrate", "O")])
singleidx = 0
multiidx = 0
for nn in range(len(mat["spikes"])):
if mat["spikes"][nn].dtype == mattdtype:
singleidx = nn
start, stop = mat["spikes"][singleidx]["timerange"].ravel()[0]
break
if singleidx:
print("singleidx")
start, stop = mat["spikes"][singleidx]["timerange"].ravel()[0]
else:
for nn in range(len(mat["spikes"])):
try:
for mm in range(len(mat["spikes"][nn])):
if mat["spikes"][nn][mm].dtype == mattdtype:
singleidx = nn
multiidx = mm
start, stop = mat["spikes"][singleidx][multiidx][
"timerange"
].ravel()[0]
break
except TypeError:
continue
if multiidx:
print("multiidx")
start, stop = mat["spikes"][singleidx][multiidx]["timerange"].ravel()[0]
multiidx
# start, stop = mat['spikes'][0][0]['timerange'].ravel()[0]
# Epoch for which spikes were recorded
session_bounds = nel.EpochArray([start, stop])
Now we're ready to build our SpikeTrainArray
. We simply pass in the list of spike times, the epoch during which spikes were recorded, and the sampling rate (recall that we previously specified the sampling rate, FS=30000 Hz).
st = nel.SpikeTrainArray(timestamps=spikes, support=session_bounds, fs=FS)
print(st)
<SpikeTrainArray at 0x2a5d218f048: 31 series> at 30000 Hz
Printing out our SpikeTrainArray
(st
) tells us that we have 31 units, sampled at 30,000 Hz.
We can ask several things from this SpikeTrainArray
object, such as the number of spikes for each unit. To do that, we simply access the .n_spikes
property, which returns the number of spikes associated with each unit as an array:
st.n_spikes
array([1748, 106, 352, 88, 875, 305, 145, 113, 408, 557, 1613, 491, 270, 984, 1381, 7959, 931, 71, 477, 1183, 487, 816, 479, 44, 1065, 92, 41, 2127, 901, 1179, 1541])
Of course, we can also compute these things manually. Let's take a step back, and compute the average firing rate for each unit, using not the nelpy
SpikeTrainArray
, but the original list of lists:
avg_firing_rates = []
st_duration = stop - start
for unit in spikes:
n_spikes = len(unit)
avg_firing_rate = n_spikes / st_duration
avg_firing_rates.append(avg_firing_rate)
print(avg_firing_rates)
[0.8880880967134036, 0.05385431250092722, 0.17883696226723, 0.0447092405668075, 0.44455210790859734, 0.15495816332813964, 0.07366863502485327, 0.057410729364196, 0.20728829717338024, 0.28298917040581567, 0.8195000572075057, 0.24945723998071004, 0.13717607901179574, 0.49993059906521115, 0.7016302411677404, 4.043645973536601, 0.47300344281474754, 0.036072228184583326, 0.24234440625417247, 0.6010344498924236, 0.24742500177312787, 0.4145765943467605, 0.24336052535796357, 0.02235462028340375, 0.5410834227687499, 0.04674147877438966, 0.020830441627717133, 1.0806426668818132, 0.45776165625788134, 0.5990022116848415, 0.7829197694710268]
This wasn't too difficult, but nelpy
can make this a little easier.
Like all other core nelpy
objects, a SpikeTrainArray
has a support on which it is defined, and so we can always get timing information by accessing the underlying support object (instead of using start
and stop
explicitly, as we did above).
Let's calculate the average firing rates using nelpy
now:
avg_firing_rates = st.n_spikes / st.support.duration
print(avg_firing_rates)
[0.8880881 0.05385431 0.17883696 0.04470924 0.44455211 0.15495816 0.07366864 0.05741073 0.2072883 0.28298917 0.81950006 0.24945724 0.13717608 0.4999306 0.70163024 4.04364597 0.47300344 0.03607223 0.24234441 0.60103445 0.247425 0.41457659 0.24336053 0.02235462 0.54108342 0.04674148 0.02083044 1.08064267 0.45776166 0.59900221 0.78291977]
Now wasn't that a little simpler? But the real beauty comes in when the SpikeTrainArray
is not simply defined on one continuous epoch, but on many smaller epochs. For example, we might be interested in asking "what is the average firing rate for each unit, during periods when the animal was running?"---in that case, the firing rate calculation using the list of lists approach would become exceedingly painful, but we would be able to simply write
avg_firing_rates_during_run = st[run_epochs].n_spikes / st[run_epochs].support.duration
which will work just the same, even though st[run].support
now consists of many discontiguous epochs.
Visualize spikes with raster plots¶
Nelpy has several built-in plot types (we have already seen npl.plot2d()
before), and one that is frequently useful is the rasterplot()
.
fig, ax = plt.subplots(ncols=1, figsize=(14, 5))
# use nelpy to plot the spike raster
npl.rasterplot(st, lw=0.5, ax=ax)
ax.set_xlabel("time [seconds]")
ax.set_ylabel("unit")
ax.set_xlim(*session_bounds.time)
We notice that there is a clear difference in overall activity between the run session (4400--5400 seconds) and when the animal was in its sleep box (5550--). For example, units 4, 7, 8, 26, 27, and 29 seem to be largely inactive during run, but active during the sleep box. Let's use the rasterplot()
to highlight those units:
fig, ax = plt.subplots(ncols=1, figsize=(14, 5))
units_in_sleepbox = [2, 4, 6, 7, 8, 26, 27, 29]
units_in_run = [11, 15, 17, 21]
transition_epoch = nel.EpochArray([5382, 5530])
# use nelpy to plot the spike rasters
npl.rasterplot(st[:, units_in_sleepbox], lw=1, ax=ax, color=npl.colors.sweet.red)
npl.rasterplot(st[:, units_in_run], lw=1, ax=ax, color=npl.colors.sweet.blue)
npl.epochplot(
sleep_box - transition_epoch,
alpha=0.2,
hatch="",
color=npl.colors.sweet.red,
label="in sleep box",
)
npl.epochplot(
~sleep_box, alpha=0.2, hatch="", color=npl.colors.sweet.blue, label="on track"
)
ax.set_xlabel("time [seconds]")
ax.set_ylabel("unit")
ax.set_xlim(*session_bounds.time)
plt.legend(loc=(1.02, 0.9))
d:\dropbox\code\nelpy\nelpy\plotting\core.py:906: UserWarning: Spike trains may be plotted in the same vertical position as another unit C:\Users\etien\Anaconda3\lib\site-packages\matplotlib\patches.py:91: UserWarning: Setting the 'color' property will overridethe edgecolor or facecolor properties.
<matplotlib.legend.Legend at 0x2a5d5d0bd68>
The rasterplot()
function has quite a bit of flexibility. As another example, let's plot the same as above, but let's collapse (stack) all the units onto each other:
fig, ax = plt.subplots(ncols=1, figsize=(14, 2.4))
units_in_sleepbox = [2, 4, 6, 7, 8, 26, 27, 29]
units_in_run = [11, 15, 17, 21]
transition_epoch = nel.EpochArray([5382, 5530])
# use nelpy to plot the spike rasters
npl.rasterplot(
st[:, units_in_sleepbox], lw=1, ax=ax, color=npl.colors.sweet.red, vertstack=True
)
npl.rasterplot(
st[:, units_in_run], lw=1, ax=ax, color=npl.colors.sweet.blue, vertstack=True
)
npl.epochplot(
sleep_box - transition_epoch,
alpha=0.2,
hatch="",
color=npl.colors.sweet.red,
label="in sleep box",
)
npl.epochplot(
~sleep_box, alpha=0.2, hatch="", color=npl.colors.sweet.blue, label="on track"
)
ax.set_xlabel("time [seconds]")
ax.set_ylabel("unit")
ax.set_xlim(*session_bounds.time)
plt.legend(loc=(1.02, 0.9))
C:\Users\etien\Anaconda3\lib\site-packages\matplotlib\patches.py:91: UserWarning: Setting the 'color' property will overridethe edgecolor or facecolor properties.
<matplotlib.legend.Legend at 0x2a5d4a227f0>
We can even get a histogram with the number of spikes across all units by using the npl.rastercountplot()
function. Note that this function is not yet fully developed, and still needs to be modified to increase flexibility.
axc, axr = npl.rastercountplot(st, lw=0.5, nbins=130)
axes = (axr, axc)
axr.set_xlabel("time [seconds]")
axr.set_ylabel("unit")
for ax in axes:
ax.set_xlim(*session_bounds.time)
d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:2099: UserWarning: _restrict_to_interval_array() not yet implemented for BinnedTypes
One final raster example, where we call rastercountplot()
on the list of units that seem to be more active in the sleep box than on the track:
axc, axr = npl.rastercountplot(
st[:, units_in_sleepbox], nbins=130, color=npl.colors.sweet.red, vertstack=True
)
axes = (axr, axc)
axr.set_xlabel("time [seconds]")
axr.set_ylabel("unit")
for ax in axes:
ax.set_xlim(*session_bounds.time)
npl.epochplot(
sleep_box - transition_epoch,
ax=axc,
alpha=0.2,
hatch="",
color=npl.colors.sweet.red,
label="in sleep box",
)
npl.epochplot(
~sleep_box,
ax=axc,
alpha=0.2,
hatch="",
color=npl.colors.sweet.blue,
label="on track",
)
axc.legend(loc=(1.0, 0.5))
d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:2099: UserWarning: _restrict_to_interval_array() not yet implemented for BinnedTypes C:\Users\etien\Anaconda3\lib\site-packages\matplotlib\patches.py:91: UserWarning: Setting the 'color' property will overridethe edgecolor or facecolor properties.
<matplotlib.legend.Legend at 0x2a5d0a8c828>
5. Estimate tuning curves¶
Next, we will use the position and spikes to estimate the spatial tuning curves of our 31 units.
Place fields are usually estimated using only those epochs when the animal is running faster than some speed threshold, so we find these epochs first. We smooth our speed estimates with a Guassian kernel with sigma = 100 ms, and then we return those epochs where the animal was running at least 8 percent per second.
Note: Usually we would measure the speed in cm/s or similar, but recall that our position has been scaled to be a percentage; the track was about 117 cm long, so that percent/s is not too different from cm/s.
sigma_100ms = 0.1
# compute and smooth speed of linearized position
speed1 = nel.utils.ddt_asa(pos1d, smooth=True, sigma=sigma_100ms, norm=True)
# find epochs where the animal ran at least as fast as v1=8 percent per second
run_epochs = nel.utils.get_run_epochs(speed1, v1=8, v2=8)
The get_run_epochs()
function has two thresholds (both set to 8 above). More specifically, the first threshold (v1
) is referred to as the primary threshold, and this must be reached or exceeded for an epoch to be considered a candidate. But then the secondary threshold (v2
) is used to determine the epoch boundaries.
For example, if we had used v1=8, v2=5
then only epochs during which the animal reached a speed of at least 8 u/s would be considered, and the epoch boundaries would be when the animal ran ran at or slower than 5 u/s, starting at the peak speed of >8 u/s.
There are other possible arguments, too, so that we can e.g. ask that the animal must run at least 8 u/s for some minimum period of time, and so on.
Let's see for how much time the animal actually ran faster than 8 percent per second:
print(run_epochs)
<EpochArray at 0x2a5d60f64e0: 221 epochs> of length 4:18:131 minutes
We see that there are 221! short epochs during the animal ran sufficiently fast, and that the total amount of time spent running at this speed or faster, is about 4 minutes and 18 seconds (out of approximately 16 minutes spent on the track).
As promised before, we can now easily (trivially) calculate the firing rates of all the units, only during times when the animal was running:
avg_firing_rates_during_run = st[run_epochs].n_spikes / st[run_epochs].support.duration
for unit, (run_rate, avg_rate) in enumerate(
zip(avg_firing_rates_during_run, avg_firing_rates)
):
print(
"unit {}: {:2.3f} Hz (run and rest: {:2.3f} Hz)".format(
unit + 1, run_rate, avg_rate
)
)
unit 1: 1.058 Hz (run and rest: 0.888 Hz) unit 2: 0.004 Hz (run and rest: 0.054 Hz) unit 3: 0.023 Hz (run and rest: 0.179 Hz) unit 4: 0.000 Hz (run and rest: 0.045 Hz) unit 5: 0.112 Hz (run and rest: 0.445 Hz) unit 6: 0.046 Hz (run and rest: 0.155 Hz) unit 7: 0.000 Hz (run and rest: 0.074 Hz) unit 8: 0.015 Hz (run and rest: 0.057 Hz) unit 9: 0.360 Hz (run and rest: 0.207 Hz) unit 10: 0.019 Hz (run and rest: 0.283 Hz) unit 11: 3.281 Hz (run and rest: 0.820 Hz) unit 12: 0.132 Hz (run and rest: 0.249 Hz) unit 13: 0.418 Hz (run and rest: 0.137 Hz) unit 14: 2.243 Hz (run and rest: 0.500 Hz) unit 15: 1.786 Hz (run and rest: 0.702 Hz) unit 16: 6.318 Hz (run and rest: 4.044 Hz) unit 17: 0.856 Hz (run and rest: 0.473 Hz) unit 18: 0.077 Hz (run and rest: 0.036 Hz) unit 19: 0.666 Hz (run and rest: 0.242 Hz) unit 20: 1.251 Hz (run and rest: 0.601 Hz) unit 21: 1.449 Hz (run and rest: 0.247 Hz) unit 22: 0.701 Hz (run and rest: 0.415 Hz) unit 23: 0.209 Hz (run and rest: 0.243 Hz) unit 24: 0.000 Hz (run and rest: 0.022 Hz) unit 25: 0.050 Hz (run and rest: 0.541 Hz) unit 26: 0.008 Hz (run and rest: 0.047 Hz) unit 27: 0.000 Hz (run and rest: 0.021 Hz) unit 28: 3.913 Hz (run and rest: 1.081 Hz) unit 29: 0.023 Hz (run and rest: 0.458 Hz) unit 30: 1.143 Hz (run and rest: 0.599 Hz) unit 31: 1.550 Hz (run and rest: 0.783 Hz)
We can also plot the linearized trajectory only during those run epochs. Each new color indicates that it is a new epoch. The black on the bottom figure corresponds to epochs when the animal did not meet the 8 percent per second requirement, and so those are when the animal is considered to be at rest.
with npl.FigureManager(show=True, nrows=2, figsize=(16, 4)) as (fig2, axes):
npl.utils.skip_if_no_output(fig2)
ax0, ax1 = axes
for ax in axes:
ax.plot(pos1d.time, pos1d.asarray().yvals, lw=1, alpha=0.2, color="gray")
ax.set_ylabel("position (%)")
npl.plot(pos1d[run_epochs], ax=ax0, lw=1, label="run")
npl.plot(pos1d[~run_epochs], ax=ax1, lw=1, label="run", color="k")
npl.utils.no_xticklabels(ax0)
ax0.set_title("run")
ax1.set_title("~run (=rest)")
d:\dropbox\code\nelpy\nelpy\core\_analogsignalarray.py:789: UserWarning: ignoring signal outside of support d:\dropbox\code\nelpy\nelpy\core\_analogsignalarray.py:789: UserWarning: ignoring signal outside of support
Now that we have the epochs during which the animal was running, we can use it to restrict our SpikeTrainArray
to those epochs. Then we will use these spikes, along with the position data, to estimate the spatial tuning curves.
st_run = st[
run_epochs
] # restrict spike trains to those epochs during which the animal was running
We bin our spikes into 50 ms bins, so that we can count the number of spikes in a small window of time, which we will then later associate with the particular position bin that the animal was at when those spikes occured.
We also apply a little bit of spike time smoothing.
ds_run = 0.5 # 100 ms
ds_50ms = 0.05
# smooth and re-bin:
sigma = 0.3 # 300 ms spike smoothing
bst_run = (
st_run.bin(ds=ds_50ms).smooth(sigma=sigma, inplace=True).rebin(w=ds_run / ds_50ms)
)
d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:1689: UserWarning: interval duration is less than bin size: ignoring... d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:2099: UserWarning: _restrict_to_interval_array() not yet implemented for BinnedTypes
We now estimate the tuning curves using the binned spikes during run (bst_run
), and we partition our linear track into n_extern=50
equal-sized spatial bins, and we specify that the linearized position ranges from extmin=0
to extmax=100
. We also smooth the estimated tuning curves with a 0.2 cm smoothing kernel (this should be imperceptable, and a larger value for sigma should really be used, but we can always apply the smoothing later, so it's not a real problem).
sigma = 0.2 # smoothing std dev in cm
tc = nel.TuningCurve1D(
bst=bst_run,
extern=pos1d,
n_extern=50,
extmin=0,
extmax=100,
sigma=sigma,
min_duration=1,
)
d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:385: UserWarning: series tags have not yet been specified
Finally, we reorder the units by their peak firing positions on the track, simply to make visualization a little nicer:
tc = tc.reorder_units()
That's it! Let's see what we've got:
npl.set_palette(npl.colors.rainbow)
with npl.FigureManager(show=True, figsize=(8, 8)) as (fig_tc, ax):
npl.utils.skip_if_no_output(fig_tc)
npl.plot_tuning_curves1D(tc.smooth(sigma=2), normalize=False, pad=3)
We see that some units have spatially localized tuning curves, and some don't. Also, some units have large firing rates, and some don't. To see the shape of all the tuning curves better, we can normalize the peak firing rates like so:
npl.set_palette(npl.colors.rainbow)
with npl.FigureManager(show=True, figsize=(8, 8)) as (fig_tc_norm, ax):
npl.utils.skip_if_no_output(fig_tc_norm)
npl.plot_tuning_curves1D(tc.smooth(sigma=2), normalize=True, pad=0.8)
This normalized view allows us to see the shapes of the tuning curves more clearly, and we can see that units 4, 24, and 7 are essentially flat, meaning they have NO spatial localization for their firing rates. This is partly due to there just being almost no spikes during run for those units though.
At any rate, let's filter out the units a little further, to try and figure out which ones might actually be place cells.
Restrict subset of cells (units) to use for subsequent decoding and/or analysis¶
Here we may request to use
- unimodal cells only,
- pyramidal cells only,
- active cells only,
- any combination of the above, and other criteria
Here we only impose a minimum peak firing rate of 1.5 Hz, and we reject putative interneurons by imposing a maximum average firing rate of 5 Hz averaged over the entire track.
# set criteria for units used in decoding
min_peakfiringrate = 1 # Hz
max_avgfiringrate = 5 # Hz
peak_to_mean_ratio_threshold = (
3.5 # peak firing rate should be greater than 3.5 times mean firing rate
)
# unimodal_cells = find_unimodal_tuningcurves1D(smoothed_rate, peakthresh=0.5)
# enforce minimum peak firing rate
unit_ids_to_keep = set(
np.asanyarray(tc.unit_ids)[
np.argwhere(tc.ratemap.max(axis=1) > min_peakfiringrate).squeeze().tolist()
]
)
# enforce maximum average firing rate
unit_ids_to_keep = unit_ids_to_keep.intersection(
set(
np.asanyarray(tc.unit_ids)[
np.argwhere(tc.ratemap.mean(axis=1) < max_avgfiringrate).squeeze().tolist()
]
)
)
# enforce peak to mean firing ratio
peak_firing_rates = tc.max(axis=1)
mean_firing_rates = tc.mean(axis=1)
ratio = peak_firing_rates / mean_firing_rates
unit_ids_to_keep = unit_ids_to_keep.intersection(
set(
np.asanyarray(tc.unit_ids)[
np.argwhere(ratio >= peak_to_mean_ratio_threshold).squeeze().tolist()
]
)
)
# finally, convert remaining units into a list of indices
unit_ids_to_keep = list(unit_ids_to_keep)
# modify spike trains and ratemap to only include those units that passed all the criteria
sta_placecells = st._unit_subset(unit_ids_to_keep)
tc = tc._unit_subset(unit_ids_to_keep)
# reorder cells by peak firing location on track (this is nice for visualization, but doesn't affect decoding)
tc.reorder_units(inplace=True)
# with plt.xkcd():
with npl.palettes.color_palette(npl.colors.rainbow):
with npl.FigureManager(show=True, nrows=1, ncols=3, figsize=(16, 4)) as (fig, axes):
npl.utils.skip_if_no_output(fig)
ax0, ax1, ax2 = axes
npl.plot_tuning_curves1D(tc.smooth(sigma=3), ax=ax0, pad=5.5)
npl.plot_tuning_curves1D(tc.smooth(sigma=3), ax=ax1, normalize=True, pad=0.9)
npl.plot_tuning_curves1D(tc.smooth(sigma=3), ax=ax2, pad=0)
for ax in axes:
ax.set_xlabel("position [cm]")
npl.utils.xticks_interval(25, *axes)
npl.utils.yticks_interval(5, ax2)
npl.add_simple_scalebar(
"10 Hz",
ax=ax0,
xy=(10, 57),
length=10,
orientation="v",
rotation_text="h",
size=14,
)
# npl.add_simple_scalebar("5 Hz", ax=ax1, xy=(10, 17.5), length=5, orientation='v', rotation_text='h', size=14)
ax0.set_title("True firing rates", size=12)
ax1.set_title("Normalized firing rates", size=12)
ax2.set_title("Collapsed units (pad=0)", size=12)
6. Evaluate decoding performance¶
Next, having defined a subset of 12 place cells, we may reasonably want to ask how well these cells can represent the animal's location? We can evaluate the expected decoding performance by using a Bayesian decoder, and by evaluating the decoding accuracy on a test set.
Here, we use 5-fold cross-validation and the tuning curves of our 12 place cells to evaluate the performance, and the results are summarized in the figure below:
with npl.FigureManager(show=True, figsize=(5, 5)) as (fig, ax):
npl.utils.skip_if_no_output(fig)
ds_run = 0.5 # 100 ms
ds_50ms = 0.05
# st_run = st[run_epochs]
# smooth and re-bin:
bst_run = (
st_run.bin(ds=ds_50ms)
.smooth(sigma=0.15, inplace=True)
.rebin(w=ds_run / ds_50ms)
)
bst = bst_run
npl.plot_cum_error_dist(bst=bst, extern=pos1d, extmin=0, extmax=100, sigma=0.0)
d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:1689: UserWarning: interval duration is less than bin size: ignoring... d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:2099: UserWarning: _restrict_to_interval_array() not yet implemented for BinnedTypes d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:385: UserWarning: series tags have not yet been specified d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:385: UserWarning: series tags have not yet been specified d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:385: UserWarning: series tags have not yet been specified d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:385: UserWarning: series tags have not yet been specified d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:385: UserWarning: series tags have not yet been specified
cumhist, bincenters = nel.decoding.cumulative_dist_decoding_error_using_xval(
bst_run, extern=pos1d, sigma=0.0, extmax=np.ceil(pos1d.max())
)
npl.plot_cum_error_dist(cumhist=cumhist, bincenters=bincenters, label="a")
d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:385: UserWarning: series tags have not yet been specified d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:385: UserWarning: series tags have not yet been specified d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:385: UserWarning: series tags have not yet been specified d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:385: UserWarning: series tags have not yet been specified d:\dropbox\code\nelpy\nelpy\core\_eventarray.py:385: UserWarning: series tags have not yet been specified
We can evaluate cumhist
at any value between p=[0,1]
, which will return the error at that percentile. For example, to get the median decoding error, we simply evaluate cumhist(0.5)
:
print("median decoding error: {:2.2f}".format(cumhist(0.5)))
print("90th percentile decoding error: {:2.2f}".format(cumhist(0.9)))
median decoding error: 4.47 90th percentile decoding error: 19.38
This tells us that, with probability 0.9, we can decode the animal's position to within 20 percent of the true position. Recall that in our particular example, percent and cm are almost the same. The inset shows a zoomed-in view, from where we can see that with probability 0.7, we have an error of less than 8 percent.
It's not the best decoding accuracy imaginable, but overall, not bad!
plt.figure(figsize=(16, 4))
posteriors, lengths, mode_pth, mean_pth = nel.decoding.decode1D(
bst_run.loc[:, unit_ids_to_keep], tc
)
actual_pos = pos1d(bst_run.bin_centers)
plt.plot(actual_pos, c="0.8", label="actual")
plt.ylabel("position (% along track)")
plt.xlabel("time bin (concatenated)")
plt.plot(mean_pth, c="0.2", label="decoded using sorted units")
plt.legend()
plt.show()
C:\Users\etien\Anaconda3\lib\site-packages\numpy\core\_methods.py:26: RuntimeWarning: invalid value encountered in reduce
Some additional views¶
from mpl_toolkits.axes_grid1 import make_axes_locatable
with npl.FigureManager(show=True, figsize=(30, 5)) as (fig, ax):
npl.utils.skip_if_no_output(fig)
npl.rasterplot(st, lw=0.5, ax=ax)
npl.rasterplot(sta_placecells, lw=0.5, color=npl.colors.sweet.green)
npl.epochplot(run_epochs)
divider = make_axes_locatable(ax)
axSpeed1d = divider.append_axes("bottom", size=0.6, pad=0.1)
npl.plot(pos1d, ax=axSpeed1d)
npl.epochplot(run_epochs, ax=axSpeed1d)
xlims = (4500, 5300)
ax.set_xlim(xlims)
axSpeed1d.set_xlim(xlims)
Next, we take a look at where on the track spikes from a particular unit occurred:
ax = npl.plot2d(pos, lw=0.5, c="0.8")
unit_id = 9
_, pos_at_spikes = pos.asarray(at=st[:, unit_id].time)
ax.plot(pos_at_spikes[0, :], pos_at_spikes[1, :], ".")
ax.set_aspect("equal")
fig, axes = plt.subplots(ncols=3, figsize=(16, 4))
ax0, ax1, ax2 = axes
# plot trajectory in gray
for ax in axes:
npl.plot2d(pos, lw=0.5, c="0.8", ax=ax)
ax.set_aspect("equal")
ax.set_xlabel("x position")
ax.set_ylabel("y position")
npl.utils.clear_right(ax)
npl.utils.clear_top(ax)
unit_id = 19
# all spikes
at = st.loc[:, unit_id][pos.support].time
_, pos_at_spikes = pos.asarray(at=at)
ax0.plot(pos_at_spikes[0, :], pos_at_spikes[1, :], ".", color="k")
ax0.set_aspect("equal")
ax0.set_title("All spikes, unit {}".format(unit_id))
# spikes during RUN
at = st.loc[:, unit_id][run_epochs].time
_, pos_at_spikes = pos.asarray(at=at)
ax1.plot(pos_at_spikes[0, :], pos_at_spikes[1, :], ".", color="k")
ax1.set_aspect("equal")
ax1.set_title("Spikes while running, unit {}".format(unit_id))
# spikes during REST
at = st.loc[:, unit_id][~run_epochs].time
_, pos_at_spikes = pos.asarray(at=at)
ax2.plot(pos_at_spikes[0, :], pos_at_spikes[1, :], ".", color="k")
ax2.set_aspect("equal")
ax2.set_title("Spikes while at rest, unit {}".format(unit_id))
Text(0.5,1,'Spikes while at rest, unit 19')
Conclusion¶
This demo is still a work in progress, but hopefully it highlights a few ways in which using nelpy
might make it easier to combine and analyze ephys data with different sampling rates, and different epochs where they are defined. There's a lot more that we can do with nelpy
, so feel free to dig a little deeper to explore!
Feedback about this notebook would be appreciated, and can be sent to era3@rice.edu.
Supplementary analysis: hidden Markov models (HMMs)¶
Preliminaries¶
Nelpy makes use of hmmlearn
(https://github.com/hmmlearn/hmmlearn), but modified to include a Poisson emissions model. The modified version can be found at https://github.com/eackermann/hmmlearn . In addition to the Poisson emissions model, this modified branch also fixed up an issue in calculating the log probabilities, as well as consistency with random seeds. The random seed issue has since been fixed in the main repository, and it is possible that the log probability issue has also been fixed, and if not, I (Etienne) should make a pull request to incorporate the fixes into the main branch. At any rate, for now, nelpy makes use of the fork at https://github.com/eackermann/hmmlearn .
Installing hmmlearn
should be trivial on Linux, fairly easy on MacOS, and doable on Windows. To get it to compile on Windows, all you really need is to download and install the free version of Microsoft Visual Studio (with the Python extensions), and then hmmlearn
should compile with no issues. After installation, you can remove Visual Studio again. However, if you don't want to go through the trouble of compiling it yourself, there is also a 64-bit Windows binary in the nelpy repository.
Overview¶
Here we give a quick demo of how we may use nelpy to learn hidden Markov models. Typically, HMMs are used in the absence of any [observable] behavioral correlate, but here we demonstrate the approach on the run data, where we do actually have access to the position data. Nevertheless, we don't use the position data when training the HMM, nor do we need it to decode neural activity to the state space. But it does make it nice for a demo, since we can use the position data to interpret what those hidden states might represent.
Briefly, we will use a SpikeTrainArray
during bouts of running activity to learn our HMM. The HMM makes use of vectors of spike counts, so we need to bin our SpikeTrainArray
into an appropriate bin size. Since we are learning a model on run data, we want to choose a bin size that's small enough to capture the animal's behavioral dynamics, but large enough to capture sufficiently many spikes (lot of empty bins will do us no good). Here, we will use 250 ms bins.
We will also apply a little bit of spike smoothing (300 ms standard devitation Gaussian kernel), since this has been shown to improve decoding accuracy (in general, not just for HMMs; the idea is roughly that there is some inherent noise in the signals, and so modeling that uncertainty with spike smoothing actually improves our model robustness).
Thereafter, we train our HMM on a subset of the BinnedSpikeTrainArray
(the train set), and we can evaluate our model on the test set. We demonstrate how to decode to the state space, and also how to decode back to an actual behavioral correlate (in this cas position, if we augment our HMM with an additional mapping that we've learned between the state space and the external behavioral correlates.
There are much more that we can do with HMMs, and this section is by no means meant to be comprehensive, nor definitive, but hopefully it demonstrates how to get started, and how to do basic inference.
import nelpy.hmmutils
ds_run = 0.25 # 100 ms
ds_50ms = 0.05
# smooth and re-bin:
sigma = 0.3 # 300 ms spike smoothing
bst_run = (
st_run.bin(ds=ds_50ms).smooth(sigma=sigma, inplace=True).rebin(w=ds_run / ds_50ms)
)
/home/etienne/Dropbox/code/nelpy/nelpy/core/_eventarray.py:1672: UserWarning: interval duration is less than bin size: ignoring... /home/etienne/Dropbox/code/nelpy/nelpy/core/_eventarray.py:2089: UserWarning: _restrict_to_interval_array() not yet implemented for BinnedTypes
from sklearn.model_selection import train_test_split
trainidx, testidx = train_test_split(
np.arange(bst_run.n_epochs), test_size=0.2, random_state=1
)
trainidx.sort()
testidx.sort()
bst_train = bst_run[trainidx]
bst_test = bst_run[testidx]
print("{} train sequences and {} test sequences".format(len(trainidx), len(testidx)))
138 train sequences and 35 test sequences
num_states = 20
hmm = nel.hmmutils.PoissonHMM(n_components=num_states, random_state=0, verbose=False)
hmm.fit(bst_train)
/home/etienne/temp/hmmlearn/hmmlearn/utils.py:87: RuntimeWarning: divide by zero encountered in log /home/etienne/temp/hmmlearn/hmmlearn/utils.py:87: RuntimeWarning: invalid value encountered in log /home/etienne/Dropbox/code/nelpy/nelpy/core/_eventarray.py:385: UserWarning: series tags have not yet been specified
nelpy.PoissonHMM(init_params='stm', n_components=20, n_iter=50, params='stm', random_state=0, verbose=False); fit=True, fit_ext=False
with npl.FigureManager(show=True, nrows=1, ncols=2, figsize=(8, 8)) as (fig, axes):
npl.utils.skip_if_no_output(fig)
ax0, ax1 = axes
ax0.matshow(hmm.transmat_, cmap=plt.cm.RdPu)
ax0.set_title("Before reordering")
transmat_order = hmm.get_state_order("transmat")
hmm.reorder_states(transmat_order)
ax1.matshow(hmm.transmat_, cmap=plt.cm.GnBu)
ax1.set_title("After reordering")
for ax in axes:
npl.utils.no_xticks(ax)
npl.utils.no_xticklabels(ax)
npl.utils.no_yticks(ax)
npl.utils.no_yticklabels(ax)
ys = np.arange(-0.5, num_states + 0.5, step=1)
xs = np.arange(-0.5, num_states + 0.5, step=1)
ax0.hlines(ys[1:-2], xs[2:-1], xs[3:])
ax0.hlines(ys[1:-1], xs[:-1], xs[1:])
ax0.vlines(xs[:-1], ys[:], ys[1:])
ax0.vlines(xs[2:-1], ys[:], ys[1:])
ax1.hlines(ys[1:-2], xs[2:-1], xs[3:])
ax1.hlines(ys[1:-1], xs[:-1], xs[1:])
ax1.vlines(xs[:-1], ys[:], ys[1:])
ax1.vlines(xs[2:-1], ys[:], ys[1:])
Estimate virtual tuning curves¶
ds_run = 0.125 # 100 ms
ds_50ms = 0.05
st_run = st[run_epochs]
# smooth and re-bin:
sigma = 0.3 # 300 ms spike smoothing
bst_run = (
st_run.bin(ds=ds_50ms).smooth(sigma=sigma, inplace=True).rebin(w=ds_run / ds_50ms)
)
/home/etienne/Dropbox/code/nelpy/nelpy/core/_eventarray.py:1672: UserWarning: interval duration is less than bin size: ignoring... /home/etienne/Dropbox/code/nelpy/nelpy/core/_eventarray.py:2089: UserWarning: _restrict_to_interval_array() not yet implemented for BinnedTypes
sigma_tc = 2
bst = bst_run
xpos = pos1d.asarray(at=bst.centers).yvals
x0 = pos1d.min()
xl = pos1d.max()
n_extern = 50
xx_left = np.linspace(x0, xl, n_extern + 1)
xx_mid = np.linspace(x0, xl, n_extern + 1)[:-1]
xx_mid += (xx_mid[1] - xx_mid[0]) / 2
ext_x = np.digitize(xpos, xx_left) - 1 # spatial bin numbers
ext_x = ext_x.astype(float)
ext_x[ext_x == 0] = np.nan
ext_x[ext_x >= n_extern] = np.nan
extern = hmm.fit_ext(X=bst_run, ext=ext_x, n_extern=n_extern)
vtc = nel.TuningCurve1D(ratemap=extern, min_duration=0, extmin=x0, extmax=xl)
vtc = vtc.smooth(sigma=sigma_tc)
states_in_track_order = np.array(vtc.get_peak_firing_order_ids()) - 1
vtc.reorder_units(inplace=True)
hmm.reorder_states(states_in_track_order)
/home/etienne/anaconda3/lib/python3.5/site-packages/ipykernel/__main__.py:13: RuntimeWarning: invalid value encountered in greater_equal /home/etienne/anaconda3/lib/python3.5/site-packages/numpy/lib/function_base.py:780: RuntimeWarning: invalid value encountered in greater_equal /home/etienne/anaconda3/lib/python3.5/site-packages/numpy/lib/function_base.py:781: RuntimeWarning: invalid value encountered in less_equal
npl.setup(font_scale=1.2)
with npl.palettes.color_palette(npl.colors.rainbow):
with npl.FigureManager(show=True, figsize=(6, 6)) as (fig, ax):
npl.utils.skip_if_no_output(fig)
ax = npl.plot_tuning_curves1D(vtc, pad=0.08)
ax.set_xlabel("position [cm]")
ax.set_ylabel("state")
npl.utils.xticks_interval(25)
fig.suptitle("virtual tuning curves estimated from run data")
Decoding to the state space:¶
_, posterior_states = hmm.score_samples(bst_test)
bst_test.lengths
array([ 1, 15, 1, 1, 12, 2, 1, 14, 1, 1, 1, 15, 15, 15, 1, 14, 1, 9, 4, 15, 1, 1, 8, 1, 1, 4, 1, 15, 1, 1, 7, 4, 1, 1, 3])
npl.imagesc(posterior_states[1])
(<matplotlib.axes._subplots.AxesSubplot at 0x7ffa140d5ef0>, <matplotlib.image.AxesImage at 0x7ffa140e7048>)
posteriors, bdries, mode_pth, mean_pth = hmm.decode_ext(
X=bst_test[1], ext_shape=(n_extern,)
)
npl.imagesc(posteriors)
(<matplotlib.axes._subplots.AxesSubplot at 0x7ffa140bd438>, <matplotlib.image.AxesImage at 0x7ffa14088a58>)
posterior_pos, bdries, mode_pth, mean_pth = hmm.decode_ext(
X=bst_test, ext_shape=(vtc.n_bins,)
)
mean_pth = vtc.bins[0] + mean_pth * (vtc.bins[-1] - vtc.bins[0])
plt.plot(
pos1d.asarray(at=bst_test.bin_centers).yvals,
c="0.6",
label="True linearized position",
)
plt.plot(mean_pth, label="HMM decoded position")
plt.legend(loc=(1.025, 0.5))
plt.ylabel("position (cm)")
Text(0,0.5,'position (cm)')
def first_event(st):
"""Returns the [time of the] first event across all series."""
first = np.inf
for series in st.data:
if series[0] < first:
first = series[0]
return first
def last_event(st):
"""Returns the [time of the] last event across all series."""
last = -np.inf
for series in st.data:
if series[-1] > last:
last = series[-1]
return last