Add configurable token limit and truncation warning to Lab 1 confidence chat

This commit is contained in:
c4ch3c4d3
2026-04-27 10:58:13 -06:00
parent 269a4e4985
commit a7c1bda07c
5 changed files with 121 additions and 6 deletions
+17 -5
View File
@@ -4,11 +4,12 @@ import { normalizeUpstreamChatEndpoint } from "~/lib/lab2-chat";
import { import {
clampLab1Messages, clampLab1Messages,
extractLab1AssistantContent, extractLab1AssistantContent,
extractLab1FinishReason,
extractLab1ResponseTokens, extractLab1ResponseTokens,
getLab1SystemPrompt, getLab1SystemPrompt,
LAB1_CONFIDENCE_MODEL_ALIAS, LAB1_CONFIDENCE_MODEL_ALIAS,
LAB1_DEFAULT_MAX_TOKENS,
LAB1_DEFAULT_TEMPERATURE, LAB1_DEFAULT_TEMPERATURE,
parseLab1MaxTokens,
type Lab1ConfidenceMessage, type Lab1ConfidenceMessage,
} from "~/lib/lab1-confidence"; } from "~/lib/lab1-confidence";
@@ -32,6 +33,10 @@ function getLab1ModelAlias() {
); );
} }
function getLab1MaxTokens() {
return parseLab1MaxTokens(process.env.COURSEWARE_LAB1_MAX_TOKENS?.trim());
}
export async function POST(request: Request) { export async function POST(request: Request) {
let body: ChatRouteRequestBody; let body: ChatRouteRequestBody;
@@ -62,10 +67,11 @@ export async function POST(request: Request) {
); );
try { try {
const maxTokens = getLab1MaxTokens();
const upstreamResponse = await fetch(getLocalOllamaEndpoint(), { const upstreamResponse = await fetch(getLocalOllamaEndpoint(), {
body: JSON.stringify({ body: JSON.stringify({
logprobs: true, logprobs: true,
max_tokens: LAB1_DEFAULT_MAX_TOKENS, max_tokens: maxTokens,
messages: [ messages: [
{ {
content: getLab1SystemPrompt(), content: getLab1SystemPrompt(),
@@ -131,13 +137,18 @@ export async function POST(request: Request) {
const content = const content =
extractLab1AssistantContent(parsedBody) || extractLab1AssistantContent(parsedBody) ||
tokens.map((token) => token.token).join(""); tokens.map((token) => token.token).join("");
const finishReason = extractLab1FinishReason(parsedBody);
const isTruncated = finishReason === "length";
return NextResponse.json({ return NextResponse.json({
content, content,
finishReason,
isTruncated,
maxTokens,
model: model:
("model" in parsedBody && typeof parsedBody.model === "string" "model" in parsedBody && typeof parsedBody.model === "string"
? parsedBody.model ? parsedBody.model
: getLab1ModelAlias()), : getLab1ModelAlias(),
role: "assistant", role: "assistant",
tokens, tokens,
}); });
@@ -153,7 +164,8 @@ export async function POST(request: Request) {
return NextResponse.json( return NextResponse.json(
{ {
error: "The Lab 1 confidence route could not reach the local Ollama endpoint.", error:
"The Lab 1 confidence route could not reach the local Ollama endpoint.",
}, },
{ status: 502 }, { status: 502 },
); );
@@ -15,6 +15,9 @@ describe("Lab1ConfidenceChat", () => {
return { return {
json: async () => ({ json: async () => ({
content: "often works", content: "often works",
finishReason: "stop",
isTruncated: false,
maxTokens: 512,
model: "batiai/gemma4-e2b:q4", model: "batiai/gemma4-e2b:q4",
role: "assistant", role: "assistant",
tokens: [ tokens: [
@@ -86,4 +89,46 @@ describe("Lab1ConfidenceChat", () => {
await screen.findByText("The local Ollama request failed."), await screen.findByText("The local Ollama request failed."),
).toBeInTheDocument(); ).toBeInTheDocument();
}); });
it("explains when the response hit the configured token limit", async () => {
vi.stubGlobal(
"fetch",
vi.fn(async () => {
return {
json: async () => ({
content: "partial output",
finishReason: "length",
isTruncated: true,
maxTokens: 512,
model: "batiai/gemma4-e2b:q4",
role: "assistant",
tokens: [
{
logprob: Math.log(0.5),
probability: 50,
token: "partial",
topAlternatives: [],
},
],
}),
ok: true,
};
}),
);
render(<Lab1ConfidenceChat />);
fireEvent.change(screen.getByLabelText("Prompt"), {
target: { value: "Write a longer answer." },
});
fireEvent.submit(
screen.getByRole("button", { name: "Generate Output" }).closest("form")!,
);
expect(
await screen.findByText(
/Response reached the configured 512-token limit/,
),
).toBeInTheDocument();
});
}); });
@@ -304,6 +304,15 @@ export function Lab1ConfidenceChat() {
})} })}
</div> </div>
{message.isTruncated ? (
<p className="lab1-confidence__message-warning">
Response reached the configured{" "}
{message.maxTokens ? `${message.maxTokens}-token` : "token"}{" "}
limit. Increase <code>COURSEWARE_LAB1_MAX_TOKENS</code> to
allow longer Lab 1 generations.
</p>
) : null}
{message.error ? ( {message.error ? (
<p className="lab1-confidence__message-warning"> <p className="lab1-confidence__message-warning">
{message.error} {message.error}
+24
View File
@@ -2,10 +2,12 @@ import { describe, expect, it } from "vitest";
import { import {
extractLab1AssistantContent, extractLab1AssistantContent,
extractLab1FinishReason,
extractLab1ResponseTokens, extractLab1ResponseTokens,
formatProbabilityPercent, formatProbabilityPercent,
getConfidenceBand, getConfidenceBand,
logprobToProbabilityPercent, logprobToProbabilityPercent,
parseLab1MaxTokens,
} from "~/lib/lab1-confidence"; } from "~/lib/lab1-confidence";
describe("logprobToProbabilityPercent", () => { describe("logprobToProbabilityPercent", () => {
@@ -30,6 +32,28 @@ describe("extractLab1AssistantContent", () => {
}); });
}); });
describe("extractLab1FinishReason", () => {
it("reads the upstream finish reason when it is present", () => {
expect(
extractLab1FinishReason({
choices: [
{
finish_reason: "length",
},
],
}),
).toBe("length");
});
});
describe("parseLab1MaxTokens", () => {
it("uses a bounded positive environment override", () => {
expect(parseLab1MaxTokens("768")).toBe(768);
expect(parseLab1MaxTokens("999999")).toBe(2048);
expect(parseLab1MaxTokens("nope")).toBe(512);
});
});
describe("extractLab1ResponseTokens", () => { describe("extractLab1ResponseTokens", () => {
it("maps token logprobs and alternate candidates into display data", () => { it("maps token logprobs and alternate candidates into display data", () => {
expect( expect(
+26 -1
View File
@@ -1,6 +1,7 @@
export const LAB1_CONFIDENCE_MODEL_ALIAS = "batiai/gemma4-e2b:q4"; export const LAB1_CONFIDENCE_MODEL_ALIAS = "batiai/gemma4-e2b:q4";
export const LAB1_DEFAULT_MAX_TOKENS = 64; export const LAB1_DEFAULT_MAX_TOKENS = 512;
export const LAB1_DEFAULT_TEMPERATURE = 0.7; export const LAB1_DEFAULT_TEMPERATURE = 0.7;
export const LAB1_MAX_COMPLETION_TOKENS = 2048;
export const LAB1_MAX_CONTEXT_MESSAGES = 10; export const LAB1_MAX_CONTEXT_MESSAGES = 10;
export const LAB1_MAX_MESSAGE_LENGTH = 4000; export const LAB1_MAX_MESSAGE_LENGTH = 4000;
@@ -25,6 +26,9 @@ export type Lab1ResponseToken = {
export type Lab1ConfidenceResponse = { export type Lab1ConfidenceResponse = {
content: string; content: string;
finishReason: string | null;
isTruncated: boolean;
maxTokens: number;
model: string; model: string;
role: "assistant"; role: "assistant";
tokens: Lab1ResponseToken[]; tokens: Lab1ResponseToken[];
@@ -43,6 +47,7 @@ type OpenAiLogprobToken = {
type OpenAiCompatibilityPayload = { type OpenAiCompatibilityPayload = {
choices?: Array<{ choices?: Array<{
finish_reason?: string;
logprobs?: { logprobs?: {
content?: OpenAiLogprobToken[]; content?: OpenAiLogprobToken[];
}; };
@@ -61,6 +66,19 @@ export function getLab1SystemPrompt() {
].join(" "); ].join(" ");
} }
export function parseLab1MaxTokens(value: string | undefined) {
if (!value) {
return LAB1_DEFAULT_MAX_TOKENS;
}
const parsedValue = Number.parseInt(value, 10);
if (!Number.isFinite(parsedValue) || parsedValue <= 0) {
return LAB1_DEFAULT_MAX_TOKENS;
}
return Math.min(parsedValue, LAB1_MAX_COMPLETION_TOKENS);
}
export function clampLab1Messages(messages: Lab1ConfidenceMessage[]) { export function clampLab1Messages(messages: Lab1ConfidenceMessage[]) {
return messages return messages
.filter((message) => { .filter((message) => {
@@ -117,6 +135,13 @@ export function extractLab1AssistantContent(
return content || null; return content || null;
} }
export function extractLab1FinishReason(payload: OpenAiCompatibilityPayload) {
const finishReason = payload.choices?.[0]?.finish_reason;
return typeof finishReason === "string" && finishReason.trim()
? finishReason
: null;
}
export function extractLab1ResponseTokens( export function extractLab1ResponseTokens(
payload: OpenAiCompatibilityPayload, payload: OpenAiCompatibilityPayload,
): Lab1ResponseToken[] { ): Lab1ResponseToken[] {