tang-hi

Don't Panic

经典的倒排索引 - Finite State Transducers (实现篇)

前一篇文章中,我们介绍了 FST 作为倒排索引的优势,同时也介绍了 FST 的基本概念以及相应的算法流程。 在这篇文章中,我们会详细介绍如何用 C++ 实现一个 FST 以及相关的实现细节。

  1. 我们会先介绍构建阶段 FST 所需要用到的数据结构。
  2. 再实现 FstBuilder,参照上一篇的四个步骤一步步进行开发。
  3. 最后实现一个 FstSearcher,用于在构建好的 FST 上进行搜索。

0. 最终目标

在开始之前,我们先来看一下最终目标。我们希望实现两个类:FstBuilderFstSearcher。它们的接口如下所示:

FstBuilder builder;

// add must be called with inputs in lexicographical order
builder.add("a", 1);
builder.add("ab", 2);
builder.add("cap", 1);
builder.add("tap", 1);

builder.finish(); // after this no more inputs will be added

// immutable FST is built
FstSearcher searcher = builder.buildFst();

std::optional<int> v1 = searcher.search("ab"); // v1 == 2
std::optional<int> v2 = searcher.search("cad");  // v2 == std::nullopt

为了便于理解,我们在这里实现一个最朴素的 std::string -> int 映射的 FST。

1. 构建时所用到的数据结构

构建时我们需要用到一些数据结构来表示 FST 的节点和边。我们对照上一篇文章中的定义来实现这些数据结构。

struct Node;
struct Arc {
  char label; // edge label
  Node* target; // target node
  int output; // output value

  bool targetCompiled; // whether target node is compiled
  uint64_t targetAddress; // address of target node in serialized FST
  bool isFinal; // whether target node is final
  int finalOutput; // final output value of target node
};

struct Node {
  std::vector<Arc> arcs; // outgoing edges
  bool isFinal; // is final node
  int finalOutput; // final output value

  bool compiled; // whether this node has been compiled
  uint64_t address; // address in the serialized FST
};

为了加深印象,我们再介绍一下 FST 边(Arc)和节点(Node)的定义。

Arc 的定义如下

  • label: 边的标签,表示这条边所代表的字符
  • target: 这条边所指向的目标节点
  • output: 这条边所携带的输出值(值 >= 0)
  • targetCompiled: 一个布尔值,表示目标节点是否已经被序列化
  • targetAddress: 如果目标节点已经被序列化,则表示该节点在文件中的偏移地址
  • isFinal: 一个布尔值,表示目标节点是否为终止节点
  • finalOutput: 如果目标节点是终止节点,则表示该节点的最终输出值(详细作用见上一篇文章)

Node 的定义如下

  • arcs: 该节点的所有出边
  • isFinal: 一个布尔值,表示该节点是否为终止节点
  • finalOutput: 如果该节点是终止节点,则表示该节点的最终输出值(详细作用见上一篇文章)
  • compiled: 一个布尔值,表示该节点是否已经被序列化
  • address: 如果该节点已经被序列化,则表示该节点在文件中的偏移地址

在明确了 FST 的基本数据结构后,我们就可以给出 FstBuilder 的定义了。

class FstBuilder {
  public:
    // add a new key-output pair
    bool add(const std::string &input, int out);

    // finish the building process
    void finish();

    // convert to FST
    FstSearcher buildFst();

  private:
    Meta meta_; // FST metadata
    std::string lastInput_; // last input added
    std::vector<std::unique_ptr<Node>> frontier_; // nodes waiting for processing
    std::unordered_map<std::vector<std::byte>, uint64_t> nodeCache_; // nodes could be reused
};

函数的定义就不说了,语义都很清晰。我们重点介绍一下 FstBuilder 中维护的变量:

  • meta_: FST 的元数据,包含根节点地址和字节大小
  • lastInput_: 上一个被添加的输入字符串,用于寻找前缀
  • frontier_: 一个节点数组,用于存储当前待处理(待冻结)的节点
  • nodeCache_: 一个哈希表,用于存储已经处理过的节点,以便进行节点复用

