MLLIF
a MLIR-based Language to Language Interoperability Flyover
Loading...
Searching...
No Matches
BindGenerator.cs
Go to the documentation of this file.
1using System.Text;
2using Microsoft.CodeAnalysis;
3
5
6[Generator(LanguageNames.CSharp)]
7public class BindGenerator : IIncrementalGenerator
8{
9 public void Initialize(IncrementalGeneratorInitializationContext context)
10 {
11 context.RegisterPostInitializationOutput(static context =>
12 {
13 context.AddSource("ExportAttribute.cs",
14 """
15 // <auto-generated />
16 using System;
17
18 namespace MLLIF {
19 [AttributeUsage(AttributeTargets.Method, AllowMultiple=false, Inherited=false)]
20 internal sealed class ExportAttribute : Attribute {}
21
22 public static class Internal {
23 public delegate void FreeDelegate(IntPtr ptr);
24
25 public static void Free(IntPtr ptr) {
26 System.Runtime.InteropServices.GCHandle.FromIntPtr(ptr).Free();
27 }
28 }
29 }
30 """);
31 });
32
33 var source = context.SyntaxProvider.ForAttributeWithMetadataName(
34 "MLLIF.ExportAttribute",
35 static (_, _) => true,
36 static (context, _) => context);
37
38 context.RegisterSourceOutput(source, Emit);
39 }
40
41 private static string ToNativeType(ITypeSymbol type)
42 {
43 if (type.IsReferenceType)
44 return "System.IntPtr";
45
46 return type.ToDisplayString();
47 }
48
49 private static Diagnostic CannotMangle(IMethodSymbol method)
50 => Diagnostic.Create(
51 "MLLIF02", "",
52 $"'{method.ToDisplayString()}' cannot be mangled",
53 DiagnosticSeverity.Error, DiagnosticSeverity.Error,
54 true, 0);
55
56 private void Emit(SourceProductionContext spc, GeneratorAttributeSyntaxContext gasc)
57 {
58 var symbol = (IMethodSymbol)gasc.TargetSymbol;
59
60 if (symbol.MangleName(spc) is not { } mangling)
61 {
62 spc.ReportDiagnostic(CannotMangle(symbol));
63 return;
64 }
65
66 var builder = new StringBuilder();
67
68 builder.Append("\tpublic delegate ");
69 builder.Append(ToNativeType(symbol.ReturnType));
70 builder.Append(" __delegate");
71 builder.Append(mangling);
72 builder.Append('(');
73
74 if (!symbol.IsStatic)
75 {
76 builder.Append("System.IntPtr self");
77 if (symbol.Parameters.Length > 0)
78 builder.Append(", ");
79 }
80
81 for (var i = 0; i < symbol.Parameters.Length; i++)
82 {
83 if (i > 0)
84 builder.Append(", ");
85
86 var param = symbol.Parameters[i];
87 builder
88 .Append(ToNativeType(param.Type))
89 .Append(' ')
90 .Append(param.Name);
91 }
92
93 builder.Append(");\n");
94
95
96 builder.Append("\tpublic static ");
97 builder.Append(ToNativeType(symbol.ReturnType));
98 builder.Append(' ');
99 builder.Append(mangling);
100 builder.Append('(');
101
102 if (!symbol.IsStatic)
103 {
104 builder.Append("System.IntPtr self");
105 if (symbol.Parameters.Length > 0)
106 builder.Append(", ");
107 }
108
109 for (var i = 0; i < symbol.Parameters.Length; i++)
110 {
111 if (i > 0)
112 builder.Append(", ");
113
114 var param = symbol.Parameters[i];
115 builder
116 .Append(ToNativeType(param.Type))
117 .Append(' ')
118 .Append(param.Name);
119 }
120
121 builder.Append(") {\n");
122
123 if (!symbol.IsStatic)
124 builder.Append("\t\tvar handle = System.Runtime.InteropServices.GCHandle.FromIntPtr(self);\n");
125
126 builder.Append("\t\treturn ");
127
128 if (symbol.ReturnType.IsReferenceType)
129 {
130 builder.Append("System.Runtime.InteropServices.GCHandle.ToIntPtr(");
131 builder.Append("System.Runtime.InteropServices.GCHandle.Alloc(");
132 }
133
134 if (!symbol.IsStatic)
135 builder.Append("(handle.Target as ").Append(symbol.ContainingType.ToDisplayString()).Append(").");
136
137 builder.Append(symbol.Name).Append('(');
138
139 for (var i = 0; i < symbol.Parameters.Length; i++)
140 {
141 if (i > 0)
142 builder.Append(", ");
143
144 var param = symbol.Parameters[i];
145
146 if (param.Type.IsReferenceType)
147 builder.Append("System.Runtime.InteropServices.GCHandle.FromIntPtr(");
148 builder.Append(param.Name);
149 if (param.Type.IsReferenceType)
150 builder.Append(").Target as ").Append(param.Type.ToDisplayString());
151 }
152
153 if (symbol.ReturnType.IsReferenceType)
154 builder.Append("))");
155
156 builder.Append(");\n");
157
158 builder.Append("\t}\n");
159
160 spc.AddSource($"{symbol.ContainingType.Name}.{symbol.Name}.g.cs",
161 $$"""
162 namespace {{symbol.ContainingType.ContainingNamespace.ToDisplayString()}};
163
164 public partial class {{symbol.ContainingType.Name}} {
165 {{builder}}
166 }
167 """);
168 }
169}
void Initialize(IncrementalGeneratorInitializationContext context)