diff --git a/authapi/authapi.go b/authapi/authapi.go index c782651..03c71c0 100644 --- a/authapi/authapi.go +++ b/authapi/authapi.go @@ -48,6 +48,7 @@ func (api *AuthApi) Ping() (*PingResult, error) { if err = json.Unmarshal(body, ret); err != nil { return nil, err } + ret.SyncCode() return ret, nil } @@ -72,6 +73,7 @@ func (api *AuthApi) Check() (*CheckResult, error) { if err = json.Unmarshal(body, ret); err != nil { return nil, err } + ret.SyncCode() return ret, nil } @@ -98,6 +100,7 @@ func (api *AuthApi) Logo() (*LogoResult, error) { if err = json.Unmarshal(body, ret); err != nil { return nil, err } + ret.SyncCode() return ret, nil } @@ -145,6 +148,7 @@ func (api *AuthApi) Enroll(options ...func(*url.Values)) (*EnrollResult, error) if err = json.Unmarshal(body, ret); err != nil { return nil, err } + ret.SyncCode() return ret, nil } @@ -174,6 +178,7 @@ func (api *AuthApi) EnrollStatus(userid string, if err = json.Unmarshal(body, ret); err != nil { return nil, err } + ret.SyncCode() return ret, nil } @@ -239,6 +244,7 @@ func (api *AuthApi) Preauth(options ...func(*url.Values)) (*PreauthResult, error if err = json.Unmarshal(body, ret); err != nil { return nil, err } + ret.SyncCode() return ret, nil } @@ -348,6 +354,7 @@ func (api *AuthApi) Auth(factor string, options ...func(*url.Values)) (*AuthResu if err = json.Unmarshal(body, ret); err != nil { return nil, err } + ret.SyncCode() return ret, nil } @@ -377,5 +384,6 @@ func (api *AuthApi) AuthStatus(txid string) (*AuthStatusResult, error) { if err = json.Unmarshal(body, ret); err != nil { return nil, err } + ret.SyncCode() return ret, nil } diff --git a/authapi/authapi_test.go b/authapi/authapi_test.go index f373756..320d550 100644 --- a/authapi/authapi_test.go +++ b/authapi/authapi_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - "github.com/duosecurity/duo_api_golang" + duoapi "github.com/duosecurity/duo_api_golang" ) func buildAuthApi(url string, proxy func(*http.Request) (*url.URL, error)) *AuthApi { @@ -203,7 +203,7 @@ func TestLogo(t *testing.T) { } } -// Test a failure logo reqeust / response. +// Test a failure logo request / response. func TestLogoError(t *testing.T) { ts := httptest.NewTLSServer( http.HandlerFunc( @@ -602,3 +602,40 @@ func TestAuthStatus(t *testing.T) { t.Error("Unexpected response status msg: " + res.Response.Status_Msg) } } + +// Test a response with empty code. +func TestEmptyResponseCode(t *testing.T) { + ts := httptest.NewTLSServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // Return a 400, as if the logo was not found. + w.WriteHeader(400) + fmt.Fprintln(w, ` + { + "stat": "FAIL", + "code": "", + "message": "Code is empty", + "message_detail": "Deal with it" + }`) + })) + defer ts.Close() + + duo := buildAuthApi(ts.URL, nil) + + res, err := duo.Logo() + if err != nil { + t.Error("Failed TestCheck: " + err.Error()) + } + if res.Stat != "FAIL" { + t.Error("Expected FAIL, but got " + res.Stat) + } + if res.Code == nil || *res.Code != 0 { + t.Error("Unexpected response code.") + } + if res.Message == nil || *res.Message != "Code is empty" { + t.Error("Unexpected message.") + } + if res.Message_Detail == nil || *res.Message_Detail != "Deal with it" { + t.Error("Unexpected message detail.") + } +} diff --git a/duoapi.go b/duoapi.go index 5932b59..32e8b97 100644 --- a/duoapi.go +++ b/duoapi.go @@ -259,14 +259,41 @@ func (duoapi *DuoApi) buildOptions(options ...DuoApiOption) *requestOptions { return opts } +type NullableInt32 struct { + value *int32 +} + // API calls will return a StatResult object. On success, Stat is 'OK'. // On error, Stat is 'FAIL', and Code, Message, and Message_Detail // contain error information. type StatResult struct { - Stat string - Code *int32 - Message *string - Message_Detail *string + Stat string `json:"stat"` + Ncode NullableInt32 `json:"code"` + Code *int32 `json:"-"` + Message *string `json:"message"` + Message_Detail *string `json:"message_detail"` +} + +func (n *NullableInt32) UnmarshalJSON(data []byte) error { + var raw interface{} + + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + switch v := raw.(type) { + case float64: + intVal := int32(v) + n.value = &intVal + case string: + intVal := int32(0) + n.value = &intVal + } + return nil +} + +func (s *StatResult) SyncCode() { + s.Code = s.Ncode.value } // SetCustomHTTPClient allows one to set a completely custom http client that