Skip to content

Commit

Permalink
Reasoning support for claude 3.7 sonnet (#909)
Browse files Browse the repository at this point in the history
  • Loading branch information
maekawataiki authored Feb 26, 2025
1 parent ef0302a commit eb2e9cf
Show file tree
Hide file tree
Showing 11 changed files with 316 additions and 111 deletions.
4 changes: 3 additions & 1 deletion docs/mkdocs/material/overrides/home.html
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,9 @@ <h2 class="mb-16 text-center text-3xl font-bold">導入企業</h2>
alt="タキヒヨー"
class="mb-4 h-16 object-contain" />
<p class="text-center text-sm text-gray-600">
生成 AI を活用し社内業務効率化と 450 時間超の工数削減を実現。Amazon Bedrock を衣服デザイン等に適用、デジタル人材育成を推進。
生成 AI を活用し社内業務効率化と 450
時間超の工数削減を実現。Amazon Bedrock
を衣服デザイン等に適用、デジタル人材育成を推進。
</p>
</a>
</div>
Expand Down
43 changes: 19 additions & 24 deletions packages/cdk/lambda/utils/bedrockApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
ApiInterface,
BedrockImageGenerationResponse,
GenerateImageParams,
Model,
UnrecordedMessage,
} from 'generative-ai-use-cases-jp';
import { BEDROCK_TEXT_GEN_MODELS, BEDROCK_IMAGE_GEN_MODELS } from './models';
Expand Down Expand Up @@ -84,11 +85,11 @@ const initBedrockClient = async () => {
};

