3. Kaleidoscope: LLVM IRへのコード生成¶
3.1. 第3章 イントロダクション¶
「LLVMを使った言語の実装」チュートリアルの第3章へようこそ。この章では、第2章で構築した抽象構文木をLLVM IRに変換する方法を紹介します。LLVMの動作を少し理解するのに役立つだけでなく、LLVMがいかに簡単に使えるかを実証します。レクサーとパーサーを構築するよりも、LLVM IRコードを生成する方がはるかに簡単です。:)
**注意**: この章以降のコードはLLVM 3.7以降が必要です。LLVM 3.6以前では動作しません。また、LLVMのリリースに合ったバージョンのチュートリアルを使用する必要があります。公式のLLVMリリースを使用している場合は、リリースに含まれているか、llvm.orgのリリースページにあるドキュメントのバージョンを使用してください。
3.2. コード生成の準備¶
LLVM IRを生成するために、まず簡単な準備が必要です。最初に、各ASTクラスに仮想コード生成(codegen)メソッドを定義します。
/// ExprAST - Base class for all expression nodes.
class ExprAST {
public:
virtual ~ExprAST() = default;
virtual Value *codegen() = 0;
};
/// NumberExprAST - Expression class for numeric literals like "1.0".
class NumberExprAST : public ExprAST {
double Val;
public:
NumberExprAST(double Val) : Val(Val) {}
Value *codegen() override;
};
...
codegen()メソッドは、そのASTノードとその依存関係すべてについてIRを生成するように指示し、すべてLLVM Valueオブジェクトを返します。「Value」は、LLVMにおける「静的単一代入(SSA)レジスタ」または「SSA値」を表すために使用されるクラスです。SSA値の最も顕著な側面は、その値が関連する命令の実行時に計算され、命令が再実行されるまで(そして再実行された場合)新しい値を取得しないことです。言い換えれば、SSA値を「変更」する方法はありません。詳細については、静的単一代入について読んでみてください。一度理解すれば、概念は非常に自然なものです。
ExprASTクラス階層に仮想メソッドを追加する代わりに、Visitorパターンやその他の方法を使用してモデル化することもできます。繰り返しますが、このチュートリアルでは優れたソフトウェアエンジニアリング手法については詳しく説明しません。ここでは、仮想メソッドを追加するのが最も簡単です。
次に必要なのは、パーサーで使用した「LogError」メソッドのようなもので、コード生成中に見つかったエラー(例えば、未定義のパラメータの使用)を報告するために使用されます。
static std::unique_ptr<LLVMContext> TheContext;
static std::unique_ptr<IRBuilder<>> Builder(TheContext);
static std::unique_ptr<Module> TheModule;
static std::map<std::string, Value *> NamedValues;
Value *LogErrorV(const char *Str) {
LogError(Str);
return nullptr;
}
静的変数はコード生成中に使用されます。TheContext
は、型や定数値テーブルなど、多くのLLVMコアデータ構造を所有する不透明なオブジェクトです。詳細を理解する必要はなく、APIに渡すための単一のインスタンスが必要です。
Builder
オブジェクトは、LLVM命令を簡単に生成するためのヘルパーオブジェクトです。IRBuilderクラステンプレートのインスタンスは、命令を挿入する現在の場所を追跡し、新しい命令を作成するためのメソッドを持っています。
TheModule
は、関数とグローバル変数を格納するLLVMの構造体です。多くの点で、LLVM IRがコードを格納するために使用する最上位の構造体です。生成するすべてのIRのメモリを所有します。そのため、codegen()メソッドはunique_ptr<Value>ではなく、生のValue*を返します。
NamedValues
マップは、現在のスコープで定義されている値とそのLLVM表現を追跡します。(言い換えれば、コードのシンボルテーブルです)。この形式のKaleidoscopeでは、参照できるのは関数パラメータのみです。そのため、関数本体のコードを生成する際には、関数パラメータがこのマップに存在します。
これらの基本が整ったので、各式のコードを生成する方法について説明を始めましょう。これは、Builder
が何かにコードを生成するように設定されていることを前提としていることに注意してください。ここでは、これがすでに完了していると仮定し、コードを生成するために使用します。
3.3. 式のコード生成¶
式ノードのLLVMコードの生成は非常に簡単です。4つの式ノードすべてについて、45行未満のコメント付きコードです。最初に数値リテラルを行います。
Value *NumberExprAST::codegen() {
return ConstantFP::get(*TheContext, APFloat(Val));
}
LLVM IRでは、数値定数はConstantFP
クラスで表されます。このクラスは、数値を内部的にAPFloat
で保持します(APFloat
は任意精度の浮動小数点定数を保持できます)。このコードは基本的にConstantFP
を作成して返すだけです。LLVM IRでは、定数はすべて一意にまとめられ、共有されることに注意してください。このため、APIは「new foo(..)」または「foo::Create(..)」の代わりに「foo::get(…)」というイディオムを使用します。
Value *VariableExprAST::codegen() {
// Look this variable up in the function.
Value *V = NamedValues[Name];
if (!V)
LogErrorV("Unknown variable name");
return V;
}
変数への参照もLLVMを使用すると非常に簡単です。Kaleidoscopeの単純なバージョンでは、変数はすでにどこかで生成されており、その値が使用可能であると仮定します。実際には、NamedValues
マップに含まれる値は関数引数のみです。このコードは、指定された名前がマップにあるかどうかを確認し(そうでない場合は、不明な変数が参照されています)、その値を返します。今後の章では、シンボルテーブルにループ誘導変数のサポートを追加し、ローカル変数のサポートを追加します。
Value *BinaryExprAST::codegen() {
Value *L = LHS->codegen();
Value *R = RHS->codegen();
if (!L || !R)
return nullptr;
switch (Op) {
case '+':
return Builder->CreateFAdd(L, R, "addtmp");
case '-':
return Builder->CreateFSub(L, R, "subtmp");
case '*':
return Builder->CreateFMul(L, R, "multmp");
case '<':
L = Builder->CreateFCmpULT(L, R, "cmptmp");
// Convert bool 0/1 to double 0.0 or 1.0
return Builder->CreateUIToFP(L, Type::getDoubleTy(TheContext),
"booltmp");
default:
return LogErrorV("invalid binary operator");
}
}
二項演算子はもう少し面白くなります。基本的な考え方は、式の左辺のコードを再帰的に生成し、次に右辺のコードを生成し、次に二項式の結果を計算するというものです。このコードでは、opcodeに対して単純なswitchを実行して、正しいLLVM命令を作成します。
上記の例では、LLVMビルダークラスの価値が示され始めています。IRBuilderは新しく作成された命令をどこに挿入するかを知っており、作成する命令(例:CreateFAdd
)、使用するオペランド(ここではL
とR
)、およびオプションで生成された命令の名前を指定するだけです。
LLVMの良い点の1つは、名前が単なるヒントであることです。たとえば、上記のコードが複数の「addtmp」変数を生成する場合、LLVMはそれぞれに増加する一意の数値サフィックスを自動的に提供します。命令のローカル値名は純粋にオプションですが、IRダンプを読みやすくします。
LLVM命令は厳密なルールによって制約されます。たとえば、add命令の左オペランドと右オペランドは同じ型でなければならず、addの結果型はオペランド型と一致する必要があります。Kaleidoscopeのすべての値はdoubleなので、add、sub、mulのコードは非常にシンプルになります。
一方、LLVMはfcmp命令が常に「i1」値(1ビット整数)を返すように指定しています。問題は、Kaleidoscopeは値が0.0または1.0であることを望んでいることです。これらのセマンティクスを取得するために、fcmp命令とuitofp命令を組み合わせます。この命令は、入力整数を符号なし値として扱うことにより、入力整数を浮動小数点値に変換します。対照的に、sitofp命令を使用すると、Kaleidoscopeの「<」演算子は入力値に応じて0.0と-1.0を返します。
Value *CallExprAST::codegen() {
// Look up the name in the global module table.
Function *CalleeF = TheModule->getFunction(Callee);
if (!CalleeF)
return LogErrorV("Unknown function referenced");
// If argument mismatch error.
if (CalleeF->arg_size() != Args.size())
return LogErrorV("Incorrect # arguments passed");
std::vector<Value *> ArgsV;
for (unsigned i = 0, e = Args.size(); i != e; ++i) {
ArgsV.push_back(Args[i]->codegen());
if (!ArgsV.back())
return nullptr;
}
return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
}
LLVMを使用した関数呼び出しのコード生成は非常に簡単です。上記のコードは、最初にLLVMモジュールのシンボルテーブルで関数名検索を実行します。LLVMモジュールは、JITコンパイルする関数を保持するコンテナであることを思い出してください。各関数にユーザーが指定したものと同じ名前を付けることで、LLVMシンボルテーブルを使用して関数名を解決できます。
呼び出す関数が決まったら、渡される各引数のコードを再帰的に生成し、LLVM call命令を作成します。LLVMはデフォルトでネイティブC呼び出し規約を使用するため、これらの呼び出しは追加の労力なしに「sin」や「cos」などの標準ライブラリ関数も呼び出すことができます。
これで、Kaleidoscopeでこれまでに使用した4つの基本的な式の処理は終わりです。自由にさらに追加してください。たとえば、LLVM言語リファレンスを参照すると、基本的なフレームワークに簡単に組み込むことができる他の興味深い命令がいくつか見つかります。
3.4. 関数のコード生成¶
プロトタイプと関数のコード生成では、多くの詳細を処理する必要があります。そのため、式のコード生成ほど美しくはありませんが、重要なポイントを説明することができます。まず、プロトタイプのコード生成について説明しましょう。プロトタイプは、関数本体と外部関数宣言の両方で使用されます。コードは次のように始まります。
Function *PrototypeAST::codegen() {
// Make the function type: double(double,double) etc.
std::vector<Type*> Doubles(Args.size(),
Type::getDoubleTy(*TheContext));
FunctionType *FT =
FunctionType::get(Type::getDoubleTy(*TheContext), Doubles, false);
Function *F =
Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
このコードは数行で多くのことを実現しています。まず、この関数は「Value*」ではなく「Function*」を返すことに注意してください。「プロトタイプ」は実際には関数の外部インターフェース(式によって計算された値ではない)について説明しているため、codegenされたときに対応するLLVM関数を返すのは理にかなっています。
FunctionType::get
の呼び出しは、特定のプロトタイプに使用する必要があるFunctionType
を作成します。Kaleidoscopeのすべての関数引数はdouble型であるため、最初の行は「N」個のLLVM double型のベクトルを作成します。次に、Functiontype::get
メソッドを使用して、引数として「N」個のdoubleを取り、結果として1つのdoubleを返し、可変長引数ではない(falseパラメータはこれを示します)関数型を作成します。LLVMの型は定数と同様に一意であるため、「new」で型を作成するのではなく、「get」で型を取得することに注意してください。
上記の最後の行は、実際にはプロトタイプに対応する IR 関数を作成します。これは、使用する型、リンケージ、名前、および挿入するモジュールを示します。「外部リンケージ」とは、関数が現在のモジュールの外部で定義されている可能性があり、かつ/またはモジュールの外部の関数から呼び出し可能であることを意味します。渡された名前は、ユーザーが指定した名前です。「TheModule
」が指定されているため、この名前は「TheModule
」のシンボルテーブルに登録されます。
// Set names for all arguments.
unsigned Idx = 0;
for (auto &Arg : F->args())
Arg.setName(Args[Idx++]);
return F;
最後に、プロトタイプで指定された名前を使用して、各関数引数の名前を設定します。この手順は厳密には必須ではありませんが、名前の一貫性を保つことで IR の可読性が向上し、後続のコードがプロトタイプ AST で引数を検索するのではなく、名前で直接参照できるようになります。
この時点で、本体のない関数プロトタイプができました。これは、LLVM IR が関数宣言を表す方法です。Kaleidoscope の extern 文の場合、これで完了です。ただし、関数定義の場合は、関数本体をコード生成してアタッチする必要があります。
Function *FunctionAST::codegen() {
// First, check for an existing function from a previous 'extern' declaration.
Function *TheFunction = TheModule->getFunction(Proto->getName());
if (!TheFunction)
TheFunction = Proto->codegen();
if (!TheFunction)
return nullptr;
if (!TheFunction->empty())
return (Function*)LogErrorV("Function cannot be redefined.");
関数定義の場合、まず 'extern' 文を使用して既に作成されている場合に備えて、TheModule のシンボルテーブルでこの関数の既存バージョンを検索します。 Module::getFunction が null を返す場合は、以前のバージョンが存在しないため、プロトタイプからコード生成します。いずれの場合も、開始する前に関数が空であること(つまり、まだ本体がないこと)を確認する必要があります。
// Create a new basic block to start insertion into.
BasicBlock *BB = BasicBlock::Create(*TheContext, "entry", TheFunction);
Builder->SetInsertPoint(BB);
// Record the function arguments in the NamedValues map.
NamedValues.clear();
for (auto &Arg : TheFunction->args())
NamedValues[std::string(Arg.getName())] = &Arg;
ここで、Builder
が設定されるポイントに到達します。最初の行は、TheFunction
に挿入される新しい 基本ブロック(「entry」という名前)を作成します。2 行目は、新しい命令を新しい基本ブロックの最後に挿入する必要があることをビルダーに指示します。LLVM の基本ブロックは、制御フローグラフ を定義する関数の重要な部分です。制御フローがないため、現時点では関数は 1 つのブロックのみを含みます。これは 第 5 章 で修正します :)
次に、VariableExprAST
ノードからアクセスできるように、関数引数を NamedValues マップに追加します(最初にクリアした後)。
if (Value *RetVal = Body->codegen()) {
// Finish off the function.
Builder->CreateRet(RetVal);
// Validate the generated code, checking for consistency.
verifyFunction(*TheFunction);
return TheFunction;
}
挿入ポイントが設定され、NamedValues マップが設定されたら、関数のルート式に対して codegen()
メソッドを呼び出します。エラーが発生しない場合、これは式を計算してエントリブロックにコードを発行し、計算された値を返します。エラーがないと仮定すると、LLVM ret 命令 を作成し、関数を完了します。関数がビルドされたら、LLVM によって提供される verifyFunction
を呼び出します。この関数は、生成されたコードに対してさまざまな整合性チェックを実行し、コンパイラがすべて正しく行われているかどうかを判断します。これを使用することは重要です。多くのバグをキャッチできます。関数が完了して検証されたら、それを返します。
// Error reading body, remove function.
TheFunction->eraseFromParent();
return nullptr;
}
残っているのは、エラーケースの処理だけです。簡単にするために、eraseFromParent
メソッドを使用して生成した関数を単に削除することで、これを処理します。これにより、ユーザーは以前に入力ミスをした関数を再定義できます。削除しないと、シンボルテーブルに本体とともに残ってしまい、将来の再定義が妨げられます。
ただし、このコードにはバグがあります。FunctionAST::codegen()
メソッドが既存の IR 関数を見つけると、定義自身のプロトタイプに対してそのシグネチャを検証しません。これは、以前の「extern」宣言が関数定義のシグネチャよりも優先されることを意味し、たとえば関数引数の名前が異なる場合などに、コード生成が失敗する可能性があります。このバグを修正するには、いくつかの方法があります。何が思いつくか試してみてください。テストケースはこちらです
extern foo(a); # ok, defines foo.
def foo(b) b; # Error: Unknown variable name. (decl using 'a' takes precedence).
3.5. ドライバの変更とまとめ¶
今のところ、LLVM へのコード生成は、きれいな IR 呼び出しを見ることができる以外は、それほど役に立ちません。サンプルコードは、「HandleDefinition
」、「HandleExtern
」などの関数に codegen の呼び出しを挿入し、LLVM IR をダンプします。これは、単純な関数の LLVM IR を調べるための良い方法です。例えば
ready> 4+5;
Read top-level expression:
define double @0() {
entry:
ret double 9.000000e+00
}
パーサーがトップレベルの式を匿名関数に変換する方法に注目してください。これは、次の章で JIT サポート を追加するときに役立ちます。また、コードは非常に文字通りに転記されており、IRBuilder によって行われる単純な定数畳み込みを除いて、最適化は実行されていません。次の章で 最適化を追加 します。
ready> def foo(a b) a*a + 2*a*b + b*b;
Read function definition:
define double @foo(double %a, double %b) {
entry:
%multmp = fmul double %a, %a
%multmp1 = fmul double 2.000000e+00, %a
%multmp2 = fmul double %multmp1, %b
%addtmp = fadd double %multmp, %multmp2
%multmp3 = fmul double %b, %b
%addtmp4 = fadd double %addtmp, %multmp3
ret double %addtmp4
}
これは、単純な算術を示しています。命令を作成するために使用する LLVM ビルダー呼び出しとの顕著な類似性に注目してください。
ready> def bar(a) foo(a, 4.0) + bar(31337);
Read function definition:
define double @bar(double %a) {
entry:
%calltmp = call double @foo(double %a, double 4.000000e+00)
%calltmp1 = call double @bar(double 3.133700e+04)
%addtmp = fadd double %calltmp, %calltmp1
ret double %addtmp
}
これは、いくつかの関数呼び出しを示しています。この関数を呼び出すと、実行に時間がかかることに注意してください。将来的には、再帰を実際に役立てるために、条件付き制御フローを追加します:)。
ready> extern cos(x);
Read extern:
declare double @cos(double)
ready> cos(1.234);
Read top-level expression:
define double @1() {
entry:
%calltmp = call double @cos(double 1.234000e+00)
ret double %calltmp
}
これは、libm の「cos」関数の extern と、それへの呼び出しを示しています。
ready> ^D
; ModuleID = 'my cool jit'
define double @0() {
entry:
%addtmp = fadd double 4.000000e+00, 5.000000e+00
ret double %addtmp
}
define double @foo(double %a, double %b) {
entry:
%multmp = fmul double %a, %a
%multmp1 = fmul double 2.000000e+00, %a
%multmp2 = fmul double %multmp1, %b
%addtmp = fadd double %multmp, %multmp2
%multmp3 = fmul double %b, %b
%addtmp4 = fadd double %addtmp, %multmp3
ret double %addtmp4
}
define double @bar(double %a) {
entry:
%calltmp = call double @foo(double %a, double 4.000000e+00)
%calltmp1 = call double @bar(double 3.133700e+04)
%addtmp = fadd double %calltmp, %calltmp1
ret double %addtmp
}
declare double @cos(double)
define double @1() {
entry:
%calltmp = call double @cos(double 1.234000e+00)
ret double %calltmp
}
現在のデモを終了すると(Linux では CTRL + D、Windows では CTRL + Z と ENTER を送信することにより)、生成されたモジュール全体の IR がダンプされます。ここで、すべての関数が相互に参照している全体像を見ることができます。
これで、Kaleidoscope チュートリアルの第 3 章は終わりです。次は、JIT コード生成とオプティマイザのサポートを追加 して、実際にコードを実行できるようにする方法について説明します。
3.6. 完全なコードリスト¶
LLVM コードジェネレーターで拡張された、実行中の例の完全なコードリストを以下に示します。これは LLVM ライブラリを使用するため、それらをリンクする必要があります。これを行うには、llvm-config ツールを使用して、使用するオプションについて makefile /コマンドラインに指示します
# Compile
clang++ -g -O3 toy.cpp `llvm-config --cxxflags --ldflags --system-libs --libs core` -o toy
# Run
./toy
コードは次のとおりです
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstdlib>
#include <map>
#include <memory>
#include <string>
#include <vector>
using namespace llvm;
//===----------------------------------------------------------------------===//
// Lexer
//===----------------------------------------------------------------------===//
// The lexer returns tokens [0-255] if it is an unknown character, otherwise one
// of these for known things.
enum Token {
tok_eof = -1,
// commands
tok_def = -2,
tok_extern = -3,
// primary
tok_identifier = -4,
tok_number = -5
};
static std::string IdentifierStr; // Filled in if tok_identifier
static double NumVal; // Filled in if tok_number
/// gettok - Return the next token from standard input.
static int gettok() {
static int LastChar = ' ';
// Skip any whitespace.
while (isspace(LastChar))
LastChar = getchar();
if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
IdentifierStr = LastChar;
while (isalnum((LastChar = getchar())))
IdentifierStr += LastChar;
if (IdentifierStr == "def")
return tok_def;
if (IdentifierStr == "extern")
return tok_extern;
return tok_identifier;
}
if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
std::string NumStr;
do {
NumStr += LastChar;
LastChar = getchar();
} while (isdigit(LastChar) || LastChar == '.');
NumVal = strtod(NumStr.c_str(), nullptr);
return tok_number;
}
if (LastChar == '#') {
// Comment until end of line.
do
LastChar = getchar();
while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
if (LastChar != EOF)
return gettok();
}
// Check for end of file. Don't eat the EOF.
if (LastChar == EOF)
return tok_eof;
// Otherwise, just return the character as its ascii value.
int ThisChar = LastChar;
LastChar = getchar();
return ThisChar;
}
//===----------------------------------------------------------------------===//
// Abstract Syntax Tree (aka Parse Tree)
//===----------------------------------------------------------------------===//
namespace {
/// ExprAST - Base class for all expression nodes.
class ExprAST {
public:
virtual ~ExprAST() = default;
virtual Value *codegen() = 0;
};
/// NumberExprAST - Expression class for numeric literals like "1.0".
class NumberExprAST : public ExprAST {
double Val;
public:
NumberExprAST(double Val) : Val(Val) {}
Value *codegen() override;
};
/// VariableExprAST - Expression class for referencing a variable, like "a".
class VariableExprAST : public ExprAST {
std::string Name;
public:
VariableExprAST(const std::string &Name) : Name(Name) {}
Value *codegen() override;
};
/// BinaryExprAST - Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
char Op;
std::unique_ptr<ExprAST> LHS, RHS;
public:
BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
std::unique_ptr<ExprAST> RHS)
: Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
Value *codegen() override;
};
/// CallExprAST - Expression class for function calls.
class CallExprAST : public ExprAST {
std::string Callee;
std::vector<std::unique_ptr<ExprAST>> Args;
public:
CallExprAST(const std::string &Callee,
std::vector<std::unique_ptr<ExprAST>> Args)
: Callee(Callee), Args(std::move(Args)) {}
Value *codegen() override;
};
/// PrototypeAST - This class represents the "prototype" for a function,
/// which captures its name, and its argument names (thus implicitly the number
/// of arguments the function takes).
class PrototypeAST {
std::string Name;
std::vector<std::string> Args;
public:
PrototypeAST(const std::string &Name, std::vector<std::string> Args)
: Name(Name), Args(std::move(Args)) {}
Function *codegen();
const std::string &getName() const { return Name; }
};
/// FunctionAST - This class represents a function definition itself.
class FunctionAST {
std::unique_ptr<PrototypeAST> Proto;
std::unique_ptr<ExprAST> Body;
public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,
std::unique_ptr<ExprAST> Body)
: Proto(std::move(Proto)), Body(std::move(Body)) {}
Function *codegen();
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//
/// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current
/// token the parser is looking at. getNextToken reads another token from the
/// lexer and updates CurTok with its results.
static int CurTok;
static int getNextToken() { return CurTok = gettok(); }
/// BinopPrecedence - This holds the precedence for each binary operator that is
/// defined.
static std::map<char, int> BinopPrecedence;
/// GetTokPrecedence - Get the precedence of the pending binary operator token.
static int GetTokPrecedence() {
if (!isascii(CurTok))
return -1;
// Make sure it's a declared binop.
int TokPrec = BinopPrecedence[CurTok];
if (TokPrec <= 0)
return -1;
return TokPrec;
}
/// LogError* - These are little helper functions for error handling.
std::unique_ptr<ExprAST> LogError(const char *Str) {
fprintf(stderr, "Error: %s\n", Str);
return nullptr;
}
std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
LogError(Str);
return nullptr;
}
static std::unique_ptr<ExprAST> ParseExpression();
/// numberexpr ::= number
static std::unique_ptr<ExprAST> ParseNumberExpr() {
auto Result = std::make_unique<NumberExprAST>(NumVal);
getNextToken(); // consume the number
return std::move(Result);
}
/// parenexpr ::= '(' expression ')'
static std::unique_ptr<ExprAST> ParseParenExpr() {
getNextToken(); // eat (.
auto V = ParseExpression();
if (!V)
return nullptr;
if (CurTok != ')')
return LogError("expected ')'");
getNextToken(); // eat ).
return V;
}
/// identifierexpr
/// ::= identifier
/// ::= identifier '(' expression* ')'
static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
std::string IdName = IdentifierStr;
getNextToken(); // eat identifier.
if (CurTok != '(') // Simple variable ref.
return std::make_unique<VariableExprAST>(IdName);
// Call.
getNextToken(); // eat (
std::vector<std::unique_ptr<ExprAST>> Args;
if (CurTok != ')') {
while (true) {
if (auto Arg = ParseExpression())
Args.push_back(std::move(Arg));
else
return nullptr;
if (CurTok == ')')
break;
if (CurTok != ',')
return LogError("Expected ')' or ',' in argument list");
getNextToken();
}
}
// Eat the ')'.
getNextToken();
return std::make_unique<CallExprAST>(IdName, std::move(Args));
}
/// primary
/// ::= identifierexpr
/// ::= numberexpr
/// ::= parenexpr
static std::unique_ptr<ExprAST> ParsePrimary() {
switch (CurTok) {
default:
return LogError("unknown token when expecting an expression");
case tok_identifier:
return ParseIdentifierExpr();
case tok_number:
return ParseNumberExpr();
case '(':
return ParseParenExpr();
}
}
/// binoprhs
/// ::= ('+' primary)*
static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
std::unique_ptr<ExprAST> LHS) {
// If this is a binop, find its precedence.
while (true) {
int TokPrec = GetTokPrecedence();
// If this is a binop that binds at least as tightly as the current binop,
// consume it, otherwise we are done.
if (TokPrec < ExprPrec)
return LHS;
// Okay, we know this is a binop.
int BinOp = CurTok;
getNextToken(); // eat binop
// Parse the primary expression after the binary operator.
auto RHS = ParsePrimary();
if (!RHS)
return nullptr;
// If BinOp binds less tightly with RHS than the operator after RHS, let
// the pending operator take RHS as its LHS.
int NextPrec = GetTokPrecedence();
if (TokPrec < NextPrec) {
RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
if (!RHS)
return nullptr;
}
// Merge LHS/RHS.
LHS =
std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
}
}
/// expression
/// ::= primary binoprhs
///
static std::unique_ptr<ExprAST> ParseExpression() {
auto LHS = ParsePrimary();
if (!LHS)
return nullptr;
return ParseBinOpRHS(0, std::move(LHS));
}
/// prototype
/// ::= id '(' id* ')'
static std::unique_ptr<PrototypeAST> ParsePrototype() {
if (CurTok != tok_identifier)
return LogErrorP("Expected function name in prototype");
std::string FnName = IdentifierStr;
getNextToken();
if (CurTok != '(')
return LogErrorP("Expected '(' in prototype");
std::vector<std::string> ArgNames;
while (getNextToken() == tok_identifier)
ArgNames.push_back(IdentifierStr);
if (CurTok != ')')
return LogErrorP("Expected ')' in prototype");
// success.
getNextToken(); // eat ')'.
return std::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
}
/// definition ::= 'def' prototype expression
static std::unique_ptr<FunctionAST> ParseDefinition() {
getNextToken(); // eat def.
auto Proto = ParsePrototype();
if (!Proto)
return nullptr;
if (auto E = ParseExpression())
return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
return nullptr;
}
/// toplevelexpr ::= expression
static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
if (auto E = ParseExpression()) {
// Make an anonymous proto.
auto Proto = std::make_unique<PrototypeAST>("__anon_expr",
std::vector<std::string>());
return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
}
return nullptr;
}
/// external ::= 'extern' prototype
static std::unique_ptr<PrototypeAST> ParseExtern() {
getNextToken(); // eat extern.
return ParsePrototype();
}
//===----------------------------------------------------------------------===//
// Code Generation
//===----------------------------------------------------------------------===//
static std::unique_ptr<LLVMContext> TheContext;
static std::unique_ptr<Module> TheModule;
static std::unique_ptr<IRBuilder<>> Builder;
static std::map<std::string, Value *> NamedValues;
Value *LogErrorV(const char *Str) {
LogError(Str);
return nullptr;
}
Value *NumberExprAST::codegen() {
return ConstantFP::get(*TheContext, APFloat(Val));
}
Value *VariableExprAST::codegen() {
// Look this variable up in the function.
Value *V = NamedValues[Name];
if (!V)
return LogErrorV("Unknown variable name");
return V;
}
Value *BinaryExprAST::codegen() {
Value *L = LHS->codegen();
Value *R = RHS->codegen();
if (!L || !R)
return nullptr;
switch (Op) {
case '+':
return Builder->CreateFAdd(L, R, "addtmp");
case '-':
return Builder->CreateFSub(L, R, "subtmp");
case '*':
return Builder->CreateFMul(L, R, "multmp");
case '<':
L = Builder->CreateFCmpULT(L, R, "cmptmp");
// Convert bool 0/1 to double 0.0 or 1.0
return Builder->CreateUIToFP(L, Type::getDoubleTy(*TheContext), "booltmp");
default:
return LogErrorV("invalid binary operator");
}
}
Value *CallExprAST::codegen() {
// Look up the name in the global module table.
Function *CalleeF = TheModule->getFunction(Callee);
if (!CalleeF)
return LogErrorV("Unknown function referenced");
// If argument mismatch error.
if (CalleeF->arg_size() != Args.size())
return LogErrorV("Incorrect # arguments passed");
std::vector<Value *> ArgsV;
for (unsigned i = 0, e = Args.size(); i != e; ++i) {
ArgsV.push_back(Args[i]->codegen());
if (!ArgsV.back())
return nullptr;
}
return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
}
Function *PrototypeAST::codegen() {
// Make the function type: double(double,double) etc.
std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(*TheContext));
FunctionType *FT =
FunctionType::get(Type::getDoubleTy(*TheContext), Doubles, false);
Function *F =
Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
// Set names for all arguments.
unsigned Idx = 0;
for (auto &Arg : F->args())
Arg.setName(Args[Idx++]);
return F;
}
Function *FunctionAST::codegen() {
// First, check for an existing function from a previous 'extern' declaration.
Function *TheFunction = TheModule->getFunction(Proto->getName());
if (!TheFunction)
TheFunction = Proto->codegen();
if (!TheFunction)
return nullptr;
// Create a new basic block to start insertion into.
BasicBlock *BB = BasicBlock::Create(*TheContext, "entry", TheFunction);
Builder->SetInsertPoint(BB);
// Record the function arguments in the NamedValues map.
NamedValues.clear();
for (auto &Arg : TheFunction->args())
NamedValues[std::string(Arg.getName())] = &Arg;
if (Value *RetVal = Body->codegen()) {
// Finish off the function.
Builder->CreateRet(RetVal);
// Validate the generated code, checking for consistency.
verifyFunction(*TheFunction);
return TheFunction;
}
// Error reading body, remove function.
TheFunction->eraseFromParent();
return nullptr;
}
//===----------------------------------------------------------------------===//
// Top-Level parsing and JIT Driver
//===----------------------------------------------------------------------===//
static void InitializeModule() {
// Open a new context and module.
TheContext = std::make_unique<LLVMContext>();
TheModule = std::make_unique<Module>("my cool jit", *TheContext);
// Create a new builder for the module.
Builder = std::make_unique<IRBuilder<>>(*TheContext);
}
static void HandleDefinition() {
if (auto FnAST = ParseDefinition()) {
if (auto *FnIR = FnAST->codegen()) {
fprintf(stderr, "Read function definition:");
FnIR->print(errs());
fprintf(stderr, "\n");
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
static void HandleExtern() {
if (auto ProtoAST = ParseExtern()) {
if (auto *FnIR = ProtoAST->codegen()) {
fprintf(stderr, "Read extern: ");
FnIR->print(errs());
fprintf(stderr, "\n");
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
static void HandleTopLevelExpression() {
// Evaluate a top-level expression into an anonymous function.
if (auto FnAST = ParseTopLevelExpr()) {
if (auto *FnIR = FnAST->codegen()) {
fprintf(stderr, "Read top-level expression:");
FnIR->print(errs());
fprintf(stderr, "\n");
// Remove the anonymous expression.
FnIR->eraseFromParent();
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
/// top ::= definition | external | expression | ';'
static void MainLoop() {
while (true) {
fprintf(stderr, "ready> ");
switch (CurTok) {
case tok_eof:
return;
case ';': // ignore top-level semicolons.
getNextToken();
break;
case tok_def:
HandleDefinition();
break;
case tok_extern:
HandleExtern();
break;
default:
HandleTopLevelExpression();
break;
}
}
}
//===----------------------------------------------------------------------===//
// Main driver code.
//===----------------------------------------------------------------------===//
int main() {
// Install standard binary operators.
// 1 is lowest precedence.
BinopPrecedence['<'] = 10;
BinopPrecedence['+'] = 20;
BinopPrecedence['-'] = 20;
BinopPrecedence['*'] = 40; // highest.
// Prime the first token.
fprintf(stderr, "ready> ");
getNextToken();
// Make the module, which holds all the code.
InitializeModule();
// Run the main "interpreter loop" now.
MainLoop();
// Print out all of the generated code.
TheModule->print(errs(), nullptr);
return 0;
}