Skip to content

Commit

Permalink
improved recommendation endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
hkulekci committed Jun 12, 2024
1 parent a6cf8ed commit bcfda90
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 53 deletions.
21 changes: 6 additions & 15 deletions src/Endpoints/Collections/Points.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

use Qdrant\Endpoints\AbstractEndpoint;
use Qdrant\Endpoints\Collections\Points\Payload;
use Qdrant\Endpoints\Collections\Points\Recommend;
use Qdrant\Exception\InvalidArgumentException;
use Qdrant\Models\Filter\Filter;
use Qdrant\Models\PointsStruct;
use Qdrant\Models\Request\PointsBatch;
use Qdrant\Models\Request\RecommendRequest;
use Qdrant\Models\Request\ScrollRequest;
use Qdrant\Models\Request\SearchRequest;
use Qdrant\Response;
Expand All @@ -28,6 +28,11 @@ public function payload(): Payload
return (new Payload($this->client))->setCollectionName($this->collectionName);
}

public function recommend(): Recommend
{
return (new Recommend($this->client))->setCollectionName($this->collectionName);
}

/**
* @throws InvalidArgumentException
*/
Expand Down Expand Up @@ -180,18 +185,4 @@ public function batch(PointsBatch $points, array $queryParams = []): Response
)
);
}

/**
* @throws InvalidArgumentException
*/
public function recommend(RecommendRequest $recommendParams, array $queryParams = []): Response
{
return $this->client->execute(
$this->createRequest(
'POST',
'collections/' . $this->collectionName . '/points/recommend' . $this->queryBuild($queryParams),
$recommendParams->toArray()
)
);
}
}
44 changes: 17 additions & 27 deletions src/Endpoints/Collections/Points/Recommend.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,68 +10,58 @@

use Qdrant\Endpoints\AbstractEndpoint;
use Qdrant\Exception\InvalidArgumentException;
use Qdrant\Models\Filter\Filter;
use Qdrant\Models\Request\Points\BatchRecommendRequest;
use Qdrant\Models\Request\Points\GroupRecommendRequest;
use Qdrant\Models\Request\Points\RecommendRequest;
use Qdrant\Response;

class Recommend extends AbstractEndpoint
{
/**
* Retrieves points that are closer to stored positive examples and further from negative examples.
*
* @throws InvalidArgumentException
*/
public function clear(array $points): Response
public function recommend(RecommendRequest $request, array $queryParams = []): Response
{
return $this->client->execute(
$this->createRequest(
'POST',
'/collections/' . $this->getCollectionName() . '/points/recommend',
[
'points' => $points,
]
'/collections/' . $this->getCollectionName() . '/points/recommend' . $this->queryBuild($queryParams),
$request->toArray()
)
);
}

/**
* Delete specified key payload for points
* Retrieves points in batches that are closer to stored positive examples and further from negative examples.
*
* @param array $points
* @param array $keys
* @param Filter|null $filter
* @param BatchRecommendRequest $request
* @param array $queryParams
* @return Response
* @throws InvalidArgumentException
*/
public function delete(array $points, array $keys, Filter $filter = null, array $queryParams = []): Response
public function batch(BatchRecommendRequest $request, array $queryParams = []): Response
{
$data = [
'points' => $points,
'keys' => $keys
];
if ($filter) {
$data['filters'] = $filter->toArray();
}

return $this->client->execute(
$this->createRequest(
'POST',
'/collections/' . $this->getCollectionName() . '/points/payload/delete' . $this->queryBuild($queryParams),
$data
'/collections/' . $this->getCollectionName() . '/points/recommend/batch' . $this->queryBuild($queryParams),
$request->toArray()
)
);
}

/**
* @throws InvalidArgumentException
*/
public function set(array $points, array $payload, array $queryParams = []): Response
public function groups(GroupRecommendRequest $request, array $queryParams = []): Response

Check warning on line 58 in src/Endpoints/Collections/Points/Recommend.php

View check run for this annotation

Codecov / codecov/patch

src/Endpoints/Collections/Points/Recommend.php#L58

Added line #L58 was not covered by tests
{
return $this->client->execute(
$this->createRequest(
'POST',
'/collections/' . $this->getCollectionName() . '/points/payload' . $this->queryBuild($queryParams),
[
'payload' => $payload,
'points' => $points,
]
'/collections/' . $this->getCollectionName() . '/points/recommend/groups' . $this->queryBuild($queryParams),
$request->toArray()
)
);

Check warning on line 66 in src/Endpoints/Collections/Points/Recommend.php

View check run for this annotation

Codecov / codecov/patch

src/Endpoints/Collections/Points/Recommend.php#L60-L66

Added lines #L60 - L66 were not covered by tests
}
Expand Down
50 changes: 50 additions & 0 deletions src/Models/Request/Points/BatchRecommendRequest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
<?php
/**
* RecommendRequest
*
* @since Jun 2023
* @author Greg Priday <[email protected]>
*/
namespace Qdrant\Models\Request\Points;

use Qdrant\Exception\InvalidArgumentException;
use Qdrant\Models\Filter\Filter;
use Qdrant\Models\Traits\ProtectedPropertyAccessor;

class BatchRecommendRequest
{
use ProtectedPropertyAccessor;

/** @var RecommendRequest[] $searches */
protected array $searches = [];

/**
* @param RecommendRequest[] $searches
*/
public function __construct(array $searches)
{
foreach ($searches as $search) {
$this->addSearch($search);
}
}

public function addSearch(RecommendRequest $request): static
{
$this->searches[] = $request;

return $this;
}

public function toArray(): array
{
$searches = [];

foreach ($this->searches as $search) {
$searches[] = $search->toArray();
}

return [
'searches' => $searches
];
}
}
15 changes: 15 additions & 0 deletions src/Models/Request/Points/GroupRecommendRequest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<?php
/**
* GroupRecommendRequest
*
* @since Jun 2024
* @author Greg Priday <[email protected]>
* @author Haydar KULEKCI <[email protected]>
*/
namespace Qdrant\Models\Request\Points;


class GroupRecommendRequest
{

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* @since Jun 2023
* @author Greg Priday <[email protected]>
*/
namespace Qdrant\Models\Request;
namespace Qdrant\Models\Request\Points;

use Qdrant\Exception\InvalidArgumentException;
use Qdrant\Models\Filter\Filter;
Expand All @@ -29,7 +29,7 @@ class RecommendRequest
const STRATEGY_BEST_SCORE = 'best_score';

protected ?string $shardKey = null;
protected string $strategy = '';
protected ?string $strategy = null;
protected ?Filter $filter = null;
protected ?string $using = null;
protected ?int $limit = null;
Expand Down Expand Up @@ -109,16 +109,19 @@ public function toArray(): array
if ($this->filter !== null && $this->filter->toArray()) {
$body['filter'] = $this->filter->toArray();
}
if($this->scoreThreshold) {
if($this->scoreThreshold !== null) {
$body['score_threshold'] = $this->scoreThreshold;
}
if ($this->using) {
if ($this->using !== null) {
$body['using'] = $this->using;
}
if ($this->limit) {
if ($this->limit !== null) {
$body['limit'] = $this->limit;
}
if ($this->offset) {
if ($this->strategy !== null) {
$body['strategy'] = $this->strategy;
}
if ($this->offset !== null) {
$body['offset'] = $this->offset;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
* @author Greg Priday <[email protected]>
*/

namespace Qdrant\Tests\Integration\Endpoints\Collections;
namespace Integration\Endpoints\Collections\Points;

use Qdrant\Endpoints\Collections;
use Qdrant\Exception\InvalidArgumentException;
use Qdrant\Models\Filter\Condition\MatchString;
use Qdrant\Models\Filter\Filter;
use Qdrant\Models\PointsStruct;
use Qdrant\Models\Request\RecommendRequest;
use Qdrant\Models\Request\Points\BatchRecommendRequest;
use Qdrant\Models\Request\Points\RecommendRequest;
use Qdrant\Models\VectorStruct;
use Qdrant\Tests\Integration\AbstractIntegration;

Expand Down Expand Up @@ -47,6 +48,20 @@ public static function basicPointDataProvider(): array
'image' => 'sample image'
]
],
[
'id' => 3,
'vector' => new VectorStruct([1, 3, 200], 'image'),
'payload' => [
'image' => 'sample image'
]
],
[
'id' => 4,
'vector' => new VectorStruct([1, 3, 100], 'image'),
'payload' => [
'image' => 'sample image'
]
],
]
]
];
Expand Down Expand Up @@ -74,7 +89,7 @@ public function testRecommendPoint(array $positive, array $negative): void
)
);

