From bcfda9058d9828fe7338ed99ea2f20fe9442b1c8 Mon Sep 17 00:00:00 2001 From: Haydar Kulekci Date: Wed, 12 Jun 2024 14:48:42 +0300 Subject: [PATCH] improved recommendation endpoint --- src/Endpoints/Collections/Points.php | 21 +++----- .../Collections/Points/Recommend.php | 44 ++++++---------- .../Request/Points/BatchRecommendRequest.php | 50 ++++++++++++++++++ .../Request/Points/GroupRecommendRequest.php | 15 ++++++ .../Request/{ => Points}/RecommendRequest.php | 15 +++--- .../{ => Points}/RecommendTest.php | 52 +++++++++++++++++-- .../Points/BatchRecommendRequestTest.php | 37 +++++++++++++ .../{ => Points}/RecommendRequestTest.php | 20 ++++++- 8 files changed, 201 insertions(+), 53 deletions(-) create mode 100644 src/Models/Request/Points/BatchRecommendRequest.php create mode 100644 src/Models/Request/Points/GroupRecommendRequest.php rename src/Models/Request/{ => Points}/RecommendRequest.php (90%) rename tests/Integration/Endpoints/Collections/{ => Points}/RecommendTest.php (73%) create mode 100644 tests/Unit/Models/Request/Points/BatchRecommendRequestTest.php rename tests/Unit/Models/Request/{ => Points}/RecommendRequestTest.php (82%) diff --git a/src/Endpoints/Collections/Points.php b/src/Endpoints/Collections/Points.php index 2fa4488..ea25a30 100644 --- a/src/Endpoints/Collections/Points.php +++ b/src/Endpoints/Collections/Points.php @@ -12,11 +12,11 @@ use Qdrant\Endpoints\AbstractEndpoint; use Qdrant\Endpoints\Collections\Points\Payload; +use Qdrant\Endpoints\Collections\Points\Recommend; use Qdrant\Exception\InvalidArgumentException; use Qdrant\Models\Filter\Filter; use Qdrant\Models\PointsStruct; use Qdrant\Models\Request\PointsBatch; -use Qdrant\Models\Request\RecommendRequest; use Qdrant\Models\Request\ScrollRequest; use Qdrant\Models\Request\SearchRequest; use Qdrant\Response; @@ -28,6 +28,11 @@ public function payload(): Payload return (new Payload($this->client))->setCollectionName($this->collectionName); } + public function recommend(): Recommend + { + return (new Recommend($this->client))->setCollectionName($this->collectionName); + } + /** * @throws InvalidArgumentException */ @@ -180,18 +185,4 @@ public function batch(PointsBatch $points, array $queryParams = []): Response ) ); } - - /** - * @throws InvalidArgumentException - */ - public function recommend(RecommendRequest $recommendParams, array $queryParams = []): Response - { - return $this->client->execute( - $this->createRequest( - 'POST', - 'collections/' . $this->collectionName . '/points/recommend' . $this->queryBuild($queryParams), - $recommendParams->toArray() - ) - ); - } } diff --git a/src/Endpoints/Collections/Points/Recommend.php b/src/Endpoints/Collections/Points/Recommend.php index d7249b7..b196dd9 100644 --- a/src/Endpoints/Collections/Points/Recommend.php +++ b/src/Endpoints/Collections/Points/Recommend.php @@ -10,51 +10,44 @@ use Qdrant\Endpoints\AbstractEndpoint; use Qdrant\Exception\InvalidArgumentException; -use Qdrant\Models\Filter\Filter; +use Qdrant\Models\Request\Points\BatchRecommendRequest; +use Qdrant\Models\Request\Points\GroupRecommendRequest; +use Qdrant\Models\Request\Points\RecommendRequest; use Qdrant\Response; class Recommend extends AbstractEndpoint { /** + * Retrieves points that are closer to stored positive examples and further from negative examples. + * * @throws InvalidArgumentException */ - public function clear(array $points): Response + public function recommend(RecommendRequest $request, array $queryParams = []): Response { return $this->client->execute( $this->createRequest( 'POST', - '/collections/' . $this->getCollectionName() . '/points/recommend', - [ - 'points' => $points, - ] + '/collections/' . $this->getCollectionName() . '/points/recommend' . $this->queryBuild($queryParams), + $request->toArray() ) ); } /** - * Delete specified key payload for points + * Retrieves points in batches that are closer to stored positive examples and further from negative examples. * - * @param array $points - * @param array $keys - * @param Filter|null $filter + * @param BatchRecommendRequest $request + * @param array $queryParams * @return Response - * @throws InvalidArgumentException */ - public function delete(array $points, array $keys, Filter $filter = null, array $queryParams = []): Response + public function batch(BatchRecommendRequest $request, array $queryParams = []): Response { - $data = [ - 'points' => $points, - 'keys' => $keys - ]; - if ($filter) { - $data['filters'] = $filter->toArray(); - } return $this->client->execute( $this->createRequest( 'POST', - '/collections/' . $this->getCollectionName() . '/points/payload/delete' . $this->queryBuild($queryParams), - $data + '/collections/' . $this->getCollectionName() . '/points/recommend/batch' . $this->queryBuild($queryParams), + $request->toArray() ) ); } @@ -62,16 +55,13 @@ public function delete(array $points, array $keys, Filter $filter = null, array /** * @throws InvalidArgumentException */ - public function set(array $points, array $payload, array $queryParams = []): Response + public function groups(GroupRecommendRequest $request, array $queryParams = []): Response { return $this->client->execute( $this->createRequest( 'POST', - '/collections/' . $this->getCollectionName() . '/points/payload' . $this->queryBuild($queryParams), - [ - 'payload' => $payload, - 'points' => $points, - ] + '/collections/' . $this->getCollectionName() . '/points/recommend/groups' . $this->queryBuild($queryParams), + $request->toArray() ) ); } diff --git a/src/Models/Request/Points/BatchRecommendRequest.php b/src/Models/Request/Points/BatchRecommendRequest.php new file mode 100644 index 0000000..aeaf838 --- /dev/null +++ b/src/Models/Request/Points/BatchRecommendRequest.php @@ -0,0 +1,50 @@ + + */ +namespace Qdrant\Models\Request\Points; + +use Qdrant\Exception\InvalidArgumentException; +use Qdrant\Models\Filter\Filter; +use Qdrant\Models\Traits\ProtectedPropertyAccessor; + +class BatchRecommendRequest +{ + use ProtectedPropertyAccessor; + + /** @var RecommendRequest[] $searches */ + protected array $searches = []; + + /** + * @param RecommendRequest[] $searches + */ + public function __construct(array $searches) + { + foreach ($searches as $search) { + $this->addSearch($search); + } + } + + public function addSearch(RecommendRequest $request): static + { + $this->searches[] = $request; + + return $this; + } + + public function toArray(): array + { + $searches = []; + + foreach ($this->searches as $search) { + $searches[] = $search->toArray(); + } + + return [ + 'searches' => $searches + ]; + } +} \ No newline at end of file diff --git a/src/Models/Request/Points/GroupRecommendRequest.php b/src/Models/Request/Points/GroupRecommendRequest.php new file mode 100644 index 0000000..97825b0 --- /dev/null +++ b/src/Models/Request/Points/GroupRecommendRequest.php @@ -0,0 +1,15 @@ + + * @author Haydar KULEKCI + */ +namespace Qdrant\Models\Request\Points; + + +class GroupRecommendRequest +{ + +} \ No newline at end of file diff --git a/src/Models/Request/RecommendRequest.php b/src/Models/Request/Points/RecommendRequest.php similarity index 90% rename from src/Models/Request/RecommendRequest.php rename to src/Models/Request/Points/RecommendRequest.php index 488e2e3..31e40bc 100644 --- a/src/Models/Request/RecommendRequest.php +++ b/src/Models/Request/Points/RecommendRequest.php @@ -5,7 +5,7 @@ * @since Jun 2023 * @author Greg Priday */ -namespace Qdrant\Models\Request; +namespace Qdrant\Models\Request\Points; use Qdrant\Exception\InvalidArgumentException; use Qdrant\Models\Filter\Filter; @@ -29,7 +29,7 @@ class RecommendRequest const STRATEGY_BEST_SCORE = 'best_score'; protected ?string $shardKey = null; - protected string $strategy = ''; + protected ?string $strategy = null; protected ?Filter $filter = null; protected ?string $using = null; protected ?int $limit = null; @@ -109,16 +109,19 @@ public function toArray(): array if ($this->filter !== null && $this->filter->toArray()) { $body['filter'] = $this->filter->toArray(); } - if($this->scoreThreshold) { + if($this->scoreThreshold !== null) { $body['score_threshold'] = $this->scoreThreshold; } - if ($this->using) { + if ($this->using !== null) { $body['using'] = $this->using; } - if ($this->limit) { + if ($this->limit !== null) { $body['limit'] = $this->limit; } - if ($this->offset) { + if ($this->strategy !== null) { + $body['strategy'] = $this->strategy; + } + if ($this->offset !== null) { $body['offset'] = $this->offset; } diff --git a/tests/Integration/Endpoints/Collections/RecommendTest.php b/tests/Integration/Endpoints/Collections/Points/RecommendTest.php similarity index 73% rename from tests/Integration/Endpoints/Collections/RecommendTest.php rename to tests/Integration/Endpoints/Collections/Points/RecommendTest.php index 0a2a53f..f6440a0 100644 --- a/tests/Integration/Endpoints/Collections/RecommendTest.php +++ b/tests/Integration/Endpoints/Collections/Points/RecommendTest.php @@ -4,14 +4,15 @@ * @author Greg Priday */ -namespace Qdrant\Tests\Integration\Endpoints\Collections; +namespace Integration\Endpoints\Collections\Points; use Qdrant\Endpoints\Collections; use Qdrant\Exception\InvalidArgumentException; use Qdrant\Models\Filter\Condition\MatchString; use Qdrant\Models\Filter\Filter; use Qdrant\Models\PointsStruct; -use Qdrant\Models\Request\RecommendRequest; +use Qdrant\Models\Request\Points\BatchRecommendRequest; +use Qdrant\Models\Request\Points\RecommendRequest; use Qdrant\Models\VectorStruct; use Qdrant\Tests\Integration\AbstractIntegration; @@ -47,6 +48,20 @@ public static function basicPointDataProvider(): array 'image' => 'sample image' ] ], + [ + 'id' => 3, + 'vector' => new VectorStruct([1, 3, 200], 'image'), + 'payload' => [ + 'image' => 'sample image' + ] + ], + [ + 'id' => 4, + 'vector' => new VectorStruct([1, 3, 100], 'image'), + 'payload' => [ + 'image' => 'sample image' + ] + ], ] ] ]; @@ -74,7 +89,7 @@ public function testRecommendPoint(array $positive, array $negative): void ) ); - $response = $this->getCollections('sample-collection')->points()->recommend($recommendRequest); + $response = $this->getCollections('sample-collection')->points()->recommend()->recommend($recommendRequest); $this->assertEquals('ok', $response['status']); } @@ -104,11 +119,13 @@ public function testRecommendWithThreshold(): void // Perform recommend without score threshold $responseWithoutThreshold = $this->getCollections('sample-collection') ->points() + ->recommend() ->recommend($recommendRequestWithoutThreshold); // Perform recommend with score threshold $responseWithThreshold = $this->getCollections('sample-collection') ->points() + ->recommend() ->recommend($recommendRequestWithThreshold); // Check that we got a response in both cases @@ -123,6 +140,35 @@ public function testRecommendWithThreshold(): void ); } + public static function batchRecommendQueryProvider(): array + { + + return [ + [ + [ + (new RecommendRequest([1], [2]))->setLimit(3)->setUsing('image'), + (new RecommendRequest([1, 2], []))->setLimit(3)->setUsing('image'), + (new RecommendRequest([1], [2, 3]))->setLimit(3)->setUsing('image'), + ] + ] + ]; + } + + /** + * @dataProvider batchRecommendQueryProvider + */ + public function testBatchRecommendPoint(array $batch): void + { + $recommendRequest = new BatchRecommendRequest($batch); + + $response = $this->getCollections('sample-collection') + ->points() + ->recommend() + ->batch($recommendRequest); + + $this->assertEquals('ok', $response['status']); + } + protected function tearDown(): void { diff --git a/tests/Unit/Models/Request/Points/BatchRecommendRequestTest.php b/tests/Unit/Models/Request/Points/BatchRecommendRequestTest.php new file mode 100644 index 0000000..a34e5cb --- /dev/null +++ b/tests/Unit/Models/Request/Points/BatchRecommendRequestTest.php @@ -0,0 +1,37 @@ + + */ + +namespace Qdrant\Tests\Unit\Models\Request\Points; + +use PHPUnit\Framework\TestCase; +use Qdrant\Models\Filter\Condition\MatchExcept; +use Qdrant\Models\Filter\Filter; +use Qdrant\Models\Request\Points\BatchRecommendRequest; +use Qdrant\Models\Request\Points\RecommendRequest; + +class BatchRecommendRequestTest extends TestCase +{ + public function testBasicRecommendRequest(): void + { + $request = new BatchRecommendRequest([ + new RecommendRequest([100, 101], [110]), + new RecommendRequest([101, 102], [112]), + ]); + + $this->assertEquals([ + 'searches' =>[ + [ + 'positive' => [100, 101], + 'negative' => [110], + ], + [ + 'positive' => [101, 102], + 'negative' => [112], + ] + ] + ], $request->toArray()); + } +} \ No newline at end of file diff --git a/tests/Unit/Models/Request/RecommendRequestTest.php b/tests/Unit/Models/Request/Points/RecommendRequestTest.php similarity index 82% rename from tests/Unit/Models/Request/RecommendRequestTest.php rename to tests/Unit/Models/Request/Points/RecommendRequestTest.php index d811269..302303b 100644 --- a/tests/Unit/Models/Request/RecommendRequestTest.php +++ b/tests/Unit/Models/Request/Points/RecommendRequestTest.php @@ -4,12 +4,12 @@ * @author Haydar KULEKCI */ -namespace Qdrant\Tests\Unit\Models\Request; +namespace Qdrant\Tests\Unit\Models\Request\Points; use PHPUnit\Framework\TestCase; use Qdrant\Models\Filter\Condition\MatchExcept; use Qdrant\Models\Filter\Filter; -use Qdrant\Models\Request\RecommendRequest; +use Qdrant\Models\Request\Points\RecommendRequest; class RecommendRequestTest extends TestCase { @@ -108,4 +108,20 @@ public function testRecommendRequestWithLimit(): void 'offset'=> 1, ], $request->toArray()); } + + public function testRecommendRequestWithLimitAndStrategy(): void + { + $request = (new RecommendRequest([100, 101], [110])) + ->setLimit(10) + ->setOffset(1) + ->setStrategy(RecommendRequest::STRATEGY_AVERAGE_VECTOR); + + $this->assertEquals([ + 'positive' => [100, 101], + 'negative' => [110], + 'limit'=> 10, + 'offset'=> 1, + 'strategy' => RecommendRequest::STRATEGY_AVERAGE_VECTOR + ], $request->toArray()); + } } \ No newline at end of file