Skip to content

Trade-off memory for compute, Windows support, 24 distributions with cdf, variance etc., dtypes, zero-dimensional Tensors, Tensor-Variable merge, , faster distributed, perf and bug fixes, CuDNN 7.1

Compare
Choose a tag to compare
@soumith soumith released this 24 Apr 20:49

PyTorch 0.4.0 release notes

Table of Contents

  • Major Core Changes
    • Tensor / Variable merged
    • Zero-dimensional Tensors
    • dtypes
    • migration guide
  • New Features
    • Tensors
      • Full support for advanced indexing
      • Fast Fourier Transforms
    • Neural Networks
      • Trade-off memory for compute
      • bottleneck - a tool to identify hotspots in your code
    • torch.distributions
      • 24 basic probability distributions
      • Added cdf, variance, entropy, perplexity etc.
    • Distributed Training
      • Launcher utility for ease of use
      • NCCL2 backend
    • C++ Extensions
    • Windows Support
    • ONNX Improvements
      • RNN support
  • Performance improvements
  • Bug fixes

Major Core changes

Here is a summary of the updates to the most important core features users will use daily.

Major Changes and Potentially Breaking Changes:

  • Tensors and Variables have merged
  • Some operations now return 0-dimensional (scalar) Tensors
  • Deprecation of the volatile flag

Improvements:

  • dtypes, devices, and Numpy-style Tensor creation functions added
  • Support for writing device-agnostic code

We wrote a migration guide that should help you transition your code to new APIs and style. Please read it if you have code in a previous version of PyTorch that you would like to migrate.

Please read the migration guide if you have code in a previous version of PyTorch that you would like to migrate.
Please read the migration guide if you have code in a previous version of PyTorch that you would like to migrate.
Please read the migration guide if you have code in a previous version of PyTorch that you would like to migrate.

The contents of this section (Major Core changes) are included in the migration guide.

Merging Tensor and Variable classes

torch.autograd.Variable and torch.Tensor are now the same class. More precisely, torch.Tensor is capable of tracking history and behaves like the old Variable; Variable wrapping continues to work as before but returns an object of type torch.Tensor. This means that you don't need the Variable wrapper everywhere in your code anymore.

The type() of a Tensor has changed

Note also that the type() of a Tensor no longer reflects the data type. Use isinstance() or x.type() instead:

>>> x = torch.DoubleTensor([1, 1, 1])
>>> print(type(x)) # was torch.DoubleTensor
<class 'torch.autograd.variable.Variable'>
>>> print(x.type())  # OK: 'torch.DoubleTensor'
'torch.DoubleTensor'
>>> print(isinstance(x, torch.DoubleTensor))  # OK: True
True

When does autograd start tracking history now?

requires_grad, the central flag for autograd, is now an attribute on Tensors. Let's see how this change manifests in code.

autograd uses the same rules previously used for Variables. It starts tracking history when any input Tensor of an operation has requires_grad=True. For example,

>>> x = torch.ones(1)  # create a tensor with requires_grad=False (default)
>>> x.requires_grad
False
>>> y = torch.ones(1)  # another tensor with requires_grad=False
>>> z = x + y
>>> # both inputs have requires_grad=False. so does the output
>>> z.requires_grad
False
>>> # then autograd won't track this computation. let's verify!
>>> z.backward()
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
>>>
>>> # now create a tensor with requires_grad=True
>>> w = torch.ones(1, requires_grad=True)
>>> w.requires_grad
True
>>> # add to the previous result that has require_grad=False
>>> total = w + z
>>> # the total sum now requires grad!
>>> total.requires_grad
True
>>> # autograd can compute the gradients as well
>>> total.backward()
>>> w.grad
tensor([ 1.])
>>> # and no computation is wasted to compute gradients for x, y and z, which don't require grad
>>> z.grad == x.grad == y.grad == None
True
Manipulating requires_grad flag

Other than directly setting the attribute, you can change this flag in-place using my_tensor.requires_grad_(requires_grad=True), or, as in the above example, at creation time by passing it in as an argument (default is False), e.g.,

>>> existing_tensor.requires_grad_()
>>> existing_tensor.requires_grad
True
>>> my_tensor = torch.zeros(3, 4, requires_grad=True)
>>> my_tensor.requires_grad
True

What about .data?

.data was the primary way to get the underlying Tensor from a Variable. After this merge, calling y = x.data still has similar semantics. So y will be a Tensor that shares the same data with x, is unrelated with the computation history of x, and has requires_grad=False.

