#include "table/merger.h"
#include "rocksdb/comparator.h"
#include "rocksdb/iterator.h"
#include "table/iter_heap.h"
#include "table/iterator_wrapper.h"
#include <vector>
namespace
rocksdb {
namespace
{
class
MergingIterator :
public
Iterator {
public
:
MergingIterator(
const
Comparator* comparator, Iterator** children,
int
n)
: comparator_(comparator),
children_(n),
current_(
nullptr
),
direction_(kForward),
maxHeap_(NewMaxIterHeap(comparator_)),
minHeap_ (NewMinIterHeap(comparator_)) {
for
(
int
i = 0; i < n; i++) {
children_[i].Set(children[i]);
}
for
(
auto
& child : children_) {
if
(child.Valid()) {
minHeap_.push(&child);
}
}
}
virtual
~MergingIterator() { }
virtual
bool
Valid()
const
{
return
(current_ !=
nullptr
);
}
virtual
void
SeekToFirst() {
ClearHeaps();
for
(
auto
& child : children_) {
child.SeekToFirst();
if
(child.Valid()) {
minHeap_.push(&child);
}
}
FindSmallest();
direction_ = kForward;
}
virtual
void
SeekToLast() {
ClearHeaps();
for
(
auto
& child : children_) {
child.SeekToLast();
if
(child.Valid()) {
maxHeap_.push(&child);
}
}
FindLargest();
direction_ = kReverse;
}
virtual
void
Seek(
const
Slice& target) {
ClearHeaps();
for
(
auto
& child : children_) {
child.Seek(target);
if
(child.Valid()) {
minHeap_.push(&child);
}
}
FindSmallest();
direction_ = kForward;
}
virtual
void
Next() {
assert
(Valid());
if
(direction_ != kForward) {
ClearHeaps();
for
(
auto
& child : children_) {
if
(&child != current_) {
child.Seek(key());
if
(child.Valid() &&
comparator_->Compare(key(), child.key()) == 0) {
child.Next();
}
if
(child.Valid()) {
minHeap_.push(&child);
}
}
}
direction_ = kForward;
}
current_->Next();
if
(current_->Valid()){
minHeap_.push(current_);
}
FindSmallest();
}
virtual
void
Prev() {
assert
(Valid());
if
(direction_ != kReverse) {
ClearHeaps();
for
(
auto
& child : children_) {
if
(&child != current_) {
child.Seek(key());
if
(child.Valid()) {
child.Prev();
}
else
{
child.SeekToLast();
}
if
(child.Valid()) {
maxHeap_.push(&child);
}
}
}
direction_ = kReverse;
}
current_->Prev();
if
(current_->Valid()) {
maxHeap_.push(current_);
}
FindLargest();
}
virtual
Slice key()
const
{
assert
(Valid());
return
current_->key();
}
virtual
Slice value()
const
{
assert
(Valid());
return
current_->value();
}
virtual
Status status()
const
{
Status status;
for
(
auto
& child : children_) {
status = child.status();
if
(!status.ok()) {
break
;
}
}
return
status;
}
private
:
void
FindSmallest();
void
FindLargest();
void
ClearHeaps();
const
Comparator* comparator_;
std::vector<IteratorWrapper> children_;
IteratorWrapper* current_;
enum
Direction {
kForward,
kReverse
};
Direction direction_;
MaxIterHeap maxHeap_;
MinIterHeap minHeap_;
};
void
MergingIterator::FindSmallest() {
if
(minHeap_.empty()) {
current_ =
nullptr
;
}
else
{
current_ = minHeap_.top();
assert
(current_->Valid());
minHeap_.pop();
}
}
void
MergingIterator::FindLargest() {
if
(maxHeap_.empty()) {
current_ =
nullptr
;
}
else
{
current_ = maxHeap_.top();
assert
(current_->Valid());
maxHeap_.pop();
}
}
void
MergingIterator::ClearHeaps() {
maxHeap_ = NewMaxIterHeap(comparator_);
minHeap_ = NewMinIterHeap(comparator_);
}
}
Iterator* NewMergingIterator(
const
Comparator* cmp, Iterator** list,
int
n) {
assert
(n >= 0);
if
(n == 0) {
return
NewEmptyIterator();
}
else
if
(n == 1) {
return
list[0];
}
else
{
return
new
MergingIterator(cmp, list, n);
}
}
}