看了前一篇文章的话,你对这些变量应该都不陌生。不过即使你印象不深,我们在后续的实现过程中也会一一解释它们的作用。

2. 构建主流程

我们依旧先给出整个构建流程的代码框架,然后再聚焦每一个步骤的实现细节。

bool FstBuilder::add(const std::string &input, int out) {

    if (input < lastInput_) {
        throw std::invalid_argument("Inputs must be added in lexicographical order");
    }
    
    // step 1: find common prefix length
    uint64_t prefixLen = commonPrefixLength(lastInput_, input);

    // step 2: process suffix of last input
    freezeTail(prefixLen + 1);

    // step 3: insert new input
    insertNewInput(input, prefixLen + 1);

    // step 4: adjust outputs
    adjustOutputs(input, out, prefixLen + 1);

    // update last input
    lastInput_ = input;
    return true;
}

后续我们详细介绍每一个步骤的实现细节。

2.1 寻找前缀:commonPrefixLength

commonPrefixLength 用于寻找两个字符串的最长公共前缀长度。通过这个长度,我们可以确定哪些节点在后续构建的过程中不会再新增出边,从而可以进行冻结操作(写到磁盘中)。

uint64_t commonPrefixLength(const std::string &a, const std::string &b) {
    uint64_t minLength = std::min(a.size(), b.size());
    for (uint64_t i = 0; i < minLength; ++i) {
        if (a[i] != b[i]) {
            return i;
        }
    }
    return minLength;
}

2.2 处理后缀:freezeTail

在第一步中,我们找到了最长公共前缀的长度,所以在这个长度之后的节点都不会再有变化了(因为输入是按照字典序添加的,读者可以自己思考一下)。我们可以对这些节点进行冻结处理。

这是整个构建流程中较为复杂的一步,我们仍旧先给出代码框架,然后对照代码进行解释。

void FstBuilder::freezeTail(uint64_t startIndex) {
  startIndex = std::max(startIndex, 1UL); ---------------------- a
  for (auto i = lastInput_.size(); i >= startIndex; --i) {
    auto &current = frontier_[i];
    auto &parent = frontier_[i - 1];
    bool isFinal = current->isFinal;  
    auto finalOutput = current->finalOutput;
    auto compiledAddress = compileNode(current.get());  -------- b

    // update the arc in parent pointing to current
    Arc &pointingArc = parent->arcs.back();     ---------------|
    pointingArc.targetCompiled = true;                         |
    pointingArc.targetAddress = compiledAddress;               | c
    pointingArc.target = nullptr;                              |
    pointingArc.isFinal = isFinal;                             |
    pointingArc.finalOutput = finalOutput; --------------------| 

    current->reset(); ------------------------------------------ d
  }
}

我们首先看代码 a 处,这里我们确保 startIndex 至少为 1,因为根节点在构建时是不会被冻结的。只有当我们所有的输入全部处理完毕后,根节点才会被处理。

接下来我们进入循环,从上一个输入的最后一个节点开始,一直到公共前缀的下一个节点为止。对于每一个节点,我们都需要进行冻结处理。

在代码 b 处,我们调用 compileNode 函数对当前处理的节点进行序列化,并将其写入磁盘中。然后将序列化后的节点在文件中的偏移作为返回值返回。

在完成节点序列化后,我们需要更新其父节点中指向该节点的边(Arc)。 这部分代码在 c 处。我们通过 parent->arcs.back() 获取到指向当前节点的边,然后将其目标节点信息更新为刚才序列化后的地址,并把目标节点指针置为空。 同时,我们还需要把当前节点的终止状态和 finalOutput 也更新到边上。原因后面会讲到。

最后,在代码 d 处,我们调用 current->reset() 函数来重置当前节点。这是因为构建过程中 frontier_ 中的节点是会复用的,当该节点的所有信息都已经 保存完毕后,我们就可以将其重置,方便后续进行复用。

现在我们再来看一下 compileNode 函数的实现:

