|
| 1 | +import { invariant } from '../errors.ts'; |
| 2 | +import { roundUp } from '../mathUtils.ts'; |
| 3 | +import { alignmentOf } from '../data/alignmentOf.ts'; |
| 4 | +import { offsetsForProps } from '../data/offsets.ts'; |
| 5 | +import { sizeOf } from '../data/sizeOf.ts'; |
| 6 | +import type { BaseData, TypedArrayFor, WgslArray, WgslStruct } from '../data/wgslTypes.ts'; |
| 7 | +import { isMat, isMat2x2f, isMat3x3f, isWgslArray } from '../data/wgslTypes.ts'; |
| 8 | +import type { BufferWriteOptions, TgpuBuffer } from '../core/buffer/buffer.ts'; |
| 9 | +import type { Prettify } from '../shared/utilityTypes.ts'; |
| 10 | + |
| 11 | +type UnwrapWgslArray<T> = T extends WgslArray<infer U> ? UnwrapWgslArray<U> : T; |
| 12 | +type PackedSoAInputFor<T> = TypedArrayFor<UnwrapWgslArray<T>>; |
| 13 | + |
| 14 | +type SoAFieldsFor<T extends Record<string, BaseData>> = { |
| 15 | + [K in keyof T as [PackedSoAInputFor<T[K]>] extends [never] ? never : K]: PackedSoAInputFor<T[K]>; |
| 16 | +}; |
| 17 | + |
| 18 | +type SoAInputFor<T extends Record<string, BaseData>> = [keyof T] extends [keyof SoAFieldsFor<T>] |
| 19 | + ? Prettify<SoAFieldsFor<T>> |
| 20 | + : never; |
| 21 | + |
| 22 | +function getPackedMatrixLayout(schema: BaseData) { |
| 23 | + if (!isMat(schema)) { |
| 24 | + return undefined; |
| 25 | + } |
| 26 | + |
| 27 | + const dim = isMat3x3f(schema) ? 3 : isMat2x2f(schema) ? 2 : 4; |
| 28 | + const packedColumnSize = dim * 4; |
| 29 | + |
| 30 | + return { |
| 31 | + dim, |
| 32 | + packedColumnSize, |
| 33 | + packedSize: dim * packedColumnSize, |
| 34 | + } as const; |
| 35 | +} |
| 36 | + |
| 37 | +function packedSizeOf(schema: BaseData): number { |
| 38 | + const matrixLayout = getPackedMatrixLayout(schema); |
| 39 | + if (matrixLayout) { |
| 40 | + return matrixLayout.packedSize; |
| 41 | + } |
| 42 | + |
| 43 | + if (isWgslArray(schema)) { |
| 44 | + return schema.elementCount * packedSizeOf(schema.elementType); |
| 45 | + } |
| 46 | + |
| 47 | + return sizeOf(schema); |
| 48 | +} |
| 49 | + |
| 50 | +function inferSoAElementCount( |
| 51 | + arraySchema: WgslArray, |
| 52 | + soaData: Record<string, ArrayBufferView>, |
| 53 | +): number | undefined { |
| 54 | + const structSchema = arraySchema.elementType as WgslStruct; |
| 55 | + let inferredCount: number | undefined; |
| 56 | + |
| 57 | + for (const key in soaData) { |
| 58 | + const srcArray = soaData[key]; |
| 59 | + const fieldSchema = structSchema.propTypes[key]; |
| 60 | + if (srcArray === undefined || fieldSchema === undefined) { |
| 61 | + continue; |
| 62 | + } |
| 63 | + |
| 64 | + const fieldPackedSize = packedSizeOf(fieldSchema); |
| 65 | + if (fieldPackedSize === 0) { |
| 66 | + continue; |
| 67 | + } |
| 68 | + |
| 69 | + const fieldElementCount = Math.floor(srcArray.byteLength / fieldPackedSize); |
| 70 | + inferredCount = |
| 71 | + inferredCount === undefined ? fieldElementCount : Math.min(inferredCount, fieldElementCount); |
| 72 | + } |
| 73 | + |
| 74 | + return inferredCount; |
| 75 | +} |
| 76 | + |
| 77 | +function computeSoAByteLength( |
| 78 | + arraySchema: WgslArray, |
| 79 | + soaData: Record<string, ArrayBufferView>, |
| 80 | +): number | undefined { |
| 81 | + const elementCount = inferSoAElementCount(arraySchema, soaData); |
| 82 | + if (elementCount === undefined) { |
| 83 | + return undefined; |
| 84 | + } |
| 85 | + const elementStride = roundUp( |
| 86 | + sizeOf(arraySchema.elementType), |
| 87 | + alignmentOf(arraySchema.elementType), |
| 88 | + ); |
| 89 | + return elementCount * elementStride; |
| 90 | +} |
| 91 | + |
| 92 | +function writePackedValue( |
| 93 | + target: Uint8Array, |
| 94 | + schema: BaseData, |
| 95 | + srcBytes: Uint8Array, |
| 96 | + dstOffset: number, |
| 97 | + srcOffset: number, |
| 98 | +): void { |
| 99 | + const matrixLayout = getPackedMatrixLayout(schema); |
| 100 | + if (matrixLayout) { |
| 101 | + const gpuColumnStride = roundUp(matrixLayout.packedColumnSize, alignmentOf(schema)); |
| 102 | + |
| 103 | + for (let col = 0; col < matrixLayout.dim; col++) { |
| 104 | + target.set( |
| 105 | + srcBytes.subarray( |
| 106 | + srcOffset + col * matrixLayout.packedColumnSize, |
| 107 | + srcOffset + col * matrixLayout.packedColumnSize + matrixLayout.packedColumnSize, |
| 108 | + ), |
| 109 | + dstOffset + col * gpuColumnStride, |
| 110 | + ); |
| 111 | + } |
| 112 | + |
| 113 | + return; |
| 114 | + } |
| 115 | + |
| 116 | + if (isWgslArray(schema)) { |
| 117 | + const packedElementSize = packedSizeOf(schema.elementType); |
| 118 | + const gpuElementStride = roundUp(sizeOf(schema.elementType), alignmentOf(schema.elementType)); |
| 119 | + |
| 120 | + for (let i = 0; i < schema.elementCount; i++) { |
| 121 | + writePackedValue( |
| 122 | + target, |
| 123 | + schema.elementType, |
| 124 | + srcBytes, |
| 125 | + dstOffset + i * gpuElementStride, |
| 126 | + srcOffset + i * packedElementSize, |
| 127 | + ); |
| 128 | + } |
| 129 | + |
| 130 | + return; |
| 131 | + } |
| 132 | + |
| 133 | + target.set(srcBytes.subarray(srcOffset, srcOffset + sizeOf(schema)), dstOffset); |
| 134 | +} |
| 135 | + |
| 136 | +function scatterSoA( |
| 137 | + target: Uint8Array, |
| 138 | + arraySchema: WgslArray, |
| 139 | + soaData: Record<string, ArrayBufferView>, |
| 140 | + startOffset: number, |
| 141 | + endOffset: number, |
| 142 | +): void { |
| 143 | + const structSchema = arraySchema.elementType as WgslStruct; |
| 144 | + const offsets = offsetsForProps(structSchema); |
| 145 | + const elementStride = roundUp(sizeOf(structSchema), alignmentOf(structSchema)); |
| 146 | + invariant( |
| 147 | + startOffset % elementStride === 0, |
| 148 | + `startOffset (${startOffset}) must be aligned to the element stride (${elementStride})`, |
| 149 | + ); |
| 150 | + const startElement = Math.floor(startOffset / elementStride); |
| 151 | + const endElement = Math.min(arraySchema.elementCount, Math.ceil(endOffset / elementStride)); |
| 152 | + const elementCount = Math.max(0, endElement - startElement); |
| 153 | + |
| 154 | + for (const key in structSchema.propTypes) { |
| 155 | + const fieldSchema = structSchema.propTypes[key]; |
| 156 | + if (fieldSchema === undefined) { |
| 157 | + continue; |
| 158 | + } |
| 159 | + const srcArray = soaData[key]; |
| 160 | + invariant(srcArray !== undefined, `Missing SoA data for field '${key}'`); |
| 161 | + |
| 162 | + const fieldOffset = offsets[key]?.offset; |
| 163 | + invariant(fieldOffset !== undefined, `Field ${key} not found in struct schema`); |
| 164 | + const srcBytes = new Uint8Array(srcArray.buffer, srcArray.byteOffset, srcArray.byteLength); |
| 165 | + |
| 166 | + const packedFieldSize = packedSizeOf(fieldSchema); |
| 167 | + for (let i = 0; i < elementCount; i++) { |
| 168 | + writePackedValue( |
| 169 | + target, |
| 170 | + fieldSchema, |
| 171 | + srcBytes, |
| 172 | + (startElement + i) * elementStride + fieldOffset, |
| 173 | + i * packedFieldSize, |
| 174 | + ); |
| 175 | + } |
| 176 | + } |
| 177 | +} |
| 178 | + |
| 179 | +export function writeSoA<TProps extends Record<string, BaseData>>( |
| 180 | + buffer: TgpuBuffer<WgslArray<WgslStruct<TProps>>>, |
| 181 | + data: SoAInputFor<TProps>, |
| 182 | + options?: BufferWriteOptions, |
| 183 | +): void { |
| 184 | + const arrayBuffer = buffer.arrayBuffer; |
| 185 | + const startOffset = options?.startOffset ?? 0; |
| 186 | + const bufferSize = sizeOf(buffer.dataType); |
| 187 | + const naturalSize = computeSoAByteLength( |
| 188 | + buffer.dataType, |
| 189 | + data as Record<string, ArrayBufferView>, |
| 190 | + ); |
| 191 | + const endOffset = |
| 192 | + options?.endOffset ?? |
| 193 | + (naturalSize === undefined ? bufferSize : Math.min(startOffset + naturalSize, bufferSize)); |
| 194 | + |
| 195 | + scatterSoA( |
| 196 | + new Uint8Array(arrayBuffer), |
| 197 | + buffer.dataType, |
| 198 | + data as Record<string, ArrayBufferView>, |
| 199 | + startOffset, |
| 200 | + endOffset, |
| 201 | + ); |
| 202 | + buffer.write(arrayBuffer, { startOffset, endOffset }); |
| 203 | +} |
| 204 | + |
| 205 | +export namespace writeSoA { |
| 206 | + export type InputFor<TProps extends Record<string, BaseData>> = SoAInputFor<TProps>; |
| 207 | +} |
0 commit comments