Skip to content

Commit e4aa58e

Browse files
committed
Sketch, copy makefile from bridgestan
0 parents  commit e4aa58e

9 files changed

+438
-0
lines changed

.clang-format

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
---
2+
Language: Cpp
3+
# BasedOnStyle: Google
4+
AccessModifierOffset: -1
5+
AlignAfterOpenBracket: Align
6+
AlignConsecutiveAssignments: false
7+
AlignConsecutiveDeclarations: false
8+
AlignEscapedNewlines: Left
9+
AlignOperands: true
10+
AlignTrailingComments: true
11+
AllowAllParametersOfDeclarationOnNextLine: true
12+
AllowShortBlocksOnASingleLine: false
13+
AllowShortCaseLabelsOnASingleLine: false
14+
AllowShortFunctionsOnASingleLine: All
15+
AllowShortIfStatementsOnASingleLine: false
16+
AllowShortLoopsOnASingleLine: false
17+
AlwaysBreakAfterDefinitionReturnType: None
18+
AlwaysBreakAfterReturnType: None
19+
AlwaysBreakBeforeMultilineStrings: true
20+
AlwaysBreakTemplateDeclarations: true
21+
BinPackArguments: true
22+
BinPackParameters: true
23+
BraceWrapping:
24+
AfterClass: false
25+
AfterControlStatement: false
26+
AfterEnum: false
27+
AfterFunction: false
28+
AfterNamespace: false
29+
AfterObjCDeclaration: false
30+
AfterStruct: false
31+
AfterUnion: false
32+
BeforeCatch: false
33+
BeforeElse: false
34+
IndentBraces: false
35+
BreakBeforeBinaryOperators: All
36+
BreakBeforeBraces: Attach
37+
BreakBeforeInheritanceComma: false
38+
BreakBeforeTernaryOperators: true
39+
BreakConstructorInitializersBeforeComma: false
40+
BreakConstructorInitializers: BeforeColon
41+
BreakAfterJavaFieldAnnotations: false
42+
BreakStringLiterals: true
43+
ColumnLimit: 80
44+
CommentPragmas: '^ IWYU pragma:'
45+
CompactNamespaces: false
46+
ConstructorInitializerAllOnOneLineOrOnePerLine: true
47+
ConstructorInitializerIndentWidth: 4
48+
ContinuationIndentWidth: 4
49+
Cpp11BracedListStyle: true
50+
DerivePointerAlignment: true
51+
DisableFormat: false
52+
ExperimentalAutoDetectBinPacking: false
53+
FixNamespaceComments: true
54+
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
55+
IncludeCategories:
56+
- Regex: '^<.*\.h>'
57+
Priority: 1
58+
- Regex: '^<.*'
59+
Priority: 2
60+
- Regex: '.*'
61+
Priority: 3
62+
IncludeIsMainRegex: '([-_](test|unittest))?$'
63+
IndentCaseLabels: true
64+
IndentWidth: 2
65+
IndentWrappedFunctionNames: false
66+
JavaScriptQuotes: Leave
67+
JavaScriptWrapImports: true
68+
KeepEmptyLinesAtTheStartOfBlocks: false
69+
MacroBlockBegin: ''
70+
MacroBlockEnd: ''
71+
MaxEmptyLinesToKeep: 1
72+
NamespaceIndentation: None
73+
ObjCBlockIndentWidth: 2
74+
ObjCSpaceAfterProperty: false
75+
ObjCSpaceBeforeProtocolList: false
76+
PenaltyBreakAssignment: 2
77+
PenaltyBreakBeforeFirstCallParameter: 1
78+
PenaltyBreakComment: 300
79+
PenaltyBreakFirstLessLess: 120
80+
PenaltyBreakString: 1000
81+
PenaltyExcessCharacter: 1000000
82+
PenaltyReturnTypeOnItsOwnLine: 200
83+
PointerAlignment: Left
84+
ReflowComments: true
85+
SortIncludes: false
86+
SpaceAfterCStyleCast: false
87+
SpaceAfterTemplateKeyword: true
88+
SpaceBeforeAssignmentOperators: true
89+
SpaceBeforeParens: ControlStatements
90+
SpaceInEmptyParentheses: false
91+
SpacesBeforeTrailingComments: 2
92+
SpacesInAngles: false
93+
SpacesInContainerLiterals: true
94+
SpacesInCStyleCastParentheses: false
95+
SpacesInParentheses: false
96+
SpacesInSquareBrackets: false
97+
Standard: Auto
98+
TabWidth: 8
99+
UseTab: Never
100+
...

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*.o
2+
*.so
3+
bin/

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "stan"]
2+
path = stan
3+
url = https://github.com/stan-dev/stan.git

