-
Notifications
You must be signed in to change notification settings - Fork 6
/
pir_test.go
62 lines (50 loc) · 1.61 KB
/
pir_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
package main
// Test suite for classical PIR, used as baseline for the experiments.
import (
"encoding/binary"
"fmt"
"io"
"math"
"testing"
"github.com/si-co/vpir-code/lib/client"
"github.com/si-co/vpir-code/lib/database"
"github.com/si-co/vpir-code/lib/field"
"github.com/si-co/vpir-code/lib/monitor"
"github.com/si-co/vpir-code/lib/server"
"github.com/si-co/vpir-code/lib/utils"
"github.com/stretchr/testify/require"
)
func TestPIRPoint(t *testing.T) {
dbLen := oneMB
blockLen := testBlockLength * field.Bytes
elemBitSize := 8
numBlocks := dbLen / (elemBitSize * blockLen)
nCols := int(math.Sqrt(float64(numBlocks)))
nRows := nCols
// functions defined in vpir_test.go
xofDB := utils.RandomPRG()
xof := utils.RandomPRG()
db := database.CreateRandomBytes(xofDB, dbLen, nRows, blockLen)
retrievePIRPoint(t, xof, db, numBlocks, "PIRPoint")
}
func retrievePIRPoint(t *testing.T, rnd io.Reader, db *database.Bytes, numBlocks int, testName string) {
c := client.NewPIR(rnd, &db.Info)
s0 := server.NewPIRTwo(db)
s1 := server.NewPIRTwo(db)
totalTimer := monitor.NewMonitor()
for i := 0; i < numBlocks; i++ {
in := make([]byte, 4)
binary.BigEndian.PutUint32(in, uint32(i))
queries, err := c.QueryBytes(in, 2)
require.NoError(t, err)
a0, err := s0.AnswerBytes(queries[0])
require.NoError(t, err)
a1, err := s1.AnswerBytes(queries[1])
require.NoError(t, err)
answers := [][]byte{a0, a1}
res, err := c.ReconstructBytes(answers)
require.NoError(t, err)
require.Equal(t, db.Entries[i*db.BlockSize:(i+1)*db.BlockSize], res)
}
fmt.Printf("Total CPU time %s: %.2fms\n", testName, totalTimer.Record())
}