# Simple Autograd

This notebook walks through a self-contained implementation of reverse mode auto-differentiation. The intention is to make it easier to understand PyTorch's implementation of auto-diff and how TorchScript interacts with it without having to work through all the complexity that the real implementation contains.


To get started, we import some helper functions.

In [None]:
import torch
from typing import List, NamedTuple, Callable, Dict, Optional

_name: int = 0
def fresh_name() -> str:
    """ create a new unique name for a variable: v0, v1, v2 """
    global _name
    r = f'v{_name}'
    _name += 1
    return r


To make it possible to fully understand, this system does not rely on PyTorch's autograd at all. It only uses the Tensor object to do compute. We add our own Variable class to track the gradients of computation, and a `grad` function compute gradients.


Similar to PyTorch, we use tape-based reverse mode auto-differentiation to
compute the gradient. For some scalar loss `l`, we will compute the value `dl/dX` for 
_every_ value `X` computed in the program (`l` is always a scalar, but the `X`s can be tensors).
We do this by starting with `dl/dl == 1`, and use the partial derivatives plus 
the chain rule to propagate the values backward,
e.g. `dl/dx * dx/dy = dl/dy`.

https://sidsite.com/posts/autodiff/ might be a good place to start if you 
haven't seen reverse mode auto-diff before.

For the purpose of this example, we primarily use point-wise tensor operators like `+` to keep the partial derivatives simple.


# The Implementation

Variable is a wrapper around Tensor that tracks the compute.
 Each variable has a globally unique name so that we can track the gradient
for this Variable in a dictionary. For ease of understanding,
we sometimes provide this name as an argument. Otherwise, we 
generate a fresh temporary each time.
        

In [None]:
class Variable:
    def __init__(self, value : torch.Tensor, name: str=None):
        self.value = value
        self.name = name or fresh_name()

    # We need to start with some tensors whose values were not computed
    # inside the autograd. This function constructs leaf nodes. 
    @staticmethod
    def constant(value: torch.Tensor, name: str=None):
        r = Variable(value, name)
        print(f'{r.name} = {value}')
        return r

    def __repr__(self):
        return repr(self.value)


    # This performs a pointwise multiplication of a Variable, tracking gradients
    def __mul__(self, rhs: 'Variable') -> 'Variable':
        # defined later in the notebook
        return operator_mul(self, rhs)

    def __add__(self, rhs: 'Variable') -> 'Variable':
        return operator_add(self, rhs)
            
    def sum(self, name: Optional[str]=None) -> 'Variable':
        return operator_sum(self, name)
    
    def expand(self, sizes: List[int]) -> 'Variable':
        return operator_expand(self, sizes)


We need to keep track of all the computation so we can apply the
chain rule backward. A tape entry will help is do this.

In [None]:
class TapeEntry(NamedTuple):
    # names of the inputs to the original computation
    inputs : List[str]
    # names of the outputs of the original computation
    outputs: List[str]
    # apply chain rule
    propagate: 'Callable[List[Variable], List[Variable]]'

The `inputs` and `outputs` are the unique names of the Variables that are inputs and outputs of the _original_ computation.  `propagate` is a closure that propagates the gradient of the outputs of this function to the inputs using the chain rule. This is specific to each leaf operator. Its inputs are `dL/dOutputs`, and its outputs are `dL/dInputs`.  The tape is a just a list of accumulated entries recording all compute. We provide a way to reset it so we can run multiple examples.

In [None]:
gradient_tape : List[TapeEntry] = []

def reset_tape():
  gradient_tape.clear()
  global _name
  _name = 0 # reset variable names too to keep them small.


Now let's look at how an operator is defined. First we calculate the forward result and create a new Variable to represent it. Then we define the `propagate` closure, which uses the chain rule to backprop the gradient.

In [None]:
def operator_mul(self : Variable, rhs: Variable) -> Variable:
    if isinstance(rhs, float) and rhs == 1.0:
        # peephole optimization
        return self

    # define forward
    r = Variable(self.value * rhs.value)
    print(f'{r.name} = {self.name} * {rhs.name}')

    # record what the inputs and outputs of the op were
    inputs = [self.name, rhs.name]
    outputs = [r.name]

    # define backprop
    def propagate(dL_doutputs: List[Variable]):
        dL_dr, = dL_doutputs
    
        dr_dself = rhs # partial derivative of r = self*rhs
        dr_drhs = self # partial derivative of r = self*rhs

        # chain rule propagation from outputs to inputs of multiply
        dL_dself = dL_dr * dr_dself
        dL_drhs = dL_dr * dr_drhs
        dL_dinputs = [dL_dself, dL_drhs] 
        return dL_dinputs
    # finally, we record the compute we did on the tape
    gradient_tape.append(TapeEntry(inputs=inputs, outputs=outputs, propagate=propagate))
    return r

  Notice how both `rhs` and `self` are captured by this closure.
  Their values have to be saved for the backward pass.
  PyTorch does something similar, but because PyTorch allows for
  mutable tensors, it has additional logic to make sure these captured
  variables are not mutated.

  We'll define the other operators later. Let's look at how we can define a `grad` function that puts these pieces together. `grad` calculates the gradient of `L` with respect to `desired_results`. We first calculate the gradient of `L` with respect to _all_ computed values and then just extract `desired_results` from them. Real systems do more pruning ahead of time to make sure we are not computing unused values.