However, .data can be unsafe in some cases. Any changes on x.data wouldn't be tracked by autograd, and the computed gradients would be incorrect if x is needed in a backward pass. A safer alternative is to use x.detach(), which also returns a Tensor that shares data with requires_grad=False, but will have its in-place changes reported by autograd if x is needed in backward.

Some operations now return 0-dimensional (scalar) Tensors

Previously, indexing into a Tensor vector (1-dimensional tensor) gave a Python number but indexing into a Variable vector gave (incosistently!) a vector of size (1,)! Similar behavior existed with reduction functions, i.e. tensor.sum() would return a Python number, but variable.sum() would retun a vector of size (1,).

Fortunately, this release introduces proper scalar (0-dimensional tensor) support in PyTorch! Scalars can be created using the new torch.tensor function (which will be explained in more detail later; for now just think of it as the PyTorch equivalent of numpy.array). Now you can do things like:

>>> torch.tensor(3.1416)         # create a scalar directly
tensor(3.1416)
>>> torch.tensor(3.1416).size()  # scalar is 0-dimensional
torch.Size([])
>>> torch.tensor([3]).size()     # compare to a vector of size 1
torch.Size([1])
>>>
>>> vector = torch.arange(2, 6)  # this is a vector
>>> vector
tensor([ 2.,  3.,  4.,  5.])
>>> vector.size()
torch.Size([4])
>>> vector[3]                    # indexing into a vector gives a scalar
tensor(5.)
>>> vector[3].item()             # .item() gives the value as a Python number
5.0
>>> sum = torch.tensor([2, 3]).sum()
>>> sum
tensor(5)
>>> sum.size()
torch.Size([])

Accumulating losses

Consider the widely used pattern total_loss += loss.data[0] before 0.4.0. loss was a Variable wrapping a tensor of size (1,), but in 0.4.0 loss is now a scalar and has 0 dimensions. Indexing into a scalar doesn't make sense (it gives a warning now, but will be a hard error in 0.5.0): use loss.item() to get the Python number from a scalar.

Note that if you don't convert to a Python number when accumulating losses, you may find increased memory usage in your program. This is because the right-hand-side of the above expression used to be a Python float, while it is now a zero-dim Tensor. The total loss is thus accumulating Tensors and their gradient history, which may keep around large autograd graphs for much longer than necessary.

Deprecation of volatile flag

The volatile flag is now deprecated and has no effect. Previously, any computation that involves a Variable with volatile=True won't be tracked by autograd. This has now been replaced by a set of more flexible context managers including torch.no_grad(), torch.set_grad_enabled(grad_mode), and others.

>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
...     y = x * 2
>>> y.requires_grad
False
>>>
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...     y = x * 2
>>> y.requires_grad
False
>>> torch.set_grad_enabled(True)  # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

dtypes, devices and NumPy-style creation functions

In previous versions of PyTorch, we used to specify data type (e.g. float vs double), device type (cpu vs cuda) and layout (dense vs sparse) together as a "tensor type". For example, torch.cuda.sparse.DoubleTensor was the Tensor type respresentingdouble data type, living on CUDA devices, and with COO sparse tensor layout.

In this release, we introduce torch.dtype, torch.device and torch.layout classes to allow better management of these properties via NumPy-style creation functions.

torch.dtype

Below is a complete list of available torch.dtypes (data types) and their corresponding tensor types.

Data type torch.dtype Tensor types
32-bit floating point torch.float32 or torch.float torch.*.FloatTensor
64-bit floating point torch.float64 or torch.double torch.*.DoubleTensor
16-bit floating point torch.float16 or torch.half torch.*.HalfTensor
8-bit integer (unsigned) torch.uint8 torch.*.ByteTensor
8-bit integer (signed) torch.int8 torch.*.CharTensor
16-bit integer (signed) torch.int16 or torch.short torch.*.ShortTensor
32-bit integer (signed) torch.int32 or torch.int torch.*.IntTensor
64-bit integer (signed) torch.int64 or torch.long torch.*.LongTensor

Use torch.set_default_dtype and torch.get_default_dtype to manipulate default dtype for floating point tensors.

torch.device

A torch.device contains a device type ('cpu' or 'cuda') and optional device ordinal (id) for the device type. It can be initilized with torch.device('{device_type}') or torch.device('{device_type}:{device_ordinal}').

If the device ordinal is not present, this represents the current device for the device type; e.g., torch.device('cuda') is equivalent to torch.device('cuda:X') where X is the result of torch.cuda.current_device().

torch.layout

torch.layout represents the data layout of a Tensor. Currentlytorch.strided (dense tensors) and torch.sparse_coo (sparse tensors with COO format) are supported.

Creating Tensors

Methods that create a Tensor now also take in dtype, device, layout, and requires_grad options to specify the desired attributes on the returned Tensor. For example,

