From de379496ee411993dff94e52f393f6e19437a204 Mon Sep 17 00:00:00 2001 From: redpony Date: Mon, 18 Oct 2010 23:24:01 +0000 Subject: kenneth's LM preliminary integration git-svn-id: https://ws10smt.googlecode.com/svn/trunk@681 ec762483-ff6d-05da-a07a-a48fb63a330f --- klm/COPYING | 674 ++++++++++++++++++++++++++++++++++++ klm/COPYING.LESSER | 165 +++++++++ klm/LICENSE | 19 + klm/README | 31 ++ klm/clean.sh | 2 + klm/compile.sh | 11 + klm/lm/Makefile.am | 20 ++ klm/lm/exception.cc | 21 ++ klm/lm/exception.hh | 40 +++ klm/lm/facade.hh | 64 ++++ klm/lm/ngram.cc | 522 ++++++++++++++++++++++++++++ klm/lm/ngram.hh | 226 ++++++++++++ klm/lm/ngram_build_binary.cc | 13 + klm/lm/ngram_config.hh | 58 ++++ klm/lm/ngram_query.cc | 72 ++++ klm/lm/ngram_test.cc | 91 +++++ klm/lm/sri.cc | 115 ++++++ klm/lm/sri.hh | 102 ++++++ klm/lm/sri_test.cc | 65 ++++ klm/lm/test.arpa | 112 ++++++ klm/lm/test.binary | Bin 0 -> 1660 bytes klm/lm/virtual_interface.cc | 22 ++ klm/lm/virtual_interface.hh | 156 +++++++++ klm/lm/word_index.hh | 11 + klm/test.sh | 8 + klm/util/Makefile.am | 18 + klm/util/ersatz_progress.cc | 47 +++ klm/util/ersatz_progress.hh | 50 +++ klm/util/exception.cc | 35 ++ klm/util/exception.hh | 72 ++++ klm/util/file_piece.cc | 224 ++++++++++++ klm/util/file_piece.hh | 105 ++++++ klm/util/file_piece_test.cc | 41 +++ klm/util/joint_sort.hh | 145 ++++++++ klm/util/joint_sort_test.cc | 50 +++ klm/util/key_value_packing.hh | 122 +++++++ klm/util/key_value_packing_test.cc | 75 ++++ klm/util/mmap.cc | 95 +++++ klm/util/mmap.hh | 101 ++++++ klm/util/murmur_hash.cc | 129 +++++++ klm/util/murmur_hash.hh | 14 + klm/util/probing_hash_table.hh | 97 ++++++ klm/util/probing_hash_table_test.cc | 30 ++ klm/util/proxy_iterator.hh | 94 +++++ klm/util/scoped.cc | 12 + klm/util/scoped.hh | 66 ++++ klm/util/sorted_uniform.hh | 139 ++++++++ klm/util/sorted_uniform_test.cc | 116 +++++++ klm/util/string_piece.cc | 57 +++ klm/util/string_piece.hh | 260 ++++++++++++++ 50 files changed, 4814 insertions(+) create mode 100644 klm/COPYING create mode 100644 klm/COPYING.LESSER create mode 100644 klm/LICENSE create mode 100644 klm/README create mode 100755 klm/clean.sh create mode 100755 klm/compile.sh create mode 100644 klm/lm/Makefile.am create mode 100644 klm/lm/exception.cc create mode 100644 klm/lm/exception.hh create mode 100644 klm/lm/facade.hh create mode 100644 klm/lm/ngram.cc create mode 100644 klm/lm/ngram.hh create mode 100644 klm/lm/ngram_build_binary.cc create mode 100644 klm/lm/ngram_config.hh create mode 100644 klm/lm/ngram_query.cc create mode 100644 klm/lm/ngram_test.cc create mode 100644 klm/lm/sri.cc create mode 100644 klm/lm/sri.hh create mode 100644 klm/lm/sri_test.cc create mode 100644 klm/lm/test.arpa create mode 100644 klm/lm/test.binary create mode 100644 klm/lm/virtual_interface.cc create mode 100644 klm/lm/virtual_interface.hh create mode 100644 klm/lm/word_index.hh create mode 100755 klm/test.sh create mode 100644 klm/util/Makefile.am create mode 100644 klm/util/ersatz_progress.cc create mode 100644 klm/util/ersatz_progress.hh create mode 100644 klm/util/exception.cc create mode 100644 klm/util/exception.hh create mode 100644 klm/util/file_piece.cc create mode 100644 klm/util/file_piece.hh create mode 100644 klm/util/file_piece_test.cc create mode 100644 klm/util/joint_sort.hh create mode 100644 klm/util/joint_sort_test.cc create mode 100644 klm/util/key_value_packing.hh create mode 100644 klm/util/key_value_packing_test.cc create mode 100644 klm/util/mmap.cc create mode 100644 klm/util/mmap.hh create mode 100644 klm/util/murmur_hash.cc create mode 100644 klm/util/murmur_hash.hh create mode 100644 klm/util/probing_hash_table.hh create mode 100644 klm/util/probing_hash_table_test.cc create mode 100644 klm/util/proxy_iterator.hh create mode 100644 klm/util/scoped.cc create mode 100644 klm/util/scoped.hh create mode 100644 klm/util/sorted_uniform.hh create mode 100644 klm/util/sorted_uniform_test.cc create mode 100644 klm/util/string_piece.cc create mode 100644 klm/util/string_piece.hh (limited to 'klm') diff --git a/klm/COPYING b/klm/COPYING new file mode 100644 index 00000000..94a9ed02 --- /dev/null +++ b/klm/COPYING @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/klm/COPYING.LESSER b/klm/COPYING.LESSER new file mode 100644 index 00000000..cca7fc27 --- /dev/null +++ b/klm/COPYING.LESSER @@ -0,0 +1,165 @@ + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. diff --git a/klm/LICENSE b/klm/LICENSE new file mode 100644 index 00000000..20b76c13 --- /dev/null +++ b/klm/LICENSE @@ -0,0 +1,19 @@ +Most of the code here is licensed under the LGPL. There are exceptions which have their own licenses, listed below. See comments in those files for more details. + +util/murmur_hash.cc is under the MIT license. +util/string_piece.hh and util/string_piece.cc are Google code and contains its own license. + +For the rest: + + Avenue code is free software: you can redistribute it and/or modify + it under the terms of the GNU Lesser General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Avenue code is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public License + along with Avenue code. If not, see . diff --git a/klm/README b/klm/README new file mode 100644 index 00000000..8d1050a8 --- /dev/null +++ b/klm/README @@ -0,0 +1,31 @@ +Language model inference code by Kenneth Heafield +See LICENSE for list of files by other people and their licenses. + +Compile: ./compile.sh +Run: ./query lm/test.arpa +#include + +namespace lm { + +LoadException::LoadException() throw() {} +LoadException::~LoadException() throw() {} +VocabLoadException::VocabLoadException() throw() {} +VocabLoadException::~VocabLoadException() throw() {} + +FormatLoadException::FormatLoadException() throw() {} +FormatLoadException::~FormatLoadException() throw() {} + +SpecialWordMissingException::SpecialWordMissingException(StringPiece which) throw() { + *this << "Missing special word " << which; +} +SpecialWordMissingException::~SpecialWordMissingException() throw() {} + +} // namespace lm diff --git a/klm/lm/exception.hh b/klm/lm/exception.hh new file mode 100644 index 00000000..95109012 --- /dev/null +++ b/klm/lm/exception.hh @@ -0,0 +1,40 @@ +#ifndef LM_EXCEPTION__ +#define LM_EXCEPTION__ + +#include "util/exception.hh" +#include "util/string_piece.hh" + +#include +#include + +namespace lm { + +class LoadException : public util::Exception { + public: + virtual ~LoadException() throw(); + + protected: + LoadException() throw(); +}; + +class VocabLoadException : public LoadException { + public: + virtual ~VocabLoadException() throw(); + VocabLoadException() throw(); +}; + +class FormatLoadException : public LoadException { + public: + FormatLoadException() throw(); + ~FormatLoadException() throw(); +}; + +class SpecialWordMissingException : public VocabLoadException { + public: + explicit SpecialWordMissingException(StringPiece which) throw(); + ~SpecialWordMissingException() throw(); +}; + +} // namespace lm + +#endif diff --git a/klm/lm/facade.hh b/klm/lm/facade.hh new file mode 100644 index 00000000..8b186017 --- /dev/null +++ b/klm/lm/facade.hh @@ -0,0 +1,64 @@ +#ifndef LM_FACADE__ +#define LM_FACADE__ + +#include "lm/virtual_interface.hh" +#include "util/string_piece.hh" + +#include + +namespace lm { +namespace base { + +// Common model interface that depends on knowing the specific classes. +// Curiously recurring template pattern. +template class ModelFacade : public Model { + public: + typedef StateT State; + typedef VocabularyT Vocabulary; + + // Default Score function calls FullScore. Model can override this. + float Score(const State &in_state, const WordIndex new_word, State &out_state) const { + return static_cast(this)->FullScore(in_state, new_word, out_state).prob; + } + + /* Translate from void* to State */ + FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const { + return static_cast(this)->FullScore( + *reinterpret_cast(in_state), + new_word, + *reinterpret_cast(out_state)); + } + float Score(const void *in_state, const WordIndex new_word, void *out_state) const { + return static_cast(this)->Score( + *reinterpret_cast(in_state), + new_word, + *reinterpret_cast(out_state)); + } + + const State &BeginSentenceState() const { return begin_sentence_; } + const State &NullContextState() const { return null_context_; } + const Vocabulary &GetVocabulary() const { return *static_cast(&BaseVocabulary()); } + + protected: + ModelFacade() : Model(sizeof(State)) {} + + virtual ~ModelFacade() {} + + // begin_sentence and null_context can disappear after. vocab should stay. + void Init(const State &begin_sentence, const State &null_context, const Vocabulary &vocab, unsigned char order) { + begin_sentence_ = begin_sentence; + null_context_ = null_context; + begin_sentence_memory_ = &begin_sentence_; + null_context_memory_ = &null_context_; + base_vocab_ = &vocab; + order_ = order; + } + + private: + State begin_sentence_, null_context_; +}; + +} // mamespace base +} // namespace lm + +#endif // LM_FACADE__ diff --git a/klm/lm/ngram.cc b/klm/lm/ngram.cc new file mode 100644 index 00000000..a87c82aa --- /dev/null +++ b/klm/lm/ngram.cc @@ -0,0 +1,522 @@ +#include "lm/ngram.hh" + +#include "lm/exception.hh" +#include "util/file_piece.hh" +#include "util/joint_sort.hh" +#include "util/murmur_hash.hh" +#include "util/probing_hash_table.hh" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace lm { +namespace ngram { + +size_t hash_value(const State &state) { + return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_); +} + +namespace detail { +uint64_t HashForVocab(const char *str, std::size_t len) { + // This proved faster than Boost's hash in speed trials: total load time Murmur 67090000, Boost 72210000 + // Chose to use 64A instead of native so binary format will be portable across 64 and 32 bit. + return util::MurmurHash64A(str, len, 0); +} + +void Prob::SetBackoff(float to) { + UTIL_THROW(FormatLoadException, "Attempt to set backoff " << to << " for the highest order n-gram"); +} + +// Normally static initialization is a bad idea but MurmurHash is pure arithmetic, so this is ok. +const uint64_t kUnknownHash = HashForVocab("", 5); +// Sadly some LMs have . +const uint64_t kUnknownCapHash = HashForVocab("", 5); + +} // namespace detail + +SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL) {} + +std::size_t SortedVocabulary::Size(std::size_t entries, float ignored) { + // Lead with the number of entries. + return sizeof(uint64_t) + sizeof(Entry) * entries; +} + +void SortedVocabulary::Init(void *start, std::size_t allocated, std::size_t entries) { + assert(allocated >= Size(entries)); + // Leave space for number of entries. + begin_ = reinterpret_cast(reinterpret_cast(start) + 1); + end_ = begin_; + saw_unk_ = false; +} + +WordIndex SortedVocabulary::Insert(const StringPiece &str) { + uint64_t hashed = detail::HashForVocab(str); + if (hashed == detail::kUnknownHash || hashed == detail::kUnknownCapHash) { + saw_unk_ = true; + return 0; + } + end_->key = hashed; + ++end_; + // This is 1 + the offset where it was inserted to make room for unk. + return end_ - begin_; +} + +bool SortedVocabulary::FinishedLoading(detail::ProbBackoff *reorder_vocab) { + util::JointSort(begin_, end_, reorder_vocab + 1); + SetSpecial(Index(""), Index(""), 0, end_ - begin_ + 1); + // Save size. + *(reinterpret_cast(begin_) - 1) = end_ - begin_; + return saw_unk_; +} + +void SortedVocabulary::LoadedBinary() { + end_ = begin_ + *(reinterpret_cast(begin_) - 1); + SetSpecial(Index(""), Index(""), 0, end_ - begin_ + 1); +} + +namespace detail { + +template MapVocabulary::MapVocabulary() {} + +template void MapVocabulary::Init(void *start, std::size_t allocated, std::size_t entries) { + lookup_ = Lookup(start, allocated); + available_ = 1; + // Later if available_ != expected_available_ then we can throw UnknownMissingException. + saw_unk_ = false; +} + +template WordIndex MapVocabulary::Insert(const StringPiece &str) { + uint64_t hashed = HashForVocab(str); + // Prevent unknown from going into the table. + if (hashed == kUnknownHash || hashed == kUnknownCapHash) { + saw_unk_ = true; + return 0; + } else { + lookup_.Insert(Lookup::Packing::Make(hashed, available_)); + return available_++; + } +} + +template bool MapVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { + lookup_.FinishedInserting(); + SetSpecial(Index(""), Index(""), 0, available_); + return saw_unk_; +} + +template void MapVocabulary::LoadedBinary() { + lookup_.LoadedBinary(); + SetSpecial(Index(""), Index(""), 0, available_); +} + +/* All of the entropy is in low order bits and boost::hash does poorly with + * these. Odd numbers near 2^64 chosen by mashing on the keyboard. There is a + * stable point: 0. But 0 is which won't be queried here anyway. + */ +inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { + uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast(next) * 17894857484156487943ULL); + return ret; +} + +uint64_t ChainedWordHash(const WordIndex *word, const WordIndex *word_end) { + if (word == word_end) return 0; + uint64_t current = static_cast(*word); + for (++word; word != word_end; ++word) { + current = CombineWordHash(current, *word); + } + return current; +} + +bool IsEntirelyWhiteSpace(const StringPiece &line) { + for (size_t i = 0; i < static_cast(line.size()); ++i) { + if (!isspace(line.data()[i])) return false; + } + return true; +} + +void ReadARPACounts(util::FilePiece &in, std::vector &number) { + number.clear(); + StringPiece line; + if (!IsEntirelyWhiteSpace(line = in.ReadLine())) UTIL_THROW(FormatLoadException, "First line was \"" << line << "\" not blank"); + if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\."); + while (!IsEntirelyWhiteSpace(line = in.ReadLine())) { + if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \""); + // So strtol doesn't go off the end of line. + std::string remaining(line.data() + 6, line.size() - 6); + char *end_ptr; + unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10); + if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line); + if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line); + ++end_ptr; + const char *start = end_ptr; + long int count = std::strtol(start, &end_ptr, 10); + if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count); + if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line); + number.push_back(count); + } +} + +void ReadNGramHeader(util::FilePiece &in, unsigned int length) { + StringPiece line; + while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} + std::stringstream expected; + expected << '\\' << length << "-grams:"; + if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead."); +} + +// Special unigram reader because unigram's data structure is different and because we're inserting vocab words. +template void Read1Grams(util::FilePiece &f, const size_t count, Voc &vocab, ProbBackoff *unigrams) { + ReadNGramHeader(f, 1); + for (size_t i = 0; i < count; ++i) { + try { + float prob = f.ReadFloat(); + if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability"); + ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited())]; + value.prob = prob; + switch (f.get()) { + case '\t': + value.SetBackoff(f.ReadFloat()); + if ((f.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff"); + break; + case '\n': + value.ZeroBackoff(); + break; + default: + UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram"); + } + } catch(util::Exception &e) { + e << " in the " << i << "th 1-gram at byte " << f.Offset(); + throw; + } + } + if (f.ReadLine().size()) UTIL_THROW(FormatLoadException, "Expected blank line after unigrams at byte " << f.Offset()); +} + +template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, Store &store) { + ReadNGramHeader(f, n); + + // vocab ids of words in reverse order + WordIndex vocab_ids[n]; + typename Store::Packing::Value value; + for (size_t i = 0; i < count; ++i) { + try { + value.prob = f.ReadFloat(); + for (WordIndex *vocab_out = &vocab_ids[n-1]; vocab_out >= vocab_ids; --vocab_out) { + *vocab_out = vocab.Index(f.ReadDelimited()); + } + uint64_t key = ChainedWordHash(vocab_ids, vocab_ids + n); + + switch (f.get()) { + case '\t': + value.SetBackoff(f.ReadFloat()); + if ((f.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff"); + break; + case '\n': + value.ZeroBackoff(); + break; + default: + UTIL_THROW(FormatLoadException, "Expected tab or newline after n-gram"); + } + store.Insert(Store::Packing::Make(key, value)); + } catch(util::Exception &e) { + e << " in the " << i << "th " << n << "-gram at byte " << f.Offset(); + throw; + } + } + + if (f.ReadLine().size()) UTIL_THROW(FormatLoadException, "Expected blank line after " << n << "-grams at byte " << f.Offset()); + store.FinishedInserting(); +} + +template size_t GenericModel::Size(const std::vector &counts, const Config &config) { + if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile."); + if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); + size_t memory_size = VocabularyT::Size(counts[0], config.probing_multiplier); + memory_size += sizeof(ProbBackoff) * (counts[0] + 1); // +1 for hallucinate + for (unsigned char n = 2; n < counts.size(); ++n) { + memory_size += Middle::Size(counts[n - 1], config.probing_multiplier); + } + memory_size += Longest::Size(counts.back(), config.probing_multiplier); + return memory_size; +} + +template void GenericModel::SetupMemory(char *base, const std::vector &counts, const Config &config) { + char *start = base; + size_t allocated = VocabularyT::Size(counts[0], config.probing_multiplier); + vocab_.Init(start, allocated, counts[0]); + start += allocated; + unigram_ = reinterpret_cast(start); + start += sizeof(ProbBackoff) * (counts[0] + 1); + for (unsigned int n = 2; n < counts.size(); ++n) { + allocated = Middle::Size(counts[n - 1], config.probing_multiplier); + middle_.push_back(Middle(start, allocated)); + start += allocated; + } + allocated = Longest::Size(counts.back(), config.probing_multiplier); + longest_ = Longest(start, allocated); + start += allocated; + if (static_cast(start - base) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - base) << " but Size says they should take " << Size(counts, config)); +} + +const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 0\n\0"; +struct BinaryFileHeader { + char magic[sizeof(kMagicBytes)]; + float zero_f, one_f, minus_half_f; + WordIndex one_word_index, max_word_index; + uint64_t one_uint64; + + void SetToReference() { + std::memcpy(magic, kMagicBytes, sizeof(magic)); + zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5; + one_word_index = 1; + max_word_index = std::numeric_limits::max(); + one_uint64 = 1; + } +}; + +bool IsBinaryFormat(int fd, off_t size) { + if (size == util::kBadSize || (size <= static_cast(sizeof(BinaryFileHeader)))) return false; + // Try reading the header. + util::scoped_mmap memory(mmap(NULL, sizeof(BinaryFileHeader), PROT_READ, MAP_FILE | MAP_PRIVATE, fd, 0), sizeof(BinaryFileHeader)); + if (memory.get() == MAP_FAILED) return false; + BinaryFileHeader reference_header = BinaryFileHeader(); + reference_header.SetToReference(); + if (!memcmp(memory.get(), &reference_header, sizeof(BinaryFileHeader))) return true; + if (!memcmp(memory.get(), "mmap lm ", 8)) UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Was it built on a different machine or with a different compiler?"); + return false; +} + +std::size_t Align8(std::size_t in) { + std::size_t off = in % 8; + if (!off) return in; + return in + 8 - off; +} + +std::size_t TotalHeaderSize(unsigned int order) { + return Align8(sizeof(BinaryFileHeader) + 1 /* order */ + sizeof(uint64_t) * order /* counts */ + sizeof(float) /* probing multiplier */ + 1 /* search_tag */); +} + +void ReadBinaryHeader(const void *from, off_t size, std::vector &out, float &probing_multiplier, unsigned char &search_tag) { + const char *from_char = reinterpret_cast(from); + if (size < static_cast(1 + sizeof(BinaryFileHeader))) UTIL_THROW(FormatLoadException, "File too short to have count information."); + // Skip over the BinaryFileHeader which was read by IsBinaryFormat. + from_char += sizeof(BinaryFileHeader); + unsigned char order = *reinterpret_cast(from_char); + if (size < static_cast(TotalHeaderSize(order))) UTIL_THROW(FormatLoadException, "File too short to have full header."); + out.resize(static_cast(order)); + const uint64_t *counts = reinterpret_cast(from_char + 1); + for (std::size_t i = 0; i < out.size(); ++i) { + out[i] = static_cast(counts[i]); + } + const float *probing_ptr = reinterpret_cast(counts + out.size()); + probing_multiplier = *probing_ptr; + search_tag = *reinterpret_cast(probing_ptr + 1); +} + +void WriteBinaryHeader(void *to, const std::vector &from, float probing_multiplier, char search_tag) { + BinaryFileHeader header = BinaryFileHeader(); + header.SetToReference(); + memcpy(to, &header, sizeof(BinaryFileHeader)); + char *out = reinterpret_cast(to) + sizeof(BinaryFileHeader); + *reinterpret_cast(out) = static_cast(from.size()); + uint64_t *counts = reinterpret_cast(out + 1); + for (std::size_t i = 0; i < from.size(); ++i) { + counts[i] = from[i]; + } + float *probing_ptr = reinterpret_cast(counts + from.size()); + *probing_ptr = probing_multiplier; + *reinterpret_cast(probing_ptr + 1) = search_tag; +} + +template GenericModel::GenericModel(const char *file, Config config) : mapped_file_(util::OpenReadOrThrow(file)) { + const off_t file_size = util::SizeFile(mapped_file_.get()); + + std::vector counts; + + if (IsBinaryFormat(mapped_file_.get(), file_size)) { + memory_.reset(util::MapForRead(file_size, config.prefault, mapped_file_.get()), file_size); + + unsigned char search_tag; + ReadBinaryHeader(memory_.begin(), file_size, counts, config.probing_multiplier, search_tag); + if (config.probing_multiplier < 1.0) UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << config.probing_multiplier << " which is < 1.0."); + if (search_tag != Search::kBinaryTag) UTIL_THROW(FormatLoadException, "The binary file has a different search strategy than the one requested."); + size_t memory_size = Size(counts, config); + + char *start = reinterpret_cast(memory_.get()) + TotalHeaderSize(counts.size()); + if (memory_size != static_cast(memory_.end() - start)) UTIL_THROW(FormatLoadException, "The mmap file " << file << " has size " << file_size << " but " << (memory_size + TotalHeaderSize(counts.size())) << " was expected based on the number of counts and configuration."); + + SetupMemory(start, counts, config); + vocab_.LoadedBinary(); + for (typename std::vector::iterator i = middle_.begin(); i != middle_.end(); ++i) { + i->LoadedBinary(); + } + longest_.LoadedBinary(); + + } else { + if (config.probing_multiplier <= 1.0) UTIL_THROW(FormatLoadException, "probing multiplier must be > 1.0"); + + util::FilePiece f(file, mapped_file_.release(), config.messages); + ReadARPACounts(f, counts); + size_t memory_size = Size(counts, config); + char *start; + + if (config.write_mmap) { + // Write out an mmap file. + util::MapZeroedWrite(config.write_mmap, TotalHeaderSize(counts.size()) + memory_size, mapped_file_, memory_); + WriteBinaryHeader(memory_.get(), counts, config.probing_multiplier, Search::kBinaryTag); + start = reinterpret_cast(memory_.get()) + TotalHeaderSize(counts.size()); + } else { + memory_.reset(util::MapAnonymous(memory_size), memory_size); + start = reinterpret_cast(memory_.get()); + } + SetupMemory(start, counts, config); + try { + LoadFromARPA(f, counts, config); + } catch (FormatLoadException &e) { + e << " in file " << file; + throw; + } + } + + // g++ prints warnings unless these are fully initialized. + State begin_sentence = State(); + begin_sentence.valid_length_ = 1; + begin_sentence.history_[0] = vocab_.BeginSentence(); + begin_sentence.backoff_[0] = unigram_[begin_sentence.history_[0]].backoff; + State null_context = State(); + null_context.valid_length_ = 0; + P::Init(begin_sentence, null_context, vocab_, counts.size()); +} + +template void GenericModel::LoadFromARPA(util::FilePiece &f, const std::vector &counts, const Config &config) { + // Read the unigrams. + Read1Grams(f, counts[0], vocab_, unigram_); + bool saw_unk = vocab_.FinishedLoading(unigram_); + if (!saw_unk) { + switch(config.unknown_missing) { + case Config::THROW_UP: + { + SpecialWordMissingException e(""); + e << " and configuration was set to throw if unknown is missing"; + throw e; + } + case Config::COMPLAIN: + if (config.messages) *config.messages << "Language model is missing . Substituting probability " << config.unknown_missing_prob << "." << std::endl; + // There's no break;. This is by design. + case Config::SILENT: + // Default probabilities for unknown. + unigram_[0].backoff = 0.0; + unigram_[0].prob = config.unknown_missing_prob; + break; + } + } + + // Read the n-grams. + for (unsigned int n = 2; n < counts.size(); ++n) { + ReadNGrams(f, n, counts[n-1], vocab_, middle_[n-2]); + } + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab_, longest_); + if (std::fabs(unigram_[0].backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << unigram_[0].backoff); +} + +/* Ugly optimized function. + * in_state contains the previous ngram's length and backoff probabilites to + * be used here. out_state is populated with the found ngram length and + * backoffs that the next call will find useful. + * + * The search goes in increasing order of ngram length. + */ +template FullScoreReturn GenericModel::FullScore( + const State &in_state, + const WordIndex new_word, + State &out_state) const { + + FullScoreReturn ret; + // This is end pointer passed to SumBackoffs. + const ProbBackoff &unigram = unigram_[new_word]; + if (new_word == 0) { + ret.ngram_length = out_state.valid_length_ = 0; + // all of backoff. + ret.prob = std::accumulate( + in_state.backoff_, + in_state.backoff_ + in_state.valid_length_, + unigram.prob); + return ret; + } + float *backoff_out(out_state.backoff_); + *backoff_out = unigram.backoff; + ret.prob = unigram.prob; + out_state.history_[0] = new_word; + if (in_state.valid_length_ == 0) { + ret.ngram_length = out_state.valid_length_ = 1; + // No backoff because NGramLength() == 0 and unknown can't have backoff. + return ret; + } + ++backoff_out; + + // Ok now we now that the bigram contains known words. Start by looking it up. + + uint64_t lookup_hash = static_cast(new_word); + const WordIndex *hist_iter = in_state.history_; + const WordIndex *const hist_end = hist_iter + in_state.valid_length_; + typename std::vector::const_iterator mid_iter = middle_.begin(); + for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { + if (hist_iter == hist_end) { + // Used history [in_state.history_, hist_end) and ran out. No backoff. + std::copy(in_state.history_, hist_end, out_state.history_ + 1); + ret.ngram_length = out_state.valid_length_ = in_state.valid_length_ + 1; + // ret.prob was already set. + return ret; + } + lookup_hash = CombineWordHash(lookup_hash, *hist_iter); + if (mid_iter == middle_.end()) break; + typename Middle::ConstIterator found; + if (!mid_iter->Find(lookup_hash, found)) { + // Didn't find an ngram using hist_iter. + // The history used in the found n-gram is [in_state.history_, hist_iter). + std::copy(in_state.history_, hist_iter, out_state.history_ + 1); + // Therefore, we found a (hist_iter - in_state.history_ + 1)-gram including the last word. + ret.ngram_length = out_state.valid_length_ = (hist_iter - in_state.history_) + 1; + ret.prob = std::accumulate( + in_state.backoff_ + (mid_iter - middle_.begin()), + in_state.backoff_ + in_state.valid_length_, + ret.prob); + return ret; + } + *backoff_out = found->GetValue().backoff; + ret.prob = found->GetValue().prob; + } + + typename Longest::ConstIterator found; + if (!longest_.Find(lookup_hash, found)) { + // It's an (P::Order()-1)-gram + std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1); + ret.ngram_length = out_state.valid_length_ = P::Order() - 1; + ret.prob += in_state.backoff_[P::Order() - 2]; + return ret; + } + // It's an P::Order()-gram + // out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much. + std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1); + out_state.valid_length_ = P::Order() - 1; + ret.ngram_length = P::Order(); + ret.prob = found->GetValue().prob; + return ret; +} + +template class GenericModel >; +template class GenericModel; +} // namespace detail +} // namespace ngram +} // namespace lm diff --git a/klm/lm/ngram.hh b/klm/lm/ngram.hh new file mode 100644 index 00000000..899a80e8 --- /dev/null +++ b/klm/lm/ngram.hh @@ -0,0 +1,226 @@ +#ifndef LM_NGRAM__ +#define LM_NGRAM__ + +#include "lm/facade.hh" +#include "lm/ngram_config.hh" +#include "util/key_value_packing.hh" +#include "util/mmap.hh" +#include "util/probing_hash_table.hh" +#include "util/scoped.hh" +#include "util/sorted_uniform.hh" +#include "util/string_piece.hh" + +#include +#include +#include + +namespace util { class FilePiece; } + +namespace lm { +namespace ngram { + +// If you need higher order, change this and recompile. +// Having this limit means that State can be +// (kMaxOrder - 1) * sizeof(float) bytes instead of +// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead +const std::size_t kMaxOrder = 6; + +// This is a POD. +class State { + public: + bool operator==(const State &other) const { + if (valid_length_ != other.valid_length_) return false; + const WordIndex *end = history_ + valid_length_; + for (const WordIndex *first = history_, *second = other.history_; + first != end; ++first, ++second) { + if (*first != *second) return false; + } + // If the histories are equal, so are the backoffs. + return true; + } + + // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD. + // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit. + WordIndex history_[kMaxOrder - 1]; + float backoff_[kMaxOrder - 1]; + unsigned char valid_length_; +}; + +size_t hash_value(const State &state); + +namespace detail { + +uint64_t HashForVocab(const char *str, std::size_t len); +inline uint64_t HashForVocab(const StringPiece &str) { + return HashForVocab(str.data(), str.length()); +} + +struct Prob { + float prob; + void SetBackoff(float to); + void ZeroBackoff() {} +}; +// No inheritance so this will be a POD. +struct ProbBackoff { + float prob; + float backoff; + void SetBackoff(float to) { backoff = to; } + void ZeroBackoff() { backoff = 0.0; } +}; + +} // namespace detail + +// Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices. +class SortedVocabulary : public base::Vocabulary { + private: + // Sorted uniform requires a GetKey function. + struct Entry { + uint64_t GetKey() const { return key; } + uint64_t key; + bool operator<(const Entry &other) const { + return key < other.key; + } + }; + + public: + SortedVocabulary(); + + WordIndex Index(const StringPiece &str) const { + const Entry *found; + if (util::SortedUniformFind(begin_, end_, detail::HashForVocab(str), found)) { + return found - begin_ + 1; // +1 because is 0 and does not appear in the lookup table. + } else { + return 0; + } + } + + // Ignores second argument for consistency with probing hash which has a float here. + static size_t Size(std::size_t entries, float ignored = 0.0); + + // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. + void Init(void *start, std::size_t allocated, std::size_t entries); + + WordIndex Insert(const StringPiece &str); + + // Returns true if unknown was seen. Reorders reorder_vocab so that the IDs are sorted. + bool FinishedLoading(detail::ProbBackoff *reorder_vocab); + + void LoadedBinary(); + + private: + Entry *begin_, *end_; + + bool saw_unk_; +}; + +namespace detail { + +// Vocabulary storing a map from uint64_t to WordIndex. +template class MapVocabulary : public base::Vocabulary { + public: + MapVocabulary(); + + WordIndex Index(const StringPiece &str) const { + typename Lookup::ConstIterator i; + return lookup_.Find(HashForVocab(str), i) ? i->GetValue() : 0; + } + + static size_t Size(std::size_t entries, float probing_multiplier) { + return Lookup::Size(entries, probing_multiplier); + } + + // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. + void Init(void *start, std::size_t allocated, std::size_t entries); + + WordIndex Insert(const StringPiece &str); + + // Returns true if unknown was seen. Does nothing with reorder_vocab. + bool FinishedLoading(ProbBackoff *reorder_vocab); + + void LoadedBinary(); + + private: + typedef typename Search::template Table::T Lookup; + Lookup lookup_; + + bool saw_unk_; +}; + +// std::identity is an SGI extension :-( +struct IdentityHash : public std::unary_function { + size_t operator()(uint64_t arg) const { return static_cast(arg); } +}; + +// Should return the same results as SRI. +// Why VocabularyT instead of just Vocabulary? ModelFacade defines Vocabulary. +template class GenericModel : public base::ModelFacade, State, VocabularyT> { + private: + typedef base::ModelFacade, State, VocabularyT> P; + public: + // Get the size of memory that will be mapped given ngram counts. This + // does not include small non-mapped control structures, such as this class + // itself. + static size_t Size(const std::vector &counts, const Config &config = Config()); + + GenericModel(const char *file, Config config = Config()); + + FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const; + + private: + // Appears after Size in the cc. + void SetupMemory(char *start, const std::vector &counts, const Config &config); + + void LoadFromARPA(util::FilePiece &f, const std::vector &counts, const Config &config); + + util::scoped_fd mapped_file_; + + // memory_ is the raw block of memory backing vocab_, unigram_, [middle.begin(), middle.end()), and longest_. + util::scoped_mmap memory_; + + VocabularyT vocab_; + + ProbBackoff *unigram_; + + typedef typename Search::template Table::T Middle; + std::vector middle_; + + typedef typename Search::template Table::T Longest; + Longest longest_; +}; + +struct ProbingSearch { + typedef float Init; + + static const unsigned char kBinaryTag = 1; + + template struct Table { + typedef util::ByteAlignedPacking Packing; + typedef util::ProbingHashTable T; + }; +}; + +struct SortedUniformSearch { + // This is ignored. + typedef float Init; + + static const unsigned char kBinaryTag = 2; + + template struct Table { + typedef util::ByteAlignedPacking Packing; + typedef util::SortedUniformMap T; + }; +}; + +} // namespace detail + +// These must also be instantiated in the cc file. +typedef detail::MapVocabulary Vocabulary; +typedef detail::GenericModel Model; + +// SortedVocabulary was defined above. +typedef detail::GenericModel SortedModel; + +} // namespace ngram +} // namespace lm + +#endif // LM_NGRAM__ diff --git a/klm/lm/ngram_build_binary.cc b/klm/lm/ngram_build_binary.cc new file mode 100644 index 00000000..9dab30a1 --- /dev/null +++ b/klm/lm/ngram_build_binary.cc @@ -0,0 +1,13 @@ +#include "lm/ngram.hh" + +#include + +int main(int argc, char *argv[]) { + if (argc != 3) { + std::cerr << "Usage: " << argv[0] << " input.arpa output.mmap" << std::endl; + return 1; + } + lm::ngram::Config config; + config.write_mmap = argv[2]; + lm::ngram::Model(argv[1], config); +} diff --git a/klm/lm/ngram_config.hh b/klm/lm/ngram_config.hh new file mode 100644 index 00000000..a7b3afae --- /dev/null +++ b/klm/lm/ngram_config.hh @@ -0,0 +1,58 @@ +#ifndef LM_NGRAM_CONFIG__ +#define LM_NGRAM_CONFIG__ + +/* Configuration for ngram model. Separate header to reduce pollution. */ + +#include + +namespace lm { namespace ngram { + +struct Config { + /* EFFECTIVE FOR BOTH ARPA AND BINARY READS */ + // Where to log messages including the progress bar. Set to NULL for + // silence. + std::ostream *messages; + + + + /* ONLY EFFECTIVE WHEN READING ARPA */ + + // What to do when isn't in the provided model. + typedef enum {THROW_UP, COMPLAIN, SILENT} UnknownMissing; + UnknownMissing unknown_missing; + + // The probability to substitute for if it's missing from the model. + // No effect if the model has or unknown_missing == THROW_UP. + float unknown_missing_prob; + + // Size multiplier for probing hash table. Must be > 1. Space is linear in + // this. Time is probing_multiplier / (probing_multiplier - 1). No effect + // for sorted variant. + // If you find yourself setting this to a low number, consider using the + // Sorted version instead which has lower memory consumption. + float probing_multiplier; + + // While loading an ARPA file, also write out this binary format file. Set + // to NULL to disable. + const char *write_mmap; + + + + /* ONLY EFFECTIVE WHEN READING BINARY */ + bool prefault; + + + + // Defaults. + Config() : + messages(&std::cerr), + unknown_missing(COMPLAIN), + unknown_missing_prob(0.0), + probing_multiplier(1.5), + write_mmap(NULL), + prefault(false) {} +}; + +} /* namespace ngram */ } /* namespace lm */ + +#endif // LM_NGRAM_CONFIG__ diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc new file mode 100644 index 00000000..d1970260 --- /dev/null +++ b/klm/lm/ngram_query.cc @@ -0,0 +1,72 @@ +#include "lm/ngram.hh" + +#include +#include +#include +#include + +#include +#include + +float FloatSec(const struct timeval &tv) { + return static_cast(tv.tv_sec) + (static_cast(tv.tv_usec) / 1000000000.0); +} + +void PrintUsage(const char *message) { + struct rusage usage; + if (getrusage(RUSAGE_SELF, &usage)) { + perror("getrusage"); + return; + } + std::cerr << message; + std::cerr << "user\t" << FloatSec(usage.ru_utime) << "\nsys\t" << FloatSec(usage.ru_stime) << '\n'; + + // Linux doesn't set memory usage :-(. + std::ifstream status("/proc/self/status", std::ios::in); + std::string line; + while (getline(status, line)) { + if (!strncmp(line.c_str(), "VmRSS:\t", 7)) { + std::cerr << "rss " << (line.c_str() + 7) << '\n'; + break; + } + } +} + +template void Query(const Model &model) { + PrintUsage("Loading statistics:\n"); + typename Model::State state, out; + lm::FullScoreReturn ret; + std::string word; + + while (std::cin) { + state = model.BeginSentenceState(); + float total = 0.0; + bool got = false; + while (std::cin >> word) { + got = true; + ret = model.FullScore(state, model.GetVocabulary().Index(word), out); + total += ret.prob; + std::cout << word << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << ' '; + state = out; + if (std::cin.get() == '\n') break; + } + if (!got && !std::cin) break; + ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); + total += ret.prob; + std::cout << " " << static_cast(ret.ngram_length) << ' ' << ret.prob << ' '; + std::cout << "Total: " << total << '\n'; + } + PrintUsage("After queries:\n"); +} + +int main(int argc, char *argv[]) { + if (argc < 2) { + std::cerr << "Pass language model name." << std::endl; + return 0; + } + { + lm::ngram::Model ngram(argv[1]); + Query(ngram); + } + PrintUsage("Total time including destruction:\n"); +} diff --git a/klm/lm/ngram_test.cc b/klm/lm/ngram_test.cc new file mode 100644 index 00000000..031e0348 --- /dev/null +++ b/klm/lm/ngram_test.cc @@ -0,0 +1,91 @@ +#include "lm/ngram.hh" + +#include + +#define BOOST_TEST_MODULE NGramTest +#include + +namespace lm { +namespace ngram { +namespace { + +#define StartTest(word, ngram, score) \ + ret = model.FullScore( \ + state, \ + model.GetVocabulary().Index(word), \ + out);\ + BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ + BOOST_CHECK_EQUAL(static_cast(ngram), ret.ngram_length); \ + BOOST_CHECK_EQUAL(std::min(ngram, 5 - 1), out.valid_length_); + +#define AppendTest(word, ngram, score) \ + StartTest(word, ngram, score) \ + state = out; + +template void Starters(M &model) { + FullScoreReturn ret; + Model::State state(model.BeginSentenceState()); + Model::State out; + + StartTest("looking", 2, -0.4846522); + + // , probability plus backoff + StartTest(",", 1, -1.383514 + -0.4149733); + // probability plus backoff + StartTest("this_is_not_found", 0, -1.995635 + -0.4149733); +} + +template void Continuation(M &model) { + FullScoreReturn ret; + Model::State state(model.BeginSentenceState()); + Model::State out; + + AppendTest("looking", 2, -0.484652); + AppendTest("on", 3, -0.348837); + AppendTest("a", 4, -0.0155266); + AppendTest("little", 5, -0.00306122); + State preserve = state; + AppendTest("the", 1, -4.04005); + AppendTest("biarritz", 1, -1.9889); + AppendTest("not_found", 0, -2.29666); + AppendTest("more", 1, -1.20632); + AppendTest(".", 2, -0.51363); + AppendTest("", 3, -0.0191651); + + state = preserve; + AppendTest("more", 5, -0.00181395); + AppendTest("loin", 5, -0.0432557); +} + +BOOST_AUTO_TEST_CASE(starters_probing) { Model m("test.arpa"); Starters(m); } +BOOST_AUTO_TEST_CASE(continuation_probing) { Model m("test.arpa"); Continuation(m); } +BOOST_AUTO_TEST_CASE(starters_sorted) { SortedModel m("test.arpa"); Starters(m); } +BOOST_AUTO_TEST_CASE(continuation_sorted) { SortedModel m("test.arpa"); Continuation(m); } + +BOOST_AUTO_TEST_CASE(write_and_read_probing) { + Config config; + config.write_mmap = "test.binary"; + { + Model copy_model("test.arpa", config); + } + Model binary("test.binary"); + Starters(binary); + Continuation(binary); +} + +BOOST_AUTO_TEST_CASE(write_and_read_sorted) { + Config config; + config.write_mmap = "test.binary"; + config.prefault = true; + { + SortedModel copy_model("test.arpa", config); + } + SortedModel binary("test.binary"); + Starters(binary); + Continuation(binary); +} + + +} // namespace +} // namespace ngram +} // namespace lm diff --git a/klm/lm/sri.cc b/klm/lm/sri.cc new file mode 100644 index 00000000..7bd23d76 --- /dev/null +++ b/klm/lm/sri.cc @@ -0,0 +1,115 @@ +#include "lm/exception.hh" +#include "lm/sri.hh" + +#include +#include + +#include + +namespace lm { +namespace sri { + +Vocabulary::Vocabulary() : sri_(new Vocab) {} + +Vocabulary::~Vocabulary() {} + +WordIndex Vocabulary::Index(const char *str) const { + WordIndex ret = sri_->getIndex(str); + // NGram wants the index of Vocab_Unknown for unknown words, but for some reason SRI returns Vocab_None here :-(. + if (ret == Vocab_None) { + return not_found_; + } else { + return ret; + } +} + +const char *Vocabulary::Word(WordIndex index) const { + return sri_->getWord(index); +} + +void Vocabulary::FinishedLoading() { + SetSpecial( + sri_->ssIndex(), + sri_->seIndex(), + sri_->unkIndex(), + sri_->highIndex() + 1); +} + +namespace { +Ngram *MakeSRIModel(const char *file_name, unsigned int ngram_length, Vocab &sri_vocab) { + sri_vocab.unkIsWord() = true; + std::auto_ptr ret(new Ngram(sri_vocab, ngram_length)); + File file(file_name, "r"); + errno = 0; + if (!ret->read(file)) { + UTIL_THROW(FormatLoadException, "reading file " << file_name << " with SRI failed."); + } + return ret.release(); +} +} // namespace + +Model::Model(const char *file_name, unsigned int ngram_length) : sri_(MakeSRIModel(file_name, ngram_length, *vocab_.sri_)) { + if (!sri_->setorder()) { + UTIL_THROW(FormatLoadException, "Can't have an SRI model with order 0."); + } + vocab_.FinishedLoading(); + State begin_state = State(); + begin_state.valid_length_ = 1; + if (kMaxOrder > 1) { + begin_state.history_[0] = vocab_.BeginSentence(); + if (kMaxOrder > 2) begin_state.history_[1] = Vocab_None; + } + State null_state = State(); + null_state.valid_length_ = 0; + if (kMaxOrder > 1) null_state.history_[0] = Vocab_None; + Init(begin_state, null_state, vocab_, sri_->setorder()); + not_found_ = vocab_.NotFound(); +} + +Model::~Model() {} + +namespace { + +/* Argh SRI's wordProb knows the ngram length but doesn't return it. One more + * reason you should use my model. */ +// TODO(stolcke): fix SRILM so I don't have to do this. +unsigned int MatchedLength(Ngram &model, const WordIndex new_word, const SRIVocabIndex *const_history) { + unsigned int out_length = 0; + // This gets the length of context used, which is ngram_length - 1 unless new_word is OOV in which case it is 0. + model.contextID(new_word, const_history, out_length); + return out_length + 1; +} + +} // namespace + +FullScoreReturn Model::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { + // If you get a compiler in this function, change SRIVocabIndex in sri.hh to match the one found in SRI's Vocab.h. + const SRIVocabIndex *const_history; + SRIVocabIndex local_history[Order()]; + if (in_state.valid_length_ < kMaxOrder - 1) { + const_history = in_state.history_; + } else { + std::copy(in_state.history_, in_state.history_ + in_state.valid_length_, local_history); + local_history[in_state.valid_length_] = Vocab_None; + const_history = local_history; + } + FullScoreReturn ret; + if (new_word != not_found_) { + ret.ngram_length = MatchedLength(*sri_, new_word, const_history); + out_state.history_[0] = new_word; + out_state.valid_length_ = std::min(ret.ngram_length, Order() - 1); + std::copy(const_history, const_history + out_state.valid_length_ - 1, out_state.history_ + 1); + if (out_state.valid_length_ < kMaxOrder - 1) { + out_state.history_[out_state.valid_length_] = Vocab_None; + } + } else { + ret.ngram_length = 0; + if (kMaxOrder > 1) out_state.history_[0] = Vocab_None; + out_state.valid_length_ = 0; + } + ret.prob = sri_->wordProb(new_word, const_history); + return ret; +} + +} // namespace sri +} // namespace lm diff --git a/klm/lm/sri.hh b/klm/lm/sri.hh new file mode 100644 index 00000000..b57e9b73 --- /dev/null +++ b/klm/lm/sri.hh @@ -0,0 +1,102 @@ +#ifndef LM_SRI__ +#define LM_SRI__ + +#include "lm/facade.hh" +#include "util/murmur_hash.hh" + +#include +#include +#include + +class Ngram; +class Vocab; + +/* The ngram length reported uses some random API I found and may be wrong. + * + * See ngram, which should return equivalent results. + */ + +namespace lm { +namespace sri { + +static const unsigned int kMaxOrder = 6; + +/* This should match VocabIndex found in SRI's Vocab.h + * The reason I define this here independently is that SRI's headers + * pollute and increase compile time. + * It's difficult to extract this from their header and anyway would + * break packaging. + * If these differ there will be a compiler error in ActuallyCall. + */ +typedef unsigned int SRIVocabIndex; + +class State { + public: + // You shouldn't need to touch these, but they're public so State will be a POD. + // If valid_length_ < kMaxOrder - 1 then history_[valid_length_] == Vocab_None. + SRIVocabIndex history_[kMaxOrder - 1]; + unsigned char valid_length_; +}; + +inline bool operator==(const State &left, const State &right) { + if (left.valid_length_ != right.valid_length_) { + return false; + } + for (const SRIVocabIndex *l = left.history_, *r = right.history_; + l != left.history_ + left.valid_length_; + ++l, ++r) { + if (*l != *r) return false; + } + return true; +} + +inline size_t hash_value(const State &state) { + return util::MurmurHashNative(&state.history_, sizeof(SRIVocabIndex) * state.valid_length_); +} + +class Vocabulary : public base::Vocabulary { + public: + Vocabulary(); + + ~Vocabulary(); + + WordIndex Index(const StringPiece &str) const { + std::string temp(str.data(), str.length()); + return Index(temp.c_str()); + } + WordIndex Index(const std::string &str) const { + return Index(str.c_str()); + } + WordIndex Index(const char *str) const; + + const char *Word(WordIndex index) const; + + private: + friend class Model; + void FinishedLoading(); + + // The parent class isn't copyable so auto_ptr is the same as scoped_ptr + // but without the boost dependence. + mutable std::auto_ptr sri_; +}; + +class Model : public base::ModelFacade { + public: + Model(const char *file_name, unsigned int ngram_length); + + ~Model(); + + FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const; + + private: + Vocabulary vocab_; + + mutable std::auto_ptr sri_; + + WordIndex not_found_; +}; + +} // namespace sri +} // namespace lm + +#endif // LM_SRI__ diff --git a/klm/lm/sri_test.cc b/klm/lm/sri_test.cc new file mode 100644 index 00000000..e697d722 --- /dev/null +++ b/klm/lm/sri_test.cc @@ -0,0 +1,65 @@ +#include "lm/sri.hh" + +#include + +#define BOOST_TEST_MODULE SRITest +#include + +namespace lm { +namespace sri { +namespace { + +#define StartTest(word, ngram, score) \ + ret = model.FullScore( \ + state, \ + model.GetVocabulary().Index(word), \ + out);\ + BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ + BOOST_CHECK_EQUAL(static_cast(ngram), ret.ngram_length); \ + BOOST_CHECK_EQUAL(std::min(ngram, 5 - 1), out.valid_length_); + +#define AppendTest(word, ngram, score) \ + StartTest(word, ngram, score) \ + state = out; + +template void Starters(M &model) { + FullScoreReturn ret; + Model::State state(model.BeginSentenceState()); + Model::State out; + + StartTest("looking", 2, -0.4846522); + + // , probability plus backoff + StartTest(",", 1, -1.383514 + -0.4149733); + // probability plus backoff + StartTest("this_is_not_found", 0, -1.995635 + -0.4149733); +} + +template void Continuation(M &model) { + FullScoreReturn ret; + Model::State state(model.BeginSentenceState()); + Model::State out; + + AppendTest("looking", 2, -0.484652); + AppendTest("on", 3, -0.348837); + AppendTest("a", 4, -0.0155266); + AppendTest("little", 5, -0.00306122); + State preserve = state; + AppendTest("the", 1, -4.04005); + AppendTest("biarritz", 1, -1.9889); + AppendTest("not_found", 0, -2.29666); + AppendTest("more", 1, -1.20632); + AppendTest(".", 2, -0.51363); + AppendTest("", 3, -0.0191651); + + state = preserve; + AppendTest("more", 5, -0.00181395); + AppendTest("loin", 5, -0.0432557); +} + +BOOST_AUTO_TEST_CASE(starters) { Model m("test.arpa", 5); Starters(m); } +BOOST_AUTO_TEST_CASE(continuation) { Model m("test.arpa", 5); Continuation(m); } + +} // namespace +} // namespace sri +} // namespace lm diff --git a/klm/lm/test.arpa b/klm/lm/test.arpa new file mode 100644 index 00000000..9d674e83 --- /dev/null +++ b/klm/lm/test.arpa @@ -0,0 +1,112 @@ + +\data\ +ngram 1=34 +ngram 2=43 +ngram 3=8 +ngram 4=5 +ngram 5=3 + +\1-grams: +-1.383514 , -0.30103 +-1.139057 . -0.845098 +-1.029493 +-99 -0.4149733 +-1.995635 +-1.285941 a -0.69897 +-1.687872 also -0.30103 +-1.687872 beyond -0.30103 +-1.687872 biarritz -0.30103 +-1.687872 call -0.30103 +-1.687872 concerns -0.30103 +-1.687872 consider -0.30103 +-1.687872 considering -0.30103 +-1.687872 for -0.30103 +-1.509559 higher -0.30103 +-1.687872 however -0.30103 +-1.687872 i -0.30103 +-1.687872 immediate -0.30103 +-1.687872 in -0.30103 +-1.687872 is -0.30103 +-1.285941 little -0.69897 +-1.383514 loin -0.30103 +-1.687872 look -0.30103 +-1.285941 looking -0.4771212 +-1.206319 more -0.544068 +-1.509559 on -0.4771212 +-1.509559 screening -0.4771212 +-1.687872 small -0.30103 +-1.687872 the -0.30103 +-1.687872 to -0.30103 +-1.687872 watch -0.30103 +-1.687872 watching -0.30103 +-1.687872 what -0.30103 +-1.687872 would -0.30103 + +\2-grams: +-0.6925742 , . +-0.7522095 , however +-0.7522095 , is +-0.0602359 . +-0.4846522 looking -0.4771214 +-1.051485 screening +-1.07153 the +-1.07153 watching +-1.07153 what +-0.09132547 a little -0.69897 +-0.2922095 also call +-0.2922095 beyond immediate +-0.2705918 biarritz . +-0.2922095 call for +-0.2922095 concerns in +-0.2922095 consider watch +-0.2922095 considering consider +-0.2834328 for , +-0.5511513 higher more +-0.5845945 higher small +-0.2834328 however , +-0.2922095 i would +-0.2922095 immediate concerns +-0.2922095 in biarritz +-0.2922095 is to +-0.09021038 little more -0.1998621 +-0.7273645 loin , +-0.6925742 loin . +-0.6708385 loin +-0.2922095 look beyond +-0.4638903 looking higher +-0.4638903 looking on -0.4771212 +-0.5136299 more . -0.4771212 +-0.3561665 more loin +-0.1649931 on a -0.4771213 +-0.1649931 screening a -0.4771213 +-0.2705918 small . +-0.287799 the screening +-0.2922095 to look +-0.2622373 watch +-0.2922095 watching considering +-0.2922095 what i +-0.2922095 would also + +\3-grams: +-0.01916512 more . +-0.0283603 on a little -0.4771212 +-0.0283603 screening a little -0.4771212 +-0.01660496 a little more -0.09409451 +-0.3488368 looking higher +-0.3488368 looking on -0.4771212 +-0.1892331 little more loin +-0.04835128 looking on a -0.4771212 + +\4-grams: +-0.009249173 looking on a little -0.4771212 +-0.005464747 on a little more -0.4771212 +-0.005464747 screening a little more +-0.1453306 a little more loin +-0.01552657 looking on a -0.4771212 + +\5-grams: +-0.003061223 looking on a little +-0.001813953 looking on a little more +-0.0432557 on a little more loin + +\end\ diff --git a/klm/lm/test.binary b/klm/lm/test.binary new file mode 100644 index 00000000..90bd2b76 Binary files /dev/null and b/klm/lm/test.binary differ diff --git a/klm/lm/virtual_interface.cc b/klm/lm/virtual_interface.cc new file mode 100644 index 00000000..9c7151f9 --- /dev/null +++ b/klm/lm/virtual_interface.cc @@ -0,0 +1,22 @@ +#include "lm/virtual_interface.hh" + +#include "lm/exception.hh" + +namespace lm { +namespace base { + +Vocabulary::~Vocabulary() {} + +void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available) { + begin_sentence_ = begin_sentence; + end_sentence_ = end_sentence; + not_found_ = not_found; + available_ = available; + if (begin_sentence_ == not_found_) throw SpecialWordMissingException(""); + if (end_sentence_ == not_found_) throw SpecialWordMissingException(""); +} + +Model::~Model() {} + +} // namespace base +} // namespace lm diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh new file mode 100644 index 00000000..621a129e --- /dev/null +++ b/klm/lm/virtual_interface.hh @@ -0,0 +1,156 @@ +#ifndef LM_VIRTUAL_INTERFACE__ +#define LM_VIRTUAL_INTERFACE__ + +#include "lm/word_index.hh" +#include "util/string_piece.hh" + +#include + +namespace lm { + +struct FullScoreReturn { + float prob; + unsigned char ngram_length; +}; + +namespace base { + +template class ModelFacade; + +/* Vocabulary interface. Call Index(string) and get a word index for use in + * calling Model. It provides faster convenience functions for , , and + * although you can also find these using Index. + * + * Some models do not load the mapping from index to string. If you need this, + * check if the model Vocabulary class implements such a function and access it + * directly. + * + * The Vocabulary object is always owned by the Model and can be retrieved from + * the Model using BaseVocabulary() for this abstract interface or + * GetVocabulary() for the actual implementation (in which case you'll need the + * actual implementation of the Model too). + */ +class Vocabulary { + public: + virtual ~Vocabulary(); + + WordIndex BeginSentence() const { return begin_sentence_; } + WordIndex EndSentence() const { return end_sentence_; } + WordIndex NotFound() const { return not_found_; } + // FullScoreReturn start index of unused word assignments. + WordIndex Available() const { return available_; } + + /* Most implementations allow StringPiece lookups and need only override + * Index(StringPiece). SRI requires null termination and overrides all + * three methods. + */ + virtual WordIndex Index(const StringPiece &str) const = 0; + virtual WordIndex Index(const std::string &str) const { + return Index(StringPiece(str)); + } + virtual WordIndex Index(const char *str) const { + return Index(StringPiece(str)); + } + + protected: + // Call SetSpecial afterward. + Vocabulary() {} + + Vocabulary(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available) { + SetSpecial(begin_sentence, end_sentence, not_found, available); + } + + void SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available); + + WordIndex begin_sentence_, end_sentence_, not_found_, available_; + + private: + // Disable copy constructors. They're private and undefined. + // Ersatz boost::noncopyable. + Vocabulary(const Vocabulary &); + Vocabulary &operator=(const Vocabulary &); +}; + +/* There are two ways to access a Model. + * + * + * OPTION 1: Access the Model directly (e.g. lm::ngram::Model in ngram.hh). + * Every Model implements the scoring function: + * float Score( + * const Model::State &in_state, + * const WordIndex new_word, + * Model::State &out_state) const; + * + * It can also return the length of n-gram matched by the model: + * FullScoreReturn FullScore( + * const Model::State &in_state, + * const WordIndex new_word, + * Model::State &out_state) const; + * + * There are also accessor functions: + * const State &BeginSentenceState() const; + * const State &NullContextState() const; + * const Vocabulary &GetVocabulary() const; + * unsigned int Order() const; + * + * NB: In case you're wondering why the model implementation looks like it's + * missing these methods, see facade.hh. + * + * This is the fastest way to use a model and presents a normal State class to + * be included in hypothesis state structure. + * + * + * OPTION 2: Use the virtual interface below. + * + * The virtual interface allow you to decide which Model to use at runtime + * without templatizing everything on the Model type. However, each Model has + * its own State class, so a single State cannot be efficiently provided (it + * would require using the maximum memory of any Model's State or memory + * allocation with each lookup). This means you become responsible for + * allocating memory with size StateSize() and passing it to the Score or + * FullScore functions provided here. + * + * For example, cdec has a std::string containing the entire state of a + * hypothesis. It can reserve StateSize bytes in this string for the model + * state. + * + * All the State objects are POD, so it's ok to use raw memory for storing + * State. + */ +class Model { + public: + virtual ~Model(); + + size_t StateSize() const { return state_size_; } + const void *BeginSentenceMemory() const { return begin_sentence_memory_; } + const void *NullContextMemory() const { return null_context_memory_; } + + virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0; + + virtual FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0; + + unsigned char Order() const { return order_; } + + const Vocabulary &BaseVocabulary() const { return *base_vocab_; } + + private: + template friend class ModelFacade; + explicit Model(size_t state_size) : state_size_(state_size) {} + + const size_t state_size_; + const void *begin_sentence_memory_, *null_context_memory_; + + const Vocabulary *base_vocab_; + + unsigned char order_; + + // Disable copy constructors. They're private and undefined. + // Ersatz boost::noncopyable. + Model(const Model &); + Model &operator=(const Model &); +}; + +} // mamespace base +} // namespace lm + +#endif // LM_VIRTUAL_INTERFACE__ diff --git a/klm/lm/word_index.hh b/klm/lm/word_index.hh new file mode 100644 index 00000000..67841c30 --- /dev/null +++ b/klm/lm/word_index.hh @@ -0,0 +1,11 @@ +// Separate header because this is used often. +#ifndef LM_WORD_INDEX__ +#define LM_WORD_INDEX__ + +namespace lm { +typedef unsigned int WordIndex; +} // namespace lm + +typedef lm::WordIndex LMWordIndex; + +#endif diff --git a/klm/test.sh b/klm/test.sh new file mode 100755 index 00000000..b741cdf0 --- /dev/null +++ b/klm/test.sh @@ -0,0 +1,8 @@ +#!/bin/bash +#Run tests. Requires Boost. +set -e +./compile.sh +for i in util/{file_piece,joint_sort,key_value_packing,probing_hash_table,sorted_uniform}_test lm/ngram_test; do + g++ -I. -O3 $i.cc {lm,util}/*.o -lboost_test_exec_monitor -o $i + pushd $(dirname $i) && ./$(basename $i); popd +done diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am new file mode 100644 index 00000000..d3aea6b7 --- /dev/null +++ b/klm/util/Makefile.am @@ -0,0 +1,18 @@ +if HAVE_GTEST +noinst_PROGRAMS = \ + scorer_test +TESTS = scorer_test +endif + +noinst_LIBRARIES = libklm_util.a + +libklm_util_a_SOURCES = \ + ersatz_progress.cc \ + exception.cc \ + file_piece.cc \ + mmap.cc \ + murmur_hash.cc \ + scoped.cc \ + string_piece.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/util/ersatz_progress.cc b/klm/util/ersatz_progress.cc new file mode 100644 index 00000000..09e3a106 --- /dev/null +++ b/klm/util/ersatz_progress.cc @@ -0,0 +1,47 @@ +#include "util/ersatz_progress.hh" + +#include +#include +#include +#include + +namespace util { + +namespace { const unsigned char kWidth = 100; } + +ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits::max()), complete_(next_), out_(NULL) {} + +ErsatzProgress::~ErsatzProgress() { + if (!out_) return; + for (; stones_written_ < kWidth; ++stones_written_) { + (*out_) << '*'; + } + *out_ << '\n'; +} + +ErsatzProgress::ErsatzProgress(std::ostream *to, const std::string &message, std::size_t complete) + : current_(0), next_(complete / kWidth), complete_(complete), stones_written_(0), out_(to) { + if (!out_) { + next_ = std::numeric_limits::max(); + return; + } + *out_ << message << "\n----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n"; +} + +void ErsatzProgress::Milestone() { + if (!out_) { current_ = 0; return; } + if (!complete_) return; + unsigned char stone = std::min(static_cast(kWidth), (current_ * kWidth) / complete_); + + for (; stones_written_ < stone; ++stones_written_) { + (*out_) << '*'; + } + + if (current_ >= complete_) { + next_ = std::numeric_limits::max(); + } else { + next_ = std::max(next_, (stone * complete_) / kWidth); + } +} + +} // namespace util diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh new file mode 100644 index 00000000..ea6c3bb9 --- /dev/null +++ b/klm/util/ersatz_progress.hh @@ -0,0 +1,50 @@ +#ifndef UTIL_ERSATZ_PROGRESS__ +#define UTIL_ERSATZ_PROGRESS__ + +#include +#include + +// Ersatz version of boost::progress so core language model doesn't depend on +// boost. Also adds option to print nothing. + +namespace util { +class ErsatzProgress { + public: + // No output. + ErsatzProgress(); + + // Null means no output. The null value is useful for passing along the ostream pointer from another caller. + ErsatzProgress(std::ostream *to, const std::string &message, std::size_t complete); + + ~ErsatzProgress(); + + ErsatzProgress &operator++() { + if (++current_ == next_) Milestone(); + return *this; + } + + ErsatzProgress &operator+=(std::size_t amount) { + if ((current_ += amount) >= next_) Milestone(); + return *this; + } + + void Set(std::size_t to) { + if ((current_ = to) >= next_) Milestone(); + Milestone(); + } + + private: + void Milestone(); + + std::size_t current_, next_, complete_; + unsigned char stones_written_; + std::ostream *out_; + + // noncopyable + ErsatzProgress(const ErsatzProgress &other); + ErsatzProgress &operator=(const ErsatzProgress &other); +}; + +} // namespace util + +#endif // UTIL_ERSATZ_PROGRESS__ diff --git a/klm/util/exception.cc b/klm/util/exception.cc new file mode 100644 index 00000000..dd337a76 --- /dev/null +++ b/klm/util/exception.cc @@ -0,0 +1,35 @@ +#include "util/exception.hh" + +#include +#include + +namespace util { + +Exception::Exception() throw() {} +Exception::~Exception() throw() {} + +namespace { +// The XOPEN version. +const char *HandleStrerror(int ret, const char *buf) { + if (!ret) return buf; + return NULL; +} + +// The GNU version. +const char *HandleStrerror(const char *ret, const char *buf) { + return ret; +} +} // namespace + +ErrnoException::ErrnoException() throw() : errno_(errno) { + char buf[200]; + buf[0] = 0; + const char *add = HandleStrerror(strerror_r(errno, buf, 200), buf); + if (add) { + *this << add << ' '; + } +} + +ErrnoException::~ErrnoException() throw() {} + +} // namespace util diff --git a/klm/util/exception.hh b/klm/util/exception.hh new file mode 100644 index 00000000..124689cf --- /dev/null +++ b/klm/util/exception.hh @@ -0,0 +1,72 @@ +#ifndef UTIL_EXCEPTION__ +#define UTIL_EXCEPTION__ + +#include "util/string_piece.hh" + +#include +#include +#include + +namespace util { + +class Exception : public std::exception { + public: + Exception() throw(); + virtual ~Exception() throw(); + + const char *what() const throw() { return what_.c_str(); } + + // This helps restrict operator<< defined below. + template struct ExceptionTag { + typedef T Identity; + }; + + std::string &Str() { + return what_; + } + + protected: + std::string what_; +}; + +/* This implements the normal operator<< for Exception and all its children. + * SNIFAE means it only applies to Exception. Think of this as an ersatz + * boost::enable_if. + */ +template typename Except::template ExceptionTag::Identity operator<<(Except &e, const Data &data) { + // Argh I had a stringstream in the exception, but the only way to get the string is by calling str(). But that's a temporary string, so virtual const char *what() const can't actually return it. + std::stringstream stream; + stream << data; + e.Str() += stream.str(); + return e; +} +template typename Except::template ExceptionTag::Identity operator<<(Except &e, const char *data) { + e.Str() += data; + return e; +} +template typename Except::template ExceptionTag::Identity operator<<(Except &e, const std::string &data) { + e.Str() += data; + return e; +} +template typename Except::template ExceptionTag::Identity operator<<(Except &e, const StringPiece &str) { + e.Str().append(str.data(), str.length()); + return e; +} + +#define UTIL_THROW(Exception, Modify) { Exception UTIL_e; {UTIL_e << Modify;} throw UTIL_e; } + +class ErrnoException : public Exception { + public: + ErrnoException() throw(); + + virtual ~ErrnoException() throw(); + + int Error() { return errno_; } + + private: + int errno_; +}; + +} // namespace util + +#endif // UTIL_EXCEPTION__ diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc new file mode 100644 index 00000000..2b439499 --- /dev/null +++ b/klm/util/file_piece.cc @@ -0,0 +1,224 @@ +#include "util/file_piece.hh" + +#include "util/exception.hh" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace util { + +EndOfFileException::EndOfFileException() throw() { + *this << "End of file"; +} +EndOfFileException::~EndOfFileException() throw() {} + +ParseNumberException::ParseNumberException(StringPiece value) throw() { + *this << "Could not parse \"" << value << "\" into a float"; +} + +int OpenReadOrThrow(const char *name) { + int ret = open(name, O_RDONLY); + if (ret == -1) UTIL_THROW(ErrnoException, "in open (" << name << ") for reading"); + return ret; +} + +off_t SizeFile(int fd) { + struct stat sb; + if (fstat(fd, &sb) == -1 || (!sb.st_size && !S_ISREG(sb.st_mode))) return kBadSize; + return sb.st_size; +} + +FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) : + file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)), + progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) { + Initialize(name, show_progress, min_buffer); +} + +FilePiece::FilePiece(const char *name, int fd, std::ostream *show_progress, off_t min_buffer) : + file_(fd), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)), + progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) { + Initialize(name, show_progress, min_buffer); +} + +void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) { + if (total_size_ == kBadSize) { + fallback_to_read_ = true; + if (show_progress) + *show_progress << "File " << name << " isn't normal. Using slower read() instead of mmap(). No progress bar." << std::endl; + } else { + fallback_to_read_ = false; + } + default_map_size_ = page_ * std::max((min_buffer / page_ + 1), 2); + position_ = NULL; + position_end_ = NULL; + mapped_offset_ = 0; + at_end_ = false; + Shift(); +} + +float FilePiece::ReadFloat() throw(EndOfFileException, ParseNumberException) { + SkipSpaces(); + while (last_space_ < position_) { + if (at_end_) { + // Hallucinate a null off the end of the file. + std::string buffer(position_, position_end_); + char *end; + float ret = std::strtof(buffer.c_str(), &end); + if (buffer.c_str() == end) throw ParseNumberException(buffer); + position_ += end - buffer.c_str(); + return ret; + } + Shift(); + } + char *end; + float ret = std::strtof(position_, &end); + if (end == position_) throw ParseNumberException(ReadDelimited()); + position_ = end; + return ret; +} + +void FilePiece::SkipSpaces() throw (EndOfFileException) { + for (; ; ++position_) { + if (position_ == position_end_) Shift(); + if (!isspace(*position_)) return; + } +} + +const char *FilePiece::FindDelimiterOrEOF() throw (EndOfFileException) { + for (const char *i = position_; i <= last_space_; ++i) { + if (isspace(*i)) return i; + } + while (!at_end_) { + size_t skip = position_end_ - position_; + Shift(); + for (const char *i = position_ + skip; i <= last_space_; ++i) { + if (isspace(*i)) return i; + } + } + return position_end_; +} + +StringPiece FilePiece::ReadLine(char delim) throw (EndOfFileException) { + const char *start = position_; + do { + for (const char *i = start; i < position_end_; ++i) { + if (*i == delim) { + StringPiece ret(position_, i - position_); + position_ = i + 1; + return ret; + } + } + size_t skip = position_end_ - position_; + Shift(); + start = position_ + skip; + } while (!at_end_); + StringPiece ret(position_, position_end_ - position_); + position_ = position_end_; + return position_; +} + +void FilePiece::Shift() throw(EndOfFileException) { + if (at_end_) throw EndOfFileException(); + off_t desired_begin = position_ - data_.begin() + mapped_offset_; + progress_.Set(desired_begin); + + if (!fallback_to_read_) MMapShift(desired_begin); + // Notice an mmap failure might set the fallback. + if (fallback_to_read_) ReadShift(desired_begin); + + for (last_space_ = position_end_ - 1; last_space_ >= position_; --last_space_) { + if (isspace(*last_space_)) break; + } +} + +void FilePiece::MMapShift(off_t desired_begin) throw() { + // Use mmap. + off_t ignore = desired_begin % page_; + // Duplicate request for Shift means give more data. + if (position_ == data_.begin() + ignore) { + default_map_size_ *= 2; + } + // Local version so that in case of failure it doesn't overwrite the class variable. + off_t mapped_offset = desired_begin - ignore; + + off_t mapped_size; + if (default_map_size_ >= static_cast(total_size_ - mapped_offset)) { + at_end_ = true; + mapped_size = total_size_ - mapped_offset; + } else { + mapped_size = default_map_size_; + } + + // Forcibly clear the existing mmap first. + data_.reset(); + data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_PRIVATE, *file_, mapped_offset), mapped_size, scoped_memory::MMAP_ALLOCATED); + if (data_.get() == MAP_FAILED) { + fallback_to_read_ = true; + if (desired_begin) { + if (((off_t)-1) == lseek(*file_, desired_begin, SEEK_SET)) UTIL_THROW(ErrnoException, "mmap failed even though it worked before. lseek failed too, so using read isn't an option either."); + } + return; + } + mapped_offset_ = mapped_offset; + position_ = data_.begin() + ignore; + position_end_ = data_.begin() + mapped_size; +} + +void FilePiece::ReadShift(off_t desired_begin) throw() { + assert(fallback_to_read_); + if (data_.source() != scoped_memory::MALLOC_ALLOCATED) { + // First call. + data_.reset(); + data_.reset(malloc(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED); + if (!data_.get()) UTIL_THROW(ErrnoException, "malloc failed for " << default_map_size_); + position_ = data_.begin(); + position_end_ = position_; + } + + // Bytes [data_.begin(), position_) have been consumed. + // Bytes [position_, position_end_) have been read into the buffer. + + // Start at the beginning of the buffer if there's nothing useful in it. + if (position_ == position_end_) { + mapped_offset_ += (position_end_ - data_.begin()); + position_ = data_.begin(); + position_end_ = position_; + } + + std::size_t already_read = position_end_ - data_.begin(); + + if (already_read == default_map_size_) { + if (position_ == data_.begin()) { + // Buffer too small. + std::size_t valid_length = position_end_ - position_; + default_map_size_ *= 2; + data_.call_realloc(default_map_size_); + if (!data_.get()) UTIL_THROW(ErrnoException, "realloc failed for " << default_map_size_); + position_ = data_.begin(); + position_end_ = position_ + valid_length; + } else { + size_t moving = position_end_ - position_; + memmove(data_.get(), position_, moving); + position_ = data_.begin(); + position_end_ = position_ + moving; + already_read = moving; + } + } + + ssize_t read_return = read(file_.get(), static_cast(data_.get()) + already_read, default_map_size_ - already_read); + if (read_return == -1) UTIL_THROW(ErrnoException, "read failed"); + if (read_return == 0) at_end_ = true; + position_end_ += read_return; +} + +} // namespace util diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh new file mode 100644 index 00000000..704f0ac6 --- /dev/null +++ b/klm/util/file_piece.hh @@ -0,0 +1,105 @@ +#ifndef UTIL_FILE_PIECE__ +#define UTIL_FILE_PIECE__ + +#include "util/ersatz_progress.hh" +#include "util/exception.hh" +#include "util/mmap.hh" +#include "util/scoped.hh" +#include "util/string_piece.hh" + +#include + +#include + +namespace util { + +class EndOfFileException : public Exception { + public: + EndOfFileException() throw(); + ~EndOfFileException() throw(); +}; + +class ParseNumberException : public Exception { + public: + explicit ParseNumberException(StringPiece value) throw(); + ~ParseNumberException() throw() {} +}; + +int OpenReadOrThrow(const char *name); + +// Return value for SizeFile when it can't size properly. +const off_t kBadSize = -1; +off_t SizeFile(int fd); + +class FilePiece { + public: + // 32 MB default. + explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432); + // Takes ownership of fd. name is used for messages. + explicit FilePiece(const char *name, int fd, std::ostream *show_progress = NULL, off_t min_buffer = 33554432); + + char get() throw(EndOfFileException) { + if (position_ == position_end_) Shift(); + return *(position_++); + } + + // Memory backing the returned StringPiece may vanish on the next call. + // Leaves the delimiter, if any, to be returned by get(). + StringPiece ReadDelimited() throw(EndOfFileException) { + SkipSpaces(); + return Consume(FindDelimiterOrEOF()); + } + // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter. + // It is similar to getline in that way. + StringPiece ReadLine(char delim = '\n') throw(EndOfFileException); + + float ReadFloat() throw(EndOfFileException, ParseNumberException); + + void SkipSpaces() throw (EndOfFileException); + + off_t Offset() const { + return position_ - data_.begin() + mapped_offset_; + } + + // Only for testing. + void ForceFallbackToRead() { + fallback_to_read_ = true; + } + + private: + void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer); + + StringPiece Consume(const char *to) { + StringPiece ret(position_, to - position_); + position_ = to; + return ret; + } + + const char *FindDelimiterOrEOF() throw(EndOfFileException); + + void Shift() throw (EndOfFileException); + // Backends to Shift(). + void MMapShift(off_t desired_begin) throw (); + void ReadShift(off_t desired_begin) throw (); + + const char *position_, *last_space_, *position_end_; + + scoped_fd file_; + const off_t total_size_; + const off_t page_; + + size_t default_map_size_; + off_t mapped_offset_; + + // Order matters: file_ should always be destroyed after this. + scoped_memory data_; + + bool at_end_; + bool fallback_to_read_; + + ErsatzProgress progress_; +}; + +} // namespace util + +#endif // UTIL_FILE_PIECE__ diff --git a/klm/util/file_piece_test.cc b/klm/util/file_piece_test.cc new file mode 100644 index 00000000..befb7866 --- /dev/null +++ b/klm/util/file_piece_test.cc @@ -0,0 +1,41 @@ +#include "util/file_piece.hh" + +#define BOOST_TEST_MODULE FilePieceTest +#include +#include +#include + +namespace util { +namespace { + +/* mmap implementation */ +BOOST_AUTO_TEST_CASE(MMapLine) { + std::fstream ref("file_piece.cc", std::ios::in); + FilePiece test("file_piece.cc", NULL, 1); + std::string ref_line; + while (getline(ref, ref_line)) { + StringPiece test_line(test.ReadLine()); + // I submitted a bug report to ICU: http://bugs.icu-project.org/trac/ticket/7924 + if (!test_line.empty() || !ref_line.empty()) { + BOOST_CHECK_EQUAL(ref_line, test_line); + } + } +} + +/* read() implementation */ +BOOST_AUTO_TEST_CASE(ReadLine) { + std::fstream ref("file_piece.cc", std::ios::in); + FilePiece test("file_piece.cc", NULL, 1); + test.ForceFallbackToRead(); + std::string ref_line; + while (getline(ref, ref_line)) { + StringPiece test_line(test.ReadLine()); + // I submitted a bug report to ICU: http://bugs.icu-project.org/trac/ticket/7924 + if (!test_line.empty() || !ref_line.empty()) { + BOOST_CHECK_EQUAL(ref_line, test_line); + } + } +} + +} // namespace +} // namespace util diff --git a/klm/util/joint_sort.hh b/klm/util/joint_sort.hh new file mode 100644 index 00000000..a2f1c01d --- /dev/null +++ b/klm/util/joint_sort.hh @@ -0,0 +1,145 @@ +#ifndef UTIL_JOINT_SORT__ +#define UTIL_JOINT_SORT__ + +/* A terrifying amount of C++ to coax std::sort into soring one range while + * also permuting another range the same way. + */ + +#include "util/proxy_iterator.hh" + +#include +#include +#include + +namespace util { + +namespace detail { + +template class JointProxy; + +template class JointIter { + public: + JointIter() {} + + JointIter(const KeyIter &key_iter, const ValueIter &value_iter) : key_(key_iter), value_(value_iter) {} + + bool operator==(const JointIter &other) const { return key_ == other.key_; } + + bool operator<(const JointIter &other) const { return (key_ < other.key_); } + + std::ptrdiff_t operator-(const JointIter &other) const { return key_ - other.key_; } + + JointIter &operator+=(std::ptrdiff_t amount) { + key_ += amount; + value_ += amount; + return *this; + } + + void swap(const JointIter &other) { + std::swap(key_, other.key_); + std::swap(value_, other.value_); + } + + private: + friend class JointProxy; + KeyIter key_; + ValueIter value_; +}; + +template class JointProxy { + private: + typedef JointIter InnerIterator; + + public: + typedef struct { + typename std::iterator_traits::value_type key; + typename std::iterator_traits::value_type value; + const typename std::iterator_traits::value_type &GetKey() const { return key; } + } value_type; + + JointProxy(const KeyIter &key_iter, const ValueIter &value_iter) : inner_(key_iter, value_iter) {} + JointProxy(const JointProxy &other) : inner_(other.inner_) {} + + operator const value_type() const { + value_type ret; + ret.key = *inner_.key_; + ret.value = *inner_.value_; + return ret; + } + + JointProxy &operator=(const JointProxy &other) { + *inner_.key_ = *other.inner_.key_; + *inner_.value_ = *other.inner_.value_; + return *this; + } + + JointProxy &operator=(const value_type &other) { + *inner_.key_ = other.key; + *inner_.value_ = other.value; + return *this; + } + + typename std::iterator_traits::reference GetKey() const { + return *(inner_.key_); + } + + void swap(JointProxy &other) { + std::swap(*inner_.key_, *other.inner_.key_); + std::swap(*inner_.value_, *other.inner_.value_); + } + + private: + friend class ProxyIterator >; + + InnerIterator &Inner() { return inner_; } + const InnerIterator &Inner() const { return inner_; } + InnerIterator inner_; +}; + +template class LessWrapper : public std::binary_function { + public: + explicit LessWrapper(const Less &less) : less_(less) {} + + bool operator()(const Proxy &left, const Proxy &right) const { + return less_(left.GetKey(), right.GetKey()); + } + bool operator()(const Proxy &left, const typename Proxy::value_type &right) const { + return less_(left.GetKey(), right.GetKey()); + } + bool operator()(const typename Proxy::value_type &left, const Proxy &right) const { + return less_(left.GetKey(), right.GetKey()); + } + bool operator()(const typename Proxy::value_type &left, const typename Proxy::value_type &right) const { + return less_(left.GetKey(), right.GetKey()); + } + + private: + const Less less_; +}; + +} // namespace detail + +template void JointSort(const KeyIter &key_begin, const KeyIter &key_end, const ValueIter &value_begin, const Less &less) { + ProxyIterator > full_begin(detail::JointProxy(key_begin, value_begin)); + detail::LessWrapper, Less> less_wrap(less); + std::sort(full_begin, full_begin + (key_end - key_begin), less_wrap); +} + + +template void JointSort(const KeyIter &key_begin, const KeyIter &key_end, const ValueIter &value_begin) { + JointSort(key_begin, key_end, value_begin, std::less::value_type>()); +} + +} // namespace util + +namespace std { +template void swap(util::detail::JointIter &left, util::detail::JointIter &right) { + left.swap(right); +} + +template void swap(util::detail::JointProxy &left, util::detail::JointProxy &right) { + left.swap(right); +} +} // namespace std + +#endif // UTIL_JOINT_SORT__ diff --git a/klm/util/joint_sort_test.cc b/klm/util/joint_sort_test.cc new file mode 100644 index 00000000..4dc85916 --- /dev/null +++ b/klm/util/joint_sort_test.cc @@ -0,0 +1,50 @@ +#include "util/joint_sort.hh" + +#define BOOST_TEST_MODULE JointSortTest +#include + +namespace util { namespace { + +BOOST_AUTO_TEST_CASE(just_flip) { + char keys[2]; + int values[2]; + keys[0] = 1; values[0] = 327; + keys[1] = 0; values[1] = 87897; + JointSort(keys + 0, keys + 2, values + 0); + BOOST_CHECK_EQUAL(0, keys[0]); + BOOST_CHECK_EQUAL(87897, values[0]); + BOOST_CHECK_EQUAL(1, keys[1]); + BOOST_CHECK_EQUAL(327, values[1]); +} + +BOOST_AUTO_TEST_CASE(three) { + char keys[3]; + int values[3]; + keys[0] = 1; values[0] = 327; + keys[1] = 2; values[1] = 87897; + keys[2] = 0; values[2] = 10; + JointSort(keys + 0, keys + 3, values + 0); + BOOST_CHECK_EQUAL(0, keys[0]); + BOOST_CHECK_EQUAL(1, keys[1]); + BOOST_CHECK_EQUAL(2, keys[2]); +} + +BOOST_AUTO_TEST_CASE(char_int) { + char keys[4]; + int values[4]; + keys[0] = 3; values[0] = 327; + keys[1] = 1; values[1] = 87897; + keys[2] = 2; values[2] = 10; + keys[3] = 0; values[3] = 24347; + JointSort(keys + 0, keys + 4, values + 0); + BOOST_CHECK_EQUAL(0, keys[0]); + BOOST_CHECK_EQUAL(24347, values[0]); + BOOST_CHECK_EQUAL(1, keys[1]); + BOOST_CHECK_EQUAL(87897, values[1]); + BOOST_CHECK_EQUAL(2, keys[2]); + BOOST_CHECK_EQUAL(10, values[2]); + BOOST_CHECK_EQUAL(3, keys[3]); + BOOST_CHECK_EQUAL(327, values[3]); +} + +}} // namespace anonymous util diff --git a/klm/util/key_value_packing.hh b/klm/util/key_value_packing.hh new file mode 100644 index 00000000..450512ac --- /dev/null +++ b/klm/util/key_value_packing.hh @@ -0,0 +1,122 @@ +#ifndef UTIL_KEY_VALUE_PACKING__ +#define UTIL_KEY_VALUE_PACKING__ + +/* Why such a general interface? I'm planning on doing bit-level packing. */ + +#include +#include +#include + +#include + +namespace util { + +template struct Entry { + Key key; + Value value; + + const Key &GetKey() const { return key; } + const Value &GetValue() const { return value; } + + void Set(const Key &key_in, const Value &value_in) { + SetKey(key_in); + SetValue(value_in); + } + void SetKey(const Key &key_in) { key = key_in; } + void SetValue(const Value &value_in) { value = value_in; } + + bool operator<(const Entry &other) const { return GetKey() < other.GetKey(); } +}; + +// And now for a brief interlude to specialize std::swap. +} // namespace util +namespace std { +template void swap(util::Entry &first, util::Entry &second) { + swap(first.key, second.key); + swap(first.value, second.value); +} +}// namespace std +namespace util { + +template class AlignedPacking { + public: + typedef KeyT Key; + typedef ValueT Value; + + public: + static const std::size_t kBytes = sizeof(Entry); + static const std::size_t kBits = kBytes * 8; + + typedef Entry * MutableIterator; + typedef const Entry * ConstIterator; + typedef const Entry & ConstReference; + + static MutableIterator FromVoid(void *start) { + return reinterpret_cast(start); + } + + static Entry Make(const Key &key, const Value &value) { + Entry ret; + ret.Set(key, value); + return ret; + } +}; + +template class ByteAlignedPacking { + public: + typedef KeyT Key; + typedef ValueT Value; + + private: +#pragma pack(push) +#pragma pack(1) + struct RawEntry { + Key key; + Value value; + + const Key &GetKey() const { return key; } + const Value &GetValue() const { return value; } + + void Set(const Key &key_in, const Value &value_in) { + SetKey(key_in); + SetValue(value_in); + } + void SetKey(const Key &key_in) { key = key_in; } + void SetValue(const Value &value_in) { value = value_in; } + + bool operator<(const RawEntry &other) const { return GetKey() < other.GetKey(); } + }; +#pragma pack(pop) + + friend void std::swap<>(RawEntry&, RawEntry&); + + public: + typedef RawEntry *MutableIterator; + typedef const RawEntry *ConstIterator; + typedef RawEntry &ConstReference; + + static const std::size_t kBytes = sizeof(RawEntry); + static const std::size_t kBits = kBytes * 8; + + static MutableIterator FromVoid(void *start) { + return MutableIterator(reinterpret_cast(start)); + } + + static RawEntry Make(const Key &key, const Value &value) { + RawEntry ret; + ret.Set(key, value); + return ret; + } +}; + +} // namespace util +namespace std { +template void swap( + typename util::ByteAlignedPacking::RawEntry &first, + typename util::ByteAlignedPacking::RawEntry &second) { + swap(first.key, second.key); + swap(first.value, second.value); +} +}// namespace std + +#endif // UTIL_KEY_VALUE_PACKING__ diff --git a/klm/util/key_value_packing_test.cc b/klm/util/key_value_packing_test.cc new file mode 100644 index 00000000..a0d33fd7 --- /dev/null +++ b/klm/util/key_value_packing_test.cc @@ -0,0 +1,75 @@ +#include "util/key_value_packing.hh" + +#include +#include +#include +#include +#define BOOST_TEST_MODULE KeyValueStoreTest +#include + +#include +#include + +namespace util { +namespace { + +BOOST_AUTO_TEST_CASE(basic_in_out) { + typedef ByteAlignedPacking Packing; + void *backing = malloc(Packing::kBytes * 2); + Packing::MutableIterator i(Packing::FromVoid(backing)); + i->SetKey(10); + BOOST_CHECK_EQUAL(10, i->GetKey()); + i->SetValue(3); + BOOST_CHECK_EQUAL(3, i->GetValue()); + ++i; + i->SetKey(5); + BOOST_CHECK_EQUAL(5, i->GetKey()); + i->SetValue(42); + BOOST_CHECK_EQUAL(42, i->GetValue()); + + Packing::ConstIterator c(i); + BOOST_CHECK_EQUAL(5, c->GetKey()); + --c; + BOOST_CHECK_EQUAL(10, c->GetKey()); + BOOST_CHECK_EQUAL(42, i->GetValue()); + + BOOST_CHECK_EQUAL(5, i->GetKey()); + free(backing); +} + +BOOST_AUTO_TEST_CASE(simple_sort) { + typedef ByteAlignedPacking Packing; + char foo[Packing::kBytes * 4]; + Packing::MutableIterator begin(Packing::FromVoid(foo)); + Packing::MutableIterator i = begin; + i->SetKey(0); ++i; + i->SetKey(2); ++i; + i->SetKey(3); ++i; + i->SetKey(1); ++i; + std::sort(begin, i); + BOOST_CHECK_EQUAL(0, begin[0].GetKey()); + BOOST_CHECK_EQUAL(1, begin[1].GetKey()); + BOOST_CHECK_EQUAL(2, begin[2].GetKey()); + BOOST_CHECK_EQUAL(3, begin[3].GetKey()); +} + +BOOST_AUTO_TEST_CASE(big_sort) { + typedef ByteAlignedPacking Packing; + boost::scoped_array memory(new char[Packing::kBytes * 1000]); + Packing::MutableIterator begin(Packing::FromVoid(memory.get())); + + boost::mt19937 rng; + boost::uniform_int range(0, std::numeric_limits::max()); + boost::variate_generator > gen(rng, range); + + for (size_t i = 0; i < 1000; ++i) { + (begin + i)->SetKey(gen()); + } + std::sort(begin, begin + 1000); + for (size_t i = 0; i < 999; ++i) { + BOOST_CHECK(begin[i] < begin[i+1]); + } +} + +} // namespace +} // namespace util diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc new file mode 100644 index 00000000..648b5d0a --- /dev/null +++ b/klm/util/mmap.cc @@ -0,0 +1,95 @@ +#include "util/exception.hh" +#include "util/mmap.hh" +#include "util/scoped.hh" + +#include +#include +#include +#include +#include +#include +#include + +namespace util { + +scoped_mmap::~scoped_mmap() { + if (data_ != (void*)-1) { + if (munmap(data_, size_)) + err(1, "munmap failed "); + } +} + +void scoped_memory::reset(void *data, std::size_t size, Alloc source) { + switch(source_) { + case MMAP_ALLOCATED: + scoped_mmap(data_, size_); + break; + case ARRAY_ALLOCATED: + delete [] reinterpret_cast(data_); + break; + case MALLOC_ALLOCATED: + free(data_); + break; + case NONE_ALLOCATED: + break; + } + data_ = data; + size_ = size; + source_ = source; +} + +void scoped_memory::call_realloc(std::size_t size) { + assert(source_ == MALLOC_ALLOCATED || source_ == NONE_ALLOCATED); + void *new_data = realloc(data_, size); + if (!new_data) { + reset(); + } else { + reset(new_data, size, MALLOC_ALLOCATED); + } +} + +void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int fd, off_t offset) { +#ifdef MAP_POPULATE // Linux specific + if (prefault) { + flags |= MAP_POPULATE; + } + int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ; +#else + int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ; +#endif + void *ret = mmap(NULL, size, protect, flags, fd, offset); + if (ret == MAP_FAILED) { + UTIL_THROW(ErrnoException, "mmap failed for size " << size << " at offset " << offset); + } + return ret; +} + +void *MapForRead(std::size_t size, bool prefault, int fd, off_t offset) { + return MapOrThrow(size, false, MAP_FILE | MAP_PRIVATE, prefault, fd, offset); +} + +void *MapAnonymous(std::size_t size) { + return MapOrThrow(size, true, +#ifdef MAP_ANONYMOUS + MAP_ANONYMOUS // Linux +#else + MAP_ANON // BSD +#endif + | MAP_PRIVATE, false, -1, 0); +} + +void MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file, scoped_mmap &mem) { + file.reset(open(name, O_CREAT | O_RDWR | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)); + if (-1 == file.get()) + UTIL_THROW(ErrnoException, "Failed to open " << name << " for writing"); + if (-1 == ftruncate(file.get(), size)) + UTIL_THROW(ErrnoException, "ftruncate on " << name << " to " << size << " failed"); + try { + mem.reset(MapOrThrow(size, true, MAP_FILE | MAP_SHARED, false, file.get(), 0), size); + } catch (ErrnoException &e) { + e << " in file " << name; + throw; + } +} + +} // namespace util diff --git a/klm/util/mmap.hh b/klm/util/mmap.hh new file mode 100644 index 00000000..c9068ec9 --- /dev/null +++ b/klm/util/mmap.hh @@ -0,0 +1,101 @@ +#ifndef UTIL_MMAP__ +#define UTIL_MMAP__ +// Utilities for mmaped files. + +#include "util/scoped.hh" + +#include + +#include + +namespace util { + +// (void*)-1 is MAP_FAILED; this is done to avoid including the mmap header here. +class scoped_mmap { + public: + scoped_mmap() : data_((void*)-1), size_(0) {} + scoped_mmap(void *data, std::size_t size) : data_(data), size_(size) {} + ~scoped_mmap(); + + void *get() const { return data_; } + + const char *begin() const { return reinterpret_cast(data_); } + const char *end() const { return reinterpret_cast(data_) + size_; } + std::size_t size() const { return size_; } + + void reset(void *data, std::size_t size) { + scoped_mmap other(data_, size_); + data_ = data; + size_ = size; + } + + void reset() { + reset((void*)-1, 0); + } + + private: + void *data_; + std::size_t size_; + + scoped_mmap(const scoped_mmap &); + scoped_mmap &operator=(const scoped_mmap &); +}; + +/* For when the memory might come from mmap, new char[], or malloc. Uses NULL + * and 0 for blanks even though mmap signals errors with (void*)-1). The reset + * function checks that blank for mmap. + */ +class scoped_memory { + public: + typedef enum {MMAP_ALLOCATED, ARRAY_ALLOCATED, MALLOC_ALLOCATED, NONE_ALLOCATED} Alloc; + + scoped_memory() : data_(NULL), size_(0), source_(NONE_ALLOCATED) {} + + ~scoped_memory() { reset(); } + + void *get() const { return data_; } + const char *begin() const { return reinterpret_cast(data_); } + const char *end() const { return reinterpret_cast(data_) + size_; } + std::size_t size() const { return size_; } + + Alloc source() const { return source_; } + + void reset() { reset(NULL, 0, NONE_ALLOCATED); } + + void reset(void *data, std::size_t size, Alloc from); + + // realloc allows the current data to escape hence the need for this call + // If realloc fails, destroys the original too and get() returns NULL. + void call_realloc(std::size_t to); + + private: + + void *data_; + std::size_t size_; + + Alloc source_; + + scoped_memory(const scoped_memory &); + scoped_memory &operator=(const scoped_memory &); +}; + +struct scoped_mapped_file { + scoped_fd fd; + scoped_mmap mem; +}; + +// Wrapper around mmap to check it worked and hide some platform macros. +void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int fd, off_t offset = 0); +void *MapForRead(std::size_t size, bool prefault, int fd, off_t offset = 0); + +void *MapAnonymous(std::size_t size); + +// Open file name with mmap of size bytes, all of which are initially zero. +void MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file, scoped_mmap &mem); +inline void MapZeroedWrite(const char *name, std::size_t size, scoped_mapped_file &out) { + MapZeroedWrite(name, size, out.fd, out.mem); +} + +} // namespace util + +#endif // UTIL_SCOPED__ diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc new file mode 100644 index 00000000..d58a0727 --- /dev/null +++ b/klm/util/murmur_hash.cc @@ -0,0 +1,129 @@ +/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All + * code is released to the public domain. For business purposes, Murmurhash is + * under the MIT license." + * This is modified from the original: + * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. + * length changed to unsigned int. + * placed in namespace util + * add MurmurHashNative + * default option = 0 for seed + */ + +#include "util/murmur_hash.hh" + +namespace util { + +//----------------------------------------------------------------------------- +// MurmurHash2, 64-bit versions, by Austin Appleby + +// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment +// and endian-ness issues if used across multiple platforms. + +// 64-bit hash for 64-bit platforms + +uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) +{ + const uint64_t m = 0xc6a4a7935bd1e995ULL; + const int r = 47; + + uint64_t h = seed ^ (len * m); + + const uint64_t * data = (const uint64_t *)key; + const uint64_t * end = data + (len/8); + + while(data != end) + { + uint64_t k = *data++; + + k *= m; + k ^= k >> r; + k *= m; + + h ^= k; + h *= m; + } + + const unsigned char * data2 = (const unsigned char*)data; + + switch(len & 7) + { + case 7: h ^= uint64_t(data2[6]) << 48; + case 6: h ^= uint64_t(data2[5]) << 40; + case 5: h ^= uint64_t(data2[4]) << 32; + case 4: h ^= uint64_t(data2[3]) << 24; + case 3: h ^= uint64_t(data2[2]) << 16; + case 2: h ^= uint64_t(data2[1]) << 8; + case 1: h ^= uint64_t(data2[0]); + h *= m; + }; + + h ^= h >> r; + h *= m; + h ^= h >> r; + + return h; +} + + +// 64-bit hash for 32-bit platforms + +uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) +{ + const unsigned int m = 0x5bd1e995; + const int r = 24; + + unsigned int h1 = seed ^ len; + unsigned int h2 = 0; + + const unsigned int * data = (const unsigned int *)key; + + while(len >= 8) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + + unsigned int k2 = *data++; + k2 *= m; k2 ^= k2 >> r; k2 *= m; + h2 *= m; h2 ^= k2; + len -= 4; + } + + if(len >= 4) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + } + + switch(len) + { + case 3: h2 ^= ((unsigned char*)data)[2] << 16; + case 2: h2 ^= ((unsigned char*)data)[1] << 8; + case 1: h2 ^= ((unsigned char*)data)[0]; + h2 *= m; + }; + + h1 ^= h2 >> 18; h1 *= m; + h2 ^= h1 >> 22; h2 *= m; + h1 ^= h2 >> 17; h1 *= m; + h2 ^= h1 >> 19; h2 *= m; + + uint64_t h = h1; + + h = (h << 32) | h2; + + return h; +} + +uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { + if (sizeof(int) == 4) { + return MurmurHash64B(key, len, seed); + } else { + return MurmurHash64A(key, len, seed); + } +} + +} // namespace util diff --git a/klm/util/murmur_hash.hh b/klm/util/murmur_hash.hh new file mode 100644 index 00000000..638aaeb2 --- /dev/null +++ b/klm/util/murmur_hash.hh @@ -0,0 +1,14 @@ +#ifndef UTIL_MURMUR_HASH__ +#define UTIL_MURMUR_HASH__ +#include +#include + +namespace util { + +uint64_t MurmurHash64A(const void * key, std::size_t len, unsigned int seed = 0); +uint64_t MurmurHash64B(const void * key, std::size_t len, unsigned int seed = 0); +uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed = 0); + +} // namespace util + +#endif // UTIL_MURMUR_HASH__ diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh new file mode 100644 index 00000000..c3529a7e --- /dev/null +++ b/klm/util/probing_hash_table.hh @@ -0,0 +1,97 @@ +#ifndef UTIL_PROBING_HASH_TABLE__ +#define UTIL_PROBING_HASH_TABLE__ + +#include +#include +#include + +#include + +namespace util { + +/* Non-standard hash table + * Buckets must be set at the beginning and must be greater than maximum number + * of elements, else an infinite loop happens. + * Memory management and initialization is externalized to make it easier to + * serialize these to disk and load them quickly. + * Uses linear probing to find value. + * Only insert and lookup operations. + */ + +template > class ProbingHashTable { + public: + typedef PackingT Packing; + typedef typename Packing::Key Key; + typedef typename Packing::MutableIterator MutableIterator; + typedef typename Packing::ConstIterator ConstIterator; + + typedef HashT Hash; + typedef EqualT Equal; + + static std::size_t Size(std::size_t entries, float multiplier) { + return std::max(entries + 1, static_cast(multiplier * static_cast(entries))) * Packing::kBytes; + } + + // Must be assigned to later. + ProbingHashTable() +#ifdef DEBUG + : initialized_(false), entries_(0) +#endif + {} + + ProbingHashTable(void *start, std::size_t allocated, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal()) + : begin_(Packing::FromVoid(start)), + buckets_(allocated / Packing::kBytes), + end_(begin_ + (allocated / Packing::kBytes)), + invalid_(invalid), + hash_(hash_func), + equal_(equal_func) +#ifdef DEBUG + , initialized_(true), + entries_(0) +#endif + {} + + template void Insert(const T &t) { +#ifdef DEBUG + assert(initialized_); + assert(++entries_ < buckets_); +#endif + for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) { + if (equal_(i->GetKey(), invalid_)) { *i = t; return; } + if (++i == end_) { i = begin_; } + } + } + + void FinishedInserting() {} + + void LoadedBinary() {} + + template bool Find(const Key key, ConstIterator &out) const { +#ifdef DEBUG + assert(initialized_); +#endif + for (ConstIterator i(begin_ + (hash_(key) % buckets_));;) { + Key got(i->GetKey()); + if (equal_(got, key)) { out = i; return true; } + if (equal_(got, invalid_)) { return false; } + if (++i == end_) { i = begin_; } + } + } + + private: + MutableIterator begin_; + std::size_t buckets_; + MutableIterator end_; + Key invalid_; + Hash hash_; + Equal equal_; +#ifdef DEBUG + bool initialized_; + std::size_t entries_; +#endif +}; + +} // namespace util + +#endif // UTIL_PROBING_HASH_TABLE__ diff --git a/klm/util/probing_hash_table_test.cc b/klm/util/probing_hash_table_test.cc new file mode 100644 index 00000000..ff2f5af3 --- /dev/null +++ b/klm/util/probing_hash_table_test.cc @@ -0,0 +1,30 @@ +#include "util/probing_hash_table.hh" + +#include "util/key_value_packing.hh" + +#define BOOST_TEST_MODULE ProbingHashTableTest +#include +#include + +namespace util { +namespace { + +typedef AlignedPacking Packing; +typedef ProbingHashTable > Table; + +BOOST_AUTO_TEST_CASE(simple) { + char mem[Table::Size(10, 1.2)]; + memset(mem, 0, sizeof(mem)); + + Table table(mem, sizeof(mem)); + Packing::ConstIterator i = Packing::ConstIterator(); + BOOST_CHECK(!table.Find(2, i)); + table.Insert(Packing::Make(3, 328920)); + BOOST_REQUIRE(table.Find(3, i)); + BOOST_CHECK_EQUAL(3, i->GetKey()); + BOOST_CHECK_EQUAL(static_cast(328920), i->GetValue()); + BOOST_CHECK(!table.Find(2, i)); +} + +} // namespace +} // namespace util diff --git a/klm/util/proxy_iterator.hh b/klm/util/proxy_iterator.hh new file mode 100644 index 00000000..1c5b7089 --- /dev/null +++ b/klm/util/proxy_iterator.hh @@ -0,0 +1,94 @@ +#ifndef UTIL_PROXY_ITERATOR__ +#define UTIL_PROXY_ITERATOR__ + +#include +#include + +/* This is a RandomAccessIterator that uses a proxy to access the underlying + * data. Useful for packing data at bit offsets but still using STL + * algorithms. + * + * Normally I would use boost::iterator_facade but some people are too lazy to + * install boost and still want to use my language model. It's amazing how + * many operators an iterator has. + * + * The Proxy needs to provide: + * class InnerIterator; + * InnerIterator &Inner(); + * const InnerIterator &Inner() const; + * + * InnerIterator has to implement: + * operator==(InnerIterator) + * operator<(InnerIterator) + * operator+=(std::ptrdiff_t) + * operator-(InnerIterator) + * and of course whatever Proxy needs to dereference it. + * + * It's also a good idea to specialize std::swap for Proxy. + */ + +namespace util { +template class ProxyIterator { + private: + // Self. + typedef ProxyIterator S; + typedef typename Proxy::InnerIterator InnerIterator; + + public: + typedef std::random_access_iterator_tag iterator_category; + typedef typename Proxy::value_type value_type; + typedef std::ptrdiff_t difference_type; + typedef Proxy reference; + typedef Proxy * pointer; + + ProxyIterator() {} + + // For cast from non const to const. + template ProxyIterator(const ProxyIterator &in) : p_(*in) {} + explicit ProxyIterator(const Proxy &p) : p_(p) {} + + // p_'s operator= does value copying, but here we want iterator copying. + S &operator=(const S &other) { + I() = other.I(); + return *this; + } + + bool operator==(const S &other) const { return I() == other.I(); } + bool operator!=(const S &other) const { return !(*this == other); } + bool operator<(const S &other) const { return I() < other.I(); } + bool operator>(const S &other) const { return other < *this; } + bool operator<=(const S &other) const { return !(*this > other); } + bool operator>=(const S &other) const { return !(*this < other); } + + S &operator++() { return *this += 1; } + S operator++(int) { S ret(*this); ++*this; return ret; } + S &operator+=(std::ptrdiff_t amount) { I() += amount; return *this; } + S operator+(std::ptrdiff_t amount) const { S ret(*this); ret += amount; return ret; } + + S &operator--() { return *this -= 1; } + S operator--(int) { S ret(*this); --*this; return ret; } + S &operator-=(std::ptrdiff_t amount) { I() += (-amount); return *this; } + S operator-(std::ptrdiff_t amount) const { S ret(*this); ret -= amount; return ret; } + + std::ptrdiff_t operator-(const S &other) const { return I() - other.I(); } + + Proxy operator*() { return p_; } + const Proxy operator*() const { return p_; } + Proxy *operator->() { return &p_; } + const Proxy *operator->() const { return &p_; } + Proxy operator[](std::ptrdiff_t amount) const { return *(*this + amount); } + + private: + InnerIterator &I() { return p_.Inner(); } + const InnerIterator &I() const { return p_.Inner(); } + + Proxy p_; +}; + +template ProxyIterator operator+(std::ptrdiff_t amount, const ProxyIterator &it) { + return it + amount; +} + +} // namespace util + +#endif // UTIL_PROXY_ITERATOR__ diff --git a/klm/util/scoped.cc b/klm/util/scoped.cc new file mode 100644 index 00000000..61394ffc --- /dev/null +++ b/klm/util/scoped.cc @@ -0,0 +1,12 @@ +#include "util/scoped.hh" + +#include +#include + +namespace util { + +scoped_fd::~scoped_fd() { + if (fd_ != -1 && close(fd_)) err(1, "Could not close file %i", fd_); +} + +} // namespace util diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh new file mode 100644 index 00000000..ef62a74f --- /dev/null +++ b/klm/util/scoped.hh @@ -0,0 +1,66 @@ +#ifndef UTIL_SCOPED__ +#define UTIL_SCOPED__ + +/* Other scoped objects in the style of scoped_ptr. */ + +#include + +namespace util { + +template class scoped_thing { + public: + explicit scoped_thing(T *c = static_cast(0)) : c_(c) {} + + ~scoped_thing() { if (c_) Free(c_); } + + void reset(T *c) { + if (c_) Free(c_); + c_ = c; + } + + T &operator*() { return *c_; } + T &operator->() { return *c_; } + + T *get() { return c_; } + const T *get() const { return c_; } + + private: + T *c_; + + scoped_thing(const scoped_thing &); + scoped_thing &operator=(const scoped_thing &); +}; + +class scoped_fd { + public: + scoped_fd() : fd_(-1) {} + + explicit scoped_fd(int fd) : fd_(fd) {} + + ~scoped_fd(); + + void reset(int to) { + scoped_fd other(fd_); + fd_ = to; + } + + int get() const { return fd_; } + + int operator*() const { return fd_; } + + int release() { + int ret = fd_; + fd_ = -1; + return ret; + } + + private: + int fd_; + + scoped_fd(const scoped_fd &); + scoped_fd &operator=(const scoped_fd &); +}; + +} // namespace util + +#endif // UTIL_SCOPED__ diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh new file mode 100644 index 00000000..96ec4866 --- /dev/null +++ b/klm/util/sorted_uniform.hh @@ -0,0 +1,139 @@ +#ifndef UTIL_SORTED_UNIFORM__ +#define UTIL_SORTED_UNIFORM__ + +#include +#include + +#include +#include + +namespace util { + +inline std::size_t Pivot(uint64_t off, uint64_t range, std::size_t width) { + std::size_t ret = static_cast(static_cast(off) / static_cast(range) * static_cast(width)); + // Cap for floating point rounding + return (ret < width) ? ret : width - 1; +} +/*inline std::size_t Pivot(uint32_t off, uint32_t range, std::size_t width) { + return static_cast(static_cast(off) * static_cast(width) / static_cast(range)); +} +inline std::size_t Pivot(uint16_t off, uint16_t range, std::size_t width) { + return static_cast(static_cast(off) * width / static_cast(range)); +} +inline std::size_t Pivot(unsigned char off, unsigned char range, std::size_t width) { + return static_cast(static_cast(off) * width / static_cast(range)); +}*/ + +template bool SortedUniformFind(Iterator begin, Iterator end, const Key key, Iterator &out) { + if (begin == end) return false; + Key below(begin->GetKey()); + if (key <= below) { + if (key == below) { out = begin; return true; } + return false; + } + // Make the range [begin, end]. + --end; + Key above(end->GetKey()); + if (key >= above) { + if (key == above) { out = end; return true; } + return false; + } + + // Search the range [begin + 1, end - 1] knowing that *begin == below, *end == above. + while (end - begin > 1) { + Iterator pivot(begin + (1 + Pivot(key - below, above - below, static_cast(end - begin - 1)))); + Key mid(pivot->GetKey()); + if (mid < key) { + begin = pivot; + below = mid; + } else if (mid > key) { + end = pivot; + above = mid; + } else { + out = pivot; + return true; + } + } + return false; +} + +// To use this template, you need to define a Pivot function to match Key. +template class SortedUniformMap { + public: + typedef PackingT Packing; + typedef typename Packing::ConstIterator ConstIterator; + + public: + // Offer consistent API with probing hash. + static std::size_t Size(std::size_t entries, float ignore = 0.0) { + return sizeof(uint64_t) + entries * Packing::kBytes; + } + + SortedUniformMap() +#ifdef DEBUG + : initialized_(false), loaded_(false) +#endif + {} + + SortedUniformMap(void *start, std::size_t allocated) : + begin_(Packing::FromVoid(reinterpret_cast(start) + 1)), + end_(begin_), size_ptr_(reinterpret_cast(start)) +#ifdef DEBUG + , initialized_(true), loaded_(false) +#endif + {} + + void LoadedBinary() { +#ifdef DEBUG + assert(initialized_); + assert(!loaded_); + loaded_ = true; +#endif + // Restore the size. + end_ = begin_ + *size_ptr_; + } + + // Caller responsible for not exceeding specified size. Do not call after FinishedInserting. + template void Insert(const T &t) { +#ifdef DEBUG + assert(initialized_); + assert(!loaded_); +#endif + *end_ = t; + ++end_; + } + + void FinishedInserting() { +#ifdef DEBUG + assert(initialized_); + assert(!loaded_); + loaded_ = true; +#endif + std::sort(begin_, end_); + *size_ptr_ = (end_ - begin_); + } + + // Do not call before FinishedInserting. + template bool Find(const Key key, ConstIterator &out) const { +#ifdef DEBUG + assert(initialized_); + assert(loaded_); +#endif + return SortedUniformFind(ConstIterator(begin_), ConstIterator(end_), key, out); + } + + ConstIterator begin() const { return begin_; } + ConstIterator end() const { return end_; } + + private: + typename Packing::MutableIterator begin_, end_; + uint64_t *size_ptr_; +#ifdef DEBUG + bool initialized_; + bool loaded_; +#endif +}; + +} // namespace util + +#endif // UTIL_SORTED_UNIFORM__ diff --git a/klm/util/sorted_uniform_test.cc b/klm/util/sorted_uniform_test.cc new file mode 100644 index 00000000..4aa4c8aa --- /dev/null +++ b/klm/util/sorted_uniform_test.cc @@ -0,0 +1,116 @@ +#include "util/sorted_uniform.hh" + +#include "util/key_value_packing.hh" + +#include +#include +#include +#include +#include +#define BOOST_TEST_MODULE SortedUniformTest +#include + +#include +#include +#include + +namespace util { +namespace { + +template void Check(const Map &map, const boost::unordered_map &reference, const Key key) { + typename boost::unordered_map::const_iterator ref = reference.find(key); + typename Map::ConstIterator i = typename Map::ConstIterator(); + if (ref == reference.end()) { + BOOST_CHECK(!map.Find(key, i)); + } else { + // g++ can't tell that require will crash and burn. + BOOST_REQUIRE(map.Find(key, i)); + BOOST_CHECK_EQUAL(ref->second, i->GetValue()); + } +} + +typedef SortedUniformMap > TestMap; + +BOOST_AUTO_TEST_CASE(empty) { + char buf[TestMap::Size(0)]; + TestMap map(buf, TestMap::Size(0)); + map.FinishedInserting(); + TestMap::ConstIterator i; + BOOST_CHECK(!map.Find(42, i)); +} + +BOOST_AUTO_TEST_CASE(one) { + char buf[TestMap::Size(1)]; + TestMap map(buf, sizeof(buf)); + Entry e; + e.Set(42,2); + map.Insert(e); + map.FinishedInserting(); + TestMap::ConstIterator i = TestMap::ConstIterator(); + BOOST_REQUIRE(map.Find(42, i)); + BOOST_CHECK(i == map.begin()); + BOOST_CHECK(!map.Find(43, i)); + BOOST_CHECK(!map.Find(41, i)); +} + +template void RandomTest(Key upper, size_t entries, size_t queries) { + typedef unsigned char Value; + typedef SortedUniformMap > Map; + boost::scoped_array buffer(new char[Map::Size(entries)]); + Map map(buffer.get(), entries); + boost::mt19937 rng; + boost::uniform_int range_key(0, upper); + boost::uniform_int range_value(0, 255); + boost::variate_generator > gen_key(rng, range_key); + boost::variate_generator > gen_value(rng, range_value); + + boost::unordered_map reference; + Entry ent; + for (size_t i = 0; i < entries; ++i) { + Key key = gen_key(); + unsigned char value = gen_value(); + if (reference.insert(std::make_pair(key, value)).second) { + ent.Set(key, value); + map.Insert(Entry(ent)); + } + } + map.FinishedInserting(); + + // Random queries. + for (size_t i = 0; i < queries; ++i) { + const Key key = gen_key(); + Check(map, reference, key); + } + + typename boost::unordered_map::const_iterator it = reference.begin(); + for (size_t i = 0; (i < queries) && (it != reference.end()); ++i, ++it) { + Check(map, reference, it->second); + } +} + +BOOST_AUTO_TEST_CASE(basic) { + RandomTest(11, 10, 200); +} + +BOOST_AUTO_TEST_CASE(tiny_dense_random) { + RandomTest(11, 50, 200); +} + +BOOST_AUTO_TEST_CASE(small_dense_random) { + RandomTest(100, 100, 200); +} + +BOOST_AUTO_TEST_CASE(small_sparse_random) { + RandomTest(200, 15, 200); +} + +BOOST_AUTO_TEST_CASE(medium_sparse_random) { + RandomTest(32000, 1000, 2000); +} + +BOOST_AUTO_TEST_CASE(sparse_random) { + RandomTest(std::numeric_limits::max(), 100000, 2000); +} + +} // namespace +} // namespace util diff --git a/klm/util/string_piece.cc b/klm/util/string_piece.cc new file mode 100644 index 00000000..6917a6bc --- /dev/null +++ b/klm/util/string_piece.cc @@ -0,0 +1,57 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// Copied from strings/stringpiece.cc with modifications + +#include "util/string_piece.hh" + +#ifdef USE_BOOST +#include +#endif + +#include +#include + +#ifdef USE_ICU +U_NAMESPACE_BEGIN +#endif + +std::ostream& operator<<(std::ostream& o, const StringPiece& piece) { + o.write(piece.data(), static_cast(piece.size())); + return o; +} + +#ifdef USE_BOOST +size_t hash_value(const StringPiece &str) { + return boost::hash_range(str.data(), str.data() + str.length()); +} +#endif + +#ifdef USE_ICU +U_NAMESPACE_END +#endif diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh new file mode 100644 index 00000000..58008d13 --- /dev/null +++ b/klm/util/string_piece.hh @@ -0,0 +1,260 @@ +/* If you use ICU in your program, then compile with -DUSE_ICU -licui18n. If + * you don't use ICU, then this will use the Google implementation from Chrome. + * This has been modified from the original version to let you choose. + */ + +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// Copied from strings/stringpiece.h with modifications +// +// A string-like object that points to a sized piece of memory. +// +// Functions or methods may use const StringPiece& parameters to accept either +// a "const char*" or a "string" value that will be implicitly converted to +// a StringPiece. The implicit conversion means that it is often appropriate +// to include this .h file in other files rather than forward-declaring +// StringPiece as would be appropriate for most other Google classes. +// +// Systematic usage of StringPiece is encouraged as it will reduce unnecessary +// conversions from "const char*" to "string" and back again. +// + +#ifndef BASE_STRING_PIECE_H__ +#define BASE_STRING_PIECE_H__ + +//Uncomment this line if you use ICU in your code. +//#define USE_ICU +//Uncomment this line if you want boost hashing for your StringPieces. +//#define USE_BOOST + +#include +#include + +#ifdef USE_ICU +#include +U_NAMESPACE_BEGIN +#else + +#include +#include +#include +#include + +class StringPiece { + public: + typedef size_t size_type; + + private: + const char* ptr_; + size_type length_; + + public: + // We provide non-explicit singleton constructors so users can pass + // in a "const char*" or a "string" wherever a "StringPiece" is + // expected. + StringPiece() : ptr_(NULL), length_(0) { } + StringPiece(const char* str) + : ptr_(str), length_((str == NULL) ? 0 : strlen(str)) { } + StringPiece(const std::string& str) + : ptr_(str.data()), length_(str.size()) { } + StringPiece(const char* offset, size_type len) + : ptr_(offset), length_(len) { } + + // data() may return a pointer to a buffer with embedded NULs, and the + // returned buffer may or may not be null terminated. Therefore it is + // typically a mistake to pass data() to a routine that expects a NUL + // terminated string. + const char* data() const { return ptr_; } + size_type size() const { return length_; } + size_type length() const { return length_; } + bool empty() const { return length_ == 0; } + + void clear() { ptr_ = NULL; length_ = 0; } + void set(const char* data, size_type len) { ptr_ = data; length_ = len; } + void set(const char* str) { + ptr_ = str; + length_ = str ? strlen(str) : 0; + } + void set(const void* data, size_type len) { + ptr_ = reinterpret_cast(data); + length_ = len; + } + + char operator[](size_type i) const { return ptr_[i]; } + + void remove_prefix(size_type n) { + ptr_ += n; + length_ -= n; + } + + void remove_suffix(size_type n) { + length_ -= n; + } + + int compare(const StringPiece& x) const { + int r = wordmemcmp(ptr_, x.ptr_, std::min(length_, x.length_)); + if (r == 0) { + if (length_ < x.length_) r = -1; + else if (length_ > x.length_) r = +1; + } + return r; + } + + std::string as_string() const { + // std::string doesn't like to take a NULL pointer even with a 0 size. + return std::string(!empty() ? data() : "", size()); + } + + void CopyToString(std::string* target) const; + void AppendToString(std::string* target) const; + + // Does "this" start with "x" + bool starts_with(const StringPiece& x) const { + return ((length_ >= x.length_) && + (wordmemcmp(ptr_, x.ptr_, x.length_) == 0)); + } + + // Does "this" end with "x" + bool ends_with(const StringPiece& x) const { + return ((length_ >= x.length_) && + (wordmemcmp(ptr_ + (length_-x.length_), x.ptr_, x.length_) == 0)); + } + + // standard STL container boilerplate + typedef char value_type; + typedef const char* pointer; + typedef const char& reference; + typedef const char& const_reference; + typedef ptrdiff_t difference_type; + static const size_type npos; + typedef const char* const_iterator; + typedef const char* iterator; + typedef std::reverse_iterator const_reverse_iterator; + typedef std::reverse_iterator reverse_iterator; + iterator begin() const { return ptr_; } + iterator end() const { return ptr_ + length_; } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(ptr_ + length_); + } + const_reverse_iterator rend() const { + return const_reverse_iterator(ptr_); + } + + size_type max_size() const { return length_; } + size_type capacity() const { return length_; } + + size_type copy(char* buf, size_type n, size_type pos = 0) const; + + size_type find(const StringPiece& s, size_type pos = 0) const; + size_type find(char c, size_type pos = 0) const; + size_type rfind(const StringPiece& s, size_type pos = npos) const; + size_type rfind(char c, size_type pos = npos) const; + + size_type find_first_of(const StringPiece& s, size_type pos = 0) const; + size_type find_first_of(char c, size_type pos = 0) const { + return find(c, pos); + } + size_type find_first_not_of(const StringPiece& s, size_type pos = 0) const; + size_type find_first_not_of(char c, size_type pos = 0) const; + size_type find_last_of(const StringPiece& s, size_type pos = npos) const; + size_type find_last_of(char c, size_type pos = npos) const { + return rfind(c, pos); + } + size_type find_last_not_of(const StringPiece& s, size_type pos = npos) const; + size_type find_last_not_of(char c, size_type pos = npos) const; + + StringPiece substr(size_type pos, size_type n = npos) const; + + static int wordmemcmp(const char* p, const char* p2, size_type N) { + return memcmp(p, p2, N); + } +}; + +inline bool operator==(const StringPiece& x, const StringPiece& y) { + if (x.size() != y.size()) + return false; + + return std::memcmp(x.data(), y.data(), x.size()) == 0; +} + +inline bool operator!=(const StringPiece& x, const StringPiece& y) { + return !(x == y); +} + +#endif + +inline bool operator<(const StringPiece& x, const StringPiece& y) { + const int r = std::memcmp(x.data(), y.data(), + std::min(x.size(), y.size())); + return ((r < 0) || ((r == 0) && (x.size() < y.size()))); +} + +inline bool operator>(const StringPiece& x, const StringPiece& y) { + return y < x; +} + +inline bool operator<=(const StringPiece& x, const StringPiece& y) { + return !(x > y); +} + +inline bool operator>=(const StringPiece& x, const StringPiece& y) { + return !(x < y); +} + +// allow StringPiece to be logged (needed for unit testing). +extern std::ostream& operator<<(std::ostream& o, const StringPiece& piece); + +#ifdef USE_BOOST +size_t hash_value(const StringPiece &str); + +/* Support for lookup of StringPiece in boost::unordered_map */ +struct StringPieceCompatibleHash : public std::unary_function { + size_t operator()(const StringPiece &str) const { + return hash_value(str); + } +}; + +struct StringPieceCompatibleEquals : public std::binary_function { + bool operator()(const StringPiece &first, const StringPiece &second) const { + return first == second; + } +}; +template typename T::const_iterator FindStringPiece(const T &t, const StringPiece &key) { + return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); +} +template typename T::iterator FindStringPiece(T &t, const StringPiece &key) { + return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); +} +#endif + +#ifdef USE_ICU +U_NAMESPACE_END +#endif + +#endif // BASE_STRING_PIECE_H__ -- cgit v1.2.3