Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement for Recommendation Endpoint #51

Merged
merged 1 commit into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions src/Endpoints/Collections/Points.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

use Qdrant\Endpoints\AbstractEndpoint;
use Qdrant\Endpoints\Collections\Points\Payload;
use Qdrant\Endpoints\Collections\Points\Recommend;
use Qdrant\Exception\InvalidArgumentException;
use Qdrant\Models\Filter\Filter;
use Qdrant\Models\PointsStruct;
use Qdrant\Models\Request\PointsBatch;
use Qdrant\Models\Request\RecommendRequest;
use Qdrant\Models\Request\ScrollRequest;
use Qdrant\Models\Request\SearchRequest;
use Qdrant\Response;
Expand All @@ -28,6 +28,11 @@ public function payload(): Payload
return (new Payload($this->client))->setCollectionName($this->collectionName);
}

public function recommend(): Recommend
{
return (new Recommend($this->client))->setCollectionName($this->collectionName);
}

/**
* @throws InvalidArgumentException
*/
Expand Down Expand Up @@ -180,18 +185,4 @@ public function batch(PointsBatch $points, array $queryParams = []): Response
)
);
}

/**
* @throws InvalidArgumentException
*/
public function recommend(RecommendRequest $recommendParams): Response
{
return $this->client->execute(
$this->createRequest(
'POST',
'collections/' . $this->collectionName . '/points/recommend',
$recommendParams->toArray()
)
);
}
}
62 changes: 62 additions & 0 deletions src/Endpoints/Collections/Points/Recommend.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
<?php
/**
* Payload
*
* @since Mar 2023
* @author Haydar KULEKCI <[email protected]>
*/

namespace Qdrant\Endpoints\Collections\Points;

use Qdrant\Endpoints\AbstractEndpoint;
use Qdrant\Exception\InvalidArgumentException;
use Qdrant\Models\Request\Points\BatchRecommendRequest;
use Qdrant\Models\Request\Points\GroupRecommendRequest;
use Qdrant\Models\Request\Points\RecommendRequest;
use Qdrant\Response;

class Recommend extends AbstractEndpoint
{
/**
* Retrieves points that are closer to stored positive examples and further from negative examples.
*
* @throws InvalidArgumentException
*/
public function recommend(RecommendRequest $request, array $queryParams = []): Response
{
return $this->client->execute(
$this->createRequest(
'POST',
'/collections/' . $this->getCollectionName() . '/points/recommend' . $this->queryBuild($queryParams),
$request->toArray()
)
);
}

/**
* Retrieves points in batches that are closer to stored positive examples and further from negative examples.
*
* @param BatchRecommendRequest $request
* @param array $queryParams
* @return Response
*/
public function batch(BatchRecommendRequest $request, array $queryParams = []): Response
{

return $this->client->execute(
$this->createRequest(
'POST',
'/collections/' . $this->getCollectionName() . '/points/recommend/batch' . $this->queryBuild($queryParams),
$request->toArray()
)
);
}

/**
* @throws InvalidArgumentException
*/
public function groups($request, array $queryParams = []): Response

Check warning on line 58 in src/Endpoints/Collections/Points/Recommend.php

View check run for this annotation

Codecov / codecov/patch

src/Endpoints/Collections/Points/Recommend.php#L58

Added line #L58 was not covered by tests
{
throw new \RuntimeException('Not implemented on client!');

Check warning on line 60 in src/Endpoints/Collections/Points/Recommend.php

View check run for this annotation

Codecov / codecov/patch

src/Endpoints/Collections/Points/Recommend.php#L60

Added line #L60 was not covered by tests
}
}
42 changes: 42 additions & 0 deletions src/Models/Filter/Condition/GeoPolygon.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
<?php
/**
* @since May 2023
* @author Haydar KULEKCI <[email protected]>
*/

namespace Qdrant\Models\Filter\Condition;

use Qdrant\Exception\InvalidArgumentException;
use Qdrant\Domain\Assert;

class GeoPolygon extends AbstractCondition implements ConditionInterface
{
public function __construct(string $key, protected array $exterior, protected ?array $interiors = null)
{
parent::__construct($key);

if (empty($this->exterior)) {
throw new InvalidArgumentException('Exteriors required!');
}

foreach ($this->exterior as $point) {
Assert::keysExists($point, ['lat', 'lon'], 'Each point of polygon needs lat and lon parameters');
}
if ($interiors) {
foreach ($this->interiors as $point) {
Assert::keysExists($point, ['lat', 'lon'], 'Each point of polygon needs lat and lon parameters');
}
}
}

public function toArray(): array
{
return [
'key' => $this->key,
'geo_polygon' => [
'exterior' => $this->exterior,
'interiors' => $this->interiors ?? []
]
];
}
}
26 changes: 26 additions & 0 deletions src/Models/Filter/Filter.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class Filter implements ConditionInterface
protected array $must = [];
protected array $must_not = [];
protected array $should = [];
protected array $minShould = [];
protected ?int $minShouldCount;

