Skip to content

Commit

Permalink
feat: invalidate MAv2 account when switching mode (#1384)
Browse files Browse the repository at this point in the history
* WIP

* fix: switching works (WIP)

* fix: don't mutate store in getAccount

* fix: get everything working smoothly

* chore: clean up

* chore: remove comment & rebuild docs

* fix: swap order of guards to make more sense

* chore: add type guards

* fix: typo

* fix: use accountParams instead of params when creating client
  • Loading branch information
jakehobbs authored Feb 25, 2025
1 parent 7df49c8 commit c75826e
Show file tree
Hide file tree
Showing 14 changed files with 215 additions and 298 deletions.
188 changes: 109 additions & 79 deletions account-kit/core/src/actions/createAccount.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import type {
import { getBundlerClient } from "./getBundlerClient.js";
import { getSigner } from "./getSigner.js";
import { getSignerStatus } from "./getSignerStatus.js";
import type { GetAccountParams } from "./getAccount";

type OmitSignerTransportChain<T> = Omit<T, "signer" | "transport" | "chain">;

Expand Down Expand Up @@ -75,14 +76,15 @@ export type CreateAccountParams<TAccount extends SupportedAccountTypes> = {
* @returns {Promise<SupportedAccounts>} A promise that resolves to the created account object
*/
export async function createAccount<TAccount extends SupportedAccountTypes>(
{ type, accountParams: params }: CreateAccountParams<TAccount>,
params: CreateAccountParams<TAccount>,
config: AlchemyAccountsConfig
): Promise<SupportedAccounts> {
const store = config.store;
const accounts = store.getState().accounts;
if (!accounts) {
throw new ClientOnlyPropertyError("account");
}
const accountConfigs = store.getState().accountConfigs;

const bundlerClient = getBundlerClient(config);
const transport = custom(bundlerClient);
Expand All @@ -94,102 +96,106 @@ export async function createAccount<TAccount extends SupportedAccountTypes>(
throw new Error("Signer not connected");
}

const cachedAccount = accounts[chain.id]?.[type];
const cachedAccount = accounts[chain.id]?.[params.type];
if (cachedAccount.status !== "RECONNECTING" && cachedAccount.account) {
return cachedAccount.account;
}
const cachedConfig = store.getState().accountConfigs[chain.id]?.[type];

const accountPromise = (() => {
switch (type) {
case "LightAccount":
return createLightAccount({
...params,
...cachedConfig,
signer,
transport: (opts) => transport({ ...opts, retryCount: 0 }),
chain,
}).then((account) => {
CoreLogger.trackEvent({
name: "account_initialized",
data: {
accountType: "LightAccount",
accountVersion: account.getLightAccountVersion(),
},
});

return account;
if (isLightAccountParams(params)) {
return createLightAccount({
...accountConfigs[chain.id]?.[params.type],
...params.accountParams,
signer,
transport: (opts) => transport({ ...opts, retryCount: 0 }),
chain,
}).then((account) => {
CoreLogger.trackEvent({
name: "account_initialized",
data: {
accountType: "LightAccount",
accountVersion: account.getLightAccountVersion(),
},
});
case "MultiOwnerLightAccount":
return createMultiOwnerLightAccount({
...(params as AccountConfig<"MultiOwnerLightAccount">),
...(cachedConfig as OmitSignerTransportChain<CreateMultiOwnerLightAccountParams>),
signer,
transport: (opts) => transport({ ...opts, retryCount: 0 }),
chain,
}).then((account) => {
CoreLogger.trackEvent({
name: "account_initialized",
data: {
accountType: "MultiOwnerLightAccount",
accountVersion: account.getLightAccountVersion(),
},
});
return account;
return account;
});
} else if (isMultiOwnerLightAccountParams(params)) {
return createMultiOwnerLightAccount({
...accountConfigs[chain.id]?.[params.type],
...params.accountParams,
signer,
transport: (opts) => transport({ ...opts, retryCount: 0 }),
chain,
}).then((account) => {
CoreLogger.trackEvent({
name: "account_initialized",
data: {
accountType: "MultiOwnerLightAccount",
accountVersion: account.getLightAccountVersion(),
},
});
case "MultiOwnerModularAccount":
return createMultiOwnerModularAccount({
...(params as AccountConfig<"MultiOwnerModularAccount">),
...(cachedConfig as OmitSignerTransportChain<CreateMultiOwnerModularAccountParams>),
signer,
transport: (opts) => transport({ ...opts, retryCount: 0 }),
chain,
}).then((account) => {
CoreLogger.trackEvent({
name: "account_initialized",
data: {
accountType: "MultiOwnerModularAccount",
accountVersion: "v1.0.0",
},
});

return account;
return account;
});
} else if (isMultiOwnerModularAccountParams(params)) {
return createMultiOwnerModularAccount({
...accountConfigs[chain.id]?.[params.type],
...params.accountParams,
signer,
transport: (opts) => transport({ ...opts, retryCount: 0 }),
chain,
}).then((account) => {
CoreLogger.trackEvent({
name: "account_initialized",
data: {
accountType: "MultiOwnerModularAccount",
accountVersion: "v1.0.0",
},
});
case "ModularAccountV2":
return createModularAccountV2({
...(params as AccountConfig<"ModularAccountV2">),
...(cachedConfig as OmitSignerTransportChain<CreateModularAccountV2Params>),
signer,
transport: (opts) => transport({ ...opts, retryCount: 0 }),
chain,
}).then((account) => {
CoreLogger.trackEvent({
name: "account_initialized",
data: {
accountType: "ModularAccountV2",
accountVersion: "v2.0.0",
},
});

return account;
return account;
});
} else if (isModularV2AccountParams(params)) {
return createModularAccountV2({
...accountConfigs[chain.id]?.[params.type],
...params.accountParams,
signer,
transport: (opts) => transport({ ...opts, retryCount: 0 }),
chain,
}).then((account) => {
CoreLogger.trackEvent({
name: "account_initialized",
data: {
accountType: "ModularAccountV2",
accountVersion: "v2.0.0",
},
});
default:
throw new Error("Unsupported account type");
return account;
});
} else {
throw new Error(`Unsupported account type: ${params.type}`);
}
})();

if (cachedAccount.status !== "RECONNECTING") {
store.setState(() => ({
store.setState((state) => ({
accounts: {
...accounts,
[chain.id]: {
...accounts[chain.id],
[type]: {
[params.type]: {
status: "INITIALIZING",
account: accountPromise,
},
},
},
accountConfigs: {
...state.accountConfigs,
[chain.id]: {
...state.accountConfigs[chain.id],
[params.type]: {
...params.accountParams,
},
},
},
}));
}

Expand All @@ -201,7 +207,7 @@ export async function createAccount<TAccount extends SupportedAccountTypes>(
...accounts,
[chain.id]: {
...accounts[chain.id],
[type]: {
[params.type]: {
status: "READY",
account,
},
Expand All @@ -211,8 +217,8 @@ export async function createAccount<TAccount extends SupportedAccountTypes>(
...state.accountConfigs,
[chain.id]: {
...state.accountConfigs[chain.id],
[type]: {
...params,
[params.type]: {
...params.accountParams,
accountAddress: account.address,
initCode,
},
Expand All @@ -225,7 +231,7 @@ export async function createAccount<TAccount extends SupportedAccountTypes>(
...accounts,
[chain.id]: {
...accounts[chain.id],
[type]: {
[params.type]: {
status: "ERROR",
error,
},
Expand All @@ -236,3 +242,27 @@ export async function createAccount<TAccount extends SupportedAccountTypes>(

return accountPromise;
}

export const isModularV2AccountParams = (
params: CreateAccountParams<SupportedAccountTypes>
): params is GetAccountParams<"ModularAccountV2"> => {
return params.type === "ModularAccountV2";
};

export const isLightAccountParams = (
params: CreateAccountParams<SupportedAccountTypes>
): params is GetAccountParams<"LightAccount"> => {
return params.type === "LightAccount";
};

export const isMultiOwnerLightAccountParams = (
params: CreateAccountParams<SupportedAccountTypes>
): params is GetAccountParams<"MultiOwnerLightAccount"> => {
return params.type === "MultiOwnerLightAccount";
};

export const isMultiOwnerModularAccountParams = (
params: CreateAccountParams<SupportedAccountTypes>
): params is GetAccountParams<"MultiOwnerModularAccount"> => {
return params.type === "MultiOwnerModularAccount";
};
25 changes: 18 additions & 7 deletions account-kit/core/src/actions/getAccount.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import { defaultAccountState } from "../store/store.js";
import type { AccountState } from "../store/types.js";
import type { AlchemyAccountsConfig, SupportedAccountTypes } from "../types.js";
import { type CreateAccountParams } from "./createAccount.js";
import {
isModularV2AccountParams,
type CreateAccountParams,
} from "./createAccount.js";
import { getChain } from "./getChain.js";

export type GetAccountResult<TAccount extends SupportedAccountTypes> =
Expand Down Expand Up @@ -28,17 +32,24 @@ export type GetAccountParams<TAccount extends SupportedAccountTypes> =
* @returns {GetAccountResult<TAccount>} The result which includes the account if found and its status
*/
export const getAccount = <TAccount extends SupportedAccountTypes>(
{ type }: GetAccountParams<TAccount>,
params: GetAccountParams<TAccount>,
config: AlchemyAccountsConfig
): GetAccountResult<TAccount> => {
const accounts = config.store.getState().accounts;
const chain = getChain(config);
const account = accounts?.[chain.id]?.[type];
const account = accounts?.[chain.id]?.[params.type];
if (!account) {
return {
account: undefined,
status: "DISCONNECTED",
};
return defaultAccountState();
}

if (isModularV2AccountParams(params) && account?.status === "READY") {
const accountConfig =
config.store.getState().accountConfigs[chain.id]?.[params.type];
const haveMode = accountConfig?.mode ?? "default";
const wantMode = params.accountParams?.mode ?? "default";
if (haveMode !== wantMode) {
return defaultAccountState();
}
}

return account;
Expand Down
10 changes: 5 additions & 5 deletions account-kit/core/src/actions/getSmartAccountClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import type {
SupportedAccounts,
SupportedAccountTypes,
} from "../types";
import { createAccount } from "./createAccount.js";
import { createAccount, isModularV2AccountParams } from "./createAccount.js";
import { getAccount, type GetAccountParams } from "./getAccount.js";
import { getAlchemyTransport } from "./getAlchemyTransport.js";
import { getConnection } from "./getConnection.js";
Expand Down Expand Up @@ -142,8 +142,9 @@ export function getSmartAccountClient(
signerStatus.isAuthenticating ||
signerStatus.isInitializing
) {
if (!account && signerStatus.isConnected)
if (!account && signerStatus.isConnected) {
createAccount({ type, accountParams }, config);
}

if (clientState && clientState.isLoadingClient) {
return clientState;
Expand Down Expand Up @@ -220,9 +221,8 @@ export function getSmartAccountClient(
};
case "ModularAccountV2":
const is7702 =
params.accountParams &&
"mode" in params.accountParams &&
params.accountParams.mode === "7702";
isModularV2AccountParams(params) &&
params.accountParams?.mode === "7702";
return {
client: createAlchemySmartAccountClient({
transport,
Expand Down
9 changes: 3 additions & 6 deletions examples/ui-demo/src/app/config.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,11 @@ export type Config = {
}
| undefined;
};
walletType: WalletTypes;
accountMode: AccountMode;
supportUrl?: string;
};

export enum WalletTypes {
smart = "smart",
hybrid7702 = "7702",
}
export type AccountMode = "default" | "7702";

export const DEFAULT_CONFIG: Config = {
auth: {
Expand Down Expand Up @@ -77,7 +74,7 @@ export const DEFAULT_CONFIG: Config = {
logoLight: undefined,
logoDark: undefined,
},
walletType: WalletTypes.smart,
accountMode: "default",
};

export const queryClient = new QueryClient();
Expand Down
Loading

0 comments on commit c75826e

Please sign in to comment.