diff --git a/packages/backend/src/controller/csrfCheck.ts b/packages/backend/src/controller/csrfCheck.ts index 727b65b..f5018d5 100644 --- a/packages/backend/src/controller/csrfCheck.ts +++ b/packages/backend/src/controller/csrfCheck.ts @@ -5,18 +5,27 @@ import { HttpUtils } from "../utils/http"; const httpUtils = new HttpUtils(); export class CsrfCheck { + private isTargetUri(uri: string): boolean { + if ( + uri.includes("client_id=") && + (uri.includes("response_type=") || + uri.includes("grant_type=") || + uri.includes("redirect_uri=") || + uri.includes("scope=") || + uri.includes("state=") || + uri.includes("nonce=")) + ) { + return true; + } + + return false; + } + private isOauthUri(request: Request): boolean { - const query = request.getQuery() || ""; + const uri = request.getUrl() || ""; // Check if the request is an OAuth authorization request - if ( - query.includes("client_id=") && - (query.includes("response_type=") || - query.includes("grant_type=") || - query.includes("redirect_uri=") || - query.includes("scope=") || - query.includes("state=")) - ) { + if (this.isTargetUri(uri)) { return true; } @@ -25,23 +34,10 @@ export class CsrfCheck { private isOauthRedirectResponse(response: Response): boolean { const status = response.getCode(); - const locationHeader = httpUtils.getHeaderValue( - response.getHeaders(), - "location" - ); + const uri = + httpUtils.getHeaderValue(response.getHeaders(), "location") || ""; - if ( - status >= 300 && - status < 400 && - locationHeader && - (locationHeader.includes("client_id=") || - locationHeader.includes("response_type=") || - locationHeader.includes("grant_type=") || - locationHeader.includes("redirect_uri=") || - locationHeader.includes("scope=") || - locationHeader.includes("state=") || - locationHeader.includes("code=")) // code is also common in OAuth redirects - ) { + if (status >= 300 && status < 400 && this.isTargetUri(uri)) { return true; } return false; @@ -49,7 +45,9 @@ export class CsrfCheck { private isStateInQuery(request: Request): boolean { const query = request.getQuery(); - const stateValue = httpUtils.getQueryParam(query || "", "state"); + const stateValue = + httpUtils.getQueryParam(query || "", "state") || + httpUtils.getQueryParam(query || "", "nonce"); if (!stateValue) { return false; } @@ -72,17 +70,18 @@ export class CsrfCheck { // 요청에서 보낸 state 추출 const query = request.getQuery() || ""; - const originalState = httpUtils.getQueryParam(query, "state"); + const originalState = + httpUtils.getQueryParam(query, "state") || + httpUtils.getQueryParam(query || "", "nonce"); // 리다이렉트 URL에서 쿼리 부분만 추출 const locationHeader = httpUtils.getHeaderValue( response.getHeaders(), "location" ); - const responseState = httpUtils.getQueryParamFromURI( - locationHeader || "", - "state" - ); + const responseState = + httpUtils.getQueryParamFromURI(locationHeader || "", "state") || + httpUtils.getQueryParamFromURI(locationHeader || "", "nonce"); // state가 없거나, 요청값과 다르면 CSRF 위험 if (!responseState) {