In [None]:
def grad(L, desired_results: List[Variable]) -> List[Variable]:
    # this map holds dL/dX for all values X
    dL_d : Dict[str, Variable] = {}
    # It starts by initializing the 'seed' dL/dL, which is 1
    dL_d[L.name] = Variable(torch.ones(()))
    print(f'd{L.name} ------------------------')

    # look up dL_dentries. If a variable is never used to compute the loss,
    # we consider its gradient None, see the note below about zeros for more information.
    def gather_grad(entries: List[str]):
        return [dL_d[entry] if entry in dL_d else None for entry in entries]

    # propagate the gradient information backward
    for entry in reversed(gradient_tape):
        dL_doutputs = gather_grad(entry.outputs)
        if all(dL_doutput is None for dL_doutput in dL_doutputs):
            # optimize for the case where some gradient pathways are zero. See
            # The note below for more details.
            continue

        # perform chain rule propagation specific to each compute
        dL_dinputs = entry.propagate(dL_doutputs)

        # Accululate the gradient produced for each input.
        # Each use of a variable produces some gradient dL_dinput for that 
        # use. The multivariate chain rule tells us it is safe to sum 
        # all the contributions together.
        for input, dL_dinput in zip(entry.inputs, dL_dinputs):
            if input not in dL_d:
                dL_d[input] = dL_dinput
            else:
                dL_d[input] += dL_dinput

    # print some information to understand the values of each intermediate 
    for name, value in dL_d.items():
        print(f'd{L.name}_d{name} = {value.name}')
    print(f'------------------------')

    return gather_grad(desired.name for desired in desired_results)


# Some more operators

We'll use these in our examples. Their implementation is very similar to `operator_mul`.

In [None]:
def operator_add(self : Variable, rhs: Variable) -> Variable:
    # Add follows a similar pattern to Mul, but it doesn't end up
    # capturing any variables.
    r = Variable(self.value + rhs.value)
    print(f'{r.name} = {self.name} + {rhs.name}')
    def propagate(dL_doutputs: List[Variable]):
        dL_dr, = dL_doutputs
        dr_dself = 1.0
        dr_drhs = 1.0
        dL_dself = dL_dr * dr_dself
        dL_drhs = dL_dr * dr_drhs
        return [dL_dself, dL_drhs]
    gradient_tape.append(TapeEntry(inputs=[self.name, rhs.name], outputs=[r.name], propagate=propagate))
    return r

# sum is used to turn our matrices into a single scalar to get a loss.
# expand is the backward of sum, so it is added to make sure our Variable
# is closed under differentiation. Both have rules similar to mul above.

def operator_sum(self: Variable, name: Optional[str]) -> 'Variable':
    r = Variable(torch.sum(self.value), name=name)
    print(f'{r.name} = {self.name}.sum()')
    def propagate(dL_doutputs: List[Variable]):
        dL_dr, = dL_doutputs
        size = self.value.size()
        return [dL_dr.expand(*size)]
    gradient_tape.append(TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate))
    return r


def operator_expand(self: Variable, sizes: List[int]) -> 'Variable':
    assert(self.value.dim() == 0) # only works for scalars
    r = Variable(self.value.expand(sizes))
    print(f'{r.name} = {self.name}.expand({sizes})')
    def propagate(dL_doutputs: List[Variable]):
        dL_dr, = dL_doutputs
        return [dL_dr.sum()]
    gradient_tape.append(TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate))
    return r

# Using `grad`
Let's use the implementation to calculate some gradients

In [None]:
a_global, b_global = torch.rand(4), torch.rand(4)

def simple(a, b):
    t = a + b
    return t * b

reset_tape() # reset any compute from other cells
a = Variable.constant(a_global, name='a')
b = Variable.constant(b_global, name='b')
loss = simple(a, b)
da, db = grad(loss, [a, b])
print("da", da)
print("db", db)