uint64_t FstBuilder::compileNode(Node* node) {
    // if already compiled, return its address
    if (node->compiled) {
        return node->address;
    }

    // if node has no arcs and is not final, return SPECIAL_ADDRESS 
    if (node->arcs.empty()) {  -----------------------------------------                                        
        node->reset();                                                 |
        node->compiled = true;                                         | a
        node->address = detail::SPECIAL_ADDRESS;                       |
        return node->address;                                          |
    } -----------------------------------------------------------------|

    meta_.arcCount += node->arcs.size();

    // serialize node into a temporary buffer
    storage::BufferWriter compiledBuffer;
    serializeNode(node, compiledBuffer); ------------------------------- b

    // write to buffer if not found in cache                                      
    uint64_t nodeAddress = buffer_.size(); -------------------------------
    if (auto iter = nodeCache_.find(compiledBuffer.data());             |
        iter != nodeCache_.end()) {                                     |
        // found in cache                                               |
        nodeAddress = iter->second;                                     |
    } else {                                                            |
        // not found in cache, append to buffer                         |
        buffer_.insert(buffer_.end(), compiledBuffer.rawPtr(),          |c
                       compiledBuffer.rawPtr()                          |
                       + compiledBuffer.size());                        |
        nodeCache_.emplace(compiledBuffer.data(), nodeAddress);         |
        meta_.nodeCount += 1;                                           |
    }-------------------------------------------------------------------|

    node->reset();
    node->compiled = true;
    node->address = nodeAddress;
    return node->address;
}

compileNode 函数中,我们首先检查该节点是否已经被编译过,如果是的话直接返回其地址。

接下来在代码 a 处我们检查当前节点是不是没有出边(arcs)。在这种情况下,我们不需要将其序列化后写入磁盘,可以直接返回一个特殊地址 SPECIAL_ADDRESS,从而节省存储空间。

然后我们在代码 b 处调用 serializeNode 函数,将节点序列化到一个临时缓冲区中。

代码 c 处,则是 FST 的精华所在。我们通过一个哈希表(nodeCache_)来缓存已经序列化过的节点。如果当前节点的序列化结果已经存在于缓存中,我们就直接复用其地址,这就完成了共享后缀。 通过这种方式,我们可以大幅度减少 FST 的存储空间。

这里我用一个大的 buffer_ 模拟“磁盘文件”,nodeAddress 实际上就是节点在这个 buffer_ 中的偏移。 真正工程里可以把这个 buffer_ 写到文件里,或者用 mmap 映射进来。

至此,我们还剩下 serializeNode 函数的实现没有介绍:

void FstBuilder::serializeNode(Node* node, storage::BufferWriter &writer) {
    for (size_t i = 0; i < node->arcs.size(); ++i) {
        auto &arc = node->arcs[i];
        serializeArc(arc, writer, i == node->arcs.size() - 1);
    }
}

serializeNode 函数中我们可以看到,虽然说是序列化节点,但实际上我们只是将节点的所有出边依次序列化而已。

这也解释了之前为什么我们要把节点的终止状态和 finalOutput 信息存储到边上,因为节点本身并没有被序列化,只有边被序列化了。 所以我们需要把这些信息存储到边上,以便在搜索时能够获取这些信息。

至于 serializeArc 函数的实现,这里就不展开讲了,你怎么实现边的序列化都可以,只要你能反序列化回来就行。

2.3 插入新的输入:insertNewInput

在处理完上一个输入的后缀后,我们需要将当前输入的新增部分插入到 FST 中。

void FstBuilder::insertNewInput(const std::string &input, uint64_t fromIndex) {
    // ensure frontier_ has enough nodes
    while (frontier_.size() < input.size() + 1) {
        frontier_.emplace_back(std::make_unique<Node>());
    }

    for (uint64_t i = fromIndex; i <= input.size(); ++i) { ------
        auto &current = frontier_[i];                           |
        auto &parent = frontier_[i - 1];                        |
        Label label = input[i - 1];                             |
                                                                | a
        Arc arc;                                                | 
        arc.label = label;                                      |
        arc.target = current.get();                             |
        parent->arcs.push_back(arc);                            |
    }-----------------------------------------------------------|

    std::unique_ptr<Node> &lastNode = frontier_[input.size()];
    lastNode->isFinal = true;
    lastNode->finalOutput = 0;
}

