Skip to content

Commit

Permalink
Sync default payment method to the backend (#10172)
Browse files Browse the repository at this point in the history
* Sync default payment method to the backend

* Add unit tests for parameters

* Fix lint issue

* Fix lint issue
  • Loading branch information
amk-stripe authored Feb 18, 2025
1 parent 0414891 commit bd3790c
Show file tree
Hide file tree
Showing 15 changed files with 266 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,14 @@ abstract class AbsFakeStripeRepository : StripeRepository {
TODO("Not yet implemented")
}

override suspend fun setDefaultPaymentMethod(
customerId: String,
paymentMethodId: String?,
options: ApiRequest.Options
): Result<Customer> {
TODO("Not yet implemented")
}

override suspend fun logOut(
consumerSessionClientSecret: String,
consumerAccountPublishableKey: String?,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,21 @@ class StripeApiRepository @JvmOverloads internal constructor(
}
}

override suspend fun setDefaultPaymentMethod(
customerId: String,
paymentMethodId: String?,
options: ApiRequest.Options
): Result<Customer> {
return fetchStripeModelResult(
apiRequest = apiRequestFactory.createPost(
url = getSetDefaultPaymentMethodUrl(customerId = customerId),
options = options,
params = mapOf("payment_method" to (paymentMethodId ?: ""))
),
jsonParser = CustomerJsonParser()
)
}

/**
* Create a [Token] using the input token parameters.
*
Expand Down Expand Up @@ -2139,6 +2154,16 @@ class StripeApiRepository @JvmOverloads internal constructor(
return getApiUrl("payment_methods/$paymentMethodId")
}

/**
* @return `https://api.stripe.com/v1/elements/customers/:customerId/set_default_payment_method`
*/
@VisibleForTesting
internal fun getSetDefaultPaymentMethodUrl(
customerId: String,
): String {
return getApiUrl("elements/customers/$customerId/set_default_payment_method")
}

private fun getApiUrl(path: String, vararg args: Any): String {
return getApiUrl(String.format(Locale.ENGLISH, path, *args))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,20 @@ interface StripeRepository {
options: ApiRequest.Options
): Result<PaymentMethod>

/**
* Set the customer's default payment method.
*
* @param customerId Id of the customer to update
* @param paymentMethodId Id of the payment method to set as the default. If null, the user's existing default
* payment method will be unset.
* */
@RestrictTo(RestrictTo.Scope.LIBRARY_GROUP)
suspend fun setDefaultPaymentMethod(
customerId: String,
paymentMethodId: String?,
options: ApiRequest.Options,
): Result<Customer>

@RestrictTo(RestrictTo.Scope.LIBRARY_GROUP)
suspend fun createToken(
tokenParams: TokenParams,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,17 @@ internal class StripeApiRepositoryTest {
assertThat(attachUrl).isEqualTo(expectedUrl)
}

@Test
fun testSetDefaultPaymentMethodUrl() {
val customerId = "cus_123"
val setDefaultPaymentMethodUrl = StripeApiRepository.getSetDefaultPaymentMethodUrl(
customerId
)
assertThat(setDefaultPaymentMethodUrl).isEqualTo(
"https://api.stripe.com/v1/elements/customers/$customerId/set_default_payment_method"
)
}

@Test
fun testGetDetachPaymentMethodUrl() {
val paymentMethodId = "pm_1ETDEa2eZvKYlo2CN5828c52"
Expand Down Expand Up @@ -1491,6 +1502,49 @@ internal class StripeApiRepositoryTest {
)
}

@Test
fun setDefaultPaymentMethod_sendPaymentMethodParameter() = runTest {
val stripeResponse = StripeResponse(
code = 200,
body = "",
headers = emptyMap()
)
whenever(stripeNetworkClient.executeRequest(any<ApiRequest>()))
.thenReturn(stripeResponse)

val expectedPaymentMethodId = "pm_123"
create().setDefaultPaymentMethod(
customerId = "cus_123",
paymentMethodId = expectedPaymentMethodId,
DEFAULT_OPTIONS,
)

verify(stripeNetworkClient).executeRequest(apiRequestArgumentCaptor.capture())
val apiRequest = apiRequestArgumentCaptor.firstValue
assertThat(apiRequest.params?.get("payment_method")).isEqualTo(expectedPaymentMethodId)
}

@Test
fun setDefaultPaymentMethod_sendsNullPaymentMethodAsEmptyString() = runTest {
val stripeResponse = StripeResponse(
code = 200,
body = "",
headers = emptyMap()
)
whenever(stripeNetworkClient.executeRequest(any<ApiRequest>()))
.thenReturn(stripeResponse)

create().setDefaultPaymentMethod(
customerId = "cus_123",
paymentMethodId = null,
DEFAULT_OPTIONS,
)

verify(stripeNetworkClient).executeRequest(apiRequestArgumentCaptor.capture())
val apiRequest = apiRequestArgumentCaptor.firstValue
assertThat(apiRequest.params?.get("payment_method")).isEqualTo("")
}

@Test
fun createCardPaymentMethod_setsCorrectPaymentUserAgent() =
runTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ internal class CustomerSessionCustomerSheetTest {
enqueueSetupIntentRetrieval()
enqueueSetupIntentConfirmation()

val paymentMethodId = "pm_12345"
enqueueElementsSession(
cards = listOf(
PaymentMethodFactory.card(id = "pm_12345").update(
PaymentMethodFactory.card(id = paymentMethodId).update(
last4 = "4242",
addCbcNetworks = false,
brand = CardBrand.Visa,
Expand All @@ -170,7 +171,7 @@ internal class CustomerSessionCustomerSheetTest {
page.clickSaveButton()
assertOnlySavedCardIsDisplayed()

page.clickConfirmButton()
context.markTestSucceeded()
}

private fun assertOnlySavedCardIsDisplayed() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,9 @@ internal class CustomerSheetViewModel(
private fun selectSavedPaymentMethod(savedPaymentSelection: PaymentSelection.Saved?) {
viewModelScope.launch(workContext) {
awaitSavedSelectionDataSource().setSavedSelection(
savedPaymentSelection?.toSavedSelection()
savedPaymentSelection?.toSavedSelection(),
shouldSyncDefault =
customerState.value.metadata?.customerMetadata?.isPaymentMethodSetAsDefaultEnabled == true,
).onSuccess {
confirmPaymentSelection(
paymentSelection = savedPaymentSelection,
Expand All @@ -1104,7 +1106,7 @@ internal class CustomerSheetViewModel(

private fun selectGooglePay() {
viewModelScope.launch(workContext) {
awaitSavedSelectionDataSource().setSavedSelection(SavedSelection.GooglePay)
awaitSavedSelectionDataSource().setSavedSelection(SavedSelection.GooglePay, shouldSyncDefault = false)
.onSuccess {
confirmPaymentSelection(
paymentSelection = PaymentSelection.GooglePay,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ internal class CustomerAdapterDataSource @Inject constructor(
}
}

override suspend fun setSavedSelection(selection: SavedSelection?) = runCatchingAdapterTask {
override suspend fun setSavedSelection(
selection: SavedSelection?,
shouldSyncDefault: Boolean
) = runCatchingAdapterTask {
customerAdapter.setSelectedPaymentOption(selection?.toPaymentOption())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import com.stripe.android.model.ElementsSession
import com.stripe.android.paymentsheet.PrefsRepository
import com.stripe.android.paymentsheet.model.SavedSelection
import com.stripe.android.paymentsheet.model.toSavedSelection
import com.stripe.android.paymentsheet.repositories.CustomerRepository
import kotlinx.coroutines.withContext
import java.io.IOException
import javax.inject.Inject
import kotlin.coroutines.CoroutineContext

internal class CustomerSessionSavedSelectionDataSource @Inject constructor(
private val elementsSessionManager: CustomerSessionElementsSessionManager,
private val customerRepository: CustomerRepository,
private val prefsRepositoryFactory: @JvmSuppressWildcards (String) -> PrefsRepository,
@IOContext private val workContext: CoroutineContext,
) : CustomerSheetSavedSelectionDataSource {
Expand Down Expand Up @@ -63,18 +65,48 @@ internal class CustomerSessionSavedSelectionDataSource @Inject constructor(
}
}

override suspend fun setSavedSelection(selection: SavedSelection?): CustomerSheetDataResult<Unit> {
override suspend fun setSavedSelection(
selection: SavedSelection?,
shouldSyncDefault: Boolean,
): CustomerSheetDataResult<Unit> {
return withContext(workContext) {
createPrefsRepository().mapCatching { prefsRepository ->
val result = prefsRepository.setSavedSelection(selection)

if (!result) {
throw IOException("Unable to persist payment option $selection")
elementsSessionManager.fetchCustomerSessionEphemeralKey().mapCatching { ephemeralKey ->
if (shouldSyncDefault) {
saveSelectionToBackend(ephemeralKey, selection)
} else {
saveSelectionToPrefs(selection)
}
}.toCustomerSheetDataResult()
}
}

private suspend fun saveSelectionToPrefs(
selection: SavedSelection?
) {
createPrefsRepository().mapCatching { prefsRepository ->
val result = prefsRepository.setSavedSelection(selection)

if (!result) {
throw IOException("Unable to persist payment option $selection")
}
}
}

private suspend fun saveSelectionToBackend(
ephemeralKey: CachedCustomerEphemeralKey,
selection: SavedSelection?
) {
val paymentMethodId = (selection as? SavedSelection.PaymentMethod)?.id
customerRepository.setDefaultPaymentMethod(
paymentMethodId = paymentMethodId,
customerInfo = CustomerRepository.CustomerInfo(
id = ephemeralKey.customerId,
ephemeralKeySecret = ephemeralKey.ephemeralKey,
customerSessionClientSecret = ephemeralKey.customerSessionClientSecret,
)
).getOrThrow()
}

private suspend fun createPrefsRepository(): CustomerSheetDataResult<PrefsRepository> {
return elementsSessionManager.fetchCustomerSessionEphemeralKey().mapCatching { ephemeralKey ->
prefsRepositoryFactory(ephemeralKey.customerId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,8 @@ internal interface CustomerSheetSavedSelectionDataSource {
customerSessionElementsSession: CustomerSessionElementsSession?
): CustomerSheetDataResult<SavedSelection?>

suspend fun setSavedSelection(selection: SavedSelection?): CustomerSheetDataResult<Unit>
suspend fun setSavedSelection(
selection: SavedSelection?,
shouldSyncDefault: Boolean,
): CustomerSheetDataResult<Unit>
}
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,18 @@ internal class CustomerApiRepository @Inject constructor(
logger.error("Failed to update payment method $paymentMethodId.", it)
}

override suspend fun setDefaultPaymentMethod(
customerInfo: CustomerRepository.CustomerInfo,
paymentMethodId: String?
): Result<Customer> = stripeRepository.setDefaultPaymentMethod(
paymentMethodId = paymentMethodId,
customerId = customerInfo.id,
options = ApiRequest.Options(
apiKey = customerInfo.ephemeralKeySecret,
stripeAccount = lazyPaymentConfig.get().stripeAccountId,
)
)

private fun filterPaymentMethods(allPaymentMethods: List<PaymentMethod>): List<PaymentMethod> {
val paymentMethods = mutableListOf<PaymentMethod>()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ internal interface CustomerRepository {
params: PaymentMethodUpdateParams
): Result<PaymentMethod>

suspend fun setDefaultPaymentMethod(
customerInfo: CustomerInfo,
paymentMethodId: String?,
): Result<Customer>

data class CustomerInfo(
val id: String,
val ephemeralKeySecret: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class CustomerAdapterDataSourceTest {
),
)

val result = dataSource.setSavedSelection(SavedSelection.GooglePay)
val result = dataSource.setSavedSelection(SavedSelection.GooglePay, false)

assertThat(result).isInstanceOf<CustomerSheetDataResult.Success<Unit>>()
}
Expand All @@ -176,7 +176,7 @@ class CustomerAdapterDataSourceTest {
)
)

val result = dataSource.setSavedSelection(SavedSelection.GooglePay)
val result = dataSource.setSavedSelection(SavedSelection.GooglePay, false)

assertThat(result).isInstanceOf<CustomerSheetDataResult.Failure<Unit>>()

Expand All @@ -199,7 +199,7 @@ class CustomerAdapterDataSourceTest {
)
)

val result = dataSource.setSavedSelection(SavedSelection.GooglePay)
val result = dataSource.setSavedSelection(SavedSelection.GooglePay, false)

assertThat(result).isInstanceOf<CustomerSheetDataResult.Failure<Unit>>()

Expand Down
Loading

0 comments on commit bd3790c

Please sign in to comment.