Add configurable token limit and truncation warning to Lab 1 confidence chat
This commit is contained in:
@@ -4,11 +4,12 @@ import { normalizeUpstreamChatEndpoint } from "~/lib/lab2-chat";
|
||||
import {
|
||||
clampLab1Messages,
|
||||
extractLab1AssistantContent,
|
||||
extractLab1FinishReason,
|
||||
extractLab1ResponseTokens,
|
||||
getLab1SystemPrompt,
|
||||
LAB1_CONFIDENCE_MODEL_ALIAS,
|
||||
LAB1_DEFAULT_MAX_TOKENS,
|
||||
LAB1_DEFAULT_TEMPERATURE,
|
||||
parseLab1MaxTokens,
|
||||
type Lab1ConfidenceMessage,
|
||||
} from "~/lib/lab1-confidence";
|
||||
|
||||
@@ -32,6 +33,10 @@ function getLab1ModelAlias() {
|
||||
);
|
||||
}
|
||||
|
||||
function getLab1MaxTokens() {
|
||||
return parseLab1MaxTokens(process.env.COURSEWARE_LAB1_MAX_TOKENS?.trim());
|
||||
}
|
||||
|
||||
export async function POST(request: Request) {
|
||||
let body: ChatRouteRequestBody;
|
||||
|
||||
@@ -62,10 +67,11 @@ export async function POST(request: Request) {
|
||||
);
|
||||
|
||||
try {
|
||||
const maxTokens = getLab1MaxTokens();
|
||||
const upstreamResponse = await fetch(getLocalOllamaEndpoint(), {
|
||||
body: JSON.stringify({
|
||||
logprobs: true,
|
||||
max_tokens: LAB1_DEFAULT_MAX_TOKENS,
|
||||
max_tokens: maxTokens,
|
||||
messages: [
|
||||
{
|
||||
content: getLab1SystemPrompt(),
|
||||
@@ -131,13 +137,18 @@ export async function POST(request: Request) {
|
||||
const content =
|
||||
extractLab1AssistantContent(parsedBody) ||
|
||||
tokens.map((token) => token.token).join("");
|
||||
const finishReason = extractLab1FinishReason(parsedBody);
|
||||
const isTruncated = finishReason === "length";
|
||||
|
||||
return NextResponse.json({
|
||||
content,
|
||||
finishReason,
|
||||
isTruncated,
|
||||
maxTokens,
|
||||
model:
|
||||
("model" in parsedBody && typeof parsedBody.model === "string"
|
||||
"model" in parsedBody && typeof parsedBody.model === "string"
|
||||
? parsedBody.model
|
||||
: getLab1ModelAlias()),
|
||||
: getLab1ModelAlias(),
|
||||
role: "assistant",
|
||||
tokens,
|
||||
});
|
||||
@@ -153,7 +164,8 @@ export async function POST(request: Request) {
|
||||
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: "The Lab 1 confidence route could not reach the local Ollama endpoint.",
|
||||
error:
|
||||
"The Lab 1 confidence route could not reach the local Ollama endpoint.",
|
||||
},
|
||||
{ status: 502 },
|
||||
);
|
||||
|
||||
@@ -15,6 +15,9 @@ describe("Lab1ConfidenceChat", () => {
|
||||
return {
|
||||
json: async () => ({
|
||||
content: "often works",
|
||||
finishReason: "stop",
|
||||
isTruncated: false,
|
||||
maxTokens: 512,
|
||||
model: "batiai/gemma4-e2b:q4",
|
||||
role: "assistant",
|
||||
tokens: [
|
||||
@@ -86,4 +89,46 @@ describe("Lab1ConfidenceChat", () => {
|
||||
await screen.findByText("The local Ollama request failed."),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("explains when the response hit the configured token limit", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn(async () => {
|
||||
return {
|
||||
json: async () => ({
|
||||
content: "partial output",
|
||||
finishReason: "length",
|
||||
isTruncated: true,
|
||||
maxTokens: 512,
|
||||
model: "batiai/gemma4-e2b:q4",
|
||||
role: "assistant",
|
||||
tokens: [
|
||||
{
|
||||
logprob: Math.log(0.5),
|
||||
probability: 50,
|
||||
token: "partial",
|
||||
topAlternatives: [],
|
||||
},
|
||||
],
|
||||
}),
|
||||
ok: true,
|
||||
};
|
||||
}),
|
||||
);
|
||||
|
||||
render(<Lab1ConfidenceChat />);
|
||||
|
||||
fireEvent.change(screen.getByLabelText("Prompt"), {
|
||||
target: { value: "Write a longer answer." },
|
||||
});
|
||||
fireEvent.submit(
|
||||
screen.getByRole("button", { name: "Generate Output" }).closest("form")!,
|
||||
);
|
||||
|
||||
expect(
|
||||
await screen.findByText(
|
||||
/Response reached the configured 512-token limit/,
|
||||
),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -304,6 +304,15 @@ export function Lab1ConfidenceChat() {
|
||||
})}
|
||||
</div>
|
||||
|
||||
{message.isTruncated ? (
|
||||
<p className="lab1-confidence__message-warning">
|
||||
Response reached the configured{" "}
|
||||
{message.maxTokens ? `${message.maxTokens}-token` : "token"}{" "}
|
||||
limit. Increase <code>COURSEWARE_LAB1_MAX_TOKENS</code> to
|
||||
allow longer Lab 1 generations.
|
||||
</p>
|
||||
) : null}
|
||||
|
||||
{message.error ? (
|
||||
<p className="lab1-confidence__message-warning">
|
||||
{message.error}
|
||||
|
||||
@@ -2,10 +2,12 @@ import { describe, expect, it } from "vitest";
|
||||
|
||||
import {
|
||||
extractLab1AssistantContent,
|
||||
extractLab1FinishReason,
|
||||
extractLab1ResponseTokens,
|
||||
formatProbabilityPercent,
|
||||
getConfidenceBand,
|
||||
logprobToProbabilityPercent,
|
||||
parseLab1MaxTokens,
|
||||
} from "~/lib/lab1-confidence";
|
||||
|
||||
describe("logprobToProbabilityPercent", () => {
|
||||
@@ -30,6 +32,28 @@ describe("extractLab1AssistantContent", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("extractLab1FinishReason", () => {
|
||||
it("reads the upstream finish reason when it is present", () => {
|
||||
expect(
|
||||
extractLab1FinishReason({
|
||||
choices: [
|
||||
{
|
||||
finish_reason: "length",
|
||||
},
|
||||
],
|
||||
}),
|
||||
).toBe("length");
|
||||
});
|
||||
});
|
||||
|
||||
describe("parseLab1MaxTokens", () => {
|
||||
it("uses a bounded positive environment override", () => {
|
||||
expect(parseLab1MaxTokens("768")).toBe(768);
|
||||
expect(parseLab1MaxTokens("999999")).toBe(2048);
|
||||
expect(parseLab1MaxTokens("nope")).toBe(512);
|
||||
});
|
||||
});
|
||||
|
||||
describe("extractLab1ResponseTokens", () => {
|
||||
it("maps token logprobs and alternate candidates into display data", () => {
|
||||
expect(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
export const LAB1_CONFIDENCE_MODEL_ALIAS = "batiai/gemma4-e2b:q4";
|
||||
export const LAB1_DEFAULT_MAX_TOKENS = 64;
|
||||
export const LAB1_DEFAULT_MAX_TOKENS = 512;
|
||||
export const LAB1_DEFAULT_TEMPERATURE = 0.7;
|
||||
export const LAB1_MAX_COMPLETION_TOKENS = 2048;
|
||||
export const LAB1_MAX_CONTEXT_MESSAGES = 10;
|
||||
export const LAB1_MAX_MESSAGE_LENGTH = 4000;
|
||||
|
||||
@@ -25,6 +26,9 @@ export type Lab1ResponseToken = {
|
||||
|
||||
export type Lab1ConfidenceResponse = {
|
||||
content: string;
|
||||
finishReason: string | null;
|
||||
isTruncated: boolean;
|
||||
maxTokens: number;
|
||||
model: string;
|
||||
role: "assistant";
|
||||
tokens: Lab1ResponseToken[];
|
||||
@@ -43,6 +47,7 @@ type OpenAiLogprobToken = {
|
||||
|
||||
type OpenAiCompatibilityPayload = {
|
||||
choices?: Array<{
|
||||
finish_reason?: string;
|
||||
logprobs?: {
|
||||
content?: OpenAiLogprobToken[];
|
||||
};
|
||||
@@ -61,6 +66,19 @@ export function getLab1SystemPrompt() {
|
||||
].join(" ");
|
||||
}
|
||||
|
||||
export function parseLab1MaxTokens(value: string | undefined) {
|
||||
if (!value) {
|
||||
return LAB1_DEFAULT_MAX_TOKENS;
|
||||
}
|
||||
|
||||
const parsedValue = Number.parseInt(value, 10);
|
||||
if (!Number.isFinite(parsedValue) || parsedValue <= 0) {
|
||||
return LAB1_DEFAULT_MAX_TOKENS;
|
||||
}
|
||||
|
||||
return Math.min(parsedValue, LAB1_MAX_COMPLETION_TOKENS);
|
||||
}
|
||||
|
||||
export function clampLab1Messages(messages: Lab1ConfidenceMessage[]) {
|
||||
return messages
|
||||
.filter((message) => {
|
||||
@@ -117,6 +135,13 @@ export function extractLab1AssistantContent(
|
||||
return content || null;
|
||||
}
|
||||
|
||||
export function extractLab1FinishReason(payload: OpenAiCompatibilityPayload) {
|
||||
const finishReason = payload.choices?.[0]?.finish_reason;
|
||||
return typeof finishReason === "string" && finishReason.trim()
|
||||
? finishReason
|
||||
: null;
|
||||
}
|
||||
|
||||
export function extractLab1ResponseTokens(
|
||||
payload: OpenAiCompatibilityPayload,
|
||||
): Lab1ResponseToken[] {
|
||||
|
||||
Reference in New Issue
Block a user