$response = $this->getCollections('sample-collection')->points()->recommend($recommendRequest);
$response = $this->getCollections('sample-collection')->points()->recommend()->recommend($recommendRequest);

$this->assertEquals('ok', $response['status']);
}
Expand Down Expand Up @@ -104,11 +119,13 @@ public function testRecommendWithThreshold(): void
// Perform recommend without score threshold
$responseWithoutThreshold = $this->getCollections('sample-collection')
->points()
->recommend()
->recommend($recommendRequestWithoutThreshold);

// Perform recommend with score threshold
$responseWithThreshold = $this->getCollections('sample-collection')
->points()
->recommend()
->recommend($recommendRequestWithThreshold);

// Check that we got a response in both cases
Expand All @@ -123,6 +140,35 @@ public function testRecommendWithThreshold(): void
);
}

public static function batchRecommendQueryProvider(): array
{

return [
[
[
(new RecommendRequest([1], [2]))->setLimit(3)->setUsing('image'),
(new RecommendRequest([1, 2], []))->setLimit(3)->setUsing('image'),
(new RecommendRequest([1], [2, 3]))->setLimit(3)->setUsing('image'),
]
]
];
}

/**
* @dataProvider batchRecommendQueryProvider
*/
public function testBatchRecommendPoint(array $batch): void
{
$recommendRequest = new BatchRecommendRequest($batch);

$response = $this->getCollections('sample-collection')
->points()
->recommend()
->batch($recommendRequest);

$this->assertEquals('ok', $response['status']);
}


protected function tearDown(): void
{
Expand Down
37 changes: 37 additions & 0 deletions tests/Unit/Models/Request/Points/BatchRecommendRequestTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
<?php
/**
* @since Mar 2023
* @author Haydar KULEKCI <[email protected]>
*/

namespace Qdrant\Tests\Unit\Models\Request\Points;

use PHPUnit\Framework\TestCase;
use Qdrant\Models\Filter\Condition\MatchExcept;
use Qdrant\Models\Filter\Filter;
use Qdrant\Models\Request\Points\BatchRecommendRequest;
use Qdrant\Models\Request\Points\RecommendRequest;

class BatchRecommendRequestTest extends TestCase
{
public function testBasicRecommendRequest(): void
{
$request = new BatchRecommendRequest([
new RecommendRequest([100, 101], [110]),
new RecommendRequest([101, 102], [112]),
]);

$this->assertEquals([
'searches' =>[
[
'positive' => [100, 101],
'negative' => [110],
],
[
'positive' => [101, 102],
'negative' => [112],
]
]
], $request->toArray());
}
}
Loading

0 comments on commit bcfda90

Please sign in to comment.