Answer a question

I would like to know if using a DataLoader connected to a MongoDB is a sensible thing to do and how this could be implemented.

Background

I have about 20 million documents sitting in a (local) MongoDB. Way more documents than fit in memory. I would like to train a deep neural net on the data. So far, I have been exporting the data to the file system first, with subfolders named as the classes of the documents. But I find this approach nonsensical. Why export first (and later delete) if the data is already well-maintained sitting in a DB.

Question 1:

Am I right? Would it make sense to directly connect to the MongoDB? Or are there reasons not to do it (e.g. DBs generally being too slow etc.)? If DBs are too slow (why?), can one prefetch the data somehow?

Question 2:

How would one implement a PyTorch DataLoader? I have found only very few code snippets online ([1] and [2]) which makes me doubt my approach.

Code snippet

The general way how I access MongoDB is as follows below. Nothing special about this, I think.

import pymongo
from pymongo import MongoClient

myclient = pymongo.MongoClient("mongodb://localhost:27017/")
mydb = myclient["xyz"]
mycol = mydb["xyz_documents"]

query = {
    # some filters
}

results = mycol.find(query)

# results is now a cursor that can run through all docs
# Assume, for the sake of this example, that each doc contains a class name and some image that I want to train a classifier on

Answers

Introduction

This one is a little open-ended, but let's try, also please correct me if I'm wrong somewhere.

So far, I have been exporting the data to the file system first, with subfolders named as the classes of the documents.

IMO this isn't sensible because:

  • you are essentially duplicating data
  • any time you would like to train a-new given only code and database this operation would have to be repeated
  • you can access multiple datapoints at once and cache them in RAM for later reuse without reading from hard drive multiple times (which is quite heavy)

Am I right? Would it make sense to directly connect to the MongoDB?

Given above, probably yes (especially when it comes to clear and portable implementation)

Or are there reasons not to do it (e.g. DBs generally being to slow etc.)?

AFAIK DB shouldn't be slower in this case as it will cache access to it, but I'm no db expert unfortunately. Many tricks for faster access are implemented out-of-the-box for databases.

can one prefetch the data somehow?

Yes, if you just want to get data, you could load a larger part of data (say 1024 records) at one go and return batches of data from that (say batch_size=128)

Implementation

How would one implement a PyTorch DataLoader? I have found only very few code snippets online ([1] and [2]) which makes me doubt with my approach.

I'm not sure why would you want to do that. What you should go for is torch.utils.data.Dataset as shown in the examples you've listed.

I would start with simple non-optimized approach similar to the one here, so:

  • open connection to db in __init__ and keep it as long as it's used (I would create a context manager from torch.utils.data.Dataset so the connection is closed after epochs are finished)
  • I would not transform the results to list (especially since you cannot fit it in RAM for obvious reasons) as it misses the point of generators
  • I would perform batching inside this Dataset (there is an argument batch_size here).
  • I am not sure about __getitem__ function but it seems it can return multiple datapoints at once, hence I'd use that and it should allow us to use num_workers>0 (given that mycol.find(query) returns data in the same order every time)

Given that, something along those lines is what I'd do:

class DatabaseDataset(torch.utils.data.Dataset):
    def __init__(self, query, batch_size, path: str, database: str):
        self.batch_size = batch_size

        client = pymongo.MongoClient(path)
        self.db = client[database]
        self.query = query
        # Or non-approximate method, if the approximate method
        # returns smaller number of items you should be fine
        self.length = self.db.estimated_document_count()

        self.cursor = None

    def __enter__(self):
        # Ensure that this find returns the same order of query every time
        # If not, you might get duplicated data
        # It is rather unlikely (depending on batch size), shouldn't be a problem
        # for 20 million samples anyway
        self.cursor = self.db.find(self.query)
        return self

    def shuffle(self):
        # Find a way to shuffle data so it is returned in different order
        # If that happens out of the box you might be fine without it actually
        pass

    def __exit__(self, *_, **__):
        # Or anything else how to close the connection
        self.cursor.close()

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, index):
        # Read takes long, hence if you can load a batch of documents it should speed things up
        examples = self.cursor[index * batch_size : (index + 1) * batch_size]
        # Do something with this data
        ...
        # Return the whole batch
        return data, labels

Now batching is taken care of by DatabaseDataset, hence torch.utils.data.DataLoader can have batch_size=1. You might need to squeeze additional dimension.

As MongoDB uses locks (which is no surprise, but see here) num_workers>0 shouldn't be a problem.

Possible usage (schematically):

with DatabaseDataset(...) as e:
    dataloader = torch.utils.data.DataLoader(e, batch_size=1)
    for epoch in epochs:
        for batch in dataloader:
            # And all the stuff
            ...
        dataset.shuffle() # after each epoch

Remember about shuffling implementation in such case! (also shuffling can be done inside context manager and you might want to close connection manually or something along those lines).

Logo

MongoDB社区为您提供最前沿的新闻资讯和知识内容

更多推荐