Files
LLM-Labs/src/components/labs/QuantizationGridExplorer.tsx
T
2026-04-07 16:02:48 -06:00

151 lines
4.2 KiB
TypeScript

"use client";
import { useMemo, useState } from "react";
const BIT_DEPTHS = [2, 4, 6, 8, 16] as const;
const TOY_LAYER_WEIGHTS = [
0.91, -0.72, 0.64, -0.58, 0.41, -0.33, 0.28, -0.19, 0.15, -0.11, 0.09, -0.06,
0.04, -0.03, 0.02, -0.01,
] as const;
type BitDepth = (typeof BIT_DEPTHS)[number];
type StoredCell = {
stored: string;
};
function clamp(value: number, min: number, max: number) {
return Math.min(max, Math.max(min, value));
}
function formatFloat(value: number, decimals = 3) {
return value.toFixed(decimals);
}
function toBfloat16Word(value: number) {
const floatView = new Float32Array(1);
const intView = new Uint32Array(floatView.buffer);
floatView[0] = value;
const current = intView[0] ?? 0;
const leastSignificantBit = (current >> 16) & 1;
const roundingBias = 0x7fff + leastSignificantBit;
const rounded = (current + roundingBias) & 0xffff0000;
intView[0] = rounded;
return {
reconstructed: floatView[0] ?? value,
word: `0x${(rounded >>> 16).toString(16).padStart(4, "0")}`,
};
}
function getSharedScale(bitDepth: Exclude<BitDepth, 16>) {
const maxMagnitude = Math.max(
...TOY_LAYER_WEIGHTS.map((value) => Math.abs(value)),
);
const qmax = Math.pow(2, bitDepth - 1) - 1;
return {
qmax,
scale: maxMagnitude / qmax,
};
}
function quantizeCell(value: number, bitDepth: BitDepth): StoredCell {
if (bitDepth === 16) {
const bf16 = toBfloat16Word(value);
return {
stored: bf16.word,
};
}
const { scale, qmax } = getSharedScale(bitDepth);
const stored = clamp(Math.round(value / scale), -qmax, qmax);
return {
stored: String(stored),
};
}
export function QuantizationGridExplorer() {
const [bitDepthIndex, setBitDepthIndex] = useState(0);
const bitDepth = BIT_DEPTHS[bitDepthIndex] ?? BIT_DEPTHS[0];
const scaleSummary =
bitDepth === 16 ? null : getSharedScale(bitDepth as Exclude<BitDepth, 16>);
const cells = useMemo(
() => TOY_LAYER_WEIGHTS.map((weight) => quantizeCell(weight, bitDepth)),
[bitDepth],
);
return (
<div className="quantization-grid-explorer" data-widget-enhanced="true">
<div className="quantization-grid-explorer__header">
<p className="quantization-grid-explorer__eyebrow">Toy Layer View</p>
<h3>Watch a tiny layer get stored as 16 buckets</h3>
<p className="quantization-grid-explorer__lede">
Each square below is one toy weight slot. The number shown is the
stored bucket value.
</p>
</div>
<div className="quantization-grid-explorer__slider-card">
<label
className="quantization-grid-explorer__slider-label"
htmlFor="grid-quant-depth"
>
Precision:{" "}
<strong>{bitDepth === 16 ? "16-bit BF16" : `${bitDepth}-bit`}</strong>
</label>
<input
id="grid-quant-depth"
type="range"
min={0}
max={BIT_DEPTHS.length - 1}
step={1}
value={bitDepthIndex}
onChange={(event) => setBitDepthIndex(Number(event.target.value))}
/>
<div
className="quantization-grid-explorer__tick-row"
aria-hidden="true"
>
{BIT_DEPTHS.map((depth) => (
<span key={depth}>{depth}</span>
))}
</div>
</div>
<p className="quantization-grid-explorer__helper">
{bitDepth === 16 ? (
<>
In BF16, each square stores a 16-bit word instead of a tiny bucket.
</>
) : (
<>
Smaller bit depths force more weights into the same few bucket
values. Scale = <code>{formatFloat(scaleSummary?.scale ?? 0)}</code>
</>
)}
</p>
<div className="quantization-grid-explorer__grid">
{cells.map((cell, index) => (
<div
className="quantization-grid-explorer__cell"
key={`${bitDepth}-${index}`}
>
<span className="quantization-grid-explorer__cell-label">
W{index}
</span>
<strong className="quantization-grid-explorer__cell-value">
{cell.stored}
</strong>
</div>
))}
</div>
</div>
);
}