Skip to content

Commit e3d8667

Browse files
authored
Merge pull request #1329 from mcneel/kike/1.32
2 parents 954be49 + 4de7c4f commit e3d8667

File tree

3 files changed

+478
-442
lines changed

3 files changed

+478
-442
lines changed
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
#if !NET
2+
using System.Collections.Generic;
3+
using System.ComponentModel;
4+
using System.Diagnostics;
5+
using System.Linq;
6+
using System.Reflection;
7+
using System.Threading;
8+
9+
namespace System.Runtime.Loader
10+
{
11+
class AssemblyLoadContext
12+
{
13+
static AssemblyLoadContext()
14+
{
15+
AppDomain = AppDomain.CurrentDomain;
16+
17+
var assemblyResolve = _AssemblyResolve.GetValue(AppDomain) as ResolveEventHandler;
18+
if (assemblyResolve?.GetInvocationList() is Delegate[] invocationList)
19+
{
20+
foreach (var invocation in invocationList)
21+
AppDomain.AssemblyResolve -= invocation as ResolveEventHandler;
22+
23+
AppDomain.AssemblyResolve += Resolve;
24+
25+
foreach (var invocation in invocationList)
26+
AppDomain.AssemblyResolve += invocation as ResolveEventHandler;
27+
}
28+
else AppDomain.AssemblyResolve += Resolve;
29+
30+
AppDomain.AssemblyLoad += AssemblyLoad;
31+
}
32+
33+
private static readonly AppDomain AppDomain;
34+
private static readonly List<AssemblyLoadContext> AllContexts = new List<AssemblyLoadContext>();
35+
private static readonly Dictionary<Assembly, AssemblyLoadContext> AssemblyContexts = new Dictionary<Assembly, AssemblyLoadContext>();
36+
public static IEnumerable<AssemblyLoadContext> All => AllContexts;
37+
public static readonly AssemblyLoadContext Default = new AssemblyLoadContext("Default", isCollectible: false);
38+
public static AssemblyLoadContext GetLoadContext(Assembly assembly)
39+
{
40+
if (assembly is null) throw new ArgumentNullException(nameof(assembly));
41+
return AssemblyContexts.TryGetValue(assembly, out var context) ? context : Default;
42+
}
43+
public static AssemblyName GetAssemblyName(string assemblyPath) => AssemblyName.GetAssemblyName(assemblyPath);
44+
45+
#region ContextualReflectionContext
46+
private static readonly AsyncLocal<AssemblyLoadContext> _CurrentContextualReflectionContext = new AsyncLocal<AssemblyLoadContext>();
47+
public static AssemblyLoadContext CurrentContextualReflectionContext => _CurrentContextualReflectionContext?.Value;
48+
49+
public ContextualReflectionScope EnterContextualReflection() => new ContextualReflectionScope(this);
50+
public static ContextualReflectionScope EnterContextualReflection(Assembly activating)
51+
{
52+
if (activating == null)
53+
return new ContextualReflectionScope(null);
54+
55+
return GetLoadContext(activating).EnterContextualReflection();
56+
}
57+
58+
[EditorBrowsable(EditorBrowsableState.Never)]
59+
public readonly struct ContextualReflectionScope : IDisposable
60+
{
61+
private readonly AssemblyLoadContext _previous;
62+
63+
internal ContextualReflectionScope(AssemblyLoadContext activating)
64+
{
65+
_previous = CurrentContextualReflectionContext;
66+
_CurrentContextualReflectionContext.Value = activating;
67+
}
68+
69+
public void Dispose()
70+
{
71+
_CurrentContextualReflectionContext.Value = _previous;
72+
}
73+
}
74+
#endregion
75+
76+
public override string ToString() => $"\"{Name}\" {GetType()} #{Id}";
77+
78+
private readonly int Id;
79+
public string Name { get; }
80+
public bool IsCollectible => !AppDomain.IsDefaultAppDomain();
81+
public IEnumerable<Assembly> Assemblies => AppDomain.GetAssemblies().Where(x => ReferenceEquals(GetLoadContext(x), this));
82+
83+
public event Func<AssemblyLoadContext, AssemblyName, Assembly> Resolving;
84+
public event Func<Assembly, string, IntPtr> ResolvingUnmanagedDll;
85+
86+
public AssemblyLoadContext() : this(default, default) { }
87+
public AssemblyLoadContext(bool isCollectible) : this(default, isCollectible) { }
88+
public AssemblyLoadContext(string name, bool isCollectible = false)
89+
{
90+
if (isCollectible && AppDomain.CurrentDomain.IsDefaultAppDomain()) throw new NotSupportedException();
91+
92+
Name = name;
93+
94+
lock (AllContexts)
95+
{
96+
AllContexts.Add(this);
97+
Id = AllContexts.Count;
98+
}
99+
}
100+
101+
public Assembly LoadFromAssemblyName(AssemblyName assemblyName)
102+
{
103+
using (EnterContextualReflection())
104+
return AppDomain.Load(assemblyName);
105+
}
106+
107+
public Assembly LoadFromAssemblyPath(string assemblyPath)
108+
{
109+
using (EnterContextualReflection())
110+
return Assembly.LoadFrom(assemblyPath);
111+
}
112+
113+
protected virtual Assembly Load(AssemblyName assemblyName) => null;
114+
115+
protected virtual IntPtr LoadUnmanagedDll(string unmanagedDllName) => IntPtr.Zero;
116+
117+
#region Resolve
118+
private static readonly FieldInfo _AssemblyResolve = typeof(AppDomain).GetField("_AssemblyResolve", BindingFlags.Instance | BindingFlags.NonPublic);
119+
private static Assembly Resolve(object sender, ResolveEventArgs args)
120+
{
121+
if ((args.RequestingAssembly ?? GetRequestingAssembly()) is Assembly requestingAssembly)
122+
{
123+
if (GetLoadContext(requestingAssembly) is AssemblyLoadContext context)
124+
{
125+
var assemblyName = new AssemblyName(args.Name);
126+
127+
if (context.ResolveUsingLoad(assemblyName) is Assembly loadedAssembly)
128+
return loadedAssembly;
129+
130+
return context.ResolveUsingEvent(assemblyName);
131+
}
132+
}
133+
134+
return null;
135+
}
136+
137+
private static void AssemblyLoad(object sender, AssemblyLoadEventArgs args)
138+
{
139+
if (args.LoadedAssembly.ReflectionOnly) return;
140+
141+
var context = CurrentContextualReflectionContext;
142+
if (context is null && GetRequestingAssembly() is Assembly requestingAssembly)
143+
AssemblyContexts.TryGetValue(requestingAssembly, out context);
144+
145+
if (context is object)
146+
AssemblyContexts.Add(args.LoadedAssembly, context);
147+
}
148+
149+
private Assembly ResolveUsingLoad(AssemblyName assemblyName)
150+
{
151+
try
152+
{
153+
if (Load(assemblyName) is Assembly loadedAssembly)
154+
return AssertValidAssemblyName(loadedAssembly, assemblyName);
155+
}
156+
catch {}
157+
158+
return null;
159+
}
160+
161+
private Assembly ResolveUsingEvent(AssemblyName assemblyName)
162+
{
163+
if (Resolving?.GetInvocationList() is Delegate[] invocationList)
164+
{
165+
foreach (Func<AssemblyLoadContext, AssemblyName, Assembly> resolver in invocationList)
166+
{
167+
try
168+
{
169+
if (resolver(this, assemblyName) is Assembly resolvedAssembly)
170+
return AssertValidAssemblyName(resolvedAssembly, assemblyName);
171+
}
172+
catch { }
173+
}
174+
}
175+
176+
return null;
177+
}
178+
179+
private static Assembly AssertValidAssemblyName(Assembly assembly, AssemblyName assemblyName)
180+
{
181+
var loadedName = assembly.GetName().Name;
182+
if (string.IsNullOrEmpty(loadedName) || !assemblyName.Name.Equals(loadedName, StringComparison.InvariantCultureIgnoreCase))
183+
throw new InvalidOperationException("Invalid assembly name.");
184+
185+
return assembly;
186+
}
187+
188+
static Assembly GetRequestingAssembly()
189+
{
190+
var trace = new StackTrace(1);
191+
var frames = trace.GetFrames();
192+
193+
var callingAssembly = Assembly.GetCallingAssembly();
194+
195+
// Skip Calling Assembly
196+
int f = 0;
197+
for (; f < frames.Length; ++f)
198+
{
199+
var method = frames[f].GetMethod();
200+
if (method is null) continue;
201+
var frameAssembly = method.DeclaringType?.Assembly ?? method.Module?.Assembly;
202+
if (frameAssembly != callingAssembly)
203+
break;
204+
}
205+
206+
// Skip mscorlib
207+
for (; f < frames.Length; ++f)
208+
{
209+
var method = frames[f].GetMethod();
210+
if (method is null) continue;
211+
var frameAssembly = method.DeclaringType?.Assembly ?? method.Module?.Assembly;
212+
if (frameAssembly != typeof(object).Assembly)
213+
return frameAssembly;
214+
}
215+
216+
return null;
217+
}
218+
#endregion
219+
}
220+
}
221+
#endif

0 commit comments

Comments
 (0)