Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accessing DataPipe state with MultiProcessingReadingService #1033

Open
jhoareau opened this issue Feb 20, 2023 · 9 comments · May be fixed by #1039
Open

Accessing DataPipe state with MultiProcessingReadingService #1033

jhoareau opened this issue Feb 20, 2023 · 9 comments · May be fixed by #1039

Comments

@jhoareau
Copy link

jhoareau commented Feb 20, 2023

Hi TorchData team,

I'm wondering how to access the state of the datapipe in the multi-processing context with DataLoader2 + MultiProcessingReadingService. When using no reading service, we can simply access the graph using dataloader.datapipe, then I can easily access the state of my datapipe using the code shown below.

However, in the multi processing case, the datapipe graph is replaced with QueueWrapper instances, and I cannot find any way to communicate with the workers to get access to the state of the data pipe (and I get the error that my StatefulIterator cannot be found on the datapipe). If I access dl2._datapipe_before_reading_service_adapt I do get the initial state only which makes sense since there is no state sync between the main and worker processes.

As far as I understand, this will also be a blocker for state capturing for proper DataLoader checkpointing when the MultiProcessingReadingService is being used.

Potentially, could we add a getstate communication primitive in communication.messages in order to capture the state (via getstate) of a datapipe in a worker process?
We're also open to using sharding_round_robin_dispatch in order to keep more information in the main process but I'm a bit confused on how to use it, if you have some sample code for me for the following case?

Running against today's master (commit a3b34a0):

import torchdata.datapipes as dp
from torch.utils.data.graph_settings import get_all_graph_pipes, traverse_dps
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService


class StatefulIterator(dp.iter.IterDataPipe):
    def __init__(self, datapipe):
        self.datapipe = datapipe
        self.custom_index = 0

    def __iter__(self):
        self.custom_index = 0
        for item in self.datapipe:
            self.custom_index += 1
            yield item
        self.custom_index = 0


def get_datapipe():
    initial_data = dp.iter.IterableWrapper([1, 2, 3, 4])
    stateful_data = StatefulIterator(initial_data)
    sharded_data = stateful_data.sharding_filter()
    return sharded_data


def get_datapipe_state(datapipe):
    graph = traverse_dps(datapipe)
    all_pipes = get_all_graph_pipes(graph)
    for pipe in all_pipes:
        if hasattr(pipe, "custom_index"):
            return pipe.custom_index

    raise ValueError("This datapipe does not contain a StatefulIterator.")


def main_no_multiprocessing():
    dp = get_datapipe()
    dl2 = DataLoader2(dp)
    for item in dl2:
        print("Custom index", get_datapipe_state(dl2.datapipe))
        print("Item", item)


def main_multiprocessing():
    dp = get_datapipe()
    dl2 = DataLoader2(dp, reading_service=MultiProcessingReadingService(num_workers=4))
    for item in dl2:
        print("Custom index", get_datapipe_state(dl2.datapipe))
        print("Item", item)


if __name__ == "__main__":
    main_no_multiprocessing()
    main_multiprocessing()

cc: @ejguan @VitalyFedyunin @NivekT

@ejguan
Copy link
Contributor

ejguan commented Feb 21, 2023

I'm wondering how to access the state of the datapipe in the multi-processing context with DataLoader2 + MultiProcessingReadingService. When using no reading service, we can simply access the graph using dataloader.datapipe, then I can easily access the state of my datapipe using the code shown below.

When MP gets involved, the partial DataPipe graph is sent to worker process. So, there won't be any reference of that partial graph from the main process. QueueWrapper is the place connecting worker process to main process.

As far as I understand, this will also be a blocker for state capturing for proper DataLoader checkpointing when the MultiProcessingReadingService is being used.

Yes, it is. And, we are working on the solution for it. And, we probably want to add a new request like https://github.com/pytorch/data/blob/a3b34a00e7d2b6694ea0d5e21fcc084080a3abae/torchdata/dataloader2/communication/messages.py#LL89C7-L89C21 to pass the request for state to worker process and let worker process send back the state of graph.

