使用 C++ 构建图

C++ 图表构建器是一个功能强大的工具,用于:

  • 构建复杂图表
  • 将图表参数化(例如,在 InferenceCalculator 上设置委托、启用/停用图表的某些部分)
  • 删除重复图表(例如,您可以使用一个代码来构建所需的图表,而不是 pbtxt 中的 CPU 和 GPU 专用图表,尽可能多地共享)
  • 支持可选的图表输入/输出
  • 按平台自定义图表

基本用法

我们来看一看 C++ 图表构建器如何用于简单的图表:

# Graph inputs.
input_stream: "input_tensors"
input_side_packet: "model"

# Graph outputs.
output_stream: "output_tensors"

node {
  calculator: "InferenceCalculator"
  input_stream: "TENSORS:input_tensors"
  input_side_packet: "MODEL:model"
  output_stream: "TENSORS:output_tensors"
  options: {
    [drishti.InferenceCalculatorOptions.ext] {
      # Requesting GPU delegate.
      delegate { gpu {} }
    }
  }
}

用于构建上述 CalculatorGraphConfig 的函数可能如下所示:

CalculatorGraphConfig BuildGraph() {
  Graph graph;

  // Graph inputs.
  Stream<std::vector<Tensor>> input_tensors =
      graph.In(0).SetName("input_tensors").Cast<std::vector<Tensor>>();
  SidePacket<TfLiteModelPtr> model =
      graph.SideIn(0).SetName("model").Cast<TfLiteModelPtr>();

  auto& inference_node = graph.AddNode("InferenceCalculator");
  auto& inference_opts =
      inference_node.GetOptions<InferenceCalculatorOptions>();
  // Requesting GPU delegate.
  inference_opts.mutable_delegate()->mutable_gpu();
  input_tensors.ConnectTo(inference_node.In("TENSORS"));
  model.ConnectTo(inference_node.SideIn("MODEL"));
  Stream<std::vector<Tensor>> output_tensors =
      inference_node.Out("TENSORS").Cast<std::vector<Tensor>>();

  // Graph outputs.
  output_tensors.SetName("output_tensors").ConnectTo(graph.Out(0));

  // Get `CalculatorGraphConfig` to pass it into `CalculatorGraph`
  return graph.GetConfig();
}

简短摘要:

  • 使用 Graph::In/SideIn 获取图表输入作为 Stream/SidePacket
  • 使用 Node::Out/SideOut 获取节点输出作为 Stream/SidePacket
  • 使用 Stream/SidePacket::ConnectTo 将数据流和旁数据包连接到节点输入 (Node::In/SideIn) 和图表输出 (Graph::Out/SideOut)
    • 您可以使用“快捷方式”运算符 >> 代替 ConnectTo 函数(例如 x >> node.In("IN"))。
  • Stream/SidePacket::Cast 用于将 AnyType 的流或边包(例如 Stream<AnyType> in = graph.In(0);)转换为特定类型
    • 使用实际类型而不是 AnyType,您将找到更好的途径来释放图表构建器功能和提高图表可读性。

高级用法

实用函数

我们将推断构建代码提取到专用的实用函数中,以帮助提高可读性和代码重用:

// Updates graph to run inference.
Stream<std::vector<Tensor>> RunInference(
    Stream<std::vector<Tensor>> tensors, SidePacket<TfLiteModelPtr> model,
    const InferenceCalculatorOptions::Delegate& delegate, Graph& graph) {
  auto& inference_node = graph.AddNode("InferenceCalculator");
  auto& inference_opts =
      inference_node.GetOptions<InferenceCalculatorOptions>();
  *inference_opts.mutable_delegate() = delegate;
  tensors.ConnectTo(inference_node.In("TENSORS"));
  model.ConnectTo(inference_node.SideIn("MODEL"));
  return inference_node.Out("TENSORS").Cast<std::vector<Tensor>>();
}