a = tensor([0.0171, 0.1633, 0.5833, 0.3794])
b = tensor([0.3774, 0.6308, 0.5239, 0.1387])
v0 = a + b
v1 = v0 * b
dv1 ------------------------
v3 = v2 * b
v4 = v2 * v0
v5 = v4 + v3
dv1_dv1 = v2
dv1_dv0 = v3
dv1_db = v5
dv1_da = v3
------------------------
da tensor([0.3774, 0.6308, 0.5239, 0.1387])
db tensor([0.7719, 1.4249, 1.6311, 0.6567])


# Zero Gradients

An interesting case to look at is when the gradient is zero.

In [None]:
reset_tape()
loss = a*a
da, db = grad(loss, [a, b])
print("da", da)
print("db", db)

v0 = a * a
dv0 ------------------------
v2 = v1 * a
v3 = v1 * a
v4 = v2 + v3
dv0_dv0 = v1
dv0_da = v4
------------------------
da tensor([0.9209, 0.8121, 1.8843, 0.7893])
db None


Notice that `db` has the value `None`. Another perhaps more mathematically appropriate choice would be to return a 4 element tensor of zeros because a value that does not contribute to the loss will have a gradient of zero. So why do we use `None` instead? The reason is because we want to be able to quickly check that a gradient value is zero, so that we can skip `propgate` functions that involve it in `grad`:

```
if all(dL_doutput is None for dL_doutput in dL_doutputs):
    # optimize for the case where some gradient pathways are zero. See
    # The note below for more details.
    continue
```

How does this skipping optimization work? Each propagate function is applying the chain rule.
In the general case where there is a vector of inputs and vector
of outputs to the function, the jacobean `J` represents the pairwise
partial derivatives from each input to each output (`dinput_i/d_output_j`) in matrix form.
The multiplication `v*J` (equivalently `J^t*v` if you treat `v` as a column vector) propagates the chain 
rule backward. This is why propagate is sometimes called the
vector-Jacobean product, or `vjp` (and also why forward autodiff uses the Jacobean-vector product).

In practice, we do not construct the `J` matrix, because it often
has a lot of structure in it. For instance, in pointwise operations,
it is a diagonal matrix (input of vector `i` affects only the output of vector `i`).  Constructing it would create `N^2` entries when we only have `N` non-zeros.

However, we know that propgate always computes a matrix product
against `J`. One important property is if `v` is 0, we know from
the fact that matrix multiplication is a linear operator, that `v*J`
is also 0. This is what the the `if`-statement is saying. If all the 
input derivatives are 0, we know the outputs are 0, even without
running propagate. This property is important in autograd as we often
do more compute that is not related to the loss, and do not
want to waste time computing zero gradients for it.

This would be more expensive to check if we have to check that each element of a matrix was zero. So we use `None` in grad to represent a value _known_ to be full of only zeros, making the check constant time. PyTorch's autograd does the exact same check. For historical reasons it uses undefined tensors (`at::Tensor()`) in C++ to represent these known-to-be-zero tensors. This has implications for when we generate gradients for aggregate operators as we will see later. When working with the PyTorch autograd, you should keep in mind that undefined tensors are always used to represent these known-to-be-zero values. 

# Gradients of Gradients

Notice how the definition of `propagate` works on `Variables` not `Tensors`. This is so that it can calculate the gradient of some other gradient. Just think of the first gradient computation like any other compute you can do. There is no reason why you can't take a gradient of that compute as well. As a concrete example lets look at this code:



In [None]:
def run_gradients(my_fn, second_loss=True):
    reset_tape()
    a = Variable.constant(a_global, name='a')
    b = Variable.constant(b_global, name='b')

    # our first loss
    L0 = (my_fn(a, b)).sum(name='L0')

    # compute derivatives of our inputs
    dL0_da, dL0_db = grad(L0, [a, b])
    if not second_loss:
      return dL0_da, dL0_db

    # now lets compute the L2 norm of our derivatives
    L1 = (dL0_da*dL0_da + dL0_db*dL0_db).sum(name='L1')

    # and take the gradient of that.
    # notice there are two losses involved.
    dL1_da, dL1_db = grad(L1, [a, b])
    return dL1_da, dL1_db

da, db = run_gradients(simple)
print("da", da)
print("db", db)

