Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active November 27, 2023 22:20
Show Gist options
  • Save AmosLewis/c0d551bfdd11b24eb10a5f8de10809d1 to your computer and use it in GitHub Desktop.
Save AmosLewis/c0d551bfdd11b24eb10a5f8de10809d1 to your computer and use it in GitHub Desktop.
import shark_turbine.aot as aot
import torch
import torch.nn as nn
class ExMod(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.BatchNorm2d(100)
def forward(self,x):
return self.m(x)
x = torch.zeros(20,100,35,45)
mod = ExMod()
mod.eval()
export_output = aot.export(mod,x)
export_output.save_mlir('bnex.mlir')
# BUG1:
# iree version: https://github.com/AmosLewis/SHARK-Turbine/commits/main
# commit ID: 4117974833a7c5b35c81f30a8c9d523f483b514f
# TypeError: print(): incompatible function arguments. The following argument types are supported:
# 1. (self: iree.compiler._mlir_libs._mlir.ir._OperationBase, state: mlir::python::PyAsmState, file: object = None, binary: bool = False) -> None
# 2. (self: iree.compiler._mlir_libs._mlir.ir._OperationBase, large_elements_limit: Optional[int] = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False, file: object = None, binary: bool = False) -> None
# Invoked with: <iree.compiler._mlir_libs._mlir.ir.Operation object at 0x7f9fea26e030>, <_io.BufferedWriter name='bnex.mlir'>; kwargs: binary=True
binary = export_output.compile(save_to = None)
# BUG2:
# (iree_venv) ➜ src cd /nodclouddata/chi/src ; /usr/bin/env /nodclouddata/chi/s
# rc/SHARK-Turbine/.venv/bin/python3.11 /home/chi/.vscode-server/extensions/ms-py
# thon.python-2023.20.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/laun
# cher 38747 -- /nodclouddata/chi/src/SHARK-Turbine/tests/aot/batchnorm2d_test.py
# loc("<eval_with_key>.0 from /nodclouddata/chi/src/SHARK-Turbine/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped":5:0):
# error: 'tensor.cast' op operand type 'tensor<?xui8>' and result type 'tensor<0xi8>' are cast incompatible
@AmosLewis
Copy link
Author

iree-compile bnex.mlir --compile-to=input --debug

// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
'tensor.cast' op operand type 'tensor<?xui8>' and result type 'tensor<0xi8>' are cast incompatible
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() <{function_type = (!torch.vtensor<[20,100,35,45],f32>) -> !torch.vtensor<[20,100,35,45],f32>, sym_name = "forward", sym_visibility = "private"}> ({
^bb0(%arg0: !torch.vtensor<[20,100,35,45],f32>):
  %0 = "builtin.unrealized_conversion_cast"(%arg0) : (!torch.vtensor<[20,100,35,45],f32>) -> tensor<20x100x35x45xf32>
  %1 = "torch.constant.int"() <{value = 6 : i64}> : () -> !torch.int
  %2 = "builtin.unrealized_conversion_cast"(%1) : (!torch.int) -> i64
  %3 = "torch.constant.int"() <{value = -1 : i64}> : () -> !torch.int
  %4 = "builtin.unrealized_conversion_cast"(%3) : (!torch.int) -> i64
  %5 = "torch.constant.bool"() <{value = false}> : () -> !torch.bool
  %6 = "builtin.unrealized_conversion_cast"(%5) : (!torch.bool) -> i1
  %7 = "torch.constant.int"() <{value = 1 : i64}> : () -> !torch.int
  %8 = "builtin.unrealized_conversion_cast"(%7) : (!torch.int) -> i64
  %9 = "torch.constant.float"() <{value = 1.000000e-05 : f64}> : () -> !torch.float
  %10 = "builtin.unrealized_conversion_cast"(%9) : (!torch.float) -> f64
  %11 = "torch.vtensor.literal"() <{value = dense<1.000000e+00> : tensor<100xf32>}> : () -> !torch.vtensor<[100],f32>
  %12 = "builtin.unrealized_conversion_cast"(%11) : (!torch.vtensor<[100],f32>) -> tensor<100xf32>
  %13 = "torch.vtensor.literal"() <{value = dense<0.000000e+00> : tensor<100xf32>}> : () -> !torch.vtensor<[100],f32>
  %14 = "builtin.unrealized_conversion_cast"(%13) : (!torch.vtensor<[100],f32>) -> tensor<100xf32>
  %15 = "torch.constant.none"() : () -> !torch.none
  %16 = "torch.constant.int"() <{value = 0 : i64}> : () -> !torch.int
  %17 = "builtin.unrealized_conversion_cast"(%16) : (!torch.int) -> i64
  %18 = "torch.prim.ListConstruct"(%16) : (!torch.int) -> !torch.list<int>
  %19 = "torch.constant.device"() <{value = "cpu"}> : () -> !torch.Device
  %20 = "torch_c.to_i64"(%16) : (!torch.int) -> i64
  %21 = "arith.index_cast"(%20) : (i64) -> index
  %22 = "tensor.empty"(%21) : (index) -> tensor<?xui8>
  %23 = "tensor.cast"(%22) : (tensor<?xui8>) -> tensor<0xi8>
  %24 = "torch.aten.empty.memory_format"(%18, %16, %16, %19, %15, %15) : (!torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.none, !torch.none) -> !torch.vtensor<[0],ui8>
  %25 = "arith.constant"() <{value = 1 : index}> : () -> index
  %26 = "arith.constant"() <{value = 0 : index}> : () -> index
  %27 = "arith.constant"() <{value = 100 : index}> : () -> index
  %28 = "tensor.empty"() : () -> tensor<100xf32>
  %29 = "linalg.generic"(%12, %28) <{indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = [#linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 1>}> ({
  ^bb0(%arg1: f32, %arg2: f32):
    %151 = "arith.truncf"(%10) : (f64) -> f32
    %152 = "arith.sitofp"(%8) : (i64) -> f32
    %153 = "arith.mulf"(%151, %152) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
    %154 = "arith.addf"(%arg1, %153) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
    "linalg.yield"(%154) : (f32) -> ()
  }) : (tensor<100xf32>, tensor<100xf32>) -> tensor<100xf32>
  %30 = "tensor.cast"(%29) : (tensor<100xf32>) -> tensor<100xf32>
  %31 = "torch.aten.add.Scalar"(%11, %9, %7) : (!torch.vtensor<[100],f32>, !torch.float, !torch.int) -> !torch.vtensor<[100],f32>
  %32 = "arith.constant"() <{value = 1 : index}> : () -> index
  %33 = "arith.constant"() <{value = 0 : index}> : () -> index
  %34 = "arith.constant"() <{value = 100 : index}> : () -> index
  %35 = "tensor.empty"() : () -> tensor<100xf32>
  %36 = "linalg.generic"(%30, %35) <{indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = [#linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 1>}> ({
  ^bb0(%arg1: f32, %arg2: f32):
    %151 = "math.sqrt"(%arg1) <{fastmath = #arith.fastmath<none>}> : (f32) -> f32
    "linalg.yield"(%151) : (f32) -> ()
  }) : (tensor<100xf32>, tensor<100xf32>) -> tensor<100xf32>
  %37 = "tensor.cast"(%36) : (tensor<100xf32>) -> tensor<100xf32>
  %38 = "torch.aten.sqrt"(%31) : (!torch.vtensor<[100],f32>) -> !torch.vtensor<[100],f32>
  %39 = "arith.constant"() <{value = 1 : index}> : () -> index
  %40 = "arith.constant"() <{value = 0 : index}> : () -> index
  %41 = "arith.constant"() <{value = 100 : index}> : () -> index
  %42 = "tensor.empty"() : () -> tensor<100xf32>
  %43 = "linalg.generic"(%37, %42) <{indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = [#linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 1>}> ({
  ^bb0(%arg1: f32, %arg2: f32):
    %151 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
    %152 = "arith.cmpf"(%arg1, %151) <{predicate = 6 : i64}> : (f32, f32) -> i1
    "cf.assert"(%152) <{msg = "unimplemented: tensor with zero element"}> : (i1) -> ()
    %153 = "arith.constant"() <{value = 1.000000e+00 : f32}> : () -> f32
    %154 = "arith.divf"(%153, %arg1) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
    "linalg.yield"(%154) : (f32) -> ()
  }) : (tensor<100xf32>, tensor<100xf32>) -> tensor<100xf32>
  %44 = "tensor.cast"(%43) : (tensor<100xf32>) -> tensor<100xf32>
  %45 = "torch.aten.reciprocal"(%38) : (!torch.vtensor<[100],f32>) -> !torch.vtensor<[100],f32>
  %46 = "arith.constant"() <{value = 1 : index}> : () -> index
  %47 = "arith.constant"() <{value = 0 : index}> : () -> index
  %48 = "arith.constant"() <{value = 100 : index}> : () -> index
  %49 = "tensor.empty"() : () -> tensor<100xf32>
  %50 = "linalg.generic"(%44, %49) <{indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = [#linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 1>}> ({
  ^bb0(%arg1: f32, %arg2: f32):
    %151 = "arith.sitofp"(%8) : (i64) -> f32
    %152 = "arith.mulf"(%arg1, %151) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
    "linalg.yield"(%152) : (f32) -> ()
  }) : (tensor<100xf32>, tensor<100xf32>) -> tensor<100xf32>
  %51 = "tensor.cast"(%50) : (tensor<100xf32>) -> tensor<100xf32>
  %52 = "torch.aten.mul.Scalar"(%45, %7) : (!torch.vtensor<[100],f32>, !torch.int) -> !torch.vtensor<[100],f32>
  %53 = "torch.prim.ListConstruct"(%16) : (!torch.int) -> !torch.list<int>
  %54 = "torch_c.to_i64"(%16) : (!torch.int) -> i64
  %55 = "arith.index_cast"(%54) : (i64) -> index
  %56 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
  %57 = "tensor.empty"(%55) : (index) -> tensor<?xf32>
  %58 = "linalg.fill"(%56, %57) <{operandSegmentSizes = array<i32: 1, 1>}> ({
  ^bb0(%arg1: f32, %arg2: f32):
    "linalg.yield"(%arg1) : (f32) -> ()
  }) : (f32, tensor<?xf32>) -> tensor<?xf32>
  %59 = "tensor.cast"(%58) : (tensor<?xf32>) -> tensor<0xf32>
  %60 = "torch.aten.zeros"(%53, %1, %15, %15, %5) : (!torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool) -> !torch.vtensor<[0],f32>
  %61 = "torch.prim.ListConstruct"(%16) : (!torch.int) -> !torch.list<int>
  %62 = "torch_c.to_i64"(%16) : (!torch.int) -> i64
  %63 = "arith.index_cast"(%62) : (i64) -> index
  %64 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
  %65 = "tensor.empty"(%63) : (index) -> tensor<?xf32>
  %66 = "linalg.fill"(%64, %65) <{operandSegmentSizes = array<i32: 1, 1>}> ({
  ^bb0(%arg1: f32, %arg2: f32):
    "linalg.yield"(%arg1) : (f32) -> ()
  }) : (f32, tensor<?xf32>) -> tensor<?xf32>
  %67 = "tensor.cast"(%66) : (tensor<?xf32>) -> tensor<0xf32>
  %68 = "torch.aten.zeros"(%61, %1, %15, %15, %5) : (!torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool) -> !torch.vtensor<[0],f32>
  %69 = "tensor.expand_shape"(%14) <{reassociation = [[0, 1]]}> : (tensor<100xf32>) -> tensor<100x1xf32>
  %70 = "torch.aten.unsqueeze"(%13, %3) : (!torch.vtensor<[100],f32>, !torch.int) -> !torch.vtensor<[100,1],f32>
  %71 = "tensor.expand_shape"(%69) <{reassociation = [[0], [1, 2]]}> : (tensor<100x1xf32>) -> tensor<100x1x1xf32>
  %72 = "torch.aten.unsqueeze"(%70, %3) : (!torch.vtensor<[100,1],f32>, !torch.int) -> !torch.vtensor<[100,1,1],f32>
  %73 = "tensor.expand_shape"(%51) <{reassociation = [[0, 1]]}> : (tensor<100xf32>) -> tensor<100x1xf32>
  %74 = "torch.aten.unsqueeze"(%52, %3) : (!torch.vtensor<[100],f32>, !torch.int) -> !torch.vtensor<[100,1],f32>
  %75 = "tensor.expand_shape"(%73) <{reassociation = [[0], [1, 2]]}> : (tensor<100x1xf32>) -> tensor<100x1x1xf32>
  %76 = "torch.aten.unsqueeze"(%74, %3) : (!torch.vtensor<[100,1],f32>, !torch.int) -> !torch.vtensor<[100,1,1],f32>
  %77 = "arith.constant"() <{value = 1 : index}> : () -> index
  %78 = "arith.constant"() <{value = 0 : index}> : () -> index
  %79 = "arith.constant"() <{value = 20 : index}> : () -> index
  %80 = "arith.constant"() <{value = 1 : index}> : () -> index
  %81 = "arith.constant"() <{value = 100 : index}> : () -> index
  %82 = "arith.constant"() <{value = 2 : index}> : () -> index
  %83 = "arith.constant"() <{value = 35 : index}> : () -> index
  %84 = "arith.constant"() <{value = 3 : index}> : () -> index
  %85 = "arith.constant"() <{value = 45 : index}> : () -> index
  %86 = "arith.constant"() <{value = 0 : index}> : () -> index
  %87 = "arith.constant"() <{value = 100 : index}> : () -> index
  %88 = "tensor.empty"() : () -> tensor<20x100x35x45xf32>
  %89 = "linalg.generic"(%0, %71, %88) <{indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 2, 1>}> ({
  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
    %151 = "arith.sitofp"(%8) : (i64) -> f32
    %152 = "arith.mulf"(%arg2, %151) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
    %153 = "arith.subf"(%arg1, %152) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
    "linalg.yield"(%153) : (f32) -> ()
  }) : (tensor<20x100x35x45xf32>, tensor<100x1x1xf32>, tensor<20x100x35x45xf32>) -> tensor<20x100x35x45xf32>
  %90 = "tensor.cast"(%89) : (tensor<20x100x35x45xf32>) -> tensor<20x100x35x45xf32>
  %91 = "torch.aten.sub.Tensor"(%arg0, %72, %7) : (!torch.vtensor<[20,100,35,45],f32>, !torch.vtensor<[100,1,1],f32>, !torch.int) -> !torch.vtensor<[20,100,35,45],f32>
  %92 = "arith.constant"() <{value = 1 : index}> : () -> index
  %93 = "arith.constant"() <{value = 0 : index}> : () -> index
  %94 = "arith.constant"() <{value = 20 : index}> : () -> index
  %95 = "arith.constant"() <{value = 1 : index}> : () -> index
  %96 = "arith.constant"() <{value = 100 : index}> : () -> index
  %97 = "arith.constant"() <{value = 2 : index}> : () -> index
  %98 = "arith.constant"() <{value = 35 : index}> : () -> index
  %99 = "arith.constant"() <{value = 3 : index}> : () -> index
  %100 = "arith.constant"() <{value = 45 : index}> : () -> index
  %101 = "arith.constant"() <{value = 0 : index}> : () -> index
  %102 = "arith.constant"() <{value = 100 : index}> : () -> index
  %103 = "tensor.empty"() : () -> tensor<20x100x35x45xf32>
  %104 = "linalg.generic"(%90, %75, %103) <{indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 2, 1>}> ({
  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
    %151 = "arith.mulf"(%arg1, %arg2) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
    "linalg.yield"(%151) : (f32) -> ()
  }) : (tensor<20x100x35x45xf32>, tensor<100x1x1xf32>, tensor<20x100x35x45xf32>) -> tensor<20x100x35x45xf32>
  %105 = "tensor.cast"(%104) : (tensor<20x100x35x45xf32>) -> tensor<20x100x35x45xf32>
  %106 = "torch.aten.mul.Tensor"(%91, %76) : (!torch.vtensor<[20,100,35,45],f32>, !torch.vtensor<[100,1,1],f32>) -> !torch.vtensor<[20,100,35,45],f32>
  %107 = "util.global.load"() <{global = @_params.m.weight}> : () -> tensor<100xf32>
  %108 = "torch_c.from_builtin_tensor"(%107) : (tensor<100xf32>) -> !torch.vtensor<[100],f32>
  %109 = "builtin.unrealized_conversion_cast"(%108) : (!torch.vtensor<[100],f32>) -> tensor<100xf32>
  %110 = "tensor.expand_shape"(%109) <{reassociation = [[0, 1]]}> : (tensor<100xf32>) -> tensor<100x1xf32>
  %111 = "torch.aten.unsqueeze"(%108, %3) : (!torch.vtensor<[100],f32>, !torch.int) -> !torch.vtensor<[100,1],f32>
  %112 = "tensor.expand_shape"(%110) <{reassociation = [[0], [1, 2]]}> : (tensor<100x1xf32>) -> tensor<100x1x1xf32>
  %113 = "torch.aten.unsqueeze"(%111, %3) : (!torch.vtensor<[100,1],f32>, !torch.int) -> !torch.vtensor<[100,1,1],f32>
  %114 = "arith.constant"() <{value = 1 : index}> : () -> index
  %115 = "arith.constant"() <{value = 0 : index}> : () -> index
  %116 = "arith.constant"() <{value = 20 : index}> : () -> index
  %117 = "arith.constant"() <{value = 1 : index}> : () -> index
  %118 = "arith.constant"() <{value = 100 : index}> : () -> index
  %119 = "arith.constant"() <{value = 2 : index}> : () -> index
  %120 = "arith.constant"() <{value = 35 : index}> : () -> index
  %121 = "arith.constant"() <{value = 3 : index}> : () -> index
  %122 = "arith.constant"() <{value = 45 : index}> : () -> index
  %123 = "arith.constant"() <{value = 0 : index}> : () -> index
  %124 = "arith.constant"() <{value = 100 : index}> : () -> index
  %125 = "tensor.empty"() : () -> tensor<20x100x35x45xf32>
  %126 = "linalg.generic"(%105, %112, %125) <{indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 2, 1>}> ({
  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
    %151 = "arith.mulf"(%arg1, %arg2) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
    "linalg.yield"(%151) : (f32) -> ()
  }) : (tensor<20x100x35x45xf32>, tensor<100x1x1xf32>, tensor<20x100x35x45xf32>) -> tensor<20x100x35x45xf32>
  %127 = "tensor.cast"(%126) : (tensor<20x100x35x45xf32>) -> tensor<20x100x35x45xf32>
  %128 = "torch.aten.mul.Tensor"(%106, %113) : (!torch.vtensor<[20,100,35,45],f32>, !torch.vtensor<[100,1,1],f32>) -> !torch.vtensor<[20,100,35,45],f32>
  %129 = "util.global.load"() <{global = @_params.m.bias}> : () -> tensor<100xf32>
  %130 = "torch_c.from_builtin_tensor"(%129) : (tensor<100xf32>) -> !torch.vtensor<[100],f32>
  %131 = "builtin.unrealized_conversion_cast"(%130) : (!torch.vtensor<[100],f32>) -> tensor<100xf32>
  %132 = "tensor.expand_shape"(%131) <{reassociation = [[0, 1]]}> : (tensor<100xf32>) -> tensor<100x1xf32>
  %133 = "torch.aten.unsqueeze"(%130, %3) : (!torch.vtensor<[100],f32>, !torch.int) -> !torch.vtensor<[100,1],f32>
  %134 = "tensor.expand_shape"(%132) <{reassociation = [[0], [1, 2]]}> : (tensor<100x1xf32>) -> tensor<100x1x1xf32>
  %135 = "torch.aten.unsqueeze"(%133, %3) : (!torch.vtensor<[100,1],f32>, !torch.int) -> !torch.vtensor<[100,1,1],f32>
  %136 = "arith.constant"() <{value = 1 : index}> : () -> index
  %137 = "arith.constant"() <{value = 0 : index}> : () -> index
  %138 = "arith.constant"() <{value = 20 : index}> : () -> index
  %139 = "arith.constant"() <{value = 1 : index}> : () -> index
  %140 = "arith.constant"() <{value = 100 : index}> : () -> index
  %141 = "arith.constant"() <{value = 2 : index}> : () -> index
  %142 = "arith.constant"() <{value = 35 : index}> : () -> index
  %143 = "arith.constant"() <{value = 3 : index}> : () -> index
  %144 = "arith.constant"() <{value = 45 : index}> : () -> index
  %145 = "arith.constant"() <{value = 0 : index}> : () -> index
  %146 = "arith.constant"() <{value = 100 : index}> : () -> index
  %147 = "tensor.empty"() : () -> tensor<20x100x35x45xf32>
  %148 = "linalg.generic"(%127, %134, %147) <{indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 2, 1>}> ({
  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
    %151 = "arith.sitofp"(%8) : (i64) -> f32
    %152 = "arith.mulf"(%arg2, %151) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
    %153 = "arith.addf"(%arg1, %152) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
    "linalg.yield"(%153) : (f32) -> ()
  }) : (tensor<20x100x35x45xf32>, tensor<100x1x1xf32>, tensor<20x100x35x45xf32>) -> tensor<20x100x35x45xf32>
  %149 = "tensor.cast"(%148) : (tensor<20x100x35x45xf32>) -> tensor<20x100x35x45xf32>
  %150 = "torch.aten.add.Tensor"(%128, %135, %7) : (!torch.vtensor<[20,100,35,45],f32>, !torch.vtensor<[100,1,1],f32>, !torch.int) -> !torch.vtensor<[20,100,35,45],f32>
  "func.return"(%150) : (!torch.vtensor<[20,100,35,45],f32>) -> ()
}) {torch.assume_strict_symbolic_shapes} : () -> ()


} -> SUCCESS
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'func.return'(0x5618c3142670) {
  "func.return"(%150) : (!torch.vtensor<[20,100,35,45],f32>) -> ()

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
** Insert  : 'torch_c.to_builtin_tensor'(0x7fe5800298a0)
** Insert  : 'torch_c.to_builtin_tensor'(0x7fe580029990)
** Insert  : 'torch_c.to_builtin_tensor'(0x7fe580029a20)
** Insert  : 'torch_c.to_builtin_tensor'(0x7fe580029ab0)
** Insert  : 'torch_c.to_i64'(0x7fe580029b40)
** Insert  : 'torch_c.to_f64'(0x7fe580029bd0)
** Insert  : 'torch_c.to_builtin_tensor'(0x7fe580029c60)
** Insert  : 'torch_c.to_i64'(0x7fe580029d70)
** Insert  : 'torch_c.from_builtin_tensor'(0x7fe580029050)
bnex.mlir:18:10: error: 'tensor.cast' op operand type 'tensor<?xui8>' and result type 'tensor<0xi8>' are cast incompatible
    %1 = torch.aten.empty.memory_format %0, %int0_0, %int0_1, %cpu, %none, %none_2 : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.none, !torch.none -> !torch.vtensor<[0],ui8>
         ^
bnex.mlir:18:10: note: see current operation: %20 = "tensor.cast"(%19) : (tensor<?xui8>) -> tensor<0xi8>

@AmosLewis
Copy link
Author

bnex.mlir

module @ExMod {
  util.global private @_params.m.weight {noinline} = dense<1.000000e+00> : tensor<100xf32>
  util.global private @_params.m.bias {noinline} = dense<0.000000e+00> : tensor<100xf32>
  func.func @main(%arg0: tensor<20x100x35x45xf32>) -> tensor<20x100x35x45xf32> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
    %0 = torch_c.from_builtin_tensor %arg0 : tensor<20x100x35x45xf32> -> !torch.vtensor<[20,100,35,45],f32>
    %1 = call @forward(%0) : (!torch.vtensor<[20,100,35,45],f32>) -> !torch.vtensor<[20,100,35,45],f32>
    %2 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[20,100,35,45],f32> -> tensor<20x100x35x45xf32>
    return %2 : tensor<20x100x35x45xf32>
  }
  func.func private @forward(%arg0: !torch.vtensor<[20,100,35,45],f32>) -> !torch.vtensor<[20,100,35,45],f32> {
    %int0 = torch.constant.int 0
    %0 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
    %int0_0 = torch.constant.int 0
    %int0_1 = torch.constant.int 0
    %cpu = torch.constant.device "cpu"
    %none = torch.constant.none
    %none_2 = torch.constant.none
    %1 = torch.aten.empty.memory_format %0, %int0_0, %int0_1, %cpu, %none, %none_2 : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.none, !torch.none -> !torch.vtensor<[0],ui8>
    %2 = torch.vtensor.literal(dense<0.000000e+00> : tensor<100xf32>) : !torch.vtensor<[100],f32>
    %int6 = torch.constant.int 6
    %3 = torch.prims.convert_element_type %2, %int6 : !torch.vtensor<[100],f32>, !torch.int -> !torch.vtensor<[100],f32>
    %4 = torch.vtensor.literal(dense<1.000000e+00> : tensor<100xf32>) : !torch.vtensor<[100],f32>
    %int6_3 = torch.constant.int 6
    %5 = torch.prims.convert_element_type %4, %int6_3 : !torch.vtensor<[100],f32>, !torch.int -> !torch.vtensor<[100],f32>
    %float1.000000e-05 = torch.constant.float 1.000000e-05
    %int1 = torch.constant.int 1
    %6 = torch.aten.add.Scalar %5, %float1.000000e-05, %int1 : !torch.vtensor<[100],f32>, !torch.float, !torch.int -> !torch.vtensor<[100],f32>
    %7 = torch.aten.sqrt %6 : !torch.vtensor<[100],f32> -> !torch.vtensor<[100],f32>
    %8 = torch.aten.reciprocal %7 : !torch.vtensor<[100],f32> -> !torch.vtensor<[100],f32>
    %int1_4 = torch.constant.int 1
    %9 = torch.aten.mul.Scalar %8, %int1_4 : !torch.vtensor<[100],f32>, !torch.int -> !torch.vtensor<[100],f32>
    %int0_5 = torch.constant.int 0
    %10 = torch.prim.ListConstruct %int0_5 : (!torch.int) -> !torch.list<int>
    %none_6 = torch.constant.none
    %none_7 = torch.constant.none
    %none_8 = torch.constant.none
    %false = torch.constant.bool false
    %11 = torch.aten.new_zeros %arg0, %10, %none_6, %none_7, %none_8, %false : !torch.vtensor<[20,100,35,45],f32>, !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[0],f32>
    %int0_9 = torch.constant.int 0
    %12 = torch.prim.ListConstruct %int0_9 : (!torch.int) -> !torch.list<int>
    %none_10 = torch.constant.none
    %none_11 = torch.constant.none
    %none_12 = torch.constant.none
    %false_13 = torch.constant.bool false
    %13 = torch.aten.new_zeros %arg0, %12, %none_10, %none_11, %none_12, %false_13 : !torch.vtensor<[20,100,35,45],f32>, !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[0],f32>
    %int-1 = torch.constant.int -1
    %14 = torch.aten.unsqueeze %3, %int-1 : !torch.vtensor<[100],f32>, !torch.int -> !torch.vtensor<[100,1],f32>
    %int-1_14 = torch.constant.int -1
    %15 = torch.aten.unsqueeze %14, %int-1_14 : !torch.vtensor<[100,1],f32>, !torch.int -> !torch.vtensor<[100,1,1],f32>
    %int-1_15 = torch.constant.int -1
    %16 = torch.aten.unsqueeze %9, %int-1_15 : !torch.vtensor<[100],f32>, !torch.int -> !torch.vtensor<[100,1],f32>
    %int-1_16 = torch.constant.int -1
    %17 = torch.aten.unsqueeze %16, %int-1_16 : !torch.vtensor<[100,1],f32>, !torch.int -> !torch.vtensor<[100,1,1],f32>
    %int1_17 = torch.constant.int 1
    %18 = torch.aten.sub.Tensor %arg0, %15, %int1_17 : !torch.vtensor<[20,100,35,45],f32>, !torch.vtensor<[100,1,1],f32>, !torch.int -> !torch.vtensor<[20,100,35,45],f32>
    %19 = torch.aten.mul.Tensor %18, %17 : !torch.vtensor<[20,100,35,45],f32>, !torch.vtensor<[100,1,1],f32> -> !torch.vtensor<[20,100,35,45],f32>
    %_params.m.weight = util.global.load @_params.m.weight : tensor<100xf32>
    %20 = torch_c.from_builtin_tensor %_params.m.weight : tensor<100xf32> -> !torch.vtensor<[100],f32>
    %int-1_18 = torch.constant.int -1
    %21 = torch.aten.unsqueeze %20, %int-1_18 : !torch.vtensor<[100],f32>, !torch.int -> !torch.vtensor<[100,1],f32>
    %int-1_19 = torch.constant.int -1
    %22 = torch.aten.unsqueeze %21, %int-1_19 : !torch.vtensor<[100,1],f32>, !torch.int -> !torch.vtensor<[100,1,1],f32>
    %23 = torch.aten.mul.Tensor %19, %22 : !torch.vtensor<[20,100,35,45],f32>, !torch.vtensor<[100,1,1],f32> -> !torch.vtensor<[20,100,35,45],f32>
    %_params.m.bias = util.global.load @_params.m.bias : tensor<100xf32>
    %24 = torch_c.from_builtin_tensor %_params.m.bias : tensor<100xf32> -> !torch.vtensor<[100],f32>
    %int-1_20 = torch.constant.int -1
    %25 = torch.aten.unsqueeze %24, %int-1_20 : !torch.vtensor<[100],f32>, !torch.int -> !torch.vtensor<[100,1],f32>
    %int-1_21 = torch.constant.int -1
    %26 = torch.aten.unsqueeze %25, %int-1_21 : !torch.vtensor<[100,1],f32>, !torch.int -> !torch.vtensor<[100,1,1],f32>
    %int1_22 = torch.constant.int 1
    %27 = torch.aten.add.Tensor %23, %26, %int1_22 : !torch.vtensor<[20,100,35,45],f32>, !torch.vtensor<[100,1,1],f32>, !torch.int -> !torch.vtensor<[20,100,35,45],f32>
    return %27 : !torch.vtensor<[20,100,35,45],f32>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment