PyTorch 1.0 版本向 PyTorch 引入了一種新的編程模型,稱為 TorchScript 。 TorchScript 是 Python 編程語(yǔ)言的子集,可以通過 TorchScript 編譯器進(jìn)行解析,編譯和優(yōu)化。 此外,已編譯的 TorchScript 模型可以選擇序列化為磁盤文件格式,然后可以從純 C ++(以及 Python)加載并運(yùn)行該文件格式以進(jìn)行推理。
TorchScript 支持torch
包提供的大量操作子集,使您可以純粹表示為 PyTorch 的“標(biāo)準(zhǔn)庫(kù)”中的一系列張量操作來(lái)表示多種復(fù)雜模型。 但是,有時(shí)您可能需要使用自定義 C ++或 CUDA 函數(shù)擴(kuò)展 TorchScript。 雖然我們建議您僅在無(wú)法(簡(jiǎn)單有效地)將您的想法表達(dá)為簡(jiǎn)單的 Python 函數(shù)時(shí)才訴諸該選項(xiàng),但我們確實(shí)提供了一個(gè)非常友好且簡(jiǎn)單的界面,用于使用 ATen 定義自定義 C ++和 CUDA 內(nèi)核。 ,PyTorch 的高性能 C ++張量庫(kù)。 綁定到 TorchScript 后,您可以將這些自定義內(nèi)核(或“ ops”)嵌入到 TorchScript 模型中,并以 Python 或直接以 C ++的序列化形式執(zhí)行它們。
以下段落提供了編寫 TorchScript 自定義操作以調(diào)用 OpenCV (使用 C ++編寫的計(jì)算機(jī)視覺庫(kù))的示例。 我們將討論如何在 C ++中使用張量,如何有效地將它們轉(zhuǎn)換為第三方張量格式(在這種情況下為 OpenCV 或 Mat),如何在 TorchScript 運(yùn)行時(shí)中注冊(cè)您的運(yùn)算符 最后是如何編譯運(yùn)算符并在 Python 和 C ++中使用它。
在本教程中,我們將公開 warpPerspective 函數(shù),該函數(shù)將透視轉(zhuǎn)換應(yīng)用于圖像,從 OpenCV 到 TorchScript 作為自定義運(yùn)算符。 第一步是用 C ++編寫自定義運(yùn)算符的實(shí)現(xiàn)。 讓我們將此實(shí)現(xiàn)的文件稱為op.cpp
,并使其如下所示:
#include <opencv2/opencv.hpp>
#include <torch/script.h>
torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
cv::Mat image_mat(/*rows=*/image.size(0),
/*cols=*/image.size(1),
/*type=*/CV_32FC1,
/*data=*/image.data<float>());
cv::Mat warp_mat(/*rows=*/warp.size(0),
/*cols=*/warp.size(1),
/*type=*/CV_32FC1,
/*data=*/warp.data<float>());
cv::Mat output_mat;
cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{8, 8});
torch::Tensor output = torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{8, 8});
return output.clone();
}
該運(yùn)算符的代碼很短。 在文件頂部,我們包含 OpenCV 標(biāo)頭文件opencv2/opencv.hpp
和torch/script.h
標(biāo)頭,該標(biāo)頭暴露了 PyTorch C ++ API 中所有需要編寫自定義 TorchScript 運(yùn)算符的必需屬性。 我們的函數(shù)warp_perspective
具有兩個(gè)參數(shù):輸入image
和我們希望應(yīng)用于圖像的warp
變換矩陣。 這些輸入的類型是torch::Tensor
,這是 C ++中 PyTorch 的張量類型(也是 Python 中所有張量的基礎(chǔ)類型)。 我們的warp_perspective
函數(shù)的返回類型也將是torch::Tensor
。
小費(fèi)
有關(guān) ATen 的更多信息,請(qǐng)參見本說(shuō)明,ATen 是為 PyTorch 提供Tensor
類的庫(kù)。 此外,本教程的描述了如何在 C ++中分配和初始化新的張量對(duì)象(此運(yùn)算符不需要)。
注意
TorchScript 編譯器了解固定數(shù)量的類型。 只有這些類型可以用作自定義運(yùn)算符的參數(shù)。 當(dāng)前這些類型是:這些類型的torch::Tensor
,torch::Scalar
,double
,int64_t
和std::vector
。 請(qǐng)注意,僅,,double
和不,,float
,僅,,int64_t
和,等其他整數(shù)類型,例如int
支持short
或long
。
在函數(shù)內(nèi)部,我們要做的第一件事是將 PyTorch 張量轉(zhuǎn)換為 OpenCV 矩陣,因?yàn)?OpenCV 的warpPerspective
期望cv::Mat
對(duì)象作為輸入。 幸運(yùn)的是,有一種方法可以執(zhí)行此,而無(wú)需復(fù)制任何數(shù)據(jù)。 在前幾行中
cv::Mat image_mat(/*rows=*/image.size(0),
/*cols=*/image.size(1),
/*type=*/CV_32FC1,
/*data=*/image.data<float>());
我們正在將稱為 OpenCV Mat
類的構(gòu)造函數(shù),將張量轉(zhuǎn)換為Mat
對(duì)象。 我們將原始image
張量的行數(shù)和列數(shù),數(shù)據(jù)類型(在此示例中,我們將其固定為float32
)傳遞給它,最后傳遞指向基礎(chǔ)數(shù)據(jù)的原始指針– float*
。 Mat
類的此構(gòu)造方法的特殊之處在于它不會(huì)復(fù)制輸入數(shù)據(jù)。 取而代之的是,它將簡(jiǎn)單地引用此內(nèi)存來(lái)執(zhí)行Mat
上的所有操作。 如果在image_mat
上執(zhí)行就地操作,這將反映在原始image
張量中(反之亦然)。 即使我們實(shí)際上將數(shù)據(jù)存儲(chǔ)在 PyTorch 張量中,這也使我們能夠使用庫(kù)的本機(jī)矩陣類型調(diào)用后續(xù)的 OpenCV 例程。 我們重復(fù)此過程將warp
PyTorch 張量轉(zhuǎn)換為warp_mat
OpenCV 矩陣:
cv::Mat warp_mat(/*rows=*/warp.size(0),
/*cols=*/warp.size(1),
/*type=*/CV_32FC1,
/*data=*/warp.data<float>());
接下來(lái),我們準(zhǔn)備調(diào)用我們渴望在 TorchScript 中使用的 OpenCV 函數(shù):warpPerspective
。 為此,我們將image_mat
和warp_mat
矩陣以及稱為output_mat
的空輸出矩陣傳遞給 OpenCV 函數(shù)。 我們還指定了我們希望輸出矩陣(圖像)為dsize
的大小。 對(duì)于此示例,它被硬編碼為8 x 8
:
cv::Mat output_mat;
cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{8, 8});
我們的自定義運(yùn)算符實(shí)現(xiàn)的最后一步是將output_mat
轉(zhuǎn)換回 PyTorch 張量,以便我們可以在 PyTorch 中進(jìn)一步使用它。 這與我們先前在另一個(gè)方向進(jìn)行轉(zhuǎn)換的操作極為相似。 在這種情況下,PyTorch 提供了torch::from_blob
方法。 在這種情況下, blob 旨在表示一些不透明的,扁平的指向內(nèi)存的指針,我們希望將其解釋為 PyTorch 張量。 對(duì)torch::from_blob
的調(diào)用如下所示:
torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{8, 8})
我們?cè)?OpenCV Mat
類上使用.ptr<float>()
方法來(lái)獲取指向基礎(chǔ)數(shù)據(jù)的原始指針(就像之前的 PyTorch 張量的.data<float>()
一樣)。 我們還指定了張量的輸出形狀,我們將其硬編碼為8 x 8
。 然后torch::from_blob
的輸出是torch::Tensor
,指向 OpenCV 矩陣擁有的內(nèi)存。
從我們的運(yùn)算符實(shí)現(xiàn)返回該張量之前,我們必須在張量上調(diào)用.clone()
以執(zhí)行基礎(chǔ)數(shù)據(jù)的存儲(chǔ)副本。 這樣做的原因是torch::from_blob
返回的張量不擁有其數(shù)據(jù)。 那時(shí),數(shù)據(jù)仍歸 OpenCV 矩陣所有。 但是,此 OpenCV 矩陣將超出范圍,并在函數(shù)末尾重新分配。 如果我們按原樣返回output
張量,那么當(dāng)我們?cè)诤瘮?shù)外使用它時(shí),它將指向無(wú)效的內(nèi)存。 調(diào)用.clone()
將返回一個(gè)新的張量,其中包含新張量自己擁有的原始數(shù)據(jù)的副本。 因此,返回外部世界是安全的。
現(xiàn)在,已經(jīng)在 C ++中實(shí)現(xiàn)了自定義運(yùn)算符,我們需要在 TorchScript 運(yùn)行時(shí)和編譯器中將注冊(cè)為。 這將使 TorchScript 編譯器可以在 TorchScript 代碼中解析對(duì)我們自定義運(yùn)算符的引用。 注冊(cè)非常簡(jiǎn)單。 對(duì)于我們的情況,我們需要編寫:
static auto registry =
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);
op.cpp
文件的全局范圍內(nèi)的某個(gè)位置。 這將創(chuàng)建一個(gè)全局變量registry
,該變量將在其構(gòu)造函數(shù)中向 TorchScript 注冊(cè)我們的運(yùn)算符(即每個(gè)程序一次)。 我們指定運(yùn)算符的名稱,以及指向其實(shí)現(xiàn)的指針(我們之前編寫的函數(shù))。 該名稱包括兩部分:命名空間(my_ops
)和我們正在注冊(cè)的特定運(yùn)算符的名稱(warp_perspective
)。 名稱空間和操作員名稱由兩個(gè)冒號(hào)(::
)分隔。
Tip
如果要注冊(cè)多個(gè)運(yùn)算符,可以在構(gòu)造函數(shù)之后將調(diào)用鏈接到.op()
:
static auto registry =
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective)
.op("my_ops::another_op", &another_op)
.op("my_ops::and_another_op", &and_another_op);
在后臺(tái),RegisterOperators
將執(zhí)行許多相當(dāng)復(fù)雜的 C ++模板元編程魔術(shù)技巧,以推斷我們傳遞給它的函數(shù)指針的參數(shù)和返回值類型(&warp_perspective
)。 此信息用于為我們的操作員形成功能模式。 函數(shù)模式是操作員的結(jié)構(gòu)化表示形式,一種“簽名”或“原型”,由 TorchScript 編譯器用來(lái)驗(yàn)證 TorchScript 程序的正確性。
現(xiàn)在,我們已經(jīng)用 C ++實(shí)現(xiàn)了自定義運(yùn)算符并編寫了其注冊(cè)代碼,是時(shí)候?qū)⒃撨\(yùn)算符構(gòu)建到一個(gè)(共享的)庫(kù)中了,可以將其加載到 Python 中進(jìn)行研究和實(shí)驗(yàn),或者加載到 C ++中以在非 Python 中進(jìn)行推理。 環(huán)境。 有多種方法可以使用純 CMake 或setuptools
之類的 Python 替代方法來(lái)構(gòu)建我們的運(yùn)算符。 為簡(jiǎn)潔起見,以下段落僅討論 CMake 方法。 本教程的附錄深入探討了基于 Python 的替代方法。
為了使用 CMake 構(gòu)建系統(tǒng)將自定義運(yùn)算符構(gòu)建到共享庫(kù)中,我們需要編寫一個(gè)簡(jiǎn)短的CMakeLists.txt
文件并將其與之前的op.cpp
文件一起放置。 為此,讓我們就一個(gè)看起來(lái)像這樣的目錄結(jié)構(gòu)達(dá)成一致:
warp-perspective/
op.cpp
CMakeLists.txt
另外,請(qǐng)確保從 pytorch.org 中獲取 LibTorch 發(fā)行版的最新版本,該軟件包打包了 PyTorch 的 C ++庫(kù)和 CMake 構(gòu)建文件。 將解壓縮的發(fā)行版放置在文件系統(tǒng)中可訪問的位置。 以下段落將將該位置稱為/path/to/libtorch
。 我們的CMakeLists.txt
文件的內(nèi)容應(yīng)為以下內(nèi)容:
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(warp_perspective)
find_package(Torch REQUIRED)
find_package(OpenCV REQUIRED)
## Define our library target
add_library(warp_perspective SHARED op.cpp)
## Enable C++11
target_compile_features(warp_perspective PRIVATE cxx_range_for)
## Link against LibTorch
target_link_libraries(warp_perspective "${TORCH_LIBRARIES}")
## Link against OpenCV
target_link_libraries(warp_perspective opencv_core opencv_imgproc)
警告
此設(shè)置對(duì)構(gòu)建環(huán)境進(jìn)行了一些假設(shè),特別是有關(guān) OpenCV 安裝的假設(shè)。 上面的CMakeLists.txt
文件已在運(yùn)行 Ubuntu Xenial 的 Docker 容器中通過apt
安裝了libopencv-dev
進(jìn)行了測(cè)試。 如果它對(duì)您不起作用,并且您感到困惑,請(qǐng)使用隨附的教程資料庫(kù)中的Dockerfile
構(gòu)建一個(gè)隔離的,可復(fù)制的環(huán)境,在其中可以使用本教程中的代碼。 如果您遇到其他麻煩,請(qǐng)?jiān)诮坛藤Y料庫(kù)中提交問題,或在我們的論壇中發(fā)布問題。
現(xiàn)在要構(gòu)建我們的操作員,我們可以從warp_perspective
文件夾中運(yùn)行以下命令:
$ mkdir build
$ cd build
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
-- The C compiler identification is GNU 5.4.0
-- The CXX compiler identification is GNU 5.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE
-- Found torch: /libtorch/lib/libtorch.so
-- Configuring done
-- Generating done
-- Build files have been written to: /warp_perspective/build
$ make -j
Scanning dependencies of target warp_perspective
[ 50%] Building CXX object CMakeFiles/warp_perspective.dir/op.cpp.o
[100%] Linking CXX shared library libwarp_perspective.so
[100%] Built target warp_perspective
它將在build
文件夾中放置libwarp_perspective.so
共享庫(kù)文件。 在上面的cmake
命令中,應(yīng)將/path/to/libtorch
替換為未壓縮的 LibTorch 發(fā)行版的路徑。
我們將在下面進(jìn)一步探討如何使用和調(diào)用我們的運(yùn)算符,但是為了早日獲得成功,我們可以嘗試在 Python 中運(yùn)行以下代碼:
>>> import torch
>>> torch.ops.load_library("/path/to/libwarp_perspective.so")
>>> print(torch.ops.my_ops.warp_perspective)
在這里,/path/to/libwarp_perspective.so
應(yīng)該是我們剛剛構(gòu)建的libwarp_perspective.so
共享庫(kù)的相對(duì)或絕對(duì)路徑。 如果一切順利,這應(yīng)該打印類似
<built-in method my_ops::warp_perspective of PyCapsule object at 0x7f618fc6fa50>
這是我們稍后將用來(lái)調(diào)用自定義運(yùn)算符的 Python 函數(shù)。
將我們的自定義運(yùn)算符構(gòu)建到共享庫(kù)后,我們就可以在 Python 的 TorchScript 模型中使用此運(yùn)算符了。 這有兩個(gè)部分:首先將運(yùn)算符加載到 Python 中,其次在 TorchScript 代碼中使用運(yùn)算符。
您已經(jīng)了解了如何將運(yùn)算符導(dǎo)入 Python:torch.ops.load_library()
。 此函數(shù)采用包含自定義運(yùn)算符的共享庫(kù)的路徑,并將其加載到當(dāng)前進(jìn)程中。 加載共享庫(kù)還將執(zhí)行我們放入自定義運(yùn)算符實(shí)現(xiàn)文件中的全局RegisterOperators
對(duì)象的構(gòu)造函數(shù)。 這將在 TorchScript 編譯器中注冊(cè)我們的自定義運(yùn)算符,并允許我們?cè)?TorchScript 代碼中使用該運(yùn)算符。
您可以將已加載的運(yùn)算符稱為torch.ops.<namespace>.<function>
,其中<namespace>
是運(yùn)算符名稱的名稱空間部分,而<function>
是運(yùn)算符的函數(shù)名稱。 對(duì)于我們上面編寫的運(yùn)算符,名稱空間為my_ops
,函數(shù)名稱為warp_perspective
,這意味著我們的運(yùn)算符可以作為torch.ops.my_ops.warp_perspective
使用。 盡管可以在腳本化或跟蹤的 TorchScript 模塊中使用此函數(shù),但我們也可以僅在原始的 PyTorch 中使用它,并將其傳遞給常規(guī) PyTorch 張量:
>>> import torch
>>> torch.ops.load_library("libwarp_perspective.so")
>>> torch.ops.my_ops.warp_perspective(torch.randn(32, 32), torch.rand(3, 3))
tensor([[0.0000, 0.3218, 0.4611, ..., 0.4636, 0.4636, 0.4636],
[0.3746, 0.0978, 0.5005, ..., 0.4636, 0.4636, 0.4636],
[0.3245, 0.0169, 0.0000, ..., 0.4458, 0.4458, 0.4458],
...,
[0.1862, 0.1862, 0.1692, ..., 0.0000, 0.0000, 0.0000],
[0.1862, 0.1862, 0.1692, ..., 0.0000, 0.0000, 0.0000],
[0.1862, 0.1862, 0.1692, ..., 0.0000, 0.0000, 0.0000]])
注意
幕后發(fā)生的事情是,第一次使用 Python 訪問torch.ops.namespace.function
時(shí),TorchScript 編譯器(在 C ++平臺(tái)上)將查看是否已注冊(cè)函數(shù)namespace::function
,如果已注冊(cè),則將 Python 句柄返回給該函數(shù), 我們可以隨后使用它從 Python 調(diào)用我們的 C ++運(yùn)算符實(shí)現(xiàn)。 這是 TorchScript 自定義運(yùn)算符和 C ++擴(kuò)展之間的一個(gè)值得注意的區(qū)別:C ++擴(kuò)展是使用 pybind11 手動(dòng)綁定的,而 TorchScript 自定義操作則是由 PyTorch 自己動(dòng)態(tài)綁定的。 Pybind11 在綁定到 Python 的類型和類方面為您提供了更大的靈活性,因此建議將其用于純粹渴望的代碼,但 TorchScript ops 不支持它。
從這里開始,您可以在腳本或跟蹤代碼中使用自定義運(yùn)算符,就像torch
包中的其他函數(shù)一樣。 實(shí)際上,諸如torch.matmul
之類的“標(biāo)準(zhǔn)庫(kù)”功能與自定義運(yùn)算符的注冊(cè)路徑大致相同,這使得自定義運(yùn)算符在 TorchScript 中的使用方式和位置方面真正成為一等公民。
首先,將我們的運(yùn)算符嵌入到跟蹤函數(shù)中。 回想一下,為了進(jìn)行跟蹤,我們從一些原始的 Pytorch 代碼開始:
def compute(x, y, z):
return x.matmul(y) + torch.relu(z)
然后調(diào)用torch.jit.trace
。 我們進(jìn)一步傳遞torch.jit.trace
一些示例輸入,它將輸入到我們的實(shí)現(xiàn)中,以記錄輸入流過它時(shí)發(fā)生的操作順序。 這樣的結(jié)果實(shí)際上是渴望的 PyTorch 程序的“凍結(jié)”版本,TorchScript 編譯器可以對(duì)其進(jìn)行進(jìn)一步的分析,優(yōu)化和序列化:
>>> inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(4, 5)]
>>> trace = torch.jit.trace(compute, inputs)
>>> print(trace.graph)
graph(%x : Float(4, 8)
%y : Float(8, 5)
%z : Float(4, 5)) {
%3 : Float(4, 5) = aten::matmul(%x, %y)
%4 : Float(4, 5) = aten::relu(%z)
%5 : int = prim::Constant[value=1]()
%6 : Float(4, 5) = aten::add(%3, %4, %5)
return (%6);
}
現(xiàn)在,令人興奮的啟示是,我們可以簡(jiǎn)單地將自定義運(yùn)算符放到 PyTorch 跟蹤中,就好像它是torch.relu
或任何其他torch
函數(shù)一樣:
torch.ops.load_library("libwarp_perspective.so")
def compute(x, y, z):
x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
return x.matmul(y) + torch.relu(z)
然后像以前一樣跟蹤它:
>>> inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(8, 5)]
>>> trace = torch.jit.trace(compute, inputs)
>>> print(trace.graph)
graph(%x.1 : Float(4, 8)
%y : Float(8, 5)
%z : Float(8, 5)) {
%3 : int = prim::Constant[value=3]()
%4 : int = prim::Constant[value=6]()
%5 : int = prim::Constant[value=0]()
%6 : int[] = prim::Constant[value=[0, -1]]()
%7 : Float(3, 3) = aten::eye(%3, %4, %5, %6)
%x : Float(8, 8) = my_ops::warp_perspective(%x.1, %7)
%11 : Float(8, 5) = aten::matmul(%x, %y)
%12 : Float(8, 5) = aten::relu(%z)
%13 : int = prim::Constant[value=1]()
%14 : Float(8, 5) = aten::add(%11, %12, %13)
return (%14);
}
如此簡(jiǎn)單地將 TorchScript 自定義操作集成到跟蹤的 PyTorch 代碼中!
除了跟蹤之外,獲得 PyTorch 程序的 TorchScript 表示形式的另一種方法是直接在 TorchScript 中編寫代碼。 TorchScript 在很大程度上是 Python 語(yǔ)言的子集,它具有一些限制,使 TorchScript 編譯器更容易推理程序。 您可以使用@torch.jit.script
標(biāo)記免費(fèi)功能,使用@torch.jit.script_method
標(biāo)記類中的方法(也必須從torch.jit.ScriptModule
派生),將常規(guī) PyTorch 代碼轉(zhuǎn)換為 TorchScript。 有關(guān) TorchScript 注釋的更多詳細(xì)信息,請(qǐng)參見此處的。
使用 TorchScript 而不是跟蹤的一個(gè)特殊原因是,跟蹤無(wú)法捕獲 PyTorch 代碼中的控制流。 因此,讓我們考慮使用控制流的此函數(shù):
def compute(x, y):
if bool(x[0][0] == 42):
z = 5
else:
z = 10
return x.matmul(y) + z
要將此功能從原始 PyTorch 轉(zhuǎn)換為 TorchScript,我們用@torch.jit.script
對(duì)其進(jìn)行注釋:
@torch.jit.script
def compute(x, y):
if bool(x[0][0] == 42):
z = 5
else:
z = 10
return x.matmul(y) + z
這將及時(shí)將compute
函數(shù)編譯為圖形表示形式,我們可以在compute.graph
屬性中進(jìn)行檢查:
>>> compute.graph
graph(%x : Dynamic
%y : Dynamic) {
%14 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=0]()
%7 : int = prim::Constant[value=42]()
%z.1 : int = prim::Constant[value=5]()
%z.2 : int = prim::Constant[value=10]()
%4 : Dynamic = aten::select(%x, %2, %2)
%6 : Dynamic = aten::select(%4, %2, %2)
%8 : Dynamic = aten::eq(%6, %7)
%9 : bool = prim::TensorToBool(%8)
%z : int = prim::If(%9)
block0() {
-> (%z.1)
}
block1() {
-> (%z.2)
}
%13 : Dynamic = aten::matmul(%x, %y)
%15 : Dynamic = aten::add(%13, %z, %14)
return (%15);
}
現(xiàn)在,就像以前一樣,我們可以像腳本代碼中的任何其他函數(shù)一樣使用自定義運(yùn)算符:
torch.ops.load_library("libwarp_perspective.so")
@torch.jit.script
def compute(x, y):
if bool(x[0] == 42):
z = 5
else:
z = 10
x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
return x.matmul(y) + z
當(dāng) TorchScript 編譯器看到對(duì)torch.ops.my_ops.warp_perspective
的引用時(shí),它將找到我們通過 C ++中的RegisterOperators
對(duì)象注冊(cè)的實(shí)現(xiàn),并將其編譯為圖形表示形式:
>>> compute.graph
graph(%x.1 : Dynamic
%y : Dynamic) {
%20 : int = prim::Constant[value=1]()
%16 : int[] = prim::Constant[value=[0, -1]]()
%14 : int = prim::Constant[value=6]()
%2 : int = prim::Constant[value=0]()
%7 : int = prim::Constant[value=42]()
%z.1 : int = prim::Constant[value=5]()
%z.2 : int = prim::Constant[value=10]()
%13 : int = prim::Constant[value=3]()
%4 : Dynamic = aten::select(%x.1, %2, %2)
%6 : Dynamic = aten::select(%4, %2, %2)
%8 : Dynamic = aten::eq(%6, %7)
%9 : bool = prim::TensorToBool(%8)
%z : int = prim::If(%9)
block0() {
-> (%z.1)
}
block1() {
-> (%z.2)
}
%17 : Dynamic = aten::eye(%13, %14, %2, %16)
%x : Dynamic = my_ops::warp_perspective(%x.1, %17)
%19 : Dynamic = aten::matmul(%x, %y)
%21 : Dynamic = aten::add(%19, %z, %20)
return (%21);
}
請(qǐng)?zhí)貏e注意圖形末尾對(duì)my_ops::warp_perspective
的引用。
Attention
TorchScript 圖形表示仍可能更改。 不要依靠它看起來(lái)像這樣。
在 Python 中使用自定義運(yùn)算符時(shí),確實(shí)如此。 簡(jiǎn)而言之,您可以使用torch.ops.load_library
導(dǎo)入包含運(yùn)算符的庫(kù),并像其他任何torch
運(yùn)算符一樣,從跟蹤或編寫腳本的 TorchScript 代碼中調(diào)用自定義操作。
TorchScript 的一項(xiàng)有用功能是能夠?qū)⒛P托蛄谢酱疟P文件中。 該文件可以通過有線方式發(fā)送,存儲(chǔ)在文件系統(tǒng)中,或者更重要的是,可以動(dòng)態(tài)反序列化和執(zhí)行,而無(wú)需保留原始源代碼。 這在 Python 中是可能的,但在 C ++中也是可能的。 為此,PyTorch 為提供了純 C ++ API ,用于反序列化以及執(zhí)行 TorchScript 模型。 如果還沒有的話,請(qǐng)閱讀有關(guān)使用 C ++ 加載和運(yùn)行序列化 TorchScript 模型的教程,接下來(lái)的幾段將基于該教程構(gòu)建。
簡(jiǎn)而言之,即使從文件反序列化并以 C ++運(yùn)行,也可以像常規(guī)torch
運(yùn)算符一樣執(zhí)行自定義運(yùn)算符。 唯一的要求是將我們先前構(gòu)建的自定義運(yùn)算符共享庫(kù)與執(zhí)行模型的 C ++應(yīng)用程序鏈接。 在 Python 中,只需調(diào)用torch.ops.load_library
即可。 在 C ++中,您需要在使用的任何構(gòu)建系統(tǒng)中將共享庫(kù)與主應(yīng)用程序鏈接。 下面的示例將使用 CMake 展示這一點(diǎn)。
Note
從技術(shù)上講,您還可以在運(yùn)行時(shí)將共享庫(kù)動(dòng)態(tài)加載到 C ++應(yīng)用程序中,就像在 Python 中一樣。 在 Linux 上,可以使用 dlopen 來(lái)執(zhí)行此操作。 在其他平臺(tái)上也存在等效項(xiàng)。
在上面鏈接的 C ++執(zhí)行教程的基礎(chǔ)上,讓我們從一個(gè)文件中的最小 C ++應(yīng)用程序開始,該文件位于與自定義運(yùn)算符不同的文件夾中的main.cpp
,該文件加載并執(zhí)行序列化的 TorchScript 模型:
#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app <path-to-exported-script-module>\n";
return -1;
}
// Deserialize the ScriptModule from a file using torch::jit::load().
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::randn({4, 8}));
inputs.push_back(torch::randn({8, 5}));
torch::Tensor output = module->forward(std::move(inputs)).toTensor();
std::cout << output << std::endl;
}
以及一個(gè)小的CMakeLists.txt
文件:
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(example_app)
find_package(Torch REQUIRED)
add_executable(example_app main.cpp)
target_link_libraries(example_app "${TORCH_LIBRARIES}")
target_compile_features(example_app PRIVATE cxx_range_for)
在這一點(diǎn)上,我們應(yīng)該能夠構(gòu)建應(yīng)用程序:
$ mkdir build
$ cd build
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
-- The C compiler identification is GNU 5.4.0
-- The CXX compiler identification is GNU 5.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE
-- Found torch: /libtorch/lib/libtorch.so
-- Configuring done
-- Generating done
-- Build files have been written to: /example_app/build
$ make -j
Scanning dependencies of target example_app
[ 50%] Building CXX object CMakeFiles/example_app.dir/main.cpp.o
[100%] Linking CXX executable example_app
[100%] Built target example_app
并在尚未通過模型的情況下運(yùn)行它:
$ ./example_app
usage: example_app <path-to-exported-script-module>
接下來(lái),讓我們序列化我們先前編寫的使用自定義運(yùn)算符的腳本函數(shù):
torch.ops.load_library("libwarp_perspective.so")
@torch.jit.script
def compute(x, y):
if bool(x[0][0] == 42):
z = 5
else:
z = 10
x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
return x.matmul(y) + z
compute.save("example.pt")
最后一行將腳本功能序列化為一個(gè)名為“ example.pt”的文件。 如果我們隨后將此序列化模型傳遞給我們的 C ++應(yīng)用程序,則可以立即運(yùn)行它:
$ ./example_app example.pt
terminate called after throwing an instance of 'torch::jit::script::ErrorReport'
what():
Schema not found for node. File a bug report.
Node: %16 : Dynamic = my_ops::warp_perspective(%0, %19)
或者可能不是。 也許還沒有。 當(dāng)然! 我們尚未將自定義運(yùn)算符庫(kù)與我們的應(yīng)用程序鏈接。 讓我們立即執(zhí)行此操作,并正確進(jìn)行操作,讓我們稍微更新一下文件組織,如下所示:
example_app/
CMakeLists.txt
main.cpp
warp_perspective/
CMakeLists.txt
op.cpp
這將允許我們將warp_perspective
庫(kù) CMake 目標(biāo)添加為應(yīng)用目標(biāo)的子目錄。 example_app
文件夾中的頂層CMakeLists.txt
應(yīng)該如下所示:
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(example_app)
find_package(Torch REQUIRED)
add_subdirectory(warp_perspective)
add_executable(example_app main.cpp)
target_link_libraries(example_app "${TORCH_LIBRARIES}")
target_link_libraries(example_app -Wl,--no-as-needed warp_perspective)
target_compile_features(example_app PRIVATE cxx_range_for)
基本的 CMake 配置與以前非常相似,只是我們將warp_perspective
CMake 構(gòu)建添加為子目錄。 一旦其 CMake 代碼運(yùn)行,我們就將我們的example_app
應(yīng)用程序與warp_perspective
共享庫(kù)鏈接起來(lái)。
Attention
上面的示例中嵌入了一個(gè)關(guān)鍵細(xì)節(jié):warp_perspective
鏈接行的-Wl,--no-as-needed
前綴。 這是必需的,因?yàn)槲覀儗?shí)際上不會(huì)在應(yīng)用程序代碼中從warp_perspective
共享庫(kù)中調(diào)用任何函數(shù)。 我們只需要運(yùn)行全局RegisterOperators
對(duì)象的構(gòu)造函數(shù)即可。 麻煩的是,這使鏈接器感到困惑,并使其認(rèn)為可以完全跳過針對(duì)庫(kù)的鏈接。 在 Linux 上,-Wl,--no-as-needed
標(biāo)志強(qiáng)制執(zhí)行鏈接(注意:該標(biāo)志特定于 Linux!)。 還有其他解決方法。 最簡(jiǎn)單的方法是在操作員庫(kù)中定義一些函數(shù),您需要從主應(yīng)用程序中調(diào)用該函數(shù)。 這可能就像在某個(gè)標(biāo)頭中聲明的函數(shù)void init();
一樣簡(jiǎn)單,然后在運(yùn)算符庫(kù)中將其定義為void init() { }
。 在主應(yīng)用程序中調(diào)用此init()
函數(shù)會(huì)給鏈接器以印象,這是一個(gè)值得鏈接的庫(kù)。 不幸的是,這不在我們的控制范圍之內(nèi),我們寧愿讓您知道其原因和簡(jiǎn)單的解決方法,而不是讓您將一些不透明的宏放入代碼中。
現(xiàn)在,由于我們現(xiàn)在在頂層找到了Torch
軟件包,因此warp_perspective
子目錄中的CMakeLists.txt
文件可以縮短一些。 它看起來(lái)應(yīng)該像這樣:
find_package(OpenCV REQUIRED)
add_library(warp_perspective SHARED op.cpp)
target_compile_features(warp_perspective PRIVATE cxx_range_for)
target_link_libraries(warp_perspective PRIVATE "${TORCH_LIBRARIES}")
target_link_libraries(warp_perspective PRIVATE opencv_core opencv_photo)
讓我們重新構(gòu)建示例應(yīng)用程序,該應(yīng)用程序還將與自定義運(yùn)算符庫(kù)鏈接。 在頂層example_app
目錄中:
$ mkdir build
$ cd build
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
-- The C compiler identification is GNU 5.4.0
-- The CXX compiler identification is GNU 5.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE
-- Found torch: /libtorch/lib/libtorch.so
-- Configuring done
-- Generating done
-- Build files have been written to: /warp_perspective/example_app/build
$ make -j
Scanning dependencies of target warp_perspective
[ 25%] Building CXX object warp_perspective/CMakeFiles/warp_perspective.dir/op.cpp.o
[ 50%] Linking CXX shared library libwarp_perspective.so
[ 50%] Built target warp_perspective
Scanning dependencies of target example_app
[ 75%] Building CXX object CMakeFiles/example_app.dir/main.cpp.o
[100%] Linking CXX executable example_app
[100%] Built target example_app
如果現(xiàn)在運(yùn)行example_app
二進(jìn)制文件并將其傳遞給序列化模型,我們應(yīng)該得出一個(gè)圓滿的結(jié)局:
$ ./example_app example.pt
11.4125 5.8262 9.5345 8.6111 12.3997
7.4683 13.5969 9.0850 11.0698 9.4008
7.4597 15.0926 12.5727 8.9319 9.0666
9.4834 11.1747 9.0162 10.9521 8.6269
10.0000 10.0000 10.0000 10.0000 10.0000
10.0000 10.0000 10.0000 10.0000 10.0000
10.0000 10.0000 10.0000 10.0000 10.0000
10.0000 10.0000 10.0000 10.0000 10.0000
[ Variable[CPUFloatType]{8,5} ]
成功! 您現(xiàn)在可以推斷了。
本教程向您介紹了如何在 C ++中實(shí)現(xiàn)自定義 TorchScript 運(yùn)算符,如何將其構(gòu)建到共享庫(kù)中,如何在 Python 中使用它來(lái)定義 TorchScript 模型,最后如何將其加載到 C ++應(yīng)用程序中以進(jìn)行推理工作負(fù)載。 現(xiàn)在,您可以使用與第三方 C ++庫(kù)進(jìn)行接口的 C ++運(yùn)算符擴(kuò)展 TorchScript 模型,編寫自定義的高性能 CUDA 內(nèi)核,或?qū)崿F(xiàn)任何其他需要 Python,TorchScript 和 C ++之間的界線才能平穩(wěn)融合的用例。
與往常一樣,如果您遇到任何問題或疑問,可以使用我們的論壇或 GitHub 問題進(jìn)行聯(lián)系。 另外,我們的常見問題解答(FAQ)頁(yè)面可能包含有用的信息。
“構(gòu)建自定義運(yùn)算符”一節(jié)介紹了如何使用 CMake 將自定義運(yùn)算符構(gòu)建到共享庫(kù)中。 本附錄概述了兩種進(jìn)一步的編譯方法。 他們倆都使用 Python 作為編譯過程的“驅(qū)動(dòng)程序”或“接口”。 此外,兩者都重新使用了現(xiàn)有基礎(chǔ)結(jié)構(gòu) PyTorch 提供了 C ++擴(kuò)展 ,它們是依賴于 [pybind11 用于將功能從 C ++“顯式”綁定到 Python。
第一種方法是使用 C ++擴(kuò)展程序的方便的即時(shí)(JIT)編譯界面在您首次運(yùn)行 PyTorch 腳本時(shí)在后臺(tái)編譯代碼。 第二種方法依賴于古老的setuptools
包,并涉及編寫單獨(dú)的setup.py
文件。 這樣可以進(jìn)行更高級(jí)的配置,并與其他基于setuptools
的項(xiàng)目集成。 我們將在下面詳細(xì)探討這兩種方法。
PyTorch C ++擴(kuò)展工具包提供的 JIT 編譯功能可將您的自定義運(yùn)算符的編譯直接嵌入到您的 Python 代碼中,例如 在訓(xùn)練腳本的頂部。
Note
這里的“ JIT 編譯”與 TorchScript 編譯器中用于優(yōu)化程序的 JIT 編譯無(wú)關(guān)。 這只是意味著您的自定義運(yùn)算符 C ++代碼將在您首次導(dǎo)入時(shí)在系統(tǒng) <cite>/ tmp</cite> 目錄下的文件夾中編譯,就像您自己事先對(duì)其進(jìn)行編譯一樣。
此 JIT 編譯功能有兩種形式。 首先,您仍然將操作員實(shí)現(xiàn)保存在單獨(dú)的文件(op.cpp
)中,然后使用torch.utils.cpp_extension.load()
編譯擴(kuò)展名。 通常,此函數(shù)將返回暴露您的 C ++擴(kuò)展的 Python 模塊。 但是,由于我們沒有將自定義運(yùn)算符編譯到其自己的 Python 模塊中,因此我們只想編譯一個(gè)普通的共享庫(kù)。 幸運(yùn)的是,torch.utils.cpp_extension.load()
有一個(gè)參數(shù)is_python_module
,可以將其設(shè)置為False
,以表明我們僅對(duì)構(gòu)建共享庫(kù)感興趣,而對(duì) Python 模塊不感興趣。 然后torch.utils.cpp_extension.load()
將會(huì)編譯并將共享庫(kù)也加載到當(dāng)前進(jìn)程中,就像torch.ops.load_library
之前所做的那樣:
import torch.utils.cpp_extension
torch.utils.cpp_extension.load(
name="warp_perspective",
sources=["op.cpp"],
extra_ldflags=["-lopencv_core", "-lopencv_imgproc"],
is_python_module=False,
verbose=True
)
print(torch.ops.my_ops.warp_perspective)
這應(yīng)該大致打印:
<built-in method my_ops::warp_perspective of PyCapsule object at 0x7f3e0f840b10>
JIT 編譯的第二種形式使您可以將自定義 TorchScript 運(yùn)算符的源代碼作為字符串傳遞。 為此,請(qǐng)使用torch.utils.cpp_extension.load_inline
:
import torch
import torch.utils.cpp_extension
op_source = """
#include <opencv2/opencv.hpp>
#include <torch/script.h>
torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
cv::Mat image_mat(/*rows=*/image.size(0),
/*cols=*/image.size(1),
/*type=*/CV_32FC1,
/*data=*/image.data<float>());
cv::Mat warp_mat(/*rows=*/warp.size(0),
/*cols=*/warp.size(1),
/*type=*/CV_32FC1,
/*data=*/warp.data<float>());
cv::Mat output_mat;
cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{64, 64});
torch::Tensor output =
torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{64, 64});
return output.clone();
}
static auto registry =
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);
"""
torch.utils.cpp_extension.load_inline(
name="warp_perspective",
cpp_sources=op_source,
extra_ldflags=["-lopencv_core", "-lopencv_imgproc"],
is_python_module=False,
verbose=True,
)
print(torch.ops.my_ops.warp_perspective)
自然,最佳實(shí)踐是僅在源代碼相當(dāng)短的情況下才使用torch.utils.cpp_extension.load_inline
。
請(qǐng)注意,如果您在 Jupyter Notebook 中使用此功能,則不應(yīng)多次執(zhí)行單元格的注冊(cè),因?yàn)槊看螆?zhí)行都會(huì)注冊(cè)一個(gè)新庫(kù)并重新注冊(cè)自定義運(yùn)算符。 如果需要重新執(zhí)行它,請(qǐng)事先重新啟動(dòng)筆記本的 Python 內(nèi)核。
從 Python 專門構(gòu)建自定義運(yùn)算符的第二種方法是使用setuptools
。 這樣做的好處是setuptools
具有用于構(gòu)建用 C ++編寫的 Python 模塊的功能非常強(qiáng)大且廣泛的接口。 但是,由于setuptools
實(shí)際上是用于構(gòu)建 Python 模塊而不是普通的共享庫(kù)(它們沒有 Python 期望從模塊中獲得的必要入口點(diǎn)),因此這種方法可能有點(diǎn)古怪。 也就是說(shuō),您需要的是一個(gè)setup.py
文件來(lái)代替CMakeLists.txt
,該文件看起來(lái)像這樣:
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
setup(
name="warp_perspective",
ext_modules=[
CppExtension(
"warp_perspective",
["example_app/warp_perspective/op.cpp"],
libraries=["opencv_core", "opencv_imgproc"],
)
],
cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)},
)
請(qǐng)注意,我們?cè)诘撞康?code>BuildExtension中啟用了no_python_abi_suffix
選項(xiàng)。 這指示setuptools
在產(chǎn)生的共享庫(kù)的名稱中省略任何特定于 Python-3 的 ABI 后綴。 否則,例如在 Python 3.7 上,該庫(kù)可能被稱為warp_perspective.cpython-37m-x86_64-linux-gnu.so
,其中cpython-37m-x86_64-linux-gnu
是 ABI 標(biāo)簽,但我們確實(shí)只是希望將其稱為warp_perspective.so
如果現(xiàn)在從setup.py
所在的文件夾中的終端中運(yùn)行python setup.py build develop
,我們應(yīng)該看到類似以下內(nèi)容:
$ python setup.py build develop
running build
running build_ext
building 'warp_perspective' extension
creating build
creating build/temp.linux-x86_64-3.7
gcc -pthread -B /root/local/miniconda/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/root/local/miniconda/lib/python3.7/site-packages/torch/lib/include -I/root/local/miniconda/lib/python3.7/site-packages/torch/lib/include/torch/csrc/api/include -I/root/local/miniconda/lib/python3.7/site-packages/torch/lib/include/TH -I/root/local/miniconda/lib/python3.7/site-packages/torch/lib/include/THC -I/root/local/miniconda/include/python3.7m -c op.cpp -o build/temp.linux-x86_64-3.7/op.o -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=warp_perspective -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++11
cc1plus: warning: command line option '-Wstrict-prototypes' is valid for C/ObjC but not for C++
creating build/lib.linux-x86_64-3.7
g++ -pthread -shared -B /root/local/miniconda/compiler_compat -L/root/local/miniconda/lib -Wl,-rpath=/root/local/miniconda/lib -Wl,--no-as-needed -Wl,--sysroot=/ build/temp.linux-x86_64-3.7/op.o -lopencv_core -lopencv_imgproc -o build/lib.linux-x86_64-3.7/warp_perspective.so
running develop
running egg_info
creating warp_perspective.egg-info
writing warp_perspective.egg-info/PKG-INFO
writing dependency_links to warp_perspective.egg-info/dependency_links.txt
writing top-level names to warp_perspective.egg-info/top_level.txt
writing manifest file 'warp_perspective.egg-info/SOURCES.txt'
reading manifest file 'warp_perspective.egg-info/SOURCES.txt'
writing manifest file 'warp_perspective.egg-info/SOURCES.txt'
running build_ext
copying build/lib.linux-x86_64-3.7/warp_perspective.so ->
Creating /root/local/miniconda/lib/python3.7/site-packages/warp-perspective.egg-link (link to .)
Adding warp-perspective 0.0.0 to easy-install.pth file
Installed /warp_perspective
Processing dependencies for warp-perspective==0.0.0
Finished processing dependencies for warp-perspective==0.0.0
這將產(chǎn)生一個(gè)名為warp_perspective.so
的共享庫(kù),我們可以像之前那樣將其傳遞給torch.ops.load_library
,以使我們的操作員對(duì) TorchScript 可見:
>>> import torch
>>> torch.ops.load_library("warp_perspective.so")
>>> print(torch.ops.custom.warp_perspective)
<built-in method custom::warp_perspective of PyCapsule object at 0x7ff51c5b7bd0>
更多建議: