如何扩展 LINQ

所有基于 LINQ 的方法都遵循两种类似的模式之一。 它们采用可枚举序列。 它们会返回不同的序列或单个值。 通过形状的一致性,可以通过编写具有类似形状的方法来扩展 LINQ。 事实上,自首次引入 LINQ 以来,.NET 库就在许多 .NET 版本中都获得了新的方法。 在本文中,你将看到通过编写遵循相同模式的自己的方法来扩展 LINQ 的示例。

为 LINQ 查询添加自定义方法

通过向 IEnumerable<T> 接口添加扩展方法扩展可用于 LINQ 查询的方法集。 例如,除了标准平均值或最大值运算,还可创建自定义聚合方法,从一系列值计算单个值。 此外,还可创建一种方法,用作值序列的自定义筛选器或特定数据转换,并返回新的序列。 DistinctSkipReverse 就是此类方法的示例。

扩展 IEnumerable<T> 接口时,可以将自定义方法应用于任何可枚举集合。 有关详细信息,请参阅扩展方法

聚合方法可从一组值计算单个值。 LINQ 提供多个聚合方法,包括 AverageMinMax。 可以通过向 IEnumerable<T> 接口添加扩展方法来创建自己的聚合方法。

下面的代码示例演示如何创建名为 Median 的扩展方法来计算类型为 double 的数字序列的中间值。

public static class EnumerableExtension
{
    public static double Median(this IEnumerable<double>? source)
    {
        if (source is null || !source.Any())
        {
            throw new InvalidOperationException("Cannot compute median for a null or empty set.");
        }

        var sortedList =
            source.OrderBy(number => number).ToList();

        int itemIndex = sortedList.Count / 2;

        if (sortedList.Count % 2 == 0)
        {
            // Even number of items.
            return (sortedList[itemIndex] + sortedList[itemIndex - 1]) / 2;
        }
        else
        {
            // Odd number of items.
            return sortedList[itemIndex];
        }
    }
}

使用从 IEnumerable<T> 接口调用其他聚合方法的方式为任何可枚举集合调用此扩展方法。

下面的代码示例说明如何为类型 double 的数组使用 Median 方法。

double[] numbers = [1.9, 2, 8, 4, 5.7, 6, 7.2, 0];
var query = numbers.Median();

Console.WriteLine($"double: Median = {query}");
// This code produces the following output:
//     double: Median = 4.85

可以重载聚合方法,以便其接受各种类型的序列。 标准做法是为每种类型都创建一个重载。 另一种方法是创建一个采用泛型类型的重载,并使用委托将其转换为特定类型。 还可以将两种方法结合。

可以为要支持的每种类型创建特定重载。 下面的代码示例演示 int 类型的 Median 方法的重载。

// int overload
public static double Median(this IEnumerable<int> source) =>
    (from number in source select (double)number).Median();

现在便可以为 integerdouble 类型调用 Median 重载了,如以下代码中所示:

double[] numbers1 = [1.9, 2, 8, 4, 5.7, 6, 7.2, 0];
var query1 = numbers1.Median();

Console.WriteLine($"double: Median = {query1}");

int[] numbers2 = [1, 2, 3, 4, 5];
var query2 = numbers2.Median();

Console.WriteLine($"int: Median = {query2}");
// This code produces the following output:
//     double: Median = 4.85
//     int: Median = 3

还可以创建接受泛型对象序列的重载。 此重载采用委托作为参数,并使用该参数将泛型类型的对象序列转换为特定类型。

下面的代码展示 Median 方法的重载,该重载将 Func<T,TResult> 委托作为参数。 此委托采用泛型类型 T 的对象,并返回类型 double 的对象。

// generic overload
public static double Median<T>(
    this IEnumerable<T> numbers, Func<T, double> selector) =>
    (from num in numbers select selector(num)).Median();

现在,可以为任何类型的对象序列调用 Median 方法。 如果类型没有它自己的方法重载,必须手动传递委托参数。 在 C# 中,可以使用 lambda 表达式实现此目的。 此外,仅限在 Visual Basic 中,如果使用 AggregateGroup By 子句而不是方法调用,可以传递此子句范围内的任何值或表达式。

下面的代码示例演示如何为整数数组和字符串数组调用 Median 方法。 对于字符串,将计算数组中字符串长度的中值。 该示例演示如何将 Func<T,TResult> 委托参数传递给每个用例的 Median 方法。

int[] numbers3 = [1, 2, 3, 4, 5];

/*
    You can use the num => num lambda expression as a parameter for the Median method
    so that the compiler will implicitly convert its value to double.
    If there is no implicit conversion, the compiler will display an error message.
*/
var query3 = numbers3.Median(num => num);

Console.WriteLine($"int: Median = {query3}");

string[] numbers4 = ["one", "two", "three", "four", "five"];

// With the generic overload, you can also use numeric properties of objects.
var query4 = numbers4.Median(str => str.Length);

Console.WriteLine($"string: Median = {query4}");
// This code produces the following output:
//     int: Median = 3
//     string: Median = 4