>>> device = torch.device("cuda:1")
>>> x = torch.randn(3, 3, dtype=torch.float64, device=device)
tensor([[-0.6344,  0.8562, -1.2758],
        [ 0.8414,  1.7962,  1.0589],
        [-0.1369, -1.0462, -0.4373]], dtype=torch.float64, device='cuda:1')
>>> x.requires_grad  # default is False
False
>>> x = torch.zeros(3, requires_grad=True)
>>> x.requires_grad
True

torch.tensor

torch.tensor is one of the newly added tensor creation methods. It takes in array like data of all kinds and copies the contained values into a new Tensor. As mentioned earlier, torch.tensor is the PyTorch equivalent of NumPy's numpy.array constructor. Unlike the torch.*Tensor methods, you can also create zero-dimensional Tensors (aka scalars) this way (a single python number is treated as a Size in thetorch.*Tensor methods). Moreover, if a dtype argument isn't given, it will infer the suitable dtype given the data. It is the recommended way to create a tensor from existing data like a Python list. For example,

>>> cuda = torch.device("cuda")
>>> torch.tensor([[1], [2], [3]], dtype=torch.half, device=cuda)
tensor([[ 1],
        [ 2],
        [ 3]], device='cuda:0')
>>> torch.tensor(1)               # scalar
tensor(1)
>>> torch.tensor([1, 2.3]).dtype  # type inferece
torch.float32
>>> torch.tensor([1, 2]).dtype    # type inferece
torch.int64

We've also added more tensor creation methods. Some of them have torch.*_like and/or tensor.new_* variants.

  1. torch.*_like takes in an input Tensor instead of a shape. It returns a Tensor with same attributes as the input Tensor by default unless otherwise specified:

    >>> x = torch.randn(3, dtype=torch.float64)
    >>> torch.zeros_like(x)
    tensor([ 0.,  0.,  0.], dtype=torch.float64)
    >>> torch.zeros_like(x, dtype=torch.int)
    tensor([ 0,  0,  0], dtype=torch.int32)
  2. tensor.new_* can also create Tensors with same attributes as tensor, but it always takes in a shape argument:

    >>> x = torch.randn(3, dtype=torch.float64)
    >>> x.new_ones(2)
    tensor([ 1.,  1.], dtype=torch.float64)
    >>> x.new_ones(4, dtype=torch.int)
    tensor([ 1,  1,  1,  1], dtype=torch.int32)

To specify the desired shape, you can either use a tuple (e.g., torch.zeros((2, 3))) or variable arguments (e.g., torch.zeros(2, 3)) in most cases.

Name Returned Tensor torch.*_like variant tensor.new_* variant
torch.empty unintialized memory
torch.zeros all zeros
torch.ones all ones
torch.full filled with a given value
torch.rand i.i.d. continuous Uniform[0, 1)
torch.randn i.i.d. Normal(0, 1)
torch.randint i.i.d. discrete Uniform in given range
torch.randperm random permutation of {0, 1, ..., n - 1}
torch.tensor copied from existing data (list, NumPy ndarray, etc.)
torch.from_numpy* from NumPy ndarray (sharing storage without copying)
torch.arange,
torch.range, and
torch.linspace
uniformly spaced values in a given range
torch.logspace logarithmically spaced values in a given range
torch.eye identity matrix

*: torch.from_numpy only takes in a NumPy ndarray as its input argument.

Writing device-agnostic code

Previous versions of PyTorch made it difficult to write code that was device agnostic (i.e. that could run on both CUDA-enabled and CPU-only machines without modification).

PyTorch 0.4.0 makes this easier in two ways:

  • The device attribute of a Tensor gives the torch.device for all Tensors (get_device only works for CUDA tensors)
  • The to method of Tensors and Modules can be used to easily move objects to different devices (instead of having to call cpu() or cuda() based on the context)

We recommend the following pattern:

# at beginning of the script
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

...

# then whenever you get a new Tensor or Module
# this won't copy if they are already on the desired device
input = data.to(device)
model = MyModule(...).to(device)

Tensors

Full support for Advanced indexing

PyTorch now has full support for advanced indexing, following numpy's advanced indexing rules. The following examples are now possible:

a = torch.rand(10, 10, 10, 10)

# the indexing elements can have other shapes than 1
b = a[[[3, 2]], :, [[1, 3]]]

# broadcasting also supported in the indices, as well as lists,
# negative indices, slices, elipses, numbers
c = a[[1, -2], 2:4, :, [1]]

# can also support tensors as indices
index = torch.tensor([2, 4])
d = a[index]

# and the indices can be on the GPU
# or CPU
e = a[index.cuda()]
f = a.cuda()[index]


