@@ -23,16 +23,11 @@ import (
23
23
"strings"
24
24
25
25
clientnative "github.com/haproxytech/client-native/v4"
26
+ "github.com/haproxytech/client-native/v4/configuration"
26
27
"github.com/haproxytech/client-native/v4/models"
27
28
"github.com/haproxytech/dataplaneapi/log"
28
29
)
29
30
30
- var configVersion string
31
-
32
- func ConfigVersion () string {
33
- return configVersion
34
- }
35
-
36
31
// Adapter is just a wrapper over http handler function
37
32
type Adapter func (http.Handler ) http.Handler
38
33
@@ -125,6 +120,32 @@ func ApacheLogMiddleware(logger *log.ACLLogger) Adapter {
125
120
}
126
121
}
127
122
123
+ type serverWriter struct {
124
+ w http.ResponseWriter
125
+ client configuration.Configuration
126
+ transactionID string
127
+ wroteHeader bool
128
+ }
129
+
130
+ func (s serverWriter ) WriteHeader (code int ) {
131
+ if ! s .wroteHeader {
132
+ version , err := fetchConfgVersion (s .client , s .transactionID )
133
+ if err == nil {
134
+ s .w .Header ().Set ("Configuration-Version" , version )
135
+ }
136
+ s .wroteHeader = true //nolint:staticcheck
137
+ }
138
+ s .w .WriteHeader (code )
139
+ }
140
+
141
+ func (s serverWriter ) Write (b []byte ) (int , error ) {
142
+ return s .w .Write (b )
143
+ }
144
+
145
+ func (s serverWriter ) Header () http.Header {
146
+ return s .w .Header ()
147
+ }
148
+
128
149
func ConfigVersionMiddleware (client clientnative.HAProxyClient ) Adapter {
129
150
return func (h http.Handler ) http.Handler {
130
151
return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
@@ -134,20 +155,30 @@ func ConfigVersionMiddleware(client clientnative.HAProxyClient) Adapter {
134
155
if err != nil {
135
156
http .Error (w , err .Error (), http .StatusNotImplemented )
136
157
}
137
- var v int64
138
- if tID == "" {
139
- v , err = configuration .GetConfigurationVersion ("" )
140
- } else {
141
- tr , _ := configuration .GetTransaction (tID )
142
- if tr != nil && tr .Status == models .TransactionStatusInProgress {
143
- v , err = configuration .GetConfigurationVersion (tr .ID )
144
- }
158
+ sw := serverWriter {
159
+ w : w ,
160
+ client : configuration ,
161
+ transactionID : tID ,
162
+ wroteHeader : false ,
145
163
}
146
- if err == nil {
147
- configVersion = strconv .FormatInt (v , 10 )
148
- w .Header ().Add ("Configuration-Version" , configVersion )
149
- }
150
- h .ServeHTTP (w , r )
164
+ h .ServeHTTP (sw , r )
151
165
})
152
166
}
153
167
}
168
+
169
+ func fetchConfgVersion (client configuration.Configuration , transactionID string ) (string , error ) {
170
+ var v int64
171
+ var err error
172
+ if transactionID == "" {
173
+ v , err = client .GetConfigurationVersion ("" )
174
+ } else {
175
+ tr , _ := client .GetTransaction (transactionID )
176
+ if tr != nil && tr .Status == models .TransactionStatusInProgress {
177
+ v , err = client .GetConfigurationVersion (tr .ID )
178
+ }
179
+ }
180
+ if err == nil {
181
+ return strconv .FormatInt (v , 10 ), nil
182
+ }
183
+ return "" , err
184
+ }
0 commit comments