This is the second part of Traits in C# series. For your convenience you can find other parts in the table of contents in Part 1 — Basic implementation with Fody

Last time we saw simple implementation of traits in C#. Today we are going to handle various overrides. Let’s go.

Test code

We want to write the following application:

using System;
using TraitIntroducer;

namespace TraitsDemo
{
    public interface IA
    {
    }

    public interface IB
    {
    }

    public interface IC
    {
    }

    [TraitFor(typeof(IA))]
    public static class IA_Implementation
    {
        public static void Print(this IA instance)
        {
            Console.WriteLine("I'm IA");
        }
    }

    [TraitFor(typeof(IB))]
    public static class IB_Implementation
    {
        public static void Print(this IB instance)
        {
            Console.WriteLine("I'm IB");
        }
    }

    [TraitFor(typeof(IC))]
    public static class IC_Implementation
    {
        public static void Print(this IC instance)
        {
            Console.WriteLine("I'm IC");
        }
    }

    [TraitFor(typeof(IC))]
    public static class IC_Implementation2
    {
        public static void Print(this IC instance)
        {
            Console.WriteLine("I'm IC2");
        }
    }

    public class A : IA
    {
        public virtual void Print()
        {
            Console.WriteLine("I'm A");
        }
    }

    public class B : A
    {
    }

    public class C : B
    {
        public override void Print()
        {
            Console.WriteLine("I'm C");
        }
    }

    public class D : C, IB, IC
    {
    }

    class Program
    {
        static void Main(string[] args)
        {
            IA a = new D();
            a.Print(); // Should print "I'm IC2"
        }
    }
}

What’s exactly going on here? We have three interfaces (they do not inherit from each other). We also have a simple class hierarchy: first class contains virtual method Print, and implements first interface. Second clas inherits from first class and nothing else. Third class inherits from second class and overrides method. Fourth class inherits from third class and implements two additional interfaces.

We create instance of last class and store its reference with first interface as type. Next, we call method. If we do not wire up traits logic, it should print “I’m IA”, since it should be bound during compilation. However, our hierarchy should look like this:
IA_Implementation.Print -> A.Print -> C.Print -> IB_Implementation.Print -> IC_Implementation.Print -> IC_Implementation2.Print
so we would like to see “I’m IC2” printed out.And this is the task we are going to implemented today: selecting correct override. Next time we will see how to implement base.Print in order to be able to stack traits.

Implementation

Since we want to select overrides, it might be helpful to be able to extract whole inheritance hierarchy. Let’s start with the following code:

private IEnumerable< TypeDefinition> GetTypeHierarchy(TypeDefinition type)
{
	if (type == null)
	{
		return Enumerable.Empty< TypeDefinition>();
	}

	return GetTypeHierarchy(type.BaseType as TypeDefinition)
		.Concat(type.Interfaces.OfType< TypeDefinition>().SelectMany(GetTypeHierarchy))
		.Concat(new [] {type})
		.GroupBy(t => t.FullName)
		.Select(x => x.First());
}

If passed type is null, we return empty collection. In other case we get base type’s hierarchy, add interfaces implemented by us, add ourselves, and finally remove duplicates. So for class D in demo program we would like to have the following class hierarchy:

IA -> A -> B -> C -> IB -> IC -> D

We omit System.Object here.

With this tool in hand we can implement weaver:

public void Execute()
{
	AllTypes = ModuleDefinition.Types.Concat(ModuleDefinition.GetTypeReferences().Select(t => t.Resolve())).ToList();
	TraitForAttributeTypeDefinition = AllTypes.FirstOrDefault(t => t.FullName == typeof(TraitForAttribute).FullName);
	if (TraitForAttributeTypeDefinition == null) return;

	var orderedTypes = AllTypes.OrderBy(type => GetTypeHierarchy(type).Count());

	foreach (var type in orderedTypes.Where(t => t.IsClass))
	{
		FixClass(type);
	}
	foreach (var type in orderedTypes.Where(t => t.IsInterface))
	{
		FixInterface(type);
	}
}

We first get all types from solution, and type for our attribute. Next, we sort types by their hierarchy — we want to modify types in top-down manner so this ordering should be good for us. Next, we fix all classes first, and after that we fix interfaces.

In order to fix anything, we need to be able to extract extension classes for specified interface:

