Skip to content

Commit

Permalink
Merge pull request #359 from cmu-sei/v8
Browse files Browse the repository at this point in the history
api interface work
  • Loading branch information
sei-dupdyke committed May 28, 2024
2 parents 8865ee5 + a5bbca3 commit 6f18350
Show file tree
Hide file tree
Showing 15 changed files with 183 additions and 150 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Text.Json;
using System.Threading.Tasks;
using ghosts.api.Areas.Animator.Infrastructure.Animations.AnimationDefinitions.Chat.Mattermost;
using ghosts.api.Areas.Animator.Infrastructure.ContentServices;
using ghosts.api.Areas.Animator.Infrastructure.ContentServices.Ollama;
using ghosts.api.Areas.Animator.Infrastructure.Extensions;
using ghosts.api.Areas.Animator.Infrastructure.Models;
Expand All @@ -27,12 +28,14 @@ public class ChatClient
private readonly HttpClient _client;
private string _token;
private string UserId { get; set; }
private IFormatterService _formatterService;

public ChatClient(ChatJobConfiguration config)
public ChatClient(ChatJobConfiguration config, IFormatterService formatterService)
{
_configuration = config;
this._baseUrl = _configuration.Chat.BaseUrl;
this._client = new HttpClient();
this._formatterService = formatterService;
}

private async Task<User> AdminLogin()
Expand Down Expand Up @@ -406,7 +409,7 @@ private async Task<string> ExecuteRequest(HttpRequestMessage request)
}
}

public async Task Step(OllamaConnectorService llm, Random random, IEnumerable<NpcRecord> agents)
public async Task Step(Random random, IEnumerable<NpcRecord> agents)
{
await this.AdminLogin();

Expand All @@ -425,11 +428,11 @@ await this.CreateUser(new UserCreate
});
}

await this.StepEx(llm, random, username, _configuration.Chat.DefaultUserPassword);
await this.StepEx(random, username, _configuration.Chat.DefaultUserPassword);
}
}