Makefile

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
## include paths
2+
CSTAN_ROOT ?= .
3+
SRC ?= $(CSTAN_ROOT)/src/
4+
STAN ?= $(CSTAN_ROOT)/stan/
5+
STANC ?= $(CSTAN_ROOT)/bin/stanc$(EXE)
6+
MATH ?= $(STAN)lib/stan_math/
7+
RAPIDJSON ?= $(STAN)lib/rapidjson_1.1.0/
8+
9+
## required C++ includes
10+
INC_FIRST ?= -I $(STAN)src -I $(RAPIDJSON)
11+
12+
## makefiles needed for math library
13+
-include $(CSTAN_ROOT)/make/local
14+
-include $(MATH)make/compiler_flags
15+
-include $(MATH)make/libraries
16+
17+
## Set -fPIC globally since we're always building a shared library
18+
CXXFLAGS += -fPIC
19+
20+
## set flags for stanc compiler (math calls MIGHT? set STAN_OPENCL)
21+
ifdef STAN_OPENCL
22+
STANCFLAGS += --use-opencl
23+
STAN_FLAG_OPENCL=_opencl
24+
else
25+
STAN_FLAG_OPENCL=
26+
endif
27+
ifdef STAN_THREADS
28+
STAN_FLAG_THREADS=_threads
29+
else
30+
STAN_FLAG_THREADS=
31+
endif
32+
ifdef BRIDGESTAN_AD_HESSIAN
33+
CXXFLAGS+=-DSTAN_MODEL_FVAR_VAR -DBRIDGESTAN_AD_HESSIAN
34+
STAN_FLAG_HESS=_adhessian
35+
else
36+
STAN_FLAG_HESS=
37+
endif
38+
STAN_FLAGS=$(STAN_FLAG_THREADS)$(STAN_FLAG_OPENCL)$(STAN_FLAG_HESS)
39+
40+
CSTAN_DEPS = $(SRC)cstan.cpp
41+
CSTAN_O = $(patsubst %.cpp,%$(STAN_FLAGS).o,$(SRC)cstan.cpp)
42+
43+
$(CSTAN_O) : $(CSTAN_DEPS)
44+
@echo ''
45+
@echo '--- Compiling Stan C++ code ---'
46+
@mkdir -p $(dir $@)
47+
$(COMPILE.cpp) $(OUTPUT_OPTION) $(LDLIBS) $<
48+
49+
## generate .hpp file from .stan file using stanc
50+
%.hpp : %.stan $(STANC)
51+
@echo ''
52+
@echo '--- Translating Stan model to C++ code ---'
53+
$(STANC) $(STANCFLAGS) --o=$(subst \,/,$@) $(subst \,/,$<)
54+
55+
## declares we want to keep .hpp even though it's an intermediate
56+
.PRECIOUS: %.hpp
57+
58+
## builds executable (suffix depends on platform)
59+
%_model.so : %.hpp $(CSTAN_O) $(LIBSUNDIALS) $(MPI_TARGETS) $(TBB_TARGETS)
60+
@echo ''
61+
@echo '--- Compiling C++ code ---'
62+
$(COMPILE.cpp) -x c++ -o $(subst \,/,$*).o $(subst \,/,$<)
63+
@echo '--- Linking C++ code ---'
64+
$(LINK.cpp) -shared -lm -o $(patsubst %.hpp, %_model.so, $(subst \,/,$<)) $(subst \,/,$*.o) $(CSTAN_O) $(LDLIBS) $(LIBSUNDIALS) $(MPI_TARGETS) $(TBB_TARGETS)
65+
$(RM) $(subst \,/,$*).o
66+
67+
.PHONY: docs
68+
docs:
69+
$(MAKE) -C docs/ html
70+
71+
.PHONY: clean
72+
clean:
73+
$(RM) $(SRC)/*.o
74+
$(RM) test_models/**/*.so
75+
$(RM) test_models/**/*.hpp
76+
$(RM) bin/stanc$(EXE)
77+
78+
79+
# build all test models at once
80+
TEST_MODEL_NAMES = $(patsubst $(CSTAN_ROOT)/test_models/%/, %, $(sort $(dir $(wildcard $(CSTAN_ROOT)/test_models/*/))))
81+
TEST_MODEL_NAMES := $(filter-out syntax_error, $(TEST_MODEL_NAMES))
82+
TEST_MODEL_LIBS = $(join $(addprefix test_models/, $(TEST_MODEL_NAMES)), $(addsuffix _model.so, $(addprefix /, $(TEST_MODEL_NAMES))))
83+
84+
.PHONY: test_models
85+
test_models: $(TEST_MODEL_LIBS)
86+
87+
.PHONY: stan-update stan-update-version
88+
stan-update:
89+
git submodule update --init --recursive
90+
91+
stan-update-remote:
92+
git submodule update --remote --init --recursive
93+
94+
# print compilation command line config
95+
.PHONY: compile_info
96+
compile_info:
97+
@echo '$(LINK.cpp) $(STANC_O) $(LDLIBS) $(LIBSUNDIALS) $(MPI_TARGETS) $(TBB_TARGETS)'
98+
99+
## print value of makefile variable (e.g., make print-TBB_TARGETS)
100+
.PHONY: print-%
101+
print-% : ; @echo $* = $($*) ;
102+
103+
# handles downloading of stanc
104+
STANC_DL_RETRY = 5
105+
STANC_DL_DELAY = 10
106+
STANC3_TEST_BIN_URL ?=
107+
STANC3_VERSION ?= v2.32.2
108+
109+
ifeq ($(OS),Windows_NT)
110+
OS_TAG := windows
111+
else ifeq ($(OS),Darwin)
112+
OS_TAG := mac
113+
else ifeq ($(OS),Linux)
114+
OS_TAG := linux
115+
ifeq ($(shell uname -m),mips64)
116+
ARCH_TAG := -mips64el
117+
else ifeq ($(shell uname -m),ppc64le)
118+
ARCH_TAG := -ppc64el
119+
else ifeq ($(shell uname -m),s390x)
120+
ARCH_TAG := -s390x
121+
else ifeq ($(shell uname -m),aarch64)
122+
ARCH_TAG := -arm64
123+
else ifeq ($(shell uname -m),armv7l)
124+
ifeq ($(shell readelf -A /usr/bin/file | grep Tag_ABI_VFP_args),)
125+
ARCH_TAG := -armel
126+
else
127+
ARCH_TAG := -armhf
128+
endif
129+
endif
130+
endif
131+
132+
ifeq ($(OS_TAG),windows)
133+
$(STANC):
134+
@mkdir -p $(dir $@)
135+
$(shell echo "curl -L https://github.com/stan-dev/stanc3/releases/download/$(STANC3_VERSION)/$(OS_TAG)-stanc -o $(STANC) --retry $(STANC_DL_RETRY) --retry-delay $(STANC_DL_DELAY)")
136+
else
137+
$(STANC):
138+
@mkdir -p $(dir $@)
139+
curl -L https://github.com/stan-dev/stanc3/releases/download/$(STANC3_VERSION)/$(OS_TAG)$(ARCH_TAG)-stanc -o $(STANC) --retry $(STANC_DL_RETRY) --retry-delay $(STANC_DL_DELAY)
140+
chmod +x $(STANC)
141+
endif

bernoulli.data.json

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"N" : 10,
3+
"y" : [0,1,0,0,0,0,0,0,0,1]
4+
}

bernoulli.stan

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
data {
2+
int<lower=0> N;
3+
array[N] int<lower=0,upper=1> y;
4+
}
5+
parameters {
6+
real<lower=0,upper=1> theta;
7+
}
8+
model {
9+
theta ~ beta(1,1); // uniform prior on interval 0,1
10+
y ~ bernoulli(theta);
11+
}

src/cstan.cpp

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
#include <stan/callbacks/logger.hpp>
2+
#include <stan/callbacks/interrupt.hpp>
3+
#include <stan/callbacks/writer.hpp>
4+
#include <stan/io/json/json_data.hpp>
5+
#include <stan/io/var_context.hpp>
6+
#include <stan/io/empty_var_context.hpp>
7+
#include <stan/model/model_base.hpp>
8+
#include <stan/services/util/create_rng.hpp>
9+
#include <stan/services/sample/hmc_nuts_diag_e.hpp>
10+
#include <fstream>
11+
#include <iostream>
12+
#include <ostream>
13+
#include <set>
14+
#include <sstream>
15+
#include <stdexcept>
16+
#include <string>
17+
#include <vector>
18+
19+
// #include "err.h"
20+
// #include "cstan.h"
21+
22+
// globals for Stan model output
23+
std::streambuf *buf = nullptr;
24+
std::ostream *outstream = &std::cout;
25+
26+
// error handling
27+
28+
class stan_error {
29+
public:
30+
stan_error(char *msg) : msg(msg) {}
31+
32+
~stan_error() { free(this->msg); }
33+
34+
char *msg;
35+
};
36+
37+
extern "C" {
38+
const char *cstan_get_error_message(const stan_error *err) { return err->msg; }
39+
40+
void cstan_free_stan_error(stan_error *err) { delete (err); }
41+
}
42+
43+
/**
44+
* Allocate and return a new model as a reference given the specified
45+
* data context, seed, and message stream. This function is defined
46+
* in the generated model class.
47+
*
48+
* @param[in] data_context context for reading model data
49+
* @param[in] seed random seed for transformed data block
50+
* @param[in] msg_stream stream to which to send messages printed by the model
51+
*/
52+
stan::model::model_base &new_model(stan::io::var_context &data_context,
53+
unsigned int seed, std::ostream *msg_stream);
54+
55+
class buffer_writer : public stan::callbacks::writer {
56+
public:
57+
buffer_writer(double *buf) : buf(buf), pos(0){};
58+
~buffer_writer(){};
59+
60+
void operator()(const std::vector<double> &v) override {
61+
for (auto d : v) {
62+
buf[pos++] = d;
63+
}
64+
}
65+
66+
private:
67+
double *buf;
68+
int pos;
69+
};
70+
71+
std::unique_ptr<stan::io::var_context> load_data(const char *data_char) {
72+
if (data_char == nullptr) {
73+
return std::unique_ptr<stan::io::var_context>(
74+
new stan::io::empty_var_context());
75+
}
76+
std::string data(data_char);
77+
if (data.empty()) {
78+
return std::unique_ptr<stan::io::var_context>(
79+
new stan::io::empty_var_context());
80+
}
81+
std::ifstream data_stream(data);
82+
if (!data_stream.good()) {
83+
throw std::invalid_argument("Could not open data file " + data);
84+
}
85+
return std::unique_ptr<stan::io::var_context>(
86+
new stan::json::json_data(data_stream));
87+
}
88+
extern "C" {
89+
90+
int cstan_sample(const char *data, const char *inits, unsigned int seed,
91+
unsigned int chain_id, double init_radius, int num_warmup,
92+
int num_samples, bool save_warmup, int refresh,
93+
double stepsize, double stepsize_jitter, int max_depth,
94+
double *out, stan_error **err) {
95+
auto json_data = load_data(data);
96+
auto json_inits = load_data(inits);
97+
try {
98+
auto &model = new_model(*json_data, seed, outstream);
99+
buffer_writer sample_writer(out);
100+
stan::callbacks::interrupt interrupt;
101+
stan::callbacks::logger logger;
102+
stan::callbacks::writer null_writer;
103+
104+
return stan::services::sample::hmc_nuts_diag_e(
105+
model, *json_inits, seed, chain_id, init_radius, num_warmup,
106+
num_samples, /*no thinning*/ 1, save_warmup, refresh, stepsize,
107+
stepsize_jitter, max_depth, interrupt, logger, null_writer,
108+
sample_writer, null_writer);
109+
110+
} catch (const std::exception &e) {
111+
if (err != nullptr) {
112+
*err = new stan_error(strdup(e.what()));
113+
}
114+
} catch (...) {
115+
if (err != nullptr) {
116+
*err = new stan_error(strdup("Unknown error"));
117+
}
118+
}
119+
return -1;
120+
}
121+
}

stan

Submodule stan added at ca02539

0 commit comments

Comments
 (0)