Skip to content

Torch-TensorRT v1.1.0

Compare
Choose a tag to compare
@narendasan narendasan released this 10 May 08:23

Support for PyTorch 1.11, Various Bug Fixes, Partial aten::Int support, New Debugging Tools, Removing Max Batch Size

Torch-TensorRT 1.1.0 targets PyTorch 1.11, CUDA 11.3, cuDNN 8.2 and TensorRT 8.2. Due to recent JetPack upgrades, this release does not support Jetson (Jetpack 5.0DP or otherwise). Jetpack 5.0DP support will arrive in a mid-cycle release (Torch-TensorRT 1.1.x) along with support for TensorRT 8.4. 1.1.0 also drops support for Python 3.6 as it has reached end of life. Following 1.0.0, this release is focused on stabilizing and improving the core of Torch-TensorRT. Many improvements have been made to the partitioning system addressing limitation many users hit while trying to partially compile PyTorch modules. Torch-TensorRT 1.1.0 also addresses a long standing issue with aten::Int operators (albeit) partially. Now certain common patterns which use aten::Int can be handled by the compiler without resorting to partial compilation. Most notably, this means that models like BERT can be run end to end with Torch-TensorRT, resulting in significant performance gains.

New Debugging Tools

With this release we are introducing new syntax sugar that can be used to more easily debug Torch-TensorRT compilation and execution through the use of context managers. For example, in Torch-TensorRT 1.0.0 this may be a common pattern to turn on then turn off debug info:

import torch_tensorrt
...
torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Debug)
trt_module = torch_tensorrt.compile(my_module, ...)
torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Warning)
results = trt_module(input_tensors)

With Torch-TensorRT 1.1.0, this now can be done with the following code:

import torch_tensorrt
...
with torch_tensorrt.logging.debug():
    trt_module = torch_tensorrt.compile(my_module,...)
results = trt_module(input_tensors)

You can also use this API to debug the Torch-TensorRT runtime as well:

import torch_tensorrt
torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Error)
...
trt_module = torch_tensorrt.compile(my_module,...)
with torch_tensorrt.logging.warnings():
    results = trt_module(input_tensors)

The following levels are available:

# Only internal TensorRT failures will be logged
with torch_tensorrt.logging.internal_errors():

# Internal TensorRT failures + Torch-TensorRT errors will be logged
with torch_tensorrt.logging.errors():

# All Errors plus warnings will be logged
with torch_tensorrt.logging.warnings():

# First verbosity level, information about major steps occurring during compilation and execution
with torch_tensorrt.logging.info():

# Second verbosity level, each step is logged + information about compiler state will be outputted
with torch_tensorrt.logging.debug():

# Third verbosity level, all above information + intermediate transformations of the graph during lowering
with torch_tensorrt.logging.graphs():

Removing Max Batch Size, Strict Types

In this release we are removing the max_batch_size and strict_types settings. These settings directly corresponded to the TensorRT settings, however were not always respected which often lead to confusion. Therefore we thought it best to disable these features as deterministic behavior could not be ensured.

Porting forward from max_batch_size, strict_types:

  • max_batch_size: The first dim in shapes provided to Torch-TensorRT are considered batch dimensions, therefore instead of setting max_batch_size, you can just use the Input objects directly
  • strict_types: A replacement with more deterministic behavior will come with an upcoming TensorRT release.

Dependencies

- Bazel 5.1.1
- LibTorch 1.11.0
- CUDA 11.3 (on x86_64, by default, newer CUDA 11 supported with compatible PyTorch Build)
- cuDNN 8.2.4.15
- TensorRT 8.2.4.2

1.1.0 (2022-05-10)