CalculatorGraphConfig BuildGraph() {
  Graph graph;

  // Graph inputs.
  Stream<std::vector<Tensor>> input_tensors =
      graph.In(0).SetName("input_tensors").Cast<std::vector<Tensor>>();
  SidePacket<TfLiteModelPtr> model =
      graph.SideIn(0).SetName("model").Cast<TfLiteModelPtr>();

  InferenceCalculatorOptions::Delegate delegate;
  delegate.mutable_gpu();
  Stream<std::vector<Tensor>> output_tensors =
      RunInference(input_tensors, model, delegate, graph);

  // Graph outputs.
  output_tensors.SetName("output_tensors").ConnectTo(graph.Out(0));

  return graph.GetConfig();
}

因此,RunInference 提供了一个清晰的接口,说明了什么是输入/输出及其类型。

它可以轻松重复使用,例如,如果您想运行额外的模型推断,只需几行代码:

  // Run first inference.
  Stream<std::vector<Tensor>> output_tensors =
      RunInference(input_tensors, model, delegate, graph);
  // Run second inference on the output of the first one.
  Stream<std::vector<Tensor>> extra_output_tensors =
      RunInference(output_tensors, extra_model, delegate, graph);

您无需复制名称和标记(InferenceCalculatorTENSORSMODEL)或在此处及处引入专用常量,这些详细信息已本地化为 RunInference 函数。

实用程序类

当然,这不仅仅涉及函数,在某些情况下,引入实用程序类会大有裨益,这有助于提高图构建代码的可读性,并降低出错几率。

MediaPipe 提供 PassThroughCalculator 计算器,只需通过其输入即可:

input_stream: "float_value"
input_stream: "int_value"
input_stream: "bool_value"

output_stream: "passed_float_value"
output_stream: "passed_int_value"
output_stream: "passed_bool_value"

node {
  calculator: "PassThroughCalculator"
  input_stream: "float_value"
  input_stream: "int_value"
  input_stream: "bool_value"
  # The order must be the same as for inputs (or you can use explicit indexes)
  output_stream: "passed_float_value"
  output_stream: "passed_int_value"
  output_stream: "passed_bool_value"
}

我们来看一下用于创建上述图表的简单 C++ 构建代码:

CalculatorGraphConfig BuildGraph() {
  Graph graph;

  // Graph inputs.
  Stream<float> float_value = graph.In(0).SetName("float_value").Cast<float>();
  Stream<int> int_value = graph.In(1).SetName("int_value").Cast<int>();
  Stream<bool> bool_value = graph.In(2).SetName("bool_value").Cast<bool>();

  auto& pass_node = graph.AddNode("PassThroughCalculator");
  float_value.ConnectTo(pass_node.In("")[0]);
  int_value.ConnectTo(pass_node.In("")[1]);
  bool_value.ConnectTo(pass_node.In("")[2]);
  Stream<float> passed_float_value = pass_node.Out("")[0].Cast<float>();
  Stream<int> passed_int_value = pass_node.Out("")[1].Cast<int>();
  Stream<bool> passed_bool_value = pass_node.Out("")[2].Cast<bool>();

  // Graph outputs.
  passed_float_value.SetName("passed_float_value").ConnectTo(graph.Out(0));
  passed_int_value.SetName("passed_int_value").ConnectTo(graph.Out(1));
  passed_bool_value.SetName("passed_bool_value").ConnectTo(graph.Out(2));

  // Get `CalculatorGraphConfig` to pass it into `CalculatorGraph`
  return graph.GetConfig();
}

虽然 pbtxt 表示法可能容易出错(当我们有许多输入要传递时),但 C++ 代码看起来更糟糕:重复的空标记和 Cast 调用。我们来看看如何通过引入 PassThroughNodeBuilder 来做得更好:

class PassThroughNodeBuilder {
 public:
  explicit PassThroughNodeBuilder(Graph& graph)
      : node_(graph.AddNode("PassThroughCalculator")) {}

  template <typename T>
  Stream<T> PassThrough(Stream<T> stream) {
    stream.ConnectTo(node_.In(index_));
    return node_.Out(index_++).Cast<T>();
  }

 private:
  int index_ = 0;
  GenericNode& node_;
};

