|
2 | 2 |
|
3 | 3 | import io.github.sashirestela.openai.SimpleOpenAI; |
4 | 4 | import io.github.sashirestela.openai.domain.DomainTestingHelper; |
| 5 | +import io.github.sashirestela.openai.domain.assistant.ThreadRunStepDelta.MessageCreationStepDetail; |
| 6 | +import io.github.sashirestela.openai.domain.assistant.ThreadRunStepDelta.ToolCallsStepDetail; |
5 | 7 | import org.junit.jupiter.api.BeforeAll; |
6 | 8 | import org.junit.jupiter.api.Test; |
7 | 9 |
|
@@ -163,6 +165,23 @@ void testThreadsRunsCreate() throws IOException { |
163 | 165 | assertNotNull(response); |
164 | 166 | } |
165 | 167 |
|
| 168 | + @Test |
| 169 | + void testThreadsRunsCreateStream() throws IOException { |
| 170 | + DomainTestingHelper.get().mockForStream(httpClient, "src/test/resources/threads_runs_create_stream.txt"); |
| 171 | + var request = ThreadRunRequest.builder() |
| 172 | + .assistantId(assistantId) |
| 173 | + .model("gpt-4-1106-preview") |
| 174 | + .instructions("instructions") |
| 175 | + .additionalInstructions("additional Instructions") |
| 176 | + .metadata(Map.of("key1", "value1")) |
| 177 | + .build(); |
| 178 | + var response = openAI.threads().createRunStream(threadId, request).join(); |
| 179 | + response.filter(e -> e.getName().equals(Events.THREAD_MESSAGE_DELTA)) |
| 180 | + .map(e -> ((TextContent) ((ThreadMessageDelta) e.getData()).getDelta().getContent().get(0)).getValue()) |
| 181 | + .forEach(System.out::print); |
| 182 | + assertNotNull(response); |
| 183 | + } |
| 184 | + |
166 | 185 | @Test |
167 | 186 | void testThreadsRunsModify() throws IOException { |
168 | 187 | DomainTestingHelper.get().mockForObject(httpClient, "src/test/resources/threads_runs_modify.json"); |
@@ -237,6 +256,89 @@ void testThreadsRunsCreateBoth() throws IOException { |
237 | 256 | assertNotNull(response); |
238 | 257 | } |
239 | 258 |
|
| 259 | + @Test |
| 260 | + void testThreadsRunsCreateBothStream() throws IOException { |
| 261 | + DomainTestingHelper.get().mockForStream(httpClient, "src/test/resources/threads_runs_create_both_stream.txt"); |
| 262 | + var request = ThreadCreateAndRunRequest.builder() |
| 263 | + .assistantId(assistantId) |
| 264 | + .thread(ThreadMessageList.builder() |
| 265 | + .message(ThreadMessageRequest.builder() |
| 266 | + .role("user") |
| 267 | + .content( |
| 268 | + "Inspect the content of the attached text file. After that plot graph of the formula requested in it.") |
| 269 | + .build()) |
| 270 | + .metadata(Map.of("stage", "test")) |
| 271 | + .build()) |
| 272 | + .metadata(Map.of("phase", "test")) |
| 273 | + .build(); |
| 274 | + var response = openAI.threads().createThreadAndRunStream(request).join(); |
| 275 | + response.forEach(e -> { |
| 276 | + switch (e.getName()) { |
| 277 | + case Events.THREAD_RUN_STEP_CREATED: |
| 278 | + var runStepCreated = (ThreadRunStep) e.getData(); |
| 279 | + System.out.println("\n===== Thread Run Step Created - " + runStepCreated.getType() + " - " |
| 280 | + + runStepCreated.getId()); |
| 281 | + break; |
| 282 | + case Events.THREAD_RUN_STEP_COMPLETED: |
| 283 | + var runStepCompleted = (ThreadRunStep) e.getData(); |
| 284 | + System.out.println("\n----- Thread Run Step Completed - " + runStepCompleted.getType() + " - " |
| 285 | + + runStepCompleted.getId()); |
| 286 | + break; |
| 287 | + case Events.THREAD_RUN_STEP_DELTA: |
| 288 | + var runStepDeltaDetails = ((ThreadRunStepDelta) e.getData()).getDelta().getStepDetails(); |
| 289 | + if (runStepDeltaDetails instanceof MessageCreationStepDetail) { |
| 290 | + System.out.println( |
| 291 | + ((MessageCreationStepDetail) runStepDeltaDetails).getMessageCreation().getMessageId()); |
| 292 | + } else if (runStepDeltaDetails instanceof ToolCallsStepDetail) { |
| 293 | + var toolCall = ((ToolCallsStepDetail) runStepDeltaDetails).getToolCalls().get(0); |
| 294 | + if (toolCall.getType().equals("code_interpreter")) { |
| 295 | + var codeInterpreter = toolCall.getCodeInterpreter(); |
| 296 | + if (codeInterpreter.getInput() != null) { |
| 297 | + System.out.print(codeInterpreter.getInput()); |
| 298 | + } |
| 299 | + if (codeInterpreter.getOutputs() != null && codeInterpreter.getOutputs().size() > 0) { |
| 300 | + var codeInterpreterOutput = codeInterpreter.getOutputs().get(0); |
| 301 | + if (codeInterpreterOutput.getType().equals("logs")) { |
| 302 | + System.out.print("\nOutput Logs = " + codeInterpreterOutput.getLogs()); |
| 303 | + } else if (codeInterpreterOutput.getType().equals("image")) { |
| 304 | + System.out.print( |
| 305 | + "\nOutput Image File Id = " + codeInterpreterOutput.getImage().getFileId()); |
| 306 | + } |
| 307 | + } |
| 308 | + } else if (toolCall.getType().equals("function")) { |
| 309 | + var functionTool = toolCall.getFunction(); |
| 310 | + if (functionTool.getName() != null) { |
| 311 | + System.out.println("Function Name = " + functionTool.getName()); |
| 312 | + System.out.print("Function Arguments = "); |
| 313 | + } |
| 314 | + if (functionTool.getArguments() != null) { |
| 315 | + System.out.print(functionTool.getArguments()); |
| 316 | + } |
| 317 | + if (functionTool.getOutput() != null) { |
| 318 | + System.out.print("\nFunction Output = " + functionTool.getOutput()); |
| 319 | + } |
| 320 | + } else if (toolCall.getType().equals("retrieval")) { |
| 321 | + // Currently OpenAI is replying an empty Map. |
| 322 | + } |
| 323 | + } |
| 324 | + break; |
| 325 | + case Events.THREAD_MESSAGE_DELTA: |
| 326 | + var messageDeltaFirstContent = ((ThreadMessageDelta) e.getData()).getDelta().getContent().get(0); |
| 327 | + if (messageDeltaFirstContent instanceof TextContent) { |
| 328 | + System.out.print(((TextContent) messageDeltaFirstContent).getValue()); |
| 329 | + } else if (messageDeltaFirstContent instanceof ImageFileContent) { |
| 330 | + System.out.println( |
| 331 | + "File Id = " |
| 332 | + + ((ImageFileContent) messageDeltaFirstContent).getImageFile().getFileId()); |
| 333 | + } |
| 334 | + break; |
| 335 | + default: |
| 336 | + break; |
| 337 | + } |
| 338 | + }); |
| 339 | + assertNotNull(response); |
| 340 | + } |
| 341 | + |
240 | 342 | @Test |
241 | 343 | void testThreadsRunsCancel() throws IOException { |
242 | 344 | DomainTestingHelper.get().mockForObject(httpClient, "src/test/resources/threads_runs_cancel.json"); |
|
0 commit comments