Bug Fixes

  • add at::adaptive_avg_pool1d in interpolate plugin and fix #791 (deb9f74)
  • Added ipywidget dependency to notebook (0b2040a)
  • Added test case names (296e98a)
  • Added truncate_long_and_double (417c096)
  • Adding truncate_long_and_double to ptq tests (3a0640a)
  • Avoid resolving non-tensor inputs to torch segment_blocks unneccessarily (3e090ee)
  • Considering rtol and atol in threshold comparison for floating point numbers (0b0ba8d)
  • Disabled mobilenet_v2 test for DLFW CI (40c611f)
  • fix bug that python api doesn't pass truncate_long_and_double value to internal.partition_info (828336d)
  • fix bugs in aten::to (2ecd187)
  • Fix BUILD file for tests/accuracy (8b0170e)
  • Fix existing uninstallation of Torch-TRT (9ddd7a8)
  • Fix for torch scripted module faiure with DLFW (88c02d9)
  • Fix fuse addmm pass (58e9ea0)
  • Fix pre_built name change in bazelrc (3ecee21)
  • fix the bug that introduces kLong Tensor in prim::NumToTensor (2c3e1d9)
  • Fix when TRT prunes away an output (9465e1d)
  • Fixed bugs and addressed review comments (588e1d1)
  • Fixed failures for host deps sessions (ec2232f)
  • Fixed typo in the path (43fab56)
  • Getting unsupported ops will now bypass non-schema ops avoiding redundant failures (d7d1511)
  • Guard test activation for CI testing (6d1a1fd)
  • Implement a patch for gelu schema change in older NGC containers (9ee3a04)
  • Missing log severity (6a4daef)
  • Preempt torch package override via timm in nox session (8964d1b)
  • refactor the resegmentation for TensorRT segments in ResolveNonTensorInput (3cc2dfb)
  • remove outdated member variables (0268da2)
  • Removed models directory dependencies (c4413e1)
  • Resolve issues in exception elmination pass (99cea1b)
  • Review comments incorporated (962660d)
  • Review comments incorporated (e9865c2)
  • support dict type for input in shape analysis (630f9c4)
  • truncate_long_and_double incur torchscript inference issues (c83aa15)
  • Typo fix for test case name (2a516b2)
  • Update "reduceAxes" variable in GlobalPoolingConverter function and add corresponding uTests (f6f5e3e)
  • //core/conversion/evaluators: Change how schemas are handled (20e5d41)
  • Update base container for dockerfile (1b3245a)
  • //core: Take user setting in the case we can't determine the (01c89d1), closes #814
  • Update test for new Exception syntax (2357099)
  • //core/conversion: Add special case for If and Loop (eacde8d)
  • //core/runtime: Support more delimiter variants (819c911)
  • //cpp/bin/torchtrtc: Fix mbs (aca175f)
  • //docsrc: Fix dependencies for docgen (806e663)
  • //notebooks: Render citrinet (12dbda1)
  • //py: Constrain the CUDA version in container builds (a21a045)
  • Use user provided dtype when we can't infer it from the graph (14650d1)

Code Refactoring

  • removing the strict_types and max_batch_size apis (b30cbd9)
  • Rename enabled precisions arugment to (10957eb)
  • Removing the max-batch-size argument (03bafc5)

Features

  • //core/conversion: Better tooling for debugging (c5c5c47)
  • //core/conversion/evaluators: aten::pow support (c4fdfcb)
  • //docker: New base container to let master build in container (446bf18)
  • //py: Context managers to quickly switch logging level (12e470f)
  • Add converter files for reflection pad 1d and 2d (406d860)
  • Add converter files for torch::max (f628aca)
  • Add converter files for torch::max (569bcde)
  • Add converter files for torch::max (dd7a44e)
  • Add converter for reflection pad 1d and 2d operation (2484a43)
  • Added comprehensive perf benchmark script (a8016ff)
  • Added compute capability for Orin (af3d0ff)
  • Added env var for TOP_DIR (c26180e)
  • Added Python accuracy tests using Nox (6ae8652)
  • Enable prim::DictConstruct to fallback without conversion check error (01d98c7)
  • Handle empty schemas for unsupported ops (bf6c929)
  • Implement fast approximation of Gelu as lowering pass to improve performance (8024ea2)
  • Implement lowering for aten::to.dtype schema (4b3ae3a)
  • Implement test case for aten::to.dtype lowering (bde8ee0)
  • Perf benchmark initial draft (f2d1655)
  • replace view with reshape during lowering (d39b918)
  • Review comment incorporated (161ef3d)
  • support aten::adaptive_max_pool1d, aten::adaptive_avg_pool3d and aten::adaptive_max_pool3d operators (e554dbd)
  • support aten::div.Tensor_mode (bb3046a)
  • support aten::extend evaluator (33c523d)
  • support aten::format evaluator (3a33d33)
  • Update Pytorch version to 1.11 (c009a1f)
  • Upgrade TensorRT to 8.2.4.2 (f1f151b)
  • //tests: Adding BERT to the test suite (7996a10)
  • aten::__range_length: Adding range length evaluator (11c4608)
  • aten::add: adding string concat evaluator (65dbf90)
  • aten::Int: Adding a new pass to remove single use (46ac757)
  • aten::Int: Lowers out aten::Int (908340f)
  • core//conversion: Implement converter for torch unbind (268a49b)

BREAKING CHANGES

  • This commit removes the strict types and max_batch_size apis. We are doing this because the functionality of these APIs in TRT is convoluted and likely to be ignored during building. A replacement for strict types with actual guarantees will be added at a later date.