a = tensor([0.4605, 0.4061, 0.9422, 0.3946])
b = tensor([0.0850, 0.3296, 0.9888, 0.6494])
v0 = a + b
v1 = v0 * b
L0 = v1.sum()
dL0 ------------------------
v3 = v2.expand(4)
v4 = v3 * b
v5 = v3 * v0
v6 = v5 + v4
dL0_dL0 = v2
dL0_dv1 = v3
dL0_dv0 = v4
dL0_db = v6
dL0_da = v4
------------------------
v7 = v4 * v4
v8 = v6 * v6
v9 = v7 + v8
L1 = v9.sum()
dL1 ------------------------
v11 = v10.expand(4)
v12 = v11 * v6
v13 = v11 * v6
v14 = v12 + v13
v15 = v11 * v4
v16 = v11 * v4
v17 = v15 + v16
v18 = v17 + v14
v19 = v14 * v0
v20 = v14 * v3
v21 = v18 * b
v22 = v18 * v3
v23 = v19 + v21
v24 = v23.sum()
v25 = v22 + v20
dL1_dL1 = v10
dL1_dv9 = v11
dL1_dv7 = v11
dL1_dv8 = v11
dL1_dv6 = v14
dL1_dv4 = v18
dL1_dv5 = v14
dL1_dv3 = v23
dL1_dv0 = v20
dL1_db = v25
dL1_dv2 = v24
dL1_da = v20
------------------------
da tensor([1.2611, 2.1304, 5.8394, 3.3869])
db tensor([ 2.6923,  4.9201, 13.6563,  8.0727])


Notice how the `gradient_tape` just keeps accumulating more entries as we run `grad` twice. This is because in the second call to `grad` we still have to consider all the pathways through which the gradient flows all the way from `L1` back to the inputs `a` and `b`.  One implication is that the entries that are run in the first call to `grad` actually get run _again_ in the second call to `grad`.  In practice this means that if you append a `propagate` function to the tape in a gradient-of-gradient scenario, you should expect it to run multiple times! If a single gradient compute is "forward, backward", then a gradient of gradient compute could be thought of as "forward-part-0, backward-part-0, foward-part-1, backward-part-1, backward-part-0 (again)".

Issues with how autograd functions behave often _only_ appear when considering higher order gradients so it is important to test changes on these cases. We'll see an example later.

# Rules of thumb for Autograd

## Every use of a Variable generates a gradient specific to that use

If you use a temporary variable `t` in two different subsequent computations, each _use_ of that value will have a gradient associated with it from the using operator. The multivariate chain rule tells us we can sum these gradients to get the overall contribute of `t`. We always have to account for all uses of a variable. If we forget about one, we will calculate the wrong value.

## Inputs become outputs, outputs become inputs, reads become writes, writes become reads

When we record a `TapeEntry` we also record the inputs and outputs of the compute _from the perspective of the forward pass_. The inputs/outputs of the propgate function in the backward pass are _flipped_. You get `dL/doutputs` and you produce `dL/dinputs`. It is easy to get confused by names like input or output. You have to keep in mind what they are relative to. A corrolary here occurs at the level of compute. Because every read of a value in a matrix produces a gradient, it implies that in the backward pass we will be computing (and writing) a value for every read in forward. For instance, the `sum` operator reads an entire matrix and produces one value. So its reverse must be an operator that reads one value and writes an entire matrix.  Indeed, the backward of `sum` is `expand`, which does precisely that.

## Each call to grad produces gradients for a different loss

When you call `grad(l, [a,b])` you are computing a set of gradients `dl_da`, `dl_db`. A subsquent call to `grad` will use a different loss, and potentially care about different inputs. If you abbreviate the loss, e.g. by saying `da`, you better be sure there aren't additional losses or you will quickly get confused. Gradient-of-gradient, or higher-order gradient, just means that we are computing some loss that was based on the gradients of another loss. There are an infinite number of calculations that compute gradients. There isn't a single "grad-of-grad" compute.


# Creating Aggregate Operations

While autograd is a great way to piece together fundamental operators, sometimes you want to create aggregate operators that do not perform autograd operations internally. Fusion is one common example of this where, for instance, you may want to generate a single CUDA kernel for `(a + b)*b`. TorchScript's symbolic autograd internally can separate this compute from autograd and generate explicit forward/backward passes. Let's look at what issues arise when trying to do this. This should help with understanding the PyTorch implementation, and also make it possible to create custom aggregate operators with correct autograd implementations. Let's try to turn our `simple` function from before into one that computes its entire body as an aggregate:

In [None]:
def simple(a, b):
    t = a + b
    return t * b

def simple_type_error(a, b):
    t = (a.value + b.value)
    r = Variable(t * b.value)
    def propagate(dL_doutputs: List[Variable]) -> List[Variable]:
        # manually apply the chain rule to the compute,
        # in practice a symbolic differentiator might create this code
        dL_dr, = dL_doutputs
        dr_dt = b # partial from: r = t * b
        dr_db = t # partial from: r = t * b
        dL_dt = dL_dr*dr_dt # chain rule
        dt_da = 1.0 # partial from t = a + b
        dt_db = 1.0 # partial from t = a + b
        dL_da = dL_dt * dt_da # chain rule
        dL_db = dL_dt * dt_db + dL_dr * dr_db # ERROR! dr_db is a Tensor not a Variable
        return [dL_da, dL_db]
    gradient_tape.append(TapeEntry(inputs=[a.name, b.name], outputs=[r.name], propagate=propagate))
    return r

