Skip to content

Performance improvements, new layers, ship models to other frameworks (via ONNX), CUDA9, CuDNNv7, lots of bug fixes

Compare
Choose a tag to compare
@soumith soumith released this 05 Dec 01:57

Table of contents

  • Breaking changes: removed reinforce()
  • New features
    • Unreduced losses
    • A profiler for the autograd engine
    • More functions support Higher order gradients
    • New features in Optimizers
    • New layers and nn functionality
    • New Tensor functions and Features
    • Other additions
  • API changes
  • Performance improvements
    • Big reduction in framework overhead (helps small models)
    • 4x to 256x faster Softmax/LogSoftmax
    • More...
  • Framework Interoperability
    • DLPack Interoperability
    • Model Exporter to ONNX (ship PyTorch to Caffe2, CoreML, CNTK, MXNet, Tensorflow)
  • Bug Fixes (a lot of them)

Breaking changes

Stochastic functions, i.e. Variable.reinforce() were removed because of their limited functionality and broad performance implications. The motivation for stochastic functions was to avoid book-keeping of sampled values. In practice, users were still book-keeping in their code for various reasons. We constructed an alternative, equally effective API, but did not have a reasonable deprecation path to the new API. Hence this removal is a breaking change.

We introduce the torch.distributions package to replace Stochastic functions.

Your previous code typically looked like this:

probs = policy_network(state)
action = probs.multinomial()
next_state, reward = env.step(action)
action.reinforce(reward)
action.backward()

This is the new equivalent code:

probs = policy_network(state)
# NOTE: categorical is equivalent to what used to be called multinomial
m = torch.distributions.Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

New features

Unreduced losses

Now, Some loss functions can compute per-sample losses in a mini-batch

  • By default PyTorch sums losses over the mini-batch and returns a single scalar loss. This was limiting to users.
  • Now, a subset of loss functions allow specifying reduce=False to return individual losses for each sample in the mini-batch
  • Example: loss = nn.CrossEntropyLoss(..., reduce=False)
  • Currently supported losses: MSELoss, NLLLoss, NLLLoss2d, KLDivLoss, CrossEntropyLoss, SmoothL1Loss, L1Loss
  • More loss functions will be covered in the next release

An in-built Profiler in the autograd engine

We built a low-level profiler to help you identify bottlenecks in your models

Let us start with an example:

>>> x = Variable(torch.randn(1, 1), requires_grad=True)
>>> with torch.autograd.profiler.profile() as prof:
...     y = x ** 2
...     y.backward()
>>> # NOTE: some columns were removed for brevity
... print(prof)
--------------------------------  ----------  ---------
Name                               CPU time   CUDA time
-------------------------------   ----------  ---------
PowConstant                        142.036us    0.000us
N5torch8autograd9GraphRootE         63.524us    0.000us
PowConstantBackward                184.228us    0.000us
MulConstant                         50.288us    0.000us
PowConstant                         28.439us    0.000us
Mul                                 20.154us    0.000us
N5torch8autograd14AccumulateGradE   13.790us    0.000us
N5torch8autograd5CloneE              4.088us    0.000us

The profiler works for both CPU and CUDA models.
For CUDA models, you have to run your python program with a special nvprof prefix. For example:

nvprof --profile-from-start off -o trace_name.prof -- python <your arguments>

# in python
>>> with torch.cuda.profiler.profile():
...     model(x) # Warmup CUDA memory allocator and profiler
...     with torch.autograd.profiler.emit_nvtx():
...         model(x)

Then, you can load trace_name.prof in PyTorch and print a summary profile report.

>>> prof = torch.autograd.profiler.load_nvprof('trace_name.prof')
>>> print(prof)

Read additional documentation here

Higher order gradients