Wondering do you have any specific use cases to access datapipe state on top of checkpointing?

@jhoareau
Copy link
Author

jhoareau commented Feb 22, 2023

we probably want to add a new request like https://github.com/pytorch/data/blob/a3b34a00e7d2b6694ea0d5e21fcc084080a3abae/torchdata/dataloader2/communication/messages.py#LL89C7-L89C21 to pass the request for state to worker process and let worker process send back the state of graph.

This is what I had envisioned as well. Glad to hear it's being worked on.
Would you accept a PR adding this functionality?

Our specific use case is for a data loading progress bar, but instead of counting after sharding, we want to count batch sizes before sharding (that's because we can have training on multiple ranks, and we want to avoid multi-rank synchronisation, so we want to see where the rank 0 datapipe is currently pre-sharding).

Our datapipe is like so:
FileOpener -> LineReader -> Map (tokenization) -> MaxTokenBucketizer -> Shard -> Collate
We want to measure the total size of batches produced by MaxTokenBucketizer pre-sharding.

We have a potential workaround by also returning this size with an extra Map before Shard, but we'd prefer not to.

@jhoareau
Copy link
Author

FYI, I have started working on a PR that adds that functionality via the snapshot function of the ReadingService, as a PoC. I hope it will fit well with your plans for the feature.

@ejguan
Copy link
Contributor

ejguan commented Feb 22, 2023

Would you accept a PR adding this functionality?

cc: @NivekT as the POC for snapshot/checkpoint.
From my perspective, you can open a PR as RFC and Kevin will discuss it on the PR since he has a working solution right now. And, we can see if those solutions are aligned.

Our specific use case is for a data loading progress bar, but instead of counting after sharding, we want to count batch sizes before sharding (that's because we can have training on multiple ranks, and we want to avoid multi-rank synchronisation, so we want to see where the rank 0 datapipe is currently pre-sharding).

Do you mean batch sizes or number of batches?

@jhoareau jhoareau linked a pull request Feb 22, 2023 that will close this issue
@jhoareau
Copy link
Author

I mean summed batch sizes. I've now created a PR as a RFC.

@NivekT
Copy link
Contributor

NivekT commented Feb 23, 2023

@jhoareau I responded here. Let me know if I missed something about your use case. Thanks for opening the issue and PR!

@jhoareau
Copy link
Author

Hi @NivekT, thanks for the detailed reply. I'll keep the conversation about state checkpointing in the PR, and will focus on the specific problem I'm trying to solve in this issue.

The documentation is quite vague on how to use sharding_round_robin_dispatch and I've gotten odd results with using it (4x the amount of data with 4 workers), would you have any example code on how to replace sharding_filter with it?

@NivekT
Copy link
Contributor

NivekT commented Feb 24, 2023

@jhoareau Can you tell us more about the set up where you are seeing duplicate data (what is the data pipeline)?

For example, here is a multiprocessing example (ran with nightly version):

dp1 = IterableWrapper(range(10)).sharding_filter().map(_fn)
dp2 = IterableWrapper(range(10)).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING).map(_fn)

for dp in [dp1, dp2]:
    rs = MultiProcessingReadingService(num_workers=2)
    dl = DataLoader2(dp, reading_service=rs)
    print(list(dl))  # [0, 1, ..., 9] in both cases

If you are using DistributedReadingService, then you will want to place a .sharding_filter() prior to .sharding_round_robin_dispatch() in order to divide up the work among nodes first.

Let us know if this is unclear.

@jhoareau
Copy link
Author

jhoareau commented Mar 1, 2023

Hi @NivekT it works with the sharding filter before the sharding round robin, indeed we're running multiprocessing + distributed. Thanks for the pointer. However, I needed to monkey-patch the round_robin_demux to set the buffer size to -1 (unlimited) for our use case (we collect 50k samples before building batches, so the buffer size of 1000 does not work for us).

I still see value in extracting state from the underlying datapipes with the MPReadingService, so I'll leave my PR up and hoping that we can also discuss that separately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants