Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore: ✏️ Modification of Refresh Token Specification #203

Merged
merged 15 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ public record General(
String username,
@Schema(description = "비밀번호", example = "pennyway1234")
@NotBlank(message = "비밀번호를 입력해주세요")
String password
String password,
@Schema(description = "사용자 기기 고유 식별자", example = "AA-BBB-CCC")
@NotBlank(message = "사용자 기기 고유 식별자를 입력해주세요")
String deviceId
) {
}

Expand All @@ -25,7 +28,10 @@ public record Oauth(
String idToken,
@Schema(description = "OIDC nonce")
@NotBlank(message = "OIDC nonce는 필수 입력값입니다.")
String nonce
String nonce,
@Schema(description = "사용자 기기 고유 식별자", example = "AA-BBB-CCC")
@NotBlank(message = "사용자 기기 고유 식별자를 입력해주세요")
String deviceId
) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
* 일반 회원가입 시엔 General, 소셜 회원가입 시엔 Oauth를 사용합니다.
*/
public class SignUpReq {
public record Info(String username, String name, String password, String phone, String code) {
public record Info(String username, String name, String password, String phone, String code, String deviceId) {
public String password(PasswordEncoder passwordEncoder) {
return passwordEncoder.encode(password);
}
Expand All @@ -43,7 +43,7 @@ public String password() {
}

public record OauthInfo(String oauthId, String idToken, String nonce, String name, String username, String phone,
String code) {
String code, String deviceId) {
public User toUser() {
return User.builder()
.username(username)
Expand Down Expand Up @@ -77,10 +77,13 @@ public record General(
@Schema(description = "6자리 정수 인증번호", example = "123456")
@NotBlank(message = "인증번호는 필수입니다.")
@Pattern(regexp = "^\\d{6}$", message = "인증번호는 6자리 숫자여야 합니다.")
String code
String code,
@Schema(description = "사용자 기기 고유 식별자", example = "AA-BBB-CCC")
@NotBlank(message = "사용자 기기 고유 식별자를 입력해주세요")
String deviceId
) {
public Info toInfo() {
return new Info(username, name, password, phone, code);
return new Info(username, name, password, phone, code, deviceId);
}
}

Expand All @@ -97,10 +100,13 @@ public record SyncWithOauth(
@Schema(description = "6자리 정수 인증번호", example = "123456")
@NotBlank(message = "인증번호는 필수입니다.")
@Pattern(regexp = "^\\d{6}$", message = "인증번호는 6자리 숫자여야 합니다.")
String code
String code,
@Schema(description = "사용자 기기 고유 식별자", example = "AA-BBB-CCC")
@NotBlank(message = "사용자 기기 고유 식별자를 입력해주세요")
String deviceId
) {
public Info toInfo() {
return new Info(null, null, password, phone, code);
return new Info(null, null, password, phone, code, deviceId);
}
}

Expand Down Expand Up @@ -130,10 +136,13 @@ public record Oauth(
@Schema(description = "6자리 정수 인증번호", example = "123456")
@NotBlank(message = "인증번호는 필수입니다.")
@Pattern(regexp = "^\\d{6}$", message = "인증번호는 6자리 숫자여야 합니다.")
String code
String code,
@Schema(description = "사용자 기기 고유 식별자", example = "AA-BBB-CCC")
@NotBlank(message = "사용자 기기 고유 식별자를 입력해주세요")
String deviceId
) {
public OauthInfo toOauthInfo() {
return new OauthInfo(oauthId, idToken, nonce, name, username, phone, code);
return new OauthInfo(oauthId, idToken, nonce, name, username, phone, code, deviceId);
}
}

Expand All @@ -155,10 +164,13 @@ public record SyncWithAuth(
@Schema(description = "6자리 정수 인증번호", example = "123456")
@NotBlank(message = "인증번호는 필수입니다.")
@Pattern(regexp = "^\\d{6}$", message = "인증번호는 6자리 숫자여야 합니다.")
String code
String code,
@Schema(description = "사용자 기기 고유 식별자", example = "AA-BBB-CCC")
@NotBlank(message = "사용자 기기 고유 식별자를 입력해주세요")
String deviceId
) {
public OauthInfo toOauthInfo() {
return new OauthInfo(oauthId, idToken, nonce, null, null, phone, code);
return new OauthInfo(oauthId, idToken, nonce, null, null, phone, code, deviceId);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@ public JwtAuthHelper(
* 사용자 정보 기반으로 access token과 refresh token을 생성하는 메서드 <br/>
* refresh token은 redis에 저장된다.
*
* @param user {@link User}
* @param user {@link User} : 사용자 정보
* @param deviceId String : 사용자의 디바이스 고유 식별자
* @return {@link Jwts}
*/
public Jwts createToken(User user) {
public Jwts createToken(User user, String deviceId) {
String accessToken = accessTokenProvider.generateToken(AccessTokenClaim.of(user.getId(), user.getRole().getType()));
String refreshToken = refreshTokenProvider.generateToken(RefreshTokenClaim.of(user.getId(), user.getRole().getType()));
String refreshToken = refreshTokenProvider.generateToken(RefreshTokenClaim.of(user.getId(), deviceId, user.getRole().getType()));

refreshTokenService.save(RefreshToken.of(user.getId(), refreshToken, toSeconds(refreshTokenProvider.getExpiryDate(refreshToken))));
refreshTokenService.save(RefreshToken.of(user.getId(), deviceId, refreshToken, toSeconds(refreshTokenProvider.getExpiryDate(refreshToken))));
return Jwts.of(accessToken, refreshToken);
}

Expand All @@ -62,11 +63,12 @@ public Pair<Long, Jwts> refresh(String refreshToken) {

Long userId = JwtClaimsParserUtil.getClaimsValue(claims, RefreshTokenClaimKeys.USER_ID.getValue(), Long::parseLong);
String role = JwtClaimsParserUtil.getClaimsValue(claims, RefreshTokenClaimKeys.ROLE.getValue(), String.class);
log.debug("refresh token userId : {}, role : {}", userId, role);
String deviceId = JwtClaimsParserUtil.getClaimsValue(claims, RefreshTokenClaimKeys.DEVICE_ID.getValue(), String.class);
log.debug("refresh token userId : {}, deviceId: {}, role : {}", userId, deviceId, role);

RefreshToken newRefreshToken;
try {
newRefreshToken = refreshTokenService.refresh(userId, refreshToken, refreshTokenProvider.generateToken(RefreshTokenClaim.of(userId, role)));
newRefreshToken = refreshTokenService.refresh(userId, deviceId, refreshToken, refreshTokenProvider.generateToken(RefreshTokenClaim.of(userId, deviceId, role)));
log.debug("new refresh token : {}", newRefreshToken.getToken());
} catch (IllegalArgumentException e) {
throw new JwtErrorException(JwtErrorCode.EXPIRED_TOKEN);
Expand Down Expand Up @@ -102,22 +104,23 @@ public void removeAccessTokenAndRefreshToken(Long userId, String accessToken, St
}

if (jwtClaims != null) {
deleteRefreshToken(userId, jwtClaims, refreshToken);
deleteRefreshToken(userId, jwtClaims);
}

deleteAccessToken(userId, accessToken);
}

private void deleteRefreshToken(Long userId, JwtClaims jwtClaims, String refreshToken) {
Long refreshTokenUserId = Long.parseLong((String) jwtClaims.getClaims().get(RefreshTokenClaimKeys.USER_ID.getValue()));
log.info("로그아웃 요청 refresh token id : {}", refreshTokenUserId);
private void deleteRefreshToken(Long userId, JwtClaims jwtClaims) {
Long refreshTokenUserId = JwtClaimsParserUtil.getClaimsValue(jwtClaims, RefreshTokenClaimKeys.USER_ID.getValue(), Long::parseLong);
String refreshTokenDeviceId = JwtClaimsParserUtil.getClaimsValue(jwtClaims, RefreshTokenClaimKeys.DEVICE_ID.getValue(), String.class);
log.info("로그아웃 요청 refresh token userId : {}, deviceId : {}", refreshTokenUserId, refreshTokenDeviceId);

if (!userId.equals(refreshTokenUserId)) {
throw new JwtErrorException(JwtErrorCode.WITHOUT_OWNERSHIP_REFRESH_TOKEN);
}

try {
refreshTokenService.delete(refreshTokenUserId, refreshToken);
refreshTokenService.deleteAll(refreshTokenUserId);
} catch (IllegalArgumentException e) {
log.warn("refresh token not found. id : {}", userId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ public Pair<Long, Jwts> signUp(SignUpReq.Info request) {
UserSyncDto userSync = checkOauthUserNotGeneralSignUp(request.phone());
User user = userGeneralSignService.saveUserWithEncryptedPassword(request, userSync);

return Pair.of(user.getId(), jwtAuthHelper.createToken(user));
return Pair.of(user.getId(), jwtAuthHelper.createToken(user, request.deviceId()));
}

@Transactional(readOnly = true)
public Pair<Long, Jwts> signIn(SignInReq.General request) {
User user = userGeneralSignService.readUserIfValid(request.username(), request.password());

return Pair.of(user.getId(), jwtAuthHelper.createToken(user));
return Pair.of(user.getId(), jwtAuthHelper.createToken(user, request.deviceId()));
}

public Pair<Long, Jwts> refresh(String refreshToken) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public Pair<Long, Jwts> signIn(Provider provider, SignInReq.Oauth request) {

User user = userOauthSignService.readUser(request.oauthId(), provider);

return (user != null) ? Pair.of(user.getId(), jwtAuthHelper.createToken(user)) : Pair.of(-1L, null);
return (user != null) ? Pair.of(user.getId(), jwtAuthHelper.createToken(user, request.deviceId())) : Pair.of(-1L, null);
}

@Transactional(readOnly = true)
Expand Down Expand Up @@ -67,7 +67,7 @@ public Pair<Long, Jwts> signUp(Provider provider, SignUpReq.OauthInfo request) {
OidcDecodePayload payload = oauthOidcHelper.getPayload(provider, request.oauthId(), request.idToken(), request.nonce());
User user = userOauthSignService.saveUser(request, userSync, provider, payload.sub());

return Pair.of(user.getId(), jwtAuthHelper.createToken(user));
return Pair.of(user.getId(), jwtAuthHelper.createToken(user, request.deviceId()));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@

import java.util.Map;

import static kr.co.pennyway.api.common.security.jwt.refresh.RefreshTokenClaimKeys.ROLE;
import static kr.co.pennyway.api.common.security.jwt.refresh.RefreshTokenClaimKeys.USER_ID;
import static kr.co.pennyway.api.common.security.jwt.refresh.RefreshTokenClaimKeys.*;

@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public class RefreshTokenClaim implements JwtClaims {
private final Map<String, ?> claims;

public static RefreshTokenClaim of(Long userId, String role) {
public static RefreshTokenClaim of(Long userId, String deviceToken, String role) {
Map<String, Object> claims = Map.of(
USER_ID.getValue(), userId.toString(),
DEVICE_ID.getValue(), deviceToken,
ROLE.getValue(), role
);
return new RefreshTokenClaim(claims);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

public enum RefreshTokenClaimKeys {
USER_ID("id"),
ROLE("role");
ROLE("role"),
DEVICE_ID("deviceId");

private final String value;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
import java.util.Date;
import java.util.Map;

import static kr.co.pennyway.api.common.security.jwt.refresh.RefreshTokenClaimKeys.ROLE;
import static kr.co.pennyway.api.common.security.jwt.refresh.RefreshTokenClaimKeys.USER_ID;
import static kr.co.pennyway.api.common.security.jwt.refresh.RefreshTokenClaimKeys.*;

@Slf4j
@Component
Expand Down Expand Up @@ -56,7 +55,11 @@ public String generateToken(JwtClaims claims) {
@Override
public JwtClaims getJwtClaimsFromToken(String token) {
Claims claims = getClaimsFromToken(token);
return RefreshTokenClaim.of(Long.parseLong(claims.get(USER_ID.getValue(), String.class)), claims.get(ROLE.getValue(), String.class));
return RefreshTokenClaim.of(
Long.parseLong(claims.get(USER_ID.getValue(), String.class)),
claims.get(DEVICE_ID.getValue(), String.class),
claims.get(ROLE.getValue(), String.class)
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ void setUp(WebApplicationContext webApplicationContext) {
.build();
}

@DisplayName("[1] 아이디, 이름, 비밀번호, 전화번호, 인증번호 필수 입력")
@DisplayName("[1] 아이디, 이름, 비밀번호, 전화번호, 인증번호, 디바이스 아이디 필수 입력")
@Test
void requiredInputError() throws Exception {
// given
SignUpReq.General request = new SignUpReq.General("", "", "", "", "");
SignUpReq.General request = new SignUpReq.General("", "", "", "", "", "");

// when
ResultActions resultActions = mockMvc.perform(
Expand All @@ -73,6 +73,7 @@ void requiredInputError() throws Exception {
.andExpect(jsonPath("$.fieldErrors.password").exists())
.andExpect(jsonPath("$.fieldErrors.phone").exists())
.andExpect(jsonPath("$.fieldErrors.code").exists())
.andExpect(jsonPath("$.fieldErrors.deviceId").exists())
.andDo(print());
}

Expand All @@ -81,7 +82,7 @@ void requiredInputError() throws Exception {
void idValidError() throws Exception {
// given
SignUpReq.General request = new SignUpReq.General("#pennyway", "페니웨이", "pennyway1234",
"010-1234-5678", "123456");
"010-1234-5678", "123456", "AA-BBB-CCC");

// when
ResultActions resultActions = mockMvc.perform(
Expand All @@ -102,7 +103,7 @@ void idValidError() throws Exception {
void nameValidError() throws Exception {
// given
SignUpReq.General request = new SignUpReq.General("pennyway", "페니웨이12345", "pennyway1234",
"010-1234-5678", "123456");
"010-1234-5678", "123456", "AA-BBB-CCC");

// when
ResultActions resultActions = mockMvc.perform(
Expand All @@ -123,7 +124,7 @@ void nameValidError() throws Exception {
void passwordValidError() throws Exception {
// given
SignUpReq.General request = new SignUpReq.General("pennyway", "페니웨이", "pennyway",
"010-1234-5678", "123456");
"010-1234-5678", "123456", "AA-BBB-CCC");

// when
ResultActions resultActions = mockMvc.perform(
Expand All @@ -145,7 +146,7 @@ void passwordValidError() throws Exception {
void phoneValidError() throws Exception {
// given
SignUpReq.General request = new SignUpReq.General("pennyway", "페니웨이", "pennyway1234",
"01012345673", "123456");
"01012345673", "123456", "AA-BBB-CCC");

// when
ResultActions resultActions = mockMvc.perform(
Expand All @@ -166,7 +167,7 @@ void phoneValidError() throws Exception {
void codeValidError() throws Exception {
// given
SignUpReq.General request = new SignUpReq.General("pennyway", "페니웨이", "pennyway1234",
"010-1234-5678", "12345");
"010-1234-5678", "12345", "AA-BBB-CCC");

// when
ResultActions resultActions = mockMvc.perform(
Expand All @@ -187,7 +188,7 @@ void codeValidError() throws Exception {
void someFieldMissingError() throws Exception {
// given
SignUpReq.General request = new SignUpReq.General("pennyway", "페니웨이", "pennyway1234",
"010-1234-5678", "123456");
"010-1234-5678", "123456", "AA-BBB-CCC");

// when
ResultActions resultActions = mockMvc.perform(
Expand All @@ -210,7 +211,7 @@ void someFieldMissingError() throws Exception {
void signUp() throws Exception {
// given
SignUpReq.General request = new SignUpReq.General("pennyway123", "페니웨이", "pennyway1234",
"010-1234-5678", "123456");
"010-1234-5678", "123456", "AA-BBB-CCC");
ResponseCookie expectedCookie = ResponseCookie.from("refreshToken", "refreshToken")
.maxAge(Duration.ofDays(7).toSeconds()).httpOnly(true).path("/").build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,12 @@ public void RefreshTokenRefreshSuccess() {
// given
RefreshToken refreshToken = RefreshToken.builder()
.userId(1L)
.deviceId("AA-BBB-CC-DDD")
.token("refreshToken")
.ttl(1000L)
.build();
refreshTokenRepository.save(refreshToken);
given(refreshTokenProvider.getJwtClaimsFromToken(refreshToken.getToken())).willReturn(RefreshTokenClaim.of(refreshToken.getUserId(), Role.USER.getType()));
given(refreshTokenProvider.getJwtClaimsFromToken(refreshToken.getToken())).willReturn(RefreshTokenClaim.of(refreshToken.getUserId(), refreshToken.getDeviceId(), Role.USER.getType()));
given(accessTokenProvider.generateToken(any())).willReturn("newAccessToken");
given(refreshTokenProvider.generateToken(any())).willReturn("newRefreshToken");

Expand All @@ -76,7 +77,7 @@ public void RefreshTokenRefreshSuccess() {
assertEquals("사용자 아이디가 일치하지 않습니다.", refreshToken.getUserId(), jwts.getLeft());
assertEquals("갱신된 액세스 토큰이 일치하지 않습니다.", "newAccessToken", jwts.getRight().accessToken());
assertEquals("리프레시 토큰이 갱신되지 않았습니다.", "newRefreshToken", jwts.getRight().refreshToken());
log.info("갱신된 리프레시 토큰 정보 : {}", refreshTokenRepository.findById(refreshToken.getUserId()).orElse(null));
log.info("갱신된 리프레시 토큰 정보 : {}", refreshTokenRepository.findById(refreshToken.getId()).orElse(null));
}

@Test
Expand All @@ -85,19 +86,20 @@ public void RefreshTokenRefreshFail() {
// given
RefreshToken refreshToken = RefreshToken.builder()
.userId(1L)
.deviceId("AA-BBB-CC-DDD")
.token("refreshToken")
.ttl(1000L)
.build();
refreshTokenRepository.save(refreshToken);

given(refreshTokenProvider.getJwtClaimsFromToken("anotherRefreshToken")).willReturn(RefreshTokenClaim.of(refreshToken.getUserId(), Role.USER.toString()));
given(refreshTokenProvider.getJwtClaimsFromToken("anotherRefreshToken")).willReturn(RefreshTokenClaim.of(refreshToken.getUserId(), refreshToken.getDeviceId(), Role.USER.toString()));
given(refreshTokenProvider.generateToken(any())).willReturn("newRefreshToken");

// when
JwtErrorException jwtErrorException = assertThrows(JwtErrorException.class, () -> jwtAuthHelper.refresh("anotherRefreshToken"));

// then
assertEquals("탈취 시나리오 예외가 발생하지 않았습니다.", JwtErrorCode.TAKEN_AWAY_TOKEN, jwtErrorException.getErrorCode());
assertFalse("리프레시 토큰이 삭제되지 않았습니다.", refreshTokenRepository.existsById(refreshToken.getUserId()));
assertFalse("리프레시 토큰이 삭제되지 않았습니다.", refreshTokenRepository.existsById(refreshToken.getId()));
}
}
Loading
Loading