Skip to content

Commit

Permalink
Merge pull request #113 from jackschedel/2.1.0
Browse files Browse the repository at this point in the history
api key migration + bug fixup
  • Loading branch information
jackschedel authored Feb 12, 2024
2 parents 8def22a + 473a6cd commit d5bfb23
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 20 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "koala-client",
"private": true,
"version": "2.0.8",
"version": "2.1.0",
"type": "module",
"homepage": "./",
"main": "electron/index.cjs",
Expand Down
11 changes: 11 additions & 0 deletions src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ function App() {
const setHideSideMenu = useStore((state) => state.setHideSideMenu);
const hideSideMenu = useStore((state) => state.hideSideMenu);
const bottomMessageRef = useStore((state) => state.bottomMessageRef);
const setApiAuth = useStore((state) => state.setApiAuth);
const apiAuth = useStore((state) => state.apiAuth);
const apiKey = useStore((state) => state.apiKey);
const setApiKey = useStore((state) => state.setApiKey);

const initialiseNewChat = useInitialiseNewChat();
const addChat = useAddChat();
Expand All @@ -30,6 +34,13 @@ function App() {
const copyCodeBlock = useCopyCodeBlock();
const { handleSubmit } = useSubmit();

if (apiKey && !apiAuth[0].apiKey) {
const old = apiAuth;
old[0].apiKey = apiKey;
setApiAuth(old);
setApiKey('');
}

const handleGenerate = () => {
if (useStore.getState().generating) return;
const updatedChats: ChatInterface[] = JSON.parse(
Expand Down
4 changes: 2 additions & 2 deletions src/api/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export const getChatCompletion = async (
if (isAzureEndpoint(endpoint) && apiKey) {
headers['api-key'] = apiKey;

const modelName = modelDef.name;
const modelName = modelDef.model;

const apiVersion = '2023-03-15-preview';

Expand Down Expand Up @@ -69,7 +69,7 @@ export const getChatCompletionStream = async (
if (isAzureEndpoint(endpoint) && apiKey) {
headers['api-key'] = apiKey;

const modelName = modelDef.name;
const modelName = modelDef.model;

const apiVersion = '2023-03-15-preview';

Expand Down
10 changes: 6 additions & 4 deletions src/components/ApiMenu/ApiMenu.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,12 @@ const ApiMenu = ({
}}
>
<span className='inline-block truncate max-w-full'>
{_apiAuth[modelDef.endpoint].endpoint.replace(
/^https?:\/\//,
''
)}
{_apiAuth[modelDef.endpoint]
? _apiAuth[modelDef.endpoint].endpoint.replace(
/^https?:\/\//,
''
)
: 'Endpoint Undefined'}
</span>

<DownChevronArrow className='absolute right-0 mr-1 flex items-center' />
Expand Down
6 changes: 5 additions & 1 deletion src/components/ConfigMenu/ModelSelect.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ export const ModelSelect = ({
const [dropDown, setDropDown, dropDownRef] = useHideOnOutsideClick();
const modelDefs = useStore((state: StoreState) => state.modelDefs);

if (typeof _model !== 'number') {
_setModel(0);
}

return (
<div className='mb-4'>
<button
Expand All @@ -25,7 +29,7 @@ export const ModelSelect = ({
onClick={() => setDropDown((prev) => !prev)}
aria-label='model'
>
{modelDefs[model]?.name}
{modelDefs[model]?.name || modelDefs[model].model}
<DownChevronArrow />
</button>
<div
Expand Down
8 changes: 3 additions & 5 deletions src/hooks/useSubmit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ const useSubmit = () => {
if (!apiKey || apiKey.length === 0) {
// official endpoint
if (apiEndpoint === officialAPIEndpoint) {
const error = new Error(t('noApiKeyWarning') as string);
setError(error.message);
throw error;
throw new Error(t('noApiKeyWarning') as string);
}

// other endpoints
Expand Down Expand Up @@ -223,15 +221,15 @@ const useSubmit = () => {

// update tokens used for generating title
if (countTotalTokens) {
const model = config.model_selection;
updateTotalTokenUsed(model, [message], {
updateTotalTokenUsed(0, [message], {
role: 'assistant',
content: title,
});
}
}
} catch (e: unknown) {
setError((e as Error).message);
setGenerating(false);
throw e;
}
setGenerating(false);
Expand Down
7 changes: 1 addition & 6 deletions src/store/auth-slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { ModelDefinition } from '@type/chat';
import { StoreSlice } from './store';

export interface AuthSlice {
apiKey?: string;
firstVisit: boolean;
apiAuth: EndpointAuth[];
modelDefs: ModelDefinition[];
Expand Down Expand Up @@ -44,12 +45,6 @@ export const createAuthSlice: StoreSlice<AuthSlice> = (set) => ({
apiKey: apiKey,
}));
},
setApiEndpoint: (apiEndpoint: string) => {
set((prev: AuthSlice) => ({
...prev,
apiEndpoint: apiEndpoint,
}));
},
setFirstVisit: (firstVisit: boolean) => {
set((prev: AuthSlice) => ({
...prev,
Expand Down
2 changes: 1 addition & 1 deletion src/utils/messageUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ export const useUpdateTotalTokenUsed = () => {
completionMessage: MessageInterface
) => {
const updatedTotalTokenUsed = JSON.parse(JSON.stringify(totalTokenUsed));
const modelName = modelDefs[model].name;
const modelName = modelDefs[model].model;

const newPromptTokens = countTokens(promptMessages, modelName);
const newCompletionTokens = countTokens([completionMessage], modelName);
Expand Down

0 comments on commit d5bfb23

Please sign in to comment.