Skip to content

Commit 05879b4

Browse files
authored
Fix test layer (zml#124)
* testLayer was silently broken by zml#115 : test + fix so we detect this in the future. * mapAlloc had a regression in zml#115 wrt struct with 0 size fields: test + fix * add Python `venv` to gitignore
1 parent c154327 commit 05879b4

File tree

6 files changed

+52
-10
lines changed

6 files changed

+52
-10
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ zig-out/
2323
# Data files
2424
*.wav
2525
*.ecdc
26+
27+
# Python slope
2628
.venv
29+
venv
2730

2831
# Editor specific
2932
## Neovim

zml/context.zig

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ const runtimes = @import("runtimes");
66
const std = @import("std");
77
const stdx = @import("stdx");
88

9-
const platform = @import("platform.zig");
9+
const zml_platform = @import("platform.zig");
1010
const pjrt = @import("pjrtx.zig");
1111

1212
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
@@ -18,6 +18,10 @@ const Target = @import("platform.zig").Target;
1818
const available_targets = @import("platform.zig").available_targets;
1919
const log = std.log.scoped(.@"zml/context");
2020

21+
test {
22+
std.testing.refAllDecls(Context);
23+
}
24+
2125
/// Every program using ZML must start with a `zml.Context.init(.{});`
2226
/// The ZML context contains global state to interact with the different
2327
/// devices available on your system.
@@ -149,15 +153,15 @@ pub const Context = struct {
149153
return platform_ orelse @panic("No platform found !");
150154
}
151155

152-
pub fn printAvailablePlatforms(self: Context, selected: platform.Platform) void {
156+
pub fn printAvailablePlatforms(self: Context, selected: Platform) void {
153157
// List available targets
154158
log.info("Available Platforms:", .{});
155159
const selected_prefix = "✅";
156160
const not_selected_prefix = "• ";
157161
const selected_postfix = "(AUTO-SELECTED)";
158162
const not_selected_postfix = "";
159163

160-
for (platform.available_targets) |target| {
164+
for (zml_platform.available_targets) |target| {
161165
log.info(" {s} {s} {s}", .{
162166
if (target == selected.target) selected_prefix else not_selected_prefix,
163167
@tagName(target),

zml/exe.zig

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,7 @@ pub const BaseExe = struct {
241241
shards.appendAssumeCapacity(dev_out[i]);
242242
}
243243

244-
const out_shape = self.inner.result_buffer_shapes[i];
245-
return Buffer.fromPjrtBuffers(self.platform(), out_shape, shards.constSlice());
244+
return Buffer.fromPjrtBuffers(self.platform, self.result_shapes[i], shards.constSlice());
246245
}
247246
};
248247

zml/meta.zig

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,11 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
120120
return;
121121
}
122122

123+
if (@sizeOf(ToStruct) == 0) return;
124+
123125
switch (type_info_to) {
124126
.Struct => |info| inline for (info.fields) |field| {
125-
// if (field.is_comptime) continue;
127+
if (field.is_comptime or @sizeOf(field.type) == 0) continue;
126128
const field_type_info = @typeInfo(field.type);
127129
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
128130
switch (field_type_info) {
@@ -187,6 +189,8 @@ test mapAlloc {
187189
}
188190
};
189191

192+
const Empty = struct {};
193+
190194
const AA = struct {
191195
field: A,
192196
array: [2]A,
@@ -195,6 +199,7 @@ test mapAlloc {
195199
// We want to allow conversion from comptime to runtime, because Zig type inference works like this.
196200
comptime static_val: u8 = 8,
197201
comptime static_slice: [2]A = .{ .{ .a = 11 }, .{ .a = 12 } },
202+
field_with_empty: struct { A, Empty },
198203
};
199204
const BB = struct {
200205
field: B,
@@ -203,13 +208,15 @@ test mapAlloc {
203208
other: u8,
204209
static_val: u8,
205210
static_slice: []B,
211+
field_with_empty: struct { B, Empty },
206212
};
207213

208214
const aa: AA = .{
209215
.field = .{ .a = 4 },
210216
.array = .{ .{ .a = 5 }, .{ .a = 6 } },
211217
.other = 7,
212218
.slice = &.{ .{ .a = 9 }, .{ .a = 10 } },
219+
.field_with_empty = .{ .{ .a = 9 }, .{} },
213220
};
214221
var bb: BB = undefined;
215222

zml/testing.zig

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,8 @@ pub fn testLayerOut(
201201
const exe = try zml.compileModel(alloc, fwd, layer, input_shapes, platform);
202202

203203
const n_out_exp = activations.countLayers(out_name);
204-
if (exe.inner.result_buffer_count != n_out_exp) {
205-
log.warn("Reference models produces {d} outputs, but implementation produces {d}", .{ n_out_exp, exe.inner.result_buffer_count });
204+
if (exe.inner.result_shapes.len != n_out_exp) {
205+
log.warn("Reference models produces {d} outputs, but implementation produces {d}", .{ n_out_exp, exe.inner.result_shapes.len });
206206
}
207207
const mod = exe.prepare(layer_weights);
208208

@@ -243,13 +243,13 @@ pub fn testLayerOut(
243243

244244
var buf: [1024]u8 = undefined;
245245
var failed: bool = false;
246-
for (0..mod.inner.result_buffer_count) |i| {
246+
for (0..mod.inner.result_shapes.len) |i| {
247247
const full_name = std.fmt.bufPrint(&buf, "{s}.{d}", .{ out_name, i }) catch unreachable;
248248
const expected_out = activations.get(full_name) orelse {
249249
log.warn("Output buffer not found: {s}", .{full_name});
250250
continue;
251251
};
252-
zml.testing.expectClose(expected_out, mod.getOutputBuffer(i), tolerance) catch |err| switch (err) {
252+
zml.testing.expectClose(expected_out, mod.inner.getOutputBuffer(i), tolerance) catch |err| switch (err) {
253253
error.TestUnexpectedResult => {
254254
log.err("{s}.{d} doesn't match !", .{ out_name, i });
255255
failed = true;
@@ -263,6 +263,34 @@ pub fn testLayerOut(
263263
log.info("all good for {s} !", .{name});
264264
}
265265

266+
test testLayer {
267+
const platform = env();
268+
269+
// create a model
270+
const layer: zml.nn.Linear = .{
271+
.weight = zml.Tensor{ ._shape = zml.Shape.init(.{ 5, 2 }, .f32), ._id = .{ .buffer_id = 42 } },
272+
};
273+
const layer_weights: zml.Bufferized(zml.nn.Linear) = .{
274+
.weight = try zml.Buffer.fromArray(
275+
platform,
276+
[5][2]f32{ .{ 0, 0 }, .{ 0, 1 }, .{ 1, 2 }, .{ -1, -1 }, .{ -1, 0 } },
277+
),
278+
};
279+
280+
// create a buffer store containing the activations:
281+
var activations = try zml.aio.BufferStore.init(std.testing.allocator, &.{});
282+
defer activations.deinit();
283+
{
284+
const input = zml.HostBuffer.fromArray(&[2]f32{ 1, -1 });
285+
try activations.buffers.put(activations.arena.allocator(), "model.layer.in.0", input);
286+
const output = zml.HostBuffer.fromArray(&[5]f32{ 0, -1, -1, 0, -1 });
287+
try activations.buffers.put(activations.arena.allocator(), "model.layer.out.0", output);
288+
}
289+
290+
// test the ZML layer reproduces the "captured" activations:
291+
try zml.testing.testLayer(platform, activations, "model.layer", layer, layer_weights, 1e-5);
292+
}
293+
266294
pub inline fn expectEqual(expected: anytype, actual: @TypeOf(expected)) !void {
267295
return std.testing.expectEqual(expected, actual);
268296
}

zml/zml.zig

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ pub const compileFn = exe.compileFn;
3636
pub const compileModel = exe.compileModel;
3737
pub const FnExe = exe.FnExe;
3838
pub const ModuleExe = exe.ModuleExe;
39+
pub const ModuleSignature = exe.ModuleSignature;
3940

4041
pub const ops = @import("ops.zig");
4142
pub const tools = struct {

0 commit comments

Comments
 (0)