diff --git a/CHANGELOG.md b/CHANGELOG.md index 0185c420a..35b804e26 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added support for numeric `minimum_should_match` values in `TermsSet` query [#2293](https://github.com/ruflin/Elastica/pull/2293) * Added support for "search after" based pagination [#1645](https://github.com/ruflin/Elastica/issues/1645) * Added support for the `seq_no_primary_term` search option and the `if_seq_no` / `if_primary_term` index options to enable optimistic concurrency control [#2284](https://github.com/ruflin/Elastica/pull/2284) +* Added `Elastica\Query\Knn` and `Elastica\Query::setKnn()` to support top-level kNN search (single or multiple kNN searches), with optional filters, similarity threshold and boost ### Changed * `Elastica\Util::convertDate()` now throws `Elastica\Exception\InvalidException` when given a string that `strtotime()` cannot parse, instead of silently producing `1970-01-01T00:00:00Z`. The return type has also been narrowed to `string` and a typo-prone `null` argument is no longer accepted at runtime. ### Deprecated diff --git a/src/Query.php b/src/Query.php index 923c3e758..d39a70db1 100644 --- a/src/Query.php +++ b/src/Query.php @@ -7,6 +7,7 @@ use Elastica\Aggregation\AbstractAggregation; use Elastica\Exception\InvalidException; use Elastica\Query\AbstractQuery; +use Elastica\Query\Knn; use Elastica\Query\MatchAll; use Elastica\Query\QueryString; use Elastica\Rescore\Query as QueryRescore; @@ -52,6 +53,7 @@ * from?: int, * highlight?: THighlightArgs, * indices_boost?: array, + * knn?: array|list>, * min_score?: float, * pit?: PointInTime, * post_filter?: AbstractQuery, @@ -336,7 +338,7 @@ public function addAggregation(AbstractAggregation $agg): self */ public function toArray(): array { - if (!$this->hasSuggest && !isset($this->_params['query'])) { + if (!isset($this->_params['knn']) && !$this->hasSuggest && !isset($this->_params['query'])) { $this->setQuery(new MatchAll()); } @@ -494,4 +496,26 @@ public function setSearchAfter(array $searchAfter): self return $this; } + + /** + * Sets a top-level kNN search. + * + * Pass a list of {@see Knn} to combine several kNN searches in the same request. + * The Knn is serialized at call time, so further mutations on the passed object + * are not reflected in the query - compose it fully before calling setKnn(). + * + * @param Knn|list $knn + * + * @see https://www.elastic.co/docs/solutions/search/vector/knn + */ + public function setKnn(Knn|array $knn): self + { + if (\is_array($knn)) { + $value = \array_map(static fn (Knn $k): array => $k->toArray()['knn'], $knn); + } else { + $value = $knn->toArray()['knn']; + } + + return $this->setParam('knn', $value); + } } diff --git a/src/Query/Knn.php b/src/Query/Knn.php new file mode 100644 index 000000000..34f3c356e --- /dev/null +++ b/src/Query/Knn.php @@ -0,0 +1,55 @@ +setParam('field', $field); + $this->setParam('query_vector', $queryVector); + $this->setParam('k', $k); + $this->setParam('num_candidates', $numCandidates); + } + + /** + * Adds a Query DSL filter applied before the kNN search. + * + * Filters are ANDed together by Elasticsearch. + */ + public function addFilter(AbstractQuery $filter): self + { + return $this->addParam('filter', $filter); + } + + /** + * Sets the minimum similarity required for a document to be considered a match. + */ + public function setSimilarity(float $similarity): self + { + return $this->setParam('similarity', $similarity); + } + + /** + * Boost applied to the kNN score before it is combined with other clauses. + */ + public function setBoost(float $boost): self + { + return $this->setParam('boost', $boost); + } +} diff --git a/tests/Query/KnnTest.php b/tests/Query/KnnTest.php new file mode 100644 index 000000000..da40b734b --- /dev/null +++ b/tests/Query/KnnTest.php @@ -0,0 +1,134 @@ + [ + 'field' => 'vector', + 'query_vector' => [0.1, 0.2, 0.3], + 'k' => 100, + 'num_candidates' => 200, + ], + ]; + + $this->assertSame($expected, $knn->toArray()); + } + + #[Group('unit')] + public function testToArrayWithFiltersSimilarityAndBoost(): void + { + $knn = new Knn('vector', [0.5, 0.5], 10, 20); + $knn->addFilter(new Terms('tag', ['foo'])); + $knn->addFilter(new Range('age', ['gte' => 20])); + $knn->setSimilarity(0.7); + $knn->setBoost(1.5); + + $expected = [ + 'knn' => [ + 'field' => 'vector', + 'query_vector' => [0.5, 0.5], + 'k' => 10, + 'num_candidates' => 20, + 'filter' => [ + ['terms' => ['tag' => ['foo']]], + ['range' => ['age' => ['gte' => 20]]], + ], + 'similarity' => 0.7, + 'boost' => 1.5, + ], + ]; + + $this->assertSame($expected, $knn->toArray()); + } + + #[Group('unit')] + public function testQuerySetKnnEmbedsSingleKnnAtTopLevel(): void + { + $query = new Query(); + $query->setKnn(new Knn('vector', [0.1, 0.2], 5, 10)); + + $body = $query->toArray(); + + $this->assertSame([ + 'field' => 'vector', + 'query_vector' => [0.1, 0.2], + 'k' => 5, + 'num_candidates' => 10, + ], $body['knn']); + $this->assertArrayNotHasKey('query', $body, 'knn-only requests must not be auto-padded with a match_all query'); + } + + #[Group('unit')] + public function testQuerySetKnnAcceptsListOfKnnForMultipleKnnSearches(): void + { + $query = new Query(); + $query->setKnn([ + new Knn('a.vector', [0.1], 5, 10), + new Knn('b.vector', [0.2], 5, 10), + ]); + + $body = $query->toArray(); + + $this->assertCount(2, $body['knn']); + $this->assertSame('a.vector', $body['knn'][0]['field']); + $this->assertSame('b.vector', $body['knn'][1]['field']); + $this->assertArrayNotHasKey('query', $body, 'multi-knn requests must not be auto-padded with a match_all query'); + } + + #[Group('functional')] + public function testKnnSearchAgainstDenseVectorField(): void + { + $index = $this->_createIndex(); + $index->setMapping(new Mapping([ + 'tag' => ['type' => 'keyword'], + 'vector' => [ + 'type' => 'dense_vector', + 'dims' => 3, + 'index' => true, + 'similarity' => 'cosine', + ], + ])); + + $index->addDocuments([ + new Document('1', ['tag' => 'foo', 'vector' => [1.0, 0.0, 0.0]]), + new Document('2', ['tag' => 'foo', 'vector' => [0.9, 0.1, 0.0]]), + new Document('3', ['tag' => 'bar', 'vector' => [0.0, 0.0, 1.0]]), + ]); + $index->refresh(); + + $knn = new Knn('vector', [1.0, 0.0, 0.0], 2, 10); + $knn->addFilter(new Terms('tag', ['foo'])); + + $query = new Query(); + $query->setKnn($knn); + + $results = $index->search($query); + + $ids = \array_map(static fn ($r): string => $r->getId(), $results->getResults()); + + $this->assertContains('1', $ids); + $this->assertContains('2', $ids); + $this->assertNotContains('3', $ids, 'tag filter must exclude documents with another tag value'); + } +}