Skip to content

Commit

Permalink
local data backend should have file locking for writes and reads, so …
Browse files Browse the repository at this point in the history
…that we do not read partial writes or corrupt multiprocess writes
  • Loading branch information
bghira committed Nov 14, 2024
1 parent ca023b5 commit 78d2d07
Showing 1 changed file with 77 additions and 26 deletions.
103 changes: 77 additions & 26 deletions helpers/data_backend/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import torch
from typing import Any
from regex import regex
import fcntl
import tempfile
import shutil

logger = logging.getLogger("LocalDataBackend")
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))
Expand All @@ -21,29 +24,50 @@ def __init__(self, accelerator, id: str, compress_cache: bool = False):

def read(self, filepath, as_byteIO: bool = False):
"""Read and return the content of the file."""
# Openfilepath as BytesIO:
with open(filepath, "rb") as file:
data = file.read()
if not as_byteIO:
return data
return BytesIO(data)
# Acquire a shared lock
fcntl.flock(file, fcntl.LOCK_SH)
try:
data = file.read()
if not as_byteIO:
return data
return BytesIO(data)
finally:
# Release the lock
fcntl.flock(file, fcntl.LOCK_UN)

def write(self, filepath: str, data: Any) -> None:
"""Write the provided data to the specified filepath."""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, "wb") as file:
# Check if data is a Tensor, and if so, save it appropriately
if isinstance(data, torch.Tensor):
# logger.debug(f"Writing a torch file to disk.")
return self.torch_save(data, file)
elif isinstance(data, str):
# logger.debug(f"Writing a string to disk as {filepath}: {data}")
data = data.encode("utf-8")
else:
logger.debug(
f"Received an unknown data type to write to disk. Doing our best: {type(data)}"
)
file.write(data)
temp_dir = os.path.dirname(filepath)
temp_file_path = os.path.join(temp_dir, f".{os.path.basename(filepath)}.tmp")

# Open the temporary file for writing
with open(temp_file_path, "wb") as temp_file:
# Acquire an exclusive lock on the temporary file
fcntl.flock(temp_file, fcntl.LOCK_EX)
try:
# Write data to the temporary file
if isinstance(data, torch.Tensor):
# Use the torch_save method, passing the temp file
self.torch_save(data, temp_file)
return # torch_save handles closing the file
elif isinstance(data, str):
data = data.encode("utf-8")
else:
logger.debug(
f"Received an unknown data type to write to disk. Doing our best: {type(data)}"
)
temp_file.write(data)
temp_file.flush()
os.fsync(temp_file.fileno())
finally:
# Release the lock
fcntl.flock(temp_file, fcntl.LOCK_UN)

# Atomically replace the target file with the temporary file
os.rename(temp_file_path, filepath)


def delete(self, filepath):
"""Delete the specified file."""
Expand Down Expand Up @@ -212,16 +236,43 @@ def torch_save(self, data, original_location):
Save a torch tensor to a file.
"""
if isinstance(original_location, str):
location = self.open_file(original_location, "wb")
else:
location = original_location
filepath = original_location
os.makedirs(os.path.dirname(filepath), exist_ok=True)
temp_dir = os.path.dirname(filepath)
temp_file_path = os.path.join(temp_dir, f".{os.path.basename(filepath)}.tmp")

if self.compress_cache:
compressed_data = self._compress_torch(data)
location.write(compressed_data)
with open(temp_file_path, "wb") as temp_file:
# Acquire an exclusive lock on the temporary file
fcntl.flock(temp_file, fcntl.LOCK_EX)
try:
if self.compress_cache:
compressed_data = self._compress_torch(data)
temp_file.write(compressed_data)
else:
torch.save(data, temp_file)
temp_file.flush()
os.fsync(temp_file.fileno())
finally:
# Release the lock
fcntl.flock(temp_file, fcntl.LOCK_UN)
# Atomically replace the target file with the temporary file
os.rename(temp_file_path, filepath)
else:
torch.save(data, location)
location.close()
# Handle the case where original_location is a file object
temp_file = original_location
# Acquire an exclusive lock on the file object
fcntl.flock(temp_file, fcntl.LOCK_EX)
try:
if self.compress_cache:
compressed_data = self._compress_torch(data)
temp_file.write(compressed_data)
else:
torch.save(data, temp_file)
temp_file.flush()
os.fsync(temp_file.fileno())
finally:
# Release the lock
fcntl.flock(temp_file, fcntl.LOCK_UN)

def write_batch(self, filepaths: list, data_list: list) -> None:
"""Write a batch of data to the specified filepaths."""
Expand Down

0 comments on commit 78d2d07

Please sign in to comment.