Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ private static StringMethodCheckStatus RecognizeStringMethodCheck(
private static CollectionCheckStatus RecognizeCollectionMethodCheck(
IOperation operation,
INamedTypeSymbol objectTypeSymbol,
INamedTypeSymbol? enumerableTypeSymbol,
out SyntaxNode? collectionExpression,
out SyntaxNode? itemExpression)
{
Expand Down Expand Up @@ -527,6 +528,19 @@ private static CollectionCheckStatus RecognizeCollectionMethodCheck(
}
}
}

// Handle LINQ Enumerable.Contains<TSource>(this IEnumerable<TSource>, TSource)
// In the Roslyn operation model, LINQ extension calls appear with ContainingType == Enumerable
// and Arguments includes the 'this' parameter, so Arguments.Length == 2.
if (methodName == "Contains" &&
invocation.Arguments.Length == 2 &&
enumerableTypeSymbol is not null &&
SymbolEqualityComparer.Default.Equals(invocation.TargetMethod.ContainingType, enumerableTypeSymbol))
{
collectionExpression = invocation.Arguments[0].Value.Syntax;
itemExpression = invocation.Arguments[1].Value.Syntax;
return CollectionCheckStatus.Contains;
}
}

collectionExpression = null;
Expand Down Expand Up @@ -748,7 +762,7 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co
}

// Check for collection method patterns: myCollection.Contains(...)
CollectionCheckStatus collectionMethodStatus = RecognizeCollectionMethodCheck(conditionArgument, objectTypeSymbol, out SyntaxNode? collectionExpr, out SyntaxNode? itemExpr);
CollectionCheckStatus collectionMethodStatus = RecognizeCollectionMethodCheck(conditionArgument, objectTypeSymbol, enumerableTypeSymbol, out SyntaxNode? collectionExpr, out SyntaxNode? itemExpr);
if (collectionMethodStatus != CollectionCheckStatus.Unknown)
{
if (collectionMethodStatus == CollectionCheckStatus.Contains)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3853,4 +3853,286 @@ await VerifyCS.VerifyCodeFixAsync(
VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsNull", "AreEqual"),
fixedCode);
}

#region LINQ Enumerable.Contains on interface types

[TestMethod]
public async Task WhenAssertIsTrueWithLinqContainsOnIReadOnlyList()
{
string code = """
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using System.Linq;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void MyTestMethod()
{
IReadOnlyList<int> list = new List<int> { 1, 2, 3 };
{|#0:Assert.IsTrue(list.Contains(2))|};
}
}
""";

string fixedCode = """
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using System.Linq;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void MyTestMethod()
{
IReadOnlyList<int> list = new List<int> { 1, 2, 3 };
Assert.Contains(2, list);
}
}
""";

await VerifyCS.VerifyCodeFixAsync(
code,
VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"),
fixedCode);
}

[TestMethod]
public async Task WhenAssertIsFalseWithLinqContainsOnIReadOnlyList()
{
string code = """
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using System.Linq;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void MyTestMethod()
{
IReadOnlyList<int> list = new List<int> { 1, 2, 3 };
{|#0:Assert.IsFalse(list.Contains(4))|};
}
}
""";

string fixedCode = """
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using System.Linq;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void MyTestMethod()
{
IReadOnlyList<int> list = new List<int> { 1, 2, 3 };
Assert.DoesNotContain(4, list);
}
}
""";

await VerifyCS.VerifyCodeFixAsync(
code,
VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsFalse"),
fixedCode);
}

[TestMethod]
public async Task WhenAssertIsTrueWithLinqContainsOnIEnumerable()
{
string code = """
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using System.Linq;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void MyTestMethod()
{
IEnumerable<string> items = new List<string> { "a", "b", "c" };
{|#0:Assert.IsTrue(items.Contains("b"))|};
}
}
""";

string fixedCode = """
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using System.Linq;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void MyTestMethod()
{
IEnumerable<string> items = new List<string> { "a", "b", "c" };
Assert.Contains("b", items);
}
}
""";

await VerifyCS.VerifyCodeFixAsync(
code,
VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"),
fixedCode);
}

[TestMethod]
public async Task WhenAssertIsTrueWithLinqContainsOnIReadOnlyList_WithMessage()
{
string code = """
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using System.Linq;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void MyTestMethod()
{
IReadOnlyList<int> list = new List<int> { 1, 2, 3 };
{|#0:Assert.IsTrue(list.Contains(2),
"should contain the value")|};
}
}
""";

string fixedCode = """
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using System.Linq;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void MyTestMethod()
{
IReadOnlyList<int> list = new List<int> { 1, 2, 3 };
Assert.Contains(2, list,
"should contain the value");
}
}
""";

await VerifyCS.VerifyCodeFixAsync(
code,
VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"),
fixedCode);
}

[TestMethod]
public async Task WhenAssertIsTrueWithLinqContainsOnIReadOnlyCollection()
{
string code = """
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using System.Linq;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void MyTestMethod()
{
IReadOnlyCollection<int> items = new List<int> { 1, 2, 3 };
{|#0:Assert.IsTrue(items.Contains(2))|};
}
}
""";

string fixedCode = """
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using System.Linq;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void MyTestMethod()
{
IReadOnlyCollection<int> items = new List<int> { 1, 2, 3 };
Assert.Contains(2, items);
}
}
""";

await VerifyCS.VerifyCodeFixAsync(
code,
VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"),
fixedCode);
}

[TestMethod]
public async Task WhenAssertIsTrueWithLinqContainsOnCustomNonBCLCollection_Reports()
{
string code = """
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections;
using System.Collections.Generic;
using System.Linq;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void MyTestMethod()
{
var custom = new MyCustomCollection<int>(new[] { 1, 2, 3 });
{|#0:Assert.IsTrue(custom.Contains(2))|};
}

internal sealed class MyCustomCollection<T> : IEnumerable<T>
{
private readonly IEnumerable<T> _items;
public MyCustomCollection(IEnumerable<T> items) => _items = items;
public IEnumerator<T> GetEnumerator() => _items.GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}
}
""";

string fixedCode = """
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections;
using System.Collections.Generic;
using System.Linq;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void MyTestMethod()
{
var custom = new MyCustomCollection<int>(new[] { 1, 2, 3 });
Assert.Contains(2, custom);
}

internal sealed class MyCustomCollection<T> : IEnumerable<T>
{
private readonly IEnumerable<T> _items;
public MyCustomCollection(IEnumerable<T> items) => _items = items;
public IEnumerator<T> GetEnumerator() => _items.GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}
}
""";

await VerifyCS.VerifyCodeFixAsync(
code,
VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"),
fixedCode);
}

#endregion
}