mask = torch.rand(10) > 0.5
# we can now index with a mask that has fewer
# dimensions than the indexing tensor
c = a[mask, :5]

Fast Fourier Transform

  • Add new FFT methods #5856
  • Add torch.stft (short time Fourier transform) and hann/hamming/bartlett window functions. #4095
  • Support arbitrary number of batch dimensions in *FFT #6528

New and updated Torch operators

  • Added torch.log2 and torch.log10 #6272
  • Added torch.isnan #5273
  • Add torch.reshape, which is similar to numpy.reshape. It is roughly equivalent to tensor.contiguous().view(), but avoids copying in certain cases #5575
  • Add CPU implementation of torch.unique, which outputs the unique elements of a Tensor #5503
  • Add torch.det, torch.logdet and torch.slogdet, for computing the (log-)determinant of square 2D tensors. For negative determinants, torch.logdet returns nan, while torch.slogdet returns the sign of the log-determinant and the log of the absolute value of the determinant. #3816 and #5393
  • Add nn.functional.gumbel_softmax, which lets you use the reparametrization trick for discrete variables #3341
  • Add torch.take and Tensor.put_. Those functions are equivalent to numpy.take and numpy.put, and are the base for full support of advanced indexing in PyTorch #3263
  • Add torch.randint, similar to numpy.random.randint #6136
  • Add torch.diagonal and torch.diagflat, similar to numpy.diagonal and numpy.diagflat. They are meant as a replacement for torch.diag, which handled both the cases of constructing a diagonal tensor as well as extracting the diagonal of a matrix #5622
  • Add torch.einsum, equivalent to numpy.einsum. einsum allows you to perform operations using Einstein's notation. #5503
a = torch.arange(0, 9).reshape(3, 3)
# the following transposes a
b = torch.einsum('ij->ji', (a,))
  • Add torch.expm1, a numerically stable exp(x)-1 for small x. #4350
  • Allow users to specify individual split sizes with torch.split #3837
  • Add torch.where(condition, tensor1, tensor2) that returns a tensors of elements selected from tensor1 or tensor2 based on condition. #4259, #4259
  • Add Tensor.norm(dim) for sparse tensors. #4882
  • Implement torch.neg for all types. #4075
  • Implement gradient calculation for torch.trtrs. #3972
  • Deprecate out-of-place Tensor.resize and Tensor.resize_as. These have weird semantics and are hard to use correctly. Please use their in-place variants Tensor.resize_ and Tensor.resize_as_. #4886

Rename async argument in .cuda() to non_blocking

The async keyword argument in conversion calls is now deprecated in PyTorch, and it has been replaced by non_blocking. This was necessary because async will be a keyword in Python 3.7

Neural Networks

A new autograd container that lets you trade compute for memory

The new checkpoint container allows you to only store a subset of the outputs necessary for backpropagation. If an output is missing (to save memory), the checkpoint container will recompute the intermediate outputs from the closest checkpoint, so that memory usage can be reduced (with an increase in computation time).
Here is an example:

# input
input = torch.rand(1, 10)
# suppose we have a very deep model
layers = [nn.Linear(10, 10) for _ in range(1000)]
model = nn.Sequential(*layers)
output = model(input)

The above model uses a lot of memory, because it needs to keep the intermediate values of every operation for backpropagation. checkpoint lets your reduce the memory requirements:

# create the input tensors and set the requires_grad=True
# NOTE: the requires_grad=True for the input is a current
# limitation of checkpointing. At least one of the 
# model inputs should have requires_grad=True. 
# If you don't do it, you might have empty gradients.
input = torch.rand(1, 10, requires_grad=True)
layers = [nn.Linear(10, 10) for _ in range(1000)]

# define function that will define where
# we will checkpoint and store
# intermediate gradients. In this case,
# we will only store one intermediate
# gradient, in the middle of the
# model

def run_first_half(*args):
    x = args[0]
    for layer in layers[:500]:
        x = layer(x)
    return x

def run_second_half(*args):
    x = args[0]
    for layer in layers[500:-1]:
        x = layer(x)
    return x

# now uses the new checkpoint functionality
from torch.utils.checkpoint import checkpoint

x = checkpoint(run_first_half, input)
x = checkpoint(run_second_half, x)
# last output need to be run without checkpoint
x = layers[-1](x)
x.sum.backward()  # works!

For sequential modules (which can have arbitrary blocks inside), a helper function checkpoint_sequential is provided, which takes care of the most common use-cases:

input = torch.rand(1, 10, requires_grad=True)
layers = [nn.Linear(10, 10) for _ in range(1000)]
model = nn.Sequential(*layers)