这部分代码比较简单,我们首先确保 frontier_ 中有足够的节点来存储当前输入的所有节点。

然后在代码 a 处,我们从 fromIndex 开始,依次将当前输入的新增字符插入到 FST 中。 同时父节点需要增加一条出边,指向当前节点。

在完成上述步骤后,我们需要把最后一个节点标记为终止节点,并把它的 finalOutput 设为 0。

2.4 调整输出:adjustOutputs

终于到了构建的最后一步,但这也是较为复杂的一步。在增加了新的输入后,我们需要重新调整路径上的输出值。

void FstBuilder::adjustOutputs(const std::string &input, int out, uint64_t toIndex) {
  int residual = out;          ------------------------------- a
  for (size_t i = 1; i < toIndex; ++i) { 
    auto &parent = frontier_[i - 1];
    auto &current = frontier_[i];

    Arc &arc = parent->arcs.back();

    int lastArcOutput = arc.output;--------------------------|
    int commonOutput = std::min(residual, lastArcOutput);    |b
    int suffixOutput = lastArcOutput - commonOutput;         |
    arc.output = commonOutput;  -----------------------------|
    prependOutput(current.get(), suffixOutput);-------------- c
    residual = residual - commonOutput;
  }

  Arc &arc = frontier_[toIndex - 1]->arcs.back();
  arc.output = residual;
}

在代码 a 处,我们首先将输入值赋给 residual,这个变量用于记录当前还有多少输出值需要分配到边上。

不同于之前的步骤,这次我们是从根节点开始,顺序调整边上的输出值,直到公共前缀的下一个节点为止。 我们在代码 b 处,对 residuallastArcOutput 取最小值,并使用该值作为当前边的输出值。

为什么这么做呢?假设我们现在的 residual 是 10,而当前边的输出值是 15。为了确保我们不会破坏之前路径的输出值,同时又可以让当前路径的累加值等于输入值,我们只能把 当前边的输出值设为 10,即 min(residual, lastArcOutput) 的值。

接下来,我们要将当前边剩余的输出值 suffixOutput 通过 prependOutput 函数下推到下一个节点上。这部分代码在 c 处。然后我们更新 residual,减去当前边的输出值。

最后,在循环结束后,我们需要将 residual 的值赋给公共前缀下一个节点的边的输出值。这样就完成了输出值的调整。

下面我们再简单介绍一下 prependOutput 函数的实现:

void FstBuilder::prependOutput(Node* node, int output) {
  if (output == 0) {
      return;
    }

  for (auto &arc : node->arcs) {
      arc.output = arc.output + output;
    }

  if (node->isFinal) {
      node->finalOutput = node->finalOutput + output;
  }
}

逻辑相当简单,就是把传入的 output 累加到该节点的所有出边上。如果该节点恰好是终止节点,那么我们还需要把 output 累加到 finalOutput 上。

至此,整个 FST 的构建流程就算完成了,只要我们最后调用一下 finish 函数就可以了:

void FstBuilder::finish() {
    // freeze all remaining nodes
    freezeTail(0);
    auto &rootNode = frontier_[0];
    uint64_t rootAddress = compileNode(rootNode.get());
    meta_.rootAddress = rootAddress; // record root address
    meta_.byteSize = buffer_.size();
    isFrozen_ = true;
}

我们冻结所有剩余的节点,然后编译根节点,并记录根节点的地址。最后将构建状态标记为“已冻结”。

调整输出的逻辑理解起来可能比较吃力,可以结合上一篇文章中的例子来配合理解。

3. 搜索的实现

FstSearcher 的实现相对简单,我们只需要沿着输入字符串对应的路径遍历 FST 即可。我们先给出 FstSearcher 的定义:

class FstSearcher {
  public:
    FstSearcher(const storage::BufferReader &buffer, const Meta &meta);

    // search for input string, return output if found
    std::optional<int> search(const std::string &input) const;

  private:
    storage::BufferReader reader_;
    Meta meta_;
};