public function addMust(ConditionInterface $condition): Filter
{
Expand All @@ -35,6 +37,20 @@ public function addShould(ConditionInterface $condition): Filter
return $this;
}

public function addMinShould(ConditionInterface $condition): Filter
{
$this->minShould[] = $condition;

return $this;
}

public function setMinShouldCount(int $count): Filter
{
$this->minShouldCount = $count;

return $this;
}

public function toArray(): array
{
$filter = [];
Expand All @@ -59,6 +75,16 @@ public function toArray(): array
$filter['should'][] = $should->toArray();
}
}
if ($this->minShould && $this->minShouldCount) {
$filter['min_should'] = [
'conditions' => [],
'min_count' => $this->minShouldCount
];
foreach ($this->minShould as $should) {
/** ConditionInterface $must */
$filter['min_should']['conditions'][] = $should->toArray();
}
}

return $filter;
}
Expand Down
50 changes: 50 additions & 0 deletions src/Models/Request/Points/BatchRecommendRequest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
<?php
/**
* RecommendRequest
*
* @since Jun 2023
* @author Greg Priday <[email protected]>
*/
namespace Qdrant\Models\Request\Points;

use Qdrant\Exception\InvalidArgumentException;
use Qdrant\Models\Filter\Filter;
use Qdrant\Models\Traits\ProtectedPropertyAccessor;

class BatchRecommendRequest
{
use ProtectedPropertyAccessor;

/** @var RecommendRequest[] $searches */
protected array $searches = [];

/**
* @param RecommendRequest[] $searches
*/
public function __construct(array $searches)
{
foreach ($searches as $search) {
$this->addSearch($search);
}
}

public function addSearch(RecommendRequest $request): static
{
$this->searches[] = $request;

return $this;
}

public function toArray(): array
{
$searches = [];

foreach ($this->searches as $search) {
$searches[] = $search->toArray();
}

return [
'searches' => $searches
];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,31 @@
* @since Jun 2023
* @author Greg Priday <[email protected]>
*/
namespace Qdrant\Models\Request;
namespace Qdrant\Models\Request\Points;

use Qdrant\Exception\InvalidArgumentException;
use Qdrant\Models\Filter\Filter;
use Qdrant\Models\Traits\ProtectedPropertyAccessor;

class RecommendRequest
{
use ProtectedPropertyAccessor;

/**
* average_vector - Average positive and negative vectors and create a single query with the formula
* query = avg_pos + avg_pos - avg_neg. Then performs normal search.
*/
const STRATEGY_AVERAGE_VECTOR = 'average_vector';

/**
* best_score - Uses custom search objective. Each candidate is compared against all examples, its
* score is then chosen from the max(max_pos_score, max_neg_score). If the max_neg_score is chosen
* then it is squared and negated, otherwise it is just the max_pos_score.
*/
const STRATEGY_BEST_SCORE = 'best_score';

protected ?string $shardKey = null;
protected ?string $strategy = null;
protected ?Filter $filter = null;
protected ?string $using = null;
protected ?int $limit = null;
Expand All @@ -31,6 +47,27 @@ public function setFilter(Filter $filter): static
return $this;
}

public function setShardKey(string $shardKey): static
{
$this->shardKey = $shardKey;

return $this;
}

public function setStrategy(string $strategy): static
{
$strategies = [
self::STRATEGY_AVERAGE_VECTOR,
self::STRATEGY_BEST_SCORE,
];
if (!in_array($strategy, $strategies)) {
throw new InvalidArgumentException('Invalid strategy for recommendation.');
}
$this->strategy = $strategy;

return $this;
}

public function setScoreThreshold(float $scoreThreshold): static
{
$this->scoreThreshold = $scoreThreshold;
Expand Down Expand Up @@ -66,19 +103,25 @@ public function toArray(): array
'negative' => $this->negative,
];

if ($this->shardKey !== null) {
$body['shard_key'] = $this->shardKey;
}
if ($this->filter !== null && $this->filter->toArray()) {
$body['filter'] = $this->filter->toArray();
}
if($this->scoreThreshold) {
if($this->scoreThreshold !== null) {
$body['score_threshold'] = $this->scoreThreshold;
}
if ($this->using) {
if ($this->using !== null) {
$body['using'] = $this->using;
}
if ($this->limit) {
if ($this->limit !== null) {
$body['limit'] = $this->limit;
}
if ($this->offset) {
if ($this->strategy !== null) {
$body['strategy'] = $this->strategy;
}
if ($this->offset !== null) {
$body['offset'] = $this->offset;
}

Expand Down
Loading
Loading