from torch.utils.checkpoint import checkpoint_sequential

# split in two blocks
num_segments = 2
x = checkpoint_sequential(model, num_segments, input)
x.sum().backward()  # works!

bottleneck - a tool to identify hotspots in your code

torch.utils.bottleneck (#5216, #6425) is a tool that can be used as an initial step for
debugging bottlenecks in your program. It summarizes runs of your script with
the Python profiler and PyTorch’s autograd profiler. See the bottleneck docs for more details.

reduce=False Losses

As of this release, all of our loss functions support the reduce keyword. Specifying reduce=False gives a Tensor per unit of loss instead of a single reduced loss. #4924, #5346, #5646, #4231, #4705, #5680

New modules and module improvements

  • Add DistributedDataParallelCPU. This is similar to DistributedDataParallel, but with specific support for models running on the CPU (contrary to DistributedDataParallel, which targets GPU), and supports mpi, gloo and tcp backends #5919.
  • Add Group Normalization (nn.GroupNorm), an alternative to batch normalization that doesn't suffer from the same issues as BatchNorm for small batch sizes
  • Add Layer Normalization (nn.LayerNorm), an alternative for batch normalization often used in NLP tasks. #4922
  • Add Local Response Normalization (nn.LocalResponseNorm). #4922
  • MaxPool3d now supports double backwards. MaxPool3d and MaxUnpool3d now use indices consistent with the rest of the pooling layers. #5328
  • All loss functions now support a reduce argument to return a batch of losses. #264
  • Add util to clip gradient value in torch.nn.utils.clip_grad and add param to He initialization scheme in torch.nn.init. #6173
  • Renamed torch.nn.init.* methods to have an underscore in the end, as they operate in-place, and deprecate the old versions 6093
  • Added support for returning dictionaries in DataParallel #6113
  • Added support for N-D tensors in torch.nn.Bilinear #5764
  • Add Embedding.from_pretrained factory. This allows to initialize an Embedding layer with an existing tensor, bypassing the initial random initialization of its weights.
  • You can now slice nn.Sequential, nn.ModuleList, and nn.ParameterList #4491
  • Registered nn.Module integer parameters and buffers are now immune to module.float(), module.double() module.half() calls. #3820

torch.distributions

torch.distributions has expanded to include 24 basic probability distributions: Bernoulli, Beta, Binomial, Categorical, Cauchy, Chi2, Dirichlet, Exponential, FisherSnedecor, Gamma, Geometric, Gumbel, Laplace, LogNormal, Multinomial, MultivariateNormal, Normal, OneHotCategorical, Pareto, Poisson, RelaxedBernoulli, RelaxedOneHotCategorical, StudentT, and Uniform.

The Distribution interface has expanded to include many methods including .cdf(), .icdf(), .mean(), .variance(), .entropy(), and .perplexity(). Distributions now split tensor dimensions into sample_shape+batch_shape+event_shape. Most continuous distributions now also implement a differentiable .rsample() method to compute pathwise derivatives aka the reparameterization trick (check .has_rsample for availability):

>>> loc = torch.tensor(0., requires_grad=True)
>>> scale = torch.tensor(1., requires_grad=True)
>>> samples = Normal(loc, scale).rsample(sample_shape=(1000,))
>>> loss = (samples - 0.5).pow(4).mean()  # average over 1000 monte carlo samples
>>> grad(loss, [loc, scale])
(tensor(-7.5092), tensor(15.2704))

Most discrete distributions implement an .enumerate_support() method to make it easy to sum over all possible sample values (check .has_enumerate_support for availability).

kl_divergence is defined for many pairs of distributions, e.g.

>>> x = torch.tensor(1.0, requires_grad=True)
>>> kl = kl_divergence(Uniform(-x, x), Normal(0., 1.))
>>> grad(kl, [x])[0]
tensor(-0.6667)

Distribution Transforms

New distributions can be created by combining TransformedDistribution with any number of Transform objects from the torch.distributions.transforms library, including: ExpTransform, PowerTransform, SigmoidTransform, AbsTransform, AffineTransform, SoftmaxTransform, StickBreakingTransform, LowerCholeskyTransform, and their inverses via the .inv property.

Distribution Constraints

Distributions provide metadata about the constraints of their .support and about their arguments (.arg_constraints). These Constraint objects are registered with transforms using transform_to() and biject_to(). Together constraints and transforms make it easy to specify new distributions in a generic way

>>> scale = torch.tensor(1., requires_grad=True)
>>> p = Normal(0., scale)
>>> assert p.arg_constraints['scale'] == constraints.positive
>>> prior = TransformedDistribution(Normal(0., 1.),
...                                 transform_to(constraints.positive))

Constraints in the torch.distributions.constraints library include: boolean, greater_than(lower_bound), integer_interval(lower_bound, upper_bound), interval(lower_bound, upper_bound), lower_cholesky, lower_triangular, nonnegative_integer, positive, positive_definite, positive_integer, real, real_vector, simplex, and unit_interval.

Distributed

Helper utility for launching Distributed Training jobs

We have added an utility function to help launch jobs on a distributed setup.
In order to launch a script that leverages DistributedDataParallel on either single-node multiple-nodes, we can make use of torch.distributed launch as follows

python -m torch.distributed.launch my_script.py --arg1 --arg2 --arg3

The script simplifies day to day usability of the distributed package.

You can read about it's usage here: http://pytorch.org/docs/stable/distributed.html#launch-utility

A new distributed backend based on NCCL 2.0

PyTorch now has a new distributed backend, which leverages NCCL 2.0 for maximum speed.
It also provides new APIs for collective operations on multiple GPUs.
You can enable the new backend via

torch.distributed.init_process_group("nccl")

Other distributed improvements

  • Coalesce many small broadcasts to improve performance #4978
  • Add mixed-precision support for distributed training #4891
  • Release NCCL distributed backend. Previously it was marked as experimental. #4921
  • Enable Infiniband support for Gloo data channel with automatic IB device detection #4795

C++ extensions

Previously, the official way of writing extensions using C or CUDA for custom modules was through the cffi extension. The drawback of this method was that it required a separate step for compiling the CUDA kernels, which could be a bit messy.

PyTorch now provides a better system for writing your own C++ / CUDA extensions. Example implementations using this new extension support can be found in the pytorch/cpp_extensions repo.

We provide two compilation modes:

  • ahead of time compilation: you write a setup.py script using the new CppExtension or CUDAExtension, which is an extension of setuptools.Extension module;
  • just-in-time compilation: you pass the list of C++ / CUDA files that you want to compile to torch.utils.cpp_extension.load, and it will compile on the fly and cache the libraries for you. Here is an example illustrating how easy it is to implement an extension:

In C++

// my_implementation.cpp
#include <torch/torch.h>
#include <unordered_set>

// can use templates as well. But let's keep it
// simple
using scalar_t = float;

at::Tensor unique_float(at::Tensor input_) {
  // only works for floats
  AT_ASSERT(input_.type().scalarType() == at::ScalarType::Float, "input must be a float tensor");
  // and CPU tensors
  AT_ASSERT(!input_.type().is_cuda(), "input must be a CPU tensor");
  
  // make the input contiguous, to simplify the implementation
  at::Tensor input = input_.contiguous();
  
  // get the pointer that holds the data
  scalar_t* input_data = input.data<scalar_t>();
  // let's use a function from the std library to implement
  // the unique function
  std::unordered_set<scalar_t> set(input_data, input_data + input.numel());
  
  // create the output tensor, with size set.size()
  at::Tensor output = input.type().tensor({static_cast<int64_t>(set.size())});
  scalar_t* output_data = output.data<scalar_t>();
  // copy the content of the set to the output tensor
  std::copy(set.begin(), set.end(), output_data);
  
  return output;
}

// this defines the functions exposed to Python
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("unique_float", &unique_float, "Unique for float tensors");
}

