Skip to content

Commit

Permalink
fix: enforce contract negotiation request and transfer request consis…
Browse files Browse the repository at this point in the history
…tency
  • Loading branch information
bscholtes1A committed Jun 11, 2024
1 parent 6034040 commit 4a22187
Show file tree
Hide file tree
Showing 21 changed files with 202 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
import static org.eclipse.edc.connector.controlplane.contract.spi.types.negotiation.ContractNegotiationStates.VERIFIED;
import static org.eclipse.edc.connector.controlplane.services.contractnegotiation.ContractNegotiationProtocolServiceImpl.CONTRACT_NEGOTIATION_REQUEST_SCOPE;
import static org.eclipse.edc.connector.controlplane.services.contractnegotiation.ContractNegotiationProtocolServiceImplTest.TestFunctions.contractOffer;
import static org.eclipse.edc.connector.controlplane.services.contractnegotiation.ContractNegotiationProtocolServiceImplTest.TestFunctions.createPolicy;
import static org.eclipse.edc.junit.assertions.AbstractResultAssert.assertThat;
import static org.eclipse.edc.spi.agent.ParticipantAgent.PARTICIPANT_IDENTITY;
import static org.eclipse.edc.spi.result.ServiceFailure.Reason.BAD_REQUEST;
Expand Down Expand Up @@ -357,7 +356,7 @@ <M extends RemoteMessage> void notify_shouldReturnBadRequest_whenValidationFails
var tokenRepresentation = tokenRepresentation();
var validatableOffer = mock(ValidatableConsumerOffer.class);

when(validatableOffer.getContractPolicy()).thenReturn(createPolicy());
when(validatableOffer.getContractPolicy()).thenReturn(Policy.Builder.newInstance().build());
when(consumerOfferResolver.resolveOffer(any())).thenReturn(ServiceResult.success(validatableOffer));
when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message)))
.thenReturn(ServiceResult.success(participantAgent()));
Expand All @@ -380,7 +379,7 @@ <M extends RemoteMessage> void notify_shouldReturnUnauthorized_whenTokenValidati
var tokenRepresentation = tokenRepresentation();
var validatableOffer = mock(ValidatableConsumerOffer.class);

