Skip to content

Commit

Permalink
Updated task service to send client notification events (#3578)
Browse files Browse the repository at this point in the history
Co-authored-by: jryu01 <[email protected]>
  • Loading branch information
jryu01 and jryu01 authored May 9, 2024
1 parent 8e86ec2 commit 2071751
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 18 deletions.
11 changes: 11 additions & 0 deletions packages/client/hmi-client/src/services/notification.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import API from '@/api/api';
import { ClientEvent, ClientEventType, NotificationEvent, NotificationGroup } from '@/types/Types';
import { logger } from '@/utils/logger';

/**
* Get notification
Expand Down Expand Up @@ -37,10 +38,20 @@ export async function getLatestUnacknowledgedNotifications(
*/
export function convertToClientEvents<T>(notificationGroup: NotificationGroup) {
const { notificationEvents, type } = notificationGroup;

if (Object.values(ClientEventType).includes(type as ClientEventType)) {
logger.error(`Notification type: ${type} is not supported client event type`, {
showToast: false
});
return [];
}

const events: ClientEvent<T>[] = notificationEvents.map((event: NotificationEvent) => ({
id: event.id || '',
createdAtMs: new Date(event.createdOn || Date.now()).getTime(),
type: type as ClientEventType,
notificationGroupId: notificationGroup.id,
projectId: notificationGroup.projectId,
data: event.data
}));
return events;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import { ClientEvent, ClientEventType, ExtractionStatusUpdate } from '@/types/Types';
import {
ClientEvent,
ClientEventType,
ExtractionStatusUpdate,
TaskNotificationEventData
} from '@/types/Types';
import { logger } from '@/utils/logger';
import { Ref } from 'vue';
import { NotificationItem } from '@/types/common';
Expand Down Expand Up @@ -45,12 +50,12 @@ export const createNotificationEventHandlers = (notificationItems: Ref<Notificat
if (!event.data) return;

const existingItem = notificationItems.value.find(
(item) => item.notificationGroupId === event.data.notificationGroupId
(item) => item.notificationGroupId === event.notificationGroupId
);
if (!existingItem) {
// Create a new notification item
const newItem: NotificationItem = {
notificationGroupId: event.data.notificationGroupId,
notificationGroupId: event.notificationGroupId ?? '',
type: ClientEventType.ExtractionPdf,
assetId: event.data.documentId,
assetName: '',
Expand Down Expand Up @@ -79,6 +84,13 @@ export const createNotificationEventHandlers = (notificationItems: Ref<Notificat
});
};

handlers[ClientEventType.TaskGollmModelCard] = (
event: ClientEvent<TaskNotificationEventData>
) => {
// TODO: Create a notification item and implement notification item UI for this event
console.log(event);
};

const getHandler = (eventType: ClientEventType) => handlers[eventType] ?? (() => {});

return {
Expand All @@ -100,8 +112,9 @@ export const createNotificationEventLogger = (
>(
event: ClientEvent<T>
) => {
if (!event.notificationGroupId) return;
const found = visibleNotificationItems.value.find(
(item) => item.notificationGroupId === event.data.notificationGroupId
(item) => item.notificationGroupId === event.notificationGroupId
);
if (!found) return;
logStatusMessage(
Expand Down
5 changes: 4 additions & 1 deletion packages/client/hmi-client/src/types/Types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ export interface ClientEvent<T> {
id: string;
createdAtMs: number;
type: ClientEventType;
projectId?: string;
notificationGroupId?: string;
data: T;
}

Expand Down Expand Up @@ -542,7 +544,6 @@ export interface ExtractionResponseResult {
}

export interface ExtractionStatusUpdate {
notificationGroupId: string;
documentId: string;
t: number;
message: string;
Expand Down Expand Up @@ -1261,6 +1262,8 @@ export enum ClientEventType {
FileUploadProgress = "FILE_UPLOAD_PROGRESS",
Extraction = "EXTRACTION",
ExtractionPdf = "EXTRACTION_PDF",
TaskUndefinedEvent = "TASK_UNDEFINED_EVENT",
TaskGollmModelCard = "TASK_GOLLM_MODEL_CARD",
}

export enum FileType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import lombok.Builder;
import lombok.Value;
import software.uncharted.terarium.hmiserver.annotations.TSModel;
import software.uncharted.terarium.hmiserver.annotations.TSOptional;

@Builder
@Value
Expand All @@ -18,5 +19,11 @@ public class ClientEvent<T> implements Serializable {

private ClientEventType type;

@TSOptional
private UUID projectId;

@TSOptional
private UUID notificationGroupId;

private T data;
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ public enum ClientEventType {
FILE_UPLOAD_PROGRESS,
EXTRACTION,
EXTRACTION_PDF,
// Events for the task runner notifications
TASK_UNDEFINED_EVENT,
TASK_GOLLM_MODEL_CARD,
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,12 @@
@Data
@TSModel
public class ExtractionStatusUpdate {
private UUID notificationGroupId;
private UUID documentId;
private Double t;
private String message;
private String error;

public ExtractionStatusUpdate(
final UUID notificationGroupId,
final UUID documentId,
final Double t,
final String message,
final String error) {
this.notificationGroupId = notificationGroupId;
public ExtractionStatusUpdate(final UUID documentId, final Double t, final String message, final String error) {
this.documentId = documentId;
this.t = t;
this.message = message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ private static class ExtractionGroupInstance extends NotificationGroupInstance<E
@Override
public ClientEvent<ExtractionStatusUpdate> produceClientEvent(
final Double t, final String message, final String error) {
final ExtractionStatusUpdate update =
new ExtractionStatusUpdate(this.getNotificationGroupId(), documentId, t, message, error);
final ExtractionStatusUpdate update = new ExtractionStatusUpdate(documentId, t, message, error);
return ClientEvent.<ExtractionStatusUpdate>builder()
.type(this.clientEventType)
.notificationGroupId(this.getNotificationGroupId())
.data(update)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package software.uncharted.terarium.hmiserver.service.tasks;

import java.util.Map;
import software.uncharted.terarium.hmiserver.models.ClientEventType;

public class TaskNotificationEventTypes {

private static Map<String, ClientEventType> clientEventTypes = Map.of(
ModelCardResponseHandler.NAME, ClientEventType.TASK_GOLLM_MODEL_CARD
// Add more task names and their corresponding event types here
);

public static ClientEventType getTypeFor(String taskName) {
final ClientEventType eventType = clientEventTypes.get(taskName);
return eventType == null ? ClientEventType.TASK_UNDEFINED_EVENT : eventType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,15 @@
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import software.uncharted.terarium.hmiserver.configuration.Config;
import software.uncharted.terarium.hmiserver.models.ClientEvent;
import software.uncharted.terarium.hmiserver.models.ClientEventType;
import software.uncharted.terarium.hmiserver.models.notification.NotificationEvent;
import software.uncharted.terarium.hmiserver.models.notification.NotificationGroup;
import software.uncharted.terarium.hmiserver.models.task.TaskFuture;
import software.uncharted.terarium.hmiserver.models.task.TaskRequest;
import software.uncharted.terarium.hmiserver.models.task.TaskResponse;
import software.uncharted.terarium.hmiserver.models.task.TaskStatus;
import software.uncharted.terarium.hmiserver.service.ClientEventService;
import software.uncharted.terarium.hmiserver.service.notification.NotificationService;

@Service
Expand Down Expand Up @@ -155,6 +158,7 @@ public synchronized void complete(final TaskResponse resp) {
private final Config config;
private final ObjectMapper objectMapper;
private final NotificationService notificationService;
private final ClientEventService clientEventService;

private final Map<String, TaskResponseHandler> responseHandlers = new ConcurrentHashMap<>();
private final Map<UUID, SseEmitter> taskIdToEmitter = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -436,10 +440,27 @@ private void onTaskResponseOneInstanceReceives(final Message message) {
log.info("Creating notification event under group id: {}", resp.getId());

notificationService.createNotificationEvent(resp.getId(), event);

} catch (final Exception e) {
log.error("Failed to persist notification event for for task {}", resp.getId(), e);
}

try {
// send the client event
final ClientEventType clientEventType = TaskNotificationEventTypes.getTypeFor(resp.getScript());
log.info("Sending client event with type {} for task {} ", clientEventType.toString(), resp.getId());

final ClientEvent<TaskResponse> clientEvent = ClientEvent.<TaskResponse>builder()
.notificationGroupId(resp.getId())
.type(clientEventType)
.data(resp)
.build();
clientEventService.sendToUser(clientEvent, resp.getUserId());

} catch (final Exception e) {
log.error("Failed to send client event for for task {}", resp.getId(), e);
}

log.info("Broadcasting task response for task id {} and status {}", resp.getId(), resp.getStatus());

// once the handler has executed and the response cache is up to date, we now
Expand Down Expand Up @@ -551,7 +572,8 @@ public TaskFuture runTaskAsync(final TaskRequest r) throws JsonProcessingExcepti
// create the notification group for the task
final NotificationGroup group = new NotificationGroup();
group.setId(req.getId()); // use the task id
group.setType(req.getType().toString());
group.setType(
TaskNotificationEventTypes.getTypeFor(req.getScript()).toString());
group.setUserId(req.getUserId());
group.setProjectId(req.getProjectId());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ public class NotificationServiceTests extends TerariumApplicationTests {
private CurrentUserService currentUserService;

ClientEvent<ExtractionStatusUpdate> produceClientEvent(final Double t, final String message, final String error) {
final ExtractionStatusUpdate update =
new ExtractionStatusUpdate(UUID.randomUUID(), UUID.randomUUID(), t, message, error);
final ExtractionStatusUpdate update = new ExtractionStatusUpdate(UUID.randomUUID(), t, message, error);
return ClientEvent.<ExtractionStatusUpdate>builder()
.type(ClientEventType.HEARTBEAT)
.data(update)
Expand Down

0 comments on commit 2071751

Please sign in to comment.