New Lab 2
This commit is contained in:
@@ -0,0 +1,150 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo, useState } from "react";
|
||||
|
||||
const BIT_DEPTHS = [2, 4, 6, 8, 16] as const;
|
||||
const TOY_LAYER_WEIGHTS = [
|
||||
0.91, -0.72, 0.64, -0.58, 0.41, -0.33, 0.28, -0.19, 0.15, -0.11, 0.09, -0.06,
|
||||
0.04, -0.03, 0.02, -0.01,
|
||||
] as const;
|
||||
|
||||
type BitDepth = (typeof BIT_DEPTHS)[number];
|
||||
|
||||
type StoredCell = {
|
||||
stored: string;
|
||||
};
|
||||
|
||||
function clamp(value: number, min: number, max: number) {
|
||||
return Math.min(max, Math.max(min, value));
|
||||
}
|
||||
|
||||
function formatFloat(value: number, decimals = 3) {
|
||||
return value.toFixed(decimals);
|
||||
}
|
||||
|
||||
function toBfloat16Word(value: number) {
|
||||
const floatView = new Float32Array(1);
|
||||
const intView = new Uint32Array(floatView.buffer);
|
||||
|
||||
floatView[0] = value;
|
||||
const current = intView[0] ?? 0;
|
||||
const leastSignificantBit = (current >> 16) & 1;
|
||||
const roundingBias = 0x7fff + leastSignificantBit;
|
||||
const rounded = (current + roundingBias) & 0xffff0000;
|
||||
|
||||
intView[0] = rounded;
|
||||
|
||||
return {
|
||||
reconstructed: floatView[0] ?? value,
|
||||
word: `0x${(rounded >>> 16).toString(16).padStart(4, "0")}`,
|
||||
};
|
||||
}
|
||||
|
||||
function getSharedScale(bitDepth: Exclude<BitDepth, 16>) {
|
||||
const maxMagnitude = Math.max(
|
||||
...TOY_LAYER_WEIGHTS.map((value) => Math.abs(value)),
|
||||
);
|
||||
const qmax = Math.pow(2, bitDepth - 1) - 1;
|
||||
|
||||
return {
|
||||
qmax,
|
||||
scale: maxMagnitude / qmax,
|
||||
};
|
||||
}
|
||||
|
||||
function quantizeCell(value: number, bitDepth: BitDepth): StoredCell {
|
||||
if (bitDepth === 16) {
|
||||
const bf16 = toBfloat16Word(value);
|
||||
return {
|
||||
stored: bf16.word,
|
||||
};
|
||||
}
|
||||
|
||||
const { scale, qmax } = getSharedScale(bitDepth);
|
||||
const stored = clamp(Math.round(value / scale), -qmax, qmax);
|
||||
|
||||
return {
|
||||
stored: String(stored),
|
||||
};
|
||||
}
|
||||
|
||||
export function QuantizationGridExplorer() {
|
||||
const [bitDepthIndex, setBitDepthIndex] = useState(0);
|
||||
|
||||
const bitDepth = BIT_DEPTHS[bitDepthIndex] ?? BIT_DEPTHS[0];
|
||||
const scaleSummary =
|
||||
bitDepth === 16 ? null : getSharedScale(bitDepth as Exclude<BitDepth, 16>);
|
||||
const cells = useMemo(
|
||||
() => TOY_LAYER_WEIGHTS.map((weight) => quantizeCell(weight, bitDepth)),
|
||||
[bitDepth],
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="quantization-grid-explorer" data-widget-enhanced="true">
|
||||
<div className="quantization-grid-explorer__header">
|
||||
<p className="quantization-grid-explorer__eyebrow">Toy Layer View</p>
|
||||
<h3>Watch a tiny layer get stored as 16 buckets</h3>
|
||||
<p className="quantization-grid-explorer__lede">
|
||||
Each square below is one toy weight slot. The number shown is the
|
||||
stored bucket value.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="quantization-grid-explorer__slider-card">
|
||||
<label
|
||||
className="quantization-grid-explorer__slider-label"
|
||||
htmlFor="grid-quant-depth"
|
||||
>
|
||||
Precision:{" "}
|
||||
<strong>{bitDepth === 16 ? "16-bit BF16" : `${bitDepth}-bit`}</strong>
|
||||
</label>
|
||||
<input
|
||||
id="grid-quant-depth"
|
||||
type="range"
|
||||
min={0}
|
||||
max={BIT_DEPTHS.length - 1}
|
||||
step={1}
|
||||
value={bitDepthIndex}
|
||||
onChange={(event) => setBitDepthIndex(Number(event.target.value))}
|
||||
/>
|
||||
<div
|
||||
className="quantization-grid-explorer__tick-row"
|
||||
aria-hidden="true"
|
||||
>
|
||||
{BIT_DEPTHS.map((depth) => (
|
||||
<span key={depth}>{depth}</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<p className="quantization-grid-explorer__helper">
|
||||
{bitDepth === 16 ? (
|
||||
<>
|
||||
In BF16, each square stores a 16-bit word instead of a tiny bucket.
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
Smaller bit depths force more weights into the same few bucket
|
||||
values. Scale = <code>{formatFloat(scaleSummary?.scale ?? 0)}</code>
|
||||
</>
|
||||
)}
|
||||
</p>
|
||||
|
||||
<div className="quantization-grid-explorer__grid">
|
||||
{cells.map((cell, index) => (
|
||||
<div
|
||||
className="quantization-grid-explorer__cell"
|
||||
key={`${bitDepth}-${index}`}
|
||||
>
|
||||
<span className="quantization-grid-explorer__cell-label">
|
||||
W{index}
|
||||
</span>
|
||||
<strong className="quantization-grid-explorer__cell-value">
|
||||
{cell.stored}
|
||||
</strong>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user