Added higher-order gradients support for the following layers

  • ConvTranspose, AvgPool1d, AvgPool2d, LPPool2d, AvgPool3d, MaxPool1d, MaxPool2d, AdaptiveMaxPool, AdaptiveAvgPool, FractionalMaxPool2d, MaxUnpool1d, MaxUnpool2d, nn.Upsample, ReplicationPad2d, ReplicationPad3d, ReflectionPad2d
  • PReLU, HardTanh, L1Loss, SoftSign, ELU, RReLU, Hardshrink, Softplus, SoftShrink, LogSigmoid, Softmin, GLU
  • MSELoss, SmoothL1Loss, KLDivLoss, HingeEmbeddingLoss, SoftMarginLoss, MarginRankingLoss, CrossEntropyLoss
  • DataParallel

Optimizers

  • optim.SparseAdam: Implements a lazy version of Adam algorithm suitable for sparse tensors.
    • In this variant, only moments that show up in the gradient get updated, and only those portions of the gradient get applied to the parameters.
  • Optimizers now have an add_param_group function that lets you add new parameter groups to an already constructed optimizer.

New layers and nn functionality

  • Added AdpativeMaxPool3d and AdaptiveAvgPool3d
  • Added LPPool1d
  • F.pad now has support for:
    • 'reflection' and 'replication' padding on 1d, 2d, 3d signals (so 3D, 4D and 5D Tensors)
    • constant padding on n-d signals
  • nn.Upsample now works for 1D signals (i.e. B x C x L Tensors) in nearest and linear modes.
  • grid_sample now allows padding with the border value via padding_mode="border". grid_sample expects a grid in the range of [-1, 1], and if the values are out of these bounds, padding with the value 0.0 is applied by default. However, in a lot of cases, using the border value (i.e. the nearest valid value) helps improve accuracy of the overall model.
  • Introducing nn.utils.parameters_to_vector and nn.utils.vector_to_parameters
    • parameters_to_vector takes net.parameters() and return a 1D vector that contains all the parameters
    • vector_to_parameters takes a vector of flattened parameters and copies the values over to a network's parameters
    • Convenient for some reinforcement learning algorithms, such as cross-entropy method, TRPO etc., which need to pull all network parameters as one big vector, modify them, and put the modified vector back.
  • Allow user to not specify certain input dimensions for AdaptivePool*d and infer them at runtime.
    • For example:
    # target output size of 10x7
    m = nn.AdaptiveMaxPool2d((None, 7))
  • DataParallel container on CPU is now a no-op (instead of erroring out)

New Tensor functions and features

  • Introduced torch.erf and torch.erfinv that compute the error function and the inverse error function of each element in the Tensor.
  • adds broadcasting support to bitwise operators
  • Added Tensor.put_ and torch.take similar to numpy.take and numpy.put.
    • The take function allows you to linearly index into a tensor without viewing it as a 1D tensor
      first. The output has the same shape as the indices.
    • The put function copies value into a tensor also using linear indices.
    • Differences from numpy equivalents:
      • numpy.take has an optional axis argument, which behaves like index_select. This axis argument is not yet present.
      • numpy.put repeats the values if necessary to make them as long as indices. This behavior is not yet replicated.
  • add zeros and zeros_like for sparse Tensors.
  • 1-element Tensors can now be casted to Python scalars. For example: int(torch.Tensor([5])) works now.

Other additions

  • Added torch.cuda.get_device_name and torch.cuda.get_device_capability that do what the names say. Example:
    >>> torch.cuda.get_device_name(0)
    'Quadro GP100'
    >>> torch.cuda.get_device_capability(0)
    (6, 0)
  • If one sets torch.backends.cudnn.deterministic = True, then the CuDNN convolutions use deterministic algorithms
  • torch.cuda_get_rng_state_all and torch.cuda_set_rng_state_all are introduced to let you save / load the state of the random number generator over all GPUs at once
  • torch.cuda.emptyCache() frees the cached memory blocks in PyTorch's caching allocator. This is useful when having long-running ipython notebooks while sharing the GPU with other processes.

