diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 918520ec..df8aaae3 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -203,6 +203,18 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( frames = allocateEmptyHWCTensor(height, width, options.device, numFrames); } +bool VideoDecoder::SwsContextKey::operator==( + const VideoDecoder::SwsContextKey& other) { + return decodedWidth == other.decodedWidth && decodedHeight == decodedHeight && + decodedFormat == other.decodedFormat && + outputWidth == other.outputWidth && outputHeight == other.outputHeight; +} + +bool VideoDecoder::SwsContextKey::operator!=( + const VideoDecoder::SwsContextKey& other) { + return !(*this == other); +} + VideoDecoder::VideoDecoder() {} void VideoDecoder::initializeDecoder() { @@ -1340,7 +1352,11 @@ int VideoDecoder::convertFrameToBufferUsingSwsScale( int expectedOutputHeight = outputTensor.sizes()[0]; int expectedOutputWidth = outputTensor.sizes()[1]; auto curFrameSwsContextKey = SwsContextKey{ - frame->width, frame->height, frameFormat, expectedOutputWidth, expectedOutputHeight}; + frame->width, + frame->height, + frameFormat, + expectedOutputWidth, + expectedOutputHeight}; if (activeStream.swsContext.get() == nullptr || activeStream.swsContextKey != curFrameSwsContextKey) { SwsContext* swsContext = sws_getContext( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index aefeb1fc..8ae7cc17 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -323,8 +323,8 @@ class VideoDecoder { AVPixelFormat decodedFormat; int outputWidth; int outputHeight; - bool operator==(const SwsContextKey&) const = default; - bool operator!=(const SwsContextKey&) const = default; + bool operator==(const SwsContextKey&); + bool operator!=(const SwsContextKey&); }; // Stores information for each stream. struct StreamInfo { diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 19ca1a20..4ad6fed1 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -11,7 +11,14 @@ from torchcodec.decoders import _core, VideoDecoder -from ..utils import assert_tensor_close, assert_tensor_equal, H265_VIDEO, NASA_VIDEO +from ..utils import ( + assert_tensor_close, + assert_tensor_equal, + cpu_and_cuda, + get_frame_compare_function, + H265_VIDEO, + NASA_VIDEO, +) class TestVideoDecoder: @@ -56,18 +63,22 @@ def test_create_fails(self): decoder = VideoDecoder(NASA_VIDEO.path, stream_index=1) # noqa @pytest.mark.parametrize("num_ffmpeg_threads", (1, 4)) - def test_getitem_int(self, num_ffmpeg_threads): - decoder = VideoDecoder(NASA_VIDEO.path, num_ffmpeg_threads=num_ffmpeg_threads) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_getitem_int(self, num_ffmpeg_threads, device): + decoder = VideoDecoder( + NASA_VIDEO.path, num_ffmpeg_threads=num_ffmpeg_threads, device=device + ) - ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0) - ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1) - ref_frame180 = NASA_VIDEO.get_frame_data_by_index(180) - ref_frame_last = NASA_VIDEO.get_frame_data_by_index(289) + ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0).to(device) + ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device) + ref_frame180 = NASA_VIDEO.get_frame_data_by_index(180).to(device) + ref_frame_last = NASA_VIDEO.get_frame_data_by_index(289).to(device) - assert_tensor_equal(ref_frame0, decoder[0]) - assert_tensor_equal(ref_frame1, decoder[1]) - assert_tensor_equal(ref_frame180, decoder[180]) - assert_tensor_equal(ref_frame_last, decoder[-1]) + frame_compare_function = get_frame_compare_function(device) + frame_compare_function(ref_frame0, decoder[0]) + frame_compare_function(ref_frame1, decoder[1]) + frame_compare_function(ref_frame180, decoder[180]) + frame_compare_function(ref_frame_last, decoder[-1]) def test_getitem_numpy_int(self): decoder = VideoDecoder(NASA_VIDEO.path) @@ -99,12 +110,14 @@ def test_getitem_numpy_int(self): assert_tensor_equal(ref_frame1, decoder[numpy.uint32(1)]) assert_tensor_equal(ref_frame180, decoder[numpy.uint32(180)]) - def test_getitem_slice(self): - decoder = VideoDecoder(NASA_VIDEO.path) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_getitem_slice(self, device): + decoder = VideoDecoder(NASA_VIDEO.path, device=device) + frame_compare_function = get_frame_compare_function(device) # ensure that the degenerate case of a range of size 1 works - ref0 = NASA_VIDEO.get_frame_data_by_range(0, 1) + ref0 = NASA_VIDEO.get_frame_data_by_range(0, 1).to(device) slice0 = decoder[0:1] assert slice0.shape == torch.Size( [ @@ -114,9 +127,9 @@ def test_getitem_slice(self): NASA_VIDEO.width, ] ) - assert_tensor_equal(ref0, slice0) + frame_compare_function(ref0, slice0) - ref4 = NASA_VIDEO.get_frame_data_by_range(4, 5) + ref4 = NASA_VIDEO.get_frame_data_by_range(4, 5).to(device) slice4 = decoder[4:5] assert slice4.shape == torch.Size( [ @@ -126,9 +139,9 @@ def test_getitem_slice(self): NASA_VIDEO.width, ] ) - assert_tensor_equal(ref4, slice4) + frame_compare_function(ref4, slice4) - ref8 = NASA_VIDEO.get_frame_data_by_range(8, 9) + ref8 = NASA_VIDEO.get_frame_data_by_range(8, 9).to(device) slice8 = decoder[8:9] assert slice8.shape == torch.Size( [ @@ -138,9 +151,9 @@ def test_getitem_slice(self): NASA_VIDEO.width, ] ) - assert_tensor_equal(ref8, slice8) + frame_compare_function(ref8, slice8) - ref180 = NASA_VIDEO.get_frame_data_by_index(180) + ref180 = NASA_VIDEO.get_frame_data_by_index(180).to(device) slice180 = decoder[180:181] assert slice180.shape == torch.Size( [ @@ -150,10 +163,10 @@ def test_getitem_slice(self): NASA_VIDEO.width, ] ) - assert_tensor_equal(ref180, slice180[0]) + frame_compare_function(ref180, slice180[0]) # contiguous ranges - ref0_9 = NASA_VIDEO.get_frame_data_by_range(0, 9) + ref0_9 = NASA_VIDEO.get_frame_data_by_range(0, 9).to(device) slice0_9 = decoder[0:9] assert slice0_9.shape == torch.Size( [ @@ -163,9 +176,9 @@ def test_getitem_slice(self): NASA_VIDEO.width, ] ) - assert_tensor_equal(ref0_9, slice0_9) + frame_compare_function(ref0_9, slice0_9) - ref4_8 = NASA_VIDEO.get_frame_data_by_range(4, 8) + ref4_8 = NASA_VIDEO.get_frame_data_by_range(4, 8).to(device) slice4_8 = decoder[4:8] assert slice4_8.shape == torch.Size( [ @@ -175,10 +188,10 @@ def test_getitem_slice(self): NASA_VIDEO.width, ] ) - assert_tensor_equal(ref4_8, slice4_8) + frame_compare_function(ref4_8, slice4_8) # ranges with a stride - ref15_35 = NASA_VIDEO.get_frame_data_by_range(15, 36, 5) + ref15_35 = NASA_VIDEO.get_frame_data_by_range(15, 36, 5).to(device) slice15_35 = decoder[15:36:5] assert slice15_35.shape == torch.Size( [ @@ -188,9 +201,9 @@ def test_getitem_slice(self): NASA_VIDEO.width, ] ) - assert_tensor_equal(ref15_35, slice15_35) + frame_compare_function(ref15_35, slice15_35) - ref0_9_2 = NASA_VIDEO.get_frame_data_by_range(0, 9, 2) + ref0_9_2 = NASA_VIDEO.get_frame_data_by_range(0, 9, 2).to(device) slice0_9_2 = decoder[0:9:2] assert slice0_9_2.shape == torch.Size( [ @@ -200,10 +213,10 @@ def test_getitem_slice(self): NASA_VIDEO.width, ] ) - assert_tensor_equal(ref0_9_2, slice0_9_2) + frame_compare_function(ref0_9_2, slice0_9_2) # negative numbers in the slice - ref386_389 = NASA_VIDEO.get_frame_data_by_range(386, 390) + ref386_389 = NASA_VIDEO.get_frame_data_by_range(386, 390).to(device) slice386_389 = decoder[-4:] assert slice386_389.shape == torch.Size( [ @@ -213,18 +226,18 @@ def test_getitem_slice(self): NASA_VIDEO.width, ] ) - assert_tensor_equal(ref386_389, slice386_389) + frame_compare_function(ref386_389, slice386_389) # an empty range is valid! empty_frame = decoder[5:5] - assert_tensor_equal(empty_frame, NASA_VIDEO.empty_chw_tensor) + frame_compare_function(empty_frame, NASA_VIDEO.empty_chw_tensor.to(device)) # slices that are out-of-range are also valid - they return an empty tensor also_empty = decoder[10000:] - assert_tensor_equal(also_empty, NASA_VIDEO.empty_chw_tensor) + frame_compare_function(also_empty, NASA_VIDEO.empty_chw_tensor.to(device)) # should be just a copy - all_frames = decoder[:] + all_frames = decoder[:].to(device) assert all_frames.shape == torch.Size( [ len(decoder), @@ -234,10 +247,11 @@ def test_getitem_slice(self): ] ) for sliced, ref in zip(all_frames, decoder): - assert_tensor_equal(sliced, ref) + frame_compare_function(sliced, ref) - def test_getitem_fails(self): - decoder = VideoDecoder(NASA_VIDEO.path) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_getitem_fails(self, device): + decoder = VideoDecoder(NASA_VIDEO.path, device=device) with pytest.raises(IndexError, match="out of bounds"): frame = decoder[1000] # noqa @@ -251,35 +265,37 @@ def test_getitem_fails(self): with pytest.raises(TypeError, match="Unsupported key type"): frame = decoder[2.3] # noqa - def test_iteration(self): - decoder = VideoDecoder(NASA_VIDEO.path) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_iteration(self, device): + decoder = VideoDecoder(NASA_VIDEO.path, device=device) - ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0) - ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1) - ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9) - ref_frame35 = NASA_VIDEO.get_frame_data_by_index(35) - ref_frame180 = NASA_VIDEO.get_frame_data_by_index(180) - ref_frame_last = NASA_VIDEO.get_frame_data_by_index(289) + ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0).to(device) + ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device) + ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9).to(device) + ref_frame35 = NASA_VIDEO.get_frame_data_by_index(35).to(device) + ref_frame180 = NASA_VIDEO.get_frame_data_by_index(180).to(device) + ref_frame_last = NASA_VIDEO.get_frame_data_by_index(289).to(device) + frame_compare_function = get_frame_compare_function(device) # Access an arbitrary frame to make sure that the later iteration # still works as expected. The underlying C++ decoder object is # actually stateful, and accessing a frame will move its internal # cursor. - assert_tensor_equal(ref_frame35, decoder[35]) + frame_compare_function(ref_frame35, decoder[35]) for i, frame in enumerate(decoder): if i == 0: - assert_tensor_equal(ref_frame0, frame) + frame_compare_function(ref_frame0, frame) elif i == 1: - assert_tensor_equal(ref_frame1, frame) + frame_compare_function(ref_frame1, frame) elif i == 9: - assert_tensor_equal(ref_frame9, frame) + frame_compare_function(ref_frame9, frame) elif i == 35: - assert_tensor_equal(ref_frame35, frame) + frame_compare_function(ref_frame35, frame) elif i == 180: - assert_tensor_equal(ref_frame180, frame) + frame_compare_function(ref_frame180, frame) elif i == 389: - assert_tensor_equal(ref_frame_last, frame) + frame_compare_function(ref_frame_last, frame) def test_iteration_slow(self): decoder = VideoDecoder(NASA_VIDEO.path) @@ -295,13 +311,15 @@ def test_iteration_slow(self): assert iterations == len(decoder) == 390 - def test_get_frame_at(self): - decoder = VideoDecoder(NASA_VIDEO.path) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frame_at(self, device): + decoder = VideoDecoder(NASA_VIDEO.path, device=device) + frame_compare_function = get_frame_compare_function(device) - ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9) + ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9).to(device) frame9 = decoder.get_frame_at(9) - assert_tensor_equal(ref_frame9, frame9.data) + frame_compare_function(ref_frame9, frame9.data) assert isinstance(frame9.pts_seconds, float) expected_frame_info = NASA_VIDEO.get_frame_info(9) assert frame9.pts_seconds == pytest.approx(expected_frame_info.pts_seconds) @@ -312,22 +330,23 @@ def test_get_frame_at(self): # test numpy.int64 frame9 = decoder.get_frame_at(numpy.int64(9)) - assert_tensor_equal(ref_frame9, frame9.data) + frame_compare_function(ref_frame9, frame9.data) # test numpy.int32 frame9 = decoder.get_frame_at(numpy.int32(9)) - assert_tensor_equal(ref_frame9, frame9.data) + frame_compare_function(ref_frame9, frame9.data) # test numpy.uint64 frame9 = decoder.get_frame_at(numpy.uint64(9)) - assert_tensor_equal(ref_frame9, frame9.data) + frame_compare_function(ref_frame9, frame9.data) # test numpy.uint32 frame9 = decoder.get_frame_at(numpy.uint32(9)) - assert_tensor_equal(ref_frame9, frame9.data) + frame_compare_function(ref_frame9, frame9.data) - def test_get_frame_at_tuple_unpacking(self): - decoder = VideoDecoder(NASA_VIDEO.path) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frame_at_tuple_unpacking(self, device): + decoder = VideoDecoder(NASA_VIDEO.path, device=device) frame = decoder.get_frame_at(50) data, pts, duration = decoder.get_frame_at(50) @@ -336,8 +355,9 @@ def test_get_frame_at_tuple_unpacking(self): assert frame.pts_seconds == pts assert frame.duration_seconds == duration - def test_get_frame_at_fails(self): - decoder = VideoDecoder(NASA_VIDEO.path) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frame_at_fails(self, device): + decoder = VideoDecoder(NASA_VIDEO.path, device=device) with pytest.raises(IndexError, match="out of bounds"): frame = decoder.get_frame_at(-1) # noqa @@ -345,16 +365,23 @@ def test_get_frame_at_fails(self): with pytest.raises(IndexError, match="out of bounds"): frame = decoder.get_frame_at(10000) # noqa - def test_get_frames_at(self): - decoder = VideoDecoder(NASA_VIDEO.path) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frames_at(self, device): + decoder = VideoDecoder(NASA_VIDEO.path, device=device) + frame_compare_function = get_frame_compare_function(device) frames = decoder.get_frames_at([35, 25]) assert isinstance(frames, FrameBatch) - assert_tensor_equal(frames[0].data, NASA_VIDEO.get_frame_data_by_index(35)) - assert_tensor_equal(frames[1].data, NASA_VIDEO.get_frame_data_by_index(25)) + frame_compare_function( + frames[0].data, NASA_VIDEO.get_frame_data_by_index(35).to(device) + ) + frame_compare_function( + frames[1].data, NASA_VIDEO.get_frame_data_by_index(25).to(device) + ) + assert frames.pts_seconds.device.type == "cpu" expected_pts_seconds = torch.tensor( [ NASA_VIDEO.get_frame_info(35).pts_seconds, @@ -366,6 +393,7 @@ def test_get_frames_at(self): frames.pts_seconds, expected_pts_seconds, atol=1e-4, rtol=0 ) + assert frames.duration_seconds.device.type == "cpu" expected_duration_seconds = torch.tensor( [ NASA_VIDEO.get_frame_info(35).duration_seconds, @@ -377,8 +405,9 @@ def test_get_frames_at(self): frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0 ) - def test_get_frames_at_fails(self): - decoder = VideoDecoder(NASA_VIDEO.path) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frames_at_fails(self, device): + decoder = VideoDecoder(NASA_VIDEO.path, device=device) with pytest.raises(RuntimeError, match="Invalid frame index=-1"): decoder.get_frames_at([-1]) @@ -389,17 +418,19 @@ def test_get_frames_at_fails(self): with pytest.raises(RuntimeError, match="Expected a value of type"): decoder.get_frames_at([0.3]) - def test_get_frame_played_at(self): - decoder = VideoDecoder(NASA_VIDEO.path) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frame_played_at(self, device): + decoder = VideoDecoder(NASA_VIDEO.path, device=device) + frame_compare_function = get_frame_compare_function(device) - ref_frame_played_at_6 = NASA_VIDEO.get_frame_data_by_index(180) - assert_tensor_equal( + ref_frame_played_at_6 = NASA_VIDEO.get_frame_data_by_index(180).to(device) + frame_compare_function( ref_frame_played_at_6, decoder.get_frame_played_at(6.006).data ) - assert_tensor_equal( + frame_compare_function( ref_frame_played_at_6, decoder.get_frame_played_at(6.02).data ) - assert_tensor_equal( + frame_compare_function( ref_frame_played_at_6, decoder.get_frame_played_at(6.039366).data ) assert isinstance(decoder.get_frame_played_at(6.02).pts_seconds, float) @@ -407,12 +438,16 @@ def test_get_frame_played_at(self): def test_get_frame_played_at_h265(self): # Non-regression test for https://github.com/pytorch/torchcodec/issues/179 + # We don't parametrize with CUDA because the current GPUs on CI do not + # support x265: + # https://github.com/pytorch/torchcodec/pull/350#issuecomment-2465011730 decoder = VideoDecoder(H265_VIDEO.path) ref_frame6 = H265_VIDEO.get_frame_data_by_index(5) assert_tensor_equal(ref_frame6, decoder.get_frame_played_at(0.5).data) - def test_get_frame_played_at_fails(self): - decoder = VideoDecoder(NASA_VIDEO.path) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frame_played_at_fails(self, device): + decoder = VideoDecoder(NASA_VIDEO.path, device=device) with pytest.raises(IndexError, match="Invalid pts in seconds"): frame = decoder.get_frame_played_at(-1.0) # noqa @@ -420,9 +455,11 @@ def test_get_frame_played_at_fails(self): with pytest.raises(IndexError, match="Invalid pts in seconds"): frame = decoder.get_frame_played_at(100.0) # noqa - def test_get_frames_played_at(self): + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frames_played_at(self, device): - decoder = VideoDecoder(NASA_VIDEO.path) + decoder = VideoDecoder(NASA_VIDEO.path, device=device) + frame_compare_function = get_frame_compare_function(device) # Note: We know the frame at ~0.84s has index 25, the one at 1.16s has # index 35. We use those indices as reference to test against. @@ -433,10 +470,12 @@ def test_get_frames_played_at(self): assert isinstance(frames, FrameBatch) for i in range(len(reference_indices)): - assert_tensor_equal( - frames.data[i], NASA_VIDEO.get_frame_data_by_index(reference_indices[i]) + frame_compare_function( + frames.data[i], + NASA_VIDEO.get_frame_data_by_index(reference_indices[i]).to(device), ) + assert frames.pts_seconds.device.type == "cpu" expected_pts_seconds = torch.tensor( [NASA_VIDEO.get_frame_info(i).pts_seconds for i in reference_indices], dtype=torch.float64, @@ -445,6 +484,7 @@ def test_get_frames_played_at(self): frames.pts_seconds, expected_pts_seconds, atol=1e-4, rtol=0 ) + assert frames.duration_seconds.device.type == "cpu" expected_duration_seconds = torch.tensor( [NASA_VIDEO.get_frame_info(i).duration_seconds for i in reference_indices], dtype=torch.float64, @@ -453,8 +493,9 @@ def test_get_frames_played_at(self): frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0 ) - def test_get_frames_played_at_fails(self): - decoder = VideoDecoder(NASA_VIDEO.path) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frames_played_at_fails(self, device): + decoder = VideoDecoder(NASA_VIDEO.path, device=device) with pytest.raises(RuntimeError, match="must be in range"): decoder.get_frames_played_at([-1]) @@ -465,21 +506,28 @@ def test_get_frames_played_at_fails(self): with pytest.raises(RuntimeError, match="Expected a value of type"): decoder.get_frames_played_at(["bad"]) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("stream_index", [0, 3, None]) - def test_get_frames_in_range(self, stream_index): - decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index) + def test_get_frames_in_range(self, stream_index, device): + decoder = VideoDecoder( + NASA_VIDEO.path, stream_index=stream_index, device=device + ) + frame_compare_function = get_frame_compare_function(device) # test degenerate case where we only actually get 1 frame ref_frames9 = NASA_VIDEO.get_frame_data_by_range( start=9, stop=10, stream_index=stream_index - ) + ).to(device) frames9 = decoder.get_frames_in_range(start=9, stop=10) - assert_tensor_equal(ref_frames9, frames9.data) + frame_compare_function(ref_frames9, frames9.data) + + assert frames9.pts_seconds.device.type == "cpu" assert frames9.pts_seconds[0].item() == pytest.approx( NASA_VIDEO.get_frame_info(9, stream_index=stream_index).pts_seconds, rel=1e-3, ) + assert frames9.duration_seconds.device.type == "cpu" assert frames9.duration_seconds[0].item() == pytest.approx( NASA_VIDEO.get_frame_info(9, stream_index=stream_index).duration_seconds, rel=1e-3, @@ -488,7 +536,7 @@ def test_get_frames_in_range(self, stream_index): # test simple ranges ref_frames0_9 = NASA_VIDEO.get_frame_data_by_range( start=0, stop=10, stream_index=stream_index - ) + ).to(device) frames0_9 = decoder.get_frames_in_range(start=0, stop=10) assert frames0_9.data.shape == torch.Size( [ @@ -498,7 +546,7 @@ def test_get_frames_in_range(self, stream_index): NASA_VIDEO.get_width(stream_index=stream_index), ] ) - assert_tensor_equal(ref_frames0_9, frames0_9.data) + frame_compare_function(ref_frames0_9, frames0_9.data) assert_tensor_close( NASA_VIDEO.get_pts_seconds_by_range(0, 10, stream_index=stream_index), frames0_9.pts_seconds, @@ -511,7 +559,7 @@ def test_get_frames_in_range(self, stream_index): # test steps ref_frames0_8_2 = NASA_VIDEO.get_frame_data_by_range( start=0, stop=10, step=2, stream_index=stream_index - ) + ).to(device) frames0_8_2 = decoder.get_frames_in_range(start=0, stop=10, step=2) assert frames0_8_2.data.shape == torch.Size( [ @@ -521,7 +569,7 @@ def test_get_frames_in_range(self, stream_index): NASA_VIDEO.get_width(stream_index=stream_index), ] ) - assert_tensor_equal(ref_frames0_8_2, frames0_8_2.data) + frame_compare_function(ref_frames0_8_2, frames0_8_2.data) assert_tensor_close( NASA_VIDEO.get_pts_seconds_by_range(0, 10, 2, stream_index=stream_index), frames0_8_2.pts_seconds, @@ -537,13 +585,13 @@ def test_get_frames_in_range(self, stream_index): frames0_8_2 = decoder.get_frames_in_range( start=numpy.int64(0), stop=numpy.int64(10), step=numpy.int64(2) ) - assert_tensor_equal(ref_frames0_8_2, frames0_8_2.data) + frame_compare_function(ref_frames0_8_2, frames0_8_2.data) # an empty range is valid! empty_frames = decoder.get_frames_in_range(5, 5) assert_tensor_equal( empty_frames.data, - NASA_VIDEO.get_empty_chw_tensor(stream_index=stream_index), + NASA_VIDEO.get_empty_chw_tensor(stream_index=stream_index).to(device), ) assert_tensor_equal(empty_frames.pts_seconds, NASA_VIDEO.empty_pts_seconds) assert_tensor_equal( @@ -563,8 +611,11 @@ def test_get_frames_in_range(self, stream_index): lambda decoder: decoder.get_frames_played_in_range(0, 1).data, ), ) - def test_dimension_order(self, dimension_order, frame_getter): - decoder = VideoDecoder(NASA_VIDEO.path, dimension_order=dimension_order) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_dimension_order(self, dimension_order, frame_getter, device): + decoder = VideoDecoder( + NASA_VIDEO.path, dimension_order=dimension_order, device=device + ) frame = frame_getter(decoder) C, H, W = NASA_VIDEO.num_color_channels, NASA_VIDEO.height, NASA_VIDEO.width @@ -584,8 +635,12 @@ def test_dimension_order_fails(self): VideoDecoder(NASA_VIDEO.path, dimension_order="NCDHW") @pytest.mark.parametrize("stream_index", [0, 3, None]) - def test_get_frames_by_pts_in_range(self, stream_index): - decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frames_by_pts_in_range(self, stream_index, device): + decoder = VideoDecoder( + NASA_VIDEO.path, stream_index=stream_index, device=device + ) + frame_compare_function = get_frame_compare_function(device) # Note that we are comparing the results of VideoDecoder's method: # get_frames_played_in_range() @@ -608,9 +663,11 @@ def test_get_frames_by_pts_in_range(self, stream_index): frames0_4 = decoder.get_frames_played_in_range( decoder.get_frame_at(0).pts_seconds, decoder.get_frame_at(5).pts_seconds ) - assert_tensor_equal( + frame_compare_function( frames0_4.data, - NASA_VIDEO.get_frame_data_by_range(0, 5, stream_index=stream_index), + NASA_VIDEO.get_frame_data_by_range(0, 5, stream_index=stream_index).to( + device + ), ) # Range where the stop seconds is about halfway between pts values for two frames. @@ -618,7 +675,7 @@ def test_get_frames_by_pts_in_range(self, stream_index): decoder.get_frame_at(0).pts_seconds, decoder.get_frame_at(4).pts_seconds + HALF_DURATION, ) - assert_tensor_equal(also_frames0_4.data, frames0_4.data) + frame_compare_function(also_frames0_4.data, frames0_4.data) # Again, the intention here is to provide the exact values we care about. In practice, our # pts values are slightly smaller, so we nudge the start upwards. @@ -626,9 +683,11 @@ def test_get_frames_by_pts_in_range(self, stream_index): decoder.get_frame_at(5).pts_seconds, decoder.get_frame_at(10).pts_seconds, ) - assert_tensor_equal( + frame_compare_function( frames5_9.data, - NASA_VIDEO.get_frame_data_by_range(5, 10, stream_index=stream_index), + NASA_VIDEO.get_frame_data_by_range(5, 10, stream_index=stream_index).to( + device + ), ) # Range where we provide start_seconds and stop_seconds that are different, but @@ -638,9 +697,11 @@ def test_get_frames_by_pts_in_range(self, stream_index): decoder.get_frame_at(6).pts_seconds, decoder.get_frame_at(6).pts_seconds + HALF_DURATION, ) - assert_tensor_equal( + frame_compare_function( frame6.data, - NASA_VIDEO.get_frame_data_by_range(6, 7, stream_index=stream_index), + NASA_VIDEO.get_frame_data_by_range(6, 7, stream_index=stream_index).to( + device + ), ) # Very small range that falls in the same frame. @@ -648,9 +709,11 @@ def test_get_frames_by_pts_in_range(self, stream_index): decoder.get_frame_at(35).pts_seconds, decoder.get_frame_at(35).pts_seconds + 1e-10, ) - assert_tensor_equal( + frame_compare_function( frame35.data, - NASA_VIDEO.get_frame_data_by_range(35, 36, stream_index=stream_index), + NASA_VIDEO.get_frame_data_by_range(35, 36, stream_index=stream_index).to( + device + ), ) # Single frame where the start seconds is before frame i's pts, and the stop is @@ -662,9 +725,11 @@ def test_get_frames_by_pts_in_range(self, stream_index): NASA_VIDEO.get_frame_info(8, stream_index=stream_index).pts_seconds + HALF_DURATION, ) - assert_tensor_equal( + frame_compare_function( frames7_8.data, - NASA_VIDEO.get_frame_data_by_range(7, 9, stream_index=stream_index), + NASA_VIDEO.get_frame_data_by_range(7, 9, stream_index=stream_index).to( + device + ), ) # Start and stop seconds are the same value, which should not return a frame. @@ -672,8 +737,9 @@ def test_get_frames_by_pts_in_range(self, stream_index): NASA_VIDEO.get_frame_info(4, stream_index=stream_index).pts_seconds, NASA_VIDEO.get_frame_info(4, stream_index=stream_index).pts_seconds, ) - assert_tensor_equal( - empty_frame.data, NASA_VIDEO.get_empty_chw_tensor(stream_index=stream_index) + frame_compare_function( + empty_frame.data, + NASA_VIDEO.get_empty_chw_tensor(stream_index=stream_index).to(device), ) assert_tensor_equal( empty_frame.pts_seconds, @@ -689,9 +755,11 @@ def test_get_frames_by_pts_in_range(self, stream_index): NASA_VIDEO.get_frame_info(0, stream_index=stream_index).pts_seconds + HALF_DURATION, ) - assert_tensor_equal( + frame_compare_function( frame0.data, - NASA_VIDEO.get_frame_data_by_range(0, 1, stream_index=stream_index), + NASA_VIDEO.get_frame_data_by_range(0, 1, stream_index=stream_index).to( + device + ), ) # We should be able to get all frames by giving the beginning and ending time @@ -699,10 +767,11 @@ def test_get_frames_by_pts_in_range(self, stream_index): all_frames = decoder.get_frames_played_in_range( decoder.metadata.begin_stream_seconds, decoder.metadata.end_stream_seconds ) - assert_tensor_equal(all_frames.data, decoder[:]) + frame_compare_function(all_frames.data, decoder[:]) - def test_get_frames_by_pts_in_range_fails(self): - decoder = VideoDecoder(NASA_VIDEO.path) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frames_by_pts_in_range_fails(self, device): + decoder = VideoDecoder(NASA_VIDEO.path, device=device) with pytest.raises(ValueError, match="Invalid start seconds"): frame = decoder.get_frames_played_in_range(100.0, 1.0) # noqa