In-place operations in multi-branch module

Problem Description

When a model has multiple branches, and a branch starts with an operation that mutates it's input (e.g PReLU), the other branches will use an incorrect input (an uninitializied buffer swapped in during the PReLU's in-place operation).

Bug Replication

An example model that exhibits this issue is defined below.

import torch 

class TestModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = torch.nn.Conv2d(3,16,3,1,1)
        self.branch_one = torch.nn.Sequential(
            torch.nn.PReLU(16),
            torch.nn.Conv2d(16,1,1)
        )
        self.branch_two = torch.nn.Conv2d(16,1,1)
    
    def forward(self, x):
        x = self.backbone(x)
        return self.branch_one(x) + self.branch_two(x)

if __name__ == "__main__":
    torch.onnx.export(TestModel(), torch.randn(1,3,16,16), 'prelu_bug.onnx')

Assuming the above code is placed at utests/prelu_bug.py, the bug can replicated as follows:

cd utests
python prelu_bug.py 
python onnx_inference.py --input_onnx prelu_bug.onnx --output prelu_bug.results
python ../converter/main.py --input_onnx prelu_bug.onnx --output prelu_bug.sadl

mkdir build
cd build
cmake ../ -DCMAKE_BUILD_TYPE=Release
make test_scalar

./build/test_scalar prelu_bug.sadl prelu_bug.results 0.001

This will generate the error message [ERROR] test FAILED 254/256 [ERROR] difference onnx/sadl. This behaviour can be attributed to the afore described problem by naively removing PReLU's in-place operation using the below patch and recompiling/running test_scalar as above.

diff --git a/sadl/layer_prelu.h b/sadl/layer_prelu.h
index 7531056..da23989 100644
--- a/sadl/layer_prelu.h
+++ b/sadl/layer_prelu.h
@@ -134,7 +134,12 @@ template<typename T> template<bool multialpha> bool PReLU<T>::apply_scalar(std::
   const int in_C{ in[0]->dims()[3] };
 
   const Tensor<T> &A = *in[1];
-  swap(*in[0], m_out);
+  // swap(*in[0], m_out);
+  const Tensor<T> &x = *in[0];
+  m_out.resize(x.dims());
+  m_out.quantizer = x.quantizer;
+  m_out.border_skip = x.border_skip;
+
   // keep same qunatiz as input
   const int alpha_q = A.quantizer;
   if (multialpha)
@@ -149,9 +154,9 @@ template<typename T> template<bool multialpha> bool PReLU<T>::apply_scalar(std::
         {
           for (int w_nb = 0; w_nb < in_W; w_nb++)
           {
-            if (m_out(n_nb, h_nb, w_nb, c_nb) < 0)
+            if (x(n_nb, h_nb, w_nb, c_nb) < 0)
             {
-              typename ComputationType<T>::type z = m_out(n_nb, h_nb, w_nb, c_nb) * alpha;
+              typename ComputationType<T>::type z = x(n_nb, h_nb, w_nb, c_nb) * alpha;
               ComputationType<T>::quantize(z, alpha_q);
               COUNTERS(z);
               COUNTERS_MAC(z);
@@ -160,6 +165,7 @@ template<typename T> template<bool multialpha> bool PReLU<T>::apply_scalar(std::
             }
             else
             {
+              m_out(n_nb, h_nb, w_nb, c_nb) = x(n_nb, h_nb, w_nb, c_nb);
               COUNTERS_MAC_NOP(1);
             }
           }
@@ -170,19 +176,22 @@ template<typename T> template<bool multialpha> bool PReLU<T>::apply_scalar(std::
   else
   {
     const typename ComputationType<T>::type alpha = A[0];
-    for (auto &x: m_out)
+    int idx = 0;
+    for (auto &m: x)
+
     {
-      if (x < 0)
+      if (m < 0)
       {
-        typename ComputationType<T>::type z = x * alpha;
+        typename ComputationType<T>::type z = m * alpha;
         ComputationType<T>::quantize(z, alpha_q);
         COUNTERS(z);
         COUNTERS_MAC(z);
         SATURATE(z);
-        x = static_cast<T>(z);
+        m_out[idx++] = static_cast<T>(z);
       }
       else
       {
+        m_out[idx++] = m;
         COUNTERS_MAC_NOP(1);
       }
     }

Possible solution

This problem can possibly be resolved using the copy layer mechanism in model<T>::insertCopyLayers