const createConverseCommandInput = (
model: string,
model: Model,
messages: UnrecordedMessage[],
id: string
): ConverseCommandInput => {
const modelConfig = BEDROCK_TEXT_GEN_MODELS[model];
const modelConfig = BEDROCK_TEXT_GEN_MODELS[model.modelId];
return modelConfig.createConverseCommandInput(
messages,
id,
Expand All @@ -99,11 +100,11 @@ const createConverseCommandInput = (
};

const createConverseStreamCommandInput = (
model: string,
model: Model,
messages: UnrecordedMessage[],
id: string
): ConverseStreamCommandInput => {
const modelConfig = BEDROCK_TEXT_GEN_MODELS[model];
const modelConfig = BEDROCK_TEXT_GEN_MODELS[model.modelId];
return modelConfig.createConverseStreamCommandInput(
messages,
id,
Expand All @@ -114,34 +115,31 @@ const createConverseStreamCommandInput = (
};

const extractConverseOutputText = (
model: string,
model: Model,
output: ConverseCommandOutput
): string => {
const modelConfig = BEDROCK_TEXT_GEN_MODELS[model];
const modelConfig = BEDROCK_TEXT_GEN_MODELS[model.modelId];
return modelConfig.extractConverseOutputText(output);
};

const extractConverseStreamOutputText = (
model: string,
model: Model,
output: ConverseStreamOutput
): string => {
const modelConfig = BEDROCK_TEXT_GEN_MODELS[model];
const modelConfig = BEDROCK_TEXT_GEN_MODELS[model.modelId];
return modelConfig.extractConverseStreamOutputText(output);
};

const createBodyImage = (
model: string,
params: GenerateImageParams
): string => {
const modelConfig = BEDROCK_IMAGE_GEN_MODELS[model];
const createBodyImage = (model: Model, params: GenerateImageParams): string => {
const modelConfig = BEDROCK_IMAGE_GEN_MODELS[model.modelId];
return modelConfig.createBodyImage(params);
};

const extractOutputImage = (
model: string,
model: Model,
response: BedrockImageGenerationResponse
): string => {
const modelConfig = BEDROCK_IMAGE_GEN_MODELS[model];
const modelConfig = BEDROCK_IMAGE_GEN_MODELS[model.modelId];
return modelConfig.extractOutputImage(response);
};

Expand All @@ -150,21 +148,21 @@ const bedrockApi: Omit<ApiInterface, 'invokeFlow'> = {
const client = await initBedrockClient();

const converseCommandInput = createConverseCommandInput(
model.modelId,
model,
messages,
id
);
const command = new ConverseCommand(converseCommandInput);
const output = await client.send(command);

return extractConverseOutputText(model.modelId, output);
return extractConverseOutputText(model, output);
},
invokeStream: async function* (model, messages, id) {
const client = await initBedrockClient();

try {
const converseStreamCommandInput = createConverseStreamCommandInput(
model.modelId,
model,
messages,
id
);
Expand All @@ -182,10 +180,7 @@ const bedrockApi: Omit<ApiInterface, 'invokeFlow'> = {
break;
}

const outputText = extractConverseStreamOutputText(
model.modelId,
response
);
const outputText = extractConverseStreamOutputText(model, response);

if (outputText) {
yield streamingChunk({ text: outputText });
Expand Down Expand Up @@ -231,13 +226,13 @@ const bedrockApi: Omit<ApiInterface, 'invokeFlow'> = {
// Stable Diffusion や Titan Image Generator を利用した画像生成は Converse API に対応していないため、InvokeModelCommand を利用する
const command = new InvokeModelCommand({
modelId: model.modelId,
body: createBodyImage(model.modelId, params),
body: createBodyImage(model, params),
contentType: 'application/json',
});
const res = await client.send(command);
const body = JSON.parse(Buffer.from(res.body).toString('utf-8'));

return extractOutputImage(model.modelId, body);
return extractOutputImage(model, body);
},
};

Expand Down
65 changes: 46 additions & 19 deletions packages/cdk/lambda/utils/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ const RINNA_PROMPT: PromptTemplate = {

// Model Params

const CLAUDE_3_5_DEFAULT_PARAMS: ConverseInferenceParams = {
maxTokens: 8192,
temperature: 0.6,
topP: 0.8,
};

const CLAUDE_DEFAULT_PARAMS: ConverseInferenceParams = {
maxTokens: 4096,
temperature: 0.6,
Expand Down Expand Up @@ -187,7 +193,7 @@ function normalizeId(id: string): string {
const createConverseCommandInput = (
messages: UnrecordedMessage[],
id: string,
modelId: string,
model: Model,
defaultConverseInferenceParams: ConverseInferenceParams,
usecaseConverseInferenceParams: UsecaseConverseInferenceParams
) => {
Expand Down Expand Up @@ -266,13 +272,34 @@ const createConverseCommandInput = (
const guardrailConfig = createGuardrailConfig();

const converseCommandInput: ConverseCommandInput = {
modelId: modelId,
modelId: model.modelId,
messages: conversation,
system: systemContext,
inferenceConfig: inferenceConfig,
guardrailConfig: guardrailConfig,
};

if (
modelFeatureFlags[model.modelId].reasoning &&
model.modelParameters?.reasoningConfig?.type === 'enabled'
) {
converseCommandInput.inferenceConfig = {
...inferenceConfig,
temperature: 1, // reasoning は temperature を 1 必須
topP: undefined, // reasoning は topP は不要
maxTokens:
(model.modelParameters?.reasoningConfig?.budgetTokens || 0) +
(inferenceConfig?.maxTokens || 0),
};
converseCommandInput.additionalModelRequestFields = {
reasoning_config: {
type: model.modelParameters?.reasoningConfig?.type,
budget_tokens:
model.modelParameters?.reasoningConfig?.budgetTokens || 0,
},
};
}

return converseCommandInput;
};

Expand All @@ -283,7 +310,7 @@ const createConverseCommandInput = (
const createConverseCommandInputWithoutSystemContext = (
messages: UnrecordedMessage[],
id: string,
modelId: string,
model: Model,
defaultConverseInferenceParams: ConverseInferenceParams,
usecaseConverseInferenceParams: UsecaseConverseInferenceParams
) => {
Expand All @@ -305,7 +332,7 @@ const createConverseCommandInputWithoutSystemContext = (
const guardrailConfig = createGuardrailConfig();

const converseCommandInput: ConverseCommandInput = {
modelId: modelId,
modelId: model.modelId,
messages: conversation,
inferenceConfig: inferenceConfig,
guardrailConfig: guardrailConfig,
Expand All @@ -318,14 +345,14 @@ const createConverseCommandInputWithoutSystemContext = (
const createConverseStreamCommandInput = (
messages: UnrecordedMessage[],
id: string,
modelId: string,
model: Model,
defaultParams: ConverseInferenceParams,
usecaseParams: UsecaseConverseInferenceParams
): ConverseStreamCommandInput => {
const converseCommandInput = createConverseCommandInput(
messages,
id,
modelId,
model,
defaultParams,
usecaseParams
);
Expand All @@ -343,14 +370,14 @@ const createConverseStreamCommandInput = (
const createConverseStreamCommandInputWithoutSystemContext = (
messages: UnrecordedMessage[],
id: string,
modelId: string,
model: Model,
defaultParams: ConverseInferenceParams,
usecaseParams: UsecaseConverseInferenceParams
): ConverseStreamCommandInput => {
const converseCommandInput = createConverseCommandInputWithoutSystemContext(
messages,
id,
modelId,
model,
defaultParams,
usecaseParams
);
Expand Down Expand Up @@ -626,14 +653,14 @@ export const BEDROCK_TEXT_GEN_MODELS: {
createConverseCommandInput: (
messages: UnrecordedMessage[],
id: string,
modelId: string,
model: Model,
defaultParams: ConverseInferenceParams,
usecaseParams: UsecaseConverseInferenceParams
) => ConverseCommandInput;
createConverseStreamCommandInput: (
messages: UnrecordedMessage[],
id: string,
modelId: string,
model: Model,
defaultParams: ConverseInferenceParams,
usecaseParams: UsecaseConverseInferenceParams
) => ConverseStreamCommandInput;
Expand All @@ -642,71 +669,71 @@ export const BEDROCK_TEXT_GEN_MODELS: {
};
} = {
'anthropic.claude-3-5-sonnet-20241022-v2:0': {
defaultParams: CLAUDE_DEFAULT_PARAMS,
defaultParams: CLAUDE_3_5_DEFAULT_PARAMS,
usecaseParams: USECASE_DEFAULT_PARAMS,
createConverseCommandInput: createConverseCommandInput,
createConverseStreamCommandInput: createConverseStreamCommandInput,
extractConverseOutputText: extractConverseOutputText,
extractConverseStreamOutputText: extractConverseStreamOutputText,
},
'us.anthropic.claude-3-5-sonnet-20241022-v2:0': {
defaultParams: CLAUDE_DEFAULT_PARAMS,
defaultParams: CLAUDE_3_5_DEFAULT_PARAMS,
usecaseParams: USECASE_DEFAULT_PARAMS,
createConverseCommandInput: createConverseCommandInput,
createConverseStreamCommandInput: createConverseStreamCommandInput,
extractConverseOutputText: extractConverseOutputText,
extractConverseStreamOutputText: extractConverseStreamOutputText,
},
'anthropic.claude-3-5-haiku-20241022-v1:0': {
defaultParams: CLAUDE_DEFAULT_PARAMS,
defaultParams: CLAUDE_3_5_DEFAULT_PARAMS,
usecaseParams: USECASE_DEFAULT_PARAMS,
createConverseCommandInput: createConverseCommandInput,
createConverseStreamCommandInput: createConverseStreamCommandInput,
extractConverseOutputText: extractConverseOutputText,
extractConverseStreamOutputText: extractConverseStreamOutputText,
},
'us.anthropic.claude-3-7-sonnet-20250219-v1:0': {
defaultParams: CLAUDE_DEFAULT_PARAMS,
defaultParams: CLAUDE_3_5_DEFAULT_PARAMS,
usecaseParams: USECASE_DEFAULT_PARAMS,
createConverseCommandInput: createConverseCommandInput,
createConverseStreamCommandInput: createConverseStreamCommandInput,
extractConverseOutputText: extractConverseOutputText,
extractConverseStreamOutputText: extractConverseStreamOutputText,
},
'us.anthropic.claude-3-5-haiku-20241022-v1:0': {
defaultParams: CLAUDE_DEFAULT_PARAMS,
defaultParams: CLAUDE_3_5_DEFAULT_PARAMS,
usecaseParams: USECASE_DEFAULT_PARAMS,
createConverseCommandInput: createConverseCommandInput,
createConverseStreamCommandInput: createConverseStreamCommandInput,
extractConverseOutputText: extractConverseOutputText,
extractConverseStreamOutputText: extractConverseStreamOutputText,
},
'anthropic.claude-3-5-sonnet-20240620-v1:0': {
defaultParams: CLAUDE_DEFAULT_PARAMS,
defaultParams: CLAUDE_3_5_DEFAULT_PARAMS,
usecaseParams: USECASE_DEFAULT_PARAMS,
createConverseCommandInput: createConverseCommandInput,
createConverseStreamCommandInput: createConverseStreamCommandInput,
extractConverseOutputText: extractConverseOutputText,
extractConverseStreamOutputText: extractConverseStreamOutputText,
},
'us.anthropic.claude-3-5-sonnet-20240620-v1:0': {
defaultParams: CLAUDE_DEFAULT_PARAMS,
defaultParams: CLAUDE_3_5_DEFAULT_PARAMS,
usecaseParams: USECASE_DEFAULT_PARAMS,
createConverseCommandInput: createConverseCommandInput,
createConverseStreamCommandInput: createConverseStreamCommandInput,
extractConverseOutputText: extractConverseOutputText,
extractConverseStreamOutputText: extractConverseStreamOutputText,
},
'eu.anthropic.claude-3-5-sonnet-20240620-v1:0': {
defaultParams: CLAUDE_DEFAULT_PARAMS,
defaultParams: CLAUDE_3_5_DEFAULT_PARAMS,
usecaseParams: USECASE_DEFAULT_PARAMS,
createConverseCommandInput: createConverseCommandInput,
createConverseStreamCommandInput: createConverseStreamCommandInput,
extractConverseOutputText: extractConverseOutputText,
extractConverseStreamOutputText: extractConverseStreamOutputText,
},
'apac.anthropic.claude-3-5-sonnet-20240620-v1:0': {
defaultParams: CLAUDE_DEFAULT_PARAMS,
defaultParams: CLAUDE_3_5_DEFAULT_PARAMS,
usecaseParams: USECASE_DEFAULT_PARAMS,
createConverseCommandInput: createConverseCommandInput,
createConverseStreamCommandInput: createConverseStreamCommandInput,
Expand Down
10 changes: 9 additions & 1 deletion packages/common/src/application/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ const MODEL_FEATURE: Record<string, FeatureFlags> = {
TEXT_ONLY: { text: true, doc: false, image: false, video: false },
TEXT_DOC: { text: true, doc: true, image: false, video: false },
TEXT_DOC_IMAGE: { text: true, doc: true, image: true, video: false },
TEXT_DOC_IMAGE_REASONING: {
text: true,
doc: true,
image: true,
video: false,
reasoning: true,
},
TEXT_DOC_IMAGE_VIDEO: { text: true, doc: true, image: true, video: true },
IMAGE_GEN: { image_gen: true },
VIDEO_GEN: { video_gen: true },
Expand All @@ -31,7 +38,8 @@ export const modelFeatureFlags: Record<string, FeatureFlags> = {
...MODEL_FEATURE.TEXT_DOC_IMAGE,
...MODEL_FEATURE.LIGHT,
},
'us.anthropic.claude-3-7-sonnet-20250219-v1:0': MODEL_FEATURE.TEXT_DOC_IMAGE,
'us.anthropic.claude-3-7-sonnet-20250219-v1:0':
MODEL_FEATURE.TEXT_DOC_IMAGE_REASONING,
'us.anthropic.claude-3-5-sonnet-20241022-v2:0': MODEL_FEATURE.TEXT_DOC_IMAGE,
'us.anthropic.claude-3-5-haiku-20241022-v1:0': MODEL_FEATURE.TEXT_DOC_IMAGE,
'us.anthropic.claude-3-5-sonnet-20240620-v1:0': MODEL_FEATURE.TEXT_DOC_IMAGE,
Expand Down
2 changes: 2 additions & 0 deletions packages/types/src/message.d.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import { PrimaryKey } from './base';
import { AdditionalModelRequestFields } from './text';

export type Role = 'system' | 'user' | 'assistant';

export type Model = {
type: 'bedrock' | 'bedrockAgent' | 'bedrockKb' | 'sagemaker';
modelId: string;
modelParameters?: AdditionalModelRequestFields;
sessionId?: string;
};

Expand Down
3 changes: 3 additions & 0 deletions packages/types/src/model.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ export type FeatureFlags = {
doc?: boolean;
image?: boolean;
video?: boolean;
reasoning?: boolean;

image_gen?: boolean;
video_gen?: boolean;

embedding?: boolean;
reranking?: boolean;
// Additional Flags
Expand Down
Loading

0 comments on commit eb2e9cf

Please sign in to comment.