API changes

  • softmax and log_softmax now take a dim argument that specifies the dimension in which slices are taken for the softmax operation. dim allows negative dimensions as well (dim = -1 will be the last dimension)
  • torch.potrf (Cholesky decomposition) is now differentiable and defined on Variable
  • Remove all instances of device_id and replace it with device, to make things consistent
  • torch.autograd.grad now allows you to specify inputs that are unused in the autograd graph if you use allow_unused=True
    This gets useful when using torch.autograd.grad in large graphs with lists of inputs / outputs
    For example:
    x, y = Variable(...), Variable(...)
    torch.autograd.grad(x * 2, [x, y]) # errors
    torch.autograd.grad(x * 2, [x, y], allow_unused=True) # works
  • pad_packed_sequence now allows a padding_value argument that can be used instead of zero-padding
  • Dataset now has a + operator (which uses ConcatDataset). You can do something like MNIST(...) + FashionMNIST(...) for example, and you will get a concatenated dataset containing samples from both.
  • torch.distributed.recv allows Tensors to be received from any sender (hence, src is optional). recv returns the rank of the sender.
  • adds zero_() to Variable
  • Variable.shape returns the size of the Tensor (now made consistent with Tensor)
  • torch.version.cuda specifies the CUDA version that PyTorch was compiled with
  • Add a missing function random_ for CUDA.
  • torch.load and torch.save can now take a pathlib.Path object, which is a standard Python3 typed filepath object
  • If you want to load a model's state_dict into another model (for example to fine-tune a pre-trained network), load_state_dict was strict on matching the key names of the parameters. Now we provide a strict=False option to load_state_dict where it only loads in parameters where the keys match, and ignores the other parameter keys.
  • added nn.functional.embedding_bag that is equivalent to nn.EmbeddingBag

Performance Improvements

  • The overhead of torch functions on Variables was around 10 microseconds. This has been brought down to ~1.5 microseconds by moving most of the core autograd formulas into C++ using our ATen library. This speeds-up models that are very small, such as small LSTMs and other common models seen in NLP.
  • softmax and log_softmax are now 4x to 256x faster on the GPU after rewriting the gpu kernels
  • 2.5x to 3x performance improvement of the distributed AllReduce (gloo backend) by enabling GPUDirect
  • nn.Embedding's renorm option is much faster on the GPU. For embedding dimensions of 100k x 128 and a batch size of 1024, it is 33x faster.
  • All pointwise ops now use OpenMP and get multi-core CPU benefits
  • Added dedicated CUDA kernels for group convolutions where groups == nInputPlane (depthwise convolution). Speedups range from 5x to 1000x for tested layer sizes. See the benchmark table for more details as well as this table.
  • Fixed optim.SGD's memory usage for sparse gradients (for ex. nn.Embedding(..., sparse=True)), reducing the usage on a user-provided test script by 10x.
  • Optional NNPack integration for faster CPU convolutions (not part of binaries)
  • Reduce overhead of broadcasting if Tensors aren't broadcastable
  • torch.nn.utils.weight_norm over the right-most dimensions is faster
  • Backward of torch.norm is sped up by ~1.5x
  • Improve the performance of pack_padded_sequence
  • Add a single-argument version of torch.arange. For example torch.arange(10)

Framework Interoperability

DLPack Interoperability

DLPack Tensors are cross-framework Tensor formats. We now have torch.utils.to_dlpack(x) and torch.utils.from_dlpack(x) to convert between DLPack and torch Tensor formats. The conversion has zero memory copy and hence is very efficient.

Model exporter to ONNX

