Skip to content

Commit

Permalink
Merge pull request #1312 from stan-dev/comma-separated-filenames-option
Browse files Browse the repository at this point in the history
Allow comma-separated filenames in multi-chain configurations
  • Loading branch information
WardBrian authored Feb 21, 2025
2 parents 6858878 + a6fe257 commit f842450
Show file tree
Hide file tree
Showing 21 changed files with 594 additions and 136 deletions.
5 changes: 5 additions & 0 deletions make/command
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
src/cmdstan/stansummary.d: DEPTARGETS = -MT bin/cmdstan/stansummary.o
# don't build anything during a `make clean`
ifneq ($(MAKECMDGOALS),)
ifeq ($(filter clean%,$(MAKECMDGOALS)),)
-include src/cmdstan/stansummary.d
endif
endif

bin/cmdstan/%.o : src/cmdstan/%.cpp
@mkdir -p $(dir $@)
Expand Down
14 changes: 9 additions & 5 deletions make/tests
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ test/%$(EXE) : INC += $(INC_GTEST) -I $(RAPIDJSON)
test/%$(EXE) : test/%.o $(GTEST)/src/gtest_main.cc $(GTEST)/src/gtest-all.o $(SUNDIALS_TARGETS) $(MPI_TARGETS) $(TBB_TARGETS)
$(LINK.cpp) $(filter-out src/test/test-models/% src/%.csv bin/% test/%.hpp %.hpp-test,$^) $(LDLIBS) $(OUTPUT_OPTION)

