120

I've got a C# string extension method that should return an IEnumerable<int> of all the indexes of a substring within a string. It works perfectly for its intended purpose and the expected results are returned (as proven by one of my tests, although not the one below), but another unit test has discovered a problem with it: it can't handle null arguments.

Here's the extension method I'm testing:

public static IEnumerable<int> AllIndexesOf(this string str, string searchText)
{
    if (searchText == null)
    {
        throw new ArgumentNullException("searchText");
    }
    for (int index = 0; ; index += searchText.Length)
    {
        index = str.IndexOf(searchText, index);
        if (index == -1)
            break;
        yield return index;
    }
}

Here is the test that flagged up the problem:

[TestMethod]
[ExpectedException(typeof(ArgumentNullException))]
public void Extensions_AllIndexesOf_HandlesNullArguments()
{
    string test = "a.b.c.d.e";
    test.AllIndexesOf(null);
}

When the test runs against my extension method, it fails, with the standard error message that the method "did not throw an exception".

This is confusing: I have clearly passed null into the function, yet for some reason the comparison null == null is returning false. Therefore, no exception is thrown and the code continues.

I have confirmed this is not a bug with the test: when running the method in my main project with a call to Console.WriteLine in the null-comparison if block, nothing is shown on the console and no exception is caught by any catch block I add. Furthermore, using string.IsNullOrEmpty instead of == null has the same problem.

Why does this supposedly-simple comparison fail?

7
  • 5
    Have you tried stepping through the code? That'll probably get it resolved pretty quickly. May 11, 2015 at 19:35
  • 1
    What does happen? (Does it throw an exception; if so, which one and what line?) May 11, 2015 at 19:36
  • @user2864740 I have described everything that happens. No exceptions, just a failed test and a run method.
    – ArtOfCode
    May 11, 2015 at 19:37
  • 8
    Iterators are not executed until they're iterated-over May 12, 2015 at 5:12
  • 2
    You're welcome. This one also made Jon's "worst gotcha" list: stackoverflow.com/a/241180/88656. This is a quite common problem. May 12, 2015 at 19:10

3 Answers 3

161

You are using yield return. When doing so, the compiler will rewrite your method into a function that returns a generated class that implements a state machine.

Broadly speaking, it rewrites locals to fields of that class and each part of your algorithm between the yield return instructions becomes a state. You can check with a decompiler what this method becomes after compilation (make sure to turn off smart decompilation which would produce yield return).

But the bottom line is: the code of your method won't be executed until you start iterating.

The usual way to check for preconditions is to split your method in two:

public static IEnumerable<int> AllIndexesOf(this string str, string searchText)
{
    if (str == null)
        throw new ArgumentNullException("str");
    if (searchText == null)
        throw new ArgumentNullException("searchText");

    return AllIndexesOfCore(str, searchText);
}

private static IEnumerable<int> AllIndexesOfCore(string str, string searchText)
{
    for (int index = 0; ; index += searchText.Length)
    {
        index = str.IndexOf(searchText, index);
        if (index == -1)
            break;
        yield return index;
    }
}

This works because the first method will behave just like you expect (immediate execution), and will return the state machine implemented by the second method.

Note that you should also check the str parameter for null, because extensions methods can be called on null values, as they're just syntactic sugar.


If you're curious about what the compiler does to your code, here's your method, decompiled with dotPeek using the Show Compiler-generated Code option.

public static IEnumerable<int> AllIndexesOf(this string str, string searchText)
{
  Test.<AllIndexesOf>d__0 allIndexesOfD0 = new Test.<AllIndexesOf>d__0(-2);
  allIndexesOfD0.<>3__str = str;
  allIndexesOfD0.<>3__searchText = searchText;
  return (IEnumerable<int>) allIndexesOfD0;
}

[CompilerGenerated]
private sealed class <AllIndexesOf>d__0 : IEnumerable<int>, IEnumerable, IEnumerator<int>, IEnumerator, IDisposable
{
  private int <>2__current;
  private int <>1__state;
  private int <>l__initialThreadId;
  public string str;
  public string <>3__str;
  public string searchText;
  public string <>3__searchText;
  public int <index>5__1;

  int IEnumerator<int>.Current
  {
    [DebuggerHidden] get
    {
      return this.<>2__current;
    }
  }

  object IEnumerator.Current
  {
    [DebuggerHidden] get
    {
      return (object) this.<>2__current;
    }
  }