可以使用会返回值序列的自定义查询方法来扩展 IEnumerable<T> 接口。 在这种情况下,该方法必须返回类型 IEnumerable<T> 的集合。 此类方法可用于将筛选器或数据转换应用于值序列。

下面的示例演示如何创建名为 AlternateElements 的扩展方法,该方法从集合中第一个元素开始按相隔一个元素的方式返回集合中的元素。

// Extension method for the IEnumerable<T> interface.
// The method returns every other element of a sequence.
public static IEnumerable<T> AlternateElements<T>(this IEnumerable<T> source)
{
    int index = 0;
    foreach (T element in source)
    {
        if (index % 2 == 0)
        {
            yield return element;
        }

        index++;
    }
}

可使用从 IEnumerable<T> 接口调用其他方法的方式对任何可枚举集合调用此扩展方法,如下面的代码中所示:

string[] strings = ["a", "b", "c", "d", "e"];

var query5 = strings.AlternateElements();

foreach (var element in query5)
{
    Console.WriteLine(element);
}
// This code produces the following output:
//     a
//     c
//     e

按连续键对结果进行分组

下面的示例演示如何将元素分组为表示连续键子序列的区块。 例如,假设给定下列键值对的序列:

密钥
A We
A think
A that
B Linq
C is
A really
B cool
B !

以下组将按此顺序创建:

  1. We, think, that
  2. Linq
  3. is
  4. really
  5. cool, !

此解决方案是以线程安全扩展方法实现的,该扩展方法以流的方式返回其结果。 换言之,它在源序列中遍历移动时生成其组。 与 grouporderby 运算符不同,它能在读取所有序列之前开始将组返回给调用方。 下面的示例演示该扩展方法以及使用它的客户端代码:

public static class ChunkExtensions
{
    public static IEnumerable<IGrouping<TKey, TSource>> ChunkBy<TSource, TKey>(
            this IEnumerable<TSource> source,
            Func<TSource, TKey> keySelector) =>
                source.ChunkBy(keySelector, EqualityComparer<TKey>.Default);

    public static IEnumerable<IGrouping<TKey, TSource>> ChunkBy<TSource, TKey>(
            this IEnumerable<TSource> source,
            Func<TSource, TKey> keySelector,
            IEqualityComparer<TKey> comparer)
    {
        // Flag to signal end of source sequence.
        const bool noMoreSourceElements = true;

        // Auto-generated iterator for the source array.
        IEnumerator<TSource>? enumerator = source.GetEnumerator();

        // Move to the first element in the source sequence.
        if (!enumerator.MoveNext())
        {
            yield break;        // source collection is empty
        }

        while (true)
        {
            var key = keySelector(enumerator.Current);

            Chunk<TKey, TSource> current = new(key, enumerator, value => comparer.Equals(key, keySelector(value)));

            yield return current;

            if (current.CopyAllChunkElements() == noMoreSourceElements)
            {
                yield break;
            }
        }
    }
}
public static class GroupByContiguousKeys
{
    // The source sequence.
    static readonly KeyValuePair<string, string>[] list = [
        new("A", "We"),
        new("A", "think"),
        new("A", "that"),
        new("B", "LINQ"),
        new("C", "is"),
        new("A", "really"),
        new("B", "cool"),
        new("B", "!")
    ];

    // Query variable declared as class member to be available
    // on different threads.
    static readonly IEnumerable<IGrouping<string, KeyValuePair<string, string>>> query =
        list.ChunkBy(p => p.Key);

    public static void GroupByContiguousKeys1()
    {
        // ChunkBy returns IGrouping objects, therefore a nested
        // foreach loop is required to access the elements in each "chunk".
        foreach (var item in query)
        {
            Console.WriteLine($"Group key = {item.Key}");
            foreach (var inner in item)
            {
                Console.WriteLine($"\t{inner.Value}");
            }
        }
    }
}

ChunkExtensions

在呈现的 ChunkExtensions 类实现代码中,ChunkBy 方法中的循环 while(true) 循环访问源序列并创建每个区块的副本。 在每次传递中,迭代器前进到源序列的下一个“区块”的第一个元素(由 Chunk 对象代表)。 此循环对应于执行查询的外部 foreach 循环。 在该循环中,代码执行以下操作:

  1. 获取当前区块的键并将其分配给 key 变量。 源迭代器将遍历源序列,直到找到具有不匹配键的元素。
  2. 创建一个新的区块(组)对象,并将其存储在 current 变量中。 它具有一个 GroupItem,即当前源元素的副本。
  3. 返回该区块。 区块是一个 IGrouping<TKey,TSource>,即 ChunkBy 方法的返回值。 区块仅具有其源序列中的第一个元素。 仅当客户端代码 foreach 遍历此区块时,才会返回剩余的元素。 有关详细信息,请参阅 Chunk.GetEnumerator
  4. 检查是否存在一下情况:
    • 区块具有其所有源元素的副本,或
    • 迭代器已到达源序列的末尾。
  5. 当调用方枚举所有区块项时,Chunk.GetEnumerator 方法已复制所有区块项。 如果 Chunk.GetEnumerator 循环未枚举区块中的所有元素,则我们需要在此处执行此操作,以避免损坏可能在单独线程上进行调用的客户端的迭代器。