FstSearcher 中维护了两个变量:

  • reader_: 用于读取 FST 数据的缓冲区
  • meta_: FST 的元数据,包含根节点地址和字节大小

接下来我们来看一下 search 函数的实现:

std::optional<int> FstSearcher::search(const std::string &input) const {
    uint64_t currentAddress = meta_.rootAddress;

    int accumulatedOutput = 0;
    Arc currentArc;

    for (size_t i = 0; i < input.size(); ++i) {           
        reader_.seek(currentAddress);                     
        char currentLabel = input[i];                     
        currentArc.reset();                               
        if (!findLabel(currentAddress, currentLabel, currentArc)) {       
            return std::nullopt;                                          
        }                                                                 
        accumulatedOutput = accumulatedOutput + currentArc.output;
        currentAddress = currentArc.targetAddress;
    }

    if (currentArc.isFinal()) {
        if (currentArc.hasFinalOutput()) {
            accumulatedOutput = accumulatedOutput + currentArc.finalOutput;
        }
        return accumulatedOutput;
    }
    return std::nullopt;
}

首先,我们从 meta_ 中获取根节点的地址,并将输出值 accumulatedOutput 初始化为 0。

随后,我们遍历输入字符串的每一个字符,通过 findLabel 函数在当前节点中查找对应的边。如果找不到对应的边,则说明输入字符串不存在于 FST 中,我们返回 std::nullopt。 如果找到了对应的边,我们将该边的输出值累加到 accumulatedOutput 上,并将当前节点地址更新为该边的目标节点地址。

在遍历完所有字符后,我们检查当前边是否为终止边。如果是的话,我们还需要将其 finalOutput 累加到 accumulatedOutput 上,然后返回最终的输出值。 如果当前边不是终止边,则说明输入字符串不存在于 FST 中,我们返回 std::nullopt

我们接下来再看一下 findLabel 函数的实现:

bool FstSearcher::findLabel(uint64_t nodeAddress, char label, Arc &outArc) const {
    if (nodeAddress == SPECIAL_ADDRESS) {
        return false;
    }

    reader_.seek(nodeAddress);
    while (true) {
        outArc.flag = static_cast<ArcFlag>(reader_.read<uint8_t>());
        outArc.label = reader_.read<char>();

        outArc.output = 0;
        if ((outArc.flag & ArcFlag::VALUE_ARC) != ArcFlag::EMPTY) {
            outArc.output = reader_.read<int>();
        }

        outArc.finalOutput = 0;
        if ((outArc.flag & ArcFlag::FINAL_VALUE_ARC) != ArcFlag::EMPTY) {
            outArc.finalOutput = reader_.read<int>();
        }

        outArc.targetAddress = detail::SPECIAL_ADDRESS;
        if ((outArc.flag & ArcFlag::STOP_NODE) == ArcFlag::EMPTY) {
            outArc.targetAddress = reader_.read<uint64_t>();
        }

        if (outArc.label == label) {
            return true;
        }

        if ((outArc.flag & ArcFlag::LAST) != ArcFlag::EMPTY) {
            break;
        }
    }
    return false;
}

findLabel 函数中,我们首先检查当前节点地址是否为特殊地址 SPECIAL_ADDRESS,如果是的话,说明该节点不存在,我们直接返回 false

然后我们定位到当前节点地址,并开始遍历该节点的所有出边。 这实际上就是反序列化的过程,我们依次读取每一条边的信息,并检查其标签是否与目标标签匹配。如果匹配,我们将该边的信息存储到 outArc 中,并返回 true。如果遍历完所有边都没有找到匹配的标签,则返回 false。 因为序列化时我们是依次序列化每一条边的,它们都顺序存储在一起,所以我们只需要顺序读取,直至遇到最后一条边即可。

总结

在这篇文章中,我们详细介绍了如何用 C++ 实现一个 FST 以及相关的实现细节。不过我们实现的 FST 是最简单的版本,在实际应用中,你可能还需要考虑更多的优化。但是 对于初学者理解和掌握 FST 的基本原理,这个实现已经足够了。完整的代码可以在 这里 找到。