private async Task StepEx(OllamaConnectorService llm, Random random, string username, string password)
private async Task StepEx(Random random, string username, string password)
{
_log.Trace($"Managing {username}...");

Expand Down Expand Up @@ -542,7 +545,7 @@ private async Task StepEx(OllamaConnectorService llm, Random random, string user
respondingTo = history.UserName;
}

var message = await llm.ExecuteQuery(prompt);
var message = await this._formatterService.ExecuteQuery(prompt);

message = message.Clean(this._configuration.Replacements, random);
if (!string.IsNullOrEmpty(respondingTo))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Ghosts.Animator.Extensions;
using ghosts.api.Areas.Animator.Hubs;
using ghosts.api.Areas.Animator.Infrastructure.Animations.AnimationDefinitions.Chat;
using ghosts.api.Areas.Animator.Infrastructure.ContentServices;
using ghosts.api.Areas.Animator.Infrastructure.ContentServices.Ollama;
using ghosts.api.Areas.Animator.Infrastructure.Models;
using Ghosts.Api.Infrastructure;
Expand All @@ -27,6 +28,7 @@ public class ChatJob
private readonly ChatClient _chatClient;
private readonly int _currentStep;
private CancellationToken _cancellationToken;
private IFormatterService _formatterService;

public ChatJob(ApplicationSettings configuration, IServiceScopeFactory scopeFactory, Random random,
IHubContext<ActivityHub> activityHubContext, CancellationToken cancellationToken)
Expand All @@ -44,8 +46,10 @@ public ChatJob(ApplicationSettings configuration, IServiceScopeFactory scopeFact
var chatConfiguration = JsonSerializer.Deserialize<ChatJobConfiguration>(File.ReadAllText("config/chat.json"),
new JsonSerializerOptions { PropertyNameCaseInsensitive = true }) ?? throw new InvalidOperationException();

var llm = new OllamaConnectorService(_configuration.AnimatorSettings.Animations.Chat.ContentEngine);
this._chatClient = new ChatClient(chatConfiguration);
this._formatterService =
new ContentCreationService(_configuration.AnimatorSettings.Animations.Chat.ContentEngine).FormatterService;

this._chatClient = new ChatClient(chatConfiguration, this._formatterService);

while (!_cancellationToken.IsCancellationRequested)
{
Expand All @@ -55,17 +59,17 @@ public ChatJob(ApplicationSettings configuration, IServiceScopeFactory scopeFact
return;
}

this.Step(llm, random, chatConfiguration);
Thread.Sleep(this._configuration.AnimatorSettings.Animations.SocialSharing.TurnLength);
this.Step(random, chatConfiguration);
Thread.Sleep(this._configuration.AnimatorSettings.Animations.Chat.TurnLength);

this._currentStep++;
}
}

private async void Step(OllamaConnectorService llm, Random random, ChatJobConfiguration chatConfiguration)
private async void Step(Random random, ChatJobConfiguration chatConfiguration)
{
_log.Trace("Executing a chat step...");
var agents = this._context.Npcs.ToList().Shuffle(_random).Take(chatConfiguration.Chat.AgentsPerBatch);
await this._chatClient.Step(llm, random, agents);
await this._chatClient.Step(random, agents);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class SocialSharingJob
private readonly CancellationToken _cancellationToken;
private readonly ApplicationDbContext _context;
private readonly IMachineUpdateService _updateService;
private readonly IFormatterService _formatterService;

public SocialSharingJob(ApplicationSettings configuration, IServiceScopeFactory scopeFactory, Random random,
IHubContext<ActivityHub> activityHubContext, CancellationToken cancellationToken)
Expand All @@ -48,6 +49,9 @@ public SocialSharingJob(ApplicationSettings configuration, IServiceScopeFactory

this._cancellationToken = cancellationToken;
this._updateService = innerScope.ServiceProvider.GetRequiredService<IMachineUpdateService>();

_formatterService =
new ContentCreationService(_configuration.AnimatorSettings.Animations.Chat.ContentEngine).FormatterService;

if (!_configuration.AnimatorSettings.Animations.SocialSharing.IsInteracting)
{
Expand Down Expand Up @@ -84,9 +88,6 @@ private async void Step()
{
_log.Trace("Social sharing step proceeding...");

var contentService =
new ContentCreationService(_configuration.AnimatorSettings.Animations.SocialSharing.ContentEngine);

//take some random NPCs
var activities = new List<NpcActivity>();
var rawAgents = this._context.Npcs.ToList();
Expand All @@ -102,7 +103,7 @@ private async void Step()
foreach (var agent in agents)
{
_log.Trace($"Processing agent {agent.NpcProfile.Email}...");
var tweetText = await contentService.GenerateTweet(agent);
var tweetText = await this._formatterService.GenerateTweet(agent);
if (string.IsNullOrEmpty(tweetText))
{
_log.Trace($"Content service generated no payload...");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ public Task StopAsync(CancellationToken cancellationToken)
this._socialGraphJobCancellationTokenSource.Cancel();
this._socialGraphJobThread?.Join();
this.RemoveJob("SOCIALGRAPH");
this._socialGraphJobCancellationTokenSource = new CancellationTokenSource();
}
catch
{
Expand All @@ -120,6 +121,7 @@ public Task StopAsync(CancellationToken cancellationToken)
this._socialSharingJobCancellationTokenSource.Cancel();
this._socialSharingJobThread?.Join();
this.RemoveJob("SOCIALSHARING");
this._socialGraphJobCancellationTokenSource = new CancellationTokenSource();
}
catch
{
Expand All @@ -131,6 +133,7 @@ public Task StopAsync(CancellationToken cancellationToken)
this._socialSharingJobCancellationTokenSource.Cancel();
this._socialBeliefsJobThread?.Join();
this.RemoveJob("SOCIALBELIEF");
this._socialGraphJobCancellationTokenSource = new CancellationTokenSource();
}
catch
{
Expand All @@ -142,6 +145,7 @@ public Task StopAsync(CancellationToken cancellationToken)
this._chatJobJobCancellationTokenSource.Cancel();
this._chatJobThread?.Join();
this.RemoveJob("CHAT");
this._socialGraphJobCancellationTokenSource = new CancellationTokenSource();
}
catch
{
Expand All @@ -153,6 +157,7 @@ public Task StopAsync(CancellationToken cancellationToken)
this._fullAutonomyCancellationTokenSource.Cancel();
this._fullAutonomyJobThread?.Join();
this.RemoveJob("FULLAUTONOMY");
this._socialGraphJobCancellationTokenSource = new CancellationTokenSource();
}
catch
{
Expand Down Expand Up @@ -181,22 +186,27 @@ public Task StopJob(string jobId)
case "SOCIALGRAPH":
this._socialGraphJobCancellationTokenSource.Cancel();
this._socialGraphJobThread?.Join();
this._socialGraphJobCancellationTokenSource = new CancellationTokenSource();
break;
case "SOCIALSHARING":
this._socialSharingJobCancellationTokenSource.Cancel();
this._socialSharingJobThread?.Join();
this._socialSharingJobCancellationTokenSource = new CancellationTokenSource();
break;
case "SOCIALBELIEFS":
this._socialSharingJobCancellationTokenSource.Cancel();
this._socialBeliefsJobThread?.Join();
this._socialSharingJobCancellationTokenSource = new CancellationTokenSource();
break;
case "CHAT":
this._chatJobJobCancellationTokenSource.Cancel();
this._chatJobThread?.Join();
this._chatJobJobCancellationTokenSource = new CancellationTokenSource();
break;
case "FULLAUTONOMY":
this._fullAutonomyCancellationTokenSource.Cancel();
this._fullAutonomyJobThread?.Join();
this._chatJobJobCancellationTokenSource = new CancellationTokenSource();
break;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
// Copyright 2017 Carnegie Mellon University. All Rights Reserved. See LICENSE.md file for terms.

using System;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
using ghosts.api.Areas.Animator.Infrastructure.ContentServices.Native;
using ghosts.api.Areas.Animator.Infrastructure.ContentServices.Ollama;
using ghosts.api.Areas.Animator.Infrastructure.ContentServices.OpenAi;
using ghosts.api.Areas.Animator.Infrastructure.ContentServices.Shadows;
using ghosts.api.Areas.Animator.Infrastructure.Models;
using Ghosts.Api.Infrastructure;
using Ghosts.Api.Infrastructure.Extensions;
using NLog;

namespace ghosts.api.Areas.Animator.Infrastructure.ContentServices;
Expand All @@ -21,6 +18,7 @@ public class ContentCreationService
private OpenAiFormatterService _openAiFormatterService;
private OllamaFormatterService _ollamaFormatterService;
private ShadowsFormatterService _shadowsFormatterService;
public IFormatterService FormatterService;

public ContentCreationService(ApplicationSettings.AnimatorSettingsDetail.ContentEngineSettings configuration)
{
Expand All @@ -31,124 +29,36 @@ public ContentCreationService(ApplicationSettings.AnimatorSettingsDetail.Content
configuration.Model;

if (_configuration.Source.ToLower() == "openai" && this._openAiFormatterService.IsReady)
{
_openAiFormatterService = new OpenAiFormatterService();
this.FormatterService = _openAiFormatterService;
}
else if (_configuration.Source.ToLower() == "ollama")
{
_ollamaFormatterService = new OllamaFormatterService(_configuration);
this.FormatterService = _ollamaFormatterService;
}
else if (_configuration.Source.ToLower() == "shadows")
{
_shadowsFormatterService = new ShadowsFormatterService(_configuration);
this.FormatterService = _shadowsFormatterService;
}

_log.Trace($"Content service configured for {_configuration.Source} on {_configuration.Host} running {_configuration.Model}");
}

public async Task<string> GenerateNextAction(NpcRecord agent, string history)
{
var nextAction = string.Empty;
try
{
if (_configuration.Source.ToLower() == "openai" && this._openAiFormatterService.IsReady)
{
nextAction = await this._openAiFormatterService.GenerateNextAction(agent, history).ConfigureAwait(false);
}
else if (_configuration.Source.ToLower() == "ollama")
{
nextAction = await this._ollamaFormatterService.GenerateNextAction(agent, history);
}
else if (_configuration.Source.ToLower() == "shadows")
{
nextAction = await this._shadowsFormatterService.GenerateNextAction(agent, history);
}

_log.Info($"{agent.NpcProfile.Name}'s next action is: {nextAction}");
}
catch (Exception e)
{
_log.Error(e);
}
var nextAction = await this.FormatterService.GenerateNextAction(agent, history);
_log.Info($"{agent.NpcProfile.Name}'s next action is: {nextAction}");
return nextAction;
}

public async Task<string> GenerateTweet(NpcRecord agent)
public async Task<string> GenerateTweet(NpcRecord npc)
{
string tweetText = null;

try
{
if (_configuration.Source.ToLower() == "openai" && this._openAiFormatterService.IsReady)
{
tweetText = await this._openAiFormatterService.GenerateTweet(agent).ConfigureAwait(false);
}
else if (_configuration.Source.ToLower() == "ollama")
{
var tries = 0;
while (string.IsNullOrEmpty(tweetText))
{
tweetText = await this._ollamaFormatterService.GenerateTweet(agent);
tries++;
if (tries > 5)
return null;
}

var regArray = new [] {"\"activities\": \\[\"([^\"]+)\"", "\"activity\": \"([^\"]+)\"", "'activities': \\['([^\\']+)'\\]", "\"activities\": \\[\"([^\\']+)'\\]"} ;

foreach (var reg in regArray)
{
var match = Regex.Match(tweetText,reg);
if (match.Success)
{
// Extract the activity
tweetText = match.Groups[1].Value;
break;
}
}
}
else if (_configuration.Source.ToLower() == "shadows")
{
var tries = 0;
while (string.IsNullOrEmpty(tweetText))
{
tweetText = await this._shadowsFormatterService.GenerateTweet(agent);
tries++;
if (tries > 5)
return null;
}

var regArray = new [] {"\"activities\": \\[\"([^\"]+)\"", "\"activity\": \"([^\"]+)\"", "'activities': \\['([^\\']+)'\\]", "\"activities\": \\[\"([^\\']+)'\\]"} ;

foreach (var reg in regArray)
{
var match = Regex.Match(tweetText,reg);
if (match.Success)
{
// Extract the activity
tweetText = match.Groups[1].Value;
break;
}
}
}

while (string.IsNullOrEmpty(tweetText))
{
tweetText = NativeContentFormatterService.GenerateTweet(agent);
}

tweetText = tweetText.ReplaceDoubleQuotesWithSingleQuotes(); // else breaks csv file, //TODO should replace this with a proper csv library

tweetText = Clean(tweetText);

_log.Info($"{agent.NpcProfile.Name} said: {tweetText}");
}
catch (Exception e)
{
_log.Info(e);
}
var tweetText = await this.FormatterService.GenerateTweet(npc);
_log.Info($"{npc.NpcProfile.Name} said: {tweetText}");
return tweetText;
}

private string Clean(string raw)
{
raw = raw.Replace("`", "");
raw = raw.Replace("\"", "");
raw = raw.Replace("'", "");
return raw;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using System.Threading.Tasks;

namespace ghosts.api.Areas.Animator.Infrastructure.ContentServices;

public interface IContentService
{
Task<string> ExecuteQuery(string prompt);
}
Loading

0 comments on commit 6f18350

Please sign in to comment.