diff --git a/src/dynamic_routing_analysis/decoding_utils.py b/src/dynamic_routing_analysis/decoding_utils.py index e2bc2c9..9f787f1 100644 --- a/src/dynamic_routing_analysis/decoding_utils.py +++ b/src/dynamic_routing_analysis/decoding_utils.py @@ -1096,7 +1096,7 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units= # print(f'time elapsed: {time.time()-start_time}') -def concat_decoder_results(files,savepath=None,return_table=True): +def concat_decoder_results(files,savepath=None,return_table=True,single_session=False): use_half_shifts=False n_repeats=25 @@ -1131,6 +1131,9 @@ def concat_decoder_results(files,savepath=None,return_table=True): #loop through sessions + if single_session: + if type(files) is not list: + files=[files] for file in files: try: decoder_results=pickle.load(open(file,'rb')) @@ -1251,7 +1254,10 @@ def concat_decoder_results(files,savepath=None,return_table=True): linear_shift_df=pd.DataFrame(linear_shift_dict) if savepath is not None: try: - linear_shift_df.to_csv(os.path.join(savepath,'all_unit_linear_shift_use_more_trials.csv')) + if single_session: + linear_shift_df.to_csv(os.path.join(savepath,session_id+'_linear_shift_decoding_results.csv')) + else: + linear_shift_df.to_csv(os.path.join(savepath,'all_linear_shift_decoding_results.csv')) except Exception as e: print(e) print('error saving linear shift df') @@ -1374,7 +1380,7 @@ def compute_significant_decoding_by_area(all_decoder_results): return all_frac_sig_df,all_diff_from_null_df -def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_units=None): +def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_units=None,single_session=False): #load sessions as we go @@ -1477,6 +1483,10 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un start_time=time.time() ##loop through sessions## + if single_session: + if type(files) is not list: + files=[files] + for file in files: try: session_start_time=time.time() @@ -1876,21 +1886,26 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un if not os.path.exists(savepath): os.makedirs(savepath) if n_units is not None: - n_units_str=str(n_units)+'_units' + n_units_str='_'+str(n_units)+'_units' else: n_units_str='' - decoder_confidence_versus_response_type.to_csv(os.path.join(savepath,'decoder_confidence_versus_response_type'+n_units_str+'.csv'),index=False) - decoder_confidence_dprime_by_block.to_csv(os.path.join(savepath,'decoder_confidence_dprime_by_block'+n_units_str+'.csv'),index=False) - decoder_confidence_by_switch.to_csv(os.path.join(savepath,'decoder_confidence_by_switch'+n_units_str+'.csv'),index=False) - decoder_confidence_versus_trials_since_rewarded_target.to_csv(os.path.join(savepath,'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str+'.csv'),index=False) - decoder_confidence_before_after_target.to_csv(os.path.join(savepath,'decoder_confidence_before_after_target'+n_units_str+'.csv'),index=False) - - decoder_confidence_versus_response_type.to_pickle(os.path.join(savepath,'decoder_confidence_versus_response_type'+n_units_str+'.pkl')) - decoder_confidence_dprime_by_block.to_pickle(os.path.join(savepath,'decoder_confidence_dprime_by_block'+n_units_str+'.pkl')) - decoder_confidence_by_switch.to_pickle(os.path.join(savepath,'decoder_confidence_by_switch'+n_units_str+'.pkl')) - decoder_confidence_versus_trials_since_rewarded_target.to_pickle(os.path.join(savepath,'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str+'.pkl')) - decoder_confidence_before_after_target.to_pickle(os.path.join(savepath,'decoder_confidence_before_after_target'+n_units_str+'.pkl')) + if single_session: + temp_session_str=session_id_str+'_' + else: + temp_session_str='' + + decoder_confidence_versus_response_type.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_response_type'+n_units_str+'.csv'),index=False) + decoder_confidence_dprime_by_block.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_dprime_by_block'+n_units_str+'.csv'),index=False) + decoder_confidence_by_switch.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_by_switch'+n_units_str+'.csv'),index=False) + decoder_confidence_versus_trials_since_rewarded_target.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str+'.csv'),index=False) + decoder_confidence_before_after_target.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_before_after_target'+n_units_str+'.csv'),index=False) + + decoder_confidence_versus_response_type.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_response_type'+n_units_str+'.pkl')) + decoder_confidence_dprime_by_block.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_dprime_by_block'+n_units_str+'.pkl')) + decoder_confidence_by_switch.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_by_switch'+n_units_str+'.pkl')) + decoder_confidence_versus_trials_since_rewarded_target.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str+'.pkl')) + decoder_confidence_before_after_target.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_before_after_target'+n_units_str+'.pkl')) if return_table: return decoder_confidence_versus_response_type,decoder_confidence_dprime_by_block,decoder_confidence_by_switch,decoder_confidence_versus_trials_since_rewarded_target,decoder_confidence_before_after_target