  [DebuggerHidden]
  public <AllIndexesOf>d__0(int <>1__state)
  {
    base..ctor();
    this.<>1__state = param0;
    this.<>l__initialThreadId = Environment.CurrentManagedThreadId;
  }

  [DebuggerHidden]
  IEnumerator<int> IEnumerable<int>.GetEnumerator()
  {
    Test.<AllIndexesOf>d__0 allIndexesOfD0;
    if (Environment.CurrentManagedThreadId == this.<>l__initialThreadId && this.<>1__state == -2)
    {
      this.<>1__state = 0;
      allIndexesOfD0 = this;
    }
    else
      allIndexesOfD0 = new Test.<AllIndexesOf>d__0(0);
    allIndexesOfD0.str = this.<>3__str;
    allIndexesOfD0.searchText = this.<>3__searchText;
    return (IEnumerator<int>) allIndexesOfD0;
  }

  [DebuggerHidden]
  IEnumerator IEnumerable.GetEnumerator()
  {
    return (IEnumerator) this.System.Collections.Generic.IEnumerable<System.Int32>.GetEnumerator();
  }

  bool IEnumerator.MoveNext()
  {
    switch (this.<>1__state)
    {
      case 0:
        this.<>1__state = -1;
        if (this.searchText == null)
          throw new ArgumentNullException("searchText");
        this.<index>5__1 = 0;
        break;
      case 1:
        this.<>1__state = -1;
        this.<index>5__1 += this.searchText.Length;
        break;
      default:
        return false;
    }
    this.<index>5__1 = this.str.IndexOf(this.searchText, this.<index>5__1);
    if (this.<index>5__1 != -1)
    {
      this.<>2__current = this.<index>5__1;
      this.<>1__state = 1;
      return true;
    }
    goto default;
  }

  [DebuggerHidden]
  void IEnumerator.Reset()
  {
    throw new NotSupportedException();
  }

  void IDisposable.Dispose()
  {
  }
}

This is invalid C# code, because the compiler is allowed to do things the language doesn't allow, but which are legal in IL - for instance naming the variables in a way you couldn't to avoid name collisions.

But as you can see, the AllIndexesOf only constructs and returns an object, whose constructor only initializes some state. GetEnumerator only copies the object. The real work is done when you start enumerating (by calling the MoveNext method).

4
  • 9
    BTW, I added the following important point to the answer: Note that you should also check the str parameter for null, because extensions methods can be called on null values, as they're just syntactic sugar. May 12, 2015 at 7:47
  • 2
    yield return is a nice idea in principle, but it has so many weird gotchas. Thanks for bringing this one to light!
    – nateirvin
    May 12, 2015 at 21:32
  • So, basically an error would be thrown if the enumarator was run, as in a foreach?
    – MVCDS
    May 13, 2015 at 15:42
  • 1
    @MVCDS Exactly. MoveNext is called under the hood by the foreach construct. I wrote an explanation of what foreach does in my answer explaining collection semantics if you'd like to see the exact pattern. May 13, 2015 at 15:46
34

You have an iterator block. None of the code in that method is ever run outside of calls to MoveNext on the returned iterator. Calling the method does noting but create the state machine, and that won't ever fail (outside of extremes such as out of memory errors, stack overflows, or thread abort exceptions).

When you actually attempt to iterate the sequence you'll get the exceptions.

This is why the LINQ methods actually need two methods to have the error handling semantics they desire. They have a private method that is an iterator block, and then a non-iterator block method that does nothing but do the argument validation (so that it can be done eagerly, rather than it being deferred) while still deferring all other functionality.

So this is the general pattern:

public static IEnumerable<T> Foo<T>(
    this IEnumerable<T> souce, Func<T, bool> anotherArgument)
{
    //note, not an iterator block
    if(anotherArgument == null)
    {
        //TODO make a fuss
    }
    return FooImpl(source, anotherArgument);
}

private static IEnumerable<T> FooImpl<T>(
    IEnumerable<T> souce, Func<T, bool> anotherArgument)
{
    //TODO actual implementation as an iterator block
    yield break;
}
0
0

Enumerators, as the others have said, aren't evaluated until the time they start getting enumerated (i.e. the IEnumerable.GetNext method is called). Thus this

List<int> indexes = "a.b.c.d.e".AllIndexesOf(null).ToList<int>();

doesn't get evaluated until you start enumerating, i.e.

foreach(int index in indexes)
{
    // ArgumentNullException
}

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.