Skip to content

Commit f461a7e

Browse files
authored
[TORCH][MLIR] Add E2E support for aten._softmax operation. (#431)
Signed-Off-By: Prateek Gupta <[email protected]>
1 parent 67ce816 commit f461a7e

File tree

5 files changed

+118
-13
lines changed

5 files changed

+118
-13
lines changed

e2e_testing/torchscript/basic.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,23 @@ def forward(self, tensor):
457457
def SoftmaxIntModule_basic(module, tu: TestUtils):
458458
module.forward(torch.randn(3, 2, 4))
459459

460+
class _SoftmaxModule(torch.nn.Module):
461+
def __init__(self):
462+
super().__init__()
463+
464+
@export
465+
@annotate_args([
466+
None,
467+
([-1, -1, -1], torch.float32, True),
468+
])
469+
def forward(self, tensor):
470+
return torch.ops.aten._softmax(tensor, 0, False)
471+
472+
473+
@register_test_case(module_factory=lambda: _SoftmaxModule())
474+
def _SoftmaxModule_basic(module, tu: TestUtils):
475+
module.forward(torch.randn(3, 2, 4))
476+
460477

461478
class SoftmaxIntNegDimModule(torch.nn.Module):
462479
def __init__(self):

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,6 +1392,22 @@ def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [
13921392
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
13931393
}
13941394

1395+
def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [
1396+
AllowsTypeRefinement,
1397+
HasValueSemantics
1398+
]> {
1399+
let summary = "Generated op for `aten::_softmax : (Tensor, int, bool) -> (Tensor)`";
1400+
let arguments = (ins
1401+
AnyTorchTensorType:$self,
1402+
Torch_IntType:$dim,
1403+
Torch_BoolType:$half_to_float
1404+
);
1405+
let results = (outs
1406+
AnyTorchTensorType:$result
1407+
);
1408+
let assemblyFormat = "$self `,` $dim `,` $half_to_float attr-dict `:` type($self) `,` type($dim) `,` type($half_to_float) `->` type($result)";
1409+
}
1410+
13951411
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
13961412
AllowsTypeRefinement
13971413
]> {

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,34 @@ class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
126126
};
127127
} // namespace
128128

129+
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
130+
// exp(x)/sum(exp(x)).
131+
template <typename OpTy>
132+
static Value getSoftmaxResult(OpTy op, Type resultType,
133+
PatternRewriter &rewriter) {
134+
Location loc = op.getLoc();
135+
Value dim = op.dim();
136+
Value self = op.self();
137+
138+
// exp(x)
139+
Value exp = rewriter.create<AtenExpOp>(loc, resultType, self);
140+
// sum(exp(x))
141+
Value sum =
142+
createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
143+
if (!sum)
144+
return nullptr;
145+
// exp(x) / sum(exp(x))
146+
return rewriter.create<AtenDivTensorOp>(loc, resultType, exp, sum);
147+
}
148+
129149
// Decompose softmax into: exp(x) / sum(exp(x))
130150
namespace {
131151
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
132152
public:
133153
using OpRewritePattern::OpRewritePattern;
134154
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
135155
PatternRewriter &rewriter) const override {
136-
Location loc = op.getLoc();
137156
Value self = op.self();
138-
Value dim = op.dim();
139157
if (!op.dtype().getType().isa<Torch::NoneType>())
140158
return rewriter.notifyMatchFailure(
141159
op, "Unimplemented non-None dtype for softmax");
@@ -144,14 +162,40 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
144162
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
145163
return rewriter.notifyMatchFailure(op, "Only support floating type");
146164

147-
// exp(x)
148-
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self);
149-
// sum(exp(x))
150-
Value sum = createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
151-
if (!sum)
165+
Value result = getSoftmaxResult(op, tensorType, rewriter);
166+
if (!result)
152167
return failure();
153-
// exp(x) / sum(exp(x))
154-
Value result = rewriter.create<AtenDivTensorOp>(loc, tensorType, exp, sum);
168+
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
169+
result);
170+
return success();
171+
}
172+
};
173+
} // namespace
174+
175+
namespace {
176+
class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
177+
public:
178+
using OpRewritePattern::OpRewritePattern;
179+
LogicalResult matchAndRewrite(Aten_SoftmaxOp op,
180+
PatternRewriter &rewriter) const override {
181+
Value self = op.self();
182+
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
183+
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
184+
return rewriter.notifyMatchFailure(op, "Only support floating type");
185+
bool halfToFloat;
186+
if (!matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat)))
187+
return rewriter.notifyMatchFailure(
188+
op, "Expected a boolean value for half_to_float");
189+
190+
// Currently, setting `halfToFloat` is not supported as the E2E testing for
191+
// the same is not present on CPU.
192+
if (halfToFloat)
193+
return rewriter.notifyMatchFailure(
194+
op, "halfToFloat is currently not supported.");
195+
196+
Value result = getSoftmaxResult(op, tensorType, rewriter);
197+
if (!result)
198+
return op.emitError("failed to get softmax result");
155199
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
156200
result);
157201
return success();
@@ -406,6 +450,8 @@ class DecomposeComplexOpsPass
406450

