From 010a5c43344127d904a19eddc4c5a3bf54a4d308 Mon Sep 17 00:00:00 2001 From: Martynas Dervinis <50549306+dervinism@users.noreply.github.com> Date: Tue, 1 Sep 2020 14:00:51 +0100 Subject: [PATCH] Initial commit --- RUN.txt | 53 ++++ displayWaveforms.m | 30 ++ extractWaveform.m | 375 +++++++++++++++++++++++++ getCluResFromKSdir.m | 40 +++ init.m | 24 ++ ks_batch.m | 86 ++++++ ks_master_file.m | 108 ++++++++ ks_master_file2.m | 66 +++++ postprocessingPipeline.m | 21 ++ resCluFromKilosort.m | 154 +++++++++++ spikeSortingPipeline.m | 573 +++++++++++++++++++++++++++++++++++++++ 11 files changed, 1530 insertions(+) create mode 100644 RUN.txt create mode 100644 displayWaveforms.m create mode 100644 extractWaveform.m create mode 100644 getCluResFromKSdir.m create mode 100644 init.m create mode 100644 ks_batch.m create mode 100644 ks_master_file.m create mode 100644 ks_master_file2.m create mode 100644 postprocessingPipeline.m create mode 100644 resCluFromKilosort.m create mode 100644 spikeSortingPipeline.m diff --git a/RUN.txt b/RUN.txt new file mode 100644 index 0000000..06883fb --- /dev/null +++ b/RUN.txt @@ -0,0 +1,53 @@ +When recording: +* Some channels might be faulty (e.g. if you are using B-stick Neuronexus probe this is bound to be the case) + Do not remove these channels in OpenEphys while recording - this will screw up the pipeline, which assumes all the channels are there. + However do use the pipeline to remove them when processing the recording - doing so is particularly important if you are applying common average referening (CAR). + +Data analysis workflow (no file stitching or truncating): +1. Copy the recording data on R drive after finishing an experiment. +2. Copy the recording data on your local drive (C: or D:). +3. Run spikeSortingPipeline.m located in R:\CSN\Shared\Dynamics\Code\sortingPipeline\spike_sorting\. The file also contains an example of how to set up the structure of the input variable. + spikeSortingPipeline will rearrange recording channels in correct geometric order, subtract the median, save a file with the median trace, the swap order, + and the probe-to-headstage configuration (appended by _medianTrace.mat) and a file containing channel info (forPRB_.mat). Kilosort will be run on the + re-ordered data and files necessary to run phy will be saved in output folders. Finally, electrode drift map figures will be produced and saved in Drift_plot_all_spikes.fig + and Drift_plot_large_spikes.fig files in the output folders. +4. Open the .dat file with re-ordered channels in Neuroscope and inspect it. +5. Open Anaconda prompt and run phy: activate phy >> cd data directory >> phy template-gui params.py. +6. Run postprocessingPipeline.m to extract and save average waveforms (output is saved in the waveforms.mat file). This function also inspects spike sorting cluster quality + and saves data in .qua.1.mat. +7. Copy the recording data back on R: drive. +8. Run your data analysis scripts. + + + +Data analysis workflow with file stitching: +1. Copy the recording data on R drive after finishing an experiment. +2. Copy the recording data on your local drive (C: or D:). + Do not rename files and subfolders at any stage until the data is moved back on R:. +3. Run spikeSortingPipeline.m located in R:\CSN\Shared\Dynamics\Code\sortingPipeline\spike_sorting\. The file also contains an example of how to set up the structure of the input variable. + Make sure you specify files to be stitched. spikeSortingPipeline will rearrange recording channels in correct geometric order, subtract the median, save a file with the median trace, the swap order, + and the probe-to-headstage configuration (appended by _medianTrace.mat) and a file containing channel info (forPRB_.mat). Kilosort will be run on the + re-ordered data and files necessary to run phy will be saved in output folders. Finally, electrode drift map figures will be produced and saved in Drift_plot_all_spikes.fig + and Drift_plot_large_spikes.fig files in the output folders. +4. Open the stitched dat file with re-ordered channels in Neuroscope and inspect it. +5. Open Anaconda prompt and run phy: activate phy >> cd stitched data directory >> phy template-gui params.py. +6. Run makeResClu_fromKilosort.m script located in R:\CSN\Shared\Dynamics\Code\matlib\spikes\. Specify dirname, datFilename, nCh, and chsh variables. + This will produce res and clu files for the stitched data. Note that this stage is necessary only if you intend to examine stitched files separately or compare them to one another. +7. Run zplit_dat.m located in R:\CSN\Shared\Dynamics\Code\matlib\IO\. This will produce res and clu files for individual files that were used to create the stitched data file. + The files are saved in the original folders. Note that this stage is necessary only if you intend to examine stitched files separately or compare them to one another. +8. Copy npy, py, and csv files from the folder with stitched data to folders with individual files used in stitching. +9. Run postprocessingPipeline.m for every file used in stitching separately in order to extract and save average waveforms (output is saved in the waveforms.mat file). + This function also inspects spike sorting cluster quality and saves data in .qua.1.mat. +10. Copy the recording data back on R drive. +11. Run your data analysis scripts. + + + +Data analysis workflow with file splitting (when two probes are used): +1. Copy the recording data on R drive after finishing an experiment. +2. Copy the recording data on your local drive (C: or D:). +3. Create folders for storing files created while processing split recordings (e.g., animal/series/probe1, animal/series/probe2). +4. Do not rename files and subfolders at any stage until the data is moved back on R:. +5. Run splitChannels Matlab function located in R:\CSN\Shared\Dynamics\Code\sortingPipeline\geometric_layout. +6. Run all procedures starting with step 3 in the data pre-processing lists above depending on whether you stitch files or not. + The recordings obtained with different probes should be analysed separately. \ No newline at end of file diff --git a/displayWaveforms.m b/displayWaveforms.m new file mode 100644 index 0000000..63d5f2e --- /dev/null +++ b/displayWaveforms.m @@ -0,0 +1,30 @@ +function displayWaveforms(cluIDs, maxWaveforms, datFileList) + +if (iscell(datFileList) && numel(datFileList) > 1) && (iscell(maxWaveforms) && numel(maxWaveforms) > 1) + nFiles = numel(datFileList); +else + nFiles = 1; + if iscell(datFileList) && numel(datFileList) > 1 + clear datFileList + datFileList{1} = 'all'; + elseif ~iscell(datFileList) + datFileList_temp = datFileList; + clear datFileList + datFileList{1} = datFileList_temp; + end + if ~iscell(maxWaveforms) + maxWaveforms_temp = maxWaveforms; + clear maxWaveforms + maxWaveforms{1} = maxWaveforms_temp; + end +end + +for iFile = 1:nFiles + for iWave = 1:numel(cluIDs) %#ok<*UNRCH> + figure('units', 'normalized', 'position', [0.002, .04, 1, .88], 'Visible','on'); + plot(maxWaveforms{iFile}(iWave,:)) + title(['file ' datFileList{iFile} ' unit ' num2str(cluIDs(iWave))], 'Interpreter','none'); + ylabel('\muV'); + xlabel('data points') + end +end \ No newline at end of file diff --git a/extractWaveform.m b/extractWaveform.m new file mode 100644 index 0000000..bc698b5 --- /dev/null +++ b/extractWaveform.m @@ -0,0 +1,375 @@ +function [waveforms, maxWaveforms, cluIDs, spikeCentreIndex, amplitudes, maxChan] = extractWaveform(inp) +% [waveforms, maxWaveforms, cluIDs, spikeCentreIndex, amplitudes, maxChan] = extractWaveform(inp) +% extractWaveform function extracts average waveforms and other associated +% information about spike waveforms following kilosort and phy spike +% sorting. +% Input: inp structure with the following field (some have defaults and are optional) +% inp.dataDir +% inp.dataFile +% inp.chansIgnore - a number of channels to be ignored (not electrode channels), 0 by default +% inp.outputFile - ('waveforms.mat' by default) +% inp.display - true or false for displaying average waveforms. +% inp.wavelength - the duration of the spike waveform (200 by default) +% inp.merge - merge waveforms for all files: true or false (default). +% Set to true to save the results in inp.dataDir +% inp.dataType ('int16' by default, specify otherwise) +% +% Output: waveforms - a cell array of average waveforms for each stitched +% file. +% maxWaveforms - a cell array of average waveforms containing +% channels with the largest amplitude only. +% cluIDs - a vector of unit IDs corresponding waveforms. +% spikeCentreIndex - a spike centre index on the waveform. +% amplitudes - spike amplitudes. +% maxChan - channel IDs with the largest spike amplitudes. +% The function also saves an output file containing output variables, as +% well as the list of stitched files corresponding to the output data. + + +%% User input + dataDir = inp.dataDir; + dataFile = inp.dataFile; + +if ~isfield(inp, 'dataType') || isempty(inp.dataType) + dataType = 'int16'; +else + dataType = inp.dataType; +end +if ~isfield(inp, 'chansIgnore') || isempty(inp.chansIgnore) + chansIgnore = 0; +else + chansIgnore = inp.chansIgnore; +end +if ~isfield(inp, 'outputFile') || isempty(inp.outputFile) + outputFile = 'waveforms.mat'; +else + outputFile = inp.outputFile; +end +if ~isfield(inp, 'display') || isempty(inp.display) + display = true; +else + display = inp.display; +end +if ~isfield(inp, 'wavelength') || isempty(inp.wavelength) + wavelength = 200; +else + wavelength = inp.wavelength; +end +if ~isfield(inp, 'merge') || isempty(inp.merge) + merge = false; +else + merge = inp.merge; +end + + +%% Extract the spike cluster info and template waveforms +sp = loadKSdir(dataDir); +cluFull = sp.clu; +resFull = round(sp.st * sp.sample_rate); +clu = []; +res = []; +for sh = 1:16 + if exist([dataDir filesep dataFile(1:end-4) '.clu.' num2str(sh)], 'file') && exist([dataDir filesep dataFile(1:end-4) '.res.' num2str(sh)], 'file') + resSh = load([dataDir filesep dataFile(1:end-4) '.res.' num2str(sh)]); + cluSh = load([dataDir filesep dataFile(1:end-4) '.clu.' num2str(sh)]); + assert(numel(resSh) == numel(cluSh) - 1); cluSh = cluSh(2:end); + resSh = resSh(cluSh > 0); % removing noise spikes + cluSh = cluSh(cluSh > 0); + res = [res; resSh]; + clu = [clu; cluSh]; + elseif sh == 1 + [clu, res] = resCluFromKilosort(dataDir, 1, 10000, 1:10000); + clu = clu(2:end); + end +end +[res, swapOrder] = sort(res); +clu = clu(swapOrder); +resFullDiff = resFull' - [0 resFull(1:end-1)']; +resDiff = res' - [0 res(1:end-1)']; +resDiff = resDiff(2:end); +whichAmps = strfind(resFullDiff,resDiff)-1:strfind(resFullDiff,resDiff)-1+numel(resDiff); +if isempty(whichAmps) + error('Waveform amplitudes could not be determined'); +end +amps = sp.tempScalingAmps(whichAmps); +templateWaveforms = sp.temps; +assert(numel(clu) == numel(res) && numel(res) == numel(amps)); +cids = double(sp.cids); cgs = sp.cgs; +cluIDsFull = double(unique(cluFull)); +if any(cluIDsFull == 0) + m = max(cluIDsFull) + 1; + %clu(clu == 0) = m; + cids(cids == 0) = m; + cluIDsFull(cluIDsFull == 0) = m; +end +if any(cluIDsFull == 1) + m = max(cluIDsFull) + 1; + %clu(clu == 1) = m; + cids(cids == 1) = m; + cluIDsFull(cluIDsFull == 1) = m; +end +assert(max(abs(sort(torow(cluIDsFull)) - sort(torow(cids)))) == 0, 'should be fully compatible'); +assert(~any(cgs >= 3), 'unsorted units remain'); +cluIDs = double(unique(clu)); +cluIDs = cluIDs(cluIDs > 1); +for iClu = 1:numel(cluIDsFull) + if ~sum(cluIDsFull(iClu) == cluIDs) + cids(iClu) = NaN; + cgs(iClu) = NaN; + end +end +cids = cids(~isnan(cids)); +cgs = cgs(~isnan(cgs)); +if isempty(cluIDs) && isempty(cids) + disp('The file contains no single unit activity. Please check with Phy if this is indeed correct. extractWaveform function is terminating.'); + waveforms = []; maxWaveforms = []; cluIDs = []; spikeCentreIndex = []; amplitudes = []; maxChan = []; + return +end +assert(max(abs(sort(torow(cluIDs)) - sort(torow(cids)))) == 0, 'should be fully compatible'); + + +%% Load raw data +probeConfFile = dir([dataDir filesep 'forPRB*']); +if isempty(probeConfFile.name) + error('No probe configuration file found in the data folder') +else + load([dataDir filesep probeConfFile.name],'connected'); +end +if exist([dataDir filesep dataFile(1:end-4) '.mat'], 'file') + load([dataDir filesep dataFile(1:end-4) '.mat']); %#ok + nFiles = numel(dataPoints); %#ok +end +%sp.n_channels_dat = numel(sp.xcoords)+inp.chansIgnore; + +chunkSize = 1000000; +fileName = fullfile(dataDir,dataFile); +fprintf('extractWaveform: working on %s, which is presumed to have %d channels\n', fileName, sp.n_channels_dat) +filenamestruct = dir(fileName); +dataTypeNBytes = numel(typecast(cast(0, dataType), 'uint8')); % The size of a single data point in bytes +nSampsTotal = filenamestruct.bytes/sp.n_channels_dat/dataTypeNBytes; +nChunksTotal = ceil(nSampsTotal/chunkSize); + +if ~exist([dataDir filesep dataFile(1:end-4) '.mat'], 'file') + dataPoints = nSampsTotal; + nFiles = 1; + datFileList{1} = fileName; +end + +fid = fopen(fileName, 'r'); +chunkInd = 1; +%templateLength = size(templateWaveforms,2); +templateLength = wavelength; +spikeCentreIndex = templateLength/2 + 1; +templateCh = sum(connected); +for iFile = 1:nFiles + templateWaveformsNew{iFile} = nan(numel(cluIDs), templateLength, templateCh); %#ok<*AGROW,*SAGROW> + nTemplateWaveformsNew{iFile} = nan(numel(cluIDs), templateLength, templateCh); + amplitudes{iFile} = nan(size(cluIDs)); + nAmplitudes{iFile} = nan(size(cluIDs)); + if merge + nFiles = 1; + break %#ok<*UNRCH> + end +end +while 1 + fprintf(1, 'chunk %d/%d\n', chunkInd, nChunksTotal); + dat = fread(fid, [sp.n_channels_dat chunkSize], ['*' dataType]); + if ~isempty(dat) +% if chansIgnore +% dat = dat(1:end-chansIgnore,:); +% end + + + %% Add the spikes to the waveforms + dataRange = round([(chunkInd-1)*chunkSize+1 chunkInd*chunkSize]); + %tRange = dataRange./sp.sample_rate; + resOI = res((res>=dataRange(1)) == (res<=dataRange(2))); + cluOI = clu((res>=dataRange(1)) == (res<=dataRange(2))); + ampOI = amps((res>=dataRange(1)) == (res<=dataRange(2))); + %sum(cluOI == 316)/numel(cluOI) + for i = 1:numel(resOI) + iWaveform = zeros(1,templateLength); + waveform = zeros(templateCh,templateLength); + iClu = cluOI(i); + iCluIDs = find(cluIDs == iClu); + spikeRangeInit = [resOI(i)-templateLength/2 resOI(i)+templateLength/2-1]; + spikeRange = round(spikeRangeInit); + for iFile = 1:nFiles + if iFile == 1 + dataPointStart = 1; + dataPointEnd = sum(dataPoints); + else + dataPointStart = sum(dataPoints(1:iFile-1)) + 1; + dataPointEnd = sum(dataPoints(1:iFile)); + end + if resOI(i) <= dataPointEnd + break + end + end + if merge + filePointer = 1; + else + filePointer = iFile; + end + if spikeRange(1) < dataPointStart + startWaveform = round(dataPointStart - spikeRange(1) + 1); + spikeRange(1) = round(dataPointStart); + elseif spikeRange(1) < dataRange(1) + startWaveform = round(dataRange(1) - spikeRange(1) + 1); + spikeRange(1) = round(dataRange(1)); + else + startWaveform = 1; + end + if spikeRange(2) > dataPointEnd + endWaveform = round(templateLength - (spikeRange(2) - dataPointEnd)); + spikeRange(2) = round(dataPointEnd); + elseif spikeRange(2) > dataRange(2) + endWaveform = round(templateLength - (spikeRange(2) - dataRange(2))); + spikeRange(2) = round(dataRange(2)); + else + endWaveform = templateLength; + end + iWaveform(startWaveform:endWaveform) = ones(1,numel(startWaveform:endWaveform)); + iWaveformFull = repmat(iWaveform,templateCh,1); + iWaveformFull = reshape(iWaveformFull',1,templateLength,templateCh); + if isnan(nTemplateWaveformsNew{filePointer}(iCluIDs,1,1)) + nTemplateWaveformsNew{filePointer}(iCluIDs,:,:) = iWaveformFull; + nAmplitudes{filePointer}(iCluIDs) = 1; + else + nTemplateWaveformsNew{filePointer}(iCluIDs,:,:) = nTemplateWaveformsNew{filePointer}(iCluIDs,:,:) + iWaveformFull; + nAmplitudes{filePointer}(iCluIDs) = nAmplitudes{filePointer}(iCluIDs) + 1; + end + waveform(:,startWaveform:endWaveform) = dat(logical(connected),(spikeRange(1):spikeRange(2))-(dataRange(1)-1)); + waveform = reshape(waveform',1,templateLength,templateCh); + if isnan(templateWaveformsNew{filePointer}(iCluIDs,1,1)) + templateWaveformsNew{filePointer}(iCluIDs,:,:) = waveform; + amplitudes{filePointer}(iCluIDs) = ampOI(i); + else + templateWaveformsNew{filePointer}(iCluIDs,:,:) = templateWaveformsNew{filePointer}(iCluIDs,:,:) + waveform; + amplitudes{filePointer}(iCluIDs) = amplitudes{filePointer}(iCluIDs) + ampOI(i); + end + end + else + break + end + chunkInd = chunkInd+1; +end +for iFile = 1:nFiles + waveforms{iFile} = templateWaveformsNew{iFile}./nTemplateWaveformsNew{iFile}; + amplitudes{iFile} = amplitudes{iFile}./nAmplitudes{iFile}; + if merge + break + end +end + + +%% Pick the largest waveforms +for iFile = 1:nFiles + maxWaveforms{iFile} = zeros(numel(cluIDs), templateLength); + maxChan{iFile} = zeros(size(cluIDs)); + chanMap{iFile} = zeros(numel(cluIDs), 3); + if merge + break + end +end +for iFile = 1:nFiles + for iWave = 1:numel(cluIDs) +% prevMaxVal = 0; + prevMaxDif = 0; + for iChan = 1:templateCh +% maxVal = max(abs(waveforms{iFile}(iWave,spikeCentreIndex-5:spikeCentreIndex+5,iChan))); +% if maxVal > prevMaxVal +% prevMaxVal = maxVal; +% maxWaveforms{iFile}(iWave,:) = waveforms{iFile}(iWave,:,iChan); +% maxChan{iFile}(iWave) = iChan; +% % valChan = iChan; +% end + maxDif = max(abs(waveforms{iFile}(iWave,spikeCentreIndex-3:spikeCentreIndex+3,iChan) - mean(waveforms{iFile}(iWave,[1:spikeCentreIndex-10 spikeCentreIndex+10:end],iChan)))); + if maxDif > prevMaxDif + prevMaxDif = maxDif; + maxWaveforms{iFile}(iWave,:) = waveforms{iFile}(iWave,:,iChan); + maxChan{iFile}(iWave) = iChan; +% difChan = iChan; + end +% if iChan == templateCh +% valChan +% difChan +% end + end + chanMap{iFile}(iWave,1) = cluIDs(iWave); + chanMap{iFile}(iWave,2) = maxChan{iFile}(iWave); + if cgs(cids == cluIDs(iWave)) == 0 % it's noise + chanMap{iFile}(iWave, 3) = 0; + elseif cgs(cids == cluIDs(iWave)) == 1 % it's MUA + chanMap{iFile}(iWave, 3) = 1; + else % it's a unit + chanMap{iFile}(iWave, 3) = cluIDs(iWave); + end + end + if merge + break + end +end + + +%% Clean-up waveforms +% for iFile = 1:1 +% waveExists = logical(chanMap{iFile}(:,2)); +% waveforms{iFile} = waveforms{iFile}(waveExists,:,:); +% maxWaveforms{iFile} = maxWaveforms{iFile}(waveExists,:); +% amplitudes{iFile} = amplitudes{iFile}(waveExists); +% maxChan{iFile} = maxChan{iFile}(waveExists); +% chanMap{iFile} = chanMap{iFile}(waveExists,:); +% end + +for iFile = 1:nFiles + emptyCount = ones(1,numel(cluIDs)); + for iClu = 1:numel(cluIDs) + if nAmplitudes{iFile}(iClu) < 300 + emptyCount(iClu) = 0; + end + end + emptyCount = logical(emptyCount); + waveforms{iFile}(~emptyCount,:,:) = []; + amplitudes{iFile}(~emptyCount) = []; + maxWaveforms{iFile}(~emptyCount,:) = []; + maxChan{iFile}(~emptyCount) = []; + chanMap{iFile}(~emptyCount,:) = []; +end + + +%% Save waveforms +waveforms_temp = waveforms; +maxWaveforms_temp = maxWaveforms; +amplitudes_temp = amplitudes; +maxChan_temp = maxChan; +chanMap_temp = chanMap; +if merge + iOutputFile = [dataDir filesep outputFile]; + datFile = datFileList; + waveforms = waveforms_temp{1}; + maxWaveforms = maxWaveforms_temp{1}; + amplitudes = amplitudes_temp{1}; + maxChan = maxChan_temp{1}; + chanMap = chanMap_temp{1}; + save(iOutputFile, 'datFile','cluIDs','spikeCentreIndex','waveforms','maxWaveforms','amplitudes','maxChan','chanMap', '-v7.3'); +else + for iFile = 1:nFiles + pathStr = fileparts(datFileList{iFile}); + iOutputFile = [pathStr filesep outputFile]; + datFile = datFileList{iFile}; %#ok<*NASGU> + waveforms = waveforms_temp{iFile}; + maxWaveforms = maxWaveforms_temp{iFile}; + amplitudes = amplitudes_temp{iFile}; + maxChan = maxChan_temp{iFile}; + chanMap = chanMap_temp{iFile}; + save(iOutputFile, 'datFile','cluIDs','spikeCentreIndex','waveforms','maxWaveforms','amplitudes','maxChan','chanMap', '-v7.3'); + end +end +fclose all; + + +%% Display waveforms +if display + displayWaveforms(cluIDs, maxWaveforms_temp, datFileList); +end \ No newline at end of file diff --git a/getCluResFromKSdir.m b/getCluResFromKSdir.m new file mode 100644 index 0000000..156b5a1 --- /dev/null +++ b/getCluResFromKSdir.m @@ -0,0 +1,40 @@ +% Input: directory where KiloSort output resides +% Output: clu and res in the legacy format (i.e. time is in samples), clu=0 is noise, clu=1 is MUA +function [clu, res] = getCluResFromKSdir(ksdir) + +sp = loadKSdir(ksdir); + +res = sp.st * sp.sample_rate; +u_clu = sort(unique(sp.clu), 'ascend'); + +if ~isempty(setxor(u_clu, sp.cids)) + error(['some incompatibility in the units: ' num2str(setxor(u_clu, sp.cids))]) +end + +if u_clu(1) == 0 % convert template 0 to the next available id, as 0 will be reserved for noise clusters + sp.clu(sp.clu == 0) = u_clu(end) + 1; + u_clu = [u_clu(2:end); u_clu(end) + 1]; % new id + sp.cgs = circshift(sp.cgs, -1); % move the classification of cluster 0 to the end, to correspond to u_clu +end + +if u_clu(1) == 1 % convert template 1 to the next available id, as 0 will be reserved for noise clusters + sp.clu(sp.clu == 1) = u_clu(end) + 1; + u_clu = [u_clu(2:end); u_clu(end) + 1]; % new id + sp.cgs = circshift(sp.cgs, -1); % move the classification of cluster 1 to the end, to correspond to u_clu +end + +clu = sp.clu; + +for i = 1:numel(u_clu) + if sp.cgs(i) == 0 % noise cluster + clu(clu == u_clu(i)) = 0; + elseif sp.cgs(i) == 1 % MUA cluster + clu(clu == u_clu(i)) = 1; + end +end + +clu = [numel(unique(clu)); clu]; % In the legacy format clu has the total number of clusters as its first element + + + + diff --git a/init.m b/init.m new file mode 100644 index 0000000..f11d3a1 --- /dev/null +++ b/init.m @@ -0,0 +1,24 @@ +fclose all; +close all; +clear all %#ok<*CLALL> +clc + +inp.io.dataFiles = {'R:\Neuropix\md406\continuous.dat'}; +inp.io.procFolders = {'R:\Neuropix\md406'}; +inp.io.outputFolders = {'R:\Neuropix\md406'}; +inp.io.deleteDataFiles = false; +inp.conf.probe = 'A64-A4x4-tet-5mm-150-200-121'; +inp.conf.probeFlip = false; +inp.conf.headstage = 'RHD2164'; +inp.conf.nChans = {1:64}; +inp.conf.samplingFrequency = 30000; +inp.conf.tempFact = 3; +inp.tasks.reorder = false; +inp.tasks.subtractMedian = true; +inp.tasks.stitchFiles = false; +inp.tasks.chanMap = true; +inp.tasks.runKS = 1; +inp.tasks.driftPlot = true; +inp.tasks.deleteChans = []; + +spikeSortingPipeline(inp); \ No newline at end of file diff --git a/ks_batch.m b/ks_batch.m new file mode 100644 index 0000000..b1f1e4e --- /dev/null +++ b/ks_batch.m @@ -0,0 +1,86 @@ +%script for running Kilosort + +%% +% mouseName = 'ALK052'; +% thisDate = datestr(now, 'yyyy-mm-dd'); +% thisDate = '2017-08-23'; +% fnBase = [mouseName '_' thisDate '_g0_t0.imec.']; +fnBase = 'continuous_swappedNoCAR'; +%% parameters + +%ops.chanMap = 'forPRBimecP3opt3.mat'; +ops.chanMap = 'forPRB_A2x2_tet_3mm_150_150_121.mat'; +ops.NchanTOT = 16; +ops.Nfilt = 4*16; % number of filters to use (2-4 times more than Nchan, should be a multiple of 32) +% ops.root = ['D:\Data', filesep, mouseName, filesep, thisDate, filesep]; +ops.root = 'R:\CSN\Shared\Dynamics\spikeSorting\'; +ops.fproc = fullfile('R:\CSN\Shared\Dynamics\spikeSorting', 'temp_wh.dat'); + +% fn = fullfile([ops.root fnBase '.bin']); +% fnAfterCAR = fullfile([ops.root fnBase '.bin']); +fn = fullfile([ops.root fnBase '.dat']); +fnAfterCAR = fullfile([ops.root fnBase '.dat']); + +load(ops.chanMap); + +ext = fn(end-2:end); + +ops.fbinary = fnAfterCAR; + + +% % first perform CAR +% tic +% medianTrace = applyCARtoDat(fn, ops.NchanTOT, ['/localdisk/mush/' mouseName filesep thisDate]); +% toc + +%% then run KS +ks_master_file; +fclose('all'); +return + +% %% then copy to server +% tic +% % basketDrive = 'Z:\'; +% % zserverDrive = 'X:\'; +% % lugaroDrive = 'Y:\'; +% basketDrive = '\\basket.cortexlab.net\data\'; +% zserverDrive = '\\zserver.cortexlab.net\data\'; +% lugaroDrive = '\\lugaro.cortexlab.net\staging\'; +% lugaroDrive2 = '\\lugaro.cortexlab.net\toarchive\'; +% +% zserverDest = fullfile(zserverDrive, 'multichanspikes', mouseName, thisDate); +% basketDest = fullfile(basketDrive, 'nick', mouseName, thisDate); +% lugaroDest = fullfile(lugaroDrive, [mouseName '_' thisDate]); +% +% % npy files go to basket +% mkdir(basketDest) +% fprintf(1, 'moving npy to basket\n'); +% movefile(fullfile(ops.root, '*.npy'), basketDest); +% movefile(fullfile(ops.root, 'params.py'), basketDest); +% movefile(fullfile(ops.root, 'rez.mat'), basketDest); +% toc +% %% afterCAR (along with LFP, meta, median) goes to zserver +% mkdir(zserverDest); +% fprintf(1, 'moving CAR to zserver\n'); +% movefile(fnAfterCAR, zserverDest); +% toc +% +% %% +% fprintf(1, 'moving other to zserver\n'); +% movefile(fullfile(ops.root, [fnBase 'ap.meta']), zserverDest); +% movefile(fullfile(ops.root, [fnBase 'ap_medianTrace.mat']), zserverDest); +% movefile(fullfile(ops.root, [fnBase 'lf.bin']), zserverDest); +% movefile(fullfile(ops.root, [fnBase 'lf.meta']), zserverDest); +% toc +% %% raw goes to toarchive +% fprintf(1, 'moving raw to lugaro\n'); +% tic +% mkdir(lugaroDest); +% movefile(fullfile(ops.root, [fnBase 'ap.bin']), lugaroDest); +% toc +% fprintf(1, 'copying to lugaro''s toarchive\n'); +% tic +% movefile(lugaroDest, lugaroDrive2); +% toc +% fprintf(1, 'done!\n'); +% toc \ No newline at end of file diff --git a/ks_master_file.m b/ks_master_file.m new file mode 100644 index 0000000..a9cd393 --- /dev/null +++ b/ks_master_file.m @@ -0,0 +1,108 @@ +% default options are in parenthesis after the comment + +try + gpuArray(1); + ops.GPU = 1; % whether to run this code on an Nvidia GPU (much faster, mexGPUall first) +catch + ops.GPU = 0; +end +ops.parfor = 0; % whether to use parfor to accelerate some parts of the algorithm +ops.verbose = 1; % whether to print command line progress +ops.showfigures = 1; % whether to plot figures during optimization + +if strcmpi(probe, 'Neuropixels') + ops.datatype = 'bin'; % binary ('dat', 'bin') or 'openEphys' +else + ops.datatype = 'dat'; +end + +ops.fs = 30000; % sampling rate +% ops.NchanTOT = 32; % total number of channels +% ops.Nchan = sum(connected); % number of active channels +% ops.Nfilt = 768; % number of filters to use (2-4 times more than Nchan, should be a multiple of 32) +ops.nNeighPC = 4; % visualization only (Phy): number of channnels to mask the PCs, leave empty to skip (12) +ops.nNeigh = 16; % visualization only (Phy): number of neighboring templates to retain projections of (16) + +% options for channel whitening +ops.whitening = 'full'; % type of whitening (default 'full', for 'noSpikes' set options for spike detection below) +ops.nSkipCov = 1; % compute whitening matrix from every N-th batch (1) +ops.whiteningRange = 32; % how many channels to whiten together (Inf for whole probe whitening, should be fine if Nchan<=32) + +% define the channel map as a filename (string) or simply an array +% ops.chanMap = 'C:\DATA\Spikes\Piroska\chanMap.mat'; % make this file using createChannelMapFile.m +ops.criterionNoiseChannels = 0.2; % fraction of "noise" templates allowed to span all channel groups (see createChannelMapFile for more info). +% ops.chanMap = 1:ops.Nchan; % treated as linear probe if a chanMap file + +% other options for controlling the model and optimization +ops.Nrank = 3; % matrix rank of spike template model (3) +ops.nfullpasses = 6; % number of complete passes through data during optimization (6) +ops.maxFR = 20000; % maximum number of spikes to extract per batch (20000) +ops.fshigh = 300; % frequency for high pass filtering +ops.ntbuff = 64; % samples of symmetrical buffer for whitening and spike detection +ops.scaleproc = 200; % int16 scaling of whitened data +ops.NT = 32*1024+ ops.ntbuff;% this is the batch size (try decreasing if out of memory) +% for GPU should be multiple of 32 + ntbuff + +% the following options can improve/deteriorate results. +% when multiple values are provided for an option, the first two are beginning and ending anneal values, +% the third is the value used in the final pass. +ops.Th = [4 10 10]; % threshold for detecting spikes on template-filtered data ([6 12 12]) +ops.lam = [5 20 20]; % large means amplitudes are forced around the mean ([10 30 30]) +ops.nannealpasses = 4; % should be less than nfullpasses (4) +ops.momentum = 1./[20 400]; % start with high momentum and anneal (1./[20 1000]) +ops.shuffle_clusters = 1; % allow merges and splits during optimization (1) +ops.mergeT = .1; % upper threshold for merging (.1) +ops.splitT = .1; % lower threshold for splitting (.1) + +% options for initializing spikes from data +ops.initialize = 'no'; %'fromData' or 'no' +ops.spkTh = -6; % spike threshold in standard deviations (4) +ops.loc_range = [3 1]; % ranges to detect peaks; plus/minus in time and channel ([3 1]) +ops.long_range = [30 6]; % ranges to detect isolated peaks ([30 6]) +ops.maskMaxChannels = 5; % how many channels to mask up/down ([5]) +ops.crit = .65; % upper criterion for discarding spike repeates (0.65) +ops.nFiltMax = 10000; % maximum "unique" spikes to consider (10000) + +% load predefined principal components (visualization only (Phy): used for features) +dd = load('PCspikes2.mat'); % you might want to recompute this from your own data +ops.wPCA = dd.Wi(:,1:7); % PCs + +% options for posthoc merges (under construction) +ops.fracse = 0.1; % binning step along discriminant axis for posthoc merges (in units of sd) +ops.epu = Inf; + +ops.ForceMaxRAMforDat = 20e9; % maximum RAM the algorithm will try to use; on Windows it will autodetect. + +%% +tic; % start timer + +if strcmp(ops.datatype , 'openEphys') + ops = convertOpenEphysToRawBInary(ops); % convert data, only for OpenEphys +end +% +[rez, DATA, uproj] = preprocessData(ops); + +if strcmp(ops.initialize, 'fromData') + % do scaled kmeans to initialize the algorithm (not sure if functional yet for CPU) + optimizePeaks(uproj); +end +% +rez = fitTemplates(rez, DATA, uproj); + +% +% extracts final spike times (overlapping extraction) +gpuDevice(1); +rez = fullMPMU(rez, DATA); + +% posthoc merge templates (under construction) +% rez = merge_posthoc2(rez); + +% save matlab results file +save(fullfile(ops.root, 'rez.mat'), 'rez', '-v7.3'); + +% save python results file for Phy +rezToPhy(rez, ops.root); + +% remove temporary file +delete(ops.fproc); +%% diff --git a/ks_master_file2.m b/ks_master_file2.m new file mode 100644 index 0000000..ce91f2b --- /dev/null +++ b/ks_master_file2.m @@ -0,0 +1,66 @@ +ops.GPU = 1; % whether to run this code on an Nvidia GPU (much faster, mexGPUall first) +ops.fs = 30000; % sampling rate +ops.nSkipCov = 25; % compute whitening matrix from every N-th batch (1) +ops.whiteningRange = 32; % how many channels to whiten together (Inf for whole probe whitening, should be fine if Nchan<=32) +ops.fshigh = 150; % frequency for high pass filtering +ops.ntbuff = 64; % samples of symmetrical buffer for whitening and spike detection +ops.scaleproc = 200; % int16 scaling of whitened data +ops.NT = 64*1024+ ops.ntbuff;% this is the batch size (try decreasing if out of memory) +ops.Th = [10 2]; % threshold for detecting spikes on template-filtered data ([6 12 12]) +ops.lam = 10; % large means amplitudes are forced around the mean ([10 30 30]) +ops.momentum = [20 400]; % start with high momentum and anneal (1./[20 1000]) +ops.spkTh = -6; % spike threshold in standard deviations (4) +ops.trange = [0 Inf]; % time range to sort +ops.ThPre = 8; % threshold crossings for pre-clustering (in PCA projection space) +ops.minfr_goodchannels = 0.1; % minimum firing rate on a "good" channel (0 to skip) +ops.AUCsplit = 0.9; % splitting a cluster at the end requires at least this much isolation for each sub-cluster (max = 1) +ops.minFR = 1/50; % minimum spike rate (Hz), if a cluster falls below this for too long it gets removed +ops.sigmaMask = 30; % spatial constant in um for computing residual variance of spike +ops.nPCs = feature('numcores'); % how many PCs to project the spikes into +ops.useRAM = 0; % not yet available + +%% +tic; % start timer + +gpuDevice(1); %re-initialize GPU + +rez = preprocessDataSub(ops); + +% time-reordering as a function of drift +rez = clusterSingleBatches(rez); + +% main tracking and template matching algorithm +rez = learnAndSolve8b(rez); + +% final merges +rez = find_merges(rez, 1); + +% final splits by SVD +rez = splitAllClusters(rez, 1); + +% final splits by amplitudes +rez = splitAllClusters(rez, 0); + +% decide on cutoff +rez = set_cutoff(rez); + +fprintf('found %d good units \n', sum(rez.good>0)) + +% write to Phy +fprintf('Saving results to Phy \n') +rezToPhy(rez, ops.root); + +%% if you want to save the results to a Matlab file... + +% discard features in final rez file (too slow to save) +rez.cProj = []; +rez.cProjPC = []; + +% save final results as rez2 +fprintf('Saving final results in rez2 \n') +fname = fullfile(ops.root, 'rez2.mat'); +save(fname, 'rez', '-v7.3'); + +% remove temporary file +delete(ops.fproc); +%% diff --git a/postprocessingPipeline.m b/postprocessingPipeline.m new file mode 100644 index 0000000..67eb6d2 --- /dev/null +++ b/postprocessingPipeline.m @@ -0,0 +1,21 @@ +% Performs postprocessing steps of the spike-sorting pipeline. +% Should be called after the manual refinement step in phy +% At this stage the postprocessing includes two steps: +% 1. computing the average spikewaveforms of the units (by averageing the raw recorded voltage traces) +% 2. calculating the quality of each unit +% Inputs: binaryFilename - the full name of the binary file (on which Kilosort ran) +% noAIchans - number of AI (analog input) channels, which are not about the silicon probe (typically the last channels) +% sr - sampling rate of the binary file (in Hz, 3e4 by default) +function postprocessingPipeline(binaryFilename, noAIchans, sr) +if nargin < 3 + sr = 3e4; +end + +[inp.dataDir, inp.dataFile, ext] = fileparts(binaryFilename); +inp.dataFile = [inp.dataFile, ext]; +inp.display = false; +inp.merge = true; +inp.chansIgnore = noAIchans; + +extractWaveform(inp); % will extract the waveform of each unit, and save the results in waveforms.mat in the same directory +createQualityFileKilosort(inp.dataDir, sr); % This function computes autocorrelograms, and relies on CCG function (e.g. in R:\CSN\Shared\Dynamics\Code\github_cortex-lab_spikes\analysis\helpers\ ) diff --git a/resCluFromKilosort.m b/resCluFromKilosort.m new file mode 100644 index 0000000..4650527 --- /dev/null +++ b/resCluFromKilosort.m @@ -0,0 +1,154 @@ +function [clu, res, templates] = resCluFromKilosort(dirname, shankOI, shCh, chOI, probeFile) +% [clu, res, templates] = resCluFromKilosort(dirname, shankOI, shCh, chOI) +% +% This function extracts clu, res and templates variables from a kilosort +% output directory. It is a helper function to AnPSD_load, +% loadAsMUA_noResClu, loadAsRasterSparse, and extractWaveform. +% +% Inputs: dirname - a kilosort output directory containing npy files. +% shankOI - the shank of interest. +% shCh - a number of recording channels per shank. +% chOI - channels of interest (so that you can look at specific +% shank sections). +% probeFile - a full path to a probe configuration file (forPRB*). +% If you used kilosort2 to spikesort your data, you must +% supply a waveform file. Otherwise, channel locations of +% units will not be identified correctly. +% +% Outputs: clu - a vector with spike IDs. The first element of the vector +% is the total number of unique units (MUA's are grouped +% into a cluster ID 1). +% res - spike index vector. Divide by the sampling frequency in +% order to obtain the spike times. The following is true: +% numel(res) == numel(clu)-1. +% templates is templates = clu(2:end). + +if nargin < 5 + probeFile = []; +end + +ySelection = []; + +sp = loadKSdir(dirname); + +clu = sp.clu; %readNPY([dirname filesep 'spike_clusters.npy']); +res = sp.st * sp.sample_rate; %readNPY([dirname filesep 'spike_times.npy']); +tmpl = sp.spikeTemplates+1;% readNPY([dirname filesep 'spike_templates.npy']); tmpl = tmpl+1; % they start from 0 (python way) +templateWaveforms = sp.temps; %readNPY([dirname filesep 'templates.npy']); % templates x time x channel +ycoordsCh = sp.ycoords; +xcoordsCh = sp.xcoords; + +if isempty(clu) && isempty(res) && isempty(tmpl) + templates = []; + return +else + assert(numel(clu) == numel(res) && numel(res) == numel(tmpl) && max(tmpl) <= size(templateWaveforms, 1)) +end + +cids = sp.cids; cgs = sp.cgs; %[cids, cgs] = readClusterGroupsCSV([dirname filesep 'cluster_groups.csv']); +%[cids, cgs] = readClusterGroupsCSV([dirname filesep 'cluster_group.tsv']); + +uClu = double(unique(clu)); + +assert(max(abs(double(sort(torow(uClu))) - double(sort(torow(cids))))) == 0, 'should be fully compatible') +assert(~any(cgs >= 3), 'unsorted units remain') + +if ~isempty(ySelection) + [~, max_site] = max(max(abs(sp.temps),[],2),[],3); % the maximal site for each template + spike_ycoord = sp.ycoords(max_site(sp.spikeTemplates+1)); + clu = clu(spike_ycoord >= ySelection(1) & spike_ycoord <= ySelection(2)); + res = res(spike_ycoord >= ySelection(1) & spike_ycoord <= ySelection(2)); + uClu = double(unique(clu)); +end +clear max_site spike_ycoord + +% Make sure no unit is named 0 or 1 +if any(uClu == 0) + m = max(uClu) + 1; + clu(clu == 0) = m; + cids(cids == 0) = m; + uClu(uClu == 0) = m; +end +if any(uClu == 1) + m = max(uClu) + 1; + clu(clu == 1) = m; + cids(cids == 1) = m; + uClu(uClu == 1) = m; +end +templates = clu; +templateWaveforms2D = reshape(templateWaveforms, size(templateWaveforms, 1), []); +sh = zeros(size(clu)); % will hold the shank on which each spike resides (according to the template it's assigned to) +oi = zeros(size(clu)); % will be 1 if resides on the channel of interest + +if ~isempty(probeFile) + load(probeFile, 'ycoords','xcoords') +end + +chanMap = []; +for u = torow(uClu) + h = histc(tmpl(clu == u), 1:size(templateWaveforms, 1)); h = h/sum(h); %#ok + w = squeeze(reshape(torow(h)*templateWaveforms2D, size(templateWaveforms, 2), size(templateWaveforms, 3))); + % approximation for the average waveform of this unit. The computation is + % done this way because some units can be assigned to more than one + % template (e.g. after merge in phy) + + + [~, pos] = max(abs(w(:))); if numel(pos) > 1; pos = pos(1); end + % ceil(pos / size(tmp, 1)) -- the channel with the highest spike template waveform + if isempty(probeFile) + sh(clu == u) = ceil(ceil(pos / size(w, 1))/shCh); + fprintf('Unit %d, on ch %d(%d) ==> it''s on shank %d and is ', u, ceil(pos / size(w, 1)), ceil(pos / size(w, 1))-1, ... + ceil(ceil(pos / size(w, 1))/shCh)) + oi(clu == u) = sum(ceil(pos / size(w, 1)) == chOI); + chanMap(end+1, :) = [u ceil(pos / size(w, 1)) NaN]; %#ok + if cgs(cids == u) == 0 % it's noise + clu(clu == u) = 0; + chanMap(end, 3) = 0; + fprintf('noise') + elseif cgs(cids == u) == 1 % it's MUA + clu(clu == u) = 1; + chanMap(end, 3) = 1; + fprintf('MUA') + else + chanMap(end,3) = u; + fprintf('good') + end + else + pos = ceil(pos / size(w, 1)); + ycoordCh = ycoordsCh(pos); + xcoordCh = xcoordsCh(pos); + posY = find(ycoordCh == ycoords); + posX = find(xcoordCh == xcoords(posY)); + pos = posY(1) + posX - 1; + sh(clu == u) = ceil(pos/shCh); + fprintf('Unit %d, on ch %d(%d) ==> it''s on shank %d and is ', u, pos, pos-1, ceil(pos/shCh)) + oi(clu == u) = sum(pos == chOI); + chanMap(end+1, :) = [u pos NaN]; %#ok + if cgs(cids == u) == 0 % it's noise + clu(clu == u) = 0; + chanMap(end, 3) = 0; + fprintf('noise') + elseif cgs(cids == u) == 1 % it's MUA + clu(clu == u) = 1; + chanMap(end, 3) = 1; + fprintf('MUA') + else + chanMap(end,3) = u; + fprintf('good') + end + end + fprintf('\n') +end + +if shankOI + clu = [numel(unique(clu(sh == shankOI & oi))); clu(sh == shankOI & oi)]; + res = res(sh == shankOI & oi); + templates = templates(sh == shankOI & oi); +else + clu = [numel(unique(clu(oi))); clu(oi)]; + res = res(oi); + templates = templates(oi); +end +clu = round(clu); +res = round(res); +templates = round(templates); \ No newline at end of file diff --git a/spikeSortingPipeline.m b/spikeSortingPipeline.m new file mode 100644 index 0000000..e30f613 --- /dev/null +++ b/spikeSortingPipeline.m @@ -0,0 +1,573 @@ +function spikeSortingPipeline(inp) +% A function for re-ordering channels and running kilosort. +% spikeSortingPipeline(inp) +% Input: A structure with the following fields: +% io.dataFiles - raw data file names. If multiple, have to be in a +% cell array. +% io.procFolders - kilosort processing folders for storing +% temp_wh.dat files. If you are stitching input data files +% together, only a single processing folder is needed. Otherwise, +% the number of processing folders should correspond to the number +% of input files. If you do not supply processing folder paths, +% the output folders will be used as processing folders (the last +% one in case you are stitching). The intention should be to place +% processing folders on a local SSD drive, to increase the +% efficiency of the kilosort algorithm. +% io.outputFolders - output folders for saving data files with +% channels re-ordered and/or deleted and possibly medians +% subtracted depending on the tasks. Each input file has to be +% matched by an output folder. If you are merging files, there +% also has to be an extra ouput folder specified for stitched +% data. Other output files produced by automated Kilosort spike +% sorting and electrode drift plots will also be saved in +% corresponding output folders. Folders where Kilosort will run +% should be free from results from any previous Kilosort runs, +% which might create problems. +% io.deleteDataFiles - logical that if true, the files in the +% dataFiles input cell array will be deleted; false by default. +% conf.probe - a string specifying the probe type. Currently +% supported probes are Neuropixels, A32-A1x32-Edge-5mm-20-177, +% A32-A1x32-5mm-25-177, A32-Buzsaki32-5mm-BUZ-200-160, +% A32-A1x32-Poly3-5mm-25s-177, A32-A1x32-Poly3-10mm-50-177, +% A64-Buzsaki64-5mm-BUZ-200-160, A64-A4x4-tet-5mm-150-200-121, +% CM16LP-A2x2-tet-3mm-150-150-121, CM16LP-A4x4-3mm-100-125-177, +% CM16LP-A1x16-Poly2-5mm-50s-177, CM16-A1x16-5mm-25-177, +% CM32-A32-Poly2-5mm-50s-177, CM32-A32-Poly3-5mm-25s-177, +% CM32-A1x32-6mm-100-177, CM32-A1x32-Edge-5mm-100-177, +% H32-A1x32-Edge-5mm-20-177, and H32-Buzsaki32-5mm-BUZ-200-160. +% conf.probeFlip - a logical that is true if the probe/adaptor was +% connected to the headstage upside-down during the recording +% session (the labels on the headstage and probe connectors facing +% opposite sides); default is false. +% conf.headstage - a string specifying the type of headstage used in +% combination with the probe. Supported headstages are +% RHD2132_16ch, RHD2132_32ch, RHD2164_top, RHD2164_bottom, +% RHD2164, and Neuropixels. +% conf.nChans - a cell array with the first element being an EEG +% data channel configuration vector indicating which channels from +% the original file are contained within the current data file. If +% full original file is used, then the vector simply corresponds +% to the original channels (1:end). The second element in the +% array corresponds to the number of extra input channels that are +% not electrode recordings. If the cell array is left empty, the +% default number of EEG recording sites will be assumed based on +% the probe configuration (conf.probe). If only a single element +% is supplied, it will be assumed to correspond to the EEG +% channels only. +% conf.samplingFrequency (default: 30000). +% conf.tempFact is the multiplication factor used to determine the +% number of spike sorting templates to be used by kilosort. The +% number of templates will be the multiple of conf.tempFact and +% the number of recording channels. Default is 6. +% tasks.reorder - a logical that is true for re-ordering channels; +% true by default. +% tasks.subtractMedian - a logical that is true for subtracting the +% median recording trace; false by default +% tasks.stitchFiles - a logical that if true, input data files are +% stitched together; false by default. Note that If you are +% merging files, there has to be an extra ouput folder specified +% for stitched data. +% tasks.chanMap - a logical that if true, creates a channel map file +% (forPRB...); false by default. +% tasks.runKS - a scalar indicating whether to run automated +% kilosort and which version. Available options are: +% 1 - kilosort 1; +% 2 - kilosort 2; +% 0 - don't run kilosort (default). +% tasks.driftPlot - a logical that if true, displays and saves +% electrode drift plots (relevant only if tasks.runKS was true); +% false by default. +% tasks.deleteChans - a vector with channels to be deleted +% (specified according to the original order). + + + +%% Test user input +if ~isfield(inp, 'io') || isempty(inp.io) + errMsg = 'data file input and output folders not specified. An example of setting up function input is at the bottom of the file'; + error(['spikeSortingPipeline: ' errMsg]) +else + io = inp.io; +end + +if ~isfield(inp, 'conf') || isempty(inp.conf) + errMsg = 'probe and headstage names not specified. An example of setting up function input is at the bottom of the file'; + error(['spikeSortingPipeline: ' errMsg]) +else + conf = inp.conf; +end + +if ~isfield(inp, 'tasks') || isempty(inp.tasks) + errMsg = 'no tasks specified. An example of setting up function input is at the bottom of the file'; + error(['spikeSortingPipeline: ' errMsg]) +else + tasks = inp.tasks; + if ~isfield(tasks, 'reorder') && ~isfield(tasks, 'subtractMedian') && ~isfield(tasks, 'stitchFiles') &&... + ~isfield(tasks, 'chanMap') && ~isfield(tasks, 'runKS') && ~isfield(tasks, 'driftPlot') && ~isfield(tasks, 'deleteChans') + error('spikeSortingPipeline: no tasks specified') + end +end + + +if ~isfield(io, 'dataFiles') || isempty(io.dataFiles) + error('spikeSortingPipeline: io.dataFiles input not provided') +else + dataFiles = io.dataFiles; + if ~iscell(dataFiles) + dataFiles = {dataFiles}; + end + for iFile = 1:numel(io.dataFiles) + if ~(strcmpi(io.dataFiles{iFile}(end-2:end),'dat') || strcmpi(io.dataFiles{iFile}(end-2:end),'bin')) + errMsg = 'incorrect data file format. Other than dat and bin file formats are not accepted'; + error(['spikeSortingPipeline: ' errMsg]) + end + end +end + +if ~isfield(io, 'outputFolders') || isempty(io.outputFolders) + error('spikeSortingPipeline: io.outputFolders input not provided') +else + outputFolders = io.outputFolders; + if ~iscell(outputFolders) + outputFolders = {outputFolders}; + end +end +for i1 = 1:numel(outputFolders)-1 + for i2 = 1:i1-1 + assert(~strcmpi(outputFolders{i1}, outputFolders{i2}), ... + 'spikeSortingPipeline: using the same output folder several times creates risk of files being overwritten') + end +end +clear i1 i2 + +if ~isfield(io, 'procFolders') || isempty(io.procFolders) + procFolders = outputFolders; +else + procFolders = io.procFolders; + if ~iscell(procFolders) + procFolders = {procFolders}; + end +end + +if ~isfield(io, 'deleteDataFiles') || isempty(io.deleteDataFiles) + deleteDataFiles = false; +else + deleteDataFiles = io.deleteDataFiles; +end + + +if ~isfield(conf, 'probe') || isempty(conf.probe) + error('spikeSortingPipeline: conf.probe input not provided') +else + probe = conf.probe; + if ~strcmpi(probe,'A32-A1x32-Edge-5mm-20-177') && ~strcmpi(probe,'A32-A1x32-5mm-25-177') &&... + ~strcmpi(probe,'A32-Buzsaki32-5mm-BUZ-200-160') && ~strcmpi(probe,'A32-A1x32-Poly3-5mm-25s-177') &&... + ~strcmpi(probe,'A32-A1x32-Poly3-10mm-50-177') && ~strcmpi(probe,'A64-Buzsaki64-5mm-BUZ-200-160') &&... + ~strcmpi(probe,'A64-A4x4-tet-5mm-150-200-121') && ~strcmpi(probe,'CM16LP-A2x2-tet-3mm-150-150-121') &&... + ~strcmpi(probe,'CM16LP-A4x4-3mm-100-125-177') && ~strcmpi(probe,'CM16LP-A1x16-Poly2-5mm-50s-177') &&... + ~strcmpi(probe,'CM16-A1x16-5mm-25-177') && ~strcmpi(probe,'CM32-A32-Poly2-5mm-50s-177') &&... + ~strcmpi(probe,'CM32-A32-Poly3-5mm-25s-177') && ~strcmpi(probe,'CM32-A1x32-6mm-100-177') &&... + ~strcmpi(probe,'CM32-A1x32-Edge-5mm-100-177') && ~strcmpi(probe,'H32-A1x32-Edge-5mm-20-177') &&... + ~strcmpi(probe,'H32-Buzsaki32-5mm-BUZ-200-160') && ~strcmpi(probe,'Neuropixels') + errMsg = ['probe ' probe ' is not supported. Currently supported probes are A32-A1x32-Edge-5mm-20-177, '... + 'A32-A1x32-5mm-25-177, A32-Buzsaki32-5mm-BUZ-200-160, A32-A1x32-Poly3-5mm-25s-177, A32-A1x32-Poly3-10mm-50-177, '... + 'A64-Buzsaki64-5mm-BUZ-200-160, A64-A4x4-tet-5mm-150-200-121, CM16LP-A2x2-tet-3mm-150-150-121, '... + 'CM16LP-A4x4-3mm-100-125-177, CM16LP-A1x16-Poly2-5mm-50s-177, CM16-A1x16-5mm-25-177, CM32-A1x32-6mm-100-177, '... + 'CM32-A32-Poly2-5mm-50s-177, CM32-A32-Poly3-5mm-25s-177, CM32-A1x32-Edge-5mm-100-177, H32-A1x32-Edge-5mm-20-177, '... + 'H32-Buzsaki32-5mm-BUZ-200-160, and Neuropixels']; + error(['spikeSortingPipeline: ' errMsg]) + end +end + +if ~isfield(conf, 'probeFlip') || isempty(conf.probeFlip) + probeFlip = false; +else + probeFlip = conf.probeFlip; +end + +if ~isfield(conf, 'headstage') || isempty(conf.headstage) + if strcmpi(probe(1:4), 'CM16') + headstage = 'RHD2132_16ch'; + elseif strcmpi(probe(1:3), 'A32') + headstage = 'RHD2164_top'; + elseif strcmpi(probe(1:3), 'A64') + headstage = 'RHD2164'; + elseif strcmpi(probe, 'Neuropixels') + headstage = 'Neuropixels'; + else + error('spikeSortingPipeline: conf.headstage input not provided') + end +else + headstage = conf.headstage; + if ~strcmpi(headstage,'RHD2132_16ch') && ~strcmpi(headstage,'RHD2132_32ch') && ~strcmpi(headstage,'RHD2164_top') &&... + ~strcmpi(headstage,'RHD2164_bottom') && ~strcmpi(headstage,'RHD2164') && ~strcmpi(headstage,'Neuropixels') + errMsg = ['headstage ' headstage ' is not supported. Currently supported headstages are RHD2132_16ch, RHD2132_32ch, '... + 'RHD2164_top, RHD2164_bottom, RHD2164, and Neuropixels']; + error(['spikeSortingPipeline: ' errMsg]) + end + if (strcmp(probe,'Neuropixels') && ~strcmp(headstage,'Neuropixels')) ||... + (strcmpi(probe(1:3),'A64') && ~strcmp(headstage,'RHD2164')) ||... + (strcmpi(probe(1:3),'A32') && (~strcmp(headstage,'RHD2164_top') && ~strcmp(headstage,'RHD2164_bottom') && ~strcmp(headstage,'RHD2164'))) ||... + ((strcmpi(probe(1:3),'H32') || strcmpi(probe(1:4),'CM32')) &&... + (~strcmp(headstage,'RHD2132_32ch') && ~strcmp(headstage,'RHD2164_top') && ~strcmp(headstage,'RHD2164_bottom'))) ||... + (strcmpi(probe(1:4), 'CM16') && ~strcmp(headstage,'RHD2132_16ch')) + error('spikeSortingPipeline: your probe and headstage are incompatible') + end +end + +if ~isfield(conf, 'nChans') || isempty(conf.nChans) + chansIgnore = false; + nChans = []; +else + chansIgnore = true; + nChans = conf.nChans; +end + +if ~isfield(conf, 'samplingFrequency') || isempty(conf.samplingFrequency) + samplingFrequency = 30000; +else + samplingFrequency = conf.samplingFrequency; +end + +if ~isfield(conf, 'tempFact') || isempty(conf.tempFact) + tempFact = 6; +else + tempFact = conf.tempFact; +end + + +if ~isfield(tasks, 'reorder') || isempty(tasks.reorder) + reorder = false; +else + reorder = tasks.reorder; +end + +if ~isfield(tasks, 'subtractMedian') || isempty(tasks.subtractMedian) + subtractMedian = false; +else + subtractMedian = tasks.subtractMedian; +end + +if ~isfield(tasks, 'stitchFiles') || isempty(tasks.stitchFiles) + stitchFiles = false; +else + stitchFiles = tasks.stitchFiles; +end +if stitchFiles + procFolders = {procFolders{end}}; %#ok +end + +if ~isfield(tasks, 'chanMap') || isempty(tasks.chanMap) + chanMapFile = false; +else + chanMapFile = tasks.chanMap; +end + +if ~isfield(tasks, 'runKS') || isempty(tasks.runKS) + runKS = 0; +else + runKS = tasks.runKS; +end + +if ~isfield(tasks, 'driftPlot') || isempty(tasks.driftPlot) + drift = false; +else + drift = tasks.driftPlot; +end +if isfield(tasks, 'driftPlot') && tasks.driftPlot && isempty(which('ksDriftmap')) + addpath(genpath('R:\CSN\Shared\Dynamics\Code\github_cortex-lab_spikes')) + rmpath(genpath('R:\CSN\Shared\Dynamics\Code\github_cortex-lab_spikes\.git')) +elseif ~isfield(tasks, 'driftPlot') + tasks.driftPlot = false; +end +if tasks.driftPlot + assert(isempty(strfind(which('findpeaks'), 'chronux')), 'remove chronux from path, otherwise matlab''s own findpeaks function inaccessible') %#ok +end + +if ~isfield(tasks, 'deleteChans') || isempty(tasks.deleteChans) + deleteChans = []; +else + deleteChans = tasks.deleteChans; +end + +if ~reorder && ~subtractMedian && ~stitchFiles && ~runKS && ~drift && isempty(deleteChans) + error('spikeSortingPipeline: no tasks specified') +end + + +%% Re-order channels and run kilosort +try + reset(gpuDevice); +catch + %do nothing, keep going +end +dataFilesToDelete = dataFiles; +if stitchFiles + channelMedian = []; + medianTrace = []; + for iFile = 1:numel(dataFiles) + if ~exist(outputFolders{iFile},'dir') + mkdir(outputFolders{iFile}); + end + if chansIgnore +% [channelMedianFile, medianTraceFile, ~, ops.chanMap, dataFilesFull{iFile}, dataFiles{iFile}, ops.NchanTOT, swapOrder,... +% probe2headstageConf] = swapCAR(dataFiles{iFile}, probe, probeFlip, headstage, reorder, subtractMedian, chanMapFile,... +% outputFolders{iFile}, deleteChans, nChans); %#ok + [channelMedianFile, medianTraceFile, ops.NchanTOT, ops.chanMap, dataFiles{iFile}, ~, ~, swapOrder, probe2headstageConf] = swapCAR(... + dataFiles{iFile}, probe, probeFlip, headstage, reorder, subtractMedian, chanMapFile, outputFolders{iFile}, deleteChans, nChans); %#ok<*ASGLU,*AGROW> +% connected = zeros(1,nChans); +% connected(1:ops.NchanTOT) = ones(1,ops.NchanTOT); %#ok<*NASGU> + else + [channelMedianFile, medianTraceFile, ~, ops.chanMap, dataFiles{iFile}, ~, ops.NchanTOT, swapOrder, probe2headstageConf] = swapCAR(... + dataFiles{iFile}, probe, probeFlip, headstage, reorder, subtractMedian, chanMapFile, outputFolders{iFile}, deleteChans); + end + channelMedian = [channelMedian channelMedianFile]; + medianTrace = [medianTrace medianTraceFile]; + end + if ~exist(outputFolders{iFile+1},'dir') + mkdir(outputFolders{iFile+1}); + end + stitchedDataFile = [outputFolders{iFile+1} filesep 'stitchedData.dat']; + save([stitchedDataFile(1:end-4) '_medianTrace.mat'], 'channelMedian', 'medianTrace', 'swapOrder', 'probe2headstageConf', '-v7.3'); + load(ops.chanMap); + if chanMapFile + probeConfFile = [outputFolders{iFile+1} filesep 'forPRB_' probe2headstageConf.probeConf.probe '.mat']; + save(probeConfFile, 'chanMap', 'chanMap0ind', 'connected', 'shankInd', 'xcoords', 'ycoords', '-v7.3'); %#ok<*USENS> + else + probeConfFile = dir([outputFolders{iFile+1} filesep 'forPRB*.mat']); + probeConfFile = probeConfFile.name; + end + ops.chanMap = probeConfFile; + fix_dat_stitch(dataFiles, ops.NchanTOT, samplingFrequency, stitchedDataFile); + if runKS + if runKS == 1 + addpath(genpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort')) + rmpath(genpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort\.git')) + rmpath(genpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort\CUDA')) + if verLessThan('matlab', '9.4') % add CUDA8 mex file directory + addpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort\CUDA\CUDA8') + elseif verLessThan('matlab', '9.6') % add CUDA9 mex file directory + addpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort\CUDA\CUDA9') + elseif verLessThan('matlab', '9.7') % add CUDA10.0 mex file directory + addpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort\CUDA\CUDA10') + elseif verLessThan('matlab', '9.8') % add CUDA10.1 mex file directory + addpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort\CUDA\CUDA101') + end + elseif runKS == 2 + addpath(genpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort2')) + rmpath(genpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort2\.git')) + rmpath(genpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort2\CUDA')) + if verLessThan('matlab', '9.4') % add CUDA8 mex file directory + addpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort2\CUDA\CUDA8') + elseif verLessThan('matlab', '9.6') % add CUDA9 mex file directory + addpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort2\CUDA\CUDA9') + elseif verLessThan('matlab', '9.7') % add CUDA10.0 mex file directory + addpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort2\CUDA\CUDA10') + elseif verLessThan('matlab', '9.8') % add CUDA10.1 mex file directory + addpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort2\CUDA\CUDA101') + end + end + + load(ops.chanMap); + if runKS == 1 + [ops.Nfilt, ops.root, ops.fproc, ops.fbinary] = initKS(ops.NchanTOT, outputFolders{iFile+1},... + procFolders{1}, stitchedDataFile, true, tempFact); + try + ks_master_file; + catch me + if exist(ops.fproc, 'file') + delete(ops.fproc); + end + disp(getReport(me)) + throw(me); + end + elseif runKS == 2 + [~, ops.root, ops.fproc, ops.fbinary] = initKS(ops.NchanTOT, outputFolders{iFile+1},... + procFolders{1}, stitchedDataFile, true, tempFact); + ops.nfilt_factor = tempFact; + try + ks_master_file2; + catch me + if exist(ops.fproc, 'file') + delete(ops.fproc); + end + disp(getReport(me)) + throw(me); + end + end + fclose('all'); + end +else + for iFile = 1:numel(dataFiles) %#ok<*UNRCH> + if ~exist(outputFolders{iFile},'dir') + mkdir(outputFolders{iFile}); + end + if chansIgnore +% [~, ~, ~, ops.chanMap, dataFilesFull{iFile}, dataFiles{iFile}, ops.NchanTOT] = swapCAR(dataFiles{iFile}, probe,... +% probeFlip, headstage, reorder, subtractMedian, chanMapFile, outputFolders{iFile}, deleteChans, nChans); %#ok + [~, ~, ops.NchanTOT, ops.chanMap, dataFiles{iFile}] = swapCAR(dataFiles{iFile}, probe, probeFlip, headstage, reorder,... + subtractMedian, chanMapFile, outputFolders{iFile}, deleteChans, nChans); +% connected = zeros(1,nChans); +% connected(1:ops.NchanTOT) = ones(1,ops.NchanTOT); + else + [~, ~, ~, ops.chanMap, dataFiles{iFile}, ~, ops.NchanTOT] = swapCAR(dataFiles{iFile}, probe, probeFlip, headstage, reorder,... + subtractMedian, chanMapFile, outputFolders{iFile}, deleteChans); + end + if runKS + if runKS == 1 + addpath(genpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort')) + rmpath(genpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort\.git')) + rmpath(genpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort\CUDA')) + if verLessThan('matlab', '9.4') % add CUDA8 mex file directory + addpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort\CUDA\CUDA8') + elseif verLessThan('matlab', '9.6') % add CUDA9 mex file directory + addpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort\CUDA\CUDA9') + elseif verLessThan('matlab', '9.7') % add CUDA10.0 mex file directory + addpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort\CUDA\CUDA10') + elseif verLessThan('matlab', '9.8') % add CUDA10.1 mex file directory + addpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort\CUDA\CUDA101') + end + elseif runKS == 2 + addpath(genpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort2')) + rmpath(genpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort2\.git')) + rmpath(genpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort2\CUDA')) + if verLessThan('matlab', '9.4') % add CUDA8 mex file directory + addpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort2\CUDA\CUDA8') + elseif verLessThan('matlab', '9.6') % add CUDA9 mex file directory + addpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort2\CUDA\CUDA9') + elseif verLessThan('matlab', '9.7') % add CUDA10.0 mex file directory + addpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort2\CUDA\CUDA10') + elseif verLessThan('matlab', '9.8') % add CUDA10.1 mex file directory + addpath('R:\Neuropix\Shared\Code\github_cortex-lab_KiloSort2\CUDA\CUDA101') + end + end + + if ~chanMapFile + ops.chanMap = dir([fileparts(dataFiles{iFile}) filesep 'forPRB*.mat']); + ops.chanMap = ops.chanMap.name; + end + load(ops.chanMap); + if runKS == 1 + [ops.Nfilt, ops.root, ops.fproc, ops.fbinary] = initKS(ops.NchanTOT, outputFolders{iFile},... + procFolders{iFile}, dataFiles{iFile}, false, tempFact); + try + ks_master_file; + catch me + if exist(ops.fproc, 'file') + delete(ops.fproc); + end + disp(getReport(me)) + throw(me); + end + elseif runKS == 2 + [~, ops.root, ops.fproc, ops.fbinary] = initKS(ops.NchanTOT, outputFolders{iFile},... + procFolders{iFile}, dataFiles{iFile}, false, tempFact); + ops.nfilt_factor = tempFact; + try + ks_master_file2; + catch me + if exist(ops.fproc, 'file') + delete(ops.fproc); + end + disp(getReport(me)) + throw(me); + end + end + fclose('all'); + end + end +end + +fclose('all'); +recycleState = recycle('on'); +if deleteDataFiles + for iFile = 1:numel(dataFilesToDelete) + delete(dataFilesToDelete{iFile}); + end +end +% if chansIgnore +% for iFile = 1:numel(dataFiles) +% delete(dataFiles{iFile}); +% end +% if stitchFiles +% delete(stitchedDataFile); +% stitchedDataFile = [outputFolders{iFile+1} filesep 'stitchedDataFull.dat']; +% fix_dat_stitch(dataFilesFull, nChans, samplingFrequency, stitchedDataFile); +% for iFile = 1:numel(dataFilesFull) +% delete(dataFilesFull{iFile}); +% end +% end +% % elseif stitchFiles +% % for iFile = 1:numel(dataFiles) +% % delete(dataFiles{iFile}); +% % end +% end +recycle(recycleState); +fclose('all'); + + +%% Inspect electrode drift +if drift + close all + if stitchFiles + [spikeTimes, spikeAmps, spikeDepths] = ksDriftmap(outputFolders{end}); + plotDriftmap(spikeTimes, spikeAmps, spikeDepths); + f1 = gcf; + hgsave(f1, [outputFolders{end} filesep 'Drift_plot_all_spikes']); + close(f1); + plotDriftmap(spikeTimes, spikeAmps, spikeDepths, 'show'); + f2 = gcf; + hgsave(f2, [outputFolders{end} filesep 'Drift_plot_large_spikes']); + close(f2); + else + for iFile = 1:numel(dataFiles) + [spikeTimes, spikeAmps, spikeDepths] = ksDriftmap(outputFolders{iFile}); + plotDriftmap(spikeTimes, spikeAmps, spikeDepths); + f1 = gcf; + hgsave(f1, [outputFolders{iFile} filesep 'Drift_plot_all_spikes']); + close(f1); + plotDriftmap(spikeTimes, spikeAmps, spikeDepths, 'show'); + f2 = gcf; + hgsave(f2, [outputFolders{iFile} filesep 'Drift_plot_large_spikes']); + close(f2); + end + end +end +return + + +%% Functions +function [Nfilt, root, fproc, fbinary] = initKS(Nchan, root, proc, dataFile, stitchFiles, tempFact) + +if stitchFiles + Nfilt = ceil(tempFact*(Nchan-1)/32)*32; % number of filters to use (2-4 times more than Nchan, should be divisible by 32) +else + Nfilt = ceil(tempFact*(Nchan-1)/32)*32; %Nfilt = ceil(3*Nchan/32)*32; % number of filters to use (2-4 times more than Nchan, should be divisible by 32) +end +fproc = fullfile(proc, 'temp_wh.dat'); +fbinary = dataFile; +return + +%% Example: + + + +inp.io.dataFiles = {'R:\Neuropix\Shared\Data\M191018_MD\original_data\TCB-2_g0_t0.imec0.ap.bin'}; +inp.io.procFolders = {'D:\'}; +inp.io.outputFolders = {'R:\Neuropix\Shared\Data\M191018_MD'}; +inp.io.deleteDataFiles = false; +inp.conf.probe = 'Neuropixels'; +inp.conf.probeFlip = false; +inp.conf.headstage = 'Neuropixels'; +inp.conf.nChans = {1:384; 1}; +inp.conf.samplingFrequency = 30000; +inp.conf.tempFact = 3; +inp.tasks.reorder = false; +inp.tasks.subtractMedian = false; +inp.tasks.stitchFiles = false; +inp.tasks.chanMap = true; +inp.tasks.runKS = 1; +inp.tasks.driftPlot = true; +inp.tasks.deleteChans = []; + +spikeSortingPipeline(inp);