da, db = run_gradients(simple_type_error)

a = tensor([0.4605, 0.4061, 0.9422, 0.3946])
b = tensor([0.0850, 0.3296, 0.9888, 0.6494])
L0 = v0.sum()
dL0 ------------------------
v2 = v1.expand(4)
v3 = v2 * b


AttributeError: ignored

This doesn't work because `t` is being captured and used in propagate, but propgate expects to compute on Variables. Becuase `t` was extracted from autograd, it can no longer directly participate in the `propagate` call. One way to fix this is to recompute `t`

In [None]:
def simple_recompute(a, b):
    t = (a.value + b.value)
    r = Variable(t * b.value)
    def propagate(dL_doutputs: List[Variable]) -> List[Variable]:
        dL_dr, = dL_doutputs
        dr_dt = b # partial from: r = t * b
        t = a + b # RECOMPUTE!
        dr_db = t # partial from: r = t * b
        dL_dt = dL_dr*dr_dt # chain rule
        dt_da = 1.0 # partial from t = a + b
        dt_db = 1.0 # partial from t = a + b
        dL_da = dL_dt * dt_da # chain rule
        dL_db = dL_dt * dt_db + dL_dr * dr_db # chain rule
        return [dL_da, dL_db]
    gradient_tape.append(TapeEntry(inputs=[a.name, b.name], outputs=[r.name], propagate=propagate))
    return r

da, db = run_gradients(simple_recompute)
da_ref, db_ref = run_gradients(simple)
print("da", da, da_ref)
print("db", db, db_ref)

a = tensor([0.4605, 0.4061, 0.9422, 0.3946])
b = tensor([0.0850, 0.3296, 0.9888, 0.6494])
L0 = v0.sum()
dL0 ------------------------
v2 = v1.expand(4)
v3 = a + b
v4 = v2 * b
v5 = v2 * v3
v6 = v4 + v5
dL0_dL0 = v1
dL0_dv0 = v2
dL0_da = v4
dL0_db = v6
------------------------
v7 = v4 * v4
v8 = v6 * v6
v9 = v7 + v8
L1 = v9.sum()
dL1 ------------------------
v11 = v10.expand(4)
v12 = v11 * v6
v13 = v11 * v6
v14 = v12 + v13
v15 = v11 * v4
v16 = v11 * v4
v17 = v15 + v16
v18 = v17 + v14
v19 = v14 * v3
v20 = v14 * v2
v21 = v18 * b
v22 = v18 * v2
v23 = v19 + v21
v24 = v22 + v20
v25 = v23.sum()
dL1_dL1 = v10
dL1_dv9 = v11
dL1_dv7 = v11
dL1_dv8 = v11
dL1_dv6 = v14
dL1_dv4 = v18
dL1_dv5 = v14
dL1_dv2 = v23
dL1_dv3 = v20
dL1_db = v24
dL1_da = v20
dL1_dv1 = v25
------------------------
a = tensor([0.4605, 0.4061, 0.9422, 0.3946])
b = tensor([0.0850, 0.3296, 0.9888, 0.6494])
v0 = a + b
v1 = v0 * b
L0 = v1.sum()
dL0 ------------------------
v3 = v2.expand(4)
v4 = v3 * b
v5 = v3 * v0
v6 = v5 + v4
dL0_d

This recompute works but it is not ideal. First, the original compute may have been expensive (think a bunch of convolutions and multiplies), so redoing it in the backward pass may take significant time. Second, we need to save `a` and `b` to recompute `t`. Previously we only had to save `b`. What if `a` was a _huge_ matrix but `t` was small? Then we are using _more total memory_ by doing this recompute as well. In general, we want to avoid recomputing things unless we know it won't be expensive in time or space.

Let's consider another approach. What happens if we just make `t` into a Variable?

In [None]:
def simple_variable_wrong(a, b):
    t = (a.value + b.value)
    t_v = Variable(t, name='t') # named for debugging
    r = Variable(t * b.value)
    def propagate(dL_doutputs: List[Variable]) -> List[Variable]:
        dL_dr, = dL_doutputs
        dr_dt = b # partial from: r = t * b
        dr_db = t_v # partial from: r = t * b
        dL_dt = dL_dr*dr_dt # chain rule
        dt_da = 1.0 # partial from t = a + b
        dt_db = 1.0 # partial from t = a + b
        dL_da = dL_dt * dt_da # chain rule
        dL_db = dL_dt * dt_db + dL_dr * dr_db # chain rule
        return [dL_da, dL_db]
    gradient_tape.append(TapeEntry(inputs=[a.name, b.name], outputs=[r.name], propagate=propagate))
    return r

