diff --git a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs index ae00ed3773..da443496d4 100644 --- a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs +++ b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs @@ -493,6 +493,7 @@ private static StringMethodCheckStatus RecognizeStringMethodCheck( private static CollectionCheckStatus RecognizeCollectionMethodCheck( IOperation operation, INamedTypeSymbol objectTypeSymbol, + INamedTypeSymbol? enumerableTypeSymbol, out SyntaxNode? collectionExpression, out SyntaxNode? itemExpression) { @@ -527,6 +528,19 @@ private static CollectionCheckStatus RecognizeCollectionMethodCheck( } } } + + // Handle LINQ Enumerable.Contains(this IEnumerable, TSource) + // In the Roslyn operation model, LINQ extension calls appear with ContainingType == Enumerable + // and Arguments includes the 'this' parameter, so Arguments.Length == 2. + if (methodName == "Contains" && + invocation.Arguments.Length == 2 && + enumerableTypeSymbol is not null && + SymbolEqualityComparer.Default.Equals(invocation.TargetMethod.ContainingType, enumerableTypeSymbol)) + { + collectionExpression = invocation.Arguments[0].Value.Syntax; + itemExpression = invocation.Arguments[1].Value.Syntax; + return CollectionCheckStatus.Contains; + } } collectionExpression = null; @@ -748,7 +762,7 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co } // Check for collection method patterns: myCollection.Contains(...) - CollectionCheckStatus collectionMethodStatus = RecognizeCollectionMethodCheck(conditionArgument, objectTypeSymbol, out SyntaxNode? collectionExpr, out SyntaxNode? itemExpr); + CollectionCheckStatus collectionMethodStatus = RecognizeCollectionMethodCheck(conditionArgument, objectTypeSymbol, enumerableTypeSymbol, out SyntaxNode? collectionExpr, out SyntaxNode? itemExpr); if (collectionMethodStatus != CollectionCheckStatus.Unknown) { if (collectionMethodStatus == CollectionCheckStatus.Contains) diff --git a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs index 43a74c4115..8fb29a6385 100644 --- a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs +++ b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs @@ -3853,4 +3853,286 @@ await VerifyCS.VerifyCodeFixAsync( VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsNull", "AreEqual"), fixedCode); } + + #region LINQ Enumerable.Contains on interface types + + [TestMethod] + public async Task WhenAssertIsTrueWithLinqContainsOnIReadOnlyList() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IReadOnlyList list = new List { 1, 2, 3 }; + {|#0:Assert.IsTrue(list.Contains(2))|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IReadOnlyList list = new List { 1, 2, 3 }; + Assert.Contains(2, list); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsFalseWithLinqContainsOnIReadOnlyList() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IReadOnlyList list = new List { 1, 2, 3 }; + {|#0:Assert.IsFalse(list.Contains(4))|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IReadOnlyList list = new List { 1, 2, 3 }; + Assert.DoesNotContain(4, list); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsFalse"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsTrueWithLinqContainsOnIEnumerable() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { "a", "b", "c" }; + {|#0:Assert.IsTrue(items.Contains("b"))|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { "a", "b", "c" }; + Assert.Contains("b", items); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsTrueWithLinqContainsOnIReadOnlyList_WithMessage() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IReadOnlyList list = new List { 1, 2, 3 }; + {|#0:Assert.IsTrue(list.Contains(2), + "should contain the value")|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IReadOnlyList list = new List { 1, 2, 3 }; + Assert.Contains(2, list, + "should contain the value"); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsTrueWithLinqContainsOnIReadOnlyCollection() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IReadOnlyCollection items = new List { 1, 2, 3 }; + {|#0:Assert.IsTrue(items.Contains(2))|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IReadOnlyCollection items = new List { 1, 2, 3 }; + Assert.Contains(2, items); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsTrueWithLinqContainsOnCustomNonBCLCollection_Reports() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + var custom = new MyCustomCollection(new[] { 1, 2, 3 }); + {|#0:Assert.IsTrue(custom.Contains(2))|}; + } + + internal sealed class MyCustomCollection : IEnumerable + { + private readonly IEnumerable _items; + public MyCustomCollection(IEnumerable items) => _items = items; + public IEnumerator GetEnumerator() => _items.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + var custom = new MyCustomCollection(new[] { 1, 2, 3 }); + Assert.Contains(2, custom); + } + + internal sealed class MyCustomCollection : IEnumerable + { + private readonly IEnumerable _items; + public MyCustomCollection(IEnumerable items) => _items = items; + public IEnumerator GetEnumerator() => _items.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"), + fixedCode); + } + + #endregion }