And then in Python

import torch
from torch.utils.cpp_extension import load as load_ext
# pass the source files, they will be compiled on the fly 
# and will return a python module
_C = load_ext('my_unique_lib', sources=['my_implementation.cpp'])

# now can use the functions implemented in C++
unique = _C.unique_float

a = torch.tensor([1.0, 2.0, 1.0])
print(unique(a))
# tensor([ 2.,  1.])

Windows support

PyTorch now officially supports Windows. We provide pre-compiled Conda binaries and pip wheels for Python 3.5 and 3.6.
PyTorch on Windows doesn't support distributed training and might be a tad bit slower than Linux / OSX because Visual Studio supports an older version of OpenMP.

As always, you can use the commands at http://pytorch.org to install PyTorch on Windows
We have an FAQ that answers most questions you might have around Windows here: http://pytorch.org/docs/stable/notes/windows.html

ONNX Improvements

New ONNX operators

  • Support export torch.max(input, dim) and torch.min(input, dim) #6220
  • Add symbolic for ReLU to support exporting to ONNX #5759
  • Add sum, prod, sqrt and improve log_softmax #4579
  • Add ONNX support for InstanceNorm #4626
  • Add ONNX symbolic for Elu #3453
  • Add ONNX symbolic for UpsamplingNearest2d #3450

