Skip to content

Commit

Permalink
test: fix item values to real numbers to check the effect of dtype pa…
Browse files Browse the repository at this point in the history
…rameter
  • Loading branch information
yoshoku committed Sep 30, 2022
1 parent d32746b commit 97d93d1
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .rubocop.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ RSpec/ExampleLength:
Max: 16

RSpec/MultipleMemoizedHelpers:
Max: 8
Max: 10

RSpec/NamedSubject:
Enabled: false
Expand Down
73 changes: 51 additions & 22 deletions spec/annoy_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
end

describe Annoy::AnnoyIndex do
let(:tol) { 1e-7 }
let(:n_features) { 4 }
let(:metric) { 'manhattan' }
let(:dtype) { 'float64' }
Expand All @@ -16,11 +17,11 @@

before do
index.seed(1)
index.add_item(0, [0, 1, 2, 3])
index.add_item(1, [1, 1, 2, 3])
index.add_item(2, [1, 1, 2, 4])
index.add_item(3, [2, 1, 2, 5])
index.add_item(4, [2, 1, 2, 6])
index.add_item(0, [0.0, 0.1, 0.2, 0.3])
index.add_item(1, [0.1, 0.1, 0.2, 0.3])
index.add_item(2, [0.1, 0.1, 0.2, 0.4])
index.add_item(3, [0.2, 0.1, 0.2, 0.5])
index.add_item(4, [0.2, 0.1, 0.2, 0.6])
index.build(n_trees)
end

Expand Down Expand Up @@ -94,14 +95,24 @@
context 'when include_distances is true' do
subject { index.get_nns_by_item(0, n_neighbors, include_distances: true) }

it 'returns id list and distances of nearest neighbors' do
expect(subject).to match([[0, 1, 2], [0, 1, 2]])
it 'returns id list and distances of nearest neighbors', :aggregate_failures do
expect(subject[0]).to match([0, 1, 2])
expect(subject[1]).to be_within(tol).of([0.0, 0.1, 0.2])
end

context 'with float32 data type' do
let(:dtype) { 'float32' }

it 'returns id list and distances of nearest neighbors', :aggregate_failures do
expect(subject[0]).to match([0, 1, 2])
expect(subject[1]).to be_within(tol).of([0.0, 0.1, 0.2])
end
end
end
end

describe '#get_nns_by_vector' do
let(:query) { [0, 1, 2, 3] }
let(:query) { [0.0, 0.1, 0.2, 0.3] }

context 'when include_distances is false' do
subject { index.get_nns_by_vector(query, n_neighbors) }
Expand All @@ -114,8 +125,18 @@
context 'when include_distances is true' do
subject { index.get_nns_by_vector(query, n_neighbors, include_distances: true) }

it 'returns id list and distances of nearest neighbors' do
expect(subject).to match([[0, 1, 2], [0, 1, 2]])
it 'returns id list and distances of nearest neighbors', :aggregate_failures do
expect(subject[0]).to match([0, 1, 2])
expect(subject[1]).to be_within(tol).of([0.0, 0.1, 0.2])
end

context 'with float32 data type' do
let(:dtype) { 'float32' }

it 'returns id list and distances of nearest neighbors', :aggregate_failures do
expect(subject[0]).to match([0, 1, 2])
expect(subject[1]).to be_within(tol).of([0.0, 0.1, 0.2])
end
end
end
end
Expand All @@ -124,7 +145,15 @@
subject { index.get_item(0) }

it 'returns the item vector specified by ID' do
expect(subject).to match([0, 1, 2, 3])
expect(subject).to be_within(tol).of([0.0, 0.1, 0.2, 0.3])
end

context 'with float32 data type' do
let(:dtype) { 'float32' }

it 'returns the item vector specified by ID' do
expect(subject).to be_within(tol).of([0.0, 0.1, 0.2, 0.3])
end
end
end

Expand Down Expand Up @@ -187,11 +216,11 @@
# from annoy import AnnoyIndex
#
# t = AnnoyIndex(4, 'angular')
# t.add_item(0, [1, 2, 3, 4])
# t.add_item(1, [5, 6, 7, 8])
# t.add_item(2, [9, 0, 1, 2])
# t.add_item(3, [3, 4, 5, 6])
# t.add_item(4, [7, 8, 9, 0])
# t.add_item(0, [0.1, 0.2, 0.3, 0.4])
# t.add_item(1, [0.5, 0.6, 0.7, 0.8])
# t.add_item(2, [0.9, 0.0, 0.1, 0.2])
# t.add_item(3, [0.3, 0.4, 0.5, 0.6])
# t.add_item(4, [0.7, 0.8, 0.9, 0.0])
# t.build(5)
# t.save('pytest.ann')
let(:filename) { File.expand_path("#{__dir__}/pytest.ann") }
Expand All @@ -200,11 +229,11 @@
loaded = described_class.new(n_features: 4, metric: 'angular', dtype: 'float32')
loaded.load(filename)
expect(loaded.n_items).to eq(5)
expect(loaded.get_item(0)).to eq([1, 2, 3, 4])
expect(loaded.get_item(1)).to eq([5, 6, 7, 8])
expect(loaded.get_item(2)).to eq([9, 0, 1, 2])
expect(loaded.get_item(3)).to eq([3, 4, 5, 6])
expect(loaded.get_item(4)).to eq([7, 8, 9, 0])
expect(loaded.get_item(0)).to be_within(tol).of([0.1, 0.2, 0.3, 0.4])
expect(loaded.get_item(1)).to be_within(tol).of([0.5, 0.6, 0.7, 0.8])
expect(loaded.get_item(2)).to be_within(tol).of([0.9, 0.0, 0.1, 0.2])
expect(loaded.get_item(3)).to be_within(tol).of([0.3, 0.4, 0.5, 0.6])
expect(loaded.get_item(4)).to be_within(tol).of([0.7, 0.8, 0.9, 0.0])
end
end
end
Expand Down Expand Up @@ -246,7 +275,7 @@
let(:metric) { 'euclidean' }

it 'returns euclidean distance between items' do
expect(subject).to be_within(1e-8).of(1.41421356)
expect(subject).to be_within(tol).of(1.41421356)
end
end

Expand Down
Binary file modified spec/pytest.ann
Binary file not shown.
19 changes: 19 additions & 0 deletions spec/spec_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,22 @@
c.syntax = :expect
end
end

module RSpec
module Matchers
module BuiltIn
class BeWithin < BaseMatcher
def matches?(actual)
@actual = actual
raise needs_expected unless defined? @expected

if @expected.is_a?(Array) && @actual.is_a?(Array)
@actual.zip(@expected).all? { |ac, ex| (ac - ex).abs <= @tolerance }
else
numeric? && (@actual - @expected).abs <= @tolerance
end
end
end
end
end
end
Binary file modified spec/test.ann
Binary file not shown.

0 comments on commit 97d93d1

Please sign in to comment.