Chunk

Chunk 类是一个或多个具有相同键的源元素的连续组。 区块具有一个键和一个 ChunkItem 对象列表,这些对象是源序列中元素的副本:

class Chunk<TKey, TSource> : IGrouping<TKey, TSource>
{
    // INVARIANT: DoneCopyingChunk == true ||
    //   (predicate != null && predicate(enumerator.Current) && current.Value == enumerator.Current)

    // A Chunk has a linked list of ChunkItems, which represent the elements in the current chunk. Each ChunkItem
    // has a reference to the next ChunkItem in the list.
    class ChunkItem
    {
        public ChunkItem(TSource value) => Value = value;
        public readonly TSource Value;
        public ChunkItem? Next;
    }

    public TKey Key { get; }

    // Stores a reference to the enumerator for the source sequence
    private IEnumerator<TSource> enumerator;

    // A reference to the predicate that is used to compare keys.
    private Func<TSource, bool> predicate;

    // Stores the contents of the first source element that
    // belongs with this chunk.
    private readonly ChunkItem head;

    // End of the list. It is repositioned each time a new
    // ChunkItem is added.
    private ChunkItem? tail;

    // Flag to indicate the source iterator has reached the end of the source sequence.
    internal bool isLastSourceElement;

    // Private object for thread synchronization
    private readonly object m_Lock;

    // REQUIRES: enumerator != null && predicate != null
    public Chunk(TKey key, [DisallowNull] IEnumerator<TSource> enumerator, [DisallowNull] Func<TSource, bool> predicate)
    {
        Key = key;
        this.enumerator = enumerator;
        this.predicate = predicate;

        // A Chunk always contains at least one element.
        head = new ChunkItem(enumerator.Current);

        // The end and beginning are the same until the list contains > 1 elements.
        tail = head;

        m_Lock = new object();
    }

    // Indicates that all chunk elements have been copied to the list of ChunkItems.
    private bool DoneCopyingChunk => tail == null;

    // Adds one ChunkItem to the current group
    // REQUIRES: !DoneCopyingChunk && lock(this)
    private void CopyNextChunkElement()
    {
        // Try to advance the iterator on the source sequence.
        isLastSourceElement = !enumerator.MoveNext();

        // If we are (a) at the end of the source, or (b) at the end of the current chunk
        // then null out the enumerator and predicate for reuse with the next chunk.
        if (isLastSourceElement || !predicate(enumerator.Current))
        {
            enumerator = default!;
            predicate = default!;
        }
        else
        {
            tail!.Next = new ChunkItem(enumerator.Current);
        }

        // tail will be null if we are at the end of the chunk elements
        // This check is made in DoneCopyingChunk.
        tail = tail!.Next;
    }

    // Called after the end of the last chunk was reached.
    internal bool CopyAllChunkElements()
    {
        while (true)
        {
            lock (m_Lock)
            {
                if (DoneCopyingChunk)
                {
                    return isLastSourceElement;
                }
                else
                {
                    CopyNextChunkElement();
                }
            }
        }
    }

    // Stays just one step ahead of the client requests.
    public IEnumerator<TSource> GetEnumerator()
    {
        // Specify the initial element to enumerate.
        ChunkItem? current = head;

        // There should always be at least one ChunkItem in a Chunk.
        while (current != null)
        {
            // Yield the current item in the list.
            yield return current.Value;

            // Copy the next item from the source sequence,
            // if we are at the end of our local list.
            lock (m_Lock)
            {
                if (current == tail)
                {
                    CopyNextChunkElement();
                }
            }

            // Move to the next ChunkItem in the list.
            current = current.Next;
        }
    }

    System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() => GetEnumerator();
}

每个 ChunkItem(由 ChunkItem 类表示)引用列表中的下一个 ChunkItem。 该列表由其 head(存储属于此区块的第一个源元素的内容)及其 tail(列表的末尾)组成。 每次添加新的 ChunkItem 时,都会重新定位尾巴。 如果下一个元素的键与当前区块的键不匹配,或者源中没有更多元素,则链接列表的尾部将设置为 CopyNextChunkElement 方法中的 null

Chunk 类的 CopyNextChunkElement 方法向当前项组添加一个 ChunkItem。 它尝试在源序列上递进迭代器。 如果 MoveNext() 方法返回 false,则表示迭代位于末尾,并且 isLastSourceElement 设置为 true

在到达最后一个区块的末尾后调用 CopyAllChunkElements 方法。 它首先检查源序列中是否有其他元素。 如果有,在区块的枚举器已耗尽的情况下,此方法将返回 true。 在此方法中,当检查专用 DoneCopyingChunk 字段是否为 true 时,如果 isLastSourceElement 为 false,则会向外部迭代器发出信号以继续迭代。

Chunk 类的 GetEnumerator 方法由内部 foreach 循环调用。 此方法仅领先于客户端请求一个元素。 它仅在客户端请求列表中上一个最后一个元素之后,才添加区块的下一个元素。