Signed-off-by: Dheeraj Peri [email protected]

  • This is a minor change but may cause scripts
    using torchtrtc to fail. We are renaming enabled-precisions to
    enable-precision since it makes more sense as the argument can
    be repeated

Signed-off-by: Naren Dasan [email protected]
Signed-off-by: Naren Dasan [email protected]

  • This PR removes --max-batch-size from the CLI
    as it has no real functional effect

Signed-off-by: Naren Dasan [email protected]
Signed-off-by: Naren Dasan [email protected]

Operators Supported

Operators Currently Supported Through Converters

  • aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> (Tensor)
  • aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor)
  • aten::abs(Tensor self) -> (Tensor)
  • aten::acos(Tensor self) -> (Tensor)
  • aten::acosh(Tensor self) -> (Tensor)
  • aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> (Tensor)
  • aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)
  • aten::adaptive_avg_pool3d(Tensor self, int[3] output_size) -> (Tensor)
  • aten::adaptive_max_pool1d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
  • aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
  • aten::adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)
  • aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)
  • aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
  • aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))
  • aten::asin(Tensor self) -> (Tensor)
  • aten::asinh(Tensor self) -> (Tensor)
  • aten::atan(Tensor self) -> (Tensor)
  • aten::atanh(Tensor self) -> (Tensor)
  • aten::avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=[0], bool ceil_mode=False, bool count_include_pad=True) -> (Tensor)
  • aten::avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=[0, 0], bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
  • aten::avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=[], bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
  • aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta, Tensor? mean, Tensor? var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor)
  • aten::bmm(Tensor self, Tensor mat2) -> (Tensor)
  • aten::cat(Tensor[] tensors, int dim=0) -> (Tensor)
  • aten::ceil(Tensor self) -> (Tensor)
  • aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> (Tensor)
  • aten::clamp_max(Tensor self, Scalar max) -> (Tensor)
  • aten::clamp_min(Tensor self, Scalar min) -> (Tensor)
  • aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor)
  • aten::cos(Tensor self) -> (Tensor)
  • aten::cosh(Tensor self) -> (Tensor)
  • aten::cumsum(Tensor self, int dim, *, int? dtype=None) -> (Tensor)
  • aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::div.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> (Tensor)
  • aten::div_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))
  • aten::div_.Tensor(Tensor(a!) self, Tensor other) -> (Tensor(a!))
  • aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)
  • aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> (Tensor)
  • aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::eq.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::erf(Tensor self) -> (Tensor)
  • aten::exp(Tensor self) -> (Tensor)
  • aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))
  • aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))
  • aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor)
  • aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor)
  • aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)
  • aten::floor(Tensor self) -> (Tensor)
  • aten::floor_divide(Tensor self, Tensor other) -> (Tensor)
  • aten::floor_divide.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::ge.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::ge.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor)
  • aten::gt.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::gt.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor)
  • aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor(a!))
  • aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)
  • aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> (Tensor)
  • aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta, float eps, bool cudnn_enabled) -> (Tensor)
  • aten::le.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::le.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> (Tensor)
  • aten::leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> (Tensor(a!))
  • aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> (Tensor)
  • aten::log(Tensor self) -> (Tensor)
  • aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)
  • aten::lt.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::lt.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)
  • aten::matmul(Tensor self, Tensor other) -> (Tensor)
  • aten::max(Tensor self) -> (Tensor)
  • aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
  • aten::max.other(Tensor self, Tensor other) -> (Tensor)
  • aten::max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> (Tensor)
  • aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=[0, 0], int[2] dilation=[1, 1], bool ceil_mode=False) -> (Tensor)
  • aten::max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=[], int[3] dilation=[], bool ceil_mode=False) -> (Tensor)
  • aten::mean(Tensor self, *, int? dtype=None) -> (Tensor)
  • aten::mean.dim(Tensor self, int[] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
  • aten::min(Tensor self) -> (Tensor)
  • aten::min.other(Tensor self, Tensor other) -> (Tensor)
  • aten::mul.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::mul.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> (Tensor(a!))
  • aten::narrow(Tensor(a) self, int dim, int start, int length) -> (Tensor(a))
  • aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> (Tensor(a))
  • aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::ne.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::neg(Tensor self) -> (Tensor)
  • aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)
  • aten::permute(Tensor(a) self, int[] dims) -> (Tensor(a))
  • aten::pixel_shuffle(Tensor self, int upscale_factor) -> (Tensor)
  • aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)
  • aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> (Tensor)
  • aten::prelu(Tensor self, Tensor weight) -> (Tensor)
  • aten::prod(Tensor self, *, int? dtype=None) -> (Tensor)
  • aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
  • aten::reciprocal(Tensor self) -> (Tensor)
  • aten::reflection_pad1d(Tensor self, int[2] padding) -> (Tensor)
  • aten::reflection_pad2d(Tensor self, int[4] padding) -> (Tensor)
  • aten::relu(Tensor input) -> (Tensor)
  • aten::relu_(Tensor(a!) self) -> (Tensor(a!))
  • aten::repeat(Tensor self, int[] repeats) -> (Tensor)
  • aten::replication_pad1d(Tensor self, int[2] padding) -> (Tensor)
  • aten::replication_pad2d(Tensor self, int[4] padding) -> (Tensor)
  • aten::replication_pad3d(Tensor self, int[6] padding) -> (Tensor)
  • aten::reshape(Tensor self, int[] shape) -> (Tensor)
  • aten::roll(Tensor self, int[1] shifts, int[1] dims=[]) -> (Tensor)
  • aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)
  • aten::rsub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
  • aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))
  • aten::sigmoid(Tensor input) -> (Tensor)
  • aten::sigmoid_(Tensor(a!) self) -> (Tensor(a!))
  • aten::sin(Tensor self) -> (Tensor)
  • aten::sinh(Tensor self) -> (Tensor)
  • aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> (Tensor(a))
  • aten::softmax.int(Tensor self, int dim, int? dtype=None) -> (Tensor)
  • aten::split(Tensor self, int[] split_sizes, int dim=0) -> (Tensor[])
  • aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> (Tensor[])
  • aten::split_with_sizes(Tensor(a) self, int[] split_sizes, int dim=0) -> (Tensor[])
  • aten::sqrt(Tensor self) -> (Tensor)
  • aten::squeeze.dim(Tensor(a) self, int dim) -> (Tensor(a))
  • aten::stack(Tensor[] tensors, int dim=0) -> (Tensor)
  • aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)
  • aten::sub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
  • aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))
  • aten::sum(Tensor self, *, int? dtype=None) -> (Tensor)
  • aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
  • aten::t(Tensor self) -> (Tensor)
  • aten::tan(Tensor self) -> (Tensor)
  • aten::tanh(Tensor input) -> (Tensor)
  • aten::tanh_(Tensor(a!) self) -> (Tensor(a!))
  • aten::to.device(Tensor(a) self, Device device, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))
  • aten::to.dtype(Tensor self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor)
  • aten::to.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor)
  • aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> (Tensor(a|b))
  • aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
  • aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> (Tensor(a))
  • aten::unbind.int(Tensor(a -> *) self, int dim=0) -> (Tensor[])
  • aten::unsqueeze(Tensor(a) self, int dim) -> (Tensor(a))
  • aten::upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> (Tensor)
  • aten::upsample_bilinear2d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)
  • aten::upsample_linear1d(Tensor self, int[1] output_size, bool align_corners, float? scales=None) -> (Tensor)
  • aten::upsample_linear1d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)
  • aten::upsample_nearest1d(Tensor self, int[1] output_size, float? scales=None) -> (Tensor)
  • aten::upsample_nearest1d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)
  • aten::upsample_nearest2d(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None) -> (Tensor)
  • aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)
  • aten::upsample_nearest3d(Tensor self, int[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)
  • aten::upsample_nearest3d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)
  • aten::upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)
  • aten::upsample_trilinear3d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)
  • aten::view(Tensor(a) self, int[] size) -> (Tensor(a))
  • trt::const(Tensor self) -> (Tensor)

