diff --git a/packages/cli/src/commands/login/apply-login.mts b/packages/cli/src/commands/login/apply-login.mts index 3645c8682..471eb3597 100644 --- a/packages/cli/src/commands/login/apply-login.mts +++ b/packages/cli/src/commands/login/apply-login.mts @@ -2,18 +2,52 @@ import { CONFIG_KEY_API_BASE_URL, CONFIG_KEY_API_PROXY, CONFIG_KEY_API_TOKEN, + CONFIG_KEY_AUTH_BASE_URL, CONFIG_KEY_ENFORCED_ORGS, + CONFIG_KEY_OAUTH_CLIENT_ID, + CONFIG_KEY_OAUTH_REDIRECT_URI, + CONFIG_KEY_OAUTH_REFRESH_TOKEN, + CONFIG_KEY_OAUTH_SCOPES, + CONFIG_KEY_OAUTH_TOKEN_EXPIRES_AT, } from '../../constants/config.mts' import { updateConfigValue } from '../../utils/config.mts' -export function applyLogin( - apiToken: string, - enforcedOrgs: string[], - apiBaseUrl: string | undefined, - apiProxy: string | undefined, -) { - updateConfigValue(CONFIG_KEY_ENFORCED_ORGS, enforcedOrgs) - updateConfigValue(CONFIG_KEY_API_TOKEN, apiToken) - updateConfigValue(CONFIG_KEY_API_BASE_URL, apiBaseUrl) - updateConfigValue(CONFIG_KEY_API_PROXY, apiProxy) +export function applyLogin(params: { + apiToken: string + enforcedOrgs: string[] + apiBaseUrl: string | undefined + apiProxy: string | undefined + authBaseUrl?: string | null | undefined + oauthClientId?: string | null | undefined + oauthRedirectUri?: string | null | undefined + oauthRefreshToken?: string | null | undefined + oauthScopes?: string[] | readonly string[] | null | undefined + oauthTokenExpiresAt?: number | null | undefined +}) { + updateConfigValue(CONFIG_KEY_ENFORCED_ORGS, params.enforcedOrgs) + updateConfigValue(CONFIG_KEY_API_TOKEN, params.apiToken) + updateConfigValue(CONFIG_KEY_API_BASE_URL, params.apiBaseUrl) + updateConfigValue(CONFIG_KEY_API_PROXY, params.apiProxy) + + if (params.authBaseUrl !== undefined) { + updateConfigValue(CONFIG_KEY_AUTH_BASE_URL, params.authBaseUrl) + } + if (params.oauthClientId !== undefined) { + updateConfigValue(CONFIG_KEY_OAUTH_CLIENT_ID, params.oauthClientId) + } + if (params.oauthRedirectUri !== undefined) { + updateConfigValue(CONFIG_KEY_OAUTH_REDIRECT_URI, params.oauthRedirectUri) + } + if (params.oauthRefreshToken !== undefined) { + updateConfigValue(CONFIG_KEY_OAUTH_REFRESH_TOKEN, params.oauthRefreshToken) + } + if (params.oauthScopes !== undefined) { + updateConfigValue(CONFIG_KEY_OAUTH_SCOPES, params.oauthScopes) + } + if (params.oauthTokenExpiresAt !== undefined) { + updateConfigValue( + CONFIG_KEY_OAUTH_TOKEN_EXPIRES_AT, + params.oauthTokenExpiresAt, + ) + } } diff --git a/packages/cli/src/commands/login/attempt-login.mts b/packages/cli/src/commands/login/attempt-login.mts index be38b4fee..603e081fa 100644 --- a/packages/cli/src/commands/login/attempt-login.mts +++ b/packages/cli/src/commands/login/attempt-login.mts @@ -2,19 +2,27 @@ import { joinAnd } from '@socketsecurity/lib/arrays' import { SOCKET_PUBLIC_API_TOKEN } from '@socketsecurity/lib/constants/socket' import { getDefaultLogger } from '@socketsecurity/lib/logger' import { confirm, password, select } from '@socketsecurity/lib/stdio/prompts' +import { isNonEmptyString } from '@socketsecurity/lib/strings' import { applyLogin } from './apply-login.mts' +import { oauthLogin } from './oauth-login.mts' import { CONFIG_KEY_API_BASE_URL, CONFIG_KEY_API_PROXY, CONFIG_KEY_API_TOKEN, + CONFIG_KEY_AUTH_BASE_URL, CONFIG_KEY_DEFAULT_ORG, + CONFIG_KEY_OAUTH_CLIENT_ID, + CONFIG_KEY_OAUTH_REDIRECT_URI, + CONFIG_KEY_OAUTH_SCOPES, } from '../../constants/config.mts' +import ENV from '../../constants/env.mts' import { getConfigValueOrUndef, isConfigFromFlag, updateConfigValue, } from '../../utils/config.mts' +import { deriveAuthBaseUrlFromApiBaseUrl } from '../../utils/auth/oauth.mts' import { failMsgWithBadge } from '../../utils/error/fail-msg-with-badge.mts' import { getEnterpriseOrgs, getOrgSlugs } from '../../utils/organization.mts' import { setupSdk } from '../../utils/socket/sdk.mjs' @@ -23,27 +31,150 @@ import { setupTabCompletion } from '../install/setup-tab-completion.mts' import { fetchOrganization } from '../organization/fetch-organization-list.mts' import type { Choice } from '@socketsecurity/lib/stdio/prompts' +import requirements from '../../../data/command-api-requirements.json' with { + type: 'json', +} const logger = getDefaultLogger() type OrgChoice = Choice type OrgChoices = OrgChoice[] +type LoginMethod = 'oauth' | 'token' + +function getDefaultOAuthScopes(): string[] { + const permissions: string[] = [] + const api = (requirements as any)?.api ?? {} + for (const value of Object.values(api) as any[]) { + const perms = (value?.permissions ?? []) as unknown + if (Array.isArray(perms)) { + for (const p of perms) { + if (typeof p === 'string' && p) { + permissions.push(p) + } + } + } + } + return [...new Set(permissions)].sort() +} + +function parseScopes(value: unknown): string[] | undefined { + if (Array.isArray(value)) { + return value + .filter((v): v is string => typeof v === 'string' && v.length > 0) + .sort() + } + if (!isNonEmptyString(String(value ?? ''))) { + return undefined + } + const raw = String(value) + return raw + .split(/[,\s]+/u) + .map(s => s.trim()) + .filter(Boolean) +} + export async function attemptLogin( apiBaseUrl: string | undefined, apiProxy: string | undefined, + options?: { + method?: LoginMethod | undefined + authBaseUrl?: string | undefined + oauthClientId?: string | undefined + oauthRedirectUri?: string | undefined + oauthScopes?: string | undefined + }, ) { apiBaseUrl ??= getConfigValueOrUndef(CONFIG_KEY_API_BASE_URL) ?? undefined apiProxy ??= getConfigValueOrUndef(CONFIG_KEY_API_PROXY) ?? undefined - const apiTokenInput = await password({ - message: `Enter your ${socketDocsLink('/docs/api-keys', 'Socket.dev API token')} (leave blank to use a limited public token)`, - }) + const method: LoginMethod = options?.method ?? 'oauth' - if (apiTokenInput === undefined) { - logger.fail('Canceled by user') - return { ok: false, message: 'Canceled', cause: 'Canceled by user' } - } + let apiToken: string + let oauthRefreshToken: string | null | undefined + let oauthTokenExpiresAt: number | null | undefined + let authBaseUrl: string | null | undefined + let oauthClientId: string | null | undefined + let oauthRedirectUri: string | null | undefined + let oauthScopes: string[] | null | undefined - const apiToken = apiTokenInput || SOCKET_PUBLIC_API_TOKEN + if (method === 'token') { + const apiTokenInput = await password({ + message: `Enter your ${socketDocsLink('/docs/api-keys', 'Socket.dev API token')} (leave blank to use a limited public token)`, + }) + + if (apiTokenInput === undefined) { + logger.fail('Canceled by user') + return { ok: false, message: 'Canceled', cause: 'Canceled by user' } + } + + apiToken = apiTokenInput || SOCKET_PUBLIC_API_TOKEN + + // Explicitly disable OAuth refresh flow when using a legacy org-wide token. + oauthRefreshToken = null + oauthTokenExpiresAt = null + authBaseUrl = null + oauthClientId = null + oauthRedirectUri = null + oauthScopes = null + } else { + const resolvedAuthBaseUrl = + options?.authBaseUrl || + ENV.SOCKET_CLI_AUTH_BASE_URL || + getConfigValueOrUndef(CONFIG_KEY_AUTH_BASE_URL) || + deriveAuthBaseUrlFromApiBaseUrl(apiBaseUrl) + + if (!isNonEmptyString(resolvedAuthBaseUrl)) { + process.exitCode = 1 + logger.fail( + 'OAuth auth base URL is not configured. Provide --auth-base-url or set SOCKET_CLI_AUTH_BASE_URL.', + ) + return + } + + const resolvedClientId = + options?.oauthClientId || + ENV.SOCKET_CLI_OAUTH_CLIENT_ID || + getConfigValueOrUndef(CONFIG_KEY_OAUTH_CLIENT_ID) || + 'socket-cli' + + const resolvedRedirectUri = + options?.oauthRedirectUri || + ENV.SOCKET_CLI_OAUTH_REDIRECT_URI || + getConfigValueOrUndef(CONFIG_KEY_OAUTH_REDIRECT_URI) || + 'http://127.0.0.1:53682/callback' + + const resolvedScopes = + parseScopes( + options?.oauthScopes || + ENV.SOCKET_CLI_OAUTH_SCOPES || + getConfigValueOrUndef(CONFIG_KEY_OAUTH_SCOPES) || + getDefaultOAuthScopes(), + ) ?? [] + + logger.log( + `Opening your browser to complete login (client_id: ${resolvedClientId})...`, + ) + + const oauthResult = await oauthLogin({ + authBaseUrl: resolvedAuthBaseUrl, + clientId: resolvedClientId, + redirectUri: resolvedRedirectUri, + scopes: resolvedScopes, + apiProxy, + }) + if (!oauthResult.ok) { + process.exitCode = 1 + logger.fail(failMsgWithBadge(oauthResult.message, oauthResult.cause)) + return + } + + apiToken = oauthResult.data.accessToken + oauthRefreshToken = oauthResult.data.refreshToken + oauthTokenExpiresAt = oauthResult.data.expiresAt + authBaseUrl = resolvedAuthBaseUrl + oauthClientId = resolvedClientId + oauthRedirectUri = resolvedRedirectUri + oauthScopes = resolvedScopes + } const sockSdkCResult = await setupSdk({ apiBaseUrl, apiProxy, apiToken }) if (!sockSdkCResult.ok) { @@ -155,7 +286,18 @@ export async function attemptLogin( const previousPersistedToken = getConfigValueOrUndef(CONFIG_KEY_API_TOKEN) try { - applyLogin(apiToken, enforcedOrgs, apiBaseUrl, apiProxy) + applyLogin({ + apiToken, + enforcedOrgs, + apiBaseUrl, + apiProxy, + authBaseUrl, + oauthClientId, + oauthRedirectUri, + oauthRefreshToken, + oauthScopes, + oauthTokenExpiresAt, + }) logger.success( `API credentials ${previousPersistedToken === apiToken ? 'refreshed' : previousPersistedToken ? 'updated' : 'set'}`, ) diff --git a/packages/cli/src/commands/login/cmd-login.mts b/packages/cli/src/commands/login/cmd-login.mts index c5552a1ff..847be5531 100644 --- a/packages/cli/src/commands/login/cmd-login.mts +++ b/packages/cli/src/commands/login/cmd-login.mts @@ -20,7 +20,7 @@ const logger = getDefaultLogger() export const CMD_NAME = 'login' -const description = 'Setup Socket CLI with an API token and defaults' +const description = 'Authenticate Socket CLI and store credentials' const hidden = false @@ -41,6 +41,11 @@ async function run( hidden, flags: { ...commonFlags, + method: { + type: 'string', + default: 'oauth', + description: 'Login method: oauth (default) or token (legacy)', + }, apiBaseUrl: { type: 'string', default: '', @@ -51,6 +56,29 @@ async function run( default: '', description: 'Proxy to use when making connection to API server', }, + authBaseUrl: { + type: 'string', + default: '', + description: + 'OAuth authorization server base URL (defaults to derived from apiBaseUrl)', + }, + oauthClientId: { + type: 'string', + default: '', + description: 'OAuth client_id (defaults to socket-cli)', + }, + oauthRedirectUri: { + type: 'string', + default: '', + description: + 'OAuth redirect URI (must match registered redirect URIs for client)', + }, + oauthScopes: { + type: 'string', + default: '', + description: + 'OAuth scopes to request (space or comma separated; defaults to CLI-required scopes)', + }, }, help: (command, config) => ` Usage @@ -59,13 +87,16 @@ async function run( API Token Requirements ${getFlagApiRequirementsOutput(`${parentName}:${CMD_NAME}`)} - Logs into the Socket API by prompting for an API token + Logs into the Socket API using a browser-based OAuth flow (default). + Use --method=token to enter an API token manually (legacy). Options ${getFlagListOutput(config.flags)} Examples $ ${command} + $ ${command} --method=token + $ ${command} --auth-base-url=https://api.socket.dev --oauth-client-id=socket-cli $ ${command} --api-proxy=http://localhost:1234 `, } @@ -86,14 +117,47 @@ async function run( if (!isInteractive()) { throw new InputError( - 'Cannot prompt for credentials in a non-interactive shell. Use SOCKET_CLI_API_TOKEN environment variable instead', + 'Cannot complete interactive login in a non-interactive shell. Use SOCKET_CLI_API_TOKEN environment variable instead', ) } - const { apiBaseUrl, apiProxy } = cli.flags as unknown as { + const { + apiBaseUrl, + apiProxy, + authBaseUrl, + method, + oauthClientId, + oauthRedirectUri, + oauthScopes, + } = cli.flags as unknown as { apiBaseUrl?: string | undefined apiProxy?: string | undefined + authBaseUrl?: string | undefined + method?: string | undefined + oauthClientId?: string | undefined + oauthRedirectUri?: string | undefined + oauthScopes?: string | undefined + } + + let normalizedMethod: 'oauth' | 'token' | undefined + if (method === 'oauth' || method === 'token') { + normalizedMethod = method + } else if (!method) { + normalizedMethod = undefined + } else { + normalizedMethod = undefined + } + if (method && !normalizedMethod) { + throw new InputError( + `Invalid --method value: ${method}. Expected "oauth" or "token".`, + ) } - await attemptLogin(apiBaseUrl, apiProxy) + await attemptLogin(apiBaseUrl, apiProxy, { + method: normalizedMethod, + authBaseUrl: authBaseUrl || undefined, + oauthClientId: oauthClientId || undefined, + oauthRedirectUri: oauthRedirectUri || undefined, + oauthScopes: oauthScopes || undefined, + }) } diff --git a/packages/cli/src/commands/login/oauth-login.mts b/packages/cli/src/commands/login/oauth-login.mts new file mode 100644 index 000000000..82d1e3eef --- /dev/null +++ b/packages/cli/src/commands/login/oauth-login.mts @@ -0,0 +1,322 @@ +import { createHash, randomBytes } from 'node:crypto' +import http from 'node:http' +import { setTimeout as wait } from 'node:timers/promises' + +import open from 'open' + +import { + exchangeAuthorizationCodeForToken, + fetchOAuthAuthorizationServerMetadata, +} from '../../utils/auth/oauth.mts' + +import type { CResult } from '../../types.mts' + +type OAuthLoginResult = { + accessToken: string + refreshToken: string + expiresAt: number + scope?: string | undefined +} + +function randomBase64Url(bytes = 32): string { + return randomBytes(bytes).toString('base64url') +} + +function sha256Base64Url(value: string): string { + return createHash('sha256').update(value).digest('base64url') +} + +function buildAuthorizeUrl(params: { + authorizationEndpoint: string + clientId: string + redirectUri: string + scopes: string[] + state: string + codeChallenge: string +}): string { + const url = new URL(params.authorizationEndpoint) + url.searchParams.set('response_type', 'code') + url.searchParams.set('client_id', params.clientId) + url.searchParams.set('redirect_uri', params.redirectUri) + if (params.scopes.length) { + url.searchParams.set('scope', params.scopes.join(' ')) + } + url.searchParams.set('state', params.state) + url.searchParams.set('code_challenge', params.codeChallenge) + url.searchParams.set('code_challenge_method', 'S256') + return url.toString() +} + +async function waitForCallback(params: { + redirectUri: string + expectedState: string + timeoutMs: number +}): Promise<{ + ready: Promise> + result: Promise> + close: () => void +}> { + let redirect: URL + try { + redirect = new URL(params.redirectUri) + } catch { + return { + ready: Promise.resolve({ + ok: false, + message: 'Invalid OAuth redirect URI', + cause: `Not a valid URL: ${params.redirectUri}`, + }), + result: Promise.resolve({ + ok: false, + message: 'Invalid OAuth redirect URI', + cause: `Not a valid URL: ${params.redirectUri}`, + }), + close: () => {}, + } + } + + if (redirect.protocol !== 'http:') { + const err: CResult = { + ok: false, + message: 'Invalid OAuth redirect URI', + cause: 'Redirect URI must use http:// for loopback redirect handling', + } + return { + ready: Promise.resolve(err), + result: Promise.resolve(err), + close: () => {}, + } + } + + const port = redirect.port ? Number(redirect.port) : 80 + if (!Number.isFinite(port) || port <= 0 || port > 65535) { + const err: CResult = { + ok: false, + message: 'Invalid OAuth redirect URI', + cause: `Invalid port in redirect URI: ${redirect.port || '(empty)'}`, + } + return { + ready: Promise.resolve(err), + result: Promise.resolve(err), + close: () => {}, + } + } + + const expectedPath = redirect.pathname || '/' + const host = redirect.hostname || '127.0.0.1' + if (!['127.0.0.1', 'localhost', '::1'].includes(host)) { + const err: CResult = { + ok: false, + message: 'Invalid OAuth redirect URI', + cause: `Redirect hostname must be a loopback address (got: ${host})`, + } + return { + ready: Promise.resolve(err), + result: Promise.resolve(err), + close: () => {}, + } + } + + let resolved = false + + let readyResolve: ((value: CResult) => void) | undefined + const ready = new Promise>(resolve => { + readyResolve = resolve + }) + + let resultResolve: ((value: CResult<{ code: string }>) => void) | undefined + const result = new Promise>(resolve => { + resultResolve = resolve + }) + + const server = http.createServer((req, res) => { + if (resolved) { + res.statusCode = 200 + res.end() + return + } + + if (!req.url) { + res.statusCode = 400 + res.end() + return + } + + const reqUrl = new URL(req.url, `http://${host}:${port}`) + if (req.method !== 'GET' || reqUrl.pathname !== expectedPath) { + res.statusCode = 404 + res.end() + return + } + + const state = reqUrl.searchParams.get('state') || '' + const code = reqUrl.searchParams.get('code') || '' + if (!code) { + res.statusCode = 400 + res.setHeader('content-type', 'text/plain; charset=utf-8') + res.end('Missing OAuth code') + return + } + if (state !== params.expectedState) { + res.statusCode = 400 + res.setHeader('content-type', 'text/plain; charset=utf-8') + res.end('Invalid OAuth state') + return + } + + resolved = true + res.statusCode = 200 + res.setHeader('content-type', 'text/html; charset=utf-8') + res.end( + 'Socket CLI Login

Login complete

You can close this tab and return to the Socket CLI.

', + ) + + server.close(() => { + resultResolve?.({ ok: true, data: { code } }) + }) + }) + + server.on('error', err => { + if (resolved) { + return + } + resolved = true + const failure: CResult = { + ok: false, + message: 'Failed to start OAuth callback server', + cause: err instanceof Error ? err.message : String(err), + } + readyResolve?.(failure) + resultResolve?.(failure) + }) + + server.listen(port, host, () => { + readyResolve?.({ ok: true, data: undefined }) + void wait(params.timeoutMs).then(() => { + if (resolved) { + return + } + resolved = true + server.close(() => { + resultResolve?.({ + ok: false, + message: 'OAuth login timed out', + cause: `No callback received within ${Math.round(params.timeoutMs / 1000)}s`, + }) + }) + }) + }) + + return { + ready, + result, + close: () => { + if (resolved) { + return + } + resolved = true + server.close(() => { + resultResolve?.({ + ok: false, + message: 'OAuth login canceled', + cause: 'OAuth callback server was closed before receiving a code', + }) + }) + }, + } +} + +export async function oauthLogin(params: { + authBaseUrl: string + clientId: string + redirectUri: string + scopes: string[] + apiProxy?: string | undefined + timeoutMs?: number | undefined +}): Promise> { + const timeoutMs = params.timeoutMs ?? 5 * 60 * 1000 + const metaResult = await fetchOAuthAuthorizationServerMetadata({ + authBaseUrl: params.authBaseUrl, + apiProxy: params.apiProxy, + }) + if (!metaResult.ok) { + return metaResult + } + + const { + authorization_endpoint: authorizationEndpoint, + token_endpoint: tokenEndpoint, + } = metaResult.data + + const codeVerifier = randomBase64Url(32) + const codeChallenge = sha256Base64Url(codeVerifier) + const state = randomBase64Url(16) + + const authorizeUrl = buildAuthorizeUrl({ + authorizationEndpoint, + clientId: params.clientId, + redirectUri: params.redirectUri, + scopes: params.scopes, + state, + codeChallenge, + }) + + const callbackWaiter = await waitForCallback({ + redirectUri: params.redirectUri, + expectedState: state, + timeoutMs, + }) + const readyResult = await callbackWaiter.ready + if (!readyResult.ok) { + return readyResult + } + + try { + await open(authorizeUrl, { wait: false }) + } catch (e) { + callbackWaiter.close() + return { + ok: false, + message: 'Failed to open browser for OAuth login', + cause: e instanceof Error ? e.message : String(e), + } + } + + const callbackResult = await callbackWaiter.result + if (!callbackResult.ok) { + return callbackResult + } + const { code } = callbackResult.data + + const tokenResult = await exchangeAuthorizationCodeForToken({ + tokenEndpoint, + clientId: params.clientId, + code, + redirectUri: params.redirectUri, + codeVerifier, + apiProxy: params.apiProxy, + }) + if (!tokenResult.ok) { + return tokenResult + } + + const token = tokenResult.data + if (!token.refresh_token) { + return { + ok: false, + message: 'OAuth login failed', + cause: 'Server did not return a refresh token', + } + } + + const expiresAt = Date.now() + Math.max(0, token.expires_in) * 1000 + return { + ok: true, + data: { + accessToken: token.access_token, + refreshToken: token.refresh_token, + expiresAt, + scope: token.scope, + }, + } +} diff --git a/packages/cli/src/commands/logout/apply-logout.mts b/packages/cli/src/commands/logout/apply-logout.mts index 242511cb3..c8c11e09b 100644 --- a/packages/cli/src/commands/logout/apply-logout.mts +++ b/packages/cli/src/commands/logout/apply-logout.mts @@ -2,7 +2,13 @@ import { CONFIG_KEY_API_BASE_URL, CONFIG_KEY_API_PROXY, CONFIG_KEY_API_TOKEN, + CONFIG_KEY_AUTH_BASE_URL, CONFIG_KEY_ENFORCED_ORGS, + CONFIG_KEY_OAUTH_CLIENT_ID, + CONFIG_KEY_OAUTH_REDIRECT_URI, + CONFIG_KEY_OAUTH_REFRESH_TOKEN, + CONFIG_KEY_OAUTH_SCOPES, + CONFIG_KEY_OAUTH_TOKEN_EXPIRES_AT, } from '../../constants/config.mts' import { updateConfigValue } from '../../utils/config.mts' @@ -11,4 +17,10 @@ export function applyLogout() { updateConfigValue(CONFIG_KEY_API_BASE_URL, null) updateConfigValue(CONFIG_KEY_API_PROXY, null) updateConfigValue(CONFIG_KEY_ENFORCED_ORGS, null) + updateConfigValue(CONFIG_KEY_AUTH_BASE_URL, null) + updateConfigValue(CONFIG_KEY_OAUTH_CLIENT_ID, null) + updateConfigValue(CONFIG_KEY_OAUTH_REDIRECT_URI, null) + updateConfigValue(CONFIG_KEY_OAUTH_REFRESH_TOKEN, null) + updateConfigValue(CONFIG_KEY_OAUTH_SCOPES, null) + updateConfigValue(CONFIG_KEY_OAUTH_TOKEN_EXPIRES_AT, null) } diff --git a/packages/cli/src/constants/config.mts b/packages/cli/src/constants/config.mts index 848ded6c5..86676a8a6 100644 --- a/packages/cli/src/constants/config.mts +++ b/packages/cli/src/constants/config.mts @@ -5,6 +5,12 @@ export const CONFIG_KEY_API_BASE_URL = 'apiBaseUrl' export const CONFIG_KEY_API_PROXY = 'apiProxy' export const CONFIG_KEY_API_TOKEN = 'apiToken' +export const CONFIG_KEY_AUTH_BASE_URL = 'authBaseUrl' export const CONFIG_KEY_DEFAULT_ORG = 'defaultOrg' export const CONFIG_KEY_ENFORCED_ORGS = 'enforcedOrgs' export const CONFIG_KEY_ORG = 'org' +export const CONFIG_KEY_OAUTH_CLIENT_ID = 'oauthClientId' +export const CONFIG_KEY_OAUTH_REDIRECT_URI = 'oauthRedirectUri' +export const CONFIG_KEY_OAUTH_REFRESH_TOKEN = 'oauthRefreshToken' +export const CONFIG_KEY_OAUTH_SCOPES = 'oauthScopes' +export const CONFIG_KEY_OAUTH_TOKEN_EXPIRES_AT = 'oauthTokenExpiresAt' diff --git a/packages/cli/src/constants/env.mts b/packages/cli/src/constants/env.mts index 1ae2313b9..6b0ad82a9 100644 --- a/packages/cli/src/constants/env.mts +++ b/packages/cli/src/constants/env.mts @@ -35,6 +35,7 @@ import { getPythonVersion } from '../env/python-version.mts' import { RUN_E2E_TESTS } from '../env/run-e2e-tests.mts' import { getSwfVersion } from '../env/sfw-version.mts' import { SOCKET_CLI_ACCEPT_RISKS } from '../env/socket-cli-accept-risks.mts' +import { SOCKET_CLI_AUTH_BASE_URL } from '../env/socket-cli-auth-base-url.mts' import { SOCKET_CLI_API_BASE_URL } from '../env/socket-cli-api-base-url.mts' import { SOCKET_CLI_API_PROXY } from '../env/socket-cli-api-proxy.mts' import { SOCKET_CLI_API_TIMEOUT } from '../env/socket-cli-api-timeout.mts' @@ -58,6 +59,9 @@ import { SOCKET_CLI_MODELS_PATH } from '../env/socket-cli-models-path.mts' import { SOCKET_CLI_NO_API_TOKEN } from '../env/socket-cli-no-api-token.mts' import { SOCKET_CLI_NPM_PATH } from '../env/socket-cli-npm-path.mts' import { SOCKET_CLI_OPTIMIZE } from '../env/socket-cli-optimize.mts' +import { SOCKET_CLI_OAUTH_CLIENT_ID } from '../env/socket-cli-oauth-client-id.mts' +import { SOCKET_CLI_OAUTH_REDIRECT_URI } from '../env/socket-cli-oauth-redirect-uri.mts' +import { SOCKET_CLI_OAUTH_SCOPES } from '../env/socket-cli-oauth-scopes.mts' import { SOCKET_CLI_ORG_SLUG } from '../env/socket-cli-org-slug.mts' import { SOCKET_CLI_PYCLI_LOCAL_PATH } from '../env/socket-cli-pycli-local-path.mts' import { SOCKET_CLI_SEA_NODE_VERSION } from '../env/socket-cli-sea-node-version.mts' @@ -96,6 +100,7 @@ export { PREBUILT_NODE_DOWNLOAD_URL, RUN_E2E_TESTS, SOCKET_CLI_ACCEPT_RISKS, + SOCKET_CLI_AUTH_BASE_URL, SOCKET_CLI_API_BASE_URL, SOCKET_CLI_API_PROXY, SOCKET_CLI_API_TIMEOUT, @@ -119,6 +124,9 @@ export { SOCKET_CLI_NO_API_TOKEN, SOCKET_CLI_NPM_PATH, SOCKET_CLI_OPTIMIZE, + SOCKET_CLI_OAUTH_CLIENT_ID, + SOCKET_CLI_OAUTH_REDIRECT_URI, + SOCKET_CLI_OAUTH_SCOPES, SOCKET_CLI_ORG_SLUG, SOCKET_CLI_PYCLI_LOCAL_PATH, SOCKET_CLI_SEA_NODE_VERSION, @@ -178,6 +186,7 @@ const envSnapshot = { PREBUILT_NODE_DOWNLOAD_URL, RUN_E2E_TESTS, SOCKET_CLI_ACCEPT_RISKS, + SOCKET_CLI_AUTH_BASE_URL, SOCKET_CLI_API_BASE_URL, SOCKET_CLI_API_PROXY, SOCKET_CLI_API_TIMEOUT, @@ -201,6 +210,9 @@ const envSnapshot = { SOCKET_CLI_NO_API_TOKEN, SOCKET_CLI_NPM_PATH, SOCKET_CLI_OPTIMIZE, + SOCKET_CLI_OAUTH_CLIENT_ID, + SOCKET_CLI_OAUTH_REDIRECT_URI, + SOCKET_CLI_OAUTH_SCOPES, SOCKET_CLI_ORG_SLUG, SOCKET_CLI_PYCLI_LOCAL_PATH, SOCKET_CLI_SEA_NODE_VERSION, diff --git a/packages/cli/src/env/socket-cli-auth-base-url.mts b/packages/cli/src/env/socket-cli-auth-base-url.mts new file mode 100644 index 000000000..a60922b77 --- /dev/null +++ b/packages/cli/src/env/socket-cli-auth-base-url.mts @@ -0,0 +1,5 @@ +/** @fileoverview SOCKET_CLI_AUTH_BASE_URL environment variable. */ + +import { env } from 'node:process' + +export const SOCKET_CLI_AUTH_BASE_URL = env['SOCKET_CLI_AUTH_BASE_URL'] diff --git a/packages/cli/src/env/socket-cli-oauth-client-id.mts b/packages/cli/src/env/socket-cli-oauth-client-id.mts new file mode 100644 index 000000000..149daaa12 --- /dev/null +++ b/packages/cli/src/env/socket-cli-oauth-client-id.mts @@ -0,0 +1,5 @@ +/** @fileoverview SOCKET_CLI_OAUTH_CLIENT_ID environment variable. */ + +import { env } from 'node:process' + +export const SOCKET_CLI_OAUTH_CLIENT_ID = env['SOCKET_CLI_OAUTH_CLIENT_ID'] diff --git a/packages/cli/src/env/socket-cli-oauth-redirect-uri.mts b/packages/cli/src/env/socket-cli-oauth-redirect-uri.mts new file mode 100644 index 000000000..aa7f24676 --- /dev/null +++ b/packages/cli/src/env/socket-cli-oauth-redirect-uri.mts @@ -0,0 +1,6 @@ +/** @fileoverview SOCKET_CLI_OAUTH_REDIRECT_URI environment variable. */ + +import { env } from 'node:process' + +export const SOCKET_CLI_OAUTH_REDIRECT_URI = + env['SOCKET_CLI_OAUTH_REDIRECT_URI'] diff --git a/packages/cli/src/env/socket-cli-oauth-scopes.mts b/packages/cli/src/env/socket-cli-oauth-scopes.mts new file mode 100644 index 000000000..7dec1599d --- /dev/null +++ b/packages/cli/src/env/socket-cli-oauth-scopes.mts @@ -0,0 +1,5 @@ +/** @fileoverview SOCKET_CLI_OAUTH_SCOPES environment variable. */ + +import { env } from 'node:process' + +export const SOCKET_CLI_OAUTH_SCOPES = env['SOCKET_CLI_OAUTH_SCOPES'] diff --git a/packages/cli/src/utils/auth/oauth.mts b/packages/cli/src/utils/auth/oauth.mts new file mode 100644 index 000000000..b2a46a6f3 --- /dev/null +++ b/packages/cli/src/utils/auth/oauth.mts @@ -0,0 +1,265 @@ +import http from 'node:http' +import https from 'node:https' + +import { HttpProxyAgent, HttpsProxyAgent } from 'hpagent' + +import { isNonEmptyString } from '@socketsecurity/lib/strings' +import { isUrl } from '@socketsecurity/lib/url' + +import type { CResult } from '../../types.mts' + +export type OAuthAuthorizationServerMetadata = { + issuer: string + authorization_endpoint: string + token_endpoint: string + introspection_endpoint?: string | undefined + response_types_supported?: string[] | undefined + grant_types_supported?: string[] | undefined + code_challenge_methods_supported?: string[] | undefined + token_endpoint_auth_methods_supported?: string[] | undefined +} + +export type OAuthTokenResponse = { + access_token: string + token_type: 'Bearer' + expires_in: number + refresh_token?: string | undefined + scope?: string | undefined +} + +export function normalizeUrlBase(value: string): string { + return value.replace(/\/+$/u, '') +} + +export function joinUrl(base: string, path: string): string { + const normalizedBase = normalizeUrlBase(base) + const normalizedPath = path.startsWith('/') ? path : `/${path}` + return `${normalizedBase}${normalizedPath}` +} + +export function deriveAuthBaseUrlFromApiBaseUrl( + apiBaseUrl: string | undefined, +): string | undefined { + if (!apiBaseUrl) { + return undefined + } + if (!isUrl(apiBaseUrl)) { + return undefined + } + const url = new URL(apiBaseUrl) + + const normalizedPath = url.pathname.replace(/\/+$/u, '') + const strippedPath = normalizedPath.replace(/\/v0$/u, '') + + url.pathname = strippedPath || '/' + url.search = '' + url.hash = '' + return normalizeUrlBase(url.toString()) +} + +function createProxyDispatcher(params: { + url: string + apiProxy: string | undefined +}): HttpProxyAgent | HttpsProxyAgent | undefined { + const { apiProxy, url } = params + if (!apiProxy || !isUrl(apiProxy)) { + return undefined + } + const ProxyAgent = url.startsWith('http:') ? HttpProxyAgent : HttpsProxyAgent + return new ProxyAgent({ proxy: apiProxy }) +} + +async function requestText(params: { + url: string + method: 'GET' | 'POST' + apiProxy?: string | undefined + headers?: Record | undefined + body?: string | undefined +}): Promise> { + try { + const agent = createProxyDispatcher({ + url: params.url, + apiProxy: params.apiProxy, + }) + const url = new URL(params.url) + const transport = url.protocol === 'http:' ? http : https + + const body = params.body ?? '' + const headers: Record = { + ...(params.headers ?? {}), + ...(params.method === 'POST' + ? { 'content-length': Buffer.byteLength(body).toString() } + : {}), + } + + return await new Promise(resolve => { + const req = transport.request( + url, + { + method: params.method, + headers, + ...(agent ? { agent } : {}), + }, + res => { + const chunks: Buffer[] = [] + res.on('data', chunk => + chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk)), + ) + res.on('end', () => { + resolve({ + ok: true, + data: { + status: res.statusCode ?? 0, + statusText: res.statusMessage ?? '', + text: Buffer.concat(chunks).toString('utf8'), + }, + }) + }) + }, + ) + + req.on('error', e => { + resolve({ + ok: false, + message: 'OAuth request failed', + cause: e instanceof Error ? e.message : String(e), + }) + }) + + if (params.method === 'POST') { + req.write(body) + } + req.end() + }) + } catch (e) { + return { + ok: false, + message: 'OAuth request failed', + cause: e instanceof Error ? e.message : String(e), + } + } +} + +async function postFormJson(params: { + url: string + apiProxy?: string | undefined + body: URLSearchParams +}): Promise> { + const resResult = await requestText({ + url: params.url, + method: 'POST', + apiProxy: params.apiProxy, + headers: { + 'content-type': 'application/x-www-form-urlencoded', + }, + body: params.body.toString(), + }) + if (!resResult.ok) { + return resResult + } + + const { status, statusText, text } = resResult.data + if (status < 200 || status >= 300) { + return { + ok: false, + message: `OAuth request failed (HTTP ${status})`, + cause: isNonEmptyString(text) ? text : statusText, + } + } + + try { + return { ok: true, data: JSON.parse(text) as T } + } catch { + return { + ok: false, + message: 'OAuth request failed', + cause: 'Server returned invalid JSON', + } + } +} + +export async function fetchOAuthAuthorizationServerMetadata(params: { + authBaseUrl: string + apiProxy?: string | undefined +}): Promise> { + const url = joinUrl( + params.authBaseUrl, + '/.well-known/oauth-authorization-server', + ) + const resResult = await requestText({ + url, + method: 'GET', + apiProxy: params.apiProxy, + }) + if (!resResult.ok) { + return { + ok: false, + message: 'OAuth metadata request failed', + cause: resResult.cause, + } + } + + const { status, statusText, text } = resResult.data + if (status < 200 || status >= 300) { + return { + ok: false, + message: `OAuth metadata request failed (HTTP ${status})`, + cause: isNonEmptyString(text) ? text : statusText, + } + } + + try { + return { + ok: true, + data: JSON.parse(text) as OAuthAuthorizationServerMetadata, + } + } catch { + return { + ok: false, + message: 'OAuth metadata request failed', + cause: 'Server returned invalid JSON', + } + } +} + +export async function exchangeAuthorizationCodeForToken(params: { + tokenEndpoint: string + clientId: string + code: string + redirectUri: string + codeVerifier: string + apiProxy?: string | undefined +}): Promise> { + const body = new URLSearchParams({ + grant_type: 'authorization_code', + client_id: params.clientId, + code: params.code, + redirect_uri: params.redirectUri, + code_verifier: params.codeVerifier, + }) + + return await postFormJson({ + url: params.tokenEndpoint, + apiProxy: params.apiProxy, + body, + }) +} + +export async function refreshOAuthAccessToken(params: { + tokenEndpoint: string + clientId: string + refreshToken: string + apiProxy?: string | undefined +}): Promise> { + const body = new URLSearchParams({ + grant_type: 'refresh_token', + client_id: params.clientId, + refresh_token: params.refreshToken, + }) + + return await postFormJson({ + url: params.tokenEndpoint, + apiProxy: params.apiProxy, + body, + }) +} diff --git a/packages/cli/src/utils/config.mts b/packages/cli/src/utils/config.mts index f9c5de13a..11624cf47 100644 --- a/packages/cli/src/utils/config.mts +++ b/packages/cli/src/utils/config.mts @@ -36,9 +36,15 @@ import { CONFIG_KEY_API_BASE_URL, CONFIG_KEY_API_PROXY, CONFIG_KEY_API_TOKEN, + CONFIG_KEY_AUTH_BASE_URL, CONFIG_KEY_DEFAULT_ORG, CONFIG_KEY_ENFORCED_ORGS, CONFIG_KEY_ORG, + CONFIG_KEY_OAUTH_CLIENT_ID, + CONFIG_KEY_OAUTH_REDIRECT_URI, + CONFIG_KEY_OAUTH_REFRESH_TOKEN, + CONFIG_KEY_OAUTH_SCOPES, + CONFIG_KEY_OAUTH_TOKEN_EXPIRES_AT, } from '../constants/config.mts' import { getSocketAppDataPath } from '../constants/paths.mts' import { SOCKET_YAML, SOCKET_YML } from '../constants/socket.mts' @@ -53,8 +59,14 @@ export interface LocalConfig { apiBaseUrl?: string | null | undefined apiProxy?: string | null | undefined apiToken?: string | null | undefined + authBaseUrl?: string | null | undefined defaultOrg?: string | undefined enforcedOrgs?: string[] | readonly string[] | null | undefined + oauthClientId?: string | null | undefined + oauthRedirectUri?: string | null | undefined + oauthRefreshToken?: string | null | undefined + oauthScopes?: string[] | readonly string[] | null | undefined + oauthTokenExpiresAt?: number | null | undefined skipAskToPersistDefaultOrg?: boolean | undefined // Convenience alias for defaultOrg. org?: string | undefined @@ -62,6 +74,7 @@ export interface LocalConfig { const sensitiveConfigKeyLookup: Set = new Set([ CONFIG_KEY_API_TOKEN, + CONFIG_KEY_OAUTH_REFRESH_TOKEN, ]) const supportedConfig: Map = new Map([ @@ -71,6 +84,10 @@ const supportedConfig: Map = new Map([ CONFIG_KEY_API_TOKEN, 'The Socket API token required to access most Socket API endpoints', ], + [ + CONFIG_KEY_AUTH_BASE_URL, + 'Base URL of the OAuth authorization server (used by `socket login`)', + ], [ CONFIG_KEY_DEFAULT_ORG, 'The default org slug to use; usually the org your Socket API token has access to. When set, all orgSlug arguments are implied to be this value.', @@ -79,6 +96,20 @@ const supportedConfig: Map = new Map([ CONFIG_KEY_ENFORCED_ORGS, 'Orgs in this list have their security policies enforced on this machine', ], + [CONFIG_KEY_OAUTH_CLIENT_ID, 'OAuth client_id used by the Socket CLI'], + [ + CONFIG_KEY_OAUTH_REDIRECT_URI, + 'OAuth redirect URI used by the Socket CLI (must match the registered client redirect URIs)', + ], + [CONFIG_KEY_OAUTH_REFRESH_TOKEN, 'OAuth refresh token (sensitive)'], + [ + CONFIG_KEY_OAUTH_SCOPES, + 'OAuth scopes requested during CLI login (array of scope strings)', + ], + [ + CONFIG_KEY_OAUTH_TOKEN_EXPIRES_AT, + 'OAuth access token expiry timestamp (ms since epoch)', + ], [ 'skipAskToPersistDefaultOrg', 'This flag prevents the Socket CLI from asking you to persist the org slug when you selected one interactively', @@ -109,8 +140,6 @@ function getConfigValues(): LocalConfig { logger.warn(`Failed to parse config at ${configFilePath}`) debugConfig(configFilePath, false, e) } - } else { - safeMkdirSync(socketAppDataPath, { recursive: true }) } } } @@ -348,6 +377,7 @@ export function updateConfigValue( writeFileSync( configFilePath, Buffer.from(jsonContent).toString('base64'), + { mode: 0o600 }, ) } }) diff --git a/packages/cli/src/utils/socket/api-wrapper.mts b/packages/cli/src/utils/socket/api-wrapper.mts index b229aefe3..dbfb78883 100644 --- a/packages/cli/src/utils/socket/api-wrapper.mts +++ b/packages/cli/src/utils/socket/api-wrapper.mts @@ -1,7 +1,11 @@ /** @fileoverview Simplified API wrapper to DRY out repetitive fetch-*.mts files */ import { handleApiCall } from './api.mts' -import { setupSdk } from './sdk.mts' +import { + hasOAuthRefreshTokenConfigured, + refreshOAuthApiTokenFromConfig, + setupSdk, +} from './sdk.mts' import type { BaseFetchOptions, CResult } from '../../types.mts' import type { SocketSdk } from '@socketsecurity/sdk' @@ -22,7 +26,26 @@ export async function apiCall( } const sdk = sdkResult.data - return await handleApiCall((sdk[method] as any)(...args), { description }) + const run = async (sdkInstance: SocketSdk) => + await handleApiCall((sdkInstance[method] as any)(...args), { description }) + + let result = await run(sdk) + const code = (result as any)?.data?.code + + if (code === 401 && hasOAuthRefreshTokenConfigured()) { + const refreshResult = await refreshOAuthApiTokenFromConfig({ + apiBaseUrl: options?.sdkOpts?.apiBaseUrl, + apiProxy: options?.sdkOpts?.apiProxy, + }) + if (refreshResult.ok) { + const refreshedSdkResult = await setupSdk(options?.sdkOpts) + if (refreshedSdkResult.ok) { + result = await run(refreshedSdkResult.data) + } + } + } + + return result } /** diff --git a/packages/cli/src/utils/socket/api.mts b/packages/cli/src/utils/socket/api.mts index 22713124b..e6e954a47 100644 --- a/packages/cli/src/utils/socket/api.mts +++ b/packages/cli/src/utils/socket/api.mts @@ -26,7 +26,11 @@ import { getDefaultLogger } from '@socketsecurity/lib/logger' import { getDefaultSpinner } from '@socketsecurity/lib/spinner' import { isNonEmptyString } from '@socketsecurity/lib/strings' -import { getDefaultApiToken } from './sdk.mts' +import { + getDefaultApiToken, + hasOAuthRefreshTokenConfigured, + refreshOAuthApiTokenFromConfig, +} from './sdk.mts' import { CONFIG_KEY_API_BASE_URL } from '../../constants/config.mts' import ENV from '../../constants/env.mts' import { @@ -67,6 +71,10 @@ const logger = getDefaultLogger() const NO_ERROR_MESSAGE = 'No error message returned' +function getAuthorizationHeaderValue(apiToken: string): string { + return `Bearer ${apiToken}` +} + export type CommandRequirements = { permissions?: string[] | undefined quota?: number | undefined @@ -358,7 +366,7 @@ export async function queryApi(path: string, apiToken: string) { return await fetch(`${baseUrl}${baseUrl.endsWith('/') ? '' : '/'}${path}`, { method: 'GET', headers: { - Authorization: `Basic ${btoa(`${apiToken}:`)}`, + Authorization: getAuthorizationHeaderValue(apiToken), }, }) } @@ -444,6 +452,36 @@ export async function queryApiSafeText( durationMs, headers: { Authorization: '[REDACTED]' }, }) + // If OAuth is configured and we got a 401, try one refresh+retry. + if (status === 401 && hasOAuthRefreshTokenConfigured()) { + const refreshResult = await refreshOAuthApiTokenFromConfig({ + apiBaseUrl: baseUrl, + apiProxy: undefined, + }) + if (refreshResult.ok) { + const retriedToken = getDefaultApiToken() + if (retriedToken) { + try { + result = await queryApi(path, retriedToken) + } catch { + // ignore; fall through to normal handling below + } + if (result?.ok) { + try { + const data = await result.text() + return { ok: true, data } + } catch { + return { + ok: false, + message: 'API request failed', + cause: 'Unexpected error reading response text', + } + } + } + } + } + } + // Log required permissions for 403 errors when in a command context. if (commandPath && status === 403) { logPermissionsFor403(commandPath) @@ -555,7 +593,7 @@ export async function sendApiRequest( const fetchOptions = { method, headers: { - Authorization: `Basic ${btoa(`${apiToken}:`)}`, + Authorization: getAuthorizationHeaderValue(apiToken), 'Content-Type': 'application/json', }, ...(body ? { body: JSON.stringify(body) } : {}), @@ -626,6 +664,36 @@ export async function sendApiRequest( 'Content-Type': 'application/json', }, }) + // If OAuth is configured and we got a 401, try one refresh+retry. + if (status === 401 && hasOAuthRefreshTokenConfigured()) { + const refreshResult = await refreshOAuthApiTokenFromConfig({ + apiBaseUrl: baseUrl, + apiProxy: undefined, + }) + if (refreshResult.ok) { + const retriedToken = getDefaultApiToken() + if (retriedToken) { + try { + const retryFetchOptions = { + method, + headers: { + Authorization: getAuthorizationHeaderValue(retriedToken), + 'Content-Type': 'application/json', + }, + ...(body ? { body: JSON.stringify(body) } : {}), + } + const retriedResult = await fetch(fullUrl, retryFetchOptions) + if (retriedResult.ok) { + const data = await retriedResult.json() + return { ok: true, data: data as T } + } + } catch { + // ignore; fall through to normal handling below + } + } + } + } + // Log required permissions for 403 errors when in a command context. if (commandPath && status === 403) { logPermissionsFor403(commandPath) diff --git a/packages/cli/src/utils/socket/sdk.mts b/packages/cli/src/utils/socket/sdk.mts index 43a1cc2f3..0f3ca4d36 100644 --- a/packages/cli/src/utils/socket/sdk.mts +++ b/packages/cli/src/utils/socket/sdk.mts @@ -39,12 +39,22 @@ import { CONFIG_KEY_API_BASE_URL, CONFIG_KEY_API_PROXY, CONFIG_KEY_API_TOKEN, + CONFIG_KEY_AUTH_BASE_URL, + CONFIG_KEY_OAUTH_CLIENT_ID, + CONFIG_KEY_OAUTH_REFRESH_TOKEN, + CONFIG_KEY_OAUTH_SCOPES, + CONFIG_KEY_OAUTH_TOKEN_EXPIRES_AT, } from '../../constants/config.mts' import ENV from '../../constants/env.mts' import { TOKEN_PREFIX_LENGTH } from '../../constants/socket.mts' -import { getConfigValueOrUndef } from '../config.mts' +import { getConfigValueOrUndef, updateConfigValue } from '../config.mts' import { debugApiRequest, debugApiResponse } from '../debug.mts' import { trackCliEvent } from '../telemetry/integration.mts' +import { + deriveAuthBaseUrlFromApiBaseUrl, + fetchOAuthAuthorizationServerMetadata, + refreshOAuthAccessToken, +} from '../auth/oauth.mts' import type { CResult } from '../../types.mts' import type { @@ -118,6 +128,93 @@ export type SetupSdkOptions = { apiToken?: string | undefined } +export function hasOAuthRefreshTokenConfigured(): boolean { + const refreshToken = getConfigValueOrUndef(CONFIG_KEY_OAUTH_REFRESH_TOKEN) + return isNonEmptyString(refreshToken) +} + +export async function refreshOAuthApiTokenFromConfig(params: { + apiBaseUrl: string | undefined + apiProxy: string | undefined +}): Promise> { + const refreshToken = getConfigValueOrUndef(CONFIG_KEY_OAUTH_REFRESH_TOKEN) + const clientId = getConfigValueOrUndef(CONFIG_KEY_OAUTH_CLIENT_ID) + const storedAuthBaseUrl = getConfigValueOrUndef(CONFIG_KEY_AUTH_BASE_URL) + + if (!isNonEmptyString(refreshToken) || !isNonEmptyString(clientId)) { + return { + ok: false, + message: 'Auth Error', + cause: 'OAuth refresh token is not configured. Run `socket login`.', + } + } + + const derivedAuthBaseUrl = deriveAuthBaseUrlFromApiBaseUrl(params.apiBaseUrl) + const authBaseUrl = + (isNonEmptyString(storedAuthBaseUrl) ? storedAuthBaseUrl : undefined) ?? + derivedAuthBaseUrl + + if (!isNonEmptyString(authBaseUrl)) { + return { + ok: false, + message: 'Auth Error', + cause: + 'OAuth authBaseUrl is not configured. Run `socket login` or set SOCKET_CLI_AUTH_BASE_URL.', + } + } + + const metaResult = await fetchOAuthAuthorizationServerMetadata({ + authBaseUrl, + apiProxy: params.apiProxy, + }) + if (!metaResult.ok) { + return { + ok: false, + message: metaResult.message, + cause: metaResult.cause, + } + } + + const tokenEndpoint = metaResult.data.token_endpoint + const refreshed = await refreshOAuthAccessToken({ + tokenEndpoint, + clientId, + refreshToken, + apiProxy: params.apiProxy, + }) + if (!refreshed.ok) { + return { + ok: false, + message: refreshed.message, + cause: + refreshed.cause || + 'OAuth refresh failed. Run `socket login` to re-authenticate.', + } + } + + const nextRefreshToken = + refreshed.data.refresh_token && + isNonEmptyString(refreshed.data.refresh_token) + ? refreshed.data.refresh_token + : refreshToken + + const expiresAt = Date.now() + Math.max(0, refreshed.data.expires_in) * 1000 + + updateConfigValue(CONFIG_KEY_API_TOKEN, refreshed.data.access_token) + updateConfigValue(CONFIG_KEY_OAUTH_REFRESH_TOKEN, nextRefreshToken) + updateConfigValue(CONFIG_KEY_OAUTH_TOKEN_EXPIRES_AT, expiresAt) + + // If scopes were not stored at login time, store whatever the server returned. + if (!getConfigValueOrUndef(CONFIG_KEY_OAUTH_SCOPES) && refreshed.data.scope) { + updateConfigValue( + CONFIG_KEY_OAUTH_SCOPES, + refreshed.data.scope.split(' ').filter(Boolean), + ) + } + + return { ok: true, data: { accessToken: refreshed.data.access_token } } +} + export async function setupSdk( options?: SetupSdkOptions | undefined, ): Promise> { @@ -147,6 +244,44 @@ export async function setupSdk( const { apiBaseUrl = getDefaultApiBaseUrl() } = opts + const oauthRefreshToken = getConfigValueOrUndef( + CONFIG_KEY_OAUTH_REFRESH_TOKEN, + ) + const oauthExpiresAt = getConfigValueOrUndef( + CONFIG_KEY_OAUTH_TOKEN_EXPIRES_AT, + ) + const shouldSkipOAuthRefresh = + isNonEmptyString(opts.apiToken) || + isNonEmptyString(ENV.SOCKET_CLI_API_TOKEN) + + // If OAuth is configured, treat the persisted `apiToken` as a short-lived access token and + // transparently refresh it when expired/near-expiry. + if (!shouldSkipOAuthRefresh && isNonEmptyString(oauthRefreshToken)) { + const now = Date.now() + const expiresAtMs = + typeof oauthExpiresAt === 'number' && Number.isFinite(oauthExpiresAt) + ? oauthExpiresAt + : null + const isExpiredOrMissing = + !apiToken || !expiresAtMs || expiresAtMs - now <= 60_000 + + if (isExpiredOrMissing) { + const refreshResult = await refreshOAuthApiTokenFromConfig({ + apiBaseUrl, + apiProxy, + }) + if (!refreshResult.ok) { + return { + ok: false, + message: refreshResult.message, + cause: refreshResult.cause, + } + } + apiToken = refreshResult.data.accessToken + _defaultToken = apiToken + } + } + // Usage of HttpProxyAgent vs. HttpsProxyAgent based on the chart at: // https://github.com/delvedor/hpagent?tab=readme-ov-file#usage const ProxyAgent = apiBaseUrl?.startsWith('http:') diff --git a/packages/cli/test/unit/commands/scan/handle-create-new-scan.test.mts b/packages/cli/test/unit/commands/scan/handle-create-new-scan.test.mts index c4372f0d7..043cf924e 100644 --- a/packages/cli/test/unit/commands/scan/handle-create-new-scan.test.mts +++ b/packages/cli/test/unit/commands/scan/handle-create-new-scan.test.mts @@ -229,10 +229,11 @@ describe('handleCreateNewScan', () => { }) it('handles no eligible files found', async () => { - const { fetchSupportedScanFileNames } = await import( - '../../../../src/commands/scan/fetch-supported-scan-file-names.mts' - ) - const { getPackageFilesForScan } = await import( + const { fetchSupportedScanFileNames: _fetchSupportedScanFileNames } = + await import( + '../../../../src/commands/scan/fetch-supported-scan-file-names.mts' + ) + const { getPackageFilesForScan: _getPackageFilesForScan } = await import( '../../../../src/utils/fs/path-resolve.mts' ) const { checkCommandInput } = await import( diff --git a/packages/cli/test/unit/utils/config-oauth.test.mts b/packages/cli/test/unit/utils/config-oauth.test.mts new file mode 100644 index 000000000..9d99ac295 --- /dev/null +++ b/packages/cli/test/unit/utils/config-oauth.test.mts @@ -0,0 +1,22 @@ +import { describe, expect, it } from 'vitest' + +import { + getSupportedConfigKeys, + isSensitiveConfigKey, +} from '../../../src/utils/config.mts' + +describe('utils/config (oauth keys)', () => { + it('includes oauth-related config keys', () => { + const keys = getSupportedConfigKeys() + expect(keys).toContain('authBaseUrl') + expect(keys).toContain('oauthClientId') + expect(keys).toContain('oauthRedirectUri') + expect(keys).toContain('oauthRefreshToken') + expect(keys).toContain('oauthScopes') + expect(keys).toContain('oauthTokenExpiresAt') + }) + + it('treats oauthRefreshToken as sensitive', () => { + expect(isSensitiveConfigKey('oauthRefreshToken')).toBe(true) + }) +}) diff --git a/packages/cli/test/unit/utils/config.test.mts b/packages/cli/test/unit/utils/config.test.mts index 1b51db960..afb3fcdce 100644 --- a/packages/cli/test/unit/utils/config.test.mts +++ b/packages/cli/test/unit/utils/config.test.mts @@ -258,6 +258,10 @@ describe('utils/config', () => { // Set a config value. updateConfigValue('defaultOrg', 'test-org') + // Wait for nextTick to complete the async write and avoid side effects when + // this test's afterEach resets env vars. + await new Promise(resolve => process.nextTick(resolve)) + // Read it back immediately (from cache). const result = getConfigValue('defaultOrg') expect(result.ok).toBe(true) diff --git a/packages/cli/test/unit/utils/oauth.test.mts b/packages/cli/test/unit/utils/oauth.test.mts new file mode 100644 index 000000000..dc2b9c539 --- /dev/null +++ b/packages/cli/test/unit/utils/oauth.test.mts @@ -0,0 +1,52 @@ +import { describe, expect, it } from 'vitest' + +import { + deriveAuthBaseUrlFromApiBaseUrl, + joinUrl, + normalizeUrlBase, +} from '../../../src/utils/auth/oauth.mts' + +describe('utils/auth/oauth', () => { + describe('normalizeUrlBase', () => { + it('removes trailing slashes', () => { + expect(normalizeUrlBase('https://api.socket.dev/')).toBe( + 'https://api.socket.dev', + ) + expect(normalizeUrlBase('https://api.socket.dev////')).toBe( + 'https://api.socket.dev', + ) + }) + }) + + describe('joinUrl', () => { + it('joins base + path with single slash', () => { + expect(joinUrl('https://api.socket.dev/', '/.well-known/test')).toBe( + 'https://api.socket.dev/.well-known/test', + ) + expect(joinUrl('https://api.socket.dev', '.well-known/test')).toBe( + 'https://api.socket.dev/.well-known/test', + ) + }) + }) + + describe('deriveAuthBaseUrlFromApiBaseUrl', () => { + it('strips /v0 from API base URL', () => { + expect( + deriveAuthBaseUrlFromApiBaseUrl('https://api.socket.dev/v0/'), + ).toBe('https://api.socket.dev') + expect(deriveAuthBaseUrlFromApiBaseUrl('https://api.socket.dev/v0')).toBe( + 'https://api.socket.dev', + ) + }) + + it('normalizes trailing slashes and preserves host', () => { + expect(deriveAuthBaseUrlFromApiBaseUrl('https://api.socket.dev/')).toBe( + 'https://api.socket.dev', + ) + }) + + it('returns undefined for invalid URLs', () => { + expect(deriveAuthBaseUrlFromApiBaseUrl('not a url')).toBe(undefined) + }) + }) +}) diff --git a/packages/cli/test/unit/utils/socket/sdk.test.mts b/packages/cli/test/unit/utils/socket/sdk.test.mts index 35f456ad7..6c5ea6b6a 100644 --- a/packages/cli/test/unit/utils/socket/sdk.test.mts +++ b/packages/cli/test/unit/utils/socket/sdk.test.mts @@ -20,6 +20,10 @@ import { describe, expect, it } from 'vitest' +import http from 'node:http' + +import { SocketSdk } from '@socketsecurity/sdk' + import { getPublicApiToken, getVisibleTokenPrefix, @@ -49,4 +53,46 @@ describe('SDK Utilities', () => { expect(typeof hasToken).toBe('boolean') }) }) + + it('sends API tokens via Bearer authorization', async () => { + const token = + 'sktsec_t_--RAN5U4ivauy4w37-6aoKyYPDt5ZbaT5JBVMqiwKo_api' + + let receivedAuthorization: string | undefined + + const server = http.createServer((req, res) => { + receivedAuthorization = req.headers.authorization + res.statusCode = 200 + res.setHeader('content-type', 'application/json; charset=utf-8') + res.end(JSON.stringify({ success: true, data: { organizations: {} } })) + }) + + await new Promise(resolve => { + server.listen(0, '127.0.0.1', () => resolve()) + }) + + const address = server.address() + const port = + typeof address === 'object' && address && 'port' in address + ? address.port + : 0 + + try { + const sdk = new SocketSdk(token, { + baseUrl: `http://127.0.0.1:${port}/v0/`, + }) + + try { + await sdk.listOrganizations() + } catch { + // Ignore parse/shape errors; this test only asserts the auth scheme. + } + + expect(receivedAuthorization).toBe(`Bearer ${token}`) + } finally { + await new Promise(resolve => { + server.close(() => resolve()) + }) + } + }) }) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 20fa4a230..c591ac467 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -106,8 +106,8 @@ catalogs: specifier: 2.0.2 version: 2.0.2 '@socketsecurity/sdk': - specifier: 3.2.0 - version: 3.2.0 + specifier: 3.3.0 + version: 3.3.0 '@types/cmd-shim': specifier: 5.0.2 version: 5.0.2 @@ -487,7 +487,7 @@ importers: version: 2.0.2(typescript@5.9.3) '@socketsecurity/sdk': specifier: 'catalog:' - version: 3.2.0(typescript@5.9.3) + version: 3.3.0(typescript@5.9.3) '@types/cmd-shim': specifier: 'catalog:' version: 5.0.2 @@ -799,7 +799,7 @@ importers: version: 2.0.2(typescript@5.9.3) '@socketsecurity/sdk': specifier: 'catalog:' - version: 3.2.0(typescript@5.9.3) + version: 3.3.0(typescript@5.9.3) '@types/react': specifier: ^19.2.9 version: 19.2.9 @@ -2270,8 +2270,8 @@ packages: typescript: optional: true - '@socketsecurity/sdk@3.2.0': - resolution: {integrity: sha512-Cj5qqZV6nZ3JL6/7MHzPyggr3xS2YZ7jBh9SwIZzpxq3NFxGEuJBE/gkhzdEcv29OCRToAT27AG7QhbVmZKvCA==} + '@socketsecurity/sdk@3.3.0': + resolution: {integrity: sha512-LuGjybeo9tP+ErUru5E6N5V8eBefLYYBtMNp3Fovy7EIZ4h5kDEolbsZEiJyCyOHbxImgRqKhJLHpYBFIXxHrA==} engines: {node: '>=18', pnpm: '>=10.25.0'} '@standard-schema/spec@1.0.0': @@ -3413,10 +3413,6 @@ packages: flatted@3.3.3: resolution: {integrity: sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==} - form-data@4.0.4: - resolution: {integrity: sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==} - engines: {node: '>= 6'} - form-data@4.0.5: resolution: {integrity: sha512-8RipRLol37bNs2bhoV67fiTEvdTrbMUYcFTiy3+wuuOnUog2QBHCZWXDRijWQfAkhBj2Uf5UnVaiWwA5vdd82w==} engines: {node: '>= 6'} @@ -6664,7 +6660,7 @@ snapshots: optionalDependencies: typescript: 5.9.3 - '@socketsecurity/sdk@3.2.0(typescript@5.9.3)': + '@socketsecurity/sdk@3.3.0(typescript@5.9.3)': dependencies: '@socketregistry/packageurl-js': 1.3.5 '@socketsecurity/lib': 5.5.3(typescript@5.9.3) @@ -6810,7 +6806,7 @@ snapshots: '@types/node-fetch@2.6.13': dependencies: '@types/node': 24.9.2 - form-data: 4.0.4 + form-data: 4.0.5 '@types/node@24.9.2': dependencies: @@ -7943,14 +7939,6 @@ snapshots: flatted@3.3.3: {} - form-data@4.0.4: - dependencies: - asynckit: 0.4.0 - combined-stream: 1.0.8 - es-set-tostringtag: '@socketregistry/es-set-tostringtag@1.0.10' - hasown: '@socketregistry/hasown@1.0.7' - mime-types: 2.1.35 - form-data@4.0.5: dependencies: asynckit: 0.4.0