Skip to content

Commit

Permalink
Merge pull request #381 from supertokens/tests-fix
Browse files Browse the repository at this point in the history
refactor: Check for status in github validate access token
  • Loading branch information
rishabhpoddar authored Oct 6, 2023
2 parents a8829da + e2318f3 commit ee60032
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 12 deletions.
4 changes: 2 additions & 2 deletions recipe/thirdparty/providers/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ func Github(input tpmodels.ProviderInput) *tpmodels.TypeProvider {
basicAuthToken := base64.StdEncoding.EncodeToString([]byte(clientConfig.ClientID + ":" + clientConfig.ClientSecret))
wrongClientIdError := errors.New("Access token does not belong to your application")

resp, err := doPostRequest("https://api.github.com/applications/"+clientConfig.ClientID+"/token", map[string]interface{}{
resp, status, err := doPostRequest("https://api.github.com/applications/"+clientConfig.ClientID+"/token", map[string]interface{}{
"access_token": accessToken,
}, map[string]interface{}{
"Authorization": "Basic " + basicAuthToken,
"Content-Type": "application/json",
})

if err != nil {
if err != nil || status != 200 {
return errors.New("Invalid access token")
}

Expand Down
2 changes: 1 addition & 1 deletion recipe/thirdparty/providers/oauth2_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func oauth2_ExchangeAuthCodeForOAuthTokens(config tpmodels.ProviderConfigForClie
}
/* Transformation needed for dev keys END */

oAuthTokens, err := doPostRequest(tokenAPIURL, accessTokenAPIParams, nil)
oAuthTokens, _, err := doPostRequest(tokenAPIURL, accessTokenAPIParams, nil)
if err != nil {
return nil, err
}
Expand Down
4 changes: 3 additions & 1 deletion recipe/thirdparty/providers/twitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,11 @@ func Twitter(input tpmodels.ProviderInput) *tpmodels.TypeProvider {
twitterOauthParams["redirect_uri"] = redirectUri
twitterOauthParams["code"] = redirectURIInfo.RedirectURIQueryParams["code"]

return doPostRequest(originalImplementation.Config.TokenEndpoint, twitterOauthParams, map[string]interface{}{
resp, _, err := doPostRequest(originalImplementation.Config.TokenEndpoint, twitterOauthParams, map[string]interface{}{
"Authorization": "Basic " + basicAuthToken,
})

return resp, err
}

if oOverride != nil {
Expand Down
16 changes: 8 additions & 8 deletions recipe/thirdparty/providers/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,16 @@ func doGetRequest(url string, queryParams map[string]interface{}, headers map[st
return result, nil
}

func doPostRequest(url string, params map[string]interface{}, headers map[string]interface{}) (map[string]interface{}, error) {
func doPostRequest(url string, params map[string]interface{}, headers map[string]interface{}) (map[string]interface{}, int, error) {
supertokens.LogDebugMessage(fmt.Sprintf("POST request to %s, with form fields %v and headers %v", url, params, headers))

postBody, err := qs.Marshal(params)
if err != nil {
return nil, err
return nil, -1, err
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer([]byte(postBody)))
if err != nil {
return nil, err
return nil, -1, err
}
for key, value := range headers {
req.Header.Set(key, value.(string))
Expand All @@ -110,28 +110,28 @@ func doPostRequest(url string, params map[string]interface{}, headers map[string
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
return nil, resp.StatusCode, err
}
defer resp.Body.Close()

body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
return nil, resp.StatusCode, err
}

supertokens.LogDebugMessage(fmt.Sprintf("Received response with status %d and body %s", resp.StatusCode, string(body)))

var result map[string]interface{}
err = json.Unmarshal(body, &result)
if err != nil {
return nil, err
return nil, resp.StatusCode, err
}

if resp.StatusCode >= 300 {
return nil, fmt.Errorf("POST request to %s resulted in %d status with body %s", url, resp.StatusCode, string(body))
return nil, resp.StatusCode, fmt.Errorf("POST request to %s resulted in %d status with body %s", url, resp.StatusCode, string(body))
}

return result, nil
return result, resp.StatusCode, nil
}

// JWKS utils
Expand Down

0 comments on commit ee60032

Please sign in to comment.