Improvements

  • Print source location when ONNX export fails for a node #5652
  • Export onnx protobuf bindings to python #6651
  • Support output_padding in ConvTranspose #4583

Better RNN support

PyTorch can now export a subset of RNNs to ONNX #4409

  • Add Elman RNN export to ONNX #4613
  • Support batch-first in ONNX export of padded sequences #5360
  • Bidirectional Elman RNN export to ONNX #5120
  • Handle sequence lengths correctly when exporting RNNs to ONNX #4695
  • Support GRU export to ONNX #4390

Bugfixes

  • Fix a bug in ONNX symbolic of 3d average pooling #6101
  • Fix onnx export of replication/reflection pad #4263

Miscellaneous improvements

  • implement __dir__ for Tensors, so that editors can automatically auto-complete and query for the possible fields in Tensors

  • Add numpy() and from_numpy() to HalfTensor

  • Enable TensorDataset to have any number of input tensors.

  • Add padding_value to torch.nn.utils.rnn.pad_sequence

  • Add total_length option to pack_padded_sequence, which is useful when using DataParallel, as we can ensure that we have sequences of the same length.

  • Improve numerical precision of torch.arange, making it consistent with numpy.arange

  • torch.load() and torch.save() support arbitrary file-like object

  • torch.nn.functional.grid_sample now supports 2D (spatial) and 3D (volumetric) inputs

  • set python random seed in DataLoader workers, in order to improve experiment reproducibility

  • Add __delitem__ to nn.Sequential. Now one can delete arbitrary elements of a nn.Sequential.

For example:

model = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
del model[1]  # deletes nn.ReLU
  • ReduceLROnPlateau is now serializable #5300

  • Add option to flush denormal numbers on CPU. #5294

  • PyTorch now exposes the gradients of conv1d, conv2d and conv3d with respect to the input and the weights #5408

  • Add support for calling pack_padded_sequence with either list or with a Tensor #5133

  • Support negative indexing for padding_idx in nn.Embedding #4496

  • Implement backward pass for pack_padded_sequence #4512

  • Add nn.utils.rnn.pad_sequence and nn.utils.rnn.pack_sequence to pad lists of variable length Tensors with 0 and to pack a list of variable length Tensors.

  • Add torch.cuda.memory_cached, torch.cuda.max_memory_cached, torch.cuda.memory_allocated, and torch.cuda.max_memory_allocated methods
    for checking CUDA memory usage #4511

  • Allow viewing on noncontiguous tensors if the new view size is compatible with the tensor's original size and stride. #4062

  • NLLLoss and CrossEntropyLoss now support more than 2 dimensions. #4654

  • Add an option to not show model_zoo download progress bar #4135

  • You can now assign modules to indices of nn.Sequential. #4931

  • You can create tensors with a numpy np.longlong array #4367

  • Change the autograd execution order to use good heuristics. This greatly improves memory usage for large models. #4746

  • Add AMSgrad mode to Adam and SparseAdam optmizers. #4034

  • Better torch.autograd.profiler support for CUDA profiling using the cudaEvent API. #3734

  • torch.set_num_threads also sets the respective MKL option so you won't need to use an environment variable to control it. #4949

Performance improvements

  • Speed up CPU nn.EmbeddingBag, making training overall 30% faster #5433
  • Move nn.MarginRankingLoss, nn.CosineEmbeddingLoss, nn.HingeEmbeddingLoss, and nn.TripletMarginLoss from Python to our ATen backend, resulting in some cases up to a 3x performance gains.
    #5346, #5646, #5080, #5680
  • Implement pin_memory() as a NativeFunction #4094
  • Save self.numel() for backward computation instead of self to save memory #5747
  • Rearrange dimensions for pointwise operations for up to 10x better performance in one case. #4174
  • Vectorize normal_ for a 5-6x speed up in a small case #4312
  • Allowing usage of GPU Direct within PyTorch for the Broadcast operation #4183
  • Speed-up nn.Linear for the 3D input case #5279
  • Speed up Conv3D on the CPU by parallelizing vol2col and col2vol #4824
  • Add AVX2 implementation for sigmoid function, showing around 10x speedup #5010
  • Use fast integer division algorithm to avoid division ops inside kernels. #5054
  • Improve occupancy for CUDA random number generation #5710
  • Add optimization to norm for common norms #5722
  • Add a fast fused GLU backward #5782
  • Optimize unique sorting by using std::vector+sort instead of std::set, giving up to 5x speedup. #5913
  • Speed up sum over a dimension #6026
  • Enable MKLDNN convolution forward and backward. #6062
  • Parallelize non-contiguous point-wise operations with OpenMP #2764
  • Add cudnn Tensor Core ops to RNNs for Volta #3409
  • Vectorize exp, log, sin, cos #6078
  • Reuse intermediate results over multiple backwards grad_inputs #3526