private IEnumerable< TypeDefinition> GetExtenders(TypeDefinition typeToExtend)
{
	return AllTypes
		.Where(type => type.CustomAttributes.Any(attribute => attribute.AttributeType.FullName == TraitForAttributeTypeDefinition.FullName))
		.Where(type =>
		{
			var extendedInterafaceType = type.CustomAttributes.First(attribute => attribute.AttributeType.FullName == TraitForAttributeTypeDefinition.FullName).ConstructorArguments.First().Value as TypeDefinition;
			var extendedInterfaceTypeDefinition = AllTypes.First(t => t.FullName == extendedInterafaceType.FullName);
			return extendedInterfaceTypeDefinition.FullName == typeToExtend.FullName;
		});
}

We iterate over all types, find types implementing interfaces, and check for interface types. Since we are allowed to have multiple extension classes for single interface, we need to return collection.

Now we can fix classes. For each class we need to add methods from extension classes implementing interfaces implemented by this specific class (and not by some ancestor), and add virtual method for each method from interface:

private void FixClass(TypeDefinition type)
{
	var hierarchy = GetTypeHierarchy(type).Reverse().Skip(1).TakeWhile(t => t.IsInterface).Reverse().ToArray();
	var introducedMethods = new Dictionary< string, Tuple< MethodDefinition, MethodDefinition>>();
	foreach (var implementedInterface in hierarchy)
	{
		var extenders = GetExtenders(implementedInterface);
		foreach (var extender in extenders)
		{
			foreach (var method in extender.Methods)
			{
				var injected = InjectExtensionMethodToInheritor(type, extender, method);
				var matchingMethod = type.Methods.FirstOrDefault(m => m.Name == method.Name);
				if (matchingMethod != null)
				{
					matchingMethod.Attributes |= MethodAttributes.Virtual;
					matchingMethod.Attributes &= ~MethodAttributes.NewSlot;
				}
				else
				{
					introducedMethods[method.Name] = Tuple.Create(method, injected);
				}
			}
		}
	}

	foreach (var introducedMethod in introducedMethods.Values)
	{
		InjectVirtualMethodToInheritor(type, introducedMethod.Item1, introducedMethod.Item2);
	}
}

Looks like black magic, so let’s go step by step. We first get hierarchy for modified type. Since we want to introduce only methods from interfaces implemented by our class (and not by ancestor), from hierarchy we take only last interfaces.

Next, we create dictionary for introduced methods — we want to call method introduced at the end, so we need to save its instance for later.

We iterate over all interfaces, over all extension classes, and over all methods. For each such a method we inject it. Injected method will have name in the form {methodName}_{extensionClasName}. So for method Print of interface IC implemented in IC_Implementation2 we will name the method Print_IC_Implementation2.

Next, we check if we have method with matching name already in the class. So in this case we check if we have Print in our class implemented. If it is the case, we fix its attributes — in order to make method virtual. If we don’t have such a method, we save instance of injected method from interface in order to be able to add it later.

Finally, when three nested loops are done, we iterate over dictionary, and add all missing virtual methods.

These are methods injecting code for concrete classes:

private string GetInheritorMethodName(TypeDefinition extender, MethodDefinition method)
{
	return $"{method.Name}_{extender.Name}";
}

private MethodDefinition InjectExtensionMethodToInheritor(TypeDefinition inheritor, TypeDefinition extender, MethodDefinition method)
{
	var newMethod = new MethodDefinition(GetInheritorMethodName(extender, method), MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig | MethodAttributes.NewSlot, method.ReturnType);
	foreach (var instruction in method.Body.Instructions)
	{
		newMethod.Body.Instructions.Add(instruction);
	}
	inheritor.Methods.Add(newMethod);

	return newMethod;
}

private MethodDefinition InjectVirtualMethodToInheritor(TypeDefinition inheritor, MethodDefinition method, MethodDefinition methodToCall)
{
	var newMethod = new MethodDefinition(method.Name, MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig, method.ReturnType);
	newMethod.Body.Instructions.Add(Instruction.Create(OpCodes.Ldarg_0));
	newMethod.Body.Instructions.Add(Instruction.Create(OpCodes.Callvirt, methodToCall));
	newMethod.Body.Instructions.Add(Instruction.Create(OpCodes.Ret));
	inheritor.Methods.Add(newMethod);

	return newMethod;
}

Extension method is introduced as is, virtual method calls most recent injected method.

