4. Kaleidoscope: JITとオプティマイザのサポートを追加する¶
4.1. 第4章 はじめに¶
「LLVMを使った言語の実装」チュートリアルの第4章へようこそ。第1章から第3章では、単純な言語の実装について説明し、LLVM IRの生成をサポートしました。この章では、2つの新しいテクニックについて説明します。言語へのオプティマイザのサポートの追加と、JITコンパイラのサポートの追加です。これらの追加により、Kaleidoscope言語で優れた効率的なコードを取得する方法を示します。
4.2. 簡単な定数畳み込み¶
第3章のデモは、洗練されていて拡張しやすいものです。しかし、残念ながら素晴らしいコードは生成されません。ただし、IRBuilderは、単純なコードをコンパイルする際に明らかな最適化を提供します。
ready> def test(x) 1+2+x;
Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double 3.000000e+00, %x
ret double %addtmp
}
このコードは、入力を解析して構築されたASTをそのまま書き出したものではありません。それは次のようになります。
ready> def test(x) 1+2+x;
Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double 2.000000e+00, 1.000000e+00
%addtmp1 = fadd double %addtmp, %x
ret double %addtmp1
}
特に、上記のような定数畳み込みは、非常に一般的で非常に重要な最適化です。そのため、多くの言語実装者は、AST表現に定数畳み込みのサポートを実装しています。
LLVMを使用すると、ASTでこのサポートは必要ありません。LLVM IRを構築するためのすべての呼び出しはLLVM IRビルダーを経由するため、ビルダー自体が呼び出されたときに定数畳み込みの機会があるかどうかを確認します。もしそうであれば、命令を作成する代わりに、定数畳み込みを実行して定数を返します。
さて、簡単でしたね:)。実際には、このようなコードを生成する際には、常にIRBuilder
を使用することをお勧めします。使用のための「構文上のオーバーヘッド」はなく(いたるところに定数チェックでコンパイラを醜くする必要はありません)、場合によっては生成されるLLVM IRの量を劇的に削減できます(特にマクロプリプロセッサを持つ言語や多くの定数を使用する言語の場合)。
一方、IRBuilder
は、コードが構築されるときにインラインでそのすべての分析を実行するという事実によって制限されます。もう少し複雑な例を見てみましょう。
ready> def test(x) (1+2+x)*(x+(1+2));
ready> Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double 3.000000e+00, %x
%addtmp1 = fadd double %x, 3.000000e+00
%multmp = fmul double %addtmp, %addtmp1
ret double %multmp
}
この場合、乗算のLHSとRHSは同じ値です。「x+3
」を2回計算する代わりに、「tmp = x+3; result = tmp*tmp;
」を生成したいと考えています。
残念ながら、どれだけローカル分析を行っても、これを検出して修正することはできません。これには、2つの変換が必要です。式の再結合(加算を字句的に同一にするため)と共通部分式除去(CSE)(冗長な加算命令を削除するため)です。幸いなことに、LLVMは「パス」という形で、使用できる幅広い最適化を提供しています。
4.3. LLVM最適化パス¶
LLVMは、さまざまな種類の処理を行い、さまざまなトレードオフを持つ、多くの最適化パスを提供します。他のシステムとは異なり、LLVMは、1つの最適化セットがすべての言語とすべての状況に適しているという間違った考えにとらわれていません。LLVMを使用すると、コンパイラの実装者は、どの最適化を使用するか、どの順序で使用するか、どの状況で使用するかを完全に決定できます。
具体的な例として、LLVMは、可能な限り大きなコード本体(多くの場合ファイル全体ですが、リンク時に実行される場合はプログラム全体の相当な部分になる可能性があります)を対象とする「モジュール全体」パスをサポートしています。また、他の関数を見ずに、一度に1つの関数のみを操作する「関数ごと」パスもサポートし、含まれています。パスとその実行方法の詳細については、「パスの書き方」ドキュメントと「LLVMパスのリスト」を参照してください。
Kaleidoscopeの場合、現在、ユーザーが入力するたびに、一度に1つずつ関数をオンザフライで生成しています。この設定では究極の最適化エクスペリエンスを目指しているわけではありませんが、可能な限り簡単で迅速な処理を実行したいと考えています。そのため、ユーザーが関数を入力する際に、関数ごとの最適化をいくつか実行することを選択します。「静的Kaleidoscopeコンパイラ」を作成したい場合は、ファイル全体が解析されるまでオプティマイザの実行を延期するという点を除いて、現在使用しているコードとまったく同じコードを使用します。
関数パスとモジュールパスの区別に加えて、パスは変換パスと分析パスに分類できます。変換パスはIRを変更し、分析パスは他のパスが使用できる情報を計算します。変換パスを追加するには、依存するすべての分析パスを事前に登録する必要があります.
関数ごとの最適化を実行するには、実行したいLLVM最適化を保持および整理するためのFunctionPassManagerを設定する必要があります。それができたら、実行する最適化のセットを追加できます。最適化したいモジュールごとに新しいFunctionPassManagerが必要になるため、前の章(InitializeModule()
)で作成した関数に追加します。
void InitializeModuleAndManagers(void) {
// Open a new context and module.
TheContext = std::make_unique<LLVMContext>();
TheModule = std::make_unique<Module>("KaleidoscopeJIT", *TheContext);
TheModule->setDataLayout(TheJIT->getDataLayout());
// Create a new builder for the module.
Builder = std::make_unique<IRBuilder<>>(*TheContext);
// Create new pass and analysis managers.
TheFPM = std::make_unique<FunctionPassManager>();
TheLAM = std::make_unique<LoopAnalysisManager>();
TheFAM = std::make_unique<FunctionAnalysisManager>();
TheCGAM = std::make_unique<CGSCCAnalysisManager>();
TheMAM = std::make_unique<ModuleAnalysisManager>();
ThePIC = std::make_unique<PassInstrumentationCallbacks>();
TheSI = std::make_unique<StandardInstrumentations>(*TheContext,
/*DebugLogging*/ true);
TheSI->registerCallbacks(*ThePIC, TheMAM.get());
...
グローバルモジュールTheModule
とFunctionPassManagerを初期化したら、フレームワークの他の部分を初期化する必要があります。4つのAnalysisManagersを使用すると、IR階層の4つのレベルすべてで実行される分析パスを追加できます。PassInstrumentationCallbacksとStandardInstrumentationsは、パスインストルメンテーションフレームワークに必要であり、開発者はパス間の動作をカスタマイズできます。
これらのマネージャーが設定されたら、「addPass」呼び出しを連続して使用して、LLVM変換パスをいくつか追加します。
// Add transform passes.
// Do simple "peephole" optimizations and bit-twiddling optzns.
TheFPM->addPass(InstCombinePass());
// Reassociate expressions.
TheFPM->addPass(ReassociatePass());
// Eliminate Common SubExpressions.
TheFPM->addPass(GVNPass());
// Simplify the control flow graph (deleting unreachable blocks, etc).
TheFPM->addPass(SimplifyCFGPass());
この場合、4つの最適化パスを追加することを選択します。ここで選択したパスは、さまざまなコードに役立つかなり標準的な「クリーンアップ」最適化のセットです。ここでは、それらが何をするかについては詳しく説明しません。信じてください、良い出発点です:)。
次に、変換パスで使用される分析パスを登録します.
// Register analysis passes used in these transform passes.
PassBuilder PB;
PB.registerModuleAnalyses(*TheMAM);
PB.registerFunctionAnalyses(*TheFAM);
PB.crossRegisterProxies(*TheLAM, *TheFAM, *TheCGAM, *TheMAM);
}
PassManagerが設定されたら、それを使用する必要があります。これは、新しく作成された関数が構築された後(FunctionAST::codegen()
内)、クライアントに返される前に実行することで行います。
if (Value *RetVal = Body->codegen()) {
// Finish off the function.
Builder.CreateRet(RetVal);
// Validate the generated code, checking for consistency.
verifyFunction(*TheFunction);
// Optimize the function.
TheFPM->run(*TheFunction, *TheFAM);
return TheFunction;
}
ご覧のとおり、これは非常に簡単です。FunctionPassManager
は、LLVM Function *をインプレースで最適化および更新し、(うまくいけば)その本体を改善します。これが設定されたら、上記のテストをもう一度試すことができます。
ready> def test(x) (1+2+x)*(x+(1+2));
ready> Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double %x, 3.000000e+00
%multmp = fmul double %addtmp, %addtmp
ret double %multmp
}
期待どおり、適切に最適化されたコードが得られ、この関数を実行するたびに浮動小数点加算命令が節約されます。
LLVMは、特定の状況で使用できるさまざまな最適化を提供します。いくつかのパスに関するドキュメントが利用可能ですが、あまり完全ではありません。別の良いアイデアのソースは、Clang
が開始するために実行するパスを見ることです。「opt
」ツールを使用すると、コマンドラインからパスを試すことができるため、何かを行うかどうかを確認できます。
フロントエンドから適切なコードが得られたので、実行について話しましょう!
4.4. JITコンパイラの追加¶
LLVM IRで利用可能なコードには、さまざまなツールを適用できます。たとえば、最適化を実行したり(上記のように)、テキスト形式またはバイナリ形式でダンプしたり、ターゲットのアセンブリファイル(.s)にコンパイルしたり、JITコンパイルしたりできます。LLVM IR表現の良いところは、コンパイラの多くの異なる部分間の「共通通貨」であるということです。
このセクションでは、インタプリタにJITコンパイラのサポートを追加します。Kaleidoscopeの基本的なアイデアは、ユーザーが現在行っているように関数本体を入力できるようにすることですが、入力したトップレベルの式をすぐに評価することです。たとえば、「1 + 2;」と入力した場合、3を評価して出力する必要があります。関数を定義した場合、コマンドラインから呼び出すことができます。
これを行うために、まず現在のネイティブターゲットのコードを作成するための環境を準備し、JITを宣言して初期化します。これは、InitializeNativeTarget\*
関数を呼び出してグローバル変数TheJIT
を追加し、main
で初期化することで行います。
static std::unique_ptr<KaleidoscopeJIT> TheJIT;
...
int main() {
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
InitializeNativeTargetAsmParser();
// Install standard binary operators.
// 1 is lowest precedence.
BinopPrecedence['<'] = 10;
BinopPrecedence['+'] = 20;
BinopPrecedence['-'] = 20;
BinopPrecedence['*'] = 40; // highest.
// Prime the first token.
fprintf(stderr, "ready> ");
getNextToken();
TheJIT = std::make_unique<KaleidoscopeJIT>();
// Run the main "interpreter loop" now.
MainLoop();
return 0;
}
また、JITのデータレイアウトを設定する必要があります.
void InitializeModuleAndPassManager(void) {
// Open a new context and module.
TheContext = std::make_unique<LLVMContext>();
TheModule = std::make_unique<Module>("my cool jit", TheContext);
TheModule->setDataLayout(TheJIT->getDataLayout());
// Create a new builder for the module.
Builder = std::make_unique<IRBuilder<>>(*TheContext);
// Create a new pass manager attached to it.
TheFPM = std::make_unique<legacy::FunctionPassManager>(TheModule.get());
...
KaleidoscopeJITクラスは、これらのチュートリアル専用に構築されたシンプルなJITであり、LLVMソースコード内のllvm-src/examples/Kaleidoscope/include/KaleidoscopeJIT.hにあります。後の章では、その仕組みと新機能を追加する方法について説明しますが、ここでは所与のものとします。そのAPIは非常にシンプルです。addModule
はLLVM IRモジュールをJITに追加し、その関数を(ResourceTracker
によってメモリが管理されて)実行できるようにします。lookup
を使用すると、コンパイルされたコードへのポインタを検索できます。
このシンプルなAPIを使用して、トップレベルの式を解析するコードを次のように変更できます。
static ExitOnError ExitOnErr;
...
static void HandleTopLevelExpression() {
// Evaluate a top-level expression into an anonymous function.
if (auto FnAST = ParseTopLevelExpr()) {
if (FnAST->codegen()) {
// Create a ResourceTracker to track JIT'd memory allocated to our
// anonymous expression -- that way we can free it after executing.
auto RT = TheJIT->getMainJITDylib().createResourceTracker();
auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
InitializeModuleAndPassManager();
// Search the JIT for the __anon_expr symbol.
auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr"));
assert(ExprSymbol && "Function not found");
// Get the symbol's address and cast it to the right type (takes no
// arguments, returns a double) so we can call it as a native function.
double (*FP)() = ExprSymbol.getAddress().toPtr<double (*)()>();
fprintf(stderr, "Evaluated to %f\n", FP());
// Delete the anonymous expression module from the JIT.
ExitOnErr(RT->remove());
}
解析とコード生成が成功した場合、次の手順は、トップレベルの式を含むモジュールをJITに追加することです。これは、addModuleを呼び出すことで行います。addModuleは、モジュール内のすべての関数のコード生成をトリガーし、後でJITからモジュールを削除するために使用できるResourceTracker
を受け入れます。モジュールがJITに追加されると、モジュールを変更できなくなるため、InitializeModuleAndPassManager()
を呼び出して、後続のコードを保持する新しいモジュールを開きます。
モジュールをJITに追加したら、最終的に生成されたコードへのポインタを取得する必要があります。これは、JITのlookup
メソッドを呼び出し、トップレベルの式関数の名前__anon_expr
を渡すことで行います。この関数を追加したばかりなので、lookup
が結果を返したことをアサートします。
次に、シンボルに対してgetAddress()
を呼び出すことで、__anon_expr
関数のメモリ内アドレスを取得します。 トップレベルの式は、引数を取らずに計算された倍精度浮動小数点数を返す、自己完結型のLLVM関数にコンパイルされることを思い出してください。 LLVM JITコンパイラはネイティブプラットフォームのABIと一致するため、結果のポインタをその型の関数ポインタにキャストして直接呼び出すことができます。 つまり、JITコンパイルされたコードとアプリケーションに静的にリンクされたネイティブマシンコードとの間に違いはありません。
最後に、トップレベルの式の再評価はサポートしていないため、完了したらJITからモジュールを削除して、関連付けられたメモリを解放します。 ただし、数行前に(InitializeModuleAndPassManager
を介して)作成したモジュールは、まだ開いていて新しいコードが追加されるのを待っていることに注意してください。
これらの2つの変更だけで、Kaleidoscopeがどのように動作するようになったかを見てみましょう!
ready> 4+5;
Read top-level expression:
define double @0() {
entry:
ret double 9.000000e+00
}
Evaluated to 9.000000
これで基本的に機能しているように見えます。 関数のダンプは、入力された各トップレベルの式に対して合成する「引数なしで常に倍精度浮動小数点数を返す関数」を示しています。 これは非常に基本的な機能を示していますが、もっとできるでしょうか?
ready> def testfunc(x y) x + y*2;
Read function definition:
define double @testfunc(double %x, double %y) {
entry:
%multmp = fmul double %y, 2.000000e+00
%addtmp = fadd double %multmp, %x
ret double %addtmp
}
ready> testfunc(4, 10);
Read top-level expression:
define double @1() {
entry:
%calltmp = call double @testfunc(double 4.000000e+00, double 1.000000e+01)
ret double %calltmp
}
Evaluated to 24.000000
ready> testfunc(5, 10);
ready> LLVM ERROR: Program used external function 'testfunc' which could not be resolved!
関数定義と呼び出しも機能しますが、最後の行で何かが非常に間違っていました。 呼び出しは有効に見えますが、何が起こったのでしょうか? APIから推測できるように、モジュールはJITの割り当て単位であり、testfuncは匿名式を含むのと同じモジュールの一部でした。 匿名式のメモリを解放するためにJITからそのモジュールを削除したときに、testfunc
の定義も一緒に削除しました。 その後、testfuncを2回目に呼び出そうとしたとき、JITはそれを見つけることができなくなりました。
これを修正する最も簡単な方法は、匿名式を他の関数定義とは別のモジュールに配置することです。 呼び出される各関数にプロトタイプがあり、呼び出される前にJITに追加されている限り、JITはモジュール境界を越えた関数呼び出しを問題なく解決します。 匿名式を別のモジュールに配置することで、他の関数に影響を与えることなく削除できます。
実際には、さらに一歩進んで、すべての関数を独自のモジュールに配置します。 これにより、環境をよりREPLのようにするKaleidoscopeJITの便利なプロパティを利用できます。関数はJITに複数回追加できます(すべての関数が一意の定義を持つ必要があるモジュールとは異なります)。 KaleidoscopeJITでシンボルを検索すると、常に最新の定義が返されます。
ready> def foo(x) x + 1;
Read function definition:
define double @foo(double %x) {
entry:
%addtmp = fadd double %x, 1.000000e+00
ret double %addtmp
}
ready> foo(2);
Evaluated to 3.000000
ready> def foo(x) x + 2;
define double @foo(double %x) {
entry:
%addtmp = fadd double %x, 2.000000e+00
ret double %addtmp
}
ready> foo(2);
Evaluated to 4.000000
各関数を独自のモジュールに配置できるようにするには、開く新しいモジュールごとに以前の関数宣言を再生成する方法が必要です。
static std::unique_ptr<KaleidoscopeJIT> TheJIT;
...
Function *getFunction(std::string Name) {
// First, see if the function has already been added to the current module.
if (auto *F = TheModule->getFunction(Name))
return F;
// If not, check whether we can codegen the declaration from some existing
// prototype.
auto FI = FunctionProtos.find(Name);
if (FI != FunctionProtos.end())
return FI->second->codegen();
// If no existing prototype exists, return null.
return nullptr;
}
...
Value *CallExprAST::codegen() {
// Look up the name in the global module table.
Function *CalleeF = getFunction(Callee);
...
Function *FunctionAST::codegen() {
// Transfer ownership of the prototype to the FunctionProtos map, but keep a
// reference to it for use below.
auto &P = *Proto;
FunctionProtos[Proto->getName()] = std::move(Proto);
Function *TheFunction = getFunction(P.getName());
if (!TheFunction)
return nullptr;
これを有効にするために、各関数の最新の プロトタイプを保持する新しいグローバル変数FunctionProtos
を追加することから始めます。 また、TheModule->getFunction()
の呼び出しを置き換えるための便利なメソッドgetFunction()
も追加します。 この便利なメソッドは、TheModule
で既存の関数宣言を検索し、見つからない場合はFunctionProtosから新しい宣言を生成することで対応します。 CallExprAST::codegen()
では、TheModule->getFunction()
の呼び出しを置き換えるだけです。 FunctionAST::codegen()
では、最初にFunctionProtosマップを更新してから、getFunction()
を呼び出す必要があります。 これを行うことで、以前に宣言された関数の関数宣言を現在のモジュールで常に取得できます。
HandleDefinitionとHandleExternも更新する必要があります
static void HandleDefinition() {
if (auto FnAST = ParseDefinition()) {
if (auto *FnIR = FnAST->codegen()) {
fprintf(stderr, "Read function definition:");
FnIR->print(errs());
fprintf(stderr, "\n");
ExitOnErr(TheJIT->addModule(
ThreadSafeModule(std::move(TheModule), std::move(TheContext))));
InitializeModuleAndPassManager();
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
static void HandleExtern() {
if (auto ProtoAST = ParseExtern()) {
if (auto *FnIR = ProtoAST->codegen()) {
fprintf(stderr, "Read extern: ");
FnIR->print(errs());
fprintf(stderr, "\n");
FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
HandleDefinitionでは、新しく定義された関数をJITに転送し、新しいモジュールを開くための2行を追加します。 HandleExternでは、FunctionProtosにプロトタイプを追加する行を1行追加するだけです。
警告
LLVM-9以降、別々のモジュールでのシンボルの重複は許可されていません。 つまり、以下に示すように、Kaleidoscopeで関数を再定義することはできません。 この部分はスキップしてください。
その理由は、新しいOrcV2 JIT APIが、重複するシンボルを拒否することを含め、静的および動的リンカーのルールに非常に厳密に従おうとしているためです。 シンボル名の一意性を要求することで、追跡のためのキーとして(一意の)シンボル名を使用して、シンボルの同時コンパイルをサポートできます。
これらの変更を加えて、REPLをもう一度試してみましょう(今回は匿名関数のダンプを削除しました。これでアイデアが得られるはずです:)
ready> def foo(x) x + 1;
ready> foo(2);
Evaluated to 3.000000
ready> def foo(x) x + 2;
ready> foo(2);
Evaluated to 4.000000
動作します!
この単純なコードでさえ、驚くほど強力な機能が得られます。これをチェックしてください
ready> extern sin(x);
Read extern:
declare double @sin(double)
ready> extern cos(x);
Read extern:
declare double @cos(double)
ready> sin(1.0);
Read top-level expression:
define double @2() {
entry:
ret double 0x3FEAED548F090CEE
}
Evaluated to 0.841471
ready> def foo(x) sin(x)*sin(x) + cos(x)*cos(x);
Read function definition:
define double @foo(double %x) {
entry:
%calltmp = call double @sin(double %x)
%multmp = fmul double %calltmp, %calltmp
%calltmp2 = call double @cos(double %x)
%multmp4 = fmul double %calltmp2, %calltmp2
%addtmp = fadd double %multmp, %multmp4
ret double %addtmp
}
ready> foo(4.0);
Read top-level expression:
define double @3() {
entry:
%calltmp = call double @foo(double 4.000000e+00)
ret double %calltmp
}
Evaluated to 1.000000
すごい、JITはsinとcosをどのように知っているのでしょうか? 答えは驚くほど簡単です。KaleidoscopeJITには、特定のモジュールで使用できないシンボルを見つけるために使用する簡単なシンボル解決ルールがあります。まず、JITにすでに追加されているすべてのモジュールを、最新のものから最も古いものまで検索して、最新の定義を見つけます。 JIT内で定義が見つからない場合、Kaleidoscopeプロセス自体で「dlsym("sin")
」を呼び出すことで対応します。 「sin
」はJITのアドレス空間内で定義されているため、モジュール内の呼び出しをlibmバージョンのsin
を直接呼び出すようにパッチを適用するだけです。 しかし、場合によっては、これはさらに進みます。sinとcosは標準の数学関の名前であるため、定数フォルダーは、上記の「sin(1.0)
」のように定数で呼び出されたときに、関数呼び出しを正しい結果に直接評価します。
今後、このシンボル解決ルールを調整することで、セキュリティ(JITでコンパイルされたコードで使用可能なシンボルのセットを制限すること)から、シンボル名に基づく動的コード生成、さらには遅延コンパイルまで、あらゆる種類の便利な機能を有効にする方法を見ていきます。
シンボル解決ルールの直接的な利点の1つは、操作を実装するための任意のC ++コードを記述することで、言語を拡張できるようになったことです。 たとえば、以下を追加する場合
#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif
/// putchard - putchar that takes a double and returns 0.
extern "C" DLLEXPORT double putchard(double X) {
fputc((char)X, stderr);
return 0;
}
Windowsの場合、動的シンボルローダーがシンボルを見つけるためにGetProcAddress
を使用するため、実際に関数をエクスポートする必要があることに注意してください。
これで、「extern putchard(x); putchard(120);
」のようにすることで、コンソールに単純な出力を生成できます。これは、コンソールに小文字の「x」を出力します(120は「x」のASCIIコードです)。 同様のコードを使用して、ファイルI / O、コンソール入力、およびKaleidoscopeの他の多くの機能を実装できます。
これで、KaleidoscopeチュートリアルのJITとオプティマイザの章は終わりです。 この時点で、チューリング完全ではないプログラミング言語をコンパイルし、最適化し、ユーザー主導の方法でJITコンパイルできます。 次に、制御フローコンストラクトを使用して言語を拡張する方法を見て、途中でいくつかの興味深いLLVM IRの問題に取り組みます。
4.5. 完全なコードリスト¶
LLVM JITとオプティマイザで拡張された、実行中の例の完全なコードリストを以下に示します。 この例をビルドするには、次を使用します
# Compile
clang++ -g toy.cpp `llvm-config --cxxflags --ldflags --system-libs --libs core orcjit native` -O3 -o toy
# Run
./toy
Linuxでこれをコンパイルする場合は、「-rdynamic」オプションも追加してください。 これにより、外部関数が実行時に正しく解決されるようになります。
コードは次のとおりです
#include "../include/KaleidoscopeJIT.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/StandardInstrumentations.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/GVN.h"
#include "llvm/Transforms/Scalar/Reassociate.h"
#include "llvm/Transforms/Scalar/SimplifyCFG.h"
#include <algorithm>
#include <cassert>
#include <cctype>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <map>
#include <memory>
#include <string>
#include <vector>
using namespace llvm;
using namespace llvm::orc;
//===----------------------------------------------------------------------===//
// Lexer
//===----------------------------------------------------------------------===//
// The lexer returns tokens [0-255] if it is an unknown character, otherwise one
// of these for known things.
enum Token {
tok_eof = -1,
// commands
tok_def = -2,
tok_extern = -3,
// primary
tok_identifier = -4,
tok_number = -5
};
static std::string IdentifierStr; // Filled in if tok_identifier
static double NumVal; // Filled in if tok_number
/// gettok - Return the next token from standard input.
static int gettok() {
static int LastChar = ' ';
// Skip any whitespace.
while (isspace(LastChar))
LastChar = getchar();
if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
IdentifierStr = LastChar;
while (isalnum((LastChar = getchar())))
IdentifierStr += LastChar;
if (IdentifierStr == "def")
return tok_def;
if (IdentifierStr == "extern")
return tok_extern;
return tok_identifier;
}
if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
std::string NumStr;
do {
NumStr += LastChar;
LastChar = getchar();
} while (isdigit(LastChar) || LastChar == '.');
NumVal = strtod(NumStr.c_str(), nullptr);
return tok_number;
}
if (LastChar == '#') {
// Comment until end of line.
do
LastChar = getchar();
while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
if (LastChar != EOF)
return gettok();
}
// Check for end of file. Don't eat the EOF.
if (LastChar == EOF)
return tok_eof;
// Otherwise, just return the character as its ascii value.
int ThisChar = LastChar;
LastChar = getchar();
return ThisChar;
}
//===----------------------------------------------------------------------===//
// Abstract Syntax Tree (aka Parse Tree)
//===----------------------------------------------------------------------===//
namespace {
/// ExprAST - Base class for all expression nodes.
class ExprAST {
public:
virtual ~ExprAST() = default;
virtual Value *codegen() = 0;
};
/// NumberExprAST - Expression class for numeric literals like "1.0".
class NumberExprAST : public ExprAST {
double Val;
public:
NumberExprAST(double Val) : Val(Val) {}
Value *codegen() override;
};
/// VariableExprAST - Expression class for referencing a variable, like "a".
class VariableExprAST : public ExprAST {
std::string Name;
public:
VariableExprAST(const std::string &Name) : Name(Name) {}
Value *codegen() override;
};
/// BinaryExprAST - Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
char Op;
std::unique_ptr<ExprAST> LHS, RHS;
public:
BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
std::unique_ptr<ExprAST> RHS)
: Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
Value *codegen() override;
};
/// CallExprAST - Expression class for function calls.
class CallExprAST : public ExprAST {
std::string Callee;
std::vector<std::unique_ptr<ExprAST>> Args;
public:
CallExprAST(const std::string &Callee,
std::vector<std::unique_ptr<ExprAST>> Args)
: Callee(Callee), Args(std::move(Args)) {}
Value *codegen() override;
};
/// PrototypeAST - This class represents the "prototype" for a function,
/// which captures its name, and its argument names (thus implicitly the number
/// of arguments the function takes).
class PrototypeAST {
std::string Name;
std::vector<std::string> Args;
public:
PrototypeAST(const std::string &Name, std::vector<std::string> Args)
: Name(Name), Args(std::move(Args)) {}
Function *codegen();
const std::string &getName() const { return Name; }
};
/// FunctionAST - This class represents a function definition itself.
class FunctionAST {
std::unique_ptr<PrototypeAST> Proto;
std::unique_ptr<ExprAST> Body;
public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,
std::unique_ptr<ExprAST> Body)
: Proto(std::move(Proto)), Body(std::move(Body)) {}
Function *codegen();
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//
/// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current
/// token the parser is looking at. getNextToken reads another token from the
/// lexer and updates CurTok with its results.
static int CurTok;
static int getNextToken() { return CurTok = gettok(); }
/// BinopPrecedence - This holds the precedence for each binary operator that is
/// defined.
static std::map<char, int> BinopPrecedence;
/// GetTokPrecedence - Get the precedence of the pending binary operator token.
static int GetTokPrecedence() {
if (!isascii(CurTok))
return -1;
// Make sure it's a declared binop.
int TokPrec = BinopPrecedence[CurTok];
if (TokPrec <= 0)
return -1;
return TokPrec;
}
/// LogError* - These are little helper functions for error handling.
std::unique_ptr<ExprAST> LogError(const char *Str) {
fprintf(stderr, "Error: %s\n", Str);
return nullptr;
}
std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
LogError(Str);
return nullptr;
}
static std::unique_ptr<ExprAST> ParseExpression();
/// numberexpr ::= number
static std::unique_ptr<ExprAST> ParseNumberExpr() {
auto Result = std::make_unique<NumberExprAST>(NumVal);
getNextToken(); // consume the number
return std::move(Result);
}
/// parenexpr ::= '(' expression ')'
static std::unique_ptr<ExprAST> ParseParenExpr() {
getNextToken(); // eat (.
auto V = ParseExpression();
if (!V)
return nullptr;
if (CurTok != ')')
return LogError("expected ')'");
getNextToken(); // eat ).
return V;
}
/// identifierexpr
/// ::= identifier
/// ::= identifier '(' expression* ')'
static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
std::string IdName = IdentifierStr;
getNextToken(); // eat identifier.
if (CurTok != '(') // Simple variable ref.
return std::make_unique<VariableExprAST>(IdName);
// Call.
getNextToken(); // eat (
std::vector<std::unique_ptr<ExprAST>> Args;
if (CurTok != ')') {
while (true) {
if (auto Arg = ParseExpression())
Args.push_back(std::move(Arg));
else
return nullptr;
if (CurTok == ')')
break;
if (CurTok != ',')
return LogError("Expected ')' or ',' in argument list");
getNextToken();
}
}
// Eat the ')'.
getNextToken();
return std::make_unique<CallExprAST>(IdName, std::move(Args));
}
/// primary
/// ::= identifierexpr
/// ::= numberexpr
/// ::= parenexpr
static std::unique_ptr<ExprAST> ParsePrimary() {
switch (CurTok) {
default:
return LogError("unknown token when expecting an expression");
case tok_identifier:
return ParseIdentifierExpr();
case tok_number:
return ParseNumberExpr();
case '(':
return ParseParenExpr();
}
}
/// binoprhs
/// ::= ('+' primary)*
static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
std::unique_ptr<ExprAST> LHS) {
// If this is a binop, find its precedence.
while (true) {
int TokPrec = GetTokPrecedence();
// If this is a binop that binds at least as tightly as the current binop,
// consume it, otherwise we are done.
if (TokPrec < ExprPrec)
return LHS;
// Okay, we know this is a binop.
int BinOp = CurTok;
getNextToken(); // eat binop
// Parse the primary expression after the binary operator.
auto RHS = ParsePrimary();
if (!RHS)
return nullptr;
// If BinOp binds less tightly with RHS than the operator after RHS, let
// the pending operator take RHS as its LHS.
int NextPrec = GetTokPrecedence();
if (TokPrec < NextPrec) {
RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
if (!RHS)
return nullptr;
}
// Merge LHS/RHS.
LHS =
std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
}
}
/// expression
/// ::= primary binoprhs
///
static std::unique_ptr<ExprAST> ParseExpression() {
auto LHS = ParsePrimary();
if (!LHS)
return nullptr;
return ParseBinOpRHS(0, std::move(LHS));
}
/// prototype
/// ::= id '(' id* ')'
static std::unique_ptr<PrototypeAST> ParsePrototype() {
if (CurTok != tok_identifier)
return LogErrorP("Expected function name in prototype");
std::string FnName = IdentifierStr;
getNextToken();
if (CurTok != '(')
return LogErrorP("Expected '(' in prototype");
std::vector<std::string> ArgNames;
while (getNextToken() == tok_identifier)
ArgNames.push_back(IdentifierStr);
if (CurTok != ')')
return LogErrorP("Expected ')' in prototype");
// success.
getNextToken(); // eat ')'.
return std::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
}
/// definition ::= 'def' prototype expression
static std::unique_ptr<FunctionAST> ParseDefinition() {
getNextToken(); // eat def.
auto Proto = ParsePrototype();
if (!Proto)
return nullptr;
if (auto E = ParseExpression())
return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
return nullptr;
}
/// toplevelexpr ::= expression
static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
if (auto E = ParseExpression()) {
// Make an anonymous proto.
auto Proto = std::make_unique<PrototypeAST>("__anon_expr",
std::vector<std::string>());
return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
}
return nullptr;
}
/// external ::= 'extern' prototype
static std::unique_ptr<PrototypeAST> ParseExtern() {
getNextToken(); // eat extern.
return ParsePrototype();
}
//===----------------------------------------------------------------------===//
// Code Generation
//===----------------------------------------------------------------------===//
static std::unique_ptr<LLVMContext> TheContext;
static std::unique_ptr<Module> TheModule;
static std::unique_ptr<IRBuilder<>> Builder;
static std::map<std::string, Value *> NamedValues;
static std::unique_ptr<KaleidoscopeJIT> TheJIT;
static std::unique_ptr<FunctionPassManager> TheFPM;
static std::unique_ptr<LoopAnalysisManager> TheLAM;
static std::unique_ptr<FunctionAnalysisManager> TheFAM;
static std::unique_ptr<CGSCCAnalysisManager> TheCGAM;
static std::unique_ptr<ModuleAnalysisManager> TheMAM;
static std::unique_ptr<PassInstrumentationCallbacks> ThePIC;
static std::unique_ptr<StandardInstrumentations> TheSI;
static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
static ExitOnError ExitOnErr;
Value *LogErrorV(const char *Str) {
LogError(Str);
return nullptr;
}
Function *getFunction(std::string Name) {
// First, see if the function has already been added to the current module.
if (auto *F = TheModule->getFunction(Name))
return F;
// If not, check whether we can codegen the declaration from some existing
// prototype.
auto FI = FunctionProtos.find(Name);
if (FI != FunctionProtos.end())
return FI->second->codegen();
// If no existing prototype exists, return null.
return nullptr;
}
Value *NumberExprAST::codegen() {
return ConstantFP::get(*TheContext, APFloat(Val));
}
Value *VariableExprAST::codegen() {
// Look this variable up in the function.
Value *V = NamedValues[Name];
if (!V)
return LogErrorV("Unknown variable name");
return V;
}
Value *BinaryExprAST::codegen() {
Value *L = LHS->codegen();
Value *R = RHS->codegen();
if (!L || !R)
return nullptr;
switch (Op) {
case '+':
return Builder->CreateFAdd(L, R, "addtmp");
case '-':
return Builder->CreateFSub(L, R, "subtmp");
case '*':
return Builder->CreateFMul(L, R, "multmp");
case '<':
L = Builder->CreateFCmpULT(L, R, "cmptmp");
// Convert bool 0/1 to double 0.0 or 1.0
return Builder->CreateUIToFP(L, Type::getDoubleTy(*TheContext), "booltmp");
default:
return LogErrorV("invalid binary operator");
}
}
Value *CallExprAST::codegen() {
// Look up the name in the global module table.
Function *CalleeF = getFunction(Callee);
if (!CalleeF)
return LogErrorV("Unknown function referenced");
// If argument mismatch error.
if (CalleeF->arg_size() != Args.size())
return LogErrorV("Incorrect # arguments passed");
std::vector<Value *> ArgsV;
for (unsigned i = 0, e = Args.size(); i != e; ++i) {
ArgsV.push_back(Args[i]->codegen());
if (!ArgsV.back())
return nullptr;
}
return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
}
Function *PrototypeAST::codegen() {
// Make the function type: double(double,double) etc.
std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(*TheContext));
FunctionType *FT =
FunctionType::get(Type::getDoubleTy(*TheContext), Doubles, false);
Function *F =
Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
// Set names for all arguments.
unsigned Idx = 0;
for (auto &Arg : F->args())
Arg.setName(Args[Idx++]);
return F;
}
Function *FunctionAST::codegen() {
// Transfer ownership of the prototype to the FunctionProtos map, but keep a
// reference to it for use below.
auto &P = *Proto;
FunctionProtos[Proto->getName()] = std::move(Proto);
Function *TheFunction = getFunction(P.getName());
if (!TheFunction)
return nullptr;
// Create a new basic block to start insertion into.
BasicBlock *BB = BasicBlock::Create(*TheContext, "entry", TheFunction);
Builder->SetInsertPoint(BB);
// Record the function arguments in the NamedValues map.
NamedValues.clear();
for (auto &Arg : TheFunction->args())
NamedValues[std::string(Arg.getName())] = &Arg;
if (Value *RetVal = Body->codegen()) {
// Finish off the function.
Builder->CreateRet(RetVal);
// Validate the generated code, checking for consistency.
verifyFunction(*TheFunction);
// Run the optimizer on the function.
TheFPM->run(*TheFunction, *TheFAM);
return TheFunction;
}
// Error reading body, remove function.
TheFunction->eraseFromParent();
return nullptr;
}
//===----------------------------------------------------------------------===//
// Top-Level parsing and JIT Driver
//===----------------------------------------------------------------------===//
static void InitializeModuleAndManagers() {
// Open a new context and module.
TheContext = std::make_unique<LLVMContext>();
TheModule = std::make_unique<Module>("KaleidoscopeJIT", *TheContext);
TheModule->setDataLayout(TheJIT->getDataLayout());
// Create a new builder for the module.
Builder = std::make_unique<IRBuilder<>>(*TheContext);
// Create new pass and analysis managers.
TheFPM = std::make_unique<FunctionPassManager>();
TheLAM = std::make_unique<LoopAnalysisManager>();
TheFAM = std::make_unique<FunctionAnalysisManager>();
TheCGAM = std::make_unique<CGSCCAnalysisManager>();
TheMAM = std::make_unique<ModuleAnalysisManager>();
ThePIC = std::make_unique<PassInstrumentationCallbacks>();
TheSI = std::make_unique<StandardInstrumentations>(*TheContext,
/*DebugLogging*/ true);
TheSI->registerCallbacks(*ThePIC, TheMAM.get());
// Add transform passes.
// Do simple "peephole" optimizations and bit-twiddling optzns.
TheFPM->addPass(InstCombinePass());
// Reassociate expressions.
TheFPM->addPass(ReassociatePass());
// Eliminate Common SubExpressions.
TheFPM->addPass(GVNPass());
// Simplify the control flow graph (deleting unreachable blocks, etc).
TheFPM->addPass(SimplifyCFGPass());
// Register analysis passes used in these transform passes.
PassBuilder PB;
PB.registerModuleAnalyses(*TheMAM);
PB.registerFunctionAnalyses(*TheFAM);
PB.crossRegisterProxies(*TheLAM, *TheFAM, *TheCGAM, *TheMAM);
}
static void HandleDefinition() {
if (auto FnAST = ParseDefinition()) {
if (auto *FnIR = FnAST->codegen()) {
fprintf(stderr, "Read function definition:");
FnIR->print(errs());
fprintf(stderr, "\n");
ExitOnErr(TheJIT->addModule(
ThreadSafeModule(std::move(TheModule), std::move(TheContext))));
InitializeModuleAndManagers();
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
static void HandleExtern() {
if (auto ProtoAST = ParseExtern()) {
if (auto *FnIR = ProtoAST->codegen()) {
fprintf(stderr, "Read extern: ");
FnIR->print(errs());
fprintf(stderr, "\n");
FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
static void HandleTopLevelExpression() {
// Evaluate a top-level expression into an anonymous function.
if (auto FnAST = ParseTopLevelExpr()) {
if (FnAST->codegen()) {
// Create a ResourceTracker to track JIT'd memory allocated to our
// anonymous expression -- that way we can free it after executing.
auto RT = TheJIT->getMainJITDylib().createResourceTracker();
auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
InitializeModuleAndManagers();
// Search the JIT for the __anon_expr symbol.
auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr"));
// Get the symbol's address and cast it to the right type (takes no
// arguments, returns a double) so we can call it as a native function.
double (*FP)() = ExprSymbol.getAddress().toPtr<double (*)()>();
fprintf(stderr, "Evaluated to %f\n", FP());
// Delete the anonymous expression module from the JIT.
ExitOnErr(RT->remove());
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
/// top ::= definition | external | expression | ';'
static void MainLoop() {
while (true) {
fprintf(stderr, "ready> ");
switch (CurTok) {
case tok_eof:
return;
case ';': // ignore top-level semicolons.
getNextToken();
break;
case tok_def:
HandleDefinition();
break;
case tok_extern:
HandleExtern();
break;
default:
HandleTopLevelExpression();
break;
}
}
}
//===----------------------------------------------------------------------===//
// "Library" functions that can be "extern'd" from user code.
//===----------------------------------------------------------------------===//
#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif
/// putchard - putchar that takes a double and returns 0.
extern "C" DLLEXPORT double putchard(double X) {
fputc((char)X, stderr);
return 0;
}
/// printd - printf that takes a double prints it as "%f\n", returning 0.
extern "C" DLLEXPORT double printd(double X) {
fprintf(stderr, "%f\n", X);
return 0;
}
//===----------------------------------------------------------------------===//
// Main driver code.
//===----------------------------------------------------------------------===//
int main() {
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
InitializeNativeTargetAsmParser();
// Install standard binary operators.
// 1 is lowest precedence.
BinopPrecedence['<'] = 10;
BinopPrecedence['+'] = 20;
BinopPrecedence['-'] = 20;
BinopPrecedence['*'] = 40; // highest.
// Prime the first token.
fprintf(stderr, "ready> ");
getNextToken();
TheJIT = ExitOnErr(KaleidoscopeJIT::Create());
InitializeModuleAndManagers();
// Run the main "interpreter loop" now.
MainLoop();
return 0;
}