Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dichotomic symbolic clz and ctz #195

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/int32.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

include Stdlib.Int32

let clz n = Stdlib.Int32.of_int (Ocaml_intrinsics.Int32.count_leading_zeros n)

let ctz n = Stdlib.Int32.of_int (Ocaml_intrinsics.Int32.count_trailing_zeros n)
let clz =
Some
(fun n -> Stdlib.Int32.of_int (Ocaml_intrinsics.Int32.count_leading_zeros n))

let ctz =
Some
(fun n ->
Stdlib.Int32.of_int (Ocaml_intrinsics.Int32.count_trailing_zeros n) )

(* Taken from Base https://github.com/janestreet/base *)
let popcnt =
Expand Down
4 changes: 2 additions & 2 deletions src/int32.mli
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ val unsigned_to_int : t -> int option

(** unary operators *)

val clz : t -> t
val clz : (t -> t) option

val ctz : t -> t
val ctz : (t -> t) option

val popcnt : t -> t

Expand Down
11 changes: 8 additions & 3 deletions src/int64.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

include Stdlib.Int64

let clz n = Stdlib.Int64.of_int (Ocaml_intrinsics.Int64.count_leading_zeros n)

let ctz n = Stdlib.Int64.of_int (Ocaml_intrinsics.Int64.count_trailing_zeros n)
let clz =
Some
(fun n -> Stdlib.Int64.of_int (Ocaml_intrinsics.Int64.count_leading_zeros n))

let ctz =
Some
(fun n ->
Stdlib.Int64.of_int (Ocaml_intrinsics.Int64.count_trailing_zeros n) )

(* Taken from Base: https://github.com/janestreet/base *)
let popcnt =
Expand Down
4 changes: 2 additions & 2 deletions src/int64.mli
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ val extend_s : int -> t -> t

val abs : t -> t

val clz : t -> t
val clz : (t -> t) option

val ctz : t -> t
val ctz : (t -> t) option

val popcnt : t -> t

Expand Down
84 changes: 79 additions & 5 deletions src/interpret.ml
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,92 @@ module Make (P : Interpret_intf.P) :

let consti i = const_i32 (Int32.of_int i)

let clz_impl_32 n =
let rec aux (lb : int) ub =
if ub = lb + 1 then return (const_i32 (Int32.of_int (32 - ub)))
else begin
let mid = (lb + ub) / 2 in
let two_pow_mid = Int32.shl 1l (Int32.of_int mid) in
let> cond = I32.(lt_u n (const_i32 two_pow_mid)) in
if cond then aux lb mid else aux mid ub
end
in
let> cond = I32.(eqz n) in
if cond then return @@ const_i32 32l else aux 0 32

let clz_impl_64 n =
let rec aux (lb : int) ub =
if ub = lb + 1 then return (const_i64 (Int64.of_int (64 - ub)))
else begin
let mid = (lb + ub) / 2 in
let two_pow_mid = Int64.shl 1L (Int64.of_int mid) in
(* Could be more efficient with a shift right mid, to bench *)
let> cond = I64.(lt_u n (const_i64 two_pow_mid)) in
if cond then aux lb mid else aux mid ub
end
in
let> cond = I64.(eqz n) in
if cond then return @@ const_i64 64L else aux 0 64

let ctz_impl_32 n =
let rec aux (lb : int) ub =
if ub = lb + 1 then return (const_i32 (Int32.of_int lb))
else begin
let mid = (lb + ub) / 2 in
let two_pow_mid = Int32.shl 1l (Int32.of_int mid) in
let> cond = I32.(eqz @@ rem n (const_i32 two_pow_mid)) in
if cond then aux mid ub else aux lb mid
end
in
let> cond = I32.(eqz n) in
if cond then return @@ const_i32 32l else aux 0 32

let ctz_impl_64 n =
let rec aux (lb : int) ub =
if ub = lb + 1 then return (const_i64 (Int64.of_int lb))
else begin
let mid = (lb + ub) / 2 in
let two_pow_mid = Int64.shl 1L (Int64.of_int mid) in
let> cond = I64.(eqz @@ rem n (const_i64 two_pow_mid)) in
if cond then aux mid ub else aux lb mid
end
in
let> cond = I64.(eqz n) in
if cond then return @@ const_i64 64L else aux 0 64

let with_choosing_default_impl f ch_f =
match f with
| Some f -> fun n -> Choice.return (f n)
| None -> fun n -> ch_f n

let exec_iunop stack nn op =
match nn with
| S32 ->
let n, stack = Stack.pop_i32 stack in
let res =
let+ res =
let open I32 in
match op with Clz -> clz n | Ctz -> ctz n | Popcnt -> popcnt n
match op with
| Clz ->
let clz = with_choosing_default_impl clz clz_impl_32 in
clz n
| Ctz ->
let ctz = with_choosing_default_impl ctz ctz_impl_32 in
ctz n
| Popcnt -> Choice.return @@ popcnt n
in
Stack.push_i32 stack res
| S64 ->
let n, stack = Stack.pop_i64 stack in
let res =
let+ res =
let open I64 in
match op with Clz -> clz n | Ctz -> ctz n | Popcnt -> popcnt n
match op with
| Clz ->
let clz = with_choosing_default_impl clz clz_impl_64 in
clz n
| Ctz ->
let ctz = with_choosing_default_impl ctz ctz_impl_64 in
ctz n
| Popcnt -> Choice.return @@ popcnt n
in
Stack.push_i64 stack res

Expand Down Expand Up @@ -831,7 +903,9 @@ module Make (P : Interpret_intf.P) :
| I64_const n -> st @@ Stack.push_const_i64 stack n
| F32_const f -> st @@ Stack.push_const_f32 stack f
| F64_const f -> st @@ Stack.push_const_f64 stack f
| I_unop (nn, op) -> st @@ exec_iunop stack nn op
| I_unop (nn, op) ->
let* stack = exec_iunop stack nn op in
st stack
| F_unop (nn, op) -> st @@ exec_funop stack nn op
| I_binop (nn, op) ->
let* stack = exec_ibinop stack nn op in
Expand Down
2 changes: 1 addition & 1 deletion src/interpret_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ module type S = sig
-> Func_intf.t
-> value list Result.t choice

val exec_iunop : State.stack -> Types.nn -> Types.iunop -> State.stack
val exec_iunop : State.stack -> Types.nn -> Types.iunop -> State.stack choice

val exec_funop : State.stack -> Types.nn -> Types.funop -> State.stack

Expand Down
12 changes: 4 additions & 8 deletions src/symbolic_value.ml
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,9 @@ module I32 = struct

let zero = const_i32 0l

let clz e = unop ty Clz e
let clz = None

let ctz _ =
(* TODO *)
assert false
let ctz = None

let popcnt _ =
(* TODO *)
Expand Down Expand Up @@ -281,11 +279,9 @@ module I64 = struct

let zero = const_i64 0L

let clz e = unop ty Clz e
let clz = None

let ctz _ =
(* TODO *)
assert false
let ctz = None

let popcnt _ =
(* TODO *)
Expand Down
4 changes: 2 additions & 2 deletions src/value_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ module type Iop = sig

val zero : num

val clz : num -> num
val clz : (num -> num) option

val ctz : num -> num
val ctz : (num -> num) option

val popcnt : num -> num

Expand Down
64 changes: 64 additions & 0 deletions test/sym/clz_32.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
(module
(import "symbolic" "i32_symbol" (func $i32_symbol (result i32)))
(import "symbolic" "assume" (func $assume (param i32)))
(import "symbolic" "assert" (func $assert (param i32)))

(func $countLeadingZeros (param i32) (result i32)
(local $x i32)
(local $res i32)


;; Initialize local variables
(local.set $res (i32.const 32)) ;; Initialize with the highest possible index of a bit
(local.set $x (local.get 0)) ;; Store the input

;; Loop to find the leading zeros
(block $outter
(loop $loop

;; Check if all bits are shifted out
(if (i32.eqz (local.get $x))
(then (br $outter))
)

;; Shift the input to the right by 1 bit
(local.set $x (i32.shr_u (local.get $x) (i32.const 1)))

;; Decrement the count of zero bits
(local.set $res (i32.sub (local.get $res) (i32.const 1)))

(br $loop)
)
)

;; Return the number of leading zeros
(return (local.get $res))
)

(func $start

(local $n i32)
(local.set $n (call $i32_symbol))

(call $assert (i32.eq
(call $countLeadingZeros (local.get $n))
(i32.clz (local.get $n))
))

(call $assert (i32.eq
(i32.ctz (local.get $n))
;; Implem of ctz using clz
;; from hacker's delight p107
(i32.sub
(i32.const 32)
(i32.clz ( i32.and
(i32.xor (local.get $n) (i32.const -1))
(i32.sub (local.get $n) (i32.const 1))
))
)
))
)


(start $start)
)
64 changes: 64 additions & 0 deletions test/sym/clz_64.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
(module
(import "symbolic" "i64_symbol" (func $i64_symbol (result i64)))
(import "symbolic" "assume" (func $assume (param i32)))
(import "symbolic" "assert" (func $assert (param i32)))

(func $countLeadingZeros (param i64) (result i64)
(local $x i64)
(local $res i64)


;; Initialize local variables
(local.set $res (i64.const 64)) ;; Initialize with the highest possible index of a bit
(local.set $x (local.get 0)) ;; Store the input

;; Loop to find the leading zeros
(block $outter
(loop $loop

;; Check if all bits are shifted out
(if (i64.eqz (local.get $x))
(then (br $outter))
)

;; Shift the input to the right by 1 bit
(local.set $x (i64.shr_u (local.get $x) (i64.const 1)))

;; Decrement the count of zero bits
(local.set $res (i64.sub (local.get $res) (i64.const 1)))

(br $loop)
)
)

;; Return the number of leading zeros
(return (local.get $res))
)

(func $start

(local $n i64)
(local.set $n (call $i64_symbol))

(call $assert (i64.eq
(call $countLeadingZeros (local.get $n))
(i64.clz (local.get $n))
))

(call $assert (i64.eq
(i64.ctz (local.get $n))
;; Implem of ctz using clz
;; from hacker's delight p107
(i64.sub
(i64.const 64)
(i64.clz ( i64.and
(i64.xor (local.get $n) (i64.const -1))
(i64.sub (local.get $n) (i64.const 1))
))
)
))
)


(start $start)
)