da, db = run_gradients(simple_variable_wrong)
print("da", da) # ERROR: da is None!!!????
print("db", db)

a = tensor([0.4605, 0.4061, 0.9422, 0.3946])
b = tensor([0.0850, 0.3296, 0.9888, 0.6494])
L0 = v0.sum()
dL0 ------------------------
v2 = v1.expand(4)
v3 = v2 * b
v4 = v2 * t
v5 = v3 + v4
dL0_dL0 = v1
dL0_dv0 = v2
dL0_da = v3
dL0_db = v5
------------------------
v6 = v3 * v3
v7 = v5 * v5
v8 = v6 + v7
L1 = v8.sum()
dL1 ------------------------
v10 = v9.expand(4)
v11 = v10 * v5
v12 = v10 * v5
v13 = v11 + v12
v14 = v10 * v3
v15 = v10 * v3
v16 = v14 + v15
v17 = v16 + v13
v18 = v13 * t
v19 = v13 * v2
v20 = v17 * b
v21 = v17 * v2
v22 = v18 + v20
v23 = v22.sum()
dL1_dL1 = v9
dL1_dv8 = v10
dL1_dv6 = v10
dL1_dv7 = v10
dL1_dv5 = v13
dL1_dv3 = v17
dL1_dv4 = v13
dL1_dv2 = v22
dL1_dt = v19
dL1_db = v21
dL1_dv1 = v23
------------------------
da None
db tensor([1.4312, 2.7896, 7.8169, 4.6857])


While we do not get an error, something is clearly wrong. `dL1/da` is None, but we _know_ that the value of `a` affects the norm of the gradients of the original loss so this value should not be None. We are not propagating a gradient somewhere!

Let's see what happens when we run just the first gradient.


In [None]:
da, db = run_gradients(simple_variable_wrong, second_loss=False)
da_ref, db_ref = run_gradients(simple, second_loss=False)
print("da", da, da_ref) 
print("db", db, db_ref)

a = tensor([0.4605, 0.4061, 0.9422, 0.3946])
b = tensor([0.0850, 0.3296, 0.9888, 0.6494])
L0 = v0.sum()
dL0 ------------------------
v2 = v1.expand(4)
v3 = v2 * b
v4 = v2 * t
v5 = v3 + v4
dL0_dL0 = v1
dL0_dv0 = v2
dL0_da = v3
dL0_db = v5
------------------------
a = tensor([0.4605, 0.4061, 0.9422, 0.3946])
b = tensor([0.0850, 0.3296, 0.9888, 0.6494])
v0 = a + b
v1 = v0 * b
L0 = v1.sum()
dL0 ------------------------
v3 = v2.expand(4)
v4 = v3 * b
v5 = v3 * v0
v6 = v5 + v4
dL0_dL0 = v2
dL0_dv1 = v3
dL0_dv0 = v4
dL0_db = v6
dL0_da = v4
------------------------
da tensor([0.0850, 0.3296, 0.9888, 0.6494]) tensor([0.0850, 0.3296, 0.9888, 0.6494])
db tensor([0.6306, 1.0652, 2.9197, 1.6935]) tensor([0.6306, 1.0652, 2.9197, 1.6935])


In the single-backward case, we get the right answer! This illustrates a key part of autograd: it is _very easy_ to make it appear to work for a single backward pass but have the code be broken when trying higher order gradients. 

So what is going wrong? Look at the debug trace from the first time we ran `simple_variable_wrong`. Inside the compute of `dL0` (the first backward), you can see a line: `v4 = v2 * t`. The first backward is using the value of `t`. But if a computation _uses_ `t` then the gradient of that computation will have a non-zero gradient `dL1/dt` for any future loss (`L1`) that uses the results of that computation. But this future use of `t` is not accounted for in `simple_variable_wrong`! We consider the effect of `r` on `t` as `dL_dt = dL_dr*dr_dt`, but do not consider uses of `t` outside the local aggregate. This is because the way `t` can be used in the future is subtle: it escapes from our compute _only_ through its use as a closed over variable in `propagate`. So this gradient pathway can only be non-zero for higher-order gradients where we are differentiating through this use.

The problem originates because `t` was not declared as an output of the original computation, even though it was defined by the computation and used by later computations. We can fix this by defining it as an output in the gradient tape and then using the derivative contribution that comes from it.