现在,图表构建代码可能如下所示:

CalculatorGraphConfig BuildGraph() {
  Graph graph;

  // Graph inputs.
  Stream<float> float_value = graph.In(0).SetName("float_value").Cast<float>();
  Stream<int> int_value = graph.In(1).SetName("int_value").Cast<int>();
  Stream<bool> bool_value = graph.In(2).SetName("bool_value").Cast<bool>();

  PassThroughNodeBuilder pass_node_builder(graph);
  Stream<float> passed_float_value = pass_node_builder.PassThrough(float_value);
  Stream<int> passed_int_value = pass_node_builder.PassThrough(int_value);
  Stream<bool> passed_bool_value = pass_node_builder.PassThrough(bool_value);

  // Graph outputs.
  passed_float_value.SetName("passed_float_value").ConnectTo(graph.Out(0));
  passed_int_value.SetName("passed_int_value").ConnectTo(graph.Out(1));
  passed_bool_value.SetName("passed_bool_value").ConnectTo(graph.Out(2));

  // Get `CalculatorGraphConfig` to pass it into `CalculatorGraph`
  return graph.GetConfig();
}

现在,传递构造代码中的顺序或索引不能不正确,并且您可以通过从 PassThrough 输入中猜测 Cast 的类型来减少一些输入。

正确做法和错误做法

如果可能,在一开始就定义图表输入

在以下代码中:

  • 您可能很难猜出图表中有多少输入。
  • 整体上容易出错且将来难以维护(例如,索引是正确的吗?如果某些输入被移除或设为可选,该怎么办?等)。
  • RunSomething 的重复使用受到限制,因为其他图可能具有不同的输入

错误做法 - 不良代码示例。

Stream<D> RunSomething(Stream<A> a, Stream<B> b, Graph& graph) {
  Stream<C> c = graph.In(2).SetName("c").Cast<C>();  // Bad.
  // ...
}

CalculatorGraphConfig BuildGraph() {
  Graph graph;

  Stream<A> a = graph.In(0).SetName("a").Cast<A>();
  // 10/100/N lines of code.
  Stream<B> b = graph.In(1).SetName("b").Cast<B>()  // Bad.
  Stream<D> d = RunSomething(a, b, graph);
  // ...

  return graph.GetConfig();
}

请改为在图表构建器的开头定义图表输入:

正确做法 - 优秀代码示例。

Stream<D> RunSomething(Stream<A> a, Stream<B> b, Stream<C> c, Graph& graph) {
  // ...
}

CalculatorGraphConfig BuildGraph() {
  Graph graph;

  // Inputs.
  Stream<A> a = graph.In(0).SetName("a").Cast<A>();
  Stream<B> b = graph.In(1).SetName("b").Cast<B>();
  Stream<C> c = graph.In(2).SetName("c").Cast<C>();

  // 10/100/N lines of code.
  Stream<D> d = RunSomething(a, b, c, graph);
  // ...

  return graph.GetConfig();
}

如果您的输入流或旁数据包并非始终定义,请使用 std::optional,并将其放在最前面:

正确做法 - 优秀代码示例。

std::optional<Stream<A>> a;
if (needs_a) {
  a = graph.In(0).SetName(a).Cast<A>();
}

在最后定义图表输出

在以下代码中:

  • 您可能很难猜出图表中具有多少输出。
  • 整体上容易出错且未来难以维护(例如,索引是否正确?名称?如果某些输出被移除或设为可选,该怎么办?等)。
  • RunSomething 的重复使用受到限制,因为其他图表可能具有不同的输出

错误做法 - 不良代码示例。

void RunSomething(Stream<Input> input, Graph& graph) {
  // ...
  node.Out("OUTPUT_F")
      .SetName("output_f").ConnectTo(graph.Out(2));  // Bad.
}

CalculatorGraphConfig BuildGraph() {
  Graph graph;

  // 10/100/N lines of code.
  node.Out("OUTPUT_D")
      .SetName("output_d").ConnectTo(graph.Out(0));  // Bad.
  // 10/100/N lines of code.
  node.Out("OUTPUT_E")
      .SetName("output_e").ConnectTo(graph.Out(1));  // Bad.
  // 10/100/N lines of code.
  RunSomething(input, graph);
  // ...

  return graph.GetConfig();
}

