diff --git a/src/int32.ml b/src/int32.ml index cb3b914b7..9954cea41 100644 --- a/src/int32.ml +++ b/src/int32.ml @@ -2,9 +2,14 @@ include Stdlib.Int32 -let clz n = Stdlib.Int32.of_int (Ocaml_intrinsics.Int32.count_leading_zeros n) - -let ctz n = Stdlib.Int32.of_int (Ocaml_intrinsics.Int32.count_trailing_zeros n) +let clz = + Some + (fun n -> Stdlib.Int32.of_int (Ocaml_intrinsics.Int32.count_leading_zeros n)) + +let ctz = + Some + (fun n -> + Stdlib.Int32.of_int (Ocaml_intrinsics.Int32.count_trailing_zeros n) ) (* Taken from Base https://github.com/janestreet/base *) let popcnt = diff --git a/src/int32.mli b/src/int32.mli index ec565297e..66e1acc94 100644 --- a/src/int32.mli +++ b/src/int32.mli @@ -34,9 +34,9 @@ val unsigned_to_int : t -> int option (** unary operators *) -val clz : t -> t +val clz : (t -> t) option -val ctz : t -> t +val ctz : (t -> t) option val popcnt : t -> t diff --git a/src/int64.ml b/src/int64.ml index 570e71801..bd10c431f 100644 --- a/src/int64.ml +++ b/src/int64.ml @@ -2,9 +2,14 @@ include Stdlib.Int64 -let clz n = Stdlib.Int64.of_int (Ocaml_intrinsics.Int64.count_leading_zeros n) - -let ctz n = Stdlib.Int64.of_int (Ocaml_intrinsics.Int64.count_trailing_zeros n) +let clz = + Some + (fun n -> Stdlib.Int64.of_int (Ocaml_intrinsics.Int64.count_leading_zeros n)) + +let ctz = + Some + (fun n -> + Stdlib.Int64.of_int (Ocaml_intrinsics.Int64.count_trailing_zeros n) ) (* Taken from Base: https://github.com/janestreet/base *) let popcnt = diff --git a/src/int64.mli b/src/int64.mli index d82226f5c..88cecfeb4 100644 --- a/src/int64.mli +++ b/src/int64.mli @@ -34,9 +34,9 @@ val extend_s : int -> t -> t val abs : t -> t -val clz : t -> t +val clz : (t -> t) option -val ctz : t -> t +val ctz : (t -> t) option val popcnt : t -> t diff --git a/src/interpret.ml b/src/interpret.ml index 7a713026f..c3c38be3e 100644 --- a/src/interpret.ml +++ b/src/interpret.ml @@ -99,20 +99,92 @@ module Make (P : Interpret_intf.P) : let consti i = const_i32 (Int32.of_int i) + let clz_impl_32 n = + let rec aux (lb : int) ub = + if ub = lb + 1 then return (const_i32 (Int32.of_int (32 - ub))) + else begin + let mid = (lb + ub) / 2 in + let two_pow_mid = Int32.shl 1l (Int32.of_int mid) in + let> cond = I32.(lt_u n (const_i32 two_pow_mid)) in + if cond then aux lb mid else aux mid ub + end + in + let> cond = I32.(eqz n) in + if cond then return @@ const_i32 32l else aux 0 32 + + let clz_impl_64 n = + let rec aux (lb : int) ub = + if ub = lb + 1 then return (const_i64 (Int64.of_int (64 - ub))) + else begin + let mid = (lb + ub) / 2 in + let two_pow_mid = Int64.shl 1L (Int64.of_int mid) in + (* Could be more efficient with a shift right mid, to bench *) + let> cond = I64.(lt_u n (const_i64 two_pow_mid)) in + if cond then aux lb mid else aux mid ub + end + in + let> cond = I64.(eqz n) in + if cond then return @@ const_i64 64L else aux 0 64 + + let ctz_impl_32 n = + let rec aux (lb : int) ub = + if ub = lb + 1 then return (const_i32 (Int32.of_int lb)) + else begin + let mid = (lb + ub) / 2 in + let two_pow_mid = Int32.shl 1l (Int32.of_int mid) in + let> cond = I32.(eqz @@ rem n (const_i32 two_pow_mid)) in + if cond then aux mid ub else aux lb mid + end + in + let> cond = I32.(eqz n) in + if cond then return @@ const_i32 32l else aux 0 32 + + let ctz_impl_64 n = + let rec aux (lb : int) ub = + if ub = lb + 1 then return (const_i64 (Int64.of_int lb)) + else begin + let mid = (lb + ub) / 2 in + let two_pow_mid = Int64.shl 1L (Int64.of_int mid) in + let> cond = I64.(eqz @@ rem n (const_i64 two_pow_mid)) in + if cond then aux mid ub else aux lb mid + end + in + let> cond = I64.(eqz n) in + if cond then return @@ const_i64 64L else aux 0 64 + + let with_choosing_default_impl f ch_f = + match f with + | Some f -> fun n -> Choice.return (f n) + | None -> fun n -> ch_f n + let exec_iunop stack nn op = match nn with | S32 -> let n, stack = Stack.pop_i32 stack in - let res = + let+ res = let open I32 in - match op with Clz -> clz n | Ctz -> ctz n | Popcnt -> popcnt n + match op with + | Clz -> + let clz = with_choosing_default_impl clz clz_impl_32 in + clz n + | Ctz -> + let ctz = with_choosing_default_impl ctz ctz_impl_32 in + ctz n + | Popcnt -> Choice.return @@ popcnt n in Stack.push_i32 stack res | S64 -> let n, stack = Stack.pop_i64 stack in - let res = + let+ res = let open I64 in - match op with Clz -> clz n | Ctz -> ctz n | Popcnt -> popcnt n + match op with + | Clz -> + let clz = with_choosing_default_impl clz clz_impl_64 in + clz n + | Ctz -> + let ctz = with_choosing_default_impl ctz ctz_impl_64 in + ctz n + | Popcnt -> Choice.return @@ popcnt n in Stack.push_i64 stack res @@ -831,7 +903,9 @@ module Make (P : Interpret_intf.P) : | I64_const n -> st @@ Stack.push_const_i64 stack n | F32_const f -> st @@ Stack.push_const_f32 stack f | F64_const f -> st @@ Stack.push_const_f64 stack f - | I_unop (nn, op) -> st @@ exec_iunop stack nn op + | I_unop (nn, op) -> + let* stack = exec_iunop stack nn op in + st stack | F_unop (nn, op) -> st @@ exec_funop stack nn op | I_binop (nn, op) -> let* stack = exec_ibinop stack nn op in diff --git a/src/interpret_intf.ml b/src/interpret_intf.ml index ef72a1208..27b9224b1 100644 --- a/src/interpret_intf.ml +++ b/src/interpret_intf.ml @@ -210,7 +210,7 @@ module type S = sig -> Func_intf.t -> value list Result.t choice - val exec_iunop : State.stack -> Types.nn -> Types.iunop -> State.stack + val exec_iunop : State.stack -> Types.nn -> Types.iunop -> State.stack choice val exec_funop : State.stack -> Types.nn -> Types.funop -> State.stack diff --git a/src/symbolic_value.ml b/src/symbolic_value.ml index ae7b497f0..3fc3805a6 100644 --- a/src/symbolic_value.ml +++ b/src/symbolic_value.ml @@ -165,11 +165,9 @@ module I32 = struct let zero = const_i32 0l - let clz e = unop ty Clz e + let clz = None - let ctz _ = - (* TODO *) - assert false + let ctz = None let popcnt _ = (* TODO *) @@ -281,11 +279,9 @@ module I64 = struct let zero = const_i64 0L - let clz e = unop ty Clz e + let clz = None - let ctz _ = - (* TODO *) - assert false + let ctz = None let popcnt _ = (* TODO *) diff --git a/src/value_intf.ml b/src/value_intf.ml index 1202483ac..1c0b816b3 100644 --- a/src/value_intf.ml +++ b/src/value_intf.ml @@ -17,9 +17,9 @@ module type Iop = sig val zero : num - val clz : num -> num + val clz : (num -> num) option - val ctz : num -> num + val ctz : (num -> num) option val popcnt : num -> num diff --git a/test/sym/clz_32.wat b/test/sym/clz_32.wat new file mode 100644 index 000000000..acee1ef6d --- /dev/null +++ b/test/sym/clz_32.wat @@ -0,0 +1,64 @@ +(module + (import "symbolic" "i32_symbol" (func $i32_symbol (result i32))) + (import "symbolic" "assume" (func $assume (param i32))) + (import "symbolic" "assert" (func $assert (param i32))) + + (func $countLeadingZeros (param i32) (result i32) + (local $x i32) + (local $res i32) + + + ;; Initialize local variables + (local.set $res (i32.const 32)) ;; Initialize with the highest possible index of a bit + (local.set $x (local.get 0)) ;; Store the input + + ;; Loop to find the leading zeros + (block $outter + (loop $loop + + ;; Check if all bits are shifted out + (if (i32.eqz (local.get $x)) + (then (br $outter)) + ) + + ;; Shift the input to the right by 1 bit + (local.set $x (i32.shr_u (local.get $x) (i32.const 1))) + + ;; Decrement the count of zero bits + (local.set $res (i32.sub (local.get $res) (i32.const 1))) + + (br $loop) + ) + ) + + ;; Return the number of leading zeros + (return (local.get $res)) + ) + + (func $start + + (local $n i32) + (local.set $n (call $i32_symbol)) + + (call $assert (i32.eq + (call $countLeadingZeros (local.get $n)) + (i32.clz (local.get $n)) + )) + + (call $assert (i32.eq + (i32.ctz (local.get $n)) + ;; Implem of ctz using clz + ;; from hacker's delight p107 + (i32.sub + (i32.const 32) + (i32.clz ( i32.and + (i32.xor (local.get $n) (i32.const -1)) + (i32.sub (local.get $n) (i32.const 1)) + )) + ) + )) + ) + + + (start $start) +) diff --git a/test/sym/clz_64.wat b/test/sym/clz_64.wat new file mode 100644 index 000000000..0430d61de --- /dev/null +++ b/test/sym/clz_64.wat @@ -0,0 +1,64 @@ +(module + (import "symbolic" "i64_symbol" (func $i64_symbol (result i64))) + (import "symbolic" "assume" (func $assume (param i32))) + (import "symbolic" "assert" (func $assert (param i32))) + + (func $countLeadingZeros (param i64) (result i64) + (local $x i64) + (local $res i64) + + + ;; Initialize local variables + (local.set $res (i64.const 64)) ;; Initialize with the highest possible index of a bit + (local.set $x (local.get 0)) ;; Store the input + + ;; Loop to find the leading zeros + (block $outter + (loop $loop + + ;; Check if all bits are shifted out + (if (i64.eqz (local.get $x)) + (then (br $outter)) + ) + + ;; Shift the input to the right by 1 bit + (local.set $x (i64.shr_u (local.get $x) (i64.const 1))) + + ;; Decrement the count of zero bits + (local.set $res (i64.sub (local.get $res) (i64.const 1))) + + (br $loop) + ) + ) + + ;; Return the number of leading zeros + (return (local.get $res)) + ) + + (func $start + + (local $n i64) + (local.set $n (call $i64_symbol)) + + (call $assert (i64.eq + (call $countLeadingZeros (local.get $n)) + (i64.clz (local.get $n)) + )) + + (call $assert (i64.eq + (i64.ctz (local.get $n)) + ;; Implem of ctz using clz + ;; from hacker's delight p107 + (i64.sub + (i64.const 64) + (i64.clz ( i64.and + (i64.xor (local.get $n) (i64.const -1)) + (i64.sub (local.get $n) (i64.const 1)) + )) + ) + )) + ) + + + (start $start) +)