From 99285f67295af6a0e617f5ca8104ebc3a82be4b5 Mon Sep 17 00:00:00 2001 From: Flavien David Date: Wed, 12 Feb 2025 13:56:07 -0800 Subject: [PATCH 1/5] tmp --- extension/app/main.tsx | 5 +- extension/app/src/chrome/platform.ts | 7 + extension/app/src/chrome/services/auth.ts | 173 +++++++++++++++ .../app/src/components/auth/AuthProvider.tsx | 12 +- extension/app/src/components/auth/useAuth.ts | 28 ++- extension/app/src/front/main.tsx | 37 +++ extension/app/src/front/services/auth.ts | 137 ++++++++++++ extension/app/src/hooks/useAuthErrorCheck.ts | 12 +- extension/app/src/hooks/useEventSource.ts | 5 +- extension/app/src/lib/auth.ts | 210 ------------------ extension/app/src/lib/conversation.ts | 5 +- extension/app/src/shared/context/platform.ts | 6 + extension/app/src/shared/services/auth.ts | 58 +++++ extension/app/src/shared/services/platform.ts | 6 + extension/package-lock.json | 6 + extension/package.json | 2 + 16 files changed, 471 insertions(+), 238 deletions(-) create mode 100644 extension/app/src/chrome/platform.ts create mode 100644 extension/app/src/chrome/services/auth.ts create mode 100644 extension/app/src/front/main.tsx create mode 100644 extension/app/src/front/services/auth.ts delete mode 100644 extension/app/src/lib/auth.ts create mode 100644 extension/app/src/shared/context/platform.ts create mode 100644 extension/app/src/shared/services/auth.ts create mode 100644 extension/app/src/shared/services/platform.ts diff --git a/extension/app/main.tsx b/extension/app/main.tsx index 5bf3187f168d..75741996399b 100644 --- a/extension/app/main.tsx +++ b/extension/app/main.tsx @@ -8,6 +8,7 @@ import "./src/css/components.css"; import "./src/css/custom.css"; import { Notification } from "@dust-tt/sparkle"; +import { ChromeAuth } from "@extension/chrome/services/auth"; import { AuthProvider } from "@extension/components/auth/AuthProvider"; import { PortProvider } from "@extension/components/PortContext"; import { routes } from "@extension/pages/routes"; @@ -17,9 +18,11 @@ import { createBrowserRouter, RouterProvider } from "react-router-dom"; const router = createBrowserRouter(routes); const App = () => { + const authService = new ChromeAuth(); + return ( - + diff --git a/extension/app/src/chrome/platform.ts b/extension/app/src/chrome/platform.ts new file mode 100644 index 000000000000..a9b329bbd810 --- /dev/null +++ b/extension/app/src/chrome/platform.ts @@ -0,0 +1,7 @@ +import { ChromeAuth } from "@extension/chrome/services/auth"; +import type { PlatformService } from "@extension/shared/services/platform"; + +export const chromePlatform: PlatformService = { + platform: "chrome", + auth: new ChromeAuth(), +}; diff --git a/extension/app/src/chrome/services/auth.ts b/extension/app/src/chrome/services/auth.ts new file mode 100644 index 000000000000..ba16fbb83cd4 --- /dev/null +++ b/extension/app/src/chrome/services/auth.ts @@ -0,0 +1,173 @@ +import type { Result } from "@dust-tt/client"; +import { Err, Ok } from "@dust-tt/client"; +import { + AUTH0_CLAIM_NAMESPACE, + DEFAULT_DUST_API_DOMAIN, + DUST_EU_URL, + DUST_US_URL, +} from "@extension/lib/config"; +import { + sendAuthMessage, + sendRefreshTokenMessage, + sentLogoutMessage, +} from "@extension/lib/messages"; +import type { + StoredTokens, + UserTypeWithExtensionWorkspaces, +} from "@extension/lib/storage"; +import { + clearStoredData, + getStoredTokens, + saveTokens, + saveUser, +} from "@extension/lib/storage"; +import type { AuthService } from "@extension/shared/services/auth"; +import { + AuthError, + makeEnterpriseConnectionName, +} from "@extension/shared/services/auth"; +import { jwtDecode } from "jwt-decode"; + +const REGIONS = ["europe-west1", "us-central1"] as const; +export type RegionType = (typeof REGIONS)[number]; + +const isRegionType = (region: string): region is RegionType => + REGIONS.includes(region as RegionType); + +const REGION_CLAIM = `${AUTH0_CLAIM_NAMESPACE}region`; +const CONNECTION_STRATEGY_CLAIM = `${AUTH0_CLAIM_NAMESPACE}connection.strategy`; +const WORKSPACE_ID_CLAIM = `${AUTH0_CLAIM_NAMESPACE}workspaceId`; + +const DOMAIN_FOR_REGION: Record = { + "us-central1": DUST_US_URL, + "europe-west1": DUST_EU_URL, +}; + +const log = console.error; + +export class ChromeAuth implements AuthService { + constructor() {} + + // Login sends a message to the background script to call the auth0 login endpoint. + // It saves the tokens in the extension and schedules a token refresh. + // Then it calls the /me route to get the user info. + async login(isForceLogin?: boolean, forcedConnection?: string) { + try { + const response = await sendAuthMessage(isForceLogin, forcedConnection); + if (!response.accessToken) { + throw new Error("No access token received."); + } + const tokens = await saveTokens(response); + + const claims = jwtDecode>(tokens.accessToken); + + const dustDomain = getDustDomain(claims); + const connectionDetails = getConnectionDetails(claims); + + const res = await fetchMe(tokens.accessToken, dustDomain); + if (res.isErr()) { + return res; + } + const workspaces = res.value.user.workspaces; + const user = await saveUser({ + ...res.value.user, + ...connectionDetails, + dustDomain, + selectedWorkspace: workspaces.length === 1 ? workspaces[0].sId : null, + }); + return new Ok({ tokens, user }); + } catch (error) { + return new Err(new AuthError("not_authenticated", error?.toString())); + } + } + + // Logout sends a message to the background script to call the auth0 logout endpoint. + // It also clears the stored tokens in the extension. + async logout(): Promise { + try { + const response = await sentLogoutMessage(); + if (!response?.success) { + throw new Error("No success response received."); + } + return true; + } catch (error) { + log("Logout failed: Unknown error.", error); + return false; + } finally { + await clearStoredData(); + } + } + + // Refresh token sends a message to the background script to call the auth0 refresh token endpoint. + // It updates the stored tokens with the new access token. + // If the refresh token is invalid, it will call handleLogout. + async refreshToken( + tokens?: StoredTokens | null + ): Promise> { + try { + tokens = tokens ?? (await getStoredTokens()); + if (!tokens) { + return new Err(new AuthError("not_authenticated", "No tokens found.")); + } + const response = await sendRefreshTokenMessage(tokens.refreshToken); + if (!response?.accessToken) { + return new Err( + new AuthError("not_authenticated", "No access token received") + ); + } + return new Ok(await saveTokens(response)); + } catch (error) { + log("Refresh token: unknown error.", error); + return new Err(new AuthError("not_authenticated", error?.toString())); + } + } + + async getAccessToken(): Promise { + let tokens = await getStoredTokens(); + if (!tokens || !tokens.accessToken || tokens.expiresAt < Date.now()) { + const refreshRes = await this.refreshToken(tokens); + if (refreshRes.isOk()) { + tokens = refreshRes.value; + } + } + + return tokens?.accessToken ?? null; + } +} + +const getDustDomain = (claims: Record) => { + const region = claims[REGION_CLAIM]; + + return ( + (isRegionType(region) && DOMAIN_FOR_REGION[region]) || + DEFAULT_DUST_API_DOMAIN + ); +}; + +const getConnectionDetails = (claims: Record) => { + const connectionStrategy = claims[CONNECTION_STRATEGY_CLAIM]; + const ws = claims[WORKSPACE_ID_CLAIM]; + return { + connectionStrategy, + connection: ws ? makeEnterpriseConnectionName(ws) : undefined, + }; +}; + +// Fetch me sends a request to the /me route to get the user info. +const fetchMe = async ( + accessToken: string, + dustDomain: string +): Promise> => { + const response = await fetch(`${dustDomain}/api/v1/me`, { + headers: { + Authorization: `Bearer ${accessToken}`, + "X-Request-Origin": "extension", + }, + }); + const me = await response.json(); + + if (!response.ok) { + return new Err(new AuthError(me.error.type, me.error.message)); + } + return new Ok(me); +}; diff --git a/extension/app/src/components/auth/AuthProvider.tsx b/extension/app/src/components/auth/AuthProvider.tsx index 22bfc3efbcf0..fabc6921962a 100644 --- a/extension/app/src/components/auth/AuthProvider.tsx +++ b/extension/app/src/components/auth/AuthProvider.tsx @@ -1,7 +1,7 @@ import type { ExtensionWorkspaceType, WorkspaceType } from "@dust-tt/client"; import { useAuthHook } from "@extension/components/auth/useAuth"; -import type { AuthError } from "@extension/lib/auth"; import type { StoredUser } from "@extension/lib/storage"; +import type { AuthError, AuthService } from "@extension/shared/services/auth"; import type { ReactNode } from "react"; import React, { createContext, useContext } from "react"; @@ -22,7 +22,13 @@ type AuthContextType = { const AuthContext = createContext(null); -export const AuthProvider = ({ children }: { children: ReactNode }) => { +export const AuthProvider = ({ + authService, + children, +}: { + authService: AuthService; + children: ReactNode; +}) => { const { token, isAuthenticated, @@ -36,7 +42,7 @@ export const AuthProvider = ({ children }: { children: ReactNode }) => { handleLogin, handleLogout, handleSelectWorkspace, - } = useAuthHook(); + } = useAuthHook(authService); return ( { +export const useAuthHook = (authService: AuthService) => { const [tokens, setTokens] = useState(null); const [user, setUser] = useState(null); const [authError, setAuthError] = useState(null); @@ -42,7 +40,7 @@ export const useAuthHook = () => { const handleLogout = useCallback(async () => { setIsLoading(true); - const success = await logout(); + const success = await authService.logout(); if (!success) { // TODO(EXT): User facing error message if logout failed. setIsLoading(false); @@ -58,7 +56,7 @@ export const useAuthHook = () => { }, []); const handleRefreshToken = useCallback(async () => { - const savedTokens = await refreshToken(); + const savedTokens = await authService.refreshToken(); if (savedTokens.isErr()) { setAuthError(savedTokens.error); log("Refresh token: No access token received."); @@ -144,14 +142,14 @@ export const useAuthHook = () => { ) ); setForcedConnection(makeEnterpriseConnectionName(workspace.sId)); - await logout(); + await authService.logout(); }, [setAuthError, setForcedConnection] ); const handleSelectWorkspace = async (workspace: WorkspaceType) => { const updatedUser = await saveSelectedWorkspace(workspace.sId); - if (!isValidEnterpriseConnection(updatedUser, workspace)) { + if (!isValidEnterpriseConnectionName(updatedUser, workspace)) { await redirectToSSOLogin(workspace); return; } @@ -161,7 +159,7 @@ export const useAuthHook = () => { const handleLogin = useCallback( async (isForceLogin?: boolean) => { setIsLoading(true); - const response = await login(isForceLogin, forcedConnection); + const response = await authService.login(isForceLogin, forcedConnection); if (response.isErr()) { setAuthError(response.error); setIsLoading(false); @@ -177,7 +175,7 @@ export const useAuthHook = () => { ); if ( selectedWorkspace && - !isValidEnterpriseConnection(user, selectedWorkspace) + !isValidEnterpriseConnectionName(user, selectedWorkspace) ) { await redirectToSSOLogin(selectedWorkspace); setIsLoading(false); diff --git a/extension/app/src/front/main.tsx b/extension/app/src/front/main.tsx new file mode 100644 index 000000000000..32863b27adb2 --- /dev/null +++ b/extension/app/src/front/main.tsx @@ -0,0 +1,37 @@ +// Tailwind base globals +import "./src/css/global.css"; +// Use sparkle styles, override local globals +import "@dust-tt/sparkle/dist/sparkle.css"; +// Local tailwind components override sparkle styles +import "./src/css/components.css"; +// Local custom styles +import "../css/custom.css"; + +import { Notification } from "@dust-tt/sparkle"; +import { AuthProvider } from "@extension/components/auth/AuthProvider"; +import { PortProvider } from "@extension/components/PortContext"; +import { FrontAuth } from "@extension/front/services/auth"; +import { routes } from "@extension/pages/routes"; +import ReactDOM from "react-dom/client"; +import { createBrowserRouter, RouterProvider } from "react-router-dom"; + +const router = createBrowserRouter(routes); + +const App = () => { + const authService = new FrontAuth(); + + return ( + + + + + + + + ); +}; +const rootElement = document.getElementById("root"); +if (rootElement) { + const root = ReactDOM.createRoot(rootElement); + root.render(); +} diff --git a/extension/app/src/front/services/auth.ts b/extension/app/src/front/services/auth.ts new file mode 100644 index 000000000000..4f1e6cde5a46 --- /dev/null +++ b/extension/app/src/front/services/auth.ts @@ -0,0 +1,137 @@ +import { Auth0Client } from "@auth0/auth0-spa-js"; +import type { Result } from "@dust-tt/client"; +import { Err, Ok } from "@dust-tt/client"; +import { + AUTH0_CLAIM_NAMESPACE, + DEFAULT_DUST_API_DOMAIN, + DUST_EU_URL, + DUST_US_URL, +} from "@extension/lib/config"; +import type { + StoredTokens, + UserTypeWithExtensionWorkspaces, +} from "@extension/lib/storage"; +import { jwtDecode } from "jwt-decode"; + +import type { AuthService } from "../../shared/services/auth"; +import { + AuthError, + makeEnterpriseConnectionName, +} from "../../shared/services/auth"; + +const REGION_CLAIM = `${AUTH0_CLAIM_NAMESPACE}region`; + +export class FrontAuth implements AuthService { + private auth0: Auth0Client; + + constructor() { + this.auth0 = new Auth0Client({ + domain: "your-domain", + clientId: "your-client-id", + }); + } + + async getAccessToken(): Promise { + throw new Error("Method not implemented."); + } + + async login(isForceLogin?: boolean, forcedConnection?: string) { + // TODO: Implement force login. + try { + await this.auth0.loginWithPopup(); + } catch (error) { + return new Err(new AuthError("not_authenticated", error?.toString())); + } + + const token = await this.auth0.getTokenSilently(); + const claims = jwtDecode>(token); + const dustDomain = getDustDomain(claims); + const connectionDetails = getConnectionDetails(claims); + + const res = await fetchMe(token, dustDomain); + if (res.isErr()) { + return res; + } + + // TODO: + return new Ok({ + tokens: { + accessToken: token, + refreshToken: "refreshToken", + expiresAt: Date.now() + 1000 * 60 * 60 * 24, + }, + user: { + ...res.value.user, + ...connectionDetails, + dustDomain, + selectedWorkspace: + res.value.user.workspaces.length === 1 + ? res.value.user.workspaces[0].sId + : null, + }, + }); + } + + logout(): Promise { + throw new Error("Method not implemented."); + } + + async refreshToken( + tokens?: StoredTokens | null + ): Promise> { + throw new Error("Method not implemented."); + } +} + +// TODO: Clean up. Duplicate code from chrome/services/auth.ts. + +const REGIONS = ["europe-west1", "us-central1"] as const; +export type RegionType = (typeof REGIONS)[number]; + +const DOMAIN_FOR_REGION: Record = { + "us-central1": DUST_US_URL, + "europe-west1": DUST_EU_URL, +}; + +const CONNECTION_STRATEGY_CLAIM = `${AUTH0_CLAIM_NAMESPACE}connection.strategy`; +const WORKSPACE_ID_CLAIM = `${AUTH0_CLAIM_NAMESPACE}workspaceId`; + +const isRegionType = (region: string): region is RegionType => + REGIONS.includes(region as RegionType); + +const getDustDomain = (claims: Record) => { + const region = claims[REGION_CLAIM]; + + return ( + (isRegionType(region) && DOMAIN_FOR_REGION[region]) || + DEFAULT_DUST_API_DOMAIN + ); +}; + +const getConnectionDetails = (claims: Record) => { + const connectionStrategy = claims[CONNECTION_STRATEGY_CLAIM]; + const ws = claims[WORKSPACE_ID_CLAIM]; + return { + connectionStrategy, + connection: ws ? makeEnterpriseConnectionName(ws) : undefined, + }; +}; + +// Fetch me sends a request to the /me route to get the user info. +const fetchMe = async ( + accessToken: string, + dustDomain: string +): Promise> => { + const response = await fetch(`${dustDomain}/api/v1/me`, { + headers: { + Authorization: `Bearer ${accessToken}`, + "X-Request-Origin": "extension", + }, + }); + const me = await response.json(); + + if (!response.ok) { + return new Err(new AuthError(me.error.type, me.error.message)); + } + return new Ok(me); +}; diff --git a/extension/app/src/hooks/useAuthErrorCheck.ts b/extension/app/src/hooks/useAuthErrorCheck.ts index d1fbd1d425dc..b4e219d956b6 100644 --- a/extension/app/src/hooks/useAuthErrorCheck.ts +++ b/extension/app/src/hooks/useAuthErrorCheck.ts @@ -1,8 +1,10 @@ import { useAuth } from "@extension/components/auth/AuthProvider"; -import { logout, refreshToken } from "@extension/lib/auth"; +import { usePlatform } from "@extension/shared/context/platform"; import { useEffect } from "react"; export const useAuthErrorCheck = (error: any, mutate: () => any) => { + const platform = usePlatform(); + const { setAuthError, redirectToSSOLogin, workspace } = useAuth(); useEffect(() => { const handleError = async () => { @@ -13,19 +15,19 @@ export const useAuthErrorCheck = (error: any, mutate: () => any) => { return redirectToSSOLogin(workspace); } setAuthError(error); - void logout(); + void platform.auth.logout(); break; case "not_authenticated": case "invalid_oauth_token_error": setAuthError(error); - void logout(); + void platform.auth.logout(); break; case "expired_oauth_token_error": - const res = await refreshToken(); + const res = await platform.auth.refreshToken(); if (res.isOk()) { mutate(); } else { - void logout(); + void platform.auth.logout(); } break; case "user_not_found": diff --git a/extension/app/src/hooks/useEventSource.ts b/extension/app/src/hooks/useEventSource.ts index 843246b0d0ba..4d581746cfc0 100644 --- a/extension/app/src/hooks/useEventSource.ts +++ b/extension/app/src/hooks/useEventSource.ts @@ -1,6 +1,6 @@ -import { getAccessToken } from "@extension/lib/auth"; import { useCallback, useEffect, useRef, useState } from "react"; const RECONNECT_DELAY = 5000; // 5 seconds. +import { usePlatform } from "@extension/shared/context/platform"; import { EventSourcePolyfill } from "event-source-polyfill"; /** @@ -89,6 +89,7 @@ export function useEventSource( const [isError, setIsError] = useState(null); const lastEvent = useRef(null); const reconnectAttempts = useRef(0); + const platform = usePlatform(); // We use a counter to trigger reconnects when the counter changes. const [reconnectCounter, setReconnectCounter] = useState(0); @@ -173,7 +174,7 @@ export function useEventSource( return; } const call = async () => { - const token = await getAccessToken(); + const token = await platform.auth.getAccessToken(); if (token) { connect(token); } diff --git a/extension/app/src/lib/auth.ts b/extension/app/src/lib/auth.ts deleted file mode 100644 index 1b710918db76..000000000000 --- a/extension/app/src/lib/auth.ts +++ /dev/null @@ -1,210 +0,0 @@ -import type { Result, WorkspaceType } from "@dust-tt/client"; -import { Err, Ok } from "@dust-tt/client"; -import { - AUTH0_CLAIM_NAMESPACE, - DEFAULT_DUST_API_DOMAIN, - DUST_EU_URL, - DUST_US_URL, -} from "@extension/lib/config"; -import { - sendAuthMessage, - sendRefreshTokenMessage, - sentLogoutMessage, -} from "@extension/lib/messages"; -import type { - StoredTokens, - StoredUser, - UserTypeWithExtensionWorkspaces, -} from "@extension/lib/storage"; -import { - clearStoredData, - getStoredTokens, - saveTokens, - saveUser, -} from "@extension/lib/storage"; -import { jwtDecode } from "jwt-decode"; - -const REGIONS = ["europe-west1", "us-central1"] as const; -export type RegionType = (typeof REGIONS)[number]; - -const isRegionType = (region: string): region is RegionType => - REGIONS.includes(region as RegionType); - -const REGION_CLAIM = `${AUTH0_CLAIM_NAMESPACE}region`; -const CONNECTION_STRATEGY_CLAIM = `${AUTH0_CLAIM_NAMESPACE}connection.strategy`; -const WORKSPACE_ID_CLAIM = `${AUTH0_CLAIM_NAMESPACE}workspaceId`; - -const DOMAIN_FOR_REGION: Record = { - "us-central1": DUST_US_URL, - "europe-west1": DUST_EU_URL, -}; - -export const SUPPORTED_ENTERPRISE_CONNECTIONS_STRATEGIES = [ - "okta", - "samlp", - "waad", -]; - -const log = console.error; - -type AuthErrorCode = - | "user_not_found" - | "sso_enforced" - | "not_authenticated" - | "invalid_oauth_token_error" - | "expired_oauth_token_error"; - -export class AuthError extends Error { - readonly type = "AuthError"; - constructor( - readonly code: AuthErrorCode, - msg?: string - ) { - super(msg); - } -} - -// Login sends a message to the background script to call the auth0 login endpoint. -// It saves the tokens in the extension and schedules a token refresh. -// Then it calls the /me route to get the user info. -export const login = async ( - isForceLogin?: boolean, - forcedConnection?: string -): Promise> => { - try { - const response = await sendAuthMessage(isForceLogin, forcedConnection); - if (!response.accessToken) { - throw new Error("No access token received."); - } - const tokens = await saveTokens(response); - - const claims = jwtDecode>(tokens.accessToken); - - const dustDomain = getDustDomain(claims); - const connectionDetails = getConnectionDetails(claims); - - const res = await fetchMe(tokens.accessToken, dustDomain); - if (res.isErr()) { - return res; - } - const workspaces = res.value.user.workspaces; - const user = await saveUser({ - ...res.value.user, - ...connectionDetails, - dustDomain, - selectedWorkspace: workspaces.length === 1 ? workspaces[0].sId : null, - }); - return new Ok({ tokens, user }); - } catch (error) { - return new Err(new AuthError("not_authenticated", error?.toString())); - } -}; - -// Logout sends a message to the background script to call the auth0 logout endpoint. -// It also clears the stored tokens in the extension. -export const logout = async (): Promise => { - try { - const response = await sentLogoutMessage(); - if (!response?.success) { - throw new Error("No success response received."); - } - return true; - } catch (error) { - log("Logout failed: Unknown error.", error); - return false; - } finally { - await clearStoredData(); - } -}; - -// Refresh token sends a message to the background script to call the auth0 refresh token endpoint. -// It updates the stored tokens with the new access token. -// If the refresh token is invalid, it will call handleLogout. -export const refreshToken = async ( - tokens?: StoredTokens | null -): Promise> => { - try { - tokens = tokens ?? (await getStoredTokens()); - if (!tokens) { - return new Err(new AuthError("not_authenticated", "No tokens found.")); - } - const response = await sendRefreshTokenMessage(tokens.refreshToken); - if (!response?.accessToken) { - return new Err( - new AuthError("not_authenticated", "No access token received") - ); - } - return new Ok(await saveTokens(response)); - } catch (error) { - log("Refresh token: unknown error.", error); - return new Err(new AuthError("not_authenticated", error?.toString())); - } -}; - -export const getAccessToken = async (): Promise => { - let tokens = await getStoredTokens(); - if (!tokens || !tokens.accessToken || tokens.expiresAt < Date.now()) { - const refreshRes = await refreshToken(tokens); - if (refreshRes.isOk()) { - tokens = refreshRes.value; - } - } - - return tokens?.accessToken ?? null; -}; - -export function makeEnterpriseConnectionName(workspaceId: string) { - return `workspace-${workspaceId}`; -} - -export function isValidEnterpriseConnectionName( - user: StoredUser, - workspace: WorkspaceType -) { - if (!workspace.ssoEnforced) { - return true; - } - - return ( - SUPPORTED_ENTERPRISE_CONNECTIONS_STRATEGIES.includes( - user.connectionStrategy - ) && makeEnterpriseConnectionName(workspace.sId) === user.connection - ); -} - -const getDustDomain = (claims: Record) => { - const region = claims[REGION_CLAIM]; - - return ( - (isRegionType(region) && DOMAIN_FOR_REGION[region]) || - DEFAULT_DUST_API_DOMAIN - ); -}; - -const getConnectionDetails = (claims: Record) => { - const connectionStrategy = claims[CONNECTION_STRATEGY_CLAIM]; - const ws = claims[WORKSPACE_ID_CLAIM]; - return { - connectionStrategy, - connection: ws ? makeEnterpriseConnectionName(ws) : undefined, - }; -}; - -// Fetch me sends a request to the /me route to get the user info. -const fetchMe = async ( - accessToken: string, - dustDomain: string -): Promise> => { - const response = await fetch(`${dustDomain}/api/v1/me`, { - headers: { - Authorization: `Bearer ${accessToken}`, - "X-Request-Origin": "extension", - }, - }); - const me = await response.json(); - - if (!response.ok) { - return new Err(new AuthError(me.error.type, me.error.message)); - } - return new Ok(me); -}; diff --git a/extension/app/src/lib/conversation.ts b/extension/app/src/lib/conversation.ts index d53a0d9b305f..f26ae00249ec 100644 --- a/extension/app/src/lib/conversation.ts +++ b/extension/app/src/lib/conversation.ts @@ -15,11 +15,11 @@ import type { UserType, } from "@dust-tt/client"; import { Err, Ok } from "@dust-tt/client"; -import { getAccessToken } from "@extension/lib/auth"; import type { GetActiveTabOptions } from "@extension/lib/messages"; import { sendGetActiveTabMessage } from "@extension/lib/messages"; import { getStoredUser } from "@extension/lib/storage"; import type { UploadedFileWithSupersededContentFragmentId } from "@extension/lib/types"; +import { usePlatform } from "@extension/shared/context/platform"; type SubmitMessageError = { type: @@ -296,7 +296,8 @@ export async function retryMessage({ conversationId: string; messageId: string; }): Promise> { - const token = await getAccessToken(); + const platform = usePlatform(); + const token = await platform.auth.getAccessToken(); const user = await getStoredUser(); if (!user) { diff --git a/extension/app/src/shared/context/platform.ts b/extension/app/src/shared/context/platform.ts new file mode 100644 index 000000000000..3b536b3b0326 --- /dev/null +++ b/extension/app/src/shared/context/platform.ts @@ -0,0 +1,6 @@ +import type { PlatformService } from "@extension/shared/services/platform"; +import React, { useContext } from "react"; + +export const PlatformContext = React.createContext(null!); + +export const usePlatform = () => useContext(PlatformContext); diff --git a/extension/app/src/shared/services/auth.ts b/extension/app/src/shared/services/auth.ts new file mode 100644 index 000000000000..80e54838ae65 --- /dev/null +++ b/extension/app/src/shared/services/auth.ts @@ -0,0 +1,58 @@ +import type { Result, WorkspaceType } from "@dust-tt/client"; +import type { StoredTokens, StoredUser } from "@extension/lib/storage"; + +type AuthErrorCode = + | "user_not_found" + | "sso_enforced" + | "not_authenticated" + | "invalid_oauth_token_error" + | "expired_oauth_token_error"; + +export class AuthError extends Error { + readonly type = "AuthError"; + constructor( + readonly code: AuthErrorCode, + msg?: string + ) { + super(msg); + } +} + +export const SUPPORTED_ENTERPRISE_CONNECTIONS_STRATEGIES = [ + "okta", + "samlp", + "waad", +]; + +export interface AuthService { + login( + isForceLogin?: boolean, + forcedConnection?: string + ): Promise>; + logout(): Promise; + + refreshToken( + tokens?: StoredTokens | null + ): Promise>; + + getAccessToken(): Promise; +} + +export function makeEnterpriseConnectionName(workspaceId: string) { + return `workspace-${workspaceId}`; +} + +export function isValidEnterpriseConnectionName( + user: StoredUser, + workspace: WorkspaceType +) { + if (!workspace.ssoEnforced) { + return true; + } + + return ( + SUPPORTED_ENTERPRISE_CONNECTIONS_STRATEGIES.includes( + user.connectionStrategy + ) && makeEnterpriseConnectionName(workspace.sId) === user.connection + ); +} diff --git a/extension/app/src/shared/services/platform.ts b/extension/app/src/shared/services/platform.ts new file mode 100644 index 000000000000..89c7fbd65b4c --- /dev/null +++ b/extension/app/src/shared/services/platform.ts @@ -0,0 +1,6 @@ +import type { AuthService } from "@extension/shared/services/auth"; + +export interface PlatformService { + auth: AuthService; + platform: "chrome" | "front"; +} diff --git a/extension/package-lock.json b/extension/package-lock.json index f36bd96cb139..b3a4d492e403 100644 --- a/extension/package-lock.json +++ b/extension/package-lock.json @@ -9,6 +9,7 @@ "version": "0.0.1", "license": "ISC", "dependencies": { + "@auth0/auth0-spa-js": "^2.1.3", "@dust-tt/client": "^1.0.26", "@dust-tt/sparkle": "^0.2.391", "@tailwindcss/forms": "^0.5.9", @@ -111,6 +112,11 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/@auth0/auth0-spa-js": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@auth0/auth0-spa-js/-/auth0-spa-js-2.1.3.tgz", + "integrity": "sha512-NMTBNuuG4g3rame1aCnNS5qFYIzsTUV5qTFPRfTyYFS1feS6jsCBR+eTq9YkxCp1yuoM2UIcjunPaoPl77U9xQ==" + }, "node_modules/@babel/code-frame": { "version": "7.25.7", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.25.7.tgz", diff --git a/extension/package.json b/extension/package.json index 4ee179dbb8df..e6aafe4bd144 100644 --- a/extension/package.json +++ b/extension/package.json @@ -5,6 +5,7 @@ "scripts": { "clean": "rm -rf build/*", "dev": "npm run clean && tsx run/watch.ts --mode=development", + "dev-front": "webpack serve --config config/webpack.front.js", "package:production": "tsx run/build.ts", "package:release": "tsx run/build.ts --release", "analyze": "tsx run/build.ts --analyze", @@ -49,6 +50,7 @@ "zip-webpack-plugin": "^4.0.1" }, "dependencies": { + "@auth0/auth0-spa-js": "^2.1.3", "@dust-tt/client": "^1.0.26", "@dust-tt/sparkle": "^0.2.391", "@tailwindcss/forms": "^0.5.9", From 8e0683078587f333f20ff8dc912bff51649c0a12 Mon Sep 17 00:00:00 2001 From: Flavien David Date: Wed, 12 Feb 2025 17:05:49 -0800 Subject: [PATCH 2/5] working --- extension/app/background.ts | 6 +- extension/app/src/chrome/platform.ts | 2 + extension/app/src/chrome/services/auth.ts | 16 +- extension/app/src/chrome/storage.ts | 21 + .../src/components/auth/ProtectedRoute.tsx | 73 +- extension/app/src/components/auth/useAuth.ts | 40 +- .../conversation/AttachFragment.tsx | 21 +- .../conversation/ConversationContainer.tsx | 26 +- .../app/src/components/input_bar/InputBar.tsx | 30 +- extension/app/src/front/main.tsx | 11 +- extension/app/src/front/platform.ts | 9 + extension/app/src/front/services/auth.ts | 89 +- extension/app/src/front/storage.ts | 16 + extension/app/src/lib/conversation.ts | 69 +- extension/app/src/lib/dust_api.ts | 5 +- extension/app/src/lib/messages.ts | 4 +- extension/app/src/lib/storage.ts | 123 +- extension/app/src/pages/RunPage.tsx | 4 +- .../app/src/shared/interfaces/storage.ts | 5 + extension/app/src/shared/services/platform.ts | 2 + extension/config/webpack.front.js | 73 + extension/package-lock.json | 2109 ++++++++++++++++- extension/package.json | 2 + extension/run/watch.ts | 23 +- front/lib/api/auth0.ts | 7 +- front/lib/api/auth_wrappers.ts | 11 +- front/lib/auth.ts | 3 + front/middleware.ts | 39 + 28 files changed, 2583 insertions(+), 256 deletions(-) create mode 100644 extension/app/src/chrome/storage.ts create mode 100644 extension/app/src/front/platform.ts create mode 100644 extension/app/src/front/storage.ts create mode 100644 extension/app/src/shared/interfaces/storage.ts create mode 100644 extension/config/webpack.front.js diff --git a/extension/app/background.ts b/extension/app/background.ts index 72192fe64866..e2ad912b6ccb 100644 --- a/extension/app/background.ts +++ b/extension/app/background.ts @@ -1,3 +1,4 @@ +import { ChromeStorageService } from "@extension/chrome/storage"; import type { PendingUpdate } from "@extension/lib/storage"; import { getStoredUser, savePendingUpdate } from "@extension/lib/storage"; @@ -41,7 +42,7 @@ chrome.runtime.onUpdateAvailable.addListener(async (details) => { version: details.version, detectedAt: Date.now(), }; - await savePendingUpdate(pendingUpdate); + await savePendingUpdate(new ChromeStorageService(), pendingUpdate); }); /** @@ -78,7 +79,7 @@ const shouldDisableContextMenuForDomain = async ( return true; } - const user = await getStoredUser(); + const user = await getStoredUser(new ChromeStorageService()); if (!user || !user.selectedWorkspace) { return false; } @@ -620,6 +621,7 @@ const exchangeCodeForTokens = async ( * Logout the user from Auth0. */ const logout = (sendResponse: (response: AuthBackgroundResponse) => void) => { + console.log("logout"); const redirectUri = chrome.identity.getRedirectURL(); const logoutUrl = `https://${AUTH0_CLIENT_DOMAIN}/v2/logout?client_id=${AUTH0_CLIENT_ID}&returnTo=${encodeURIComponent(redirectUri)}`; diff --git a/extension/app/src/chrome/platform.ts b/extension/app/src/chrome/platform.ts index a9b329bbd810..984564867293 100644 --- a/extension/app/src/chrome/platform.ts +++ b/extension/app/src/chrome/platform.ts @@ -1,7 +1,9 @@ import { ChromeAuth } from "@extension/chrome/services/auth"; +import { ChromeStorageService } from "@extension/chrome/storage"; import type { PlatformService } from "@extension/shared/services/platform"; export const chromePlatform: PlatformService = { platform: "chrome", auth: new ChromeAuth(), + storage: new ChromeStorageService(), }; diff --git a/extension/app/src/chrome/services/auth.ts b/extension/app/src/chrome/services/auth.ts index ba16fbb83cd4..13f843195128 100644 --- a/extension/app/src/chrome/services/auth.ts +++ b/extension/app/src/chrome/services/auth.ts @@ -1,5 +1,6 @@ import type { Result } from "@dust-tt/client"; import { Err, Ok } from "@dust-tt/client"; +import { ChromeStorageService } from "@extension/chrome/storage"; import { AUTH0_CLAIM_NAMESPACE, DEFAULT_DUST_API_DOMAIN, @@ -57,7 +58,7 @@ export class ChromeAuth implements AuthService { if (!response.accessToken) { throw new Error("No access token received."); } - const tokens = await saveTokens(response); + const tokens = await saveTokens(new ChromeStorageService(), response); const claims = jwtDecode>(tokens.accessToken); @@ -69,7 +70,7 @@ export class ChromeAuth implements AuthService { return res; } const workspaces = res.value.user.workspaces; - const user = await saveUser({ + const user = await saveUser(new ChromeStorageService(), { ...res.value.user, ...connectionDetails, dustDomain, @@ -105,17 +106,20 @@ export class ChromeAuth implements AuthService { tokens?: StoredTokens | null ): Promise> { try { - tokens = tokens ?? (await getStoredTokens()); + tokens = tokens ?? (await getStoredTokens(new ChromeStorageService())); if (!tokens) { return new Err(new AuthError("not_authenticated", "No tokens found.")); } - const response = await sendRefreshTokenMessage(tokens.refreshToken); + const response = await sendRefreshTokenMessage( + new ChromeStorageService(), + tokens.refreshToken + ); if (!response?.accessToken) { return new Err( new AuthError("not_authenticated", "No access token received") ); } - return new Ok(await saveTokens(response)); + return new Ok(await saveTokens(new ChromeStorageService(), response)); } catch (error) { log("Refresh token: unknown error.", error); return new Err(new AuthError("not_authenticated", error?.toString())); @@ -123,7 +127,7 @@ export class ChromeAuth implements AuthService { } async getAccessToken(): Promise { - let tokens = await getStoredTokens(); + let tokens = await getStoredTokens(new ChromeStorageService()); if (!tokens || !tokens.accessToken || tokens.expiresAt < Date.now()) { const refreshRes = await this.refreshToken(tokens); if (refreshRes.isOk()) { diff --git a/extension/app/src/chrome/storage.ts b/extension/app/src/chrome/storage.ts new file mode 100644 index 000000000000..c7bb7dfc4db9 --- /dev/null +++ b/extension/app/src/chrome/storage.ts @@ -0,0 +1,21 @@ +import type { StorageService } from "@extension/shared/interfaces/storage"; + +export class ChromeStorageService implements StorageService { + async get(key: string | string[]): Promise { + if (Array.isArray(key)) { + const result = await chrome.storage.local.get(key); + return result as T; + } + + const result = await chrome.storage.local.get([key]); + return result[key] ?? null; + } + + async set(key: string, value: T): Promise { + await chrome.storage.local.set({ [key]: value }); + } + + async remove(key: string): Promise { + await chrome.storage.local.remove(key); + } +} diff --git a/extension/app/src/components/auth/ProtectedRoute.tsx b/extension/app/src/components/auth/ProtectedRoute.tsx index 554ae1c743b4..f041494522fc 100644 --- a/extension/app/src/components/auth/ProtectedRoute.tsx +++ b/extension/app/src/components/auth/ProtectedRoute.tsx @@ -1,14 +1,7 @@ import type { ExtensionWorkspaceType } from "@dust-tt/client"; -import { - Button, - LogoHorizontalColorLogo, - Page, - Spinner, -} from "@dust-tt/sparkle"; +import { LogoHorizontalColorLogo, Page, Spinner } from "@dust-tt/sparkle"; import { useAuth } from "@extension/components/auth/AuthProvider"; -import type { RouteChangeMesssage } from "@extension/lib/messages"; import type { StoredUser } from "@extension/lib/storage"; -import { getPendingUpdate } from "@extension/lib/storage"; import type { ReactNode } from "react"; import { useEffect, useState } from "react"; import { useNavigate } from "react-router-dom"; @@ -36,19 +29,19 @@ export const ProtectedRoute = ({ children }: ProtectedRouteProps) => { const navigate = useNavigate(); const [isLatestVersion, setIsLatestVersion] = useState(true); - useEffect(() => { - const listener = (message: RouteChangeMesssage) => { - const { type } = message; - if (type === "EXT_ROUTE_CHANGE") { - navigate({ pathname: message.pathname, search: message.search }); - return false; - } - }; - chrome.runtime.onMessage.addListener(listener); - return () => { - chrome.runtime.onMessage.removeListener(listener); - }; - }, [navigate]); + // useEffect(() => { + // const listener = (message: RouteChangeMesssage) => { + // const { type } = message; + // if (type === "EXT_ROUTE_CHANGE") { + // navigate({ pathname: message.pathname, search: message.search }); + // return false; + // } + // }; + // chrome.runtime.onMessage.addListener(listener); + // return () => { + // chrome.runtime.onMessage.removeListener(listener); + // }; + // }, [navigate]); useEffect(() => { if (!isAuthenticated || !isUserSetup || !user || !workspace) { @@ -57,25 +50,25 @@ export const ProtectedRoute = ({ children }: ProtectedRouteProps) => { } }, [navigate, isLoading, isAuthenticated, isUserSetup, user, workspace]); - const checkIsLatestVersion = async () => { - const pendingUpdate = await getPendingUpdate(); - if (!pendingUpdate) { - return null; - } - if (pendingUpdate.version > chrome.runtime.getManifest().version) { - setIsLatestVersion(false); - } - }; + // const checkIsLatestVersion = async () => { + // const pendingUpdate = await getPendingUpdate(); + // if (!pendingUpdate) { + // return null; + // } + // if (pendingUpdate.version > chrome.runtime.getManifest().version) { + // setIsLatestVersion(false); + // } + // }; - useEffect(() => { - void checkIsLatestVersion(); + // useEffect(() => { + // void checkIsLatestVersion(); - chrome.storage.local.onChanged.addListener((changes) => { - if (changes.pendingUpdate) { - void checkIsLatestVersion(); - } - }); - }, []); + // chrome.storage.local.onChanged.addListener((changes) => { + // if (changes.pendingUpdate) { + // void checkIsLatestVersion(); + // } + // }); + // }, []); if (isLoading || !isAuthenticated || !isUserSetup || !user || !workspace) { return ( @@ -96,12 +89,12 @@ export const ProtectedRoute = ({ children }: ProtectedRouteProps) => { -