请改为在图表构建器的末尾定义图表输出:

正确做法 - 优秀代码示例。

Stream<F> RunSomething(Stream<Input> input, Graph& graph) {
  // ...
  return node.Out("OUTPUT_F").Cast<F>();
}

CalculatorGraphConfig BuildGraph() {
  Graph graph;

  // 10/100/N lines of code.
  Stream<D> d = node.Out("OUTPUT_D").Cast<D>();
  // 10/100/N lines of code.
  Stream<E> e = node.Out("OUTPUT_E").Cast<E>();
  // 10/100/N lines of code.
  Stream<F> f = RunSomething(input, graph);
  // ...

  // Outputs.
  d.SetName("output_d").ConnectTo(graph.Out(0));
  e.SetName("output_e").ConnectTo(graph.Out(1));
  f.SetName("output_f").ConnectTo(graph.Out(2));

  return graph.GetConfig();
}

使节点彼此分离

在 MediaPipe 中,数据包流和旁数据包与处理节点一样有意义。此外,任何节点输入要求和输出产品都根据其使用和生成的数据流和旁包来清楚独立地表达。

错误做法 - 不良代码示例。

CalculatorGraphConfig BuildGraph() {
  Graph graph;

  // Inputs.
  Stream<A> a = graph.In(0).Cast<A>();

  auto& node1 = graph.AddNode("Calculator1");
  a.ConnectTo(node1.In("INPUT"));

  auto& node2 = graph.AddNode("Calculator2");
  node1.Out("OUTPUT").ConnectTo(node2.In("INPUT"));  // Bad.

  auto& node3 = graph.AddNode("Calculator3");
  node1.Out("OUTPUT").ConnectTo(node3.In("INPUT_B"));  // Bad.
  node2.Out("OUTPUT").ConnectTo(node3.In("INPUT_C"));  // Bad.

  auto& node4 = graph.AddNode("Calculator4");
  node1.Out("OUTPUT").ConnectTo(node4.In("INPUT_B"));  // Bad.
  node2.Out("OUTPUT").ConnectTo(node4.In("INPUT_C"));  // Bad.
  node3.Out("OUTPUT").ConnectTo(node4.In("INPUT_D"));  // Bad.

  // Outputs.
  node1.Out("OUTPUT").SetName("b").ConnectTo(graph.Out(0));  // Bad.
  node2.Out("OUTPUT").SetName("c").ConnectTo(graph.Out(1));  // Bad.
  node3.Out("OUTPUT").SetName("d").ConnectTo(graph.Out(2));  // Bad.
  node4.Out("OUTPUT").SetName("e").ConnectTo(graph.Out(3));  // Bad.

  return graph.GetConfig();
}

在上述代码中:

  • 节点彼此耦合,例如 node4 知道其输入来自何处(node1node2node3),这会使重构、维护和代码重用都变得复杂
    • 这种使用模式是从 proto 表示法降级,在这种表示法中,节点默认分离。
  • node#.Out("OUTPUT") 调用会重复,可读性会受到影响,因为您可以改用更简洁的名称,也可以提供实际类型。

因此,如需解决上述问题,您可以编写以下图表构建代码:

正确做法 - 优秀代码示例。