test/%.o : src/test/%.cpp
.PRECIOUS: test/%.o
test/%.o : src/test/%.cpp src/test/utility.hpp $(wildcard src/cmdstan/*.hpp)
@mkdir -p $(dir $@)
$(COMPILE.cpp) $< $(OUTPUT_OPTION)

Expand All @@ -23,7 +24,6 @@ src/test/%.d : test/%.o

ifneq ($(filter test/%,$(MAKECMDGOALS)),)
-include $(patsubst test/%$(EXE),src/test/%.d,$(filter test/%,$(MAKECMDGOALS)))
-include $(patsubst %.cpp,%.d,$(STANC_TEMPLATE_INSTANTIATION_CPP))
endif

############################################################
Expand Down Expand Up @@ -56,10 +56,14 @@ test-headers: $(HEADER_TESTS)
##
TEST_MODELS := $(wildcard src/test/test-models/*.stan)

ifneq ($(filter test-models-hpp,$(MAKECMDGOALS)),)
-include $(patsubst %.stan,%.d,$(TEST_MODELS))
include src/cmdstan/main.d
endif

.PHONY: test-models-hpp
test-models-hpp:
$(MAKE) $(patsubst %.stan,%.hpp,$(TEST_MODELS))
$(MAKE) $(patsubst %.stan,%$(EXE),$(TEST_MODELS))
test-models-hpp: $(patsubst %.stan,%.hpp,$(TEST_MODELS)) $(patsubst %.stan,%$(EXE),$(TEST_MODELS))

##
# Tests that depend on compiled models
##
Expand Down
8 changes: 8 additions & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,14 @@ build-mpi: $(MPI_TARGETS)
@echo ''
@echo '--- boost mpi bindings built ---'

# don't build anything during a `make clean`
# but otherwise, we always want to check main.d
ifneq ($(MAKECMDGOALS),)
ifeq ($(filter clean%,$(MAKECMDGOALS)),)
include src/cmdstan/main.d
endif
endif

.PHONY: build
build: bin/stanc$(EXE) $(SUNDIALS_TARGETS) $(MPI_TARGETS) $(TBB_TARGETS) $(CMDSTAN_MAIN_O) $(PRECOMPILED_MODEL_HEADER) bin/stansummary$(EXE) bin/print$(EXE) bin/diagnose$(EXE)
@echo ''
Expand Down
7 changes: 5 additions & 2 deletions src/cmdstan/arguments/arg_diagnostic_file.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ class arg_diagnostic_file : public string_argument {
public:
arg_diagnostic_file() : string_argument() {
_name = "diagnostic_file";
_description = "Auxiliary output file for diagnostic information";
_validity = "Path to existing file";
_description
= "Auxiliary output file for diagnostic information. If multiple "
"chains are run, this can either be a single path, in which case its "
"name will have _ID appended, or a comma-separated list of names.";
_validity = "File(s) should not already exist";
_default = "\"\"";
_default_value = "";
_value = _default_value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ class arg_generate_quantities_fitted_params : public string_argument {
_name = "fitted_params";
_description
= "Input file of sample of fitted parameter values for model "
"conditioned on data";
"conditioned on data. If multiple chains are run, this can either "
"be a single path, in which case its name will have _ID appended, or "
"a comma-separated list of names.";
_validity = "Path to existing file";
_default = "\"\"";
_default_value = "";
Expand Down
6 changes: 5 additions & 1 deletion src/cmdstan/arguments/arg_init.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ class arg_init : public string_argument {
_description = std::string("Initialization method: ")
+ std::string("\"x\" initializes randomly between [-x, x], ")
+ std::string("\"0\" initializes to 0, ")
+ std::string("anything else identifies a file of values");
+ std::string(
"anything else identifies a file of values. If "
"multiple chains are run, this can either be a single "
"path, in which case its name will have _ID appended, "
"or a comma-separated list of names.");
_default = "\"2\"";
_default_value = "2";
_value = _default_value;
Expand Down
7 changes: 5 additions & 2 deletions src/cmdstan/arguments/arg_output_file.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ class arg_output_file : public string_argument {
public:
arg_output_file() : string_argument() {
_name = "file";
_description = "Output file";
_validity = "Path to existing file";
_description
= "Output file. If multiple chains are run, this can either be a "
"single path, in which case its name will have _ID appended, or a "
"comma-separated list of names.";
_validity = "File(s) should not already exist";
_default = "output.csv";
_default_value = "output.csv";
_value = _default_value;
Expand Down
15 changes: 10 additions & 5 deletions src/cmdstan/command.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,17 @@ int command(int argc, const char *argv[]) {
"Missing fitted_params argument, cannot run generate_quantities "
"without fitted sample.");
}
auto file_info = file::get_basename_suffix(fname);
if (file_info.second != ".csv") {
throw std::invalid_argument("Fitted params file must be a CSV file.");
}

std::vector<std::string> fname_vec
= file::make_filenames(file_info.first, "", ".csv", num_chains, id);
= file::make_filenames(fname, "", ".csv", num_chains, id);

for (auto &f : fname_vec) {
auto file_info = file::get_basename_suffix(f);
if (file_info.second != ".csv") {
throw std::invalid_argument("Fitted params file must be a CSV file.");
}
}

std::vector<std::string> param_names = get_constrained_param_names(model);
std::vector<Eigen::MatrixXd> fitted_params_vec;
fitted_params_vec.reserve(num_chains);
Expand Down
190 changes: 101 additions & 89 deletions src/cmdstan/command_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,35 @@ inline constexpr auto get_arg_val(List &&arg_list, Args &&... args) {
}
}

/**
* Check that a file is either a .json or .R (dump) extension
*/
inline void check_valid_context_file_name(const std::string &file) {
auto [file_name, file_ending] = file::get_basename_suffix(file);
if (file_ending != ".json") {
if (file_ending != ".R") {
std::stringstream msg;
msg << "User specified files must end in .json or .R. Found: ";
if (file_ending.empty()) {
msg << file;
} else {
msg << file_ending;
}
msg << std::endl;
throw std::invalid_argument(msg.str());
}

std::cerr << "Warning: file '" << file
<< "' is being read as an 'RDump' file.\n"
"\tThis format is deprecated and will not receive new "
"features.\n"
"\tConsider saving your data in JSON format instead."
<< std::endl;
}
}

using shared_context_ptr = std::shared_ptr<stan::io::var_context>;

/**
* Given the name of a file, return a shared pointer holding the data contents.
* @param file A system file to read from
Expand All @@ -117,131 +145,115 @@ inline shared_context_ptr get_var_context(const std::string &file) {
if (file.empty()) {
return std::make_shared<stan::io::empty_var_context>();
}
check_valid_context_file_name(file);
std::ifstream stream = file::safe_open(file);
if (file::get_suffix(file) == ".json") {
stan::json::json_data var_context(stream);
return std::make_shared<stan::json::json_data>(var_context);
}
std::cerr
<< "Warning: file '" << file
<< "' is being read as an 'RDump' file.\n"
"\tThis format is deprecated and will not receive new features.\n"
"\tConsider saving your data in JSON format instead."
<< std::endl;
stan::io::dump var_context(stream);
return std::make_shared<stan::io::dump>(var_context);
}

using context_vector = std::vector<shared_context_ptr>;
/**
* Make a vector of shared pointers to contexts.
* @param file The name of the file. For multi-chain we will attempt to find
* {file_name}_1{file_ending} and if that fails try to use the named file as
* @param file The name of the file. For multi-chain, this can either be a
* comma-separated list, or else we will attempt to find
* {file_name}_{id}{file_ending} and if that fails try to use the named file as
* the data for each chain.
* @param num_chains The number of chains to run
* @param id The id of the first chain
* @return a std vector of shared pointers to var contexts
*/
context_vector get_vec_var_context(const std::string &file, size_t num_chains,
unsigned int id) {
using stan::io::var_context;
// simple handling for 1 chain
if (num_chains == 1) {
return context_vector(1, get_var_context(file));
}
auto make_context = [](auto &&file, auto &&stream,
auto &&file_ending) -> shared_context_ptr {
// use default for all chain inits
if (file.empty()) {
return context_vector(num_chains,
std::make_shared<stan::io::empty_var_context>());
}

const bool has_commas = file.find(',') != std::string::npos;
auto filenames = file::make_filenames(file, "", "", num_chains, id);

std::vector<std::string> missing_files;
std::vector<std::fstream> streams;
streams.reserve(num_chains);

// check files are valid and exist, or build up a list of the missing ones
for (auto &&file_name : filenames) {
check_valid_context_file_name(file_name);
std::fstream stream(file_name.c_str(), std::fstream::in);
if (stream.rdstate() & std::ifstream::failbit) {
missing_files.push_back(file_name);
}
streams.push_back(std::move(stream));
}

auto make_context = [](auto &&file, auto &&stream) -> shared_context_ptr {
auto [file_name, file_ending] = file::get_basename_suffix(file);
if (file_ending == ".json") {
using stan::json::json_data;
return std::make_shared<json_data>(json_data(stream));
} else if (file_ending == ".R") {
using stan::io::dump;
return std::make_shared<stan::io::dump>(dump(stream));
return std::make_shared<dump>(dump(stream));

} else {
// should never happen, caught by check_valid_context_file_name above
std::stringstream msg;
msg << "file ending of " << file_ending << " is not supported by cmdstan";
throw std::invalid_argument(msg.str());
using stan::io::dump;
return std::make_shared<dump>(dump(stream));
}
};
// use default for all chain inits
if (file.empty()) {
return context_vector(num_chains,
std::make_shared<stan::io::empty_var_context>());
} else {
size_t file_marker_pos = file.find_last_of(".");
if (file_marker_pos > file.size()) {
std::stringstream msg;
msg << "Found: \"" << file
<< "\" but user specified files must end in .json or .R";
throw std::invalid_argument(msg.str());
}
std::string file_name = file.substr(0, file_marker_pos);
std::string file_ending = file.substr(file_marker_pos, file.size());
if (file_ending != ".json" && file_ending != ".R") {
std::stringstream msg;
msg << "file ending of " << file_ending << " is not supported by cmdstan";
throw std::invalid_argument(msg.str());
}
if (file_ending != ".json") {
std::cerr
<< "Warning: file '" << file
<< "' is being read as an 'RDump' file.\n"
"\tThis format is deprecated and will not receive new features.\n"
"\tConsider saving your data in JSON format instead."
<< std::endl;
}

auto filenames
= file::make_filenames(file_name, "", file_ending, num_chains, id);
auto &file_1 = filenames[0];
std::fstream stream_1(file_1.c_str(), std::fstream::in);
// if file_1 exists we'll assume num_chains of these files exist
if (stream_1.rdstate() & std::ifstream::failbit) {
// if that fails we will try to find a base file
std::fstream stream(file.c_str(), std::fstream::in);
if (stream.rdstate() & std::ifstream::failbit) {
std::string file_name_err
= std::string("\"" + file_1 + "\" and base file \"" + file + "\"");
std::stringstream msg;
msg << "Searching for \"" << file_name_err << std::endl;
msg << "Can't open either of specified files," << file_name_err
<< std::endl;
throw std::invalid_argument(msg.str());
} else {
return context_vector(num_chains,
make_context(file, stream, file_ending));
}
} else {
// If we found file_1 then we'll assume file_{1...N} exists
context_vector ret;
ret.reserve(num_chains);
ret.push_back(make_context(file_1, stream_1, file_ending));
for (size_t i = 1; i < num_chains; ++i) {
auto &file_i = filenames[i];
std::fstream stream_i(file_i.c_str(), std::fstream::in);
// If any stream fails here something went wrong with file names
if (stream_i.rdstate() & std::ifstream::failbit) {
std::string file_name_err = std::string(
"\"" + file_1 + "\" but cannot open \"" + file_i + "\"");
std::stringstream msg;
msg << "Found " << file_name_err << std::endl;
throw std::invalid_argument(msg.str());
}
ret.push_back(make_context(file_i, stream_i, file_ending));
}
return ret;
}
// happy path - all files exist and we can return the contexts
if (missing_files.empty()) {
context_vector ret(num_chains);
std::transform(filenames.cbegin(), filenames.cend(), streams.begin(),
ret.begin(), make_context);
return ret;
}

// user directly specified a list of files, some of which don't exist
if (has_commas && !missing_files.empty()) {
std::stringstream msg;
msg << "Cannot open some of the requested files: [";
msg << boost::algorithm::join(missing_files, ", ");
msg << "]" << std::endl;
throw std::invalid_argument(msg.str());
}
// This should not happen
std::cerr
<< "Warning: file '" << file
<< "' is being read as an 'RDump' file.\n"
"\tThis format is deprecated and will not receive new features.\n"
"\tConsider saving your data in JSON format instead."
<< std::endl;
using stan::io::dump;

// legacy -- if the user requested 'init.json', we looked for 'init_1.json'
// but if that fails, we try 'init.json' as well
std::fstream stream(file.c_str(), std::fstream::in);
return context_vector(num_chains, std::make_shared<dump>(dump(stream)));
if (stream.rdstate() & std::ifstream::failbit) {
std::stringstream msg;
msg << "Cannot open some of the requested files: [";
msg << boost::algorithm::join(missing_files, ", ");
msg << "]" << std::endl;
msg << "Also failed to find base file " << file << std::endl;
msg << "When cmdstan is given a file 'name' and there are "
"multiple chains or pathfinders,"
" cmdstan will look for files 'name_{N..(N + "
"num_processes)' where N is the id (typically, 1)."
" If these are not found, then it looks for the exact "
"file name as passed."
" In this case, neither option was found.";

throw std::invalid_argument(msg.str());
} else {
std::cerr << "Warning: file '" << file
<< "' is being used to initialize all " << num_chains
<< " chains!" << std::endl;
return context_vector(num_chains, make_context(file, std::move(stream)));
}
}

/**
Expand Down
Loading

0 comments on commit f842450

Please sign in to comment.