@@ -460,140 +460,152 @@ private enum Vision {
460460 }
461461 }
462462
463- fileprivate class PhiMLP : Module , UnaryLayer {
463+ fileprivate class MLP : Module , UnaryLayer {
464464
465+ @ModuleInfo var activation : GELU
465466 @ModuleInfo var fc1 : Linear
466467 @ModuleInfo var fc2 : Linear
467468
468- public init ( _ config: Qwen2VLConfiguration . VisionConfiguration ) {
469- self . fc1 = Linear ( config. hiddenSize, config. intermediateSize, bias: true )
470- self . fc2 = Linear ( config. intermediateSize, config. hiddenSize, bias: true )
469+ public init ( dimensions: Int , hiddenDimensions: Int ) {
470+ self . activation = GELU ( approximation: . fast)
471+ self . fc1 = Linear ( dimensions, hiddenDimensions)
472+ self . fc2 = Linear ( hiddenDimensions, dimensions)
471473 }
472474
473475 public func callAsFunction( _ x: MLXArray ) -> MLXArray {
474- fc2 ( geluApproximate ( fc1 ( x) ) )
475- }
476- }
477-
478- fileprivate class EncoderLayer : Module {
479-
480- @ModuleInfo ( key: " self_attn " ) var attention : Attention
481- @ModuleInfo ( key: " layer_norm1 " ) var layerNorm1 : LayerNorm
482- @ModuleInfo var mlp : PhiMLP
483- @ModuleInfo ( key: " layer_norm2 " ) var layerNorm2 : LayerNorm
484-
485- public init ( _ config: Qwen2VLConfiguration . VisionConfiguration ) {
486- self . _attention. wrappedValue = Attention (
487- dims: config. hiddenSize, numHeads: config. attentionHeads, bias: true )
488- self . _layerNorm1. wrappedValue = LayerNorm (
489- dimensions: config. hiddenSize, eps: config. layerNormEps)
490- self . mlp = PhiMLP ( config)
491- self . _layerNorm2. wrappedValue = LayerNorm (
492- dimensions: config. hiddenSize, eps: config. layerNormEps)
493- }
494-
495- public func callAsFunction( _ x: MLXArray , mask: MLXArray ? = nil ) -> MLXArray {
496- var r = attention ( layerNorm1 ( x) , mask: mask)
497- let h = x + r
498- r = mlp ( layerNorm2 ( h) )
499- return h + r
476+ fc2 ( activation ( fc1 ( x) ) )
500477 }
501478 }
502-
503- fileprivate class Encoder : Module {
504- var layers : [ EncoderLayer ]
505-
479+
480+ fileprivate class Qwen2VLVisionBlock : Module {
481+
482+ @ModuleInfo var norm1 : LayerNorm
483+ @ModuleInfo var norm2 : LayerNorm
484+ @ModuleInfo ( key: " attn " ) var attention : Attention
485+ @ModuleInfo var mlp : MLP
486+
506487 public init ( _ config: Qwen2VLConfiguration . VisionConfiguration ) {
507- self . layers = ( 0 ..< config. hiddenLayers) . map { _ in
508- EncoderLayer ( config)
509- }
510- }
511-
512- public func callAsFunction(
513- _ x: MLXArray , outputHiddenStates: Bool = false , mask: MLXArray ? = nil
514- ) -> ( MLXArray , [ MLXArray ] ? ) {
515- var encoderStates : [ MLXArray ] ? = outputHiddenStates ? [ ] : nil
516- var h = x
517- var x = x
518- for l in layers {
519- x = l ( x, mask: mask)
520- if outputHiddenStates {
521- encoderStates? . append ( x)
522- }
523- h = x [ 0 ]
524- }
525- return ( h, encoderStates)
488+ self . norm1 = LayerNorm ( dimensions: config. embedDimensions, eps: 1e-6 )
489+ self . norm2 = LayerNorm ( dimensions: config. embedDimensions, eps: 1e-6 )
490+
491+ self . _attention. wrappedValue = Attention ( dims: config. embedDimensions, numHeads: config. numHeads)
492+
493+ let mlpHiddenDimensions = Int ( Float ( config. embedDimensions) * config. mlpRatio)
494+ self . mlp = MLP ( dimensions: config. embedDimensions, hiddenDimensions: mlpHiddenDimensions)
526495 }
527- }
528-
529- fileprivate class VisionEmbeddings : Module , UnaryLayer {
530-
531- @ModuleInfo ( key: " patch_embedding " ) var patchEmbedding : Conv2d
532- @ModuleInfo ( key: " position_embedding " ) var positionEmbedding : Embedding
533-
534- let positions : Int
535- let positionIds : MLXArray
536-
537- public init ( _ config: Qwen2VLConfiguration . VisionConfiguration ) {
538- self . _patchEmbedding. wrappedValue = Conv2d (
539- inputChannels: config. channels, outputChannels: config. hiddenSize,
540- kernelSize: . init( config. patchSize) , stride: . init( config. patchSize)
541- )
542- let d = config. imageSize / config. patchSize
543- self . positions = d * d
544- self . _positionEmbedding. wrappedValue = Embedding (
545- embeddingCount: positions, dimensions: config. hiddenSize
496+
497+ func callAsFunction( _ hiddenStates: MLXArray , cuSequenceLengths: MLXArray , rotaryPositionEmbedding: MLXArray ) -> MLXArray {
498+ var hiddenStates = hiddenStates + attention(
499+ norm1 ( hiddenStates) ,
500+ cuSequenceLengths: cuSequenceLengths,
501+ rotaryPositionEmbedding: rotaryPositionEmbedding
546502 )
547- self . positionIds = MLXArray ( 0 ..< positions) [ . newAxis, 0 ... ]
548- }
549-
550- public func callAsFunction( _ x: MLXArray ) -> MLXArray {
551- var patchEmbeddings = self . patchEmbedding ( x)
552- patchEmbeddings = patchEmbeddings. flattened ( start: 1 , end: 2 )
553- let embeddings = patchEmbeddings + self . positionEmbedding ( self . positionIds)
554- return embeddings
555- }
556- }
557-
558- fileprivate class SigLipVisionModel : Module {
559-
560- @ModuleInfo var embeddings : VisionEmbeddings
561- @ModuleInfo var encoder : Encoder
562- @ModuleInfo ( key: " post_layernorm " ) var postLayerNorm : LayerNorm
563-
564- public init ( _ config: Qwen2VLConfiguration . VisionConfiguration ) {
565- self . embeddings = VisionEmbeddings ( config)
566- self . encoder = Encoder ( config)
567- self . _postLayerNorm. wrappedValue = LayerNorm ( dimensions: config. hiddenSize)
568- }
569-
570- public func callAsFunction( _ x: MLXArray , outputHiddenStates: Bool = false ) -> (
571- MLXArray , MLXArray , MLXArray ?
572- ) {
573- let x = embeddings ( x)
574-
575- let ( encoderOutput, hiddenStates) = encoder ( x, outputHiddenStates: outputHiddenStates)
576- let poolerOutput = postLayerNorm ( encoderOutput)
577-
578- return ( poolerOutput, x, hiddenStates? . last)
503+ hiddenStates = hiddenStates + mlp( norm2 ( hiddenStates) )
504+ return hiddenStates
579505 }
580506 }
581507
582508 fileprivate class VisionModel : Module {
583509
584- @ModuleInfo ( key: " vision_model " ) var visionModel : SigLipVisionModel
510+ @ModuleInfo ( key: " patch_embed " ) var patchEmbed : PatchEmbed
511+ @ModuleInfo ( key: " rotary_pos_emb " ) var rotaryPositionEmbedding : VisionRotaryEmbedding
512+ @ModuleInfo ( key: " blocks " ) var blocks : [ Qwen2VLVisionBlock ]
513+ @ModuleInfo ( key: " merger " ) var patchMerger : PatchMerger
514+
515+ let spatialMergeSize : Int
585516
586517 public init ( _ config: Qwen2VLConfiguration . VisionConfiguration ) {
587518 precondition (
588- config. modelType == " siglip_vision_model " ,
519+ config. modelType == " qwen2_vl " ,
589520 " Unsupported modelType: \( config. modelType) " )
590- self . _visionModel. wrappedValue = SigLipVisionModel ( config)
521+
522+ self . spatialMergeSize = config. spatialMergeSize
523+
524+ self . _patchEmbed. wrappedValue = PatchEmbed (
525+ patchSize: config. patchSize,
526+ temporalPatchSize: config. temporalPatchSize,
527+ inChannels: config. inChannels,
528+ embedDimensions: config. embedDimensions)
529+
530+ let headDimensions = config. embedDimensions / config. numHeads
531+ self . _rotaryPositionEmbedding. wrappedValue = VisionRotaryEmbedding ( dimensions: headDimensions, theta: 10_000 )
532+
533+ self . _blocks. wrappedValue = ( 0 ..< config. depth) . map { _ in
534+ Qwen2VLVisionBlock ( config)
535+ }
536+ self . patchMerger = PatchMerger ( dimensions: config. hiddenSize, contextDimensions: config. embedDimensions, spatialMergeSize: 2 )
537+ }
538+
539+ func rotaryPositionEmbedding( _ gridThw: MLXArray ) -> MLXArray {
540+ var positionIds = [ MLXArray] ( )
541+
542+ for row in gridThw {
543+ // TODO NOTE: this evaluates gridThw -- it shouldn't do that
544+ let t = row [ 0 ] . item ( Int . self)
545+ let h = row [ 1 ] . item ( Int . self)
546+ let w = row [ 2 ] . item ( Int . self)
547+
548+ var hposIds = expandedDimensions ( MLXArray ( 0 ..< h) , axis: 1 )
549+ hposIds = repeated ( hposIds, count: w, axis: 1 )
550+ hposIds = hposIds
551+ . reshaped (
552+ h / spatialMergeSize,
553+ spatialMergeSize,
554+ w / spatialMergeSize,
555+ spatialMergeSize)
556+ . transposed ( 0 , 2 , 1 , 3 )
557+ . flattened ( )
558+
559+ var wposIds = expandedDimensions ( MLXArray ( 0 ..< w) , axis: 0 )
560+ wposIds = repeated ( wposIds, count: h, axis: 0 )
561+ wposIds = hposIds
562+ . reshaped (
563+ h / spatialMergeSize,
564+ spatialMergeSize,
565+ w / spatialMergeSize,
566+ spatialMergeSize)
567+ . transposed ( 0 , 2 , 1 , 3 )
568+ . flattened ( )
569+
570+ let stackedPosIds = stacked ( [ hposIds, wposIds] , axis: - 1 )
571+ positionIds. append ( repeated ( stackedPosIds, count: t, axis: 0 ) )
572+ }
573+
574+ let indices = concatenated ( positionIds, axis: 0 )
575+ let maxGridSize = max ( gridThw [ 0 ... , 1 ... ] )
576+ let rotaryPositionEmbedFull = rotaryPositionEmbedding ( maxGridSize) [ indices]
577+
578+ return rotaryPositionEmbedFull. reshaped ( indices. dim ( 0 ) , - 1 )
591579 }
592580
593- public func callAsFunction( _ x: MLXArray , outputHiddenStates: Bool = false ) -> (
594- MLXArray , MLXArray , MLXArray ?
595- ) {
596- visionModel ( x, outputHiddenStates: outputHiddenStates)
581+ public func callAsFunction( _ hiddenStates: MLXArray , gridThw: MLXArray ) -> MLXArray {
582+ var hiddenStates = patchEmbed ( hiddenStates)
583+ let rotaryPositionEmbedding = rotaryPositionEmbedding ( gridThw)
584+
585+ // Assuming grid_thw has shape (batch_size, 3)
586+ let batchSize = gridThw. dim ( 0 )
587+
588+ // Calculate cu_seqlens for each item in the batch
589+ var collect = [ MLXArray] ( )
590+ for i in 0 ..< batchSize {
591+ let sequenceLength = gridThw [ i, 1 ] * gridThw[ i, 2 ]
592+
593+ // TODO NOTE: this evaluates gridThw -- it shouldn't do that
594+ let t = gridThw [ i, 0 ] . item ( Int . self)
595+ collect. append ( repeated ( sequenceLength, count: t) )
596+ }
597+
598+ // Concatenate the cu_seqlens for all items in the batch
599+ var cuSeqLengths = concatenated ( collect)
600+
601+ cuSeqLengths = cumsum ( cuSeqLengths. asType ( Int32 . self) , axis: 0 )
602+ cuSeqLengths = padded ( cuSeqLengths, width: [ 1 , 0 ] , mode: . constant, value: MLXArray ( 0 ) )
603+
604+ for block in blocks {
605+ hiddenStates = block ( hiddenStates, cuSequenceLengths: cuSeqLengths, rotaryPositionEmbedding: rotaryPositionEmbedding)
606+ }
607+
608+ return patchMerger ( hiddenStates)
597609 }
598610
599611 private func isMLXWeight( _ array: MLXArray ) -> Bool {
@@ -616,15 +628,18 @@ private enum Vision {
616628 if k. contains ( " position_id " ) {
617629 // Remove unused position_ids
618630 continue
619- } else if k. contains ( " patch_embedding.weight " ) {
631+ } else if k. contains ( " patch_embed.proj.weight " ) {
632+ // TODO: this comment doesn't match -- based on above code I presume
633+ // the first dimension is now B
634+
620635 // PyTorch conv2d weight tensors have shape:
621636 // [out_channels, in_channels, kH, KW]
622637 // MLX conv2d expects the weight be of shape:
623638 // [out_channels, kH, KW, in_channels]
624639 if isMLXWeight ( v) {
625640 sanitizedWeights [ k] = v
626641 } else {
627- sanitizedWeights [ k] = v. transposed ( 0 , 2 , 3 , 1 )
642+ sanitizedWeights [ k] = v. transposed ( 0 , 2 , 3 , 4 , 1 )
628643 }
629644 } else {
630645 sanitizedWeights [ k] = v
@@ -882,8 +897,8 @@ public struct Qwen2VLConfiguration: Codable, Sendable {
882897 public let patchSize : Int
883898 public let vocabularySize : Int
884899 public let mlpRatio : Float
885- public let _channels : Int ?
886- public var channels : Int { _channels ?? 3 }
900+ public let _inChannels : Int ?
901+ public var inChannels : Int { _inChannels ?? 3 }
887902 public let _layerNormEps : Float ?
888903 public var layerNormEps : Float { _layerNormEps ?? 1e-6 }
889904 public let spatialPatchSize : Int
@@ -900,7 +915,7 @@ public struct Qwen2VLConfiguration: Codable, Sendable {
900915 case patchSize = " patch_size "
901916 case vocabularySize = " vocab_size "
902917 case mlpRatio = " mlp_ratio "
903- case _channels = " num_channels "
918+ case _inChannels = " in_channels "
904919 case _layerNormEps = " layer_norm_eps "
905920 case spatialPatchSize = " spatial_patch_size "
906921 case spatialMergeSize = " spatial_merge_size "
0 commit comments