In [None]:
def simple_variable_almost(a, b):
    t = (a.value + b.value)
    t_v = Variable(t, name='t_v')
    r = Variable(t * b.value)
    def propagate(dL_doutputs: List[Variable]) -> List[Variable]:
        # t is considered an output, so we now get dL_dt0 as an input.
        dL_dr, dL_dt0 = dL_doutputs
               ###### new gradient contribution

        # Handle cases where one incoming gradient is zero (None)
        if dL_dr is None:
          dL_dr = Variable.constant(torch.zeros(()))
        if dL_dt0 is None:
          dL_dt0 = Variable.constant(torch.zeros(()))
               

        dr_dt = b 
        dr_db = t_v 
        # we combine this with the contribution from r to calculate 
        # all gradient paths to dL_dt
        dL_dt = dL_dt0 + dL_dr*dr_dt # chain rule
                ######

        dt_da = 1.0 
        dt_db = 1.0 
        dL_db = dL_dr * dr_db + dL_dt * dt_db 
        dL_da = dL_dt * dt_da
        return [dL_da, dL_db]

    # note: t_v is now considered an output in the tape
    gradient_tape.append(TapeEntry(inputs=[a.name, b.name], outputs=[r.name, t_v.name], propagate=propagate))
                                                                             ######### new output
    return r
da, db = run_gradients(simple_variable_almost)
print("da", da) 
print("db", db)

a = tensor([0.0171, 0.1633, 0.5833, 0.3794])
b = tensor([0.3774, 0.6308, 0.5239, 0.1387])
L0 = v0.sum()
dL0 ------------------------
v2 = v1.expand(4)
v3 = 0.0
v4 = v2 * b
v5 = v3 + v4
v6 = v2 * t_v
v7 = v6 + v5
dL0_dL0 = v1
dL0_dv0 = v2
dL0_da = v5
dL0_db = v7
------------------------
v8 = v5 * v5
v9 = v7 * v7
v10 = v8 + v9
L1 = v10.sum()
dL1 ------------------------
v12 = v11.expand(4)
v13 = v12 * v7
v14 = v12 * v7
v15 = v13 + v14
v16 = v12 * v5
v17 = v12 * v5
v18 = v16 + v17
v19 = v18 + v15
v20 = v15 * t_v
v21 = v15 * v2
v22 = v19 * b
v23 = v19 * v2
v24 = v20 + v22
v25 = v24.sum()
v26 = 0.0
v27 = v26 * b
v28 = v21 + v27
v29 = v26 * t_v
v30 = v29 + v28
v31 = v23 + v30
dL1_dL1 = v11
dL1_dv10 = v12
dL1_dv8 = v12
dL1_dv9 = v12
dL1_dv7 = v15
dL1_dv5 = v19
dL1_dv6 = v15
dL1_dv2 = v24
dL1_dt_v = v21
dL1_dv3 = v19
dL1_dv4 = v19
dL1_db = v31
dL1_dv1 = v25
dL1_da = v28
------------------------
da tensor([1.5438, 2.8499, 3.2622, 1.3134])
db tensor([3.8424, 6.9614, 7.5721, 2.9042])


This code is now correct! However, it has some non-optimal behavior. Notice how at the beginning of `propagate` we need to handle the cases where the gradients coming in are `None`. Recall that when a pathway has no gradient we give it the value `None`. The first time through `propagate` `dL_dt0` will be `None` since `t` is not used outside of the propagate function itself on the first backward. The _second_ time through `propgate`, `dL_dt0` will have a value but `dL_dr` will be `None`. Excercise: convince yourself why `dL_dr` is `None` the second time through. When we fix this by changing the `None` into zeros, we get the right answer but at the cost of always doing more compute. For instance in this case, it adds an additional pointwise addition of a zero tensor to every single-backward call to handle `dL_dt0` input which will be zero.

 It makes sense to use a constant-time check for zero to eliminate a tensor-sized amount of work. So we optimize this code by replicating some of the `None` handling logic in `grad` directly into the aggregate op. It is a little messy but it handles the cases where inputs might be `None` with a minimal amount of compute.

In [None]:
def add_optional(a: Optional['Variable'], b: Optional['Variable']):
    if a is None:
        return b
    if b is None:
        return a
    return a + b

def simple_variable(a, b):
    t = (a.value + b.value)
    t_v = Variable(t, name='t_v')
    r = Variable(t * b.value)
    def propagate(dL_doutputs: List[Variable]) -> List[Variable]:
        dL_dr, dL_dt0 = dL_doutputs
        dr_dt = b # partial from: r = t * b
        dr_db = t_v # partial from: r = t * b
        dL_dt = dL_dt0
        if dL_dr is not None:
            dL_dt = add_optional(dL_dt, dL_dr*dr_dt) # chain rule

        dt_da = 1.0 # partial from t = a + b
        dt_db = 1.0 # partial from t = a + b
        if dL_dr is not None:
            dL_db = dL_dr * dr_db # chain rule
        else:
            dL_db = None

        if dL_dt is not None:
            dL_da = dL_dt * dt_da # chain rule
            dL_db = add_optional(dL_db, dL_dt * dt_db)
        else:
            dL_da = None

        return [dL_da, dL_db]

    gradient_tape.append(TapeEntry(inputs=[a.name, b.name], outputs=[r.name, t_v.name], propagate=propagate))
    return r
