预备知识:
如果现有的库没有涵盖你想要的操作, 你可以自己定制一个. 为了使定制的 Op 能够兼容原有的库 , 你必须做以下工作:
向 TensorFlow 系统注册来定义 Op 的接口. 在注册时, 指定 Op 的名称, 它的输入(类型和名称) 和输出(类型和名称), 和所需要任何 属性的文档说明.
为了让你有直观的认识, 创建一个简单的 Op 作为例子. 该 Op 接受一个 int32 类型 tensor 作为 输入, 输出这个 tensor 的一个副本, 副本与原 tensor 唯一的区别在于第一个元素被置为 0. 创建 文件tensorflow/core/user_ops/zero_out.cc, 并调用 REGISTER_OP 宏来定义 Op 的接口.
#include "tensorflow/core/framework/op.h"
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32");
ZeroOut Op 接受 32 位整型的 tensor to_zero 作为输入, 输出 32 位整型的 tensor zeroed.
在定义接口之后, 提供一个或多个 Op 的实现. 为这些 kernel 的每一个创建一个对应的类, 继承 OpKernel, 覆盖Compute 方法. Compute 方法提供一个类型为 OpKernelContext* 的参数 context, 用于访问一些有用的信息, 例如输入和输出的 tensor.
将 kernel 添加到刚才创建的文件中, kernel 看起来和下面的代码类似:
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// 获取输入 tensor.
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();
// 创建一个输出 tensor.
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output = output_tensor->template flat<int32>();
// 设置 tensor 除第一个之外的元素均设为 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
output(i) = 0;
}
// 尽可能地保留第一个元素的值.
if (N > 0) output(0) = input(0);
}
};
实现 kernel 后, 将其注册到 TensorFlow 系统中. 注册时, 可以指定该 kernel 运行时的多个约束 条件. 例如可以指定一个 kernel 在 CPU 上运行, 另一个在 GPU 上运行.
将下列代码加入到 zero_out.cc 中, 注册 ZeroOut op:
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
一旦创建和重新安装了 TensorFlow , Tensorflow 系统可以在需要时引用和使用该 Op.
当编译 TensorFlow 时, 所有放在 tensorflow/core/user_ops 目录下 的 Op 会自动在 bazel-genfiles/tensorflow/python/ops/gen_user_ops.py 文件 中生成 Python Op 包装器. 通过以下声明, 把那些 Op 引入到tensorflow/python/user_ops/user_ops.py 中:
from tensorflow.python.ops.gen_user_ops import *你可以选择性将部分函数替换为自己的实现. 为此, 首先要隐藏自动生成的代码, 在
tensorflow/python/BUILD 文件中, 将其名字添加到 "user_ops" 的 hidden 列表.
tf_gen_op_wrapper_py(
name = "user_ops",
hidden = [
"Fact",
],
require_shape_functions = False,
)
紧接着 "Fact" 列出自己的 Op. 然后, 在 tensorflow/python/user_ops/user_ops.py 中添加你的替代实现函数. 通常, 替代实现函数也会调用自动生成函数来真正把 Op 添加 到图中. 被隐藏的自动生成函数位于 gen_user_ops 包中, 名称多了一个下划线前缀 ("_"). 例如:
def my_fact():
"""覆盖一个 Op 自动生成代码的示例."""
return gen_user_ops._fact()
C++ Op 包装器
当编译 TensorFlow 时, 所有 tensorflow/core/user_ops 文件夹 下的 Op 会自动创建 C++ Op 包装器. 例如,tensorflow/core/user_ops/zero_out.cc 中的 Op 会自动在 bazel-genfiles/tensorflow/cc/ops/user_ops.{h,cc} 中生成包装器.
tensorflow/cc/ops/standard_ops.h 通过下述申明, 导入用户自定义 Op 自动生成的包装器.
#include "tensorflow/cc/ops/user_ops.h"
验证已经成功实现 Op 的方式是编写测试程序. 创建文件 tensorflow/python/kernel_tests/zero_out_op_test.py, 包含以下内容:
import tensorflow as tf
class ZeroOutTest(tf.test.TestCase):
def testZeroOut(self):
with self.test_session():
result = tf.user_ops.zero_out([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
然后运行测试:
$ bazel test tensorflow/python:zero_out_op_test验证条件
上述示例假定 Op 能够应用在任何 shape 的 tensor 上. 如果只想应用到 vector 上 呢? 这意味需要在上述 OpKernel 实现中添加相关的检查.
void Compute(OpKernelContext* context) override {
// 获取输入 tensor
const Tensor& input_tensor = context->input(0);
OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
errors::InvalidArgument("ZeroOut expects a 1-D vector."));
// ...
}
OP_REQUIRES 断言的输入是一个 vector, 如果不是 vector, 将设置 InvalidArgument 状态并返回. OP_REQUIRES 宏有三个参数:
context: 可以是一个 OpKernelContext 或 OpKernelConstruction 指针 (参见tensorflow/core/framework/op_kernel.h), 其 SetStatus() 方法将被使用到.
tensorflow/core/public/tensor_shape.h 中有一些验证 tensor shape 的函数.
Status 对象表示, 参见 tensorflow/core/public/status.h. Status 包含一个类型 (通常是 InvalidArgument, 但也可以是任何类型) 和一个消息. 构造 一个错误的函数位于tensorflow/core/lib/core/errors.h 中.
如果想要测试一个函数返回的 Status 对象是否是一个错误, 可以使用 OP_REQUIRES_OK. 这些宏如果检测到错误, 会直接跳出函数, 终止函数执行.
Op 可以有属性, 属性的值在 Op 添加到图中时被设置. 属性值用于配置 Op, 在 kernel 实现中, Op 注册的输入和输出类型中, 均可访问这些属性值. 尽可能地使用输入代替属性, 因为输入的灵活性更高, 例如可以在执行步骤中 中被更改, 可以使用 feed 等等. 属性可用于实现一些输入无法做到的事情, 例如影响 Op 签名 (即输入输出的数量和类型) 的配置或只读配置可以通过属性实现.
注册 Op 时可以用 Attr 方法指定属性的名称和类型, 以此来定义一个属性, 形式如下:
<name>: <attr-type-expr>
<name> 必须以字母开头, 可以由数字, 字母, 下划线组成. <attr-type-expr> 是一个类型表达式, 形式如下:
例如, 如果想要 ZeroOut Op 保存一个用户索引, 指示该 Op 不仅仅只有一个元素, 你可以注册 Op 如下:
REGISTER_OP("ZeroOut")
.Attr("preserve_index: int")
.Input("to_zero: int32")
.Output("zeroed: int32");
你的 kernel 可以在构造函数里, 通过 context 参数访问这个属性:
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction * context) : OpKernel(context) {
// 获取欲保存的索引值
OP_REQUIRES_OK(context,
context->GetAttr("preserve_index", &preserve_index_));
// 检查 preserve_index 是否为正
OP_REQUIRES(context, preserve_index_ >= 0,
errors::InvalidArgument("Need preserve_index >= 0, got ",
preserve_index_));
}
void Compute(OpKernelContext* context) override {
// ...
}
private:
int preserve_index_;
};
该值可以在 Compute 方法中被使用:
void Compute(OpKernelContext* context) override {
// ...
// 检查 preserve_index 范围是否合法
OP_REQUIRES(context, preserve_index_ < input.dimension(0),
errors::InvalidArgument("preserve_index out of range"));
// 设置输出 tensor 所有的元素值为 0
const int N = input.size();
for (int i = 0; i < N; i++) {
output_flat(i) = 0;
}
// 保存请求的输入值
output_flat(preserve_index_) = input(preserve_index_);
}
为了维持向后兼容性, 将一个属性添加到一个已有的 Op 时, 必须指定一个默认值:REGISTER_OP("ZeroOut")
.Attr("preserve_index: int = 0")
.Input("to_zero: int32")
.Output("zeroed: int32");