Files
scully/src/AST/CodeGenVisitor.cpp

439 lines
11 KiB
C++
Raw Normal View History

2013-06-02 09:47:27 +02:00
/* The scully programming language.
*
* Copyright (c) Peter Dahlberg, Markus Hauschild and Florian Sattler, 2013.
* Licensed under the GNU GPL v2.
*/
2013-06-01 13:04:53 +02:00
#include "AST/CodeGenVisitor.h"
#include "llvm/Analysis/Verifier.h"
2013-06-01 17:22:52 +02:00
#include <iostream>
2013-06-01 13:04:53 +02:00
2013-06-01 21:10:05 +02:00
CodeGenVisitor::CodeGenVisitor(llvm::Module* module, llvm::FunctionPassManager *fpm, llvm::ExecutionEngine *ee) {
2013-06-01 14:29:18 +02:00
builder_ = new llvm::IRBuilder<>(llvm::getGlobalContext());
fpm_ = fpm;
2013-06-01 20:29:32 +02:00
module_ = module;
2013-06-01 21:10:05 +02:00
ee_ = ee;
2013-06-01 20:29:32 +02:00
scope_ = 0;
namedValues_.push_back(std::map<std::string, llvm::Value*>());
2013-06-01 18:08:13 +02:00
// create external for putchar
2013-06-01 18:08:13 +02:00
std::vector<llvm::Type*> argt(1, typeToLLVMType(Type::INT));
llvm::FunctionType* ft = llvm::FunctionType::get(typeToLLVMType(Type::INT), argt, false);
llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "put_char", module_);
2013-06-02 01:51:19 +02:00
// create external for time_seed
ft = llvm::FunctionType::get(typeToLLVMType(Type::INT), false);
llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "time_seed", module_);
// create external for random_if
ft = llvm::FunctionType::get(typeToLLVMType(Type::BOOL), argt, false);
2013-06-01 18:11:13 +02:00
llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "random_if", module_);
2013-06-01 13:04:53 +02:00
}
CodeGenVisitor::~CodeGenVisitor() {
delete builder_;
2013-06-01 13:04:53 +02:00
}
void CodeGenVisitor::visit(AssignmentExpression* e) {
2013-06-01 20:11:29 +02:00
value_ = 0;
e->getExpr()->accept(this);
if (value_ == 0) {
throw "error creating expression";
}
builder_->CreateStore(value_, getNamedValue(e->getId()));
2013-06-01 13:04:53 +02:00
}
void CodeGenVisitor::visit(BinOpExpression* e) {
2013-06-01 15:48:28 +02:00
e->getLeftExp()->accept(this);
2013-06-02 00:11:07 +02:00
if (!value_) {
throw "error evaluating expression (lhs)";
}
2013-06-01 15:48:28 +02:00
llvm::Value* lhs = value_;
2013-06-02 00:11:07 +02:00
e->getRightExp()->accept(this);
if (!value_) {
throw "error evaluating expression (rhs)";
2013-06-01 15:48:28 +02:00
}
2013-06-02 00:11:07 +02:00
llvm::Value* rhs = value_;
2013-06-01 15:48:28 +02:00
2013-06-02 00:10:20 +02:00
if (lhs->getType() != rhs->getType()) {
throw "lhs type of binop != rhs type of binop";
}
2013-06-01 16:13:23 +02:00
switch (e->getOp()) {
2013-06-01 17:00:39 +02:00
case BinOp::PLUS:
2013-06-01 16:13:23 +02:00
value_ = builder_->CreateAdd(lhs, rhs, "addtmp");
break;
2013-06-01 17:00:39 +02:00
case BinOp::MINUS:
2013-06-01 16:13:23 +02:00
value_ = builder_->CreateSub(lhs, rhs, "subtmp");
break;
2013-06-01 17:00:39 +02:00
case BinOp::TIMES:
2013-06-01 16:13:23 +02:00
value_ = builder_->CreateMul(lhs, rhs, "multmp");
break;
2013-06-01 17:00:39 +02:00
case BinOp::DIV:
2013-06-01 16:13:23 +02:00
value_ = builder_->CreateSDiv(lhs, rhs, "divtmp");
break;
2013-06-01 17:00:39 +02:00
case BinOp::EQUALS:
value_ = builder_->CreateICmpEQ(lhs, rhs, "cmptmp");
2013-06-01 16:13:23 +02:00
break;
2013-06-01 18:54:29 +02:00
case BinOp::LESS:
value_ = builder_->CreateICmpSLT(lhs, rhs, "cmptmp");
2013-06-01 18:54:29 +02:00
break;
2013-06-01 16:13:23 +02:00
default:
2013-06-02 00:10:20 +02:00
throw "Unkown Operator, This is a Bug!";
2013-06-01 16:13:23 +02:00
break;
}
2013-06-01 13:04:53 +02:00
}
void CodeGenVisitor::visit(ConstantExpression* e) {
2013-06-01 17:57:54 +02:00
if(e->getValue() == "true") {
value_ = llvm::ConstantInt::get(llvm::getGlobalContext(), llvm::APInt(1, 1, 10));
} else if(e->getValue() == "false") {
value_ = llvm::ConstantInt::get(llvm::getGlobalContext(), llvm::APInt(1, 0, 10));
} else {
value_ = llvm::ConstantInt::get(llvm::getGlobalContext(), llvm::APInt(32, e->getValue(), 10));
}
2013-06-01 13:04:53 +02:00
}
void CodeGenVisitor::visit(ExpressionStatement* e) {
2013-06-02 00:05:26 +02:00
value_ = 0;
2013-06-01 21:10:05 +02:00
e->getExpr()->accept(this);
2013-06-02 00:05:26 +02:00
if (!value_) {
throw "error evaluating expression";
}
2013-06-01 13:04:53 +02:00
}
2013-06-01 18:57:41 +02:00
void CodeGenVisitor::visit(ForStatement* e)
{
value_ = 0;
e->getInit()->accept(this);
2013-06-02 00:05:26 +02:00
value_ = 0;
2013-06-01 18:57:41 +02:00
e->getCond()->accept(this);
2013-06-02 00:05:26 +02:00
if (!value_) {
throw "error evaluating expression";
}
2013-06-01 18:57:41 +02:00
llvm::Function* f = builder_->GetInsertBlock()->getParent();
llvm::BasicBlock* loopBB = llvm::BasicBlock::Create(llvm::getGlobalContext(), "loop", f);
builder_->CreateBr(loopBB);
builder_->SetInsertPoint(loopBB);
value_ = 0;
e->getStmt()->accept(this);
value_ = 0;
e->getStep()->accept(this);
value_ = 0;
e->getCond()->accept(this);
2013-06-02 00:05:26 +02:00
if (!value_) {
throw "error evaluating expression";
2013-06-01 18:57:41 +02:00
}
llvm::BasicBlock* afterBB = llvm::BasicBlock::Create(llvm::getGlobalContext(), "afterLoop",f);
builder_->CreateCondBr(value_, loopBB, afterBB);
builder_->SetInsertPoint(afterBB);
2013-06-01 13:04:53 +02:00
}
void CodeGenVisitor::visit(FunctionCallExpression* e) {
2013-06-01 16:12:38 +02:00
llvm::Function* cf = module_->getFunction(e->getId());
if (!cf) {
2013-06-02 00:08:43 +02:00
throw "function to call not found";
2013-06-01 16:12:38 +02:00
return;
}
auto values = e->getValues()->getValues();
if (cf->arg_size() != values.size()) {
2013-06-02 00:08:43 +02:00
throw "argument size mismatch";
2013-06-01 16:12:38 +02:00
return;
}
std::vector<llvm::Value*> args;
auto iter = values.begin();
auto end = values.end();
for (; iter != end; ++iter) {
Expression *expr = (*iter);
expr->accept(this);
if (!value_) {
2013-06-02 00:08:43 +02:00
throw "error evaluating expression";
2013-06-01 16:12:38 +02:00
}
args.push_back(value_);
}
2013-06-02 08:44:32 +02:00
if (cf->getFunctionType()->getReturnType() == typeToLLVMType(Type::VOID)) {
builder_->CreateCall(cf, args);
// just handle void functions as if they returned 0
value_ = llvm::ConstantInt::get(llvm::getGlobalContext(), llvm::APInt(0, 32, 10));
} else {
value_ = builder_->CreateCall(cf, args, "calltmp");
}
2013-06-01 13:04:53 +02:00
}
void CodeGenVisitor::visit(FunctionDefinition* e) {
2013-06-01 16:53:27 +02:00
std::vector<llvm::Type*> argTypes;
auto params = e->getParams()->getParameters();
auto iter = params.begin();
auto end = params.end();
for (; iter != end; ++iter) {
Parameter& p = (*iter);
argTypes.push_back(typeToLLVMType(p.first));
}
llvm::FunctionType* ft = llvm::FunctionType::get(typeToLLVMType(e->getType()), argTypes, false);
llvm::Function *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, e->getName(), module_);
// If f conflicted, there was already something named 'Name'. If it has a body,
// don't allow redefinition or reextern.
if (f->getName() != e->getName()) {
f->eraseFromParent();
f = module_->getFunction(e->getName());
// If f already has a body, reject this.
if (!f->empty()) {
throw "redefinition of function";
}
// If f took a different number of args, reject.
if (f->arg_size() != e->getParams()->getParameters().size()) {
throw "redefinition of function with different # of arguments";
}
}
2013-06-01 17:22:52 +02:00
// set arg names
unsigned idx = 0;
for(auto ai = f->arg_begin(); idx != params.size(); ++ai, ++idx) {
ai->setName(params[idx].second);
//std::cout << "naming parameter " << idx << ": " << params[idx].second << std::endl;
2013-06-01 16:53:27 +02:00
2013-06-01 17:22:52 +02:00
// add to symbol table
putNamedValue(params[idx].second, ai);
}
llvm::BasicBlock* bb = llvm::BasicBlock::Create(llvm::getGlobalContext(), "entry", f);
builder_->SetInsertPoint(bb);
// put all arguments on the stack
idx = 0;
for(auto ai = f->arg_begin(); idx != params.size(); ++ai, ++idx) {
2013-06-01 17:37:18 +02:00
llvm::Value* alloca = builder_->CreateAlloca(typeToLLVMType(params[idx].first) ,0 , params[idx].second);
builder_->CreateStore(ai, alloca);
putNamedValue(params[idx].second, alloca);
2013-06-01 17:22:52 +02:00
}
// build code for the statements
e->getSl()->accept(this);
// validate generated code
llvm::verifyFunction(*f);
// optimize function
2013-06-01 21:10:05 +02:00
fpm_->run(*f);
2013-06-01 16:53:27 +02:00
value_ = f;
2013-06-01 13:04:53 +02:00
}
void CodeGenVisitor::visit(IfStatement* e) {
2013-06-01 17:45:16 +02:00
value_ = 0;
e->getCond()->accept(this);
2013-06-02 00:01:38 +02:00
if (!value_) {
throw "error evaluating expression";
}
2013-06-01 17:45:16 +02:00
llvm::Function* f = builder_->GetInsertBlock()->getParent();
llvm::BasicBlock* thenBB = llvm::BasicBlock::Create(llvm::getGlobalContext(), "then", f);
llvm::BasicBlock* elseBB = llvm::BasicBlock::Create(llvm::getGlobalContext(), "else");
llvm::BasicBlock* mergeBB = llvm::BasicBlock::Create(llvm::getGlobalContext(), "merge");
builder_->CreateCondBr(value_, thenBB, elseBB);
builder_->SetInsertPoint(thenBB);
e->getStmt()->accept(this);
builder_->CreateBr(mergeBB);
f->getBasicBlockList().push_back(elseBB);
builder_->SetInsertPoint(elseBB);
2013-06-02 00:08:43 +02:00
// we can add an else part here later ...
2013-06-01 17:45:16 +02:00
builder_->CreateBr(mergeBB);
f->getBasicBlockList().push_back(mergeBB);
builder_->SetInsertPoint(mergeBB);
2013-06-01 13:04:53 +02:00
}
void CodeGenVisitor::visit(ParameterList* e) {
2013-06-02 00:01:38 +02:00
// NOT USED
2013-06-01 13:04:53 +02:00
}
void CodeGenVisitor::visit(RandomForStatement* e) {
2013-06-01 21:55:37 +02:00
value_ = 0;
e->getInit()->accept(this);
2013-06-02 00:01:38 +02:00
value_ = 0;
2013-06-01 21:55:37 +02:00
e->getProb()->accept(this);
2013-06-02 00:01:38 +02:00
if (!value_) {
throw "error evaluating expression";
}
2013-06-01 21:55:37 +02:00
2013-06-02 00:01:38 +02:00
llvm::Function* cf = module_->getFunction("random_if");
llvm::Value* prob = builder_->CreateCall(cf, value_, "callTmp");
2013-06-01 21:55:37 +02:00
llvm::Function* f = builder_->GetInsertBlock()->getParent();
llvm::BasicBlock* loopBB = llvm::BasicBlock::Create(llvm::getGlobalContext(), "loop", f);
builder_->CreateBr(loopBB);
builder_->SetInsertPoint(loopBB);
value_ = 0;
e->getStmt()->accept(this);
value_ = 0;
e->getStep()->accept(this);
value_ = 0;
e->getProb()->accept(this);
if (value_ == 0) {
2013-06-02 00:01:38 +02:00
throw "error evaluating expression";
2013-06-01 21:55:37 +02:00
}
llvm::BasicBlock* afterBB = llvm::BasicBlock::Create(llvm::getGlobalContext(), "afterLoop",f);
2013-06-01 22:36:15 +02:00
prob = builder_->CreateCall(cf,value_,"callTmp");
2013-06-01 21:55:37 +02:00
builder_->CreateCondBr(prob, loopBB, afterBB);
builder_->SetInsertPoint(afterBB);
2013-06-01 13:04:53 +02:00
}
void CodeGenVisitor::visit(RandomIfStatement* e) {
2013-06-01 18:10:13 +02:00
value_ = 0;
e->getProb()->accept(this);
2013-06-02 00:01:38 +02:00
if (!value_) {
throw "error evaluating expression";
}
2013-06-01 18:18:48 +02:00
llvm::Function* cf = module_->getFunction("random_if");
llvm::Value* cond = builder_->CreateCall(cf,value_,"callTmp");
2013-06-01 18:13:17 +02:00
2013-06-01 18:10:13 +02:00
llvm::Function* f = builder_->GetInsertBlock()->getParent();
llvm::BasicBlock* thenBB = llvm::BasicBlock::Create(llvm::getGlobalContext(), "then", f);
llvm::BasicBlock* elseBB = llvm::BasicBlock::Create(llvm::getGlobalContext(), "else");
llvm::BasicBlock* mergeBB = llvm::BasicBlock::Create(llvm::getGlobalContext(), "merge");
2013-06-01 18:18:48 +02:00
builder_->CreateCondBr(cond, thenBB, elseBB);
2013-06-01 18:10:13 +02:00
// then
2013-06-01 18:18:48 +02:00
builder_->SetInsertPoint(thenBB);
2013-06-01 18:10:13 +02:00
e->getStmt()->accept(this);
builder_->CreateBr(mergeBB);
f->getBasicBlockList().push_back(elseBB);
builder_->SetInsertPoint(elseBB);
builder_->CreateBr(mergeBB);
f->getBasicBlockList().push_back(mergeBB);
builder_->SetInsertPoint(mergeBB);
2013-06-01 13:04:53 +02:00
}
void CodeGenVisitor::visit(ReturnStatement* e) {
if (e->getExpr() != 0) {
e->getExpr()->accept(this);
if (!value_) {
throw "error evaluating expression";
}
2013-06-02 00:01:38 +02:00
builder_->CreateRet(value_);
} else {
builder_->CreateRetVoid();
}
2013-06-01 13:04:53 +02:00
}
void CodeGenVisitor::visit(Scope* e) {
2013-06-01 20:29:32 +02:00
scope_++;
namedValues_.push_back(std::map<std::string, llvm::Value*>());
2013-06-01 17:37:18 +02:00
e->getSl()->accept(this);
2013-06-01 20:29:32 +02:00
namedValues_.pop_back();
scope_--;
2013-06-01 13:04:53 +02:00
}
void CodeGenVisitor::visit(StatementList* e) {
2013-06-01 17:22:52 +02:00
auto statements = e->getStatements();
auto iter = statements.begin();
auto end = statements.end();
for (; iter != end; ++iter) {
(*iter)->accept(this);
}
2013-06-01 13:04:53 +02:00
}
void CodeGenVisitor::visit(ValueList* e) {
// NOT USED
2013-06-01 13:04:53 +02:00
}
void CodeGenVisitor::visit(VariableDefinition* e) {
2013-06-01 20:11:29 +02:00
llvm::Function* f = builder_->GetInsertBlock()->getParent();
llvm::IRBuilder<> tmpBuilder(&(f->getEntryBlock()), f->getEntryBlock().begin());
llvm::Value* alloca = tmpBuilder.CreateAlloca(typeToLLVMType(e->getType()), 0 , e->getName());
putNamedValue(e->getName(), alloca);
2013-06-01 13:04:53 +02:00
}
2013-06-01 15:48:28 +02:00
void CodeGenVisitor::visit(LoadExpression *e) {
llvm::Value* v = getNamedValue(e->getId());
if (!v) {
throw "unknown variable name";
}
value_ = builder_->CreateLoad(v, e->getId());
2013-06-01 15:48:28 +02:00
}
2013-06-01 20:29:32 +02:00
void CodeGenVisitor::JIT(Expression* e) {
2013-06-01 21:10:05 +02:00
StatementList* sl = new StatementList();
sl->addStatement(new ReturnStatement(e));
FunctionDefinition* fd = new FunctionDefinition(Type::INT, "", new ParameterList(), sl);
value_ = 0;
fd->accept(this);
if (!value_) {
delete fd;
throw "error evaluating expression";
}
llvm::Function* f = dynamic_cast<llvm::Function*>(value_);
void* fPtr = ee_->getPointerToFunction(f);
// some casting ... because we like magic
int (*fP)() = (int (*)())(intptr_t)fPtr;
std::cout << "Evaluated to: " << fP() << std::endl;
2013-06-01 21:54:08 +02:00
// throw it away
f->eraseFromParent();
2013-06-01 15:48:28 +02:00
}
2013-06-01 17:22:52 +02:00
void CodeGenVisitor::putNamedValue(const std::string& name, llvm::Value* value) {
2013-06-01 20:29:32 +02:00
namedValues_[scope_][name] = value;
2013-06-01 17:22:52 +02:00
}
llvm::Value* CodeGenVisitor::getNamedValue(const std::string& name) {
2013-06-01 20:29:32 +02:00
llvm::Value* v = 0;
for (int i = scope_; i >= 0; i--) {
if (namedValues_[i][name] != 0) {
v = namedValues_[i][name];
break;
}
}
return v;
2013-06-01 17:22:52 +02:00
}