Concrete classes are done. We need to fix interfaces:

private void FixInterface(TypeDefinition type)
{
	var hierarchy = GetTypeHierarchy(type).Reverse().Skip(1).ToArray();
	var extenders = GetExtenders(type);
	foreach (var extender in extenders)
	{
		foreach (var method in extender.Methods)
		{
			var existingMethod = hierarchy.Select(t => t.Methods.FirstOrDefault(m => m.Name == method.Name)).FirstOrDefault(m => m != null) ?? type.Methods.FirstOrDefault(m => m.Name == method.Name);
			if (existingMethod == null)
			{
				InjectMethodToInterface(type, method);
			}
			else
			{
				FixTraitMethod(method, existingMethod);
			}
		}
	}
}

We first calculate hierarchy, and find all extension classes. Next, for each method we first check if any of our ancestor contains method with the same name, or if we already have this method. If it is not the case, we inject method:

private void InjectMethodToInterface(TypeDefinition extendedInterface, MethodDefinition method)
{
	var newMethod = new MethodDefinition(method.Name, MethodAttributes.Abstract | MethodAttributes.Virtual | MethodAttributes.HideBySig | MethodAttributes.NewSlot, method.ReturnType);
	extendedInterface.Methods.Add(newMethod);
	FixTraitMethod(method, newMethod);
}

No magic here, we do the same, as in the last part. We also need to fix method in extension class:

private void FixTraitMethod(MethodDefinition method, MethodDefinition methodToCall)
{
	method.Body.Instructions.Clear();
	method.Body.Instructions.Add(Instruction.Create(OpCodes.Ldarg_0));
	method.Body.Instructions.Add(Instruction.Create(OpCodes.Callvirt, methodToCall));
	method.Body.Instructions.Add(Instruction.Create(OpCodes.Ret));
}

And we are done.

Result

Let’s compile the code and see exactly the result. These are interfaces:

// Decompiled with JetBrains decompiler
// Type: TraitsDemo.IA
// Assembly: TraitsDemo, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null
// MVID: 7CA9337F-2598-4715-B2DF-60EC6ECB1D59
// Assembly location: C:\Adam\TraitsDemo\bin\Debug\TraitsDemo.exe

namespace TraitsDemo
{
  public interface IA
  {
    void Print();
  }
}


// Decompiled with JetBrains decompiler
// Type: TraitsDemo.IB
// Assembly: TraitsDemo, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null
// MVID: 7CA9337F-2598-4715-B2DF-60EC6ECB1D59
// Assembly location: C:\Adam\TraitsDemo\bin\Debug\TraitsDemo.exe

namespace TraitsDemo
{
  public interface IB
  {
    void Print();
  }
}


// Decompiled with JetBrains decompiler
// Type: TraitsDemo.IC
// Assembly: TraitsDemo, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null
// MVID: 55F275F7-2F8B-40AB-B38A-3F6A73C027AF
// Assembly location: C:\Adam\TraitsDemo\bin\Debug\TraitsDemo.exe

namespace TraitsDemo
{
  public interface IC
  {
    void Print();
  }
}

Classes:

// Decompiled with JetBrains decompiler
// Type: TraitsDemo.A
// Assembly: TraitsDemo, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null
// MVID: 7CA9337F-2598-4715-B2DF-60EC6ECB1D59
// Assembly location: C:\Adam\TraitsDemo\bin\Debug\TraitsDemo.exe

using System;
using TraitIntroducer;

namespace TraitsDemo
{
  public class A : IA
  {
    public virtual void Print()
    {
      Console.WriteLine("I'm A");
    }

    public virtual void Print_IA_Implementation()
    {
      Console.WriteLine("I'm IA");
    }
  }
}

// Decompiled with JetBrains decompiler
// Type: TraitsDemo.B
// Assembly: TraitsDemo, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null
// MVID: 7CA9337F-2598-4715-B2DF-60EC6ECB1D59
// Assembly location: C:\Adam\TraitsDemo\bin\Debug\TraitsDemo.exe

namespace TraitsDemo
{
  public class B : A
  {
  }
}


// Decompiled with JetBrains decompiler
// Type: TraitsDemo.C
// Assembly: TraitsDemo, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null
// MVID: 7CA9337F-2598-4715-B2DF-60EC6ECB1D59
// Assembly location: C:\Adam\TraitsDemo\bin\Debug\TraitsDemo.exe

