Skip to content

Commit 5ffe9b3

Browse files
committed
finished vision model
1 parent f17d3f8 commit 5ffe9b3

File tree

1 file changed

+131
-116
lines changed

1 file changed

+131
-116
lines changed

Libraries/VLM/Models/Qwen2VL.swift

Lines changed: 131 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -460,140 +460,152 @@ private enum Vision {
460460
}
461461
}
462462

463-
fileprivate class PhiMLP: Module, UnaryLayer {
463+
fileprivate class MLP: Module, UnaryLayer {
464464

465+
@ModuleInfo var activation: GELU
465466
@ModuleInfo var fc1: Linear
466467
@ModuleInfo var fc2: Linear
467468

468-
public init(_ config: Qwen2VLConfiguration.VisionConfiguration) {
469-
self.fc1 = Linear(config.hiddenSize, config.intermediateSize, bias: true)
470-
self.fc2 = Linear(config.intermediateSize, config.hiddenSize, bias: true)
469+
public init(dimensions: Int, hiddenDimensions: Int) {
470+
self.activation = GELU(approximation: .fast)
471+
self.fc1 = Linear(dimensions, hiddenDimensions)
472+
self.fc2 = Linear(hiddenDimensions, dimensions)
471473
}
472474

473475
public func callAsFunction(_ x: MLXArray) -> MLXArray {
474-
fc2(geluApproximate(fc1(x)))
475-
}
476-
}
477-
478-
fileprivate class EncoderLayer: Module {
479-
480-
@ModuleInfo(key: "self_attn") var attention: Attention
481-
@ModuleInfo(key: "layer_norm1") var layerNorm1: LayerNorm
482-
@ModuleInfo var mlp: PhiMLP
483-
@ModuleInfo(key: "layer_norm2") var layerNorm2: LayerNorm
484-
485-
public init(_ config: Qwen2VLConfiguration.VisionConfiguration) {
486-
self._attention.wrappedValue = Attention(
487-
dims: config.hiddenSize, numHeads: config.attentionHeads, bias: true)
488-
self._layerNorm1.wrappedValue = LayerNorm(
489-
dimensions: config.hiddenSize, eps: config.layerNormEps)
490-
self.mlp = PhiMLP(config)
491-
self._layerNorm2.wrappedValue = LayerNorm(
492-
dimensions: config.hiddenSize, eps: config.layerNormEps)
493-
}
494-
495-
public func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil) -> MLXArray {
496-
var r = attention(layerNorm1(x), mask: mask)
497-
let h = x + r
498-
r = mlp(layerNorm2(h))
499-
return h + r
476+
fc2(activation(fc1(x)))
500477
}
501478
}
502-
503-
fileprivate class Encoder: Module {
504-
var layers: [EncoderLayer]
505-
479+
480+
fileprivate class Qwen2VLVisionBlock: Module {
481+
482+
@ModuleInfo var norm1: LayerNorm
483+
@ModuleInfo var norm2: LayerNorm
484+
@ModuleInfo(key: "attn") var attention: Attention
485+
@ModuleInfo var mlp: MLP
486+
506487
public init(_ config: Qwen2VLConfiguration.VisionConfiguration) {
507-
self.layers = (0 ..< config.hiddenLayers).map { _ in
508-
EncoderLayer(config)
509-
}
510-
}
511-
512-
public func callAsFunction(
513-
_ x: MLXArray, outputHiddenStates: Bool = false, mask: MLXArray? = nil
514-
) -> (MLXArray, [MLXArray]?) {
515-
var encoderStates: [MLXArray]? = outputHiddenStates ? [] : nil
516-
var h = x
517-
var x = x
518-
for l in layers {
519-
x = l(x, mask: mask)
520-
if outputHiddenStates {
521-
encoderStates?.append(x)
522-
}
523-
h = x[0]
524-
}
525-
return (h, encoderStates)
488+
self.norm1 = LayerNorm(dimensions: config.embedDimensions, eps: 1e-6)
489+
self.norm2 = LayerNorm(dimensions: config.embedDimensions, eps: 1e-6)
490+
491+
self._attention.wrappedValue = Attention(dims: config.embedDimensions, numHeads: config.numHeads)
492+
493+
let mlpHiddenDimensions = Int(Float(config.embedDimensions) * config.mlpRatio)
494+
self.mlp = MLP(dimensions: config.embedDimensions, hiddenDimensions: mlpHiddenDimensions)
526495
}
527-
}
528-
529-
fileprivate class VisionEmbeddings: Module, UnaryLayer {
530-
531-
@ModuleInfo(key: "patch_embedding") var patchEmbedding: Conv2d
532-
@ModuleInfo(key: "position_embedding") var positionEmbedding: Embedding
533-
534-
let positions: Int
535-
let positionIds: MLXArray
536-
537-
public init(_ config: Qwen2VLConfiguration.VisionConfiguration) {
538-
self._patchEmbedding.wrappedValue = Conv2d(
539-
inputChannels: config.channels, outputChannels: config.hiddenSize,
540-
kernelSize: .init(config.patchSize), stride: .init(config.patchSize)
541-
)
542-
let d = config.imageSize / config.patchSize
543-
self.positions = d * d
544-
self._positionEmbedding.wrappedValue = Embedding(
545-
embeddingCount: positions, dimensions: config.hiddenSize
496+
497+
func callAsFunction(_ hiddenStates: MLXArray, cuSequenceLengths: MLXArray, rotaryPositionEmbedding: MLXArray) -> MLXArray {
498+
var hiddenStates = hiddenStates + attention(
499+
norm1(hiddenStates),
500+
cuSequenceLengths: cuSequenceLengths,
501+
rotaryPositionEmbedding: rotaryPositionEmbedding
546502
)
547-
self.positionIds = MLXArray(0 ..< positions)[.newAxis, 0...]
548-
}
549-
550-
public func callAsFunction(_ x: MLXArray) -> MLXArray {
551-
var patchEmbeddings = self.patchEmbedding(x)
552-
patchEmbeddings = patchEmbeddings.flattened(start: 1, end: 2)
553-
let embeddings = patchEmbeddings + self.positionEmbedding(self.positionIds)
554-
return embeddings
555-
}
556-
}
557-
558-
fileprivate class SigLipVisionModel: Module {
559-
560-
@ModuleInfo var embeddings: VisionEmbeddings
561-
@ModuleInfo var encoder: Encoder
562-
@ModuleInfo(key: "post_layernorm") var postLayerNorm: LayerNorm
563-
564-
public init(_ config: Qwen2VLConfiguration.VisionConfiguration) {
565-
self.embeddings = VisionEmbeddings(config)
566-
self.encoder = Encoder(config)
567-
self._postLayerNorm.wrappedValue = LayerNorm(dimensions: config.hiddenSize)
568-
}
569-
570-
public func callAsFunction(_ x: MLXArray, outputHiddenStates: Bool = false) -> (
571-
MLXArray, MLXArray, MLXArray?
572-
) {
573-
let x = embeddings(x)
574-
575-
let (encoderOutput, hiddenStates) = encoder(x, outputHiddenStates: outputHiddenStates)
576-
let poolerOutput = postLayerNorm(encoderOutput)
577-
578-
return (poolerOutput, x, hiddenStates?.last)
503+
hiddenStates = hiddenStates + mlp(norm2(hiddenStates))
504+
return hiddenStates
579505
}
580506
}
581507

582508
fileprivate class VisionModel: Module {
583509

584-
@ModuleInfo(key: "vision_model") var visionModel: SigLipVisionModel
510+
@ModuleInfo(key: "patch_embed") var patchEmbed: PatchEmbed
511+
@ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: VisionRotaryEmbedding
512+
@ModuleInfo(key: "blocks") var blocks: [Qwen2VLVisionBlock]
513+
@ModuleInfo(key: "merger") var patchMerger: PatchMerger
514+
515+
let spatialMergeSize: Int
585516

586517
public init(_ config: Qwen2VLConfiguration.VisionConfiguration) {
587518
precondition(
588-
config.modelType == "siglip_vision_model",
519+
config.modelType == "qwen2_vl",
589520
"Unsupported modelType: \(config.modelType)")
590-
self._visionModel.wrappedValue = SigLipVisionModel(config)
521+
522+
self.spatialMergeSize = config.spatialMergeSize
523+
524+
self._patchEmbed.wrappedValue = PatchEmbed(
525+
patchSize: config.patchSize,
526+
temporalPatchSize: config.temporalPatchSize,
527+
inChannels: config.inChannels,
528+
embedDimensions: config.embedDimensions)
529+
530+
let headDimensions = config.embedDimensions / config.numHeads
531+
self._rotaryPositionEmbedding.wrappedValue = VisionRotaryEmbedding(dimensions: headDimensions, theta: 10_000)
532+
533+
self._blocks.wrappedValue = (0 ..< config.depth).map { _ in
534+
Qwen2VLVisionBlock(config)
535+
}
536+
self.patchMerger = PatchMerger(dimensions: config.hiddenSize, contextDimensions: config.embedDimensions, spatialMergeSize: 2)
537+
}
538+
539+
func rotaryPositionEmbedding(_ gridThw: MLXArray) -> MLXArray {
540+
var positionIds = [MLXArray]()
541+
542+
for row in gridThw {
543+
// TODO NOTE: this evaluates gridThw -- it shouldn't do that
544+
let t = row[0].item(Int.self)
545+
let h = row[1].item(Int.self)
546+
let w = row[2].item(Int.self)
547+
548+
var hposIds = expandedDimensions(MLXArray(0 ..< h), axis: 1)
549+
hposIds = repeated(hposIds, count: w, axis: 1)
550+
hposIds = hposIds
551+
.reshaped(
552+
h / spatialMergeSize,
553+
spatialMergeSize,
554+
w / spatialMergeSize,
555+
spatialMergeSize)
556+
.transposed(0, 2, 1, 3)
557+
.flattened()
558+
559+
var wposIds = expandedDimensions(MLXArray(0 ..< w), axis: 0)
560+
wposIds = repeated(wposIds, count: h, axis: 0)
561+
wposIds = hposIds
562+
.reshaped(
563+
h / spatialMergeSize,
564+
spatialMergeSize,
565+
w / spatialMergeSize,
566+
spatialMergeSize)
567+
.transposed(0, 2, 1, 3)
568+
.flattened()
569+
570+
let stackedPosIds = stacked([hposIds, wposIds], axis: -1)
571+
positionIds.append(repeated(stackedPosIds, count: t, axis: 0))
572+
}
573+
574+
let indices = concatenated(positionIds, axis: 0)
575+
let maxGridSize = max(gridThw[0..., 1...])
576+
let rotaryPositionEmbedFull = rotaryPositionEmbedding(maxGridSize)[indices]
577+
578+
return rotaryPositionEmbedFull.reshaped(indices.dim(0), -1)
591579
}
592580

593-
public func callAsFunction(_ x: MLXArray, outputHiddenStates: Bool = false) -> (
594-
MLXArray, MLXArray, MLXArray?
595-
) {
596-
visionModel(x, outputHiddenStates: outputHiddenStates)
581+
public func callAsFunction(_ hiddenStates: MLXArray, gridThw: MLXArray) -> MLXArray {
582+
var hiddenStates = patchEmbed(hiddenStates)
583+
let rotaryPositionEmbedding = rotaryPositionEmbedding(gridThw)
584+
585+
// Assuming grid_thw has shape (batch_size, 3)
586+
let batchSize = gridThw.dim(0)
587+
588+
// Calculate cu_seqlens for each item in the batch
589+
var collect = [MLXArray]()
590+
for i in 0 ..< batchSize {
591+
let sequenceLength = gridThw[i, 1] * gridThw[i, 2]
592+
593+
// TODO NOTE: this evaluates gridThw -- it shouldn't do that
594+
let t = gridThw[i, 0].item(Int.self)
595+
collect.append(repeated(sequenceLength, count: t))
596+
}
597+
598+
// Concatenate the cu_seqlens for all items in the batch
599+
var cuSeqLengths = concatenated(collect)
600+
601+
cuSeqLengths = cumsum(cuSeqLengths.asType(Int32.self), axis: 0)
602+
cuSeqLengths = padded(cuSeqLengths, width: [1, 0], mode: .constant, value: MLXArray(0))
603+
604+
for block in blocks {
605+
hiddenStates = block(hiddenStates, cuSequenceLengths: cuSeqLengths, rotaryPositionEmbedding: rotaryPositionEmbedding)
606+
}
607+
608+
return patchMerger(hiddenStates)
597609
}
598610

599611
private func isMLXWeight(_ array: MLXArray) -> Bool {
@@ -616,15 +628,18 @@ private enum Vision {
616628
if k.contains("position_id") {
617629
// Remove unused position_ids
618630
continue
619-
} else if k.contains("patch_embedding.weight") {
631+
} else if k.contains("patch_embed.proj.weight") {
632+
// TODO: this comment doesn't match -- based on above code I presume
633+
// the first dimension is now B
634+
620635
// PyTorch conv2d weight tensors have shape:
621636
// [out_channels, in_channels, kH, KW]
622637
// MLX conv2d expects the weight be of shape:
623638
// [out_channels, kH, KW, in_channels]
624639
if isMLXWeight(v) {
625640
sanitizedWeights[k] = v
626641
} else {
627-
sanitizedWeights[k] = v.transposed(0, 2, 3, 1)
642+
sanitizedWeights[k] = v.transposed(0, 2, 3, 4, 1)
628643
}
629644
} else {
630645
sanitizedWeights[k] = v
@@ -882,8 +897,8 @@ public struct Qwen2VLConfiguration: Codable, Sendable {
882897
public let patchSize: Int
883898
public let vocabularySize: Int
884899
public let mlpRatio: Float
885-
public let _channels: Int?
886-
public var channels: Int { _channels ?? 3 }
900+
public let _inChannels: Int?
901+
public var inChannels: Int { _inChannels ?? 3 }
887902
public let _layerNormEps: Float?
888903
public var layerNormEps: Float { _layerNormEps ?? 1e-6 }
889904
public let spatialPatchSize: Int
@@ -900,7 +915,7 @@ public struct Qwen2VLConfiguration: Codable, Sendable {
900915
case patchSize = "patch_size"
901916
case vocabularySize = "vocab_size"
902917
case mlpRatio = "mlp_ratio"
903-
case _channels = "num_channels"
918+
case _inChannels = "in_channels"
904919
case _layerNormEps = "layer_norm_eps"
905920
case spatialPatchSize = "spatial_patch_size"
906921
case spatialMergeSize = "spatial_merge_size"

0 commit comments

Comments
 (0)