From 6903e1c2ecedc4c01664e7d6cb3155f5fe2d3d29 Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Tue, 12 Sep 2023 17:02:50 -0700 Subject: [PATCH] Apply similar changes to AZURE --- src/filesystem/implementations/as.h | 56 +++++++++++++++++------------ 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/src/filesystem/implementations/as.h b/src/filesystem/implementations/as.h index cf86f3952..100ae51ed 100644 --- a/src/filesystem/implementations/as.h +++ b/src/filesystem/implementations/as.h @@ -116,7 +116,7 @@ class ASFileSystem : public FileSystem { Status DownloadFolder( const std::string& container, const std::string& path, - const std::string& dest); + const std::string& dest, const bool recursive); std::shared_ptr client_; re2::RE2 as_regex_; @@ -392,7 +392,7 @@ ASFileSystem::FileExists(const std::string& path, bool* exists) Status ASFileSystem::DownloadFolder( const std::string& container, const std::string& path, - const std::string& dest) + const std::string& dest, const bool recursive) { auto container_client = client_->GetBlobContainerClient(container); auto func = [&](const std::vector& blobs, @@ -408,17 +408,20 @@ ASFileSystem::DownloadFolder( "Failed to download file at " + blob_item.Name + ":" + ex.what()); } } - for (const auto& directory_item : blob_prefixes) { - const auto& local_path = JoinPath({dest, BaseName(directory_item)}); - int status = mkdir( - const_cast(local_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); - if (status == -1) { - return Status( - Status::Code::INTERNAL, - "Failed to create local folder: " + local_path + - ", errno:" + strerror(errno)); + if (recursive) { + for (const auto& directory_item : blob_prefixes) { + const auto& local_path = JoinPath({dest, BaseName(directory_item)}); + int status = mkdir( + const_cast(local_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); + if (status == -1 && errno != EEXIST) { + return Status( + Status::Code::INTERNAL, + "Failed to create local folder: " + local_path + + ", errno:" + strerror(errno)); + } + RETURN_IF_ERROR( + DownloadFolder(container, directory_item, local_path, recursive)); } - RETURN_IF_ERROR(DownloadFolder(container, directory_item, local_path)); } return Status::Success; }; @@ -445,21 +448,30 @@ ASFileSystem::LocalizePath( "AS file localization not yet implemented " + path); } - std::string folder_template = "/tmp/folderXXXXXX"; - char* tmp_folder = mkdtemp(const_cast(folder_template.c_str())); - if (tmp_folder == nullptr) { - return Status( - Status::Code::INTERNAL, - "Failed to create local temp folder: " + folder_template + - ", errno:" + strerror(errno)); + // Create a local directory for s3 model store. + // If `mount_dir` or ENV variable are not set, + // creates a temporary directory under `/tmp` with the format: "folderXXXXXX". + // Otherwise, will create a folder under specified directory with the name + // indicated in path (i.e. everything after the last encounter of `/`). + const char* env_mount_dir = std::getenv("TRITON_AZURE_MOUNT_DIRECTORY"); + std::string tmp_folder; + if (mount_dir.empty() && env_mount_dir == nullptr) { + RETURN_IF_ERROR(triton::core::MakeTemporaryDirectory( + FileSystemType::LOCAL, &tmp_folder)); + } else { + tmp_folder = mount_dir.empty() ? std::string(env_mount_dir) : mount_dir; + tmp_folder = + JoinPath({tmp_folder, path.substr(path.find_last_of('/') + 1)}); + RETURN_IF_ERROR(triton::core::MakeDirectory( + tmp_folder, true /*recursive*/, true /*allow_dir_exist*/)); } - localized->reset(new LocalizedPath(path, tmp_folder)); - std::string dest(folder_template); + localized->reset(new LocalizedPath(path, tmp_folder)); + std::string dest(tmp_folder); std::string container, blob; RETURN_IF_ERROR(ParsePath(path, &container, &blob)); - return DownloadFolder(container, blob, dest); + return DownloadFolder(container, blob, dest, recursive); } Status