Distributed

  • DistributedDataParallel: 10% of NCCL backend perf improvements with mixed-precision support #5064
  • Slightly improve DistributedDataParallel (single-GPU binding) multi-process distributed training performance #4870

Bug fixes

torch operators

  • Improve torch.digamma precision near poles #6517
  • Fix incorrect behavior of Tensor.random_ on negative inputs #6463
  • Fix undefined behavior in backward pass for tensor.permute(dims) with negative dims #5945
  • Fix integer overflow in torch.remainder operator (it would break with a divisor above 2**48) #5906
  • Fix memory leak in torch.bmm #5744
  • Make dimension checker of scatter_add_ consistent with scatter_'s #5659
  • Fix CPU torch.multinomial with noncontiguous probability tensor input (previously, it would overwrite input data)#5093
  • Fix CUDA torch.multinomial using incorrect strides and being able to select zero-probability events. #5774, #5238
  • Support empty index tensor for index_select #3429
  • Support empty indices tensor in CUDA Tensor.put_ #4486
  • Improve stability of torch.cat with empty tensors #3602, #5971, #5819
  • Fix torch.fft in the case where any of the input dimensions is not aligned #6118
  • Improve the CUDA btrifact error message #5644
  • Return zeros for eigenvector tensor when not requested in torch.symeig#3411
  • Fix torch.btrifact on tensors. #4318
  • Fix torch.pstrf on tensors. #4883
  • Fix memory leak in torch.median 6889
  • Fix SVD backward on non-square matrices when some=False 6870

core

  • Detect re-initialization of _C shared library that would often result in segfaults on exit #6232
  • Fix indexing with all zero ByteTensors #3926
  • Only allow dense floating-point types as the default tensor type. #5674
  • Initialize CUDA before setting CUDA tensor types as default to prevent crash #4788
  • Fix a bug where from_dlpack fails if CUDA is not initialized. #4182
  • Fix crash in creating a CUDA tensor with a numpy array #5850
  • Fix broken sharing of empty tensor in multiprocessing on some OSes #6229

autograd

  • Restore allow_unused functionality: throw error when differentiated input is unused or unreachable. #6553
  • Fix output_nr not being incremented correctly. This caused crashes in the backward pass of operations that don't requires_grad on some inputs. #4812
  • Fix nvprof parsing in the torch.autograd.profiler #5840

nn layers

  • Support only specifying size in certain dimension for adaptive pooling #3127
  • Fix reflection padding boundary checks to not cause invalid memory access #6438
  • Improve error messages for NLLLoss. #5299, #6072
  • Fix kl_div backward on CUDA. Previously it would not respect gradOutput when computing gradInput. #5814
  • Fix incorrect bias size assert for Linear #5992
  • Fix incorrect nn.functional.convNd and nn.functional.conv_transposeNd error message #5701
  • Check that shape for input and target matches instead of number of elements for some loss functions #5085
  • Fix torch.diag backward returning square grad with non-square input #4538
  • Fix convolution type mismatch error message #5815
  • Add align_corners option to linearly interpolating upsampling and make the default upsampling behavior more consistent with other frameworks #5927
  • Prevent numerical issues with poisson_nll_loss when log_input=False #3336

CUDA

  • Ensure convolution weights are contiguous to fix CUDA ConvTranspose double backward #4543
  • Fix CUDA double backwards #4460

sparse

  • Fix embedding with sparse=True #4686
  • Fix sparse embedding backward when input contains only padding_idx #6211
  • Handle copying empty sparse tensors to/from CPU, GPU. #5361

dataloader

  • Add argument checks to the torch.utils.data.Sampler classes, fixing a bug where DataLoader tries to load the entire dataset on non-integer batch_size. #6249
  • Set dataloader.batch_size = None when batch_sampler is given, fixing a bug where DataLoader would report batch_size as 1. #6108
  • Improve signal handling in DataLoader #4643
  • Ignore FileNotFoundError when shutting down #5380
  • Make preprocessing deterministic #4640

optim

  • Cast tensors when loading optimizer state dicts to improve usability #3658
  • List model parameters in deterministic order to improve stability of load_state_dict() #6031
  • Add parameter range checks for all optimizers #6000
  • Fix AMSGrad mode for SparseAdam #4314

distributed and multi-gpu

  • Fix a number of distributed training errors caused by a detach in place error #5829
  • Don't modify requires_grad when running DataParallel in no_grad mode #5880
  • Add GPU guard for broadcast_coalesce for Distributed Data Parallel stability #5655