when(validatableOffer.getContractPolicy()).thenReturn(createPolicy());
when(validatableOffer.getContractPolicy()).thenReturn(Policy.Builder.newInstance().build());
when(consumerOfferResolver.resolveOffer(any())).thenReturn(ServiceResult.success(validatableOffer));
when(store.findById(any())).thenReturn(createContractNegotiationOffered());
when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.unauthorized("unauthorized"));
Expand Down Expand Up @@ -437,16 +436,13 @@ private interface MethodCall<M extends RemoteMessage> {

interface TestFunctions {
static ContractOffer contractOffer() {
var assetId = "test-asset-id";
return ContractOffer.Builder.newInstance()
.id(ContractOfferId.create("1", "test-asset-id").toString())
.policy(createPolicy())
.assetId("assetId")
.id(ContractOfferId.create("1", assetId).toString())
.policy(Policy.Builder.newInstance().target(assetId).build())
.assetId(assetId)
.build();
}

static Policy createPolicy() {
return Policy.Builder.newInstance().build();
}
}

private static class NotifyArguments implements ArgumentsProvider {
Expand Down Expand Up @@ -531,7 +527,7 @@ void shouldInitiateNegotiation_whenNegotiationDoesNotExist() {
.build();
var validatableOffer = mock(ValidatableConsumerOffer.class);

when(validatableOffer.getContractPolicy()).thenReturn(createPolicy());
when(validatableOffer.getContractPolicy()).thenReturn(Policy.Builder.newInstance().target(contractOffer.getAssetId()).build());
when(consumerOfferResolver.resolveOffer(contractOffer.getId())).thenReturn(ServiceResult.success(validatableOffer));
when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent));
when(store.findByIdAndLease(any())).thenReturn(StoreResult.notFound("not found"));
Expand Down Expand Up @@ -574,7 +570,7 @@ void shouldTransitionToRequested_whenNegotiationFound() {

var validatableOffer = mock(ValidatableConsumerOffer.class);

when(validatableOffer.getContractPolicy()).thenReturn(createPolicy());
when(validatableOffer.getContractPolicy()).thenReturn(Policy.Builder.newInstance().target(contractOffer.getAssetId()).build());
when(consumerOfferResolver.resolveOffer(contractOffer.getId())).thenReturn(ServiceResult.success(validatableOffer));
when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent));
when(store.findById(any())).thenReturn(negotiation);
Expand Down Expand Up @@ -612,7 +608,7 @@ void shouldReturnNotFound_whenOfferNotFound() {
.build();
var validatableOffer = mock(ValidatableConsumerOffer.class);

when(validatableOffer.getContractPolicy()).thenReturn(createPolicy());
when(validatableOffer.getContractPolicy()).thenReturn(Policy.Builder.newInstance().target(contractOffer.getAssetId()).build());
when(consumerOfferResolver.resolveOffer(contractOffer.getId())).thenReturn(ServiceResult.notFound(""));

var result = service.notifyRequested(message, tokenRepresentation);
Expand Down Expand Up @@ -731,7 +727,7 @@ <M extends ProcessRemoteMessage> void notify_shouldStoreReceivedMessageId(Method
var negotiation = contractNegotiationBuilder().state(currentState.code()).type(type).contractOffer(offer).build();
var validatableOffer = mock(ValidatableConsumerOffer.class);

when(validatableOffer.getContractPolicy()).thenReturn(createPolicy());
when(validatableOffer.getContractPolicy()).thenReturn(Policy.Builder.newInstance().target(offer.getAssetId()).build());
when(consumerOfferResolver.resolveOffer(any())).thenReturn(ServiceResult.success(validatableOffer));
when(protocolTokenValidator.verify(any(), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message)))
.thenReturn(ServiceResult.success(participantAgent()));
Expand Down Expand Up @@ -761,7 +757,7 @@ <M extends ProcessRemoteMessage> void notify_shouldIgnoreMessage_whenAlreadyRece
negotiation.protocolMessageReceived(message.getId());
var validatableOffer = mock(ValidatableConsumerOffer.class);

when(validatableOffer.getContractPolicy()).thenReturn(createPolicy());
when(validatableOffer.getContractPolicy()).thenReturn(Policy.Builder.newInstance().target(offer.getAssetId()).build());
when(consumerOfferResolver.resolveOffer(any())).thenReturn(ServiceResult.success(validatableOffer));
when(protocolTokenValidator.verify(any(), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message)))
.thenReturn(ServiceResult.success(participantAgent()));
Expand All @@ -787,7 +783,7 @@ <M extends ProcessRemoteMessage> void notify_shouldIgnoreMessage_whenFinalState(
var negotiation = contractNegotiationBuilder().state(FINALIZED.code()).type(type).contractOffer(offer).build();
var validatableOffer = mock(ValidatableConsumerOffer.class);

when(validatableOffer.getContractPolicy()).thenReturn(createPolicy());
when(validatableOffer.getContractPolicy()).thenReturn(Policy.Builder.newInstance().target(offer.getAssetId()).build());
when(consumerOfferResolver.resolveOffer(any())).thenReturn(ServiceResult.success(validatableOffer));
when(protocolTokenValidator.verify(any(), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message)))
.thenReturn(ServiceResult.success(participantAgent()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package org.eclipse.edc.connector.controlplane.services.contractnegotiation;

import org.eclipse.edc.connector.controlplane.contract.spi.ContractOfferId;
import org.eclipse.edc.connector.controlplane.contract.spi.negotiation.ConsumerContractNegotiationManager;
import org.eclipse.edc.connector.controlplane.contract.spi.negotiation.store.ContractNegotiationStore;
import org.eclipse.edc.connector.controlplane.contract.spi.types.agreement.ContractAgreement;
Expand Down Expand Up @@ -269,10 +270,11 @@ private ContractNegotiation.Builder createContractNegotiationBuilder(String nego
}

private ContractOffer createContractOffer() {
var assetId = "test-asset";
return ContractOffer.Builder.newInstance()
.id(UUID.randomUUID().toString())
.policy(Policy.Builder.newInstance().build())
.assetId("test-asset")
.id(ContractOfferId.create("1", assetId).toString())
.policy(Policy.Builder.newInstance().target(assetId).build())
.assetId(assetId)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void shouldDispatchEventsOnTransferProcessStateChanges(TransferProcessService se
when(agent.getIdentity()).thenReturn(providerId);

dispatcherRegistry.register(getTestDispatcher());
when(policyArchive.findPolicyForContract(matches("contractId"))).thenReturn(mock(Policy.class));
when(policyArchive.findPolicyForContract(matches("contractId"))).thenReturn(Policy.Builder.newInstance().target("assetId").build());
when(negotiationStore.findContractAgreement("contractId")).thenReturn(agreement);
when(agentService.createFor(token)).thenReturn(agent);
eventRouter.register(TransferProcessEvent.class, eventSubscriber);
Expand Down Expand Up @@ -215,7 +215,7 @@ void shouldDispatchEventOnTransferProcessTerminated(TransferProcessService servi
RemoteMessageDispatcherRegistry dispatcherRegistry,
PolicyArchive policyArchive) {

when(policyArchive.findPolicyForContract(matches("contractId"))).thenReturn(mock(Policy.class));
when(policyArchive.findPolicyForContract(matches("contractId"))).thenReturn(Policy.Builder.newInstance().target("assetId").build());
dispatcherRegistry.register(getTestDispatcher());
eventRouter.register(TransferProcessEvent.class, eventSubscriber);
var transferRequest = createTransferRequest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.eclipse.edc.connector.controlplane.contract.negotiation;

import org.eclipse.edc.connector.controlplane.contract.observe.ContractNegotiationObservableImpl;
import org.eclipse.edc.connector.controlplane.contract.spi.ContractOfferId;
import org.eclipse.edc.connector.controlplane.contract.spi.negotiation.ContractNegotiationPendingGuard;
import org.eclipse.edc.connector.controlplane.contract.spi.negotiation.observe.ContractNegotiationListener;
import org.eclipse.edc.connector.controlplane.contract.spi.negotiation.store.ContractNegotiationStore;
Expand Down Expand Up @@ -85,6 +86,7 @@

class ConsumerContractNegotiationManagerImplTest {

private static final String ASSET_ID = "assetId";
private static final String PARTICIPANT_ID = "participantId";
private static final int RETRY_LIMIT = 1;

Expand Down Expand Up @@ -307,7 +309,7 @@ void dispatchException(ContractNegotiationStates starting, ContractNegotiationSt
}

private Criterion[] stateIs(int state) {
return aryEq(new Criterion[]{ hasState(state), isNotPending(), new Criterion("type", "=", "CONSUMER") });
return aryEq(new Criterion[] {hasState(state), isNotPending(), new Criterion("type", "=", "CONSUMER")});
}

private ContractNegotiation.Builder contractNegotiationBuilder() {
Expand All @@ -320,20 +322,21 @@ private ContractNegotiation.Builder contractNegotiationBuilder() {
.stateTimestamp(Instant.now().toEpochMilli());
}

private ContractOffer contractOffer() {
return ContractOffer.Builder.newInstance().id("id:assetId:random")
.policy(Policy.Builder.newInstance().assigner("providerId").build())
.assetId("assetId")
private static ContractOffer contractOffer() {
return ContractOffer.Builder.newInstance()
.id(ContractOfferId.create("1", ASSET_ID).toString())
.policy(Policy.Builder.newInstance().target(ASSET_ID).assigner("providerId").build())
.assetId(ASSET_ID)
.build();
}

private ContractAgreement createContractAgreement() {
private static ContractAgreement createContractAgreement() {
return ContractAgreement.Builder.newInstance()
.id("contractId")
.consumerId("consumerId")
.providerId("providerId")
.assetId("assetId")
.policy(Policy.Builder.newInstance().build())
.assetId(ASSET_ID)
.policy(Policy.Builder.newInstance().target(ASSET_ID).build())
.build();
}

Expand Down Expand Up @@ -362,23 +365,6 @@ public Stream<? extends Arguments> provideArguments(ExtensionContext extensionCo
new DispatchFailure(TERMINATING, TERMINATED, completedFuture(StatusResult.failure(FATAL_ERROR)), b -> b.stateCount(RETRIES_NOT_EXHAUSTED).errorDetail("an error").contractOffer(contractOffer()))
);
}

private ContractAgreement createContractAgreement() {
return ContractAgreement.Builder.newInstance()
.id("contractId")
.consumerId("consumerId")
.providerId("providerId")
.assetId("assetId")
.policy(Policy.Builder.newInstance().build())
.build();
}

private ContractOffer contractOffer() {
return ContractOffer.Builder.newInstance().id("id:assetId:random")
.policy(Policy.Builder.newInstance().build())
.assetId("assetId")
.build();
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@

class ProviderContractNegotiationManagerImplTest {

private static final String ASSET_ID = "assetId";
private static final String PROVIDER_ID = "provider";
private static final int RETRY_LIMIT = 1;
private final ContractNegotiationStore store = mock();
Expand Down Expand Up @@ -210,7 +211,7 @@ void agreeing_shouldSendAgreementAndTransitionToConfirmed() {

@Test
void finalizing_shouldSendMessageAndTransitionToFinalized() {
var negotiation = contractNegotiationBuilder().state(FINALIZING.code()).contractOffer(contractOffer()).contractAgreement(contractAgreementBuilder().build()).build();
var negotiation = contractNegotiationBuilder().state(FINALIZING.code()).contractOffer(contractOffer()).contractAgreement(createContractAgreement()).build();
when(store.nextNotLeased(anyInt(), stateIs(FINALIZING.code()))).thenReturn(List.of(negotiation)).thenReturn(emptyList());
when(store.findById(negotiation.getId())).thenReturn(negotiation);
when(dispatcherRegistry.dispatch(any(), any())).thenReturn(completedFuture(StatusResult.success("any")));
Expand Down Expand Up @@ -291,25 +292,26 @@ private ContractNegotiation.Builder contractNegotiationBuilder() {
.stateTimestamp(Instant.now().toEpochMilli());
}

private ContractAgreement.Builder contractAgreementBuilder() {
private static ContractAgreement createContractAgreement() {
return ContractAgreement.Builder.newInstance()
.id(ContractOfferId.create(UUID.randomUUID().toString(), "test-asset-id").toString())
.providerId("any")
.consumerId("any")
.assetId("default")
.policy(Policy.Builder.newInstance().build());
.policy(Policy.Builder.newInstance().build())
.build();
}

private ContractOffer contractOffer() {
private static ContractOffer contractOffer() {
return ContractOffer.Builder.newInstance()
.id(ContractOfferId.create("1", "test-asset-id").toString())
.policy(Policy.Builder.newInstance().build())
.assetId("assetId")
.id(ContractOfferId.create("1", ASSET_ID).toString())
.policy(Policy.Builder.newInstance().target(ASSET_ID).build())
.assetId(ASSET_ID)
.build();
}

private Criterion[] stateIs(int state) {
return aryEq(new Criterion[]{ hasState(state), isNotPending(), new Criterion("type", "=", "PROVIDER") });
return aryEq(new Criterion[] {hasState(state), isNotPending(), new Criterion("type", "=", "PROVIDER")});
}

private static class DispatchFailureArguments implements ArgumentsProvider {
Expand Down Expand Up @@ -338,23 +340,6 @@ public Stream<? extends Arguments> provideArguments(ExtensionContext extensionCo
);
}

private ContractOffer contractOffer() {
return ContractOffer.Builder.newInstance().id("id:assetId:random")
.policy(Policy.Builder.newInstance().build())
.assetId("assetId")
.build();
}

private ContractAgreement createContractAgreement() {
return ContractAgreement.Builder.newInstance()
.id("contractId")
.consumerId("consumerId")
.providerId("providerId")
.assetId("assetId")
.policy(Policy.Builder.newInstance().build())
.build();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ void verifyContractAgreementValidation() {
}

@ParameterizedTest
@ValueSource(strings = { "malicious-actor" })
@ValueSource(strings = {"malicious-actor"})
@NullSource
void verifyContractAgreementValidation_failedIfInvalidCredentials(String counterPartyId) {
var participantAgent = new ParticipantAgent(emptyMap(), counterPartyId != null ? Map.of(PARTICIPANT_IDENTITY, counterPartyId) : Map.of());
Expand Down Expand Up @@ -354,7 +354,7 @@ void validateInitialOffer_fails_whenContractPolicyEvaluationFails() {
}

@ParameterizedTest
@ValueSource(strings = { PROVIDER_ID })
@ValueSource(strings = {PROVIDER_ID})
@NullSource
void validateConsumerRequest_failsInvalidCredentials(String counterPartyId) {
var negotiation = ContractNegotiation.Builder.newInstance()
Expand Down Expand Up @@ -397,11 +397,11 @@ private Result<ContractAgreement> validateAgreementDate(long signingDate) {
return validationService.validateAgreement(new ParticipantAgent(emptyMap(), emptyMap()), agreement);
}

private ContractOffer createContractOffer(Asset asset, Policy policy) {
private ContractOffer createContractOffer(Asset asset) {
return ContractOffer.Builder.newInstance()
.id(ContractOfferId.create("1", asset.getId()).toString())
.assetId(asset.getId())
.policy(policy)
.policy(Policy.Builder.newInstance().target(asset.getId()).build())
.build();
}

Expand All @@ -425,7 +425,7 @@ private ValidatableConsumerOffer createValidatableConsumerOffer(Asset asset, Pol

@NotNull
private ContractOffer createContractOffer() {
return createContractOffer(Asset.Builder.newInstance().build(), Policy.Builder.newInstance().build());
return createContractOffer(Asset.Builder.newInstance().build());
}

private ContractDefinition createContractDefinition() {
Expand Down
Loading

0 comments on commit 4a22187

Please sign in to comment.