CalculatorGraphConfig BuildGraph() {
  Graph graph;

  // Inputs.
  Stream<A> a = graph.In(0).Cast<A>();

  // `node1` usage is limited to 3 lines below.
  auto& node1 = graph.AddNode("Calculator1");
  a.ConnectTo(node1.In("INPUT"));
  Stream<B> b = node1.Out("OUTPUT").Cast<B>();

  // `node2` usage is limited to 3 lines below.
  auto& node2 = graph.AddNode("Calculator2");
  b.ConnectTo(node2.In("INPUT"));
  Stream<C> c = node2.Out("OUTPUT").Cast<C>();

  // `node3` usage is limited to 4 lines below.
  auto& node3 = graph.AddNode("Calculator3");
  b.ConnectTo(node3.In("INPUT_B"));
  c.ConnectTo(node3.In("INPUT_C"));
  Stream<D> d = node3.Out("OUTPUT").Cast<D>();

  // `node4` usage is limited to 5 lines below.
  auto& node4 = graph.AddNode("Calculator4");
  b.ConnectTo(node4.In("INPUT_B"));
  c.ConnectTo(node4.In("INPUT_C"));
  d.ConnectTo(node4.In("INPUT_D"));
  Stream<E> e = node4.Out("OUTPUT").Cast<E>();

  // Outputs.
  b.SetName("b").ConnectTo(graph.Out(0));
  c.SetName("c").ConnectTo(graph.Out(1));
  d.SetName("d").ConnectTo(graph.Out(2));
  e.SetName("e").ConnectTo(graph.Out(3));

  return graph.GetConfig();
}

现在,如果需要,您可以轻松移除 node1 并将 b 设为图表输入,而无需对 node2node3node4 进行更新(顺便说一下,这与 proto 表示法中相同),因为它们彼此分离。

总体而言,上面的代码进一步复制了 proto 图:

input_stream: "a"

node {
  calculator: "Calculator1"
  input_stream: "INPUT:a"
  output_stream: "OUTPUT:b"
}

node {
  calculator: "Calculator2"
  input_stream: "INPUT:b"
  output_stream: "OUTPUT:C"
}

node {
  calculator: "Calculator3"
  input_stream: "INPUT_B:b"
  input_stream: "INPUT_C:c"
  output_stream: "OUTPUT:d"
}

node {
  calculator: "Calculator4"
  input_stream: "INPUT_B:b"
  input_stream: "INPUT_C:c"
  input_stream: "INPUT_D:d"
  output_stream: "OUTPUT:e"
}

output_stream: "b"
output_stream: "c"
output_stream: "d"
output_stream: "e"

除此之外,您现在可以提取实用函数,以便在其他图中进一步重复使用:

正确做法 - 优秀代码示例。

Stream<B> RunCalculator1(Stream<A> a, Graph& graph) {
  auto& node = graph.AddNode("Calculator1");
  a.ConnectTo(node.In("INPUT"));
  return node.Out("OUTPUT").Cast<B>();
}

Stream<C> RunCalculator2(Stream<B> b, Graph& graph) {
  auto& node = graph.AddNode("Calculator2");
  b.ConnectTo(node.In("INPUT"));
  return node.Out("OUTPUT").Cast<C>();
}

Stream<D> RunCalculator3(Stream<B> b, Stream<C> c, Graph& graph) {
  auto& node = graph.AddNode("Calculator3");
  b.ConnectTo(node.In("INPUT_B"));
  c.ConnectTo(node.In("INPUT_C"));
  return node.Out("OUTPUT").Cast<D>();
}

Stream<E> RunCalculator4(Stream<B> b, Stream<C> c, Stream<D> d, Graph& graph) {
  auto& node = graph.AddNode("Calculator4");
  b.ConnectTo(node.In("INPUT_B"));
  c.ConnectTo(node.In("INPUT_C"));
  d.ConnectTo(node.In("INPUT_D"));
  return node.Out("OUTPUT").Cast<E>();
}

CalculatorGraphConfig BuildGraph() {
  Graph graph;

  // Inputs.
  Stream<A> a = graph.In(0).Cast<A>();

  Stream<B> b = RunCalculator1(a, graph);
  Stream<C> c = RunCalculator2(b, graph);
  Stream<D> d = RunCalculator3(b, c, graph);
  Stream<E> e = RunCalculator4(b, c, d, graph);

  // Outputs.
  b.SetName("b").ConnectTo(graph.Out(0));
  c.SetName("c").ConnectTo(graph.Out(1));
  d.SetName("d").ConnectTo(graph.Out(2));
  e.SetName("e").ConnectTo(graph.Out(3));

  return graph.GetConfig();
}