407451
patterns.add<DecomposeAtenSoftmaxIntOp>(context);
408452
target.addIllegalOp<AtenSoftmaxIntOp>();
453+
patterns.add<DecomposeAten_SoftmaxOp>(context);
454+
target.addIllegalOp<Aten_SoftmaxOp>();
409455
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
410456
target.addIllegalOp<AtenLogSoftmaxIntOp>();
411457
patterns.add<DecomposeAtenExpandOp>(context);

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
418418
return visitAtenMatmulOp(matmul, operands);
419419
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
420420
return visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
421+
} else if (auto _softmaxOp = dyn_cast<Aten_SoftmaxOp>(op)) {
422+
return visitAten_SoftmaxOp(_softmaxOp, operands);
421423
} else if (auto logSoftmaxIntOp = dyn_cast<AtenLogSoftmaxIntOp>(op)) {
422424
return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
423425
} else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
@@ -541,6 +543,10 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
541543
ChangeResult
542544
visitAtenAddCLikeOp(Operation *op,
543545
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
546+
547+
ChangeResult
548+
visitAten_SoftmaxOp(Aten_SoftmaxOp op,
549+
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
544550
};
545551
} // namespace
546552

@@ -1332,21 +1338,40 @@ ChangeResult TypeAnalyzer::visitAtenEmbeddingOp(
13321338
return getLatticeElement(op.getResult()).join(knowledge);
13331339
}
13341340

1341+
static ValueKnowledge
1342+
getSameSizeAsInput(Operation *op,
1343+
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
1344+
auto input = operands[0]->getValue();
1345+
auto knowledge =
1346+
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
1347+
knowledge.hasSizes = input.hasSizes;
1348+
knowledge.sizes = input.sizes;
1349+
return knowledge;
1350+
}
13351351

13361352
// Common template for softmax like ops, eg., log_softmax.
13371353
template <typename OpTy>
13381354
ChangeResult TypeAnalyzer::visitAtenSoftmaxLikeOp(
13391355
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
13401356
auto input = operands[0]->getValue();
13411357
auto dtype = op.dtype();
1342-
auto knowledge =
1343-
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
1344-
knowledge.hasSizes = input.hasSizes;
1345-
knowledge.sizes = input.sizes;
1358+
ValueKnowledge knowledge = getSameSizeAsInput(op, operands);
13461359
fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, input.dtype);
13471360
return getLatticeElement(op.getResult()).join(knowledge);
13481361
}
13491362

1363+
ChangeResult TypeAnalyzer::visitAten_SoftmaxOp(
1364+
Aten_SoftmaxOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
1365+
auto input = operands[0]->getValue();
1366+
ValueKnowledge knowledge = getSameSizeAsInput(op, operands);
1367+
bool halfToFloat;
1368+
if (matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat))) {
1369+
knowledge.dtype =
1370+
halfToFloat ? Float32Type::get(op->getContext()) : input.dtype;
1371+
}
1372+
return getLatticeElement(op.getResult()).join(knowledge);
1373+
}
1374+
13501375
ChangeResult TypeAnalyzer::visitAtenBmmOp(
13511376
AtenBmmOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
13521377
auto knowledge =

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ def emit_with_mutating_variants(key, **kwargs):
516516
emit("aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)")
517517
emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)")
518518
emit("aten::sqrt : (Tensor) -> (Tensor)")
519+
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
519520

520521
# Misc tensor ops.
521522
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")

0 commit comments

Comments
 (0)