A suggested solution to task 1 of the 2023 ZKP MOOC lab (https://github.com/rdi-berkeley/zkp-mooc-lab). My implementation is highlighted in green (the rest is part of the original template).
Note: the template RightShift has been updated after the initial submission to address an error in the original task.
pragma circom 2.0.0;
template AND() {
signal input a;
signal input b;
signal output out;
out <== a*b;
}
template OR() {
signal input a;
signal input b;
signal output out;
out <== a + b - a*b;
}
template IfThenElse() {
signal input cond;
signal input L;
signal input R;
signal output out;
out <== cond * (L - R) + R;
}
template Switcher() {
signal input sel;
signal input L;
signal input R;
signal output outL;
signal output outR;
signal aux;
aux <== (R-L)*sel;
outL <== aux + L;
outR <== -aux + R;
}
template Num2Bits(b) {
signal input in;
signal output bits[b];
for (var i = 0; i < b; i++) {
bits[i] <-- (in >> i) & 1;
bits[i] * (1 - bits[i]) === 0;
}
var sum_of_bits = 0;
for (var i = 0; i < b; i++) {
sum_of_bits += (2 ** i) * bits[i];
}
sum_of_bits === in;
}
template Bits2Num(b) {
signal input bits[b];
signal output out;
var lc = 0;
for (var i = 0; i < b; i++) {
lc += (bits[i] * (1 << i));
}
out <== lc;
}
template IsZero() {
signal input in;
signal output out;
signal inv;
inv <-- in!=0 ? 1/in : 0;
out <== -in*inv +1;
in*out === 0;
}
template IsEqual() {
signal input in[2];
signal output out;
component isz = IsZero();
in[1] - in[0] ==> isz.in;
isz.out ==> out;
}
template LessThan(n) {
assert(n <= 252);
signal input in[2];
signal output out;
component n2b = Num2Bits(n+1);
n2b.in <== in[0]+ (1<<n) - in[1];
out <== 1-n2b.bits[n];
}
template CheckBitLength(b) {
signal input in;
signal output out;
signal bits[b];
for (var i = 0; i < b; i++) {
bits[i] <-- (in >> i) & 1;
bits[i] * (1 - bits[i]) === 0;
}
var sum_of_bits = 0;
for (var i = 0; i < b; i++) {
sum_of_bits += (1 << i) * bits[i];
}
component is_zero = IsZero();
is_zero.in <== in - sum_of_bits;
out <== is_zero.out;
}
template CheckWellFormedness(k, p) {
signal input e;
signal input m;
component is_e_zero = IsZero();
is_e_zero.in <== e;
component is_m_zero = IsZero();
is_m_zero.in <== m;
component check_e_bits = CheckBitLength(k);
check_e_bits.in <== e;
component check_m_bits = CheckBitLength(p);
check_m_bits.in <== m - (1 << p);
component if_else = IfThenElse();
if_else.cond <== is_e_zero.out;
if_else.L <== is_m_zero.out;
//// check_m_bits.out * check_e_bits.out is equivalent to check_m_bits.out AND check_e_bits.out
if_else.R <== check_m_bits.out * check_e_bits.out;
if_else.out === 1;
}
template RightShift(b, shift) {
assert(shift < b);
signal input x;
signal output y;
signal r;
y <-- x \ 2 ** shift;
r <-- x % 2 ** shift;
y * 2 ** shift - x + r === 0;
component less_than = LessThan(shift);
less_than.in[0] <== x;
less_than.in[1] <== 2 ** b;
less_than.out === 1;
}
template RoundAndCheck(k, p, P) {
signal input e;
signal input m;
signal output e_out;
signal output m_out;
assert(P > p);
component if_no_overflow = LessThan(P+1);
if_no_overflow.in[0] <== m;
if_no_overflow.in[1] <== (1 << (P+1)) - (1 << (P-p-1));
signal no_overflow <== if_no_overflow.out;
var round_amt = P-p;
var m_prime = m + (1 << (round_amt-1));
component right_shift = RightShift(round_amt);
right_shift.x <== m_prime;
var m_out_1 = right_shift.y;
var e_out_1 = e;
var e_out_2 = e + 1;
var m_out_2 = (1 << p);
component if_else[2];
for (var i = 0; i < 2; i++) {
if_else[i] = IfThenElse();
if_else[i].cond <== no_overflow;
}
if_else[0].L <== e_out_1;
if_else[0].R <== e_out_2;
if_else[1].L <== m_out_1;
if_else[1].R <== m_out_2;
e_out <== if_else[0].out;
m_out <== if_else[1].out;
}
template LeftShift(shift_bound) {
signal input x;
signal input shift;
signal input skip_checks;
signal output y;
y <-- x * 2 ** shift;
component is_zero = IsZero();
is_zero.in <== shift + 1;
is_zero.out === 0;
component less_than = LessThan(shift_bound);
less_than.in[0] <== shift;
less_than.in[1] <== shift_bound;
less_than.out === 1 - skip_checks;
}
template MSNZB(b) {
signal input in;
signal input skip_checks;
signal output one_hot[b];
component is_zero = IsZero();
is_zero.in <== in;
is_zero.out === skip_checks;
component in_bits = Num2Bits(b);
in_bits.in <== in;
signal one_bits[b]; one_bits[b-1] <== 1;
for (var i = b-2; i >= 0; i--) {
one_bits[i] <== (1 - in_bits.bits[i+1]) * one_bits[i+1];
}
for (var i = 0; i < b; i++) {
one_hot[i] <== in_bits.bits[i] * one_bits[i];
}
}
template Normalize(k, p, P) {
signal input e;
signal input m;
signal input skip_checks;
signal output e_out;
signal output m_out;
assert(P > p);
component msnzb = MSNZB(P+1);
msnzb.in <== m;
msnzb.skip_checks <== skip_checks;
var e_term = 0, m_factor = 0;
for (var i = 0; i <= P; i++) {
// compute the decimal value of the exponent
e_term += msnzb.one_hot[i] * i;
// compute the decimal value of the mantissa
m_factor += msnzb.one_hot[i] * 2 ** (P-i);
}
e_out <== e + e_term - p;
m_out <== m_factor * m;
}
template FloatAdd(k, p) {
signal input e[2];
signal input m[2];
signal output e_out;
signal output m_out;
component wellformedness_0 = CheckWellFormedness(k, p);
wellformedness_0.e <== e[0];
wellformedness_0.m <== m[0];
component wellformedness_1 = CheckWellFormedness(k, p);
wellformedness_1.e <== e[1];
wellformedness_1.m <== m[1];
signal n0 <== e[0] * 2 ** (p+1) + m[0];
signal n1 <== e[1] * 2 ** (p+1) + m[1];
signal largest_e, largest_m, smallest_e, smallest_m;
component less_than = LessThan(k+p);
less_than.in[0] <== n0;
less_than.in[1] <== n1;
component switcher_e = Switcher();
switcher_e.sel <== less_than.out;
switcher_e.L <== e[0];
switcher_e.R <== e[1];
largest_e <== switcher_e.outL; smallest_e <== switcher_e.outR;
component switcher_m = Switcher();
switcher_m.sel <== less_than.out;
switcher_m.L <== m[0];
switcher_m.R <== m[1];
largest_m <== switcher_m.outL; smallest_m <== switcher_m.outR;
signal is_subnormal;
signal diff_e <== largest_e - smallest_e;
component is_zero = IsZero();
is_zero.in <== largest_e;
component less_than_diff = LessThan(k);
less_than_diff.in[0] <== p+1;
less_than_diff.in[1] <== diff_e;
component or = OR();
or.a <== less_than_diff.out;
or.b <== is_zero.out;
is_subnormal <== or.out;
component if_else_largest_m = IfThenElse();
if_else_largest_m.cond <== is_subnormal;
if_else_largest_m.L <== 1;
if_else_largest_m.R <== largest_m;
component if_else_smallest_e = IfThenElse();
if_else_smallest_e.cond <== is_subnormal;
if_else_smallest_e.L <== 1;
if_else_smallest_e.R <== smallest_e;
component if_else_diff_e = IfThenElse();
if_else_diff_e.cond <== is_subnormal;
if_else_diff_e.L <== 0;
if_else_diff_e.R <== diff_e;
component m_largest_left_shift = LeftShift(p+2);
m_largest_left_shift.x <== if_else_largest_m.out;
m_largest_left_shift.shift <== if_else_diff_e.out;
m_largest_left_shift.skip_checks <== 0;
component normalize = Normalize(k, p, 2*p+1);
normalize.e <== if_else_smallest_e.out;
normalize.m <== m_largest_left_shift.y + smallest_m;
normalize.skip_checks <== 0;
component round = RoundAndCheck(k, p, 2*p+1);
round.e <== normalize.e_out;
round.m <== normalize.m_out;
component if_else_e = IfThenElse();
if_else_e.cond <== is_subnormal;
if_else_e.L <== largest_e;
if_else_e.R <== round.e_out;
component if_else_m = IfThenElse();
if_else_m.cond <== is_subnormal;
if_else_m.L <== largest_m;
if_else_m.R <== round.m_out;
e_out <== if_else_e.out;
m_out <== if_else_m.out;
}