@@ -24,6 +24,31 @@ Add this to your `Cargo.toml`:
24
24
mpsgraph = " 0.1.0"
25
25
```
26
26
27
+ For development with the latest version:
28
+
29
+ ``` toml
30
+ [dependencies ]
31
+ mpsgraph = { git = " https://github.com/eugenebokhan/mpsgraph-rs" , package = " mpsgraph" }
32
+ ```
33
+
34
+ ## Dependencies
35
+
36
+ This crate depends on:
37
+
38
+ - ** objc2** (0.6.0): Safe Rust bindings to Objective-C
39
+ - ** objc2-foundation** (0.3.0): Rust bindings for Apple's Foundation framework
40
+ - ** metal** (0.32.0): Rust bindings for Apple's Metal API
41
+ - ** bitflags** (2.9.0): Macro for generating bitflag structures
42
+ - ** foreign-types** (0.5): FFI type handling utilities
43
+ - ** block** (0.1.6): Support for Objective-C blocks
44
+ - ** rand** (0.9.0): Random number generation utilities
45
+
46
+ The crate also requires linking against:
47
+
48
+ - MetalPerformanceShaders.framework
49
+ - Metal.framework
50
+ - Foundation.framework
51
+
27
52
## Example
28
53
29
54
``` rust
@@ -89,116 +114,6 @@ fn main() {
89
114
- Tensor reshaping and transposition
90
115
- Graph compilation for repeated execution
91
116
92
- ## Advanced Example: Neural Network with MPSGraph
93
-
94
- ``` rust
95
- use mpsgraph :: {Graph , MPSShapeDescriptor , MPSDataType };
96
- use metal :: {Device , MTLResourceOptions };
97
- use std :: collections :: HashMap ;
98
-
99
- fn main () {
100
- // Get the system default Metal device
101
- let device = Device :: system_default (). expect (" No Metal device found" );
102
-
103
- // Create a graph
104
- let graph = Graph :: new (). expect (" Failed to create graph" );
105
-
106
- // Define the neural network architecture
107
- // Input: 784 features (28x28 image flattened)
108
- // Hidden layer: 128 neurons with ReLU activation
109
- // Output: 10 classes with softmax activation
110
-
111
- // Input placeholder
112
- let input_shape = MPSShapeDescriptor :: new (vec! [1 , 784 ], MPSDataType :: Float32 );
113
- let x = graph . placeholder (& input_shape , Some (" input" ));
114
-
115
- // First layer weights and biases
116
- let w1_shape = MPSShapeDescriptor :: new (vec! [784 , 128 ], MPSDataType :: Float32 );
117
- let b1_shape = MPSShapeDescriptor :: new (vec! [1 , 128 ], MPSDataType :: Float32 );
118
-
119
- // Create random weights (would normally be trained or loaded from a file)
120
- let mut w1_data = vec! [0.0f32 ; 784 * 128 ];
121
- let mut b1_data = vec! [0.0f32 ; 128 ];
122
-
123
- // Initialize with small random values (simplified)
124
- for i in 0 .. w1_data . len () {
125
- w1_data [i ] = (i as f32 * 0.0001 ) - 0.05 ;
126
- }
127
-
128
- let w1 = graph . constant_with_data (& w1_data , & w1_shape , Some (" w1" ));
129
- let b1 = graph . constant_with_data (& b1_data , & b1_shape , Some (" b1" ));
130
-
131
- // First layer computation: h1 = ReLU(x · w1 + b1)
132
- let xw1 = graph . matmul (& x , & w1 , Some (" xw1" ));
133
- let xw1_plus_b1 = graph . add (& xw1 , & b1 , Some (" logits1" ));
134
- let h1 = graph . relu (& xw1_plus_b1 , Some (" hidden1" ));
135
-
136
- // Second layer (output layer)
137
- let w2_shape = MPSShapeDescriptor :: new (vec! [128 , 10 ], MPSDataType :: Float32 );
138
- let b2_shape = MPSShapeDescriptor :: new (vec! [1 , 10 ], MPSDataType :: Float32 );
139
-
140
- let mut w2_data = vec! [0.0f32 ; 128 * 10 ];
141
- let mut b2_data = vec! [0.0f32 ; 10 ];
142
-
143
- // Initialize with small random values (simplified)
144
- for i in 0 .. w2_data . len () {
145
- w2_data [i ] = (i as f32 * 0.001 ) - 0.05 ;
146
- }
147
-
148
- let w2 = graph . constant_with_data (& w2_data , & w2_shape , Some (" w2" ));
149
- let b2 = graph . constant_with_data (& b2_data , & b2_shape , Some (" b2" ));
150
-
151
- // Output layer computation: y = softmax(h1 · w2 + b2)
152
- let h1w2 = graph . matmul (& h1 , & w2 , Some (" h1w2" ));
153
- let logits = graph . add (& h1w2 , & b2 , Some (" logits" ));
154
- let probs = graph . softmax (& logits , 1 , Some (" probabilities" ));
155
-
156
- // Create a sample input (a simplified image)
157
- let mut input_data = vec! [0.0f32 ; 784 ];
158
- for i in 0 .. 784 {
159
- // Create a simple pattern
160
- input_data [i ] = if (i / 28 + i % 28 ) % 2 == 0 { 0.9 } else { 0.1 };
161
- }
162
-
163
- // Create input buffer
164
- let input_buffer = device . new_buffer_with_data (
165
- input_data . as_ptr () as * const _ ,
166
- (784 * std :: mem :: size_of :: <f32 >()) as u64 ,
167
- MTLResourceOptions :: StorageModeShared
168
- );
169
-
170
- // Create feed dictionary
171
- let mut feed_dict = HashMap :: new ();
172
- feed_dict . insert (& x , input_buffer . deref ());
173
-
174
- // Run the graph
175
- let results = graph . run (& device , feed_dict , & [& probs ]);
176
- assert_eq! (results . len (), 1 );
177
-
178
- // Process and print the results (class probabilities)
179
- unsafe {
180
- let probs_ptr = results [0 ]. contents () as * const f32 ;
181
- let probabilities = std :: slice :: from_raw_parts (probs_ptr , 10 );
182
-
183
- println! (" Class probabilities:" );
184
- for (i , & prob ) in probabilities . iter (). enumerate () {
185
- println! (" Class {}: {:.6}" , i , prob );
186
- }
187
-
188
- // Find the predicted class (highest probability)
189
- let mut max_idx = 0 ;
190
- let mut max_prob = probabilities [0 ];
191
- for i in 1 .. 10 {
192
- if probabilities [i ] > max_prob {
193
- max_idx = i ;
194
- max_prob = probabilities [i ];
195
- }
196
- }
197
-
198
- println! (" Predicted class: {} (probability: {:.6})" , max_idx , max_prob );
199
- }
200
- }
201
-
202
117
## License
203
118
204
119
Licensed under the MIT License.
0 commit comments