ONNX is a common model interchange format that can be executed in Caffe2, CoreML, CNTK, MXNet, Tensorflow at the moment. PyTorch models that are ConvNet-like and RNN-like (static graphs) can now be shipped to the ONNX format.

  • There is a new module torch.onnx (http://pytorch.org/docs/0.3.0/onnx.html) which provides the API for exporting ONNX models.

  • The operations supported in this release are:

    • add, sub (nonzero alpha not supported), mul, div, cat, mm, addmm, neg, tanh, sigmoid, mean, t, transpose, view, split, squeeze
    • expand (only when used before a broadcasting ONNX operator; e.g., add)
    • prelu (single weight shared among input channels not supported)
    • threshold (non-zero threshold/non-zero value not supported)
    • Conv, ConvTranspose, BatchNorm, MaxPool, RNN, Dropout, ConstantPadNd, Negate
    • elu, leaky_relu, glu, softmax, log_softmax, avg_pool2d
    • unfold (experimental support with ATen-Caffe2 integration)
    • Embedding (no optional arguments supported)
    • RNN
    • FeatureDropout (training mode not supported)
    • Index (constant integer and tuple indices supported)

Usability Improvements

  • More cogent error messages during indexing of Tensors / Variables
    Breaking changes
  • Add proper error message for specifying dimension on a tensor with no dimensions
  • better error messages for Conv*d input shape checking
  • More user-friendly error messages for LongTensor indexing
  • Better error messages and argument checking for Conv*d routines
  • Trying to construct a Tensor from a Variable fails more appropriately
  • If you are using a PyTorch binary with insufficient CUDA version, then a warning is printed to the user.
  • Fixed incoherent error messages in load_state_dict
  • Fix error message for type mismatches with sparse tensors

Bug fixes

torch

  • Fix CUDA lazy initialization to not trigger on calls to torch.manual_seed (instead, the calls are queued and run when CUDA is initialized)

Tensor

  • if x is 2D, x[[0, 3],] was needed to trigger advanced indexing. The trailing comma is no longer needed, and you can do x[[0, 3]]
  • x.sort(descending=True) used to incorrectly fail for Tensors. Fixed a bug in the argument checking logic to allow this.
  • Tensor constructors with numpy input: torch.DoubleTensor(np.array([0,1,2], dtype=np.float32))
    • torch will now copy the contents of the array in a storage of appropriate type.
    • If types match, it will share the underlying array (no-copy), with equivalent semantics to initializing a tensor with another tensor.
    • On CUDA, torch.cuda.FloatTensor(np.random.rand(10,2).astype(np.float32)) will now work by making a copy.
  • ones_like and zeros_like now create Tensors on the same device as the original Tensor
  • torch.multinomial on the CPU would reshape the input prob_dist in-place. Fixed this to make sure the prob_dist input's shape is unchanged after the call to multinomial
  • expand and expand_as allow expanding an empty Tensor to another empty Tensor
  • when [..., None, ...] was given (i.e. newaxis placement in indexing was specified), PyTorch had different behavior from NumPy. This is made consistent with NumPy in all cases.
  • Fix exponential distribution implementation to never sample infinity - cuRAND returns numbers in (0, 1]
  • torch.HalfTensor supports numpy() and torch.from_numpy
  • Add additional size checking for torch.scatter
  • fix torch.tril and torch.triu on the GPU for storage-offset Tensors (would return incorrect result).
  • Fix a memory leak in CUDA qr decomposition
  • Fix stream-awareness issues in THCUNN kernels
  • Fix kwargs parsing in torch.topk
  • Fixed random_ on CPU (which previously had a max value of 2^32) for DoubleTensor and LongTensor
  • Fix ZeroDivisionError: float division by zero when printing certain Tensors
  • torch.gels when m > n had a truncation bug on the CPU and returned incorrect results. Fixed.
  • Add a check in tensor.numpy() that checks if no positional arguments are passed
  • Before a Tensor is moved to CUDA pinned memory, added a check to ensure that it is contiguous
  • any and all work on empty Tensors on the cpu (previously errored out)
  • Fix symeig on CUDA for large matrices. The bug is that not enough space was being allocated for the workspace, causing some undefined behavior.
  • Improved the numerical stability of torch.var and torch.std by using Welford's algorithm
  • The Random Number Generator returned uniform samples with inconsistent bounds (inconsistency in cpu implementation and running into a cublas bug).
    • Now, all uniform sampled numbers will return within the bounds [0, 1), across all types and devices
  • Fix torch.svd to not segfault on large CUDA Tensors (fixed an overflow error in the magma bindings)
  • Allows empty index Tensor for index_select (instead of erroring out)
  • Previously when eigenvector=False, symeig returns some unknown value for the eigenvectors. Now we zero them out.

sparse

  • Fix bug with 'coalesced' calculation in sparse 'cadd'
  • Fixes .type() not converting indices tensor.
  • Fixes sparse tensor coalesce on the GPU in corner cases

autograd

  • Fixed crashes when calling backwards on leaf variable with requires_grad=False
  • fix bug on Variable type() around non-default GPU input.
  • when torch.norm returned 0.0, the gradient was NaN. We now use the subgradient at 0.0, so the gradient is 0.0.
  • Fix an correctness issue with advanced indexing and higher-order gradients
  • torch.prod's backward was failing on the GPU due to a type error, fixed.
  • Advanced Indexing on Variables now allows the index to be a LongTensor backed Variable
  • Variable.cuda() and Tensor.cuda() are consistent in kwargs options

optim

  • torch.optim.lr_scheduler is now imported by default.

nn

  • Returning a dictionary from a nn.Module's forward function is now supported (used to throw an error)
  • When register_buffer("foo", ...) is called, and self.foo already exists, then instead of silently failing, now raises a KeyError
  • Fixed loading of older checkpoints of RNN/LSTM which were missing _data_ptrs attributes.
  • nn.Embedding had a hard error when using the max_norm option. This is fixed now.
  • when using the max_norm option, the passed-in indices are written upon (by the underlying implementation). To fix this, pass a clone of the indices to the renorm kernel.
  • F.affine_grid now can take non-contiguous inputs
  • EmbeddingBag can accept both 1D and 2D inputs now.
  • Workaround a CuDNN bug where batch sizes greater than 131070 fail in CuDNN BatchNorm
  • fix nn.init.orthogonal to correctly return orthonormal vectors when rows < cols
  • if BatchNorm has only 1 value per channel in total, raise an error in training mode.
  • Make cuDNN bindings respect the current cuda stream (previously raised incoherent error)
  • fix grid_sample backward when gradOutput is a zero-strided Tensor
  • Fix a segmentation fault when reflection padding is out of Tensor bounds.
  • If LogSoftmax has only 1 element, -inf was returned. Now this correctly returns 0.0
  • Fix pack_padded_sequence to accept inputs of arbitrary sizes (not just 3D inputs)
  • Detect pointer aliasing in cuDNN RNN flatten_parameters and avoid that path.
  • Fixed ELU higher order gradients when applied in-place
  • Workaround a CuDNN RNN bug for half-precision
  • Prevent numerical issues with poisson_nll_loss when log_input=False by adding a small epsilon

distributed and multi-gpu

  • Allow kwargs-only inputs to DataParallel. This used to fail: n = nn.DataParallel(Net()); out = n(input=i)
  • DistributedDataParallel calculates num_samples correctly in python2
  • Fix the case of DistributedDataParallel when 1-GPU per process is used.
  • Fixed DataParallel to specify GPUs that don't include GPU-0
  • DistributedDataParallel's exit doesn't error out anymore, the daemon flag is set.
  • Fix a bug in DistributedDataParallel in the case when model has no buffers (previously raised incoherent error)
  • Fix __get_state__ to be functional in DistributedDataParallel (was returning nothing)
  • Fix a deadlock in the NCCL bindings when GIL and CudaFreeMutex were starving each other

Others

  • model.zoo.load_url now first attempts to use the requests library if available, and then falls back to urllib
  • Fix error when default_collate is passed a collection of numpy.str_