da, db = run_gradients(simple_variable)
print("da", da) 
print("db", db)

a = tensor([0.0171, 0.1633, 0.5833, 0.3794])
b = tensor([0.3774, 0.6308, 0.5239, 0.1387])
L0 = v0.sum()
dL0 ------------------------
v2 = v1.expand(4)
v3 = v2 * b
v4 = v2 * t_v
v5 = v4 + v3
dL0_dL0 = v1
dL0_dv0 = v2
dL0_da = v3
dL0_db = v5
------------------------
v6 = v3 * v3
v7 = v5 * v5
v8 = v6 + v7
L1 = v8.sum()
dL1 ------------------------
v10 = v9.expand(4)
v11 = v10 * v5
v12 = v10 * v5
v13 = v11 + v12
v14 = v10 * v3
v15 = v10 * v3
v16 = v14 + v15
v17 = v16 + v13
v18 = v13 * t_v
v19 = v13 * v2
v20 = v17 * b
v21 = v17 * v2
v22 = v18 + v20
v23 = v22.sum()
v24 = v21 + v19
dL1_dL1 = v9
dL1_dv8 = v10
dL1_dv6 = v10
dL1_dv7 = v10
dL1_dv5 = v13
dL1_dv3 = v17
dL1_dv4 = v13
dL1_dv2 = v22
dL1_dt_v = v19
dL1_db = v24
dL1_dv1 = v23
dL1_da = v19
------------------------
da tensor([1.5438, 2.8499, 3.2622, 1.3134])
db tensor([3.8424, 6.9614, 7.5721, 2.9042])


**Excercise** modify `run_gradients` such that the second call to `grad` produces non-zero values for both `dL_dr` and `dL_dt`. Hint: it can be done with the addition of 2 characters.

In PyTorch's symbolic autodiff implementation, the handling of zero tensors is done with undefined tensors in the place of `None` values, but the logic in TorchScript is very similar. The function `any_defined(...)` is used to check if any inputs are non-zero and guards the calculation of unused parts of the autograd. The `AutogradAdd(a, b)` operator adds two tensors, handling the case where either is undefined, similar to `add_optional`. 

The backward pass is very messy as-is with all of this conditional logic. Furthermore, as you have seen in these examples, in many cases the logic will branch in the same direction. This is especially true for single-backward where gradients from captured temporaries will always be zero. It is profitable to try to specialize this code for particular patterns of non-zeros since it allows more aggresive fusion of the backward pass.

# PyTorch vs Simple Grad

Simple Grad gives a good overview of how PyTorch's autograd works. TorchScript's symbolic gradient pass can generate aggregate operators from subsets of the IR by automating the process we went through to define `simple` as an aggregate operator.

The real PyTorch autograd has some features that go beyond this example related to mutable tensors. Simple Grad assumes that tensors are immutable, so saving a Tensor for use in `propagate` is as simple as saving a reference to it. In PyTorch, the gradient formulas need to explicity mark a Tensor as needing to be saved so we can track future potential mutations. The idea is to be able to track if a user mutated a tensor that is needed by backward and report an error on use. Mutable ops themselves also affect how the `propagate` functions get recorded. If a tensor is mutated, uses of the tensor _before_ the mutation need to propagate gradient to the original value, while uses _after_ propagate gradient to the new mutated value. Since tensors can be views of other mutable tensors, PyTorch needs bookkeeping to make sure any time a tensor is updated all views of the tensor now propagate gradient to the new value and not the old one. 

# Where to go from here

If you still have questions about how this process works, I encourage you to edit this notebook with additional debug information and play around with compute. You can try:
* Adding a new operator with `propagate` formula (use torch.grad to verify correctness)
* Modify `run_gradient` to calculate weirder higher order gradients and see if it behaves as you expect.
* Remove `None` and implement gradients using Tensor zeros.
* Try to manually define an another aggregate operator for something similar to `simple`
* Write a 'compiler' that can take a small expression similar to `simple` and transform it automatically into a forward and `propagate`, similar to autodiff.cpp
* Rewrite `simple_variable` so all the branching for `None` checks is at the top of `propagate`. Can you generalize this such that a compiler can generate specializations for the seen non-zero patterns?
* Read `autodiff.cpp` and add a description to this document about how concenpts in here directly relate to that code.