markly

Markov chain for text generation
git clone git://git.yotsev.xyz/markly.git
Log | Files | Refs | README | LICENSE

main.cpp (7867B)


      1 #include <boost/archive/binary_iarchive.hpp>
      2 #include <boost/archive/binary_oarchive.hpp>
      3 #include <boost/serialization/map.hpp>
      4 #include <boost/serialization/string.hpp>
      5 #include <boost/serialization/vector.hpp>
      6 #include <fstream>
      7 #include <iostream>
      8 #include <map>
      9 #include <string>
     10 #include <unistd.h>
     11 #include <vector>
     12 
     13 using namespace std;
     14 
     15 unsigned long mix(unsigned long a, unsigned long b, unsigned long c)
     16 {
     17     a = a - b;
     18     a = a - c;
     19     a = a ^ (c >> 13);
     20     b = b - c;
     21     b = b - a;
     22     b = b ^ (a << 8);
     23     c = c - a;
     24     c = c - b;
     25     c = c ^ (b >> 13);
     26     a = a - b;
     27     a = a - c;
     28     a = a ^ (c >> 12);
     29     b = b - c;
     30     b = b - a;
     31     b = b ^ (a << 16);
     32     c = c - a;
     33     c = c - b;
     34     c = c ^ (b >> 5);
     35     a = a - b;
     36     a = a - c;
     37     a = a ^ (c >> 3);
     38     b = b - c;
     39     b = b - a;
     40     b = b ^ (a << 10);
     41     c = c - a;
     42     c = c - b;
     43     c = c ^ (b >> 15);
     44     return c;
     45 }
     46 
     47 struct chain {
     48     friend class boost::serialization::access;
     49     template <class Archive>
     50     void serialize(Archive& ar, const unsigned int version)
     51     {
     52         ar& format;
     53         ar& order;
     54         ar& ngram;
     55         ar& beginnings;
     56     }
     57     char format;
     58     int order;
     59     map<string, vector<pair<char, int>>> ngram;
     60     vector<string> beginnings;
     61 };
     62 
     63 int main(int argc, char** argv)
     64 {
     65     // seed the random generator with quality values
     66     srand(mix(clock(), time(NULL), getpid()));
     67 
     68     // declaring state
     69     string filename;
     70     string chain_filename = "chain";
     71 
     72     char format = 'n';
     73     int order = 3;
     74     int length = 7;
     75     int itterations = 10;
     76 
     77     bool generating = true;
     78     bool saving = false;
     79     bool infinite = false;
     80     bool continuous = false;
     81     bool loud = false;
     82 
     83     //
     84     // processes arguments
     85     //
     86     string arg;
     87     for (int i = 1; i < argc; ++i) {
     88         arg = argv[i];
     89         if (arg == "-g") {
     90             filename = argv[i + 1];
     91             if (argv[i + 2] && argv[i + 2][0] != '-')
     92                 chain_filename = argv[i + 2];
     93             itterations = 0;
     94             generating = true;
     95             saving = true;
     96         } else if (arg == "-c") {
     97             chain_filename = argv[i + 1];
     98             generating = false;
     99             saving = false;
    100         } else if (arg == "-f") {
    101             filename = argv[i + 1];
    102             generating = true;
    103             saving = false;
    104         } else if (arg == "-o") {
    105             order = std::stoi(argv[i + 1]);
    106         } else if (arg == "-l") {
    107             length = std::stoi(argv[i + 1]);
    108         } else if (arg == "-m") {
    109             infinite = true;
    110         } else if (arg == "-C") {
    111             continuous = true;
    112         } else if (arg == "-n") {
    113             itterations = std::stoi(argv[i + 1]);
    114         } else if (arg == "-s") {
    115             format = 's';
    116         } else if (arg == "-v") {
    117             loud = true;
    118         }
    119     }
    120 
    121     map<string, vector<pair<char, int>>> ngram;
    122     vector<string> beginnings;
    123 
    124     vector<pair<char, int>>* chars;
    125 
    126     //
    127     // gets the chain in one way or another
    128     //
    129     if (generating) {
    130         // generates the chain
    131         if (loud)
    132             cerr << "Generating chain, this may take a while" << endl;
    133         string line;
    134         ifstream file(filename);
    135         if (file.is_open()) {
    136             if (format == 's') {
    137                 while (getline(file, line)) {
    138                     if (line.size() > order)
    139                         beginnings.push_back(line.substr(0, order));
    140                     for (int i = 1; i + order < line.size(); ++i) {
    141                         for (auto p : ngram[line.substr(i, order)]) {
    142                             if (p.first == line[i + order]) {
    143                                 p.second++;
    144                                 goto recorded;
    145                             }
    146                         }
    147                         {
    148                             pair<char, int> p(line[i + order], 1);
    149                             ngram[line.substr(i, order)].push_back(p);
    150                         }
    151                     recorded:;
    152                     }
    153                 }
    154             } else {
    155                 while (getline(file, line)) {
    156                     for (int i = 0; i + order < line.size(); ++i) {
    157                         for (auto p : ngram[line.substr(i, order)]) {
    158                             if (p.first == line[i + order]) {
    159                                 p.second++;
    160                                 goto recorded2;
    161                             }
    162                         }
    163                         {
    164                             pair<char, int> p(line[i + order], 1);
    165                             ngram[line.substr(i, order)].push_back(p);
    166                         }
    167                     recorded2:;
    168                     }
    169                 }
    170             }
    171             file.close();
    172         } else {
    173             cerr << "Error: file " << filename << " does not exist\n";
    174             return 1;
    175         }
    176     } else {
    177         // loads the chain if it exists
    178         ifstream ifs(chain_filename, ios::binary);
    179         if (ifs.is_open()) {
    180             if (loud)
    181                 cerr << "Loading chain from " << chain_filename << endl;
    182             boost::archive::binary_iarchive ia(ifs);
    183             chain c;
    184             ia >> c;
    185             format = c.format;
    186             order = c.order;
    187             ngram = c.ngram;
    188             if (format == 's')
    189                 beginnings = c.beginnings;
    190         } else {
    191             cerr << "Error: file " << chain_filename << " does not exist\n";
    192             return 1;
    193         }
    194     }
    195 
    196     // saves the chain
    197     if (saving) {
    198         chain_filename += ".o";
    199         chain_filename += std::to_string(order);
    200         if (loud)
    201             cerr << "Saving chain to " << chain_filename << endl;
    202         ofstream ofs(chain_filename, ios::binary);
    203         boost::archive::binary_oarchive oa(ofs);
    204         chain c;
    205         c.format = format;
    206         c.order = order;
    207         c.ngram = ngram;
    208         if (format == 's')
    209             c.beginnings = beginnings;
    210         oa << c;
    211     }
    212 
    213     //
    214     // generates text from the chain
    215     //
    216     for (int i = 0; i < itterations || continuous; ++i) {
    217 
    218         // get random starting gram
    219         string cgram;
    220 
    221         if (format == 's') {
    222             do {
    223                 cgram = beginnings[rand() % beginnings.size()];
    224             } while (ngram[cgram].size() == 0);
    225         } else {
    226             std::map<string, vector<pair<char, int>>>::iterator it;
    227             do {
    228                 it = ngram.begin();
    229                 for (int i = 0; i < rand() % ngram.size(); ++i)
    230                     ++it;
    231                 cgram = it->first;
    232             } while (ngram[cgram].size() == 0);
    233         }
    234 
    235         // print the beginning gram
    236         string result = cgram;
    237         cout << result;
    238 
    239         // find the next letter
    240         char next;
    241 
    242         int sum = 0;
    243         for (auto p : ngram[cgram])
    244             sum += p.second;
    245         int index = (rand() % sum) + 1;
    246         for (auto p : ngram[cgram]) {
    247             sum -= p.second;
    248             if (sum < index) {
    249                 next = p.first;
    250                 break;
    251             }
    252         }
    253 
    254         for (int i = 0; i < length - order || infinite; ++i) {
    255             // print out next letter
    256             cout << next;
    257             // get next cgram
    258             result += next;
    259             cgram = result.substr(result.length() - order, order);
    260             if (ngram[cgram].size() == 0)
    261                 break;
    262             // find the next letter
    263             sum = 0;
    264             for (auto p : ngram[cgram])
    265                 sum += p.second;
    266             index = (rand() % sum) + 1;
    267             for (auto p : ngram[cgram]) {
    268                 sum -= p.second;
    269                 if (sum < index) {
    270                     next = p.first;
    271                     break;
    272                 }
    273             }
    274         }
    275 
    276         if (format == 's')
    277             cout << endl;
    278     }
    279 
    280     return 0;
    281 }