From c6e752ddf8612fe4d450a86e83c94aac165ce7d7 Mon Sep 17 00:00:00 2001 From: Ayushi Daksh <37770155+AyushiDaksh@users.noreply.github.com> Date: Fri, 30 Jun 2023 07:05:36 -0700 Subject: [PATCH] Make MontecarloTransport and montecarlo_numba track the last line interaction shell ids for real and virtual packets (#2344) * Add in/out last line interaction shell ids for real and virtual packets * Update mailmap * Do not handle last_line_interaction_in_shell_id and last_line_interaction_out_shell_id separately --- docs/io/output/vpacket_logging.rst | 5 ++++ tardis/analysis.py | 15 ++++++++++++ tardis/io/model_reader.py | 4 ++++ tardis/io/tests/test_model_reader.py | 4 ++++ tardis/montecarlo/base.py | 2 ++ tardis/montecarlo/montecarlo_numba/base.py | 23 ++++++++++++++++++- .../montecarlo_numba/interaction.py | 1 + .../montecarlo_numba/numba_interface.py | 13 +++++++++++ .../montecarlo/montecarlo_numba/r_packet.py | 2 ++ .../tests/test_numba_interface.py | 11 ++++++++- tardis/montecarlo/montecarlo_numba/vpacket.py | 2 +- tardis/tests/test_tardis_full.py | 1 + tardis/transport/r_packet_transport.py | 3 +-- 13 files changed, 81 insertions(+), 5 deletions(-) diff --git a/docs/io/output/vpacket_logging.rst b/docs/io/output/vpacket_logging.rst index 5b91370beb9..a5874e18987 100644 --- a/docs/io/output/vpacket_logging.rst +++ b/docs/io/output/vpacket_logging.rst @@ -51,4 +51,9 @@ After running the simulation, the following information can be retrieved: - Numpy array - | If the last interaction was a line interaction, the | line_interaction_out_id for that interaction + | (see :doc:`physical_quantities`) + * - ``transport.virt_packet_last_line_interaction_shell_id`` + - Numpy array + - | If the last interaction was a line interaction, the + | line_interaction_shell_id for that interaction | (see :doc:`physical_quantities`) \ No newline at end of file diff --git a/tardis/analysis.py b/tardis/analysis.py index 649938a686f..987ee5fe0c0 100644 --- a/tardis/analysis.py +++ b/tardis/analysis.py @@ -54,6 +54,7 @@ def __init__( self._wavelength_end = np.inf * u.angstrom self._atomic_number = None self._ion_number = None + self._shell = None self.packet_filter_mode = packet_filter_mode self.update_last_interaction_filter() @@ -97,6 +98,15 @@ def ion_number(self, value): self._ion_number = value self.update_last_interaction_filter() + @property + def shell(self): + return self._shell + + @shell.setter + def shell(self, value): + self._shell = value + self.update_last_interaction_filter() + def update_last_interaction_filter(self): if self.packet_filter_mode == "packet_out_nu": packet_filter = ( @@ -122,6 +132,11 @@ def update_last_interaction_filter(self): "allowed are: packet_out_nu, packet_in_nu, line_in_nu" ) + if self.shell is not None: + packet_filter = packet_filter & ( + self.last_line_interaction_shell_id == self.shell + ) + self.last_line_in = self.lines.iloc[ self.last_line_interaction_in_id[packet_filter] ] diff --git a/tardis/io/model_reader.py b/tardis/io/model_reader.py index 94c360c9f71..d5536a1dc35 100644 --- a/tardis/io/model_reader.py +++ b/tardis/io/model_reader.py @@ -600,6 +600,7 @@ def transport_to_dict(transport): "virt_packet_last_interaction_type": transport.virt_packet_last_interaction_type, "virt_packet_last_line_interaction_in_id": transport.virt_packet_last_line_interaction_in_id, "virt_packet_last_line_interaction_out_id": transport.virt_packet_last_line_interaction_out_id, + "virt_packet_last_line_interaction_shell_id": transport.virt_packet_last_line_interaction_shell_id, "virt_packet_nus": transport.virt_packet_nus, "volume_cgs": transport.volume, } @@ -798,6 +799,9 @@ def transport_from_hdf(fname): new_transport.virt_packet_last_line_interaction_out_id = d[ "virt_packet_last_line_interaction_out_id" ] + new_transport.virt_packet_last_line_interaction_shell_id = d[ + "virt_packet_last_line_interaction_shell_id" + ] new_transport.virt_packet_nus = d["virt_packet_nus"] new_transport.volume = d["volume_cgs"] diff --git a/tardis/io/tests/test_model_reader.py b/tardis/io/tests/test_model_reader.py index 3967cbb7821..fe1af51f6a2 100644 --- a/tardis/io/tests/test_model_reader.py +++ b/tardis/io/tests/test_model_reader.py @@ -436,6 +436,10 @@ def test_store_transport_to_hdf(simulation_verysimple, tmp_path): f["transport/virt_packet_last_line_interaction_out_id"], transport_data["virt_packet_last_line_interaction_out_id"], ) + assert np.array_equal( + f["transport/virt_packet_last_line_interaction_shell_id"], + transport_data["virt_packet_last_line_interaction_shell_id"], + ) assert np.array_equal( f["transport/virt_packet_nus"], transport_data["virt_packet_nus"] ) diff --git a/tardis/montecarlo/base.py b/tardis/montecarlo/base.py index b2e800cdd98..b116d73f265 100644 --- a/tardis/montecarlo/base.py +++ b/tardis/montecarlo/base.py @@ -65,6 +65,7 @@ class MontecarloTransport(HDFWriterMixin): "virt_packet_last_interaction_type", "virt_packet_last_line_interaction_in_id", "virt_packet_last_line_interaction_out_id", + "virt_packet_last_line_interaction_shell_id", ] hdf_name = "transport" @@ -120,6 +121,7 @@ def __init__( self.virt_packet_last_interaction_in_nu = np.ones(2) * -1.0 self.virt_packet_last_line_interaction_in_id = np.ones(2) * -1 self.virt_packet_last_line_interaction_out_id = np.ones(2) * -1 + self.virt_packet_last_line_interaction_shell_id = np.ones(2) * -1 self.virt_packet_nus = np.ones(2) * -1.0 self.virt_packet_energies = np.ones(2) * -1.0 self.virt_packet_initial_rs = np.ones(2) * -1.0 diff --git a/tardis/montecarlo/montecarlo_numba/base.py b/tardis/montecarlo/montecarlo_numba/base.py index 1681075dd54..bf9fd3b847a 100644 --- a/tardis/montecarlo/montecarlo_numba/base.py +++ b/tardis/montecarlo/montecarlo_numba/base.py @@ -80,6 +80,7 @@ def montecarlo_radial1d( last_interaction_in_nu, last_line_interaction_in_id, last_line_interaction_out_id, + last_line_interaction_shell_id, virt_packet_nus, virt_packet_energies, virt_packet_initial_mus, @@ -88,6 +89,7 @@ def montecarlo_radial1d( virt_packet_last_interaction_type, virt_packet_last_line_interaction_in_id, virt_packet_last_line_interaction_out_id, + virt_packet_last_line_interaction_shell_id, rpacket_trackers, ) = montecarlo_main_loop( packet_collection, @@ -109,6 +111,7 @@ def montecarlo_radial1d( transport.last_interaction_in_nu = last_interaction_in_nu transport.last_line_interaction_in_id = last_line_interaction_in_id transport.last_line_interaction_out_id = last_line_interaction_out_id + transport.last_line_interaction_shell_id = last_line_interaction_shell_id if montecarlo_configuration.VPACKET_LOGGING and number_of_vpackets > 0: transport.virt_packet_nus = np.concatenate(virt_packet_nus).ravel() @@ -133,6 +136,9 @@ def montecarlo_radial1d( transport.virt_packet_last_line_interaction_out_id = np.concatenate( virt_packet_last_line_interaction_out_id ).ravel() + transport.virt_packet_last_line_interaction_shell_id = np.concatenate( + virt_packet_last_line_interaction_shell_id + ).ravel() update_iterations_pbar(1) refresh_packet_pbar() # Condition for Checking if RPacket Tracking is enabled @@ -188,6 +194,9 @@ def montecarlo_main_loop( last_line_interaction_out_ids = ( np.ones_like(packet_collection.packets_output_nu, dtype=np.int64) * -1 ) + last_line_interaction_shell_ids = ( + np.ones_like(packet_collection.packets_output_nu, dtype=np.int64) * -1 + ) v_packets_energy_hist = np.zeros_like(spectrum_frequency) delta_nu = spectrum_frequency[1] - spectrum_frequency[0] @@ -239,10 +248,10 @@ def montecarlo_main_loop( virt_packet_last_interaction_type = [] virt_packet_last_line_interaction_in_id = [] virt_packet_last_line_interaction_out_id = [] + virt_packet_last_line_interaction_shell_id = [] for i in prange(len(output_nus)): tid = get_thread_id() if show_progress_bars: - if tid == main_thread_id: with objmode: update_amount = 1 * n_threads @@ -281,6 +290,9 @@ def montecarlo_main_loop( last_interaction_in_nus[i] = r_packet.last_interaction_in_nu last_line_interaction_in_ids[i] = r_packet.last_line_interaction_in_id last_line_interaction_out_ids[i] = r_packet.last_line_interaction_out_id + last_line_interaction_shell_ids[ + i + ] = r_packet.last_line_interaction_shell_id if r_packet.status == PacketStatus.REABSORBED: output_energies[i] = -r_packet.energy @@ -360,6 +372,13 @@ def montecarlo_main_loop( ] ) ) + virt_packet_last_line_interaction_shell_id.append( + np.ascontiguousarray( + vpacket_collection.last_interaction_shell_id[ + : vpacket_collection.idx + ] + ) + ) if montecarlo_configuration.RPACKET_TRACKING: for rpacket_tracker in rpacket_trackers: @@ -373,6 +392,7 @@ def montecarlo_main_loop( last_interaction_in_nus, last_line_interaction_in_ids, last_line_interaction_out_ids, + last_line_interaction_shell_ids, virt_packet_nus, virt_packet_energies, virt_packet_initial_mus, @@ -381,5 +401,6 @@ def montecarlo_main_loop( virt_packet_last_interaction_type, virt_packet_last_line_interaction_in_id, virt_packet_last_line_interaction_out_id, + virt_packet_last_line_interaction_shell_id, rpacket_trackers, ) diff --git a/tardis/montecarlo/montecarlo_numba/interaction.py b/tardis/montecarlo/montecarlo_numba/interaction.py index d56464556c8..aaf7225ab2b 100644 --- a/tardis/montecarlo/montecarlo_numba/interaction.py +++ b/tardis/montecarlo/montecarlo_numba/interaction.py @@ -447,6 +447,7 @@ def line_emission(r_packet, emission_line_id, time_explosion, numba_plasma): """ r_packet.last_line_interaction_out_id = emission_line_id + r_packet.last_line_interaction_shell_id = r_packet.current_shell_id if emission_line_id != r_packet.next_line_id: pass diff --git a/tardis/montecarlo/montecarlo_numba/numba_interface.py b/tardis/montecarlo/montecarlo_numba/numba_interface.py index d248fc3df01..ee12a3ac4fe 100644 --- a/tardis/montecarlo/montecarlo_numba/numba_interface.py +++ b/tardis/montecarlo/montecarlo_numba/numba_interface.py @@ -310,6 +310,7 @@ def __init__( ("last_interaction_type", int64[:]), ("last_interaction_in_id", int64[:]), ("last_interaction_out_id", int64[:]), + ("last_interaction_shell_id", int64[:]), ] @@ -344,6 +345,9 @@ def __init__( self.last_interaction_out_id = -1 * np.ones( temporary_v_packet_bins, dtype=np.int64 ) + self.last_interaction_shell_id = -1 * np.ones( + temporary_v_packet_bins, dtype=np.int64 + ) self.idx = 0 self.rpacket_index = rpacket_index self.length = temporary_v_packet_bins @@ -358,6 +362,7 @@ def set_properties( last_interaction_type, last_interaction_in_id, last_interaction_out_id, + last_interaction_shell_id, ): if self.idx >= self.length: temp_length = self.length * 2 + self.number_of_vpackets @@ -371,6 +376,9 @@ def set_properties( temp_last_interaction_type = np.empty(temp_length, dtype=np.int64) temp_last_interaction_in_id = np.empty(temp_length, dtype=np.int64) temp_last_interaction_out_id = np.empty(temp_length, dtype=np.int64) + temp_last_interaction_shell_id = np.empty( + temp_length, dtype=np.int64 + ) temp_nus[: self.length] = self.nus temp_energies[: self.length] = self.energies @@ -388,6 +396,9 @@ def set_properties( temp_last_interaction_out_id[ : self.length ] = self.last_interaction_out_id + temp_last_interaction_shell_id[ + : self.length + ] = self.last_interaction_shell_id self.nus = temp_nus self.energies = temp_energies @@ -397,6 +408,7 @@ def set_properties( self.last_interaction_type = temp_last_interaction_type self.last_interaction_in_id = temp_last_interaction_in_id self.last_interaction_out_id = temp_last_interaction_out_id + self.last_interaction_shell_id = temp_last_interaction_shell_id self.length = temp_length self.nus[self.idx] = nu @@ -407,6 +419,7 @@ def set_properties( self.last_interaction_type[self.idx] = last_interaction_type self.last_interaction_in_id[self.idx] = last_interaction_in_id self.last_interaction_out_id[self.idx] = last_interaction_out_id + self.last_interaction_shell_id[self.idx] = last_interaction_shell_id self.idx += 1 diff --git a/tardis/montecarlo/montecarlo_numba/r_packet.py b/tardis/montecarlo/montecarlo_numba/r_packet.py index 88f985cf927..719e9193a14 100644 --- a/tardis/montecarlo/montecarlo_numba/r_packet.py +++ b/tardis/montecarlo/montecarlo_numba/r_packet.py @@ -43,6 +43,7 @@ class PacketStatus(IntEnum): ("last_interaction_in_nu", float64), ("last_line_interaction_in_id", int64), ("last_line_interaction_out_id", int64), + ("last_line_interaction_shell_id", int64), ] @@ -61,6 +62,7 @@ def __init__(self, r, mu, nu, energy, seed, index=0): self.last_interaction_in_nu = 0.0 self.last_line_interaction_in_id = -1 self.last_line_interaction_out_id = -1 + self.last_line_interaction_shell_id = -1 def initialize_line_id(self, numba_plasma, numba_model): inverse_line_list_nu = numba_plasma.line_list_nu[::-1] diff --git a/tardis/montecarlo/montecarlo_numba/tests/test_numba_interface.py b/tardis/montecarlo/montecarlo_numba/tests/test_numba_interface.py index 207aa1a60a8..232866600ff 100644 --- a/tardis/montecarlo/montecarlo_numba/tests/test_numba_interface.py +++ b/tardis/montecarlo/montecarlo_numba/tests/test_numba_interface.py @@ -60,7 +60,6 @@ def test_configuration_initialize(): def test_VPacketCollection_set_properties(verysimple_3vpacket_collection): - assert verysimple_3vpacket_collection.length == 0 nus = [3.0e15, 0.0, 1e15, 1e5] @@ -73,6 +72,7 @@ def test_VPacketCollection_set_properties(verysimple_3vpacket_collection): last_interaction_types = np.array([1, 1, 3, 2], dtype=np.int64) last_interaction_in_ids = np.array([100, 0, 1, 1000], dtype=np.int64) last_interaction_out_ids = np.array([1201, 123, 545, 1232], dtype=np.int64) + last_interaction_shell_ids = np.array([2, -1, 6, 0], dtype=np.int64) for ( nu, @@ -83,6 +83,7 @@ def test_VPacketCollection_set_properties(verysimple_3vpacket_collection): last_interaction_type, last_interaction_in_id, last_interaction_out_id, + last_interaction_shell_id, ) in zip( nus, energies, @@ -92,6 +93,7 @@ def test_VPacketCollection_set_properties(verysimple_3vpacket_collection): last_interaction_types, last_interaction_in_ids, last_interaction_out_ids, + last_interaction_shell_ids, ): verysimple_3vpacket_collection.set_properties( nu, @@ -102,6 +104,7 @@ def test_VPacketCollection_set_properties(verysimple_3vpacket_collection): last_interaction_type, last_interaction_in_id, last_interaction_out_id, + last_interaction_shell_id, ) npt.assert_array_equal( @@ -152,4 +155,10 @@ def test_VPacketCollection_set_properties(verysimple_3vpacket_collection): ], last_interaction_out_ids, ) + npt.assert_array_equal( + verysimple_3vpacket_collection.last_interaction_shell_id[ + : verysimple_3vpacket_collection.idx + ], + last_interaction_shell_ids, + ) assert verysimple_3vpacket_collection.length == 9 diff --git a/tardis/montecarlo/montecarlo_numba/vpacket.py b/tardis/montecarlo/montecarlo_numba/vpacket.py index 10a137400dd..69276009273 100644 --- a/tardis/montecarlo/montecarlo_numba/vpacket.py +++ b/tardis/montecarlo/montecarlo_numba/vpacket.py @@ -242,7 +242,6 @@ def trace_vpacket_volley( if (r_packet.nu < vpacket_collection.v_packet_spawn_start_frequency) or ( r_packet.nu > vpacket_collection.v_packet_spawn_end_frequency ): - return no_of_vpackets = vpacket_collection.number_of_vpackets @@ -335,4 +334,5 @@ def trace_vpacket_volley( r_packet.last_interaction_type, r_packet.last_line_interaction_in_id, r_packet.last_line_interaction_out_id, + r_packet.last_line_interaction_shell_id, ) diff --git a/tardis/tests/test_tardis_full.py b/tardis/tests/test_tardis_full.py index d477a4e1840..bfea4b35c6e 100644 --- a/tardis/tests/test_tardis_full.py +++ b/tardis/tests/test_tardis_full.py @@ -100,6 +100,7 @@ def test_transport_properties(self, transport): ("virt_packet_last_interaction_type", virt_type), ("virt_packet_last_line_interaction_in_id", virt_type), ("virt_packet_last_line_interaction_out_id", virt_type), + ("virt_packet_last_line_interaction_shell_id", virt_type), ("virt_packet_last_interaction_in_nu", virt_type), ("virt_packet_nus", virt_type), ("virt_packet_energies", virt_type), diff --git a/tardis/transport/r_packet_transport.py b/tardis/transport/r_packet_transport.py index 7f81de251d5..b050b43989b 100644 --- a/tardis/transport/r_packet_transport.py +++ b/tardis/transport/r_packet_transport.py @@ -74,7 +74,6 @@ def trace_packet( # - do not remove last_line_id = len(numba_plasma.line_list_nu) - 1 for cur_line_id in range(start_line_id, len(numba_plasma.line_list_nu)): - # Going through the lines nu_line = numba_plasma.line_list_nu[cur_line_id] @@ -107,7 +106,6 @@ def trace_packet( distance = min(distance_trace, distance_boundary, distance_continuum) if distance_trace != 0: - if distance == distance_boundary: interaction_type = InteractionType.BOUNDARY # BOUNDARY r_packet.next_line_id = cur_line_id @@ -143,6 +141,7 @@ def trace_packet( interaction_type = InteractionType.LINE # Line r_packet.last_interaction_in_nu = r_packet.nu r_packet.last_line_interaction_in_id = cur_line_id + r_packet.last_line_interaction_shell_id = r_packet.current_shell_id r_packet.next_line_id = cur_line_id distance = distance_trace break