Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

couple of useful extensions (for me at least) #16

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions src/Chains/VectorDbQa/VectorDBQA.php
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@ class VectorDBQA extends Chain
public string $outputKey = 'result';
public bool $returnSourceDocuments = false;
public array $searchKwargs;
public string $searchType = 'similarity';
public string $searchType = VectorStore::SIMILARITY_SEARCH;

public function __construct($params)
{
parent::__construct(null, null, null, $params);
$this->vectorstore = $params['vectorstore'];
$this->k = $params['k'] ?? $this->k;
$this->combineDocumentsChain = $params['combine_documents_chain'];
$this->searchKwargs = $params['search_kwargs'] ?? [];
$this->searchType = $params['search_type'] ?? $this->searchType;
$this->validateSearchType();
}

Expand Down Expand Up @@ -74,7 +76,7 @@ public function outputKeys(): array
*/
private function validateSearchType(): void
{
if (!in_array($this->searchType, ['similarity', 'mmr'])) {
if (!in_array($this->searchType, [VectorStore::SIMILARITY_SEARCH, VectorStore::MAX_MARGINAL_RELEVANCE_SEARCH])) {
throw new Exception('search_type of ' . $this->searchType . ' not allowed.');
}
}
Expand Down Expand Up @@ -122,13 +124,17 @@ public static function fromLLM(
public static function fromChainType(
BaseLLM $llm,
string $chainType = 'stuff',
?BasePromptTemplate $promptTemplate = null,
?string $documentVariableName = 'context',
?array $chainType_kwargs = null,
array $kwargs = []
): VectorDBQA {
$chainType_kwargs = $chainType_kwargs ?? [];
$combineDocuments_chain = self::loadQAChain(
$llm,
$chainType,
$promptTemplate,
$documentVariableName,
null,
null,
$chainType_kwargs
Expand All @@ -152,15 +158,17 @@ public static function fromChainType(
public static function loadQAChain(
BaseLanguageModel $llm,
string $chainType = 'stuff',
?BasePromptTemplate $promptTemplate = null,
?string $documentVariableName = 'context',
?bool $verbose = null,
?BaseCallbackManager $callbackManager = null,
?array $kwargs = []
): BaseCombineDocumentsChain {
return match ($chainType) {
'stuff' => self::loadStuffChain(
$llm,
null,
'context',
$promptTemplate,
$documentVariableName,
$verbose,
$callbackManager,
$kwargs
Expand Down Expand Up @@ -254,9 +262,9 @@ protected function call(array $inputs): array
{
$question = $inputs[$this->inputKey];

if ($this->searchType === 'similarity') {
if ($this->searchType === VectorStore::SIMILARITY_SEARCH) {
$docs = $this->vectorstore->similaritySearch($question, $this->k, $this->searchKwargs);
} elseif ($this->searchType === 'mmr') {
} elseif ($this->searchType === VectorStore::MAX_MARGINAL_RELEVANCE_SEARCH) {
$docs = $this->vectorstore->maxMarginalRelevanceSearch($question, $this->k, $this->searchKwargs);
} else {
throw new Exception('search_type of ' . $this->searchType . ' not allowed.');
Expand Down
27 changes: 27 additions & 0 deletions src/DocumentLoaders/StringLoader.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
<?php

namespace Kambo\Langchain\DocumentLoaders;

use Kambo\Langchain\Docstore\Document;

/**
* Load strings.
*/
final class StringLoader extends BaseLoader
{
/**
* @param string $text
*/
public function __construct(private string $text)
{
}

/**
* @return Document[]
*/
public function load(): array
{
$metadata = [];
return [new Document(pageContent:$this->text, metadata: $metadata)];
}
}
13 changes: 11 additions & 2 deletions src/Indexes/VectorStoreIndexWrapper.php
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
namespace Kambo\Langchain\Indexes;

use Kambo\Langchain\LLMs\BaseLLM;
use Kambo\Langchain\Prompts\BasePromptTemplate;
use Kambo\Langchain\Prompts\PromptTemplate;
use Kambo\Langchain\VectorStores\VectorStore;
use Kambo\Langchain\LLMs\OpenAI;
use Kambo\Langchain\Chains\VectorDbQa\VectorDBQA;
Expand All @@ -27,12 +29,19 @@ public function __construct(public VectorStore $vectorStore)
*
* @return string
*/
public function query(string $question, ?BaseLLM $llm = null, array $additionalParams = []): string
{
public function query(
string $question,
?BaseLLM $llm = null,
?BasePromptTemplate $promptTemplate = null,
?string $documentVariableName = 'context',
array $additionalParams = []
): string {
$llm = $llm ?? new OpenAI(['temperature' => 0]);
$chain = VectorDBQA::fromChainType(
$llm,
'stuff',
$promptTemplate,
$documentVariableName,
null,
array_merge(['vectorstore' => $this->vectorStore], $additionalParams)
);
Expand Down
62 changes: 62 additions & 0 deletions src/VectorStores/CachedSimpleStupidVectorStore.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
<?php

namespace Kambo\Langchain\VectorStores;

use Kambo\Langchain\Embeddings\Embeddings;
use Kambo\Langchain\VectorStores\SimpleStupidVectorStore\SimpleStupidVectorStore as SSVStorage;
use Psr\Cache\CacheItemPoolInterface;

class CachedSimpleStupidVectorStore extends SimpleStupidVectorStore
{
/** @var CacheItemPoolInterface */
private $cacheItemPool;

public function __construct(
private Embeddings $embedding,
SSVStorage $simpleStupidVectorStorage = null,
$options = []
) {
parent::__construct($embedding, $simpleStupidVectorStorage, $options);
$this->cacheItemPool = $options['cacheItemPool'] ?? null;
}

public function addTexts(iterable $texts, ?array $metadata = null, array $additionalArguments = []): array
{
$textsHash = md5(serialize($texts));

if ($this->cacheItemPool instanceof CacheItemPoolInterface) {
$cachedItem = $this->cacheItemPool->getItem($textsHash);

if ($cachedItem->isHit()) {
$embeddings = $cachedItem->get();
} else {
$embeddings = $this->embedding->embedDocuments($texts);
$this->cacheItemPool->save($cachedItem->set($embeddings));
}

return parent::addTexts($texts, $metadata, array_merge($additionalArguments, ['embeddings' => $embeddings]));
}

return parent::addTexts($texts, $metadata, $additionalArguments);
}

public function similaritySearch(string $query, int $k = 4, array $additionalArguments = []): array
{
$queryHash = md5(serialize($query));

if ($this->cacheItemPool instanceof CacheItemPoolInterface) {
$cachedItem = $this->cacheItemPool->getItem($queryHash);

if ($cachedItem->isHit()) {
$embeddings = $cachedItem->get();
} else {
$embeddings = $this->embedding->embedQuery($query);
$this->cacheItemPool->save($cachedItem->set($embeddings));
}

parent::similaritySearch($query, $k, array_merge($additionalArguments, ['embeddings' => $embeddings]));
}

return parent::similaritySearch($query, $k, $additionalArguments);
}
}
8 changes: 4 additions & 4 deletions src/VectorStores/SimpleStupidVectorStore.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class SimpleStupidVectorStore extends VectorStore
{
private const LANGCHAIN_DEFAULT_COLLECTION_NAME = 'langchain';
private ?SSVStorage $storage;
private $collection;
protected $collection;

public function __construct(
private Embeddings $embedding,
Expand All @@ -35,7 +35,7 @@ public function __construct(

public function addTexts(iterable $texts, ?array $metadata = null, array $additionalArguments = []): array
{
$embeddings = $this->embedding->embedDocuments($texts);
$embeddings = $additionalArguments['embeddings'] ?? $this->embedding->embedDocuments($texts);

$uuids = [];
for ($i = 0; $i < count($texts); $i++) {
Expand All @@ -50,7 +50,7 @@ public function addTexts(iterable $texts, ?array $metadata = null, array $additi

public function similaritySearch(string $query, int $k = 4, array $additionalArguments = []): array
{
$embeddings = $this->embedding->embedQuery($query);
$embeddings = $additionalArguments['embeddings'] ?? $this->embedding->embedQuery($query);
$data = $this->collection->similaritySearchWithScore($embeddings, $k);

$documents = [];
Expand All @@ -67,7 +67,7 @@ public static function fromTexts(
?array $metadata = null,
array $additionalArguments = []
): VectorStore {
$self = new self($embedding, null, $additionalArguments);
$self = new static($embedding, null, $additionalArguments);

$self->addTexts($texts, $metadata);
return $self;
Expand Down
3 changes: 3 additions & 0 deletions src/VectorStores/VectorStore.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
*/
abstract class VectorStore
{
const SIMILARITY_SEARCH = 'similarity';
const MAX_MARGINAL_RELEVANCE_SEARCH = 'mmr';

/**
* @param iterable $texts Iterable of strings to add to the vectorstore.
* @param array|null $metadata Optional list of metadatas associated with the texts.
Expand Down
8 changes: 2 additions & 6 deletions tests/Chains/VectorDBQA/VectorDBQATest.php
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ public function testRun(): void

$chain = VectorDBQA::fromChainType(
$openAI,
'stuff',
null,
[
kwargs: [
'vectorstore' => new SimpleStupidVectorStore($embeddings)
]
);
Expand All @@ -105,9 +103,7 @@ public function testToArray(): void

$chain = VectorDBQA::fromChainType(
$openAI,
'stuff',
null,
[
kwargs: [
'vectorstore' => new SimpleStupidVectorStore($embeddings)
]
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
<?php

namespace Kambo\Langchain\Tests\DocumentLoaders;

use Kambo\Langchain\DocumentLoaders\BaseLoader;
use Kambo\Langchain\TextSplitter\RecursiveCharacterTextSplitter;
use PHPUnit\Framework\TestCase;

use const PHP_EOL;

abstract class AbstractDocumentFromFixturesLoaderTestCase extends TestCase
{
abstract protected function getDocumentLoader(): BaseLoader;

public function testLoad()
{
$documentLoader = static::getDocumentLoader();
$this->assertEquals(
'Obchodní řetězec Billa v pondělí v Česku spustil pilotní verzi svého e-shopu, '
. 'dostupný je zatím v Praze, v Brně a v jejich blízkém okolí.' . PHP_EOL,
$documentLoader->load()[0]->pageContent
);
}

public function testLoadAndSplitDefault()
{
$documentLoader = static::getDocumentLoader();
$this->assertEquals(
'Obchodní řetězec Billa v pondělí v Česku spustil pilotní verzi svého e-shopu, '
. 'dostupný je zatím v Praze, v Brně a v jejich blízkém okolí.',
$documentLoader->loadAndSplit()[0]->pageContent
);
}

public function testLoadAndSplit()
{
$textSplitter = new RecursiveCharacterTextSplitter(['chunk_size' => 60, 'chunk_overlap' => 1]);
$documentLoader = static::getDocumentLoader();

$documents = $documentLoader->loadAndSplit($textSplitter);
$this->assertCount(3, $documents);
$this->assertEquals(
'Obchodní řetězec Billa v pondělí v Česku spustil',
$documents[0]->pageContent
);

$this->assertEquals(
'pilotní verzi svého e-shopu, dostupný je zatím v Praze,',
$documents[1]->pageContent
);
}
}
26 changes: 26 additions & 0 deletions tests/DocumentLoaders/StringLoaderTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<?php

namespace Kambo\Langchain\Tests\DocumentLoaders;

use Kambo\Langchain\DocumentLoaders\BaseLoader;
use Kambo\Langchain\DocumentLoaders\StringLoader;

use SplFileInfo;
use const DIRECTORY_SEPARATOR;

class StringLoaderTest extends AbstractDocumentFromFixturesLoaderTestCase
{
private $text;

public function setUp(): void
{
parent::setUp();
$filePath = new SplFileInfo(__DIR__ . DIRECTORY_SEPARATOR . 'fixtures.txt');
$this->text = file_get_contents($filePath->getRealPath());
}

protected function getDocumentLoader(): BaseLoader
{
return new StringLoader($this->text);
}
}
Loading