using System;
using TraitIntroducer;

namespace TraitsDemo
{
  public class C : B
  {
    public override void Print()
    {
      Console.WriteLine("I'm C");
    }
  }
}

// Decompiled with JetBrains decompiler
// Type: TraitsDemo.D
// Assembly: TraitsDemo, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null
// MVID: 9C15A304-C4AC-4E6C-991B-67F70B172CB3
// Assembly location: C:\Adam\TraitsDemo\bin\Debug\TraitsDemo.exe

using System;
using TraitIntroducer;

namespace TraitsDemo
{
  public class D : C, IB, IC
  {
    public virtual void Print_IB_Implementation()
    {
      Console.WriteLine("I'm IB");
    }

    public virtual void Print_IC_Implementation()
    {
      Console.WriteLine("I'm IC");
    }

    public virtual void Print_IC_Implementation2()
    {
      Console.WriteLine("I'm IC2");
    }

    public override void Print()
    {
      this.Print_IC_Implementation2();
    }
  }
}

Test program:

// Decompiled with JetBrains decompiler
// Type: TraitsDemo.Program
// Assembly: TraitsDemo, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null
// MVID: 9C15A304-C4AC-4E6C-991B-67F70B172CB3
// Assembly location: C:\Adam\TraitsDemo\bin\Debug\TraitsDemo.exe

namespace TraitsDemo
{
  internal class Program
  {
    private static void Main(string[] args)
    {
      IA_Implementation.Print(new D());
    }
  }
}

And extension classes:

// Decompiled with JetBrains decompiler
// Type: TraitsDemo.IA_Implementation
// Assembly: TraitsDemo, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null
// MVID: 9C15A304-C4AC-4E6C-991B-67F70B172CB3
// Assembly location: C:\Adam\TraitsDemo\bin\Debug\TraitsDemo.exe

using TraitIntroducer;

namespace TraitsDemo
{
  [TraitFor(typeof (IA))]
  public static class IA_Implementation
  {
    public static void Print(this IA instance)
    {
      instance.Print();
    }
  }
}

// Decompiled with JetBrains decompiler
// Type: TraitsDemo.IB_Implementation
// Assembly: TraitsDemo, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null
// MVID: 9C15A304-C4AC-4E6C-991B-67F70B172CB3
// Assembly location: C:\Adam\TraitsDemo\bin\Debug\TraitsDemo.exe

using TraitIntroducer;

namespace TraitsDemo
{
  [TraitFor(typeof (IB))]
  public static class IB_Implementation
  {
    public static void Print(this IB instance)
    {
      instance.Print();
    }
  }
}

// Decompiled with JetBrains decompiler
// Type: TraitsDemo.IC_Implementation
// Assembly: TraitsDemo, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null
// MVID: 9C15A304-C4AC-4E6C-991B-67F70B172CB3
// Assembly location: C:\Adam\TraitsDemo\bin\Debug\TraitsDemo.exe

using TraitIntroducer;

namespace TraitsDemo
{
  [TraitFor(typeof (IC))]
  public static class IC_Implementation
  {
    public static void Print(this IC instance)
    {
      instance.Print();
    }
  }
}

// Decompiled with JetBrains decompiler
// Type: TraitsDemo.IC_Implementation2
// Assembly: TraitsDemo, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null
// MVID: 9C15A304-C4AC-4E6C-991B-67F70B172CB3
// Assembly location: C:\Adam\TraitsDemo\bin\Debug\TraitsDemo.exe

using System;
using TraitIntroducer;

namespace TraitsDemo
{
  [TraitFor(typeof (IC))]
  public static class IC_Implementation2
  {
    public static void Print(this IC instance)
    {
      instance.Print();
    }
  }
}

So we create instance of class D, call method IA_Implementation.Print which in turn calls instance.Print(), which goes down to D.Print, which finally calls this.Print_IC_Implementation2(), and this prints “I’m IC2”. We are done!

Summary

We saw how to inject correct methods and bind them during compile time. In next part we are going to implement traits stacking, which will allow us to create decorators in really nice way, just like we can do in Scala.
You can find whole code in this gist.
Bonus chatter: what happens if in interface we define two methods void Print()? Will the .NET load type correctly?
Bonus chatter 2: how does base.Print works internally? Does it perform static or dynamic dispatch? What IL instruction does it use?