单独的节点以提高可读性

错误做法 - 不良代码示例。

CalculatorGraphConfig BuildGraph() {
  Graph graph;

  // Inputs.
  Stream<A> a = graph.In(0).Cast<A>();
  auto& node1 = graph.AddNode("Calculator1");
  a.ConnectTo(node1.In("INPUT"));
  Stream<B> b = node1.Out("OUTPUT").Cast<B>();
  auto& node2 = graph.AddNode("Calculator2");
  b.ConnectTo(node2.In("INPUT"));
  Stream<C> c = node2.Out("OUTPUT").Cast<C>();
  auto& node3 = graph.AddNode("Calculator3");
  b.ConnectTo(node3.In("INPUT_B"));
  c.ConnectTo(node3.In("INPUT_C"));
  Stream<D> d = node3.Out("OUTPUT").Cast<D>();
  auto& node4 = graph.AddNode("Calculator4");
  b.ConnectTo(node4.In("INPUT_B"));
  c.ConnectTo(node4.In("INPUT_C"));
  d.ConnectTo(node4.In("INPUT_D"));
  Stream<E> e = node4.Out("OUTPUT").Cast<E>();
  // Outputs.
  b.SetName("b").ConnectTo(graph.Out(0));
  c.SetName("c").ConnectTo(graph.Out(1));
  d.SetName("d").ConnectTo(graph.Out(2));
  e.SetName("e").ConnectTo(graph.Out(3));

  return graph.GetConfig();
}

在上面的代码中,可能很难理解每个节点的开始和结束位置。为了改进这一点并帮助代码读取器,您只需在每个节点前后添加空白行:

正确做法 - 优秀代码示例。

CalculatorGraphConfig BuildGraph() {
  Graph graph;

  // Inputs.
  Stream<A> a = graph.In(0).Cast<A>();

  auto& node1 = graph.AddNode("Calculator1");
  a.ConnectTo(node1.In("INPUT"));
  Stream<B> b = node1.Out("OUTPUT").Cast<B>();

  auto& node2 = graph.AddNode("Calculator2");
  b.ConnectTo(node2.In("INPUT"));
  Stream<C> c = node2.Out("OUTPUT").Cast<C>();

  auto& node3 = graph.AddNode("Calculator3");
  b.ConnectTo(node3.In("INPUT_B"));
  c.ConnectTo(node3.In("INPUT_C"));
  Stream<D> d = node3.Out("OUTPUT").Cast<D>();

  auto& node4 = graph.AddNode("Calculator4");
  b.ConnectTo(node4.In("INPUT_B"));
  c.ConnectTo(node4.In("INPUT_C"));
  d.ConnectTo(node4.In("INPUT_D"));
  Stream<E> e = node4.Out("OUTPUT").Cast<E>();

  // Outputs.
  b.SetName("b").ConnectTo(graph.Out(0));
  c.SetName("c").ConnectTo(graph.Out(1));
  d.SetName("d").ConnectTo(graph.Out(2));
  e.SetName("e").ConnectTo(graph.Out(3));

  return graph.GetConfig();
}

此外,上述表示法与 CalculatorGraphConfig proto 表示法更匹配。

如果将节点提取到实用函数中,它们已经限定在函数内,并且它们的开始和结束位置很清晰,因此可以:

正确做法 - 优秀代码示例。

CalculatorGraphConfig BuildGraph() {
  Graph graph;

  // Inputs.
  Stream<A> a = graph.In(0).Cast<A>();

  Stream<B> b = RunCalculator1(a, graph);
  Stream<C> c = RunCalculator2(b, graph);
  Stream<D> d = RunCalculator3(b, c, graph);
  Stream<E> e = RunCalculator4(b, c, d, graph);

  // Outputs.
  b.SetName("b").ConnectTo(graph.Out(0));
  c.SetName("c").ConnectTo(graph.Out(1));
  d.SetName("d").ConnectTo(graph.Out(2));
  e.SetName("e").ConnectTo(graph.Out(3));

  return graph.GetConfig();
}