Operators Currently Supported Through Evaluators

  • aten::Bool.float(float b) -> (bool)
  • aten::Bool.int(int a) -> (bool)
  • aten::Float.Scalar(Scalar a) -> float
  • aten::Float.bool(bool a) -> float
  • aten::Float.int(int a) -> float
  • aten::Int.Scalar(Scalar a) -> int
  • aten::Int.bool(bool a) -> int
  • aten::Int.float(float a) -> int
  • aten::Int.int(int a) -> int
  • aten::and(int a, int b) -> (bool)
  • aten::and.bool(bool a, bool b) -> (bool)
  • aten::getitem.t(t list, int idx) -> (t(*))
  • aten::is(t1 self, t2 obj) -> bool
  • aten::isnot(t1 self, t2 obj) -> bool
  • aten::not(bool self) -> bool
  • aten::or(int a, int b) -> (bool)
  • aten::__range_length(int lo, int hi, int step) -> int
  • aten::__round_to_zero_floordiv(int a, int b) -> (int)
  • aten::xor(int a, int b) -> (bool)
  • aten::add.float(float a, float b) -> (float)
  • aten::add.int(int a, int b) -> (int)
  • aten::add.str(str a, str b) -> (str)
  • aten::add_.t(t self, t[] b) -> (t[])
  • aten::append.t(t self, t(c -> *) el) -> (t)
  • aten::arange(Scalar end, *, int? dtype=None, int? layout=None,
    Device? device=None, bool? pin_memory=None) -> (Tensor)
  • aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None,
    Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)
  • aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None,
    Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)
  • aten::clone(Tensor self, *, int? memory_format=None) -> (Tensor)
  • aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> (Tensor(a!))
  • aten::dim(Tensor self) -> int
  • aten::div.float(float a, float b) -> (float)
  • aten::div.int(int a, int b) -> (float)
  • aten::eq.bool(bool a, bool b) -> (bool)
  • aten::eq.float(float a, float b) -> (bool)
  • aten::eq.float_int(float a, int b) -> (bool)
  • aten::eq.int(int a, int b) -> (bool)
  • aten::eq.int_float(int a, float b) -> (bool)
  • aten::eq.str(str a, str b) -> (bool)
  • aten::extend.t(t self, t[] other) -> ()
  • aten::floor.float(float a) -> (int)
  • aten::floor.int(int a) -> (int)
  • aten::floordiv.float(float a, float b) -> (int)
  • aten::floordiv.int(int a, int b) -> (int)
  • aten::format(str self, ...) -> (str)
  • aten::ge.bool(bool a, bool b) -> (bool)
  • aten::ge.float(float a, float b) -> (bool)
  • aten::ge.float_int(float a, int b) -> (bool)
  • aten::ge.int(int a, int b) -> (bool)
  • aten::ge.int_float(int a, float b) -> (bool)
  • aten::gt.bool(bool a, bool b) -> (bool)
  • aten::gt.float(float a, float b) -> (bool)
  • aten::gt.float_int(float a, int b) -> (bool)
  • aten::gt.int(int a, int b) -> (bool)
  • aten::gt.int_float(int a, float b) -> (bool)
  • aten::is_floating_point(Tensor self) -> (bool)
  • aten::le.bool(bool a, bool b) -> (bool)
  • aten::le.float(float a, float b) -> (bool)
  • aten::le.float_int(float a, int b) -> (bool)
  • aten::le.int(int a, int b) -> (bool)
  • aten::le.int_float(int a, float b) -> (bool)
  • aten::len.t(t[] a) -> (int)
  • aten::lt.bool(bool a, bool b) -> (bool)
  • aten::lt.float(float a, float b) -> (bool)
  • aten::lt.float_int(float a, int b) -> (bool)
  • aten::lt.int(int a, int b) -> (bool)
  • aten::lt.int_float(int a, float b) -> (bool)
  • aten::mul.float(float a, float b) -> (float)
  • aten::mul.int(int a, int b) -> (int)
  • aten::ne.bool(bool a, bool b) -> (bool)
  • aten::ne.float(float a, float b) -> (bool)
  • aten::ne.float_int(float a, int b) -> (bool)
  • aten::ne.int(int a, int b) -> (bool)
  • aten::ne.int_float(int a, float b) -> (bool)
  • aten::neg.int(int a) -> (int)
  • aten::numel(Tensor self) -> int
  • aten::pow.float(float a, float b) -> (float)
  • aten::pow.float_int(float a, int b) -> (float)
  • aten::pow.int(int a, int b) -> (float)
  • aten::pow.int_float(int a, float b) -> (float)
  • aten::size(Tensor self) -> (int[])
  • aten::size.int(Tensor self, int dim) -> (int)
  • aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])
  • aten::sqrt.float(float a) -> (float)
  • aten::sqrt.int(int a) -> (float)
  • aten::sub.float(float a, float b) -> (float)
  • aten::sub.int(int a, int b) -> (int)
  • aten::tensor(t[] data, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor)
  • prim::dtype(Tensor a) -> (int)
  • prim::max.bool(bool a, bool b) -> (bool)
  • prim::max.float(float a, float b) -> (bool)
  • prim::max.float_int(float a, int b) -> (bool)
  • prim::max.int(int a, int b) -> (bool)
  • prim::max.int_float(int a, float b) -> (bool)
  • prim::max.self_int(int[] self) -> (int)
  • prim::min.bool(bool a, bool b) -> (bool)
  • prim::min.float(float a, float b) -> (bool)
  • prim::min.float_int(float a, int b) -> (bool)
  • prim::min.int(int a, int b) -> (bool)
  • prim::min.int_float(int a, float b) -> (bool)
  • prim::min.self_int(int[] self) -> (int)
  • prim::shape(Tensor a) -> (int[])