Skip to content

Commit

Permalink
refactor: move counter-party validation to validation service
Browse files Browse the repository at this point in the history
  • Loading branch information
ronjaquensel committed Jun 26, 2023
1 parent 7762b4f commit c444022
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

import java.time.Instant;
import java.util.ArrayList;
import java.util.Optional;

import static java.lang.String.format;
import static org.eclipse.edc.connector.contract.spi.ContractId.createContractId;
Expand Down Expand Up @@ -141,7 +142,16 @@ public Result<ContractAgreement> validateAgreement(ClaimToken token, ContractAgr
}
return success(agreement);
}


@Override
public @NotNull Result<Void> validateRequest(ClaimToken token, ContractAgreement agreement) {
var agent = agentService.createFor(token);
return Optional.ofNullable(agent.getIdentity())
.filter(id -> id.equals(agreement.getConsumerId()) || id.equals(agreement.getProviderId()))
.map(id -> Result.success())
.orElse(Result.failure("Invalid counter-party identity"));
}

@Override
@NotNull
public Result<Void> validateRequest(ClaimToken token, ContractNegotiation negotiation) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void testNegotiation_initialOfferAccepted() {
var offer = getContractOffer();
when(validationService.validateInitialOffer(token, offer)).thenReturn(Result.success(new ValidatedConsumerOffer(CONSUMER_ID, offer)));
when(validationService.validateConfirmed(eq(token), any(ContractAgreement.class), any(ContractOffer.class))).thenReturn(Result.success());
when(validationService.validateRequest(eq(token), any())).thenReturn(Result.success());
when(validationService.validateRequest(eq(token), any(ContractNegotiation.class))).thenReturn(Result.success());

// Start provider and consumer negotiation managers
providerManager.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,45 @@ void validateConfirmed_failsIfPoliciesAreNotEqual() {

verify(agentService).createFor(eq(token));
}

@Test
void validateRequest_shouldReturnSuccess_whenRequestingPartyProvider() {
var token = ClaimToken.Builder.newInstance().build();
var agreement = createContractAgreement().build();
var participantAgent = new ParticipantAgent(Map.of(), Map.of(PARTICIPANT_IDENTITY, PROVIDER_ID));

when(agentService.createFor(token)).thenReturn(participantAgent);

var result = validationService.validateRequest(token, agreement);

assertThat(result).isSucceeded();
}

@Test
void validateRequest_shouldReturnSuccess_whenRequestingPartyConsumer() {
var token = ClaimToken.Builder.newInstance().build();
var agreement = createContractAgreement().build();
var participantAgent = new ParticipantAgent(Map.of(), Map.of(PARTICIPANT_IDENTITY, CONSUMER_ID));

when(agentService.createFor(token)).thenReturn(participantAgent);

var result = validationService.validateRequest(token, agreement);

assertThat(result).isSucceeded();
}

@Test
void validateRequest_shouldReturnFailure_whenRequestingPartyUnauthorized() {
var token = ClaimToken.Builder.newInstance().build();
var agreement = createContractAgreement().build();
var participantAgent = new ParticipantAgent(Map.of(), Map.of(PARTICIPANT_IDENTITY, "invalid"));

when(agentService.createFor(token)).thenReturn(participantAgent);

var result = validationService.validateRequest(token, agreement);

assertThat(result).isFailed();
}

@Test
void validateConsumerRequest() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,6 @@ public TransferProcessService transferProcessService() {
@Provider
public TransferProcessProtocolService transferProcessProtocolService() {
return new TransferProcessProtocolServiceImpl(transferProcessStore, transactionContext, contractNegotiationStore,
contractValidationService, dataAddressValidator, transferProcessObservable, participantAgentService, clock, monitor, telemetry);
contractValidationService, dataAddressValidator, transferProcessObservable, clock, monitor, telemetry);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
import org.eclipse.edc.connector.transfer.spi.types.protocol.TransferStartMessage;
import org.eclipse.edc.connector.transfer.spi.types.protocol.TransferTerminationMessage;
import org.eclipse.edc.service.spi.result.ServiceResult;
import org.eclipse.edc.spi.agent.ParticipantAgentService;
import org.eclipse.edc.spi.dataaddress.DataAddressValidator;
import org.eclipse.edc.spi.iam.ClaimToken;
import org.eclipse.edc.spi.monitor.Monitor;
import org.eclipse.edc.spi.result.Result;
import org.eclipse.edc.spi.telemetry.Telemetry;
import org.eclipse.edc.transaction.spi.TransactionContext;
import org.jetbrains.annotations.NotNull;
Expand All @@ -57,7 +57,6 @@ public class TransferProcessProtocolServiceImpl implements TransferProcessProtoc
private final ContractValidationService contractValidationService;
private final DataAddressValidator dataAddressValidator;
private final TransferProcessObservable observable;
private final ParticipantAgentService agentService;
private final Clock clock;
private final Monitor monitor;
private final Telemetry telemetry;
Expand All @@ -66,15 +65,13 @@ public TransferProcessProtocolServiceImpl(TransferProcessStore transferProcessSt
TransactionContext transactionContext, ContractNegotiationStore negotiationStore,
ContractValidationService contractValidationService,
DataAddressValidator dataAddressValidator, TransferProcessObservable observable,
ParticipantAgentService agentService,
Clock clock, Monitor monitor, Telemetry telemetry) {
this.transferProcessStore = transferProcessStore;
this.transactionContext = transactionContext;
this.negotiationStore = negotiationStore;
this.contractValidationService = contractValidationService;
this.dataAddressValidator = dataAddressValidator;
this.observable = observable;
this.agentService = agentService;
this.clock = clock;
this.monitor = monitor;
this.telemetry = telemetry;
Expand Down Expand Up @@ -127,7 +124,7 @@ public ServiceResult<TransferProcess> notifyTerminated(TransferTerminationMessag
@NotNull
public ServiceResult<TransferProcess> findById(String id, ClaimToken claimToken) {
return transactionContext.execute(() -> Optional.ofNullable(transferProcessStore.findById(id))
.map(tp -> validateCounterParty(claimToken, tp))
.filter(tp -> validateCounterParty(claimToken, tp))
.map(ServiceResult::success)
.orElse(ServiceResult.notFound(format("No negotiation with id %s found", id))));
}
Expand Down Expand Up @@ -215,16 +212,11 @@ private ServiceResult<TransferProcess> onMessageDo(TransferRemoteMessage message
.orElse(ServiceResult.notFound(format("TransferProcess with DataRequest id %s not found", message.getProcessId()))));
}

private TransferProcess validateCounterParty(ClaimToken claimToken, TransferProcess transferProcess) {
var agentId = agentService.createFor(claimToken).getIdentity();
if (agentId == null) {
return null;
}

private boolean validateCounterParty(ClaimToken claimToken, TransferProcess transferProcess) {
return Optional.ofNullable(negotiationStore.findContractAgreement(transferProcess.getDataRequest().getContractId()))
.filter(agreement -> agentId.equals(agreement.getConsumerId()) || agentId.equals(agreement.getProviderId()))
.map(agreement -> transferProcess)
.orElse(null);
.map(agreement -> contractValidationService.validateRequest(claimToken, agreement))
.filter(Result::succeeded)
.isPresent();
}

private void update(TransferProcess transferProcess) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ void notifyAgreed_shouldReturnBadRequest_whenValidationFails() {
void notifyVerified_shouldTransitionToVerified() {
var negotiation = contractNegotiationBuilder().id("negotiationId").type(PROVIDER).state(AGREED.code()).build();
when(store.findForCorrelationId("processId")).thenReturn(negotiation);
when(validationService.validateRequest(any(), any())).thenReturn(Result.success());
when(validationService.validateRequest(any(), any(ContractNegotiation.class))).thenReturn(Result.success());
var message = ContractAgreementVerificationMessage.Builder.newInstance()
.protocol("protocol")
.counterPartyAddress("http://any")
Expand All @@ -209,15 +209,15 @@ void notifyVerified_shouldTransitionToVerified() {
assertThat(result).isSucceeded();
verify(store).save(argThat(n -> n.getState() == VERIFIED.code()));
verify(listener).verified(negotiation);
verify(validationService).validateRequest(any(), any());
verify(validationService).validateRequest(any(), any(ContractNegotiation.class));
verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class));
}

@Test
void notifyVerified_shouldReturnBadRequest_whenValidationFails() {
var negotiation = contractNegotiationBuilder().id("negotiationId").type(PROVIDER).state(AGREED.code()).build();
when(store.findForCorrelationId("processId")).thenReturn(negotiation);
when(validationService.validateRequest(any(), any())).thenReturn(Result.failure("validation error"));
when(validationService.validateRequest(any(), any(ContractNegotiation.class))).thenReturn(Result.failure("validation error"));
var message = ContractAgreementVerificationMessage.Builder.newInstance()
.protocol("protocol")
.counterPartyAddress("http://any")
Expand All @@ -235,7 +235,7 @@ void notifyVerified_shouldReturnBadRequest_whenValidationFails() {
void notifyFinalized_shouldTransitionToFinalized() {
var negotiation = contractNegotiationBuilder().id("negotiationId").type(PROVIDER).state(VERIFIED.code()).build();
when(store.findForCorrelationId("processId")).thenReturn(negotiation);
when(validationService.validateRequest(any(), any())).thenReturn(Result.success());
when(validationService.validateRequest(any(), any(ContractNegotiation.class))).thenReturn(Result.success());
var message = ContractNegotiationEventMessage.Builder.newInstance()
.type(ContractNegotiationEventMessage.Type.FINALIZED)
.protocol("protocol")
Expand All @@ -249,15 +249,15 @@ void notifyFinalized_shouldTransitionToFinalized() {
assertThat(result).isSucceeded();
verify(store).save(argThat(n -> n.getState() == FINALIZED.code()));
verify(listener).finalized(negotiation);
verify(validationService).validateRequest(any(), any());
verify(validationService).validateRequest(any(), any(ContractNegotiation.class));
verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class));
}

@Test
void notifyFinalized_shouldReturnBadRequest_whenValidationFails() {
var negotiation = contractNegotiationBuilder().id("negotiationId").type(PROVIDER).state(VERIFIED.code()).build();
when(store.findForCorrelationId("processId")).thenReturn(negotiation);
when(validationService.validateRequest(any(), any())).thenReturn(Result.failure("validation error"));
when(validationService.validateRequest(any(), any(ContractNegotiation.class))).thenReturn(Result.failure("validation error"));
var message = ContractNegotiationEventMessage.Builder.newInstance()
.type(ContractNegotiationEventMessage.Type.FINALIZED)
.protocol("protocol")
Expand All @@ -277,7 +277,7 @@ void notifyFinalized_shouldReturnBadRequest_whenValidationFails() {
void notifyTerminated_shouldTransitionToTerminated() {
var negotiation = contractNegotiationBuilder().id("negotiationId").type(PROVIDER).state(VERIFIED.code()).build();
when(store.findForCorrelationId("processId")).thenReturn(negotiation);
when(validationService.validateRequest(any(), any())).thenReturn(Result.success());
when(validationService.validateRequest(any(), any(ContractNegotiation.class))).thenReturn(Result.success());
var message = ContractNegotiationTerminationMessage.Builder.newInstance()
.protocol("protocol")
.processId("processId")
Expand All @@ -291,15 +291,15 @@ void notifyTerminated_shouldTransitionToTerminated() {
assertThat(result).isSucceeded();
verify(store).save(argThat(n -> n.getState() == TERMINATED.code()));
verify(listener).terminated(negotiation);
verify(validationService).validateRequest(any(), any());
verify(validationService).validateRequest(any(), any(ContractNegotiation.class));
verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class));
}

@Test
void notifyTerminated_shouldReturnBadRequest_whenValidationFails() {
var negotiation = contractNegotiationBuilder().id("negotiationId").type(PROVIDER).state(VERIFIED.code()).build();
when(store.findForCorrelationId("processId")).thenReturn(negotiation);
when(validationService.validateRequest(any(), any())).thenReturn(Result.failure("validation error"));
when(validationService.validateRequest(any(), any(ContractNegotiation.class))).thenReturn(Result.failure("validation error"));
var message = ContractNegotiationTerminationMessage.Builder.newInstance()
.protocol("protocol")
.processId("processId")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
import org.eclipse.edc.policy.model.Policy;
import org.eclipse.edc.service.spi.result.ServiceFailure;
import org.eclipse.edc.service.spi.result.ServiceResult;
import org.eclipse.edc.spi.agent.ParticipantAgent;
import org.eclipse.edc.spi.agent.ParticipantAgentService;
import org.eclipse.edc.spi.dataaddress.DataAddressValidator;
import org.eclipse.edc.spi.iam.ClaimToken;
import org.eclipse.edc.spi.monitor.Monitor;
Expand All @@ -55,7 +53,6 @@
import org.mockito.ArgumentCaptor;

import java.time.Clock;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Stream;

Expand All @@ -69,7 +66,6 @@
import static org.eclipse.edc.service.spi.result.ServiceFailure.Reason.BAD_REQUEST;
import static org.eclipse.edc.service.spi.result.ServiceFailure.Reason.CONFLICT;
import static org.eclipse.edc.service.spi.result.ServiceFailure.Reason.NOT_FOUND;
import static org.eclipse.edc.spi.agent.ParticipantAgent.PARTICIPANT_IDENTITY;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.atLeastOnce;
Expand All @@ -88,7 +84,6 @@ class TransferProcessProtocolServiceImplTest {
private final ContractValidationService validationService = mock(ContractValidationService.class);
private final DataAddressValidator dataAddressValidator = mock(DataAddressValidator.class);
private final TransferProcessListener listener = mock(TransferProcessListener.class);
private final ParticipantAgentService agentService = mock(ParticipantAgentService.class);

private TransferProcessProtocolService service;

Expand All @@ -97,7 +92,7 @@ void setUp() {
var observable = new TransferProcessObservableImpl();
observable.registerListener(listener);
service = new TransferProcessProtocolServiceImpl(store, transactionContext, negotiationStore, validationService,
dataAddressValidator, observable, agentService, mock(Clock.class), mock(Monitor.class), mock(Telemetry.class));
dataAddressValidator, observable, mock(Clock.class), mock(Monitor.class), mock(Telemetry.class));
}

@Test
Expand Down Expand Up @@ -313,6 +308,24 @@ void notifyTerminated_shouldReturnConflict_whenStatusIsNotValid() {
verify(store, never()).updateOrCreate(any());
verifyNoInteractions(listener);
}

@Test
void findById_shouldReturnTransferProcess_whenValidCounterParty() {
var processId = "transferProcessId";
var transferProcess = transferProcess(INITIAL, processId);
var token = claimToken();
var agreement = contractAgreement();

when(store.findById(processId)).thenReturn(transferProcess);
when(negotiationStore.findContractAgreement(any())).thenReturn(agreement);
when(validationService.validateRequest(token, agreement)).thenReturn(Result.success());

var result = service.findById(processId, token);

assertThat(result)
.isSucceeded()
.isEqualTo(transferProcess);
}

@Test
void findById_shouldReturnNotFound_whenNegotiationNotFound() {
Expand All @@ -326,35 +339,23 @@ void findById_shouldReturnNotFound_whenNegotiationNotFound() {
.isEqualTo(NOT_FOUND);
}

@ParameterizedTest
@ArgumentsSource(FindByIdArguments.class)
void findById_shouldSucceedOrFail_dependingOnCounterParty(String counterPartyId, boolean shouldSucceed) {
@Test
void findById_shouldReturnNotFound_whenCounterPartyUnauthorized() {
var processId = "transferProcessId";
var contractId = "contractId";
var transferProcess = transferProcess(INITIAL, processId).toBuilder()
.dataRequest(dataRequest(contractId))
.build();

var token = ClaimToken.Builder.newInstance().build();
var agent = new ParticipantAgent(Map.of(), Map.of(PARTICIPANT_IDENTITY, counterPartyId));
var transferProcess = transferProcess(INITIAL, processId);
var token = claimToken();
var agreement = contractAgreement();

when(store.findById(processId)).thenReturn(transferProcess);
when(agentService.createFor(token)).thenReturn(agent);
when(negotiationStore.findContractAgreement(contractId)).thenReturn(agreement);

when(negotiationStore.findContractAgreement(any())).thenReturn(agreement);
when(validationService.validateRequest(token, agreement)).thenReturn(Result.failure("error"));
var result = service.findById(processId, token);

if (shouldSucceed) {
assertThat(result)
.isSucceeded()
.isEqualTo(transferProcess);
} else {
assertThat(result)
.isFailed()
.extracting(ServiceFailure::getReason)
.isEqualTo(NOT_FOUND);
}

assertThat(result)
.isFailed()
.extracting(ServiceFailure::getReason)
.isEqualTo(NOT_FOUND);
}

@ParameterizedTest
Expand All @@ -373,6 +374,7 @@ private TransferProcess transferProcess(TransferProcessStates state, String id)
return TransferProcess.Builder.newInstance()
.state(state.code())
.id(id)
.dataRequest(dataRequest())
.build();
}

Expand All @@ -392,9 +394,9 @@ private ContractAgreement contractAgreement() {
.build();
}

private DataRequest dataRequest(String contractId) {
private DataRequest dataRequest() {
return DataRequest.Builder.newInstance()
.contractId(contractId)
.contractId("contractId")
.destinationType("type")
.build();
}
Expand Down Expand Up @@ -422,15 +424,4 @@ public Stream<? extends Arguments> provideArguments(ExtensionContext extensionCo
);
}
}

private static class FindByIdArguments implements ArgumentsProvider {
@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) throws Exception {
return Stream.of(
Arguments.of("provider", true),
Arguments.of("consumer", true),
Arguments.of("invalid", false)
);
}
